Processing math: 100%

Glimpse PoW

VAEs

Introduction

Autoencoders are a class of neural networks primarily used for unsupervised learning and dimensionality reduction. The fundamental idea behind autoencoders is to encode input data into a lower-dimensional representation and then decode it back to the original data, aiming to minimize the reconstruction error. They are also used for inpainting and anamoly detection in medical images

Architecture

The basic architecture of an autoencoder consists of two main components - the encoder and the decoder

  • Encoder
    The encoder is responsible for transforming the input data into a compressed or latent representation. It typically consists of one or more layers of neurons that progressively reduce the dimensions of the input.
  • Decoder
    The decoder, on the other hand, takes the compressed representation produced by the encoder and attempts to reconstruct the original input data. Like the encoder, it often consists of one or more layers, but in the reverse order, gradually increasing the dimensions.

arch

  • Difference between traditional AEs and VAEs
    Variational autoencoder addresses the issue of non-regularized latent space in autoencoder and provides the generative capability to the entire space. The encoder in the AE outputs latent vectors. Instead of outputting the vectors in the latent space, the encoder of VAE outputs parameters of a pre-defined distribution in the latent space for every input. The VAE then imposes a constraint on this latent distribution forcing it to be a normal distribution. This constraint makes sure that the latent space is regularized.
  • Probabilistic Nature
    Unlike deterministic autoencoders, VAEs model the latent space as a probability distribution. This produces a probability distribution function over the input encodings instead of just a single fixed vector. This allows for a more nuanced representation of uncertainty in the data. The decoder then samples from this probability distribution.
  • Role of Latent Space
    The latent space in VAEs serves as a continuous, structured representation of the input data. Since it is continuous by design, this allows for easy interpolation. Each point in the latent space corresponds to a potential output, enabling smooth transitions between different data points and also making sure that points which are closer to the latent space lead to similar generation

Math

1) Probabilistic Modeling: In VAEs, the latent space is modeled as a probability distribution, often assumed to be a multivariate Gaussian. This distribution is parameterized by the mean and standard deviation vectors, which are outputs of the probabilistic encoder q ϕ ( z ∣ x ) q ϕ ​ (z∣x). This comprosises of our learned representation z which is further used to sample from the decoder as p θ ( x ∣ z ) p θ ​ (x∣z)

2) Loss Function The loss function for VAEs comprises two components: the reconstruction loss (measuring how well the model reconstructs the input) similar to the vanilla autoencoders and the KL divergence (measuring how closely the learned distribution resembles a chosen prior distribution, usually gaussian). The combination of these components encourages the model to learn a latent representation that captures both the data distribution and the specified prior.

ELBO

p θ ( x ) = ∫ z p θ ( z ) p θ ( x | z ) d z

this probability of drawing an image distibutution given it’s latent representation is “intractable” - bayesian statistics , Reparametarization Trick

There are several approaches to solving this intractability, one of them is variational inference, and the intuition behind this approach is the following. For a specific image x ( i ) , there is a tiny area in z with a high probability of generating x ( i ) . For all other z , the probability is close to 0. For example, think about the image of the black cat from above; there is probably a tiny region in z that represents black cats with the head on the left and two open eyes. The probability of generating this black cat from the area representing white cats is ~0.

The key idea is instead of integrating all of z , we compute p θ ( x ( i ) ) just by sampling from the tiny area in z , which is most likely to generate x ( i ) . To find the area in z most probable of generating x ( i ) , we need the posterior p θ ( z | x ) . Unfortunately, the posterior is hidden from us, but! we can estimate it with a model q ϕ ( z | x ) called the probabilistic encoder.

alt text With this model, we will compute p θ ( x ( i ) ) by first passing x ( i ) through the probabilistic encoder q ϕ ( z | x ) , and the output will be a small distribution over a tiny area in z . Then, we sample from that distribution and compute p θ ( x | z ) on the samples. I’ll get into these details later.

alt text

Probabilistic Decoding - Binary - Bernoulli , Real - Gaussian

alt text

Code

Beta (Disentangled) VAE

`import torch from models import BaseVAE from torch import nn from torch.nn import functional as F from .types_ import *

class BetaVAE(BaseVAE):

num_iter = 0 # Global static variable to keep track of iterations

def __init__(self,
             in_channels: int,
             latent_dim: int,
             hidden_dims: List = None,
             beta: int = 4,
             gamma:float = 1000.,
             max_capacity: int = 25,
             Capacity_max_iter: int = 1e5,
             loss_type:str = 'B',
             **kwargs) -> None:
    super(BetaVAE, self).__init__()

    self.latent_dim = latent_dim
    self.beta = beta
    self.gamma = gamma
    self.loss_type = loss_type
    self.C_max = torch.Tensor([max_capacity])
    self.C_stop_iter = Capacity_max_iter

    modules = []
    if hidden_dims is None:
        hidden_dims = [32, 64, 128, 256, 512]

    # Build Encoder
    for h_dim in hidden_dims:
        modules.append(
            nn.Sequential(
                nn.Conv2d(in_channels, out_channels=h_dim,
                          kernel_size= 3, stride= 2, padding  = 1),
                nn.BatchNorm2d(h_dim),
                nn.LeakyReLU())
        )
        in_channels = h_dim

    self.encoder = nn.Sequential(*modules)
    self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)
    self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)


    # Build Decoder
    modules = []

    self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)

    hidden_dims.reverse()

    for i in range(len(hidden_dims) - 1):
        modules.append(
            nn.Sequential(
                nn.ConvTranspose2d(hidden_dims[i],
                                   hidden_dims[i + 1],
                                   kernel_size=3,
                                   stride = 2,
                                   padding=1,
                                   output_padding=1),
                nn.BatchNorm2d(hidden_dims[i + 1]),
                nn.LeakyReLU())
        )



    self.decoder = nn.Sequential(*modules)

    self.final_layer = nn.Sequential(
                        nn.ConvTranspose2d(hidden_dims[-1],
                                           hidden_dims[-1],
                                           kernel_size=3,
                                           stride=2,
                                           padding=1,
                                           output_padding=1),
                        nn.BatchNorm2d(hidden_dims[-1]),
                        nn.LeakyReLU(),
                        nn.Conv2d(hidden_dims[-1], out_channels= 3,
                                  kernel_size= 3, padding= 1),
                        nn.Tanh())

def encode(self, input: Tensor) -> List[Tensor]:
    """
    Encodes the input by passing through the encoder network
    and returns the latent codes.
    :param input: (Tensor) Input tensor to encoder [N x C x H x W]
    :return: (Tensor) List of latent codes
    """
    result = self.encoder(input)
    result = torch.flatten(result, start_dim=1)

    # Split the result into mu and var components
    # of the latent Gaussian distribution
    mu = self.fc_mu(result)
    log_var = self.fc_var(result)

    return [mu, log_var]

def decode(self, z: Tensor) -> Tensor:
    result = self.decoder_input(z)
    result = result.view(-1, 512, 2, 2)
    result = self.decoder(result)
    result = self.final_layer(result)
    return result

def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
    """
    Will a single z be enough ti compute the expectation
    for the loss??
    :param mu: (Tensor) Mean of the latent Gaussian
    :param logvar: (Tensor) Standard deviation of the latent Gaussian
    :return:
    """
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return eps * std + mu

def forward(self, input: Tensor, **kwargs) -> Tensor:
    mu, log_var = self.encode(input)
    z = self.reparameterize(mu, log_var)
    return  [self.decode(z), input, mu, log_var]

def loss_function(self,
                  *args,
                  **kwargs) -> dict:
    self.num_iter += 1
    recons = args[0]
    input = args[1]
    mu = args[2]
    log_var = args[3]
    kld_weight = kwargs['M_N']  # Account for the minibatch samples from the dataset

    recons_loss =F.mse_loss(recons, input)

    kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

    if self.loss_type == 'H': # https://openreview.net/forum?id=Sy2fzU9gl
        loss = recons_loss + self.beta * kld_weight * kld_loss
    elif self.loss_type == 'B': # https://arxiv.org/pdf/1804.03599.pdf
        self.C_max = self.C_max.to(input.device)
        C = torch.clamp(self.C_max/self.C_stop_iter * self.num_iter, 0, self.C_max.data[0])
        loss = recons_loss + self.gamma * kld_weight* (kld_loss - C).abs()
    else:
        raise ValueError('Undefined loss type.')

    return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':kld_loss}

def sample(self,
           num_samples:int,
           current_device: int, **kwargs) -> Tensor:
    """
    Samples from the latent space and return the corresponding
    image space map.
    :param num_samples: (Int) Number of samples
    :param current_device: (Int) Device to run the model
    :return: (Tensor)
    """
    z = torch.randn(num_samples,
                    self.latent_dim)

    z = z.to(current_device)

    samples = self.decode(z)
    return samples

def generate(self, x: Tensor, **kwargs) -> Tensor:
    """
    Given an input image x, returns the reconstructed image
    :param x: (Tensor) [B x C x H x W]
    :return: (Tensor) [B x C x H x W]
    """

    return self.forward(x)[0] ```

Similar

Beta (Disentangled) VAEs - Introduces a beta variable in the second part of the loss function, this way it does not take all the latents effects but only the important ones.

References

https://yonigottesman.github.io/2023/03/11/vae.html
https://huggingface.co/learn/computer-vision-course/en/unit5/generative-models/variational_autoencoders
https://github.com/AntixK/PyTorch-VAE/tree/master/models