How do I visualize the latent representation produced by the Stable Diffusion VAE?

6 days ago 13
ARTICLE AD BOX

I am trying to visualize the latent representation produced by the VAE inside a Stable Diffusion pipeline

from diffusers import StableDiffusionPipeline import torch # A CUDA ordinal is simply the integer ID of a GPU in a system that has one or more GPUs. def get_device(cuda_ordinal=None): if torch.cuda.is_available(): return torch.device("cuda", cuda_ordinal) if torch.backends.mps.is_available(): return torch.device("mps") return torch.device("cpu") device=get_device() pipe=StableDiffusionPipeline.from_pretrained( "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16" ).to(device)

I wrote load_image helper

from PIL import Image from io import BytesIO import requests def load_image(url, size=None, return_tensor=False): if url.startswith("http"): response = requests.get(url) img = Image.open(BytesIO(response.content)) else: img = Image.open(url) if size is not None: img = img.resize(size) if return_tensor: return TF.to_tensor(img) return img

I successfully loaded and plotted with matplotlib library

import matplotlib.pyplot as plt im=load_image("https://media.gettyimages.com/id/2244914756/photo/a-small-island-shaped-like-a-straw-hat-by-the-blue-lake.jpg?s=2048x2048&w=gi&k=20&c=wzAXUuF06cnKTYrV7Xs8DG9jHPvrAf-1tW2Vs53VXOg=",size=(512,512)) plt.imshow(im) plt.axis('off') # remove axes plt.show()

The VAE compresses this image into a four-channel latent representation and i like to visualize each channel of VAE's latent representation:

from torchvision import transforms with torch.inference_mode(): tensor_im=transforms.ToTensor()(im).unsqueeze(0).to(device)*2-1 tensor_im=tensor_im.half() latent=pipe.vae.encode(tensor_im) latents=latent.latent_dist.sample() latents=latents*0.18215 latents.shape #torch.Size([1, 4, 64, 64])

this is helper function to plot each representation

import matplotlib.pyplot as plt def show_latent_channels(latents): lat = latents[0] # [4, 64, 64] num_ch = lat.shape[0] fig, axes = plt.subplots(1, num_ch, figsize=(12,3)) for i in range(num_ch): ch = lat[i] # normalize channel to [0,1] for visualization ch = (ch - ch.min()) / (ch.max() - ch.min()) axes[i].imshow(ch.cpu(), cmap="viridis") axes[i].set_title(f"Channel {i}") axes[i].axis("off") plt.show()

when I call show_latent_channels(latents) I get this

enter image description here
My goal is simply to “see” what the VAE is storing inside the latent space but matplotlib shows an empty plot

Read Entire Article