
# ==============================================
# Stable Diffusion — Implementation 
# Source: StableDiffusion.ipynb
# Source blog post: https://dzdata.medium.com/intro-to-diffusion-model-part-1-29fe7724c043
# ----------------------------------------------
# Notes:
#  - This file is converted to .py file by Abhijit Challapalli under the guidance of Instructor Dr. Farhad Kamangar 
#  - for the purpose of class Assignment CSE 5368 Neural Networks - Fall 2025
# ==============================================

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.utils import make_grid
import datetime
import time
import requests
from PIL import Image
from tqdm import tqdm
import sys, os, subprocess, importlib

# Ensure 'datasets[vision]' is available even outside notebooks
def pip_install(spec: str, *, target: str | None = None, upgrade: bool = True,
                quiet: bool = False, prefer_binary: bool = True) -> None:
    cmd = [sys.executable, "-m", "pip", "install", spec]
    if upgrade: cmd.append("--upgrade")
    if target:  cmd += ["--target", target]
    if prefer_binary: cmd.append("--prefer-binary")
    if quiet:   cmd.append("-q")

    # Show pip output; raise on failure
    proc = subprocess.run(cmd)
    if proc.returncode != 0:
        raise RuntimeError(f"pip install failed for {spec} (exit {proc.returncode})")

try:
    import datasets
except ImportError:
    pip_install("datasets[vision]")
    import datasets



#Q : What are we trying to do in this plot?
T = 100
x0 = np.array([10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 1, 2, 3, 4, 5, 6 ,7, 8, 9, 10, 11, 12, 13 ,14, 15, 16, 17, 18, 19, 20], dtype=float)
beta = np.linspace(0.01, 0.1, T)
alpha = 1 - beta
alpha_cumprod = np.cumprod(alpha)

#Q : Do you think the mean of the x_t in this case can ever be a zero vector? 
# if so when? Justify your answer.
x_t = x0

#Q : What happens if the Beta value is too small (almost equal to zero but not zero) 
# and if it's too large (still < 1), Justify your answer.
for t in range(T):
    noise = np.random.normal(0, np.sqrt(beta[t]), size=x_t.shape)
    x_t = np.sqrt(alpha[t]) * x_t + noise

x_T = x_t

#Q : How does changing alpha value here impacts the x_t?
plt.hist(x_T, bins=10, density=True, alpha=0.7, label="x_T Distribution")
x = np.linspace(-3, 3, 100)
plt.plot(x, 1 / np.sqrt(2 * np.pi) * np.exp(-x**2 / 2), label="Standard Gaussian")
plt.title("Distribution of x_T (Final Noisy Data)")
plt.xlabel("Value")
plt.ylabel("Density")
plt.legend()
#plt.show() # You can uncomment this line to view the plot




def plot_probability_mass(mu=0.3):
    sigma = 0.05  # Standard deviation
    delta = 1 / 255  # Half bin width

    x_below = np.floor(mu * 255) / 255  # Nearest discrete value below
    x_above = np.ceil(mu * 255) / 255   # Nearest discrete value above

    x = np.linspace(mu - 3*sigma, mu + 3*sigma, 1000)
    pdf = norm.pdf(x, mu, sigma)

    x_range_below = np.linspace(x_below - delta, x_below + delta, 100)
    x_range_above = np.linspace(x_above - delta, x_above + delta, 100)

    pdf_below = norm.pdf(x_below, mu, sigma)
    pdf_above = norm.pdf(x_above, mu, sigma)

    p_mass_below = (norm.cdf(x_below + delta, mu, sigma) - norm.cdf(x_below - delta, mu, sigma)) * (2 / 255)
    p_mass_above = (norm.cdf(x_above + delta, mu, sigma) - norm.cdf(x_above - delta, mu, sigma)) * (2 / 255)

    plt.figure(figsize=(8, 5))
    plt.plot(x, pdf, label="Gaussian PDF (μ={}, σ=0.05)".format(mu))
    plt.axvline(x_below, color='r', linestyle='--', label=f"x={x_below:.3f} (PDF={pdf_below:.3f})")
    plt.axvline(x_above, color='b', linestyle='--', label=f"x={x_above:.3f} (PDF={pdf_above:.3f})")

    plt.fill_between(x_range_below, norm.pdf(x_range_below, mu, sigma), color='red', alpha=0.5,
                     label=f"P({x_below:.3f}±1/255) * (2/255)={p_mass_below:.7f}")
    plt.fill_between(x_range_above, norm.pdf(x_range_above, mu, sigma), color='blue', alpha=0.5,
                     label=f"P({x_above:.3f}±1/255) * (2/255)={p_mass_above:.7f}")

    plt.title("Interactive Probability Mass Calculation")
    plt.xlabel("x values")
    plt.ylabel("Density")
    plt.legend()
    #plt.show() # You can uncomment this line to view the plot

    print(f"Selected μ: {mu:.3f}")
    print(f"Nearest Discrete Values: {x_below:.3f}, {x_above:.3f}")
    print(f"Scaled Probability Mass for x={x_below:.3f}: {p_mass_below:.7f}")
    print(f"Scaled Probability Mass for x={x_above:.3f}: {p_mass_above:.7f}")

# mu_slider = widgets.FloatSlider(min=0.0, max=1.0, step=0.01, value=0.3, description="μ (Mean)")
# widgets.interactive(plot_probability_mass, mu=mu_slider)


#Q : What's shown in this plot? What is silu? Why does it do?
x_values = torch.tensor([-2.0, -0.5, 0.5, 3.0])

x_silu_values = torch.nn.functional.silu(x_values)

x_values_np = x_values.numpy()
x_silu_values_np = x_silu_values.numpy()

# Plot original values and transformed values
plt.figure(figsize=(6, 4))
plt.scatter(range(len(x_values_np)), x_values_np, color='red', label="Original Values", marker='o')
plt.scatter(range(len(x_silu_values_np)), x_silu_values_np, color='blue', label="SiLU Transformed", marker='s')

for i in range(len(x_values_np)):
    plt.plot([i, i], [x_values_np[i], x_silu_values_np[i]], 'k--', alpha=0.5)

plt.axhline(0, color='black', linestyle='--', alpha=0.5)
plt.xlabel("Pixel Index")
plt.ylabel("Value")
plt.title("Original vs. SiLU Transformed Values")
plt.legend()
#plt.show() # You can uncomment this line to view the plot



def space_to_depth(x, size=2):
    """
    Downsacle method that use the depth dimension to
    downscale the spatial dimensions
    Args:
        x (torch.Tensor): a tensor to downscale
        size (float): the scaling factor

    Returns:
        (torch.Tensor): new spatial downscale tensor
    """
    b, c, h, w = x.shape
    out_h = h // size
    out_w = w // size
    out_c = c * (size * size)
    x = x.reshape((-1, c, out_h, size, out_w, size)) 
    #Q : When this 'reshape' is applied, what is the input shape and the exact output shape?
    x = x.permute((0, 1, 3, 5, 2, 4))
    #Q : When this 'permute' is applied, what is the input shape and the exact output shape?
    x = x.reshape((-1, out_c, out_h, out_w))
    return x


class SpaceToDepth(nn.Module):
  def __init__(self, size):
    super().__init__()
    self.size = size

  def forward(self, x):
    return space_to_depth(x, self.size)

# Q : Why do we apply residual connection ? How does it help during the training?
class Residual(nn.Module):
  """
  Apply residual connection using an input function
  Args:
    func (function): a function to apply over the input
  """
  def __init__(self, func):
    super().__init__()
    self.func = func

  def forward(self, x, *args, **kwargs):
    return x + self.func(x, *args, **kwargs)

# Q : What does this function do 
# Given a tensor of shape (B,C,H,W) (2,64,64,64) and apply upsample and give shapes after each step? (Even though it's not called here)
def upsample(in_channels, out_channels=None):
  out_channels = in_channels if out_channels is None else out_channels
  # Q : What does nearest do here?
  seq = nn.Sequential(
      nn.Upsample(scale_factor=2, mode='nearest'),
      nn.Conv2d(in_channels, out_channels, 3, padding=1)
  )
  return seq

# Given a tensor of shape (B,C,H,W) (2,64,64,64) and apply upsample and give shapes after each step? (Even though it's not called here)
def downsample(in_channels, out_channels=None):
  out_channels = in_channels if out_channels is None else out_channels
  seq = nn.Sequential(
      SpaceToDepth(2),
      nn.Conv2d(4 * in_channels, out_channels, 1)
  )
  return seq

#Q : What is the role of Sinusidial Position Embedding, what is it embedding? Explain.
class SinusodialPositionEmbedding(nn.Module):
  def __init__(self, embedding_dim):
    super().__init__()
    self.embedding_dim = embedding_dim

  def forward(self, time_steps):
    #Q : What is the shape of positions before and after the unsqueeze?
    positions = torch.unsqueeze(time_steps, 1)
    half_dim = self.embedding_dim // 2
    embeddings = torch.zeros((time_steps.shape[0], self.embedding_dim), device=time_steps.device)
    denominators = 10_000 ** (2 * torch.arange(self.embedding_dim // 2, device=time_steps.device) / self.embedding_dim)
    #What does embeddings[:, 0::2] and embeddings[:, 1::2] do here?
    embeddings[:, 0::2] = torch.sin(positions/denominators)
    embeddings[:, 1::2] = torch.cos(positions/denominators)
    return embeddings


class WeightStandardizedConv2d(nn.Conv2d):
  """
  https://arxiv.org/abs/1903.10520
  weight standardization purportedly works synergistically with group normalization
  """

  def forward(self, x):
    eps = 1e-5 if x.dtype == torch.float32 else 1e-3

    weight = self.weight
    mean = weight.mean(dim=[1,2,3], keepdim=True)
    variance = weight.var(dim=[1,2,3], keepdim=True, correction=0)
    normalized_weight = (weight - mean) / torch.sqrt(variance)

    return F.conv2d(
        x,
        normalized_weight,
        self.bias,
        self.stride,
        self.padding,
        self.dilation,
        self.groups
    )

class PreGroupNorm(nn.Module):
  def __init__(self, dim , func, groups=1):
    super().__init__()
    self.func = func
    self.group_norm = nn.GroupNorm(groups, dim)

  def forward(self, x):
    x = self.group_norm(x)
    x = self.func(x)
    return x

class Block(nn.Module):
  def __init__(self, in_channels, out_channels, groups=8):
    super().__init__()
    self.proj = WeightStandardizedConv2d(in_channels, out_channels, 3, padding=1)
    self.norm = nn.GroupNorm(groups, out_channels)
    self.act = nn.SiLU()

  def forward(self, x, scale_shift=None):
    x = self.proj(x)
    x = self.norm(x)

    if scale_shift is not None:
      scale, shift = scale_shift
      x = x * (scale + 1) + shift

    x = self.act(x)
    return x

class ResnetBlock(nn.Module):
  def __init__(self, in_channels, out_channels, time_emb_dim=None, groups=8):
    super().__init__()
    if time_emb_dim is not None:
      self.mlp = nn.Sequential(
          nn.SiLU(),
          nn.Linear(time_emb_dim, 2 * out_channels)
      )
    else:
      self.mlp = None

    self.block1 = Block(in_channels, out_channels, groups)
    self.block2 = Block(out_channels, out_channels, groups)
    if in_channels == out_channels:
      self.res_conv = nn.Identity()
    else:
      self.res_conv = nn.Conv2d(in_channels, out_channels, 1)

  def forward(self, x, time_emb=None):
    scale_shift = None
    if self.mlp is not None and time_emb is not None:
      time_emb = self.mlp(time_emb)
      time_emb = time_emb.view(*time_emb.shape, 1, 1)
      scale_shift = time_emb.chunk(2, dim=1)

    h = self.block1(x, scale_shift=scale_shift)
    h = self.block2(h)
    return h + self.res_conv(x)


class Attention(nn.Module):
  def __init__(self, in_channels, num_heads=4, dim_head=32):
    super().__init__()
    self.num_heads = num_heads
    self.dim_head = dim_head
    self.scale_factor = 1 / (dim_head) ** 0.5
    self.hidden_dim = num_heads * dim_head
    self.input_to_qkv = nn.Conv2d(in_channels, 3 * self.hidden_dim, 1, bias=False)
    self.to_output = nn.Conv2d(self.hidden_dim, in_channels, 1)

  #Q : Why do we need attention? is it applied to each image independently in the batch? Explain
  def forward(self, x):
    b, c, h, w = x.shape
    #Q : What is happening in the steps below, how is the input image shapes change and why? Explain.
    qkv = self.input_to_qkv(x)
    q, k, v = map(lambda t: t.view(b, self.num_heads, self.dim_head, h * w), qkv.chunk(3, dim=1))
    q = q * self.scale_factor

    #Q : Explain the couple of steps below and what operation is done between q,k and why??
    sim = torch.einsum("b h c i, b h c j -> b h i j", q, k)
    sim = sim - sim.amax(dim=-1, keepdim=True).detach()
    attention = sim.softmax(dim=-1)

    #Q : Explain the couple of steps below and what operation is done with v and why??
    output = torch.einsum("b h i j, b h c j -> b h i c", attention, v)
    output = output.permute(0, 1, 3, 2).reshape((b, self.hidden_dim, h, w))
    return self.to_output(output)


class LinearAttention(nn.Module):
  def __init__(self, in_channels, num_heads=4, dim_head=32):
    super().__init__()
    self.num_heads = num_heads
    self.dim_head = dim_head
    self.scale_factor = 1 / (dim_head) ** 0.5
    self.hidden_dim = num_heads * dim_head
    self.input_to_qkv = nn.Conv2d(in_channels, 3 * self.hidden_dim, 1, bias=False)
    self.to_output = nn.Sequential(
        nn.Conv2d(self.hidden_dim, in_channels, 1),
        nn.GroupNorm(1, in_channels)
    )

  def forward(self, x):
    b, c, h, w = x.shape
    qkv = self.input_to_qkv(x)
    q, k, v = map(lambda t: t.view(b, self.num_heads, self.dim_head, h * w), qkv.chunk(3, dim=1))

    q = q.softmax(dim=-2)
    k = k.softmax(dim=-1)

    q = q * self.scale_factor
    context = torch.einsum("b h d n, b h e n -> b h d e", k, v)
    output = torch.einsum("b h d e, b h d n -> b h e n", context, q)
    output = output.view((b, self.hidden_dim, h, w))
    return self.to_output(output)


#Q : Why do we need Unet in diffusion? What is the role of it? Justify your answer
class DiffusionUnet(nn.Module):
  def __init__(self, dim, init_dim=None, output_dim=None, dim_mults=(1, 2, 4, 8), channels=3, resnet_block_groups=4):
    super().__init__()

    self.channels = channels
    #Q : What is the value in init_dim
    init_dim = init_dim if init_dim is not None else dim
    self.init_conv = nn.Conv2d(self.channels, init_dim, 1)
    #Q : Can you print the values of the dims?
    dims = [init_dim] + [m * dim for m in dim_mults]
    #Q : Can you give us the output of the below line? Justify your answer.
    input_output_dims = list(zip(dims[:-1], dims[1:]))

    time_dim = 4 * dim  # time embedding

    self.time_mlp = nn.Sequential(
        SinusodialPositionEmbedding(dim),
        nn.Linear(dim, time_dim),
        nn.GELU(),
        nn.Linear(time_dim, time_dim)
    )

    # --- down layers ---
    self.down_layers = nn.ModuleList([])
    for ii, (dim_in, dim_out) in enumerate(input_output_dims, 1):
        is_last = ii == len(input_output_dims)
        self.down_layers.append(nn.ModuleList([
            ResnetBlock(dim_in, dim_in, time_emb_dim=time_dim, groups=resnet_block_groups),
            ResnetBlock(dim_in, dim_in, time_emb_dim=time_dim, groups=resnet_block_groups),
            Residual(PreGroupNorm(dim_in, LinearAttention(dim_in))),
            downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding=1),
        ]))

    # --- middle/bottleneck layers ---
    mid_dim = dims[-1]
    self.mid_block1    = ResnetBlock(mid_dim, mid_dim, time_emb_dim=time_dim, groups=resnet_block_groups)
    self.mid_attention = Residual(PreGroupNorm(mid_dim, Attention(mid_dim)))
    self.mid_block2    = ResnetBlock(mid_dim, mid_dim, time_emb_dim=time_dim, groups=resnet_block_groups)

    # --- up layers ---
    self.up_layers = nn.ModuleList([])
    for ii, (dim_in, dim_out) in enumerate(reversed(input_output_dims), 1):
        is_last = ii == len(input_output_dims)
        self.up_layers.append(nn.ModuleList([
            ResnetBlock(dim_out + dim_in, dim_out, time_emb_dim=time_dim, groups=resnet_block_groups),
            ResnetBlock(dim_out + dim_in, dim_out, time_emb_dim=time_dim, groups=resnet_block_groups),
            Residual(PreGroupNorm(dim_out, LinearAttention(dim_out))),
            upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding=1),
        ]))


        self.output_dim = output_dim if output_dim is not None else channels
        self.final_res_block = ResnetBlock(2 * dim, dim, time_emb_dim=time_dim, groups=resnet_block_groups)
        self.final_conv = nn.Conv2d(dim, self.output_dim, 1)

  def forward(self, x, time):
    #Q : Consider the dataset we trained butterfiles dataset, and determine the shapes
    x = self.init_conv(x) #Q : Shape of x before and after convolution
    init_result = x.clone()
    t = self.time_mlp(time) #Q : What is the shape of t?
    h = []

    #Down Layers
    for block1, block2, attention, downsample_block in self.down_layers:
      x = block1(x, t) #Q : What is the shape of block1 ?
      h.append(x) 

      x = block2(x, t) #Q : What is the shape of block2 ?
      x = attention(x) #Q : Shape of X after attention ?

      h.append(x)

      x = downsample_block(x) #Q : What is the shape of x

    # Bottleneck layers
    x = self.mid_block1(x, t) #Q : What is the shape of x
    x = self.mid_attention(x)
    x = self.mid_block2(x ,t) #Q : What is the shape of x

    #Up layers
    for block1, block2, attention, upsample_block in self.up_layers:
      x = torch.cat((x , h.pop()), dim=1) #Q : What is the shape of x
      x = block1(x, t) #Q : What is the shape of x

      x = torch.cat((x, h.pop()), dim=1)
      x = block2(x, t) #Q : What is the shape of x

      x = attention(x)

      x = upsample_block(x) #Q : What is the shape of x

    x = torch.cat((x, init_result), dim=1)
    x = self.final_res_block(x, t)
    x = self.final_conv(x) #Q : What is the shape of x
    return x


#Q : What does linspace do here?
def linear_schedule(num_timesteps):
  beta_start = 1e-4
  beta_end = 0.02
  betas = torch.linspace(beta_start, beta_end, num_timesteps)
  #Q : Why do we concatenate tensor([0]) here? Justify your answer.
  betas = torch.cat((torch.tensor([0]), betas))
  return betas

#Q : What is s in this function, what does it do?
def cosine_schedule(num_timesteps, s=0.008):
  def f(t):
    return torch.cos((t / num_timesteps + s) / (1 + s) * 0.5 * torch.pi) ** 2
  x = torch.linspace(0, num_timesteps, num_timesteps + 1)
  #Why do they divide the 1D tensor with it's first element? Justify your answer?
  alphas_cumprod = f(x) / f(torch.tensor([0]))
  betas = 1 - alphas_cumprod[1:] / alphas_cumprod[:-1]
  #Q : What does clip do?
  betas = torch.clip(betas, 0.0001, 0.999)
  return betas


url = 'https://images.pexels.com/photos/1557208/pexels-photo-1557208.jpeg?auto=compress&cs=tinysrgb&w=1260&h=750&dpr=2'
image_raw_data = requests.get(url, stream=True).raw
image = Image.open(image_raw_data)


from torchvision.transforms import RandomHorizontalFlip, Compose, ToTensor, CenterCrop, Resize, Normalize
#Preprocessing
image_size = 64
transform = Compose([
  Resize((image_size, image_size)),
  # CenterCrop(image_size),
  RandomHorizontalFlip(),
  ToTensor(),
  Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

# Q: What is the shape of x0, before applying this transform? 
# and what does unsqueeze do, what is the shape after it?
x0 = transform(image).unsqueeze(0)


from torchvision.transforms import ToPILImage

reverse_transform_pil = Compose([
  Normalize(mean=(-1, -1, -1), std=(2, 2, 2)),
  ToPILImage()
])

reverse_transform_tensor = Compose([
  Normalize(mean=(-1, -1, -1), std=(2, 2, 2)),
])

#Q : In this function what is the role of gather? 
def sample_by_t(tensor_to_sample, timesteps, x_shape):
  batch_size = timesteps.shape[0]
  sampled_tensor = tensor_to_sample.gather(-1, timesteps.cpu())
  #Q : What is the shape of the sampled_tensor before and afer reshape
  sampled_tensor = torch.reshape(sampled_tensor, (batch_size,) + (1,) * (len(x_shape) - 1))
  return sampled_tensor.to(timesteps.device)

#Q : What is this function implementing? Justify your answer.
def sample_q(x0, t, noise=None):
  #Q : Why do we need this condition?
  if noise is None:
    noise = torch.randn_like(x0)

  sqrt_alphas_bar_t_sampled = sample_by_t(sqrt_alphas_bar_t, t, x0.shape)
  sqrt_1_minus_alphas_bar_t_sampled = sample_by_t(sqrt_1_minus_alphas_bar_t, t, x0.shape)
  x_t = sqrt_alphas_bar_t_sampled * x0 + sqrt_1_minus_alphas_bar_t_sampled * noise
  return x_t

#Q :Why do we need squeeze here? What is the shape of the x_noisy image before and after?
def get_noisy_image(x0, t, transform=reverse_transform_pil):
  x_noisy = sample_q(x0, t)
  noise_image = transform(x_noisy.squeeze())
  return noise_image


def show_noisy_images(noisy_images):
  num_of_image_sets = len(noisy_images)
  num_of_images_in_set = len(noisy_images[0])
  image_size = noisy_images[0][0].size[0]

  full_image = Image.new('RGB', (image_size * num_of_images_in_set + (num_of_images_in_set - 1), image_size * num_of_image_sets + (num_of_image_sets - 1)))
  for set_index, image_set in enumerate(noisy_images):
    for image_index, image in enumerate(image_set):
      full_image.paste(image, (image_index * image_size + image_index, set_index * image_size + set_index))

  plt.imshow(full_image)
  plt.axis('off')
  return full_image



#Q : What are we calculating in the next few lines? Justify your answer.
num_timesteps = 1000
betas_t = linear_schedule(num_timesteps)
alphas_t = 1. - betas_t
alphas_bar_t = torch.cumprod(alphas_t, dim=0)
alphas_bar_t_minus_1 = torch.cat((torch.tensor([0]), alphas_bar_t[:-1]))
one_over_sqrt_alphas_t = 1. / torch.sqrt(alphas_t)
sqrt_alphas_t = torch.sqrt(alphas_t)
sqrt_alphas_bar_t = torch.sqrt(alphas_bar_t)
sqrt_alphas_bar_t_minus_1 = torch.sqrt(alphas_bar_t_minus_1)
sqrt_1_minus_alphas_bar_t = torch.sqrt(1. - alphas_bar_t)

#Q : What is this formula implementing? Does it mean the variance is fixed here and why?
posterior_variance = (1. - alphas_bar_t_minus_1) / (1. - alphas_bar_t) * betas_t


show_noisy_images([[get_noisy_image(x0, torch.tensor([t])) for t in [0, 50, 100, 150, 200]]])


@torch.no_grad()
def sample_p(model, x_t, t, clipping=True):
  """
  Sample from p_θ(xₜ₋₁|xₜ) to get xₜ₋₁ according to Algorithem 2
  """
  betas_t_sampled = sample_by_t(betas_t, t, x_t.shape)
  sqrt_1_minus_alphas_bar_t_sampled = sample_by_t(sqrt_1_minus_alphas_bar_t, t, x_t.shape)
  one_over_sqrt_alphas_t_sampled = sample_by_t(one_over_sqrt_alphas_t, t, x_t.shape)

  #Q : Explain the steps of reconstruction when clipping is True? 
  if clipping:
    sqrt_alphas_bar_t_sampled = sample_by_t(sqrt_alphas_bar_t, t, x_t.shape)
    sqrt_alphas_bar_t_minus_1_sampled = sample_by_t(sqrt_alphas_bar_t_minus_1, t, x_t.shape)
    alphas_bar_t_sampled = sample_by_t(alphas_bar_t, t, x_t.shape)
    sqrt_alphas_t_sampled = sample_by_t(sqrt_alphas_t, t, x_t.shape)
    alphas_bar_t_minus_1_sampled = sample_by_t(alphas_bar_t_minus_1, t, x_t.shape)

    x0_reconstruct = 1 / sqrt_alphas_bar_t_sampled * (x_t - sqrt_1_minus_alphas_bar_t_sampled * model(x_t, t))
    x0_reconstruct = torch.clip(x0_reconstruct, -1., 1.)
    predicted_mean = (sqrt_alphas_bar_t_minus_1_sampled * betas_t_sampled) / (1 - alphas_bar_t_sampled) * x0_reconstruct + (sqrt_alphas_t_sampled * (1 - alphas_bar_t_minus_1_sampled)) /  (1 - alphas_bar_t_sampled) * x_t

  #Q : Explain the steps of reconstruction when clipping is False? 
  else:
    predicted_mean = one_over_sqrt_alphas_t_sampled * (x_t - betas_t_sampled / sqrt_1_minus_alphas_bar_t_sampled * model(x_t, t))

  #Q : What do you conclude about the clipping in reconstruction of the image? 
  if t[0].item() == 1:
    return predicted_mean
  else:
    posterior_variance_sampled = sample_by_t(posterior_variance, t, x_t.shape)
    noise = torch.randn_like(x_t)
    return predicted_mean + torch.sqrt(posterior_variance_sampled) * noise




@torch.no_grad()
def sampling(model, shape, image_noise_steps_to_keep=1):
  """
  Implmenting Algorithm 2 - sampling.
  Args:
    model (torch.Module): the model that predictד the noise
    shape (tuple): shape of the data (batch, channels, image_size, image_size)
  Returns:
    (list): list containing the images in the different steps of the reverse process
  """

  batch = shape[0]
  images = torch.randn(shape, device=device)  # pure noise
  images_list = [images.cpu()]

  for timestep in tqdm(range(num_timesteps, 0, -1), desc='sampling timestep'):
    images = sample_p(model, images, torch.full((batch,), timestep, device=device, dtype=torch.long))
    if timestep <= image_noise_steps_to_keep:
      images_list.append(images.cpu())
  return images_list


from datasets import load_dataset
from torch.utils.data import DataLoader
#dataset = load_dataset("amaye15/fruit", split='train') # You can uncomment this line and load another dataset to try an run the code.


dataset = load_dataset("huggan/smithsonian_butterflies_subset", split='train')

print(dataset)
print(dataset[0])


from torchvision.transforms import RandomHorizontalFlip, Compose, ToTensor, Resize, Normalize

image_size = 64
transform = Compose([
  Resize((image_size, image_size)),
  RandomHorizontalFlip(),
  ToTensor(),
  Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

def transforms(data):
  images = [transform(im) for im in data['image']]
  return {'images': images}

dataset.set_transform(transforms)

batch_size=32
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

batch = next(iter(train_dataloader))
reverse_transform_pil(batch['images'][20])

from pathlib import Path
results_folder = Path("./results")
results_folder.mkdir(exist_ok = True)

def compute_loss(model, x0, t, noise=None):
  if noise is None:
    noise = torch.randn_like(x0)

  x_t = sample_q(x0, t, noise)
  predicted_noise = model(x_t, t)
  loss = F.l1_loss(noise, predicted_noise)
  return loss

from torch.optim import Adam
# --- setup ---
device = "cuda" if torch.cuda.is_available() else "cpu"
model = DiffusionUnet(dim=image_size, channels=3, dim_mults=(1, 2, 4, 8)).to(device)
optimizer = Adam(model.parameters(), lr=1e-4)

from pathlib import Path
checkpoint_path = Path("./saved_model.pth")
start_epoch = 0          
model_loaded = False     # controls whether we train

# --- loading model ---
if checkpoint_path.exists():
    print(f"Loading checkpoint: {checkpoint_path}")
    ckpt = torch.load(str(checkpoint_path), map_location=device)
    state = ckpt.get("model", ckpt)               
    model.load_state_dict(state)
    if "optimizer" in ckpt:
        optimizer.load_state_dict(ckpt["optimizer"])
    current_epoch = int(ckpt.get("epoch", 0))
    model_loaded = True
    model.eval()  
else:
    checkpoint_path.parent.mkdir(parents=True, exist_ok=True)

import numpy as np
from torchvision.utils import save_image

epochs = 5
loss_steps = 50
sample_every = 1000
loss_for_mean = np.zeros(loss_steps)

# --- train only if no checkpoint ---
if not model_loaded:
    prev_time = time.time()
    for epoch in range(start_epoch, epochs):
        for batch_index, batch in enumerate(train_dataloader):
            images = batch['images'].to(device)
            t = torch.randint(1, num_timesteps, (images.shape[0],), device=device).long()
            loss = compute_loss(model, images, t)
            current_step = batch_index + epoch * len(train_dataloader)

            if current_step % loss_steps == 0:
                batches_done = epoch * len(train_dataloader) + batch_index
                batches_left = epochs * len(train_dataloader) - current_step
                time_left = datetime.timedelta(
                    seconds=batches_left * (time.time() - prev_time) / loss_steps
                )
                prev_time = time.time()
                print(f'Loss at epoch {epoch}, batch {batch_index}: {loss_for_mean.mean()} | time remaining: {time_left}')
                loss_for_mean[:] = 0

            loss_for_mean[current_step % loss_steps] = loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if current_step % sample_every == 0:
                batch_to_sample = 5
                sample_images_list = sampling(model, (batch_to_sample, 3, image_size, image_size))
                sample_images = torch.cat(sample_images_list, dim=0)
                sample_images = reverse_transform_tensor(sample_images)
                save_image(sample_images, str(results_folder / f'sample_{current_step}.png'), nrow=batch_to_sample)

        # save checkpoint after each epoch
        torch.save({
            "epoch": epoch + 1,
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
        }, str(checkpoint_path))
else:
    print("Checkpoint found; skipping training")


steps = sampling(model, (4, 3, 64, 64),300)
all_imgs=torch.cat(steps,dim=0)

all_imgs = reverse_transform_tensor(all_imgs)        
save_image(all_imgs,results_folder / "grid_final.png",nrow=4)
