본문 바로가기
클라우드/colab

Colab으로 자신만의 stable diffusion model구현하여 이미지 생성하기

by kimjunhee9339 2024. 6. 20.

이번에는 Colab으로 자신만의 스테이블 디퓨젼 모델(이미지 생성이 가장 쉬울 것이라 판단되어 이미지를 생성해보도록 하겠다.)을 만들어서 이미지를 생성해보겠다.

#modules load
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from PIL import Image
from torchvision.models import vgg19
import math
import torch_xla
import torch_xla.core.xla_model as xm
import clip
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = xm.xla_device()

우선 필요한 모듈들을 불러온다.

모델의 아키텍처는 다음과 같다

.

모델 아키텍처

우선적으로 attention block을 만들어준다.

class AttentionBlock(nn.Module):
  def __init__(self,group,out_channels,num_heads):
    super(AttentionBlock, self).__init__()
    self.attention_norm = nn.GroupNorm(group, out_channels)
    self.attention = nn.MultiheadAttention(out_channels,num_heads,batch_first=True) #??
  def forward(self,x):
    #Attention Block
    batch_size, channels, h, w = x.shape
    in_attn=x.reshape(batch_size,channels, h*w)
    in_attn = self.attention_norm(in_attn)
    in_attn = in_attn.transpose(1,2)
    out_attn, _ = self.attention(in_attn,in_attn,in_attn)
    out_attn = out_attn.transpose(1,2).reshape(batch_size,channels,h,w)
    return out_attn

참고로 stable diffusion은 latent diffusion 아키텍쳐를 상속받고 있으므로 latent space로 mapping할 encoder와 원래대로 복원할 decoder도 만들어서 따로 학습시켜야한다.

encoder은 다음과 같다.

class Block(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1,bias=False):
    super(Block, self).__init__()
    self.block1 = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,bias=False),
        nn.BatchNorm2d(out_channels),
        nn.SiLU(),
        nn.AvgPool2d(2),
    )
  def forward(self, x):
    x = self.block1(x)
    return x
class Encoder(nn.Module):
  def __init__(self):
    super(Encoder, self).__init__()
    self.block1=Block(3,8) #//2
    self.block2=Block(8,16) #//4
    self.block3=Block(16,32) #//8
    self.block4=Block(32,64) #//16
    self.getmu=nn.Conv2d(64,64,kernel_size=1,stride=1,padding=0,bias=False)
    self.logvar=nn.Conv2d(64,64,kernel_size=1,stride=1,padding=0,bias=False)
    self.sig=nn.Sigmoid()
  def forward(self,x):
    x=self.block1(x)
    x=self.block2(x)
    x=self.block3(x)
    x=self.block4(x)
    mu = self.sig(self.getmu(x))
    logvar = self.sig(self.logvar(x))
    std = torch.exp(0.5 * logvar)
    z=mu+std*torch.randn_like(std)
    return z, mu, logvar

 

decoder은 다음과 같다.

class BlockD(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1,bias=False):
    super(BlockD, self).__init__()
    self.block1 = nn.Sequential(
        nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,bias=False),
        nn.BatchNorm2d(out_channels),
        nn.SiLU(),
        nn.UpsamplingBilinear2d(scale_factor=2),
    )
  def forward(self, x):
    x = self.block1(x)
    return x
class Decoder(nn.Module):
  def __init__(self):
    super(Decoder, self).__init__()
    self.block1 = BlockD(64, 32) #*2
    self.block2 = BlockD(32, 16) #*4
    self.block3 = BlockD(16, 8) #*8
    self.block4 = BlockD(8, 3) #*16
    self.block5 = nn.ConvTranspose2d(3, 3,1,1,0, bias=False) 
    self.tanh=nn.Tanh()
  def forward(self,x):
    x = self.block1(x)
    x = self.block2(x)
    x = self.block3(x)
    x = self.block4(x)
    x = self.block5(x)
    x=self.tanh(x)
    return x

encoder과 decoder을 연결하여 VAE모델을 만든다.

class VAE(nn.Module):
  def __init__(self):
    super(VAE, self).__init__()
    self.encoder=Encoder()
    self.decoder=Decoder()
  def forward(self,x):
    z,mu,logvar=self.encoder(x)
    output=self.decoder(z)
    return output,mu,logvar

 

그 다음으로 디퓨전모델의 시간적인 부분을 제어하기 위한 타임 스케줄러를 만들어 준다.

class TimeSchdule:
  def __init__(self):
    self.betas=torch.linspace(0.0001,0.02,1000).to(device)
    self.alphas = 1-self.betas
    self.alpha_cum_prod = torch.cumprod(self.alphas,dim=0)
    self.sqrt_alpha_cum_prod = torch.sqrt(self.alpha_cum_prod)
    self.one_minus_alpha_cum_prod = 1.-self.alpha_cum_prod
    self.sqrt_one_minus_alpha_cum_prod = torch.sqrt(1.-self.alpha_cum_prod)
  def add_noise(self,original,noise,t):
    original_shape = original.shape
    batch_size=original_shape[0]

    sqrt_alpha_cum_prod=self.sqrt_alpha_cum_prod[t].reshape(batch_size) #batchsize
    sqrt_one_minus_alpha_cum_prod=self.sqrt_one_minus_alpha_cum_prod[t].reshape(batch_size)

    for _ in range(len(original_shape)-1):
      sqrt_alpha_cum_prod = sqrt_alpha_cum_prod.unsqueeze(-1) #batchsize,1,1,1
      sqrt_one_minus_alpha_cum_prod=sqrt_one_minus_alpha_cum_prod.unsqueeze(-1)
    return sqrt_alpha_cum_prod*original + sqrt_one_minus_alpha_cum_prod*noise
  def sample_prev_timestep(self,xt,noise_pred,t):
    x0=(xt-(self.sqrt_one_minus_alpha_cum_prod[t]*noise_pred))/self.sqrt_alpha_cum_prod[t]
    x0=torch.clamp(x0,-1.,1.)

    mean=xt-((self.betas[t]*noise_pred) / (self.sqrt_one_minus_alpha_cum_prod[t]))
    mean=mean/torch.sqrt(self.alphas[t])

    if (t==0).sum()==0 :
      return mean,x0
    else:
      variance = (self.one_minus_alpha_cum_prod[t-1]) / (self.one_minus_alpha_cum_prod[t])
      variance = variance * (self.betas[t])
      sigma = variance ** 0.5
      z = torch.randn_like(xt)
      return mean+sigma*z, x0

그 다음으로 diffusion 과정을 수행할 모델 UNet을 만들어준다.

class BlockU1(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1,bias=False):
    super(BlockU1, self).__init__()
    self.attention=AttentionBlock(group=8,out_channels=in_channels,num_heads=8)
    self.block1 = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,bias=False),
        nn.BatchNorm2d(out_channels),
        nn.GroupNorm(8,out_channels),
        nn.AvgPool2d(2),
    )
  def forward(self, x):

    x = self.block1(x)
    return x
class BlockU2(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1,bias=False):
    super(BlockU2, self).__init__()
    self.attention=AttentionBlock(group=8,out_channels=in_channels,num_heads=8)
    self.block1 = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,bias=False),
        nn.BatchNorm2d(out_channels),
        nn.GroupNorm(8,out_channels),
        nn.UpsamplingBilinear2d(scale_factor=2),
    )
  def forward(self, x):

    x = self.block1(x)
    return x
class UNet(nn.Module):
  def __init__(self,concat_channel=0):
    super(UNet, self).__init__()
    self.block1=BlockU1(64+concat_channel,128)
    self.block2=BlockU1(128,128)
    self.block3=BlockU2(128,128)
    self.block4=BlockU2(128,64)
  def forward(self,x,condition=None):
    if(condition is not None):
      x=torch.cat((x,condition),dim=1)
    x=self.block1(x)
    x=self.block2(x)
    x=self.block3(x)
    x=self.block4(x)
    return x

최종적으로 stable diffusion model은 다음과 같이 만들어진다.

class StableDiffusion(nn.Module):
  def __init__(self,concat_channel=0):
    super(StableDiffusion, self).__init__()
    self.time_schdule=TimeSchdule()
    self.vae=VAE()
    self.unet=UNet(concat_channel=concat_channel)
    try:
      self.vae.load_state_dict(torch.load('/content/drive/MyDrive/My_Models/stable_diffusion_vae.pt',map_location=device))
      print("vae loaded")
    except:
      pass

    self.encoder=self.vae.encoder
    self.decoder=self.vae.decoder
    try:
      self.unet.load_state_dict(torch.load('/content/drive/MyDrive/My_Models/stable_diffusion_unet.pt',map_location=device))
      print("unet loaded")
    except:
      print("no model")
  def forward(self,x,noise,t,text=None):
    z,mu,logvar=self.encoder(x)
    z=self.time_schdule.add_noise(z,noise,t)
    if(text is not None):
        noise_pred=self.unet(z,text)
        z,z0=self.time_schdule.sample_prev_timestep(z,noise_pred,t[0])
    else:
      noise_pred=self.unet(z)
      z,z0=self.time_schdule.sample_prev_timestep(z,noise_pred,t[0])
    output0=self.decoder(z0)
    output=self.decoder(z)
    return output,output0

학습과정은 다음과 같다.

traindataset=datasets.ImageFolder(root="/content/drive/MyDrive/transparent",transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((512,512))
    ])) #당신의 이미지 데이터셋을 넣으세요
traindataset=torch.utils.data.DataLoader(traindataset,batch_size=32,shuffle=True)
vae=VAE()
unet=UNet()
try:
  vae.load_state_dict(torch.load('/content/drive/MyDrive/My_Models/stable_diffusion_vae.pt',map_location=device))
  unet.load_state_dict(torch.load('/content/drive/MyDrive/My_Models/stable_diffusion_unet.pt',map_location=device))
  print("model loaded")
except:
  print("model load failed")
vae.train()
unet.train()
vae.to(device)
unet.to(device)
time_schdule=TimeSchdule()
optimizer_vae=optim.Adam(vae.parameters(),lr=1e-4)
optimizer_unet=optim.Adam(unet.parameters(),lr=1e-4)
criterion=nn.MSELoss()
for epoch in range(100):
  for i,(image,_) in enumerate(traindataset):
    if(image.shape[0]!=32):
      continue
    image=image.to(device)
    noise=torch.randn(32,64,512//16,512//16).to(device)
    t=torch.randint(0,1000,(32,)).to(device)
    #optimize vae
    optimizer_vae.zero_grad()
    vae.train()
    output,mu,logvar=vae(image)
    loss_vae=criterion(output,image)
    loss_vae.backward()
    optimizer_vae.step()
    #optimize unet
    optimizer_unet.zero_grad()
    vae.eval()
    z,mu,logvar=vae.encoder(image)
    z=time_schdule.add_noise(z,noise,t)
    noise_pred=unet(z)
    loss_unet=criterion(noise_pred,noise)
    loss_unet.backward()
    optimizer_unet.step()
    if(i%100==0):
      with torch.inference_mode():
        print("epoch:",epoch,"batch:",i,"loss vae:",loss_vae.item(),"loss unet:",loss_unet.item())
        z,mu,logvar=vae.encoder(image)
        noise_pred=unet(z)
        z,z0=time_schdule.sample_prev_timestep(z,noise_pred,t)
        output=vae.decoder(z0)
        torchvision.utils.save_image(output,f"/content/output_{epoch}_{i}.png")
        torch.save(vae.state_dict(),'/content/drive/MyDrive/My_Models/stable_diffusion_vae.pt')
        torch.save(unet.state_dict(),'/content/drive/MyDrive/My_Models/stable_diffusion_unet.pt')

몇번의 epoch를 수행하면 다음과

학습중인 이미지

같은 이미지들을 생성할 수 있고 학습이 된다는 것을 알 수 있다.

더 좋은 사전학습된 vae를 쓰면 더 빨리 좋은 결과를 얻을 수 있다.

 

아마 더 학습하면 더 좋은 퀄리티의 이미지를 만들 수 있을것이다.