Instructions to use BiliSakura/BitDance-Tokenizer-diffusers with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use BiliSakura/BitDance-Tokenizer-diffusers with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline from diffusers.utils import load_image # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("BiliSakura/BitDance-Tokenizer-diffusers", dtype=torch.bfloat16, device_map="cuda") prompt = "Turn this cat into a dog" input_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png") image = pipe(image=input_image, prompt=prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| from contextlib import nullcontext | |
| from typing import List, Optional, Tuple, Union | |
| import torch | |
| from einops import rearrange | |
| from PIL import Image | |
| from tqdm.auto import tqdm | |
| from diffusers import DiffusionPipeline | |
| from diffusers.pipelines.pipeline_utils import ImagePipelineOutput | |
| from .constants import SUPPORTED_IMAGE_SIZES | |
| PromptType = Union[str, List[str]] | |
| def _get_pkv_seq_len(past_key_values) -> int: | |
| """Get cached sequence length from past_key_values (supports tuple and DynamicCache).""" | |
| if hasattr(past_key_values, "get_seq_length"): | |
| return past_key_values.get_seq_length() | |
| return past_key_values[0][0].shape[2] | |
| class BitDanceDiffusionPipeline(DiffusionPipeline): | |
| model_cpu_offload_seq = "text_encoder->projector->diffusion_head->autoencoder" | |
| def __init__( | |
| self, | |
| tokenizer, | |
| text_encoder, | |
| autoencoder, | |
| diffusion_head, | |
| projector, | |
| supported_image_sizes: Optional[List[List[int]]] = None, | |
| dtype: Optional[torch.dtype] = None, | |
| ) -> None: | |
| super().__init__() | |
| self.register_modules( | |
| tokenizer=tokenizer, | |
| text_encoder=text_encoder, | |
| autoencoder=autoencoder, | |
| diffusion_head=diffusion_head, | |
| projector=projector, | |
| ) | |
| image_sizes = supported_image_sizes or SUPPORTED_IMAGE_SIZES | |
| self.register_to_config(supported_image_sizes=[list(size) for size in image_sizes]) | |
| self.hidden_size = self.text_encoder.config.hidden_size | |
| self.vae_patch_size = self.autoencoder.patch_size | |
| self.parallel_num = int(self.diffusion_head.config.parallel_num) | |
| self.ps = int(self.parallel_num**0.5) | |
| if self.ps * self.ps != self.parallel_num: | |
| raise ValueError( | |
| f"parallel_num must be a perfect square (got {self.parallel_num})." | |
| ) | |
| self._build_pos_embed() | |
| def supported_image_sizes(self) -> List[List[int]]: | |
| return [list(size) for size in self.config.supported_image_sizes] | |
| def _execution_device_fallback(self) -> torch.device: | |
| if getattr(self, "_execution_device", None) is not None: | |
| return self._execution_device | |
| return next(self.text_encoder.parameters()).device | |
| def _build_pos_embed(self) -> None: | |
| max_resolution = max(max(size) for size in self.supported_image_sizes) | |
| max_len = max_resolution // self.vae_patch_size | |
| pos_embed_1d = self._get_1d_sincos_pos_embed(self.hidden_size // 2, max_len) | |
| self.pos_embed_1d = pos_embed_1d | |
| def _get_1d_sincos_pos_embed(dim: int, max_len: int, pe_interpolation: float = 1.0) -> torch.Tensor: | |
| if dim % 2 != 0: | |
| raise ValueError(f"dim must be even, got {dim}") | |
| omega = torch.arange(dim // 2, dtype=torch.float32) | |
| omega /= dim / 2.0 | |
| omega = 1.0 / 10000**omega | |
| pos = torch.arange(max_len, dtype=torch.float32) / pe_interpolation | |
| out = torch.einsum("m,d->md", pos, omega) | |
| emb_sin = torch.sin(out) | |
| emb_cos = torch.cos(out) | |
| return torch.cat([emb_sin, emb_cos], dim=1) | |
| def _get_2d_embed(self, h: int, w: int, ps: int = 1) -> torch.Tensor: | |
| emb_v = self.pos_embed_1d[:h] | |
| emb_h = self.pos_embed_1d[:w] | |
| grid_v = emb_v.view(h, 1, self.hidden_size // 2).repeat(1, w, 1) | |
| grid_h = emb_h.view(1, w, self.hidden_size // 2).repeat(h, 1, 1) | |
| pos_embed = torch.cat([grid_h, grid_v], dim=-1) | |
| return rearrange(pos_embed, "(h p1) (w p2) c -> (h w p1 p2) c", p1=ps, p2=ps) | |
| def _encode_prompt_to_embeds( | |
| self, | |
| prompt: str, | |
| image_size: Tuple[int, int], | |
| num_images_per_prompt: int, | |
| guidance_scale: float, | |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]: | |
| device = self._execution_device_fallback() | |
| model = self.text_encoder.model | |
| tokenizer = self.tokenizer | |
| cond_prompt = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" | |
| uncond_prompt = "<|im_start|>assistant\n" | |
| cond_ids = torch.tensor(tokenizer.encode(cond_prompt), device=device, dtype=torch.long) | |
| cond_emb = model.embed_tokens(cond_ids) | |
| uncond_emb = None | |
| if guidance_scale > 1.0: | |
| uncond_ids = torch.tensor(tokenizer.encode(uncond_prompt), device=device, dtype=torch.long) | |
| uncond_emb = model.embed_tokens(uncond_ids) | |
| image_h, image_w = image_size | |
| img_start_id = tokenizer.convert_tokens_to_ids("<|vision_start|>") | |
| res_h_token_id = tokenizer.convert_tokens_to_ids(f"<|res_{image_h // self.vae_patch_size}|>") | |
| res_w_token_id = tokenizer.convert_tokens_to_ids(f"<|res_{image_w // self.vae_patch_size}|>") | |
| img_start_emb = model.embed_tokens(torch.tensor([img_start_id, res_h_token_id, res_w_token_id], device=device)) | |
| for i in range(1, self.parallel_num): | |
| query_token_id = tokenizer.convert_tokens_to_ids(f"<|query_{i}|>") | |
| query_token = torch.tensor([query_token_id], device=device, dtype=torch.long) | |
| query_embed = model.embed_tokens(query_token) | |
| img_start_emb = torch.cat([img_start_emb, query_embed], dim=0) | |
| input_embeds_cond = torch.cat([cond_emb, img_start_emb], dim=0).unsqueeze(0).repeat(num_images_per_prompt, 1, 1) | |
| input_embeds_uncond = None | |
| if guidance_scale > 1.0 and uncond_emb is not None: | |
| input_embeds_uncond = torch.cat([uncond_emb, img_start_emb], dim=0).unsqueeze(0).repeat(num_images_per_prompt, 1, 1) | |
| return input_embeds_cond, input_embeds_uncond, img_start_emb | |
| def _decode_tokens_to_image(self, image_latents: torch.Tensor, image_size: Tuple[int, int], ps: int = 1) -> torch.Tensor: | |
| h, w = image_size | |
| image_latents = rearrange(image_latents, "b (h w p1 p2) c -> b c (h p1) (w p2)", h=h // ps, w=w // ps, p1=ps, p2=ps) | |
| return self.autoencoder.decode(image_latents) | |
| def _generate_single_prompt( | |
| self, | |
| prompt: str, | |
| height: int, | |
| width: int, | |
| num_inference_steps: int, | |
| guidance_scale: float, | |
| num_images_per_prompt: int, | |
| generator: Optional[torch.Generator], | |
| show_progress_bar: bool, | |
| ) -> torch.Tensor: | |
| image_size = (height, width) | |
| if list(image_size) not in self.supported_image_sizes: | |
| raise ValueError( | |
| f"image_size {list(image_size)} is not supported. " | |
| f"Please choose from {self.supported_image_sizes}" | |
| ) | |
| h, w = height // self.vae_patch_size, width // self.vae_patch_size | |
| max_length = h * w | |
| step_width = self.parallel_num | |
| if max_length % step_width != 0: | |
| raise ValueError( | |
| f"max_length ({max_length}) must be divisible by parallel_num ({step_width})." | |
| ) | |
| num_steps = max_length // step_width | |
| device = self._execution_device_fallback() | |
| model = self.text_encoder.model | |
| dtype = next(self.text_encoder.parameters()).dtype | |
| input_embeds_cond, input_embeds_uncond, _ = self._encode_prompt_to_embeds( | |
| prompt=prompt, | |
| image_size=image_size, | |
| num_images_per_prompt=num_images_per_prompt, | |
| guidance_scale=guidance_scale, | |
| ) | |
| pos_embed_for_diff = self._get_2d_embed(h, w, ps=self.ps).unsqueeze(0).to(device=device, dtype=dtype) | |
| autocast_ctx = ( | |
| torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16) | |
| if device.type == "cuda" | |
| else nullcontext() | |
| ) | |
| with autocast_ctx: | |
| outputs_c = model(inputs_embeds=input_embeds_cond[:, :-step_width, :], use_cache=True) | |
| pkv_c = outputs_c.past_key_values | |
| bi_attn_mask = torch.ones( | |
| (input_embeds_cond.shape[0], 1, step_width, step_width + _get_pkv_seq_len(pkv_c)), | |
| dtype=torch.bool, | |
| device=device, | |
| ) | |
| outputs_c = model( | |
| inputs_embeds=input_embeds_cond[:, -step_width:, :], | |
| past_key_values=pkv_c, | |
| use_cache=True, | |
| attention_mask=bi_attn_mask, | |
| ) | |
| pkv_c = outputs_c.past_key_values | |
| hidden_c = outputs_c.last_hidden_state[:, -step_width:] | |
| hidden_u = None | |
| pkv_u = None | |
| if guidance_scale > 1.0 and input_embeds_uncond is not None: | |
| outputs_u = model(inputs_embeds=input_embeds_uncond[:, :-step_width, :], use_cache=True) | |
| pkv_u = outputs_u.past_key_values | |
| bi_attn_mask_u = torch.ones( | |
| (input_embeds_uncond.shape[0], 1, step_width, step_width + _get_pkv_seq_len(pkv_u)), | |
| dtype=torch.bool, | |
| device=device, | |
| ) | |
| outputs_u = model( | |
| inputs_embeds=input_embeds_uncond[:, -step_width:, :], | |
| past_key_values=pkv_u, | |
| use_cache=True, | |
| attention_mask=bi_attn_mask_u, | |
| ) | |
| pkv_u = outputs_u.past_key_values | |
| hidden_u = outputs_u.last_hidden_state[:, -step_width:] | |
| out_tokens = [] | |
| step_iter = range(num_steps) | |
| if show_progress_bar: | |
| step_iter = tqdm(step_iter, total=num_steps, desc="Decoding steps") | |
| for step in step_iter: | |
| if guidance_scale > 1.0 and hidden_u is not None: | |
| h_fused = torch.cat([hidden_c, hidden_u], dim=0) | |
| else: | |
| h_fused = hidden_c | |
| pos_slice = pos_embed_for_diff[:, step * step_width : (step + 1) * step_width, :] | |
| h_fused = h_fused + pos_slice | |
| pred_latents = self.diffusion_head.sample( | |
| h_fused, | |
| num_sampling_steps=num_inference_steps, | |
| cfg=guidance_scale, | |
| generator=generator, | |
| ) | |
| curr_tokens = torch.sign(pred_latents) | |
| curr_embeds = self.projector(curr_tokens) | |
| out_tokens.append(curr_tokens[:num_images_per_prompt]) | |
| model_input = curr_embeds + pos_slice | |
| bi_attn_mask = torch.ones( | |
| (model_input.shape[0], 1, model_input.shape[1], model_input.shape[1] + _get_pkv_seq_len(pkv_c)), | |
| dtype=torch.bool, | |
| device=device, | |
| ) | |
| outputs_c = model( | |
| inputs_embeds=model_input[:num_images_per_prompt], | |
| past_key_values=pkv_c, | |
| use_cache=True, | |
| attention_mask=bi_attn_mask[:num_images_per_prompt], | |
| ) | |
| pkv_c = outputs_c.past_key_values | |
| hidden_c = outputs_c.last_hidden_state[:, -step_width:] | |
| if guidance_scale > 1.0 and hidden_u is not None and pkv_u is not None: | |
| bi_attn_mask_u = torch.ones( | |
| (model_input.shape[0], 1, model_input.shape[1], model_input.shape[1] + _get_pkv_seq_len(pkv_u)), | |
| dtype=torch.bool, | |
| device=device, | |
| ) | |
| outputs_u = model( | |
| inputs_embeds=model_input[num_images_per_prompt:], | |
| past_key_values=pkv_u, | |
| use_cache=True, | |
| attention_mask=bi_attn_mask_u[num_images_per_prompt:], | |
| ) | |
| pkv_u = outputs_u.past_key_values | |
| hidden_u = outputs_u.last_hidden_state[:, -step_width:] | |
| full_output = torch.cat(out_tokens, dim=1) | |
| return self._decode_tokens_to_image(full_output, image_size=(h, w), ps=self.ps) | |
| def __call__( | |
| self, | |
| prompt: PromptType, | |
| height: int = 1024, | |
| width: int = 1024, | |
| num_inference_steps: int = 50, | |
| guidance_scale: float = 7.5, | |
| num_images_per_prompt: int = 1, | |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
| output_type: str = "pil", | |
| return_dict: bool = True, | |
| show_progress_bar: bool = False, | |
| ) -> Union[ImagePipelineOutput, Tuple]: | |
| prompts = [prompt] if isinstance(prompt, str) else list(prompt) | |
| if len(prompts) == 0: | |
| raise ValueError("prompt must be a non-empty string or list of strings.") | |
| if isinstance(generator, list) and len(generator) != len(prompts): | |
| raise ValueError("When passing a list of generators, its length must equal len(prompt).") | |
| image_tensors = [] | |
| for i, prompt_text in enumerate(prompts): | |
| prompt_generator = generator[i] if isinstance(generator, list) else generator | |
| images = self._generate_single_prompt( | |
| prompt=prompt_text, | |
| height=height, | |
| width=width, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| num_images_per_prompt=num_images_per_prompt, | |
| generator=prompt_generator, | |
| show_progress_bar=show_progress_bar, | |
| ) | |
| image_tensors.append(images) | |
| images_pt = torch.cat(image_tensors, dim=0) | |
| images_pt_01 = torch.clamp((images_pt + 1.0) / 2.0, 0.0, 1.0) | |
| if output_type == "pt": | |
| output_images = images_pt_01 | |
| elif output_type == "np": | |
| output_images = images_pt_01.permute(0, 2, 3, 1).float().cpu().numpy() | |
| elif output_type == "pil": | |
| images_uint8 = ( | |
| torch.clamp(127.5 * images_pt + 128.0, 0, 255) | |
| .permute(0, 2, 3, 1) | |
| .to("cpu", dtype=torch.uint8) | |
| .numpy() | |
| ) | |
| output_images = [Image.fromarray(image) for image in images_uint8] | |
| else: | |
| raise ValueError(f"Unsupported output_type={output_type}. Expected 'pil', 'np', or 'pt'.") | |
| if not return_dict: | |
| return (output_images,) | |
| return ImagePipelineOutput(images=output_images) | |
| def generate( | |
| self, | |
| prompt: str, | |
| height: int = 1024, | |
| width: int = 1024, | |
| num_sampling_steps: int = 50, | |
| guidance_scale: float = 7.5, | |
| num_images: int = 1, | |
| seed: Optional[int] = None, | |
| ) -> List[Image.Image]: | |
| generator = None | |
| if seed is not None: | |
| device = self._execution_device_fallback() | |
| generator_device = "cuda" if device.type == "cuda" else "cpu" | |
| generator = torch.Generator(device=generator_device).manual_seed(seed) | |
| output = self( | |
| prompt=prompt, | |
| height=height, | |
| width=width, | |
| num_inference_steps=num_sampling_steps, | |
| guidance_scale=guidance_scale, | |
| num_images_per_prompt=num_images, | |
| generator=generator, | |
| output_type="pil", | |
| return_dict=True, | |
| show_progress_bar=True, | |
| ) | |
| return output.images | |