Skip to content

Instantly share code, notes, and snippets.

@dribnet
Forked from karpathy/stablediffusionwalk.py
Last active August 16, 2022 03:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dribnet/f2b39cc212d7c927c3c8734cbe45db6c to your computer and use it in GitHub Desktop.
Save dribnet/f2b39cc212d7c927c3c8734cbe45db6c to your computer and use it in GitHub Desktop.
hacky stablediffusion code for generating videos
"""
draws many samples from a diffusion model by slerp'ing around
the noise space, and dumps frames to a directory. You can then
stitch up the frames with e.g.:
$ ffmpeg -r 10 -f image2 -s 512x512 -i out/frame%06d.jpg -vcodec libx264 -crf 10 -pix_fmt yuv420p test.mp4
THIS FILE IS HACKY AND NOT CONFIGURABLE READ THE CODE, MAKE EDITS TO PATHS AND SETTINGS YOU LIKE
THIS FILE IS HACKY AND NOT CONFIGURABLE READ THE CODE, MAKE EDITS TO PATHS AND SETTINGS YOU LIKE
THIS FILE IS HACKY AND NOT CONFIGURABLE READ THE CODE, MAKE EDITS TO PATHS AND SETTINGS YOU LIKE
nice slerp def from @xsteenbrugge ty
you have to have access to stablediffusion checkpoints from https://huggingface.co/CompVis
and install all the other dependencies (e.g. diffusers library)
"""
from diffusers import StableDiffusionPipeline
from time import time
from PIL import Image
from einops import rearrange
import numpy as np
import torch
from torch import autocast
from torchvision.utils import make_grid
torch.manual_seed(42)
import os
HF_TOKEN = os.environ['HF_TOKEN']
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-3-diffusers", use_auth_token=HF_TOKEN)
torch_device = 'cuda'
pipe.unet.to(torch_device)
pipe.vae.to(torch_device)
pipe.text_encoder.to(torch_device)
print('w00t')
batch_size = 1
height = 512
width = 512
prompt = ["ultrarealistic steam punk neural network machine in the shape of a brain, placed on a pedestal, covered with neurons made of gears. dramatic lighting. #unrealengine"] * 1
text_input = pipe.tokenizer(prompt, padding=True, truncation=True, return_tensors="pt")
text_embeddings = pipe.text_encoder(text_input.input_ids.to(torch_device))[0]
@torch.no_grad()
def diffuse(text_embeddings, init, guidance_scale = 7.5):
# text_embeddings are n,t,d
max_length = text_embeddings.shape[1]
uncond_input = pipe.tokenizer([""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt")
uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(torch_device))[0]
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
latents = init.clone()
num_inference_steps = 50
pipe.scheduler.set_timesteps(num_inference_steps)
for t in pipe.scheduler.timesteps:
# predict the noise residual
latent_model_input = torch.cat([latents] * 2) # for cfg
noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
# perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = pipe.scheduler.step(noise_pred, t, latents)["prev_sample"]
# post-process
latents = 1 / 0.18215 * latents
image = pipe.vae.decode(latents)
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
return image
def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
if not isinstance(v0, np.ndarray):
inputs_are_torch = True
input_device = v0.device
v0 = v0.cpu().numpy()
v1 = v1.cpu().numpy()
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
if np.abs(dot) > DOT_THRESHOLD:
v2 = (1 - t) * v0 + t * v1
else:
theta_0 = np.arccos(dot)
sin_theta_0 = np.sin(theta_0)
theta_t = theta_0 * t
sin_theta_t = np.sin(theta_t)
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
s1 = sin_theta_t / sin_theta_0
v2 = s0 * v0 + s1 * v1
if inputs_are_torch:
v2 = torch.from_numpy(v2).to(input_device)
return v2
# DREAM
# sample start
init1 = torch.randn((batch_size, pipe.unet.in_channels, height // 8, width // 8)).to(torch_device)
n = 0
while True:
# sample destination
init2 = torch.randn((batch_size, pipe.unet.in_channels, height // 8, width // 8)).to(torch_device)
for i, t in enumerate(np.linspace(0, 1, 200)):
init = slerp(float(t), init1, init2)
image = diffuse(text_embeddings, init, guidance_scale=10.0)
im = Image.fromarray((image[0] * 255).astype(np.uint8))
im.save('outputs/mov1/frame%06d.jpg' % n)
print('dreaming... ', n)
n += 1
init1 = init2
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment