def __init__( self, *, dim_text = 512, dim_image = 512, dim_latent = 512, num_text_tokens = 10000, text_enc_depth = 6, text_seq_len = 256, text_heads = 8, num_visual_tokens = 512, visual_enc_depth = 6, visual_heads = 8, visual_image_size = 256, visual_patch_size = 32, channels = 3 ): super().__init__() self.text_emb = nn.Embedding(num_text_tokens, dim_text) self.text_pos_emb = nn.Embedding(text_seq_len, dim_text) self.text_transformer = Transformer(causal = False, seq_len = text_seq_len, dim = dim_text, depth = text_enc_depth, heads = text_heads) self.to_text_latent = nn.Linear(dim_text, dim_latent, bias = False) assert visual_image_size % visual_patch_size == 0, 'Image dimensions must be divisible by the patch size.' num_patches = (visual_image_size // visual_patch_size) ** 2 patch_dim = channels * visual_patch_size ** 2 self.visual_patch_size = visual_patch_size self.to_visual_embedding = nn.Linear(patch_dim, dim_image) self.visual_pos_emb = nn.Embedding(num_patches, dim_image) self.visual_transformer = Transformer(causal = False, seq_len = num_patches, dim = dim_image, depth = visual_enc_depth, heads = visual_heads) self.to_visual_latent = nn.Linear(dim_image, dim_latent, bias = False) self.temperature = nn.Parameter(torch.tensor(1.))
def __init__( self, *, dim, vae, num_text_tokens=10000, text_seq_len=256, depth, heads=8, dim_head=64, reversible=False, attn_dropout=0., ff_dropout=0, sparse_attn=False, attn_types=None, loss_img_weight=7, ): super().__init__() assert isinstance( vae, (DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE1024)), 'vae must be an instance of DiscreteVAE' image_size = vae.image_size num_image_tokens = vae.num_tokens image_fmap_size = (vae.image_size // (2**vae.num_layers)) image_seq_len = image_fmap_size**2 num_text_tokens = num_text_tokens + text_seq_len # reserve unique padding tokens for each position (text seq len) self.text_emb = nn.Embedding(num_text_tokens, dim) self.image_emb = nn.Embedding(num_image_tokens, dim) self.text_pos_emb = nn.Embedding(text_seq_len + 1, dim) # +1 for <bos> self.image_pos_emb = AxialPositionalEmbedding( dim, axial_shape=(image_fmap_size, image_fmap_size)) self.num_text_tokens = num_text_tokens # for offsetting logits index and calculating cross entropy loss self.num_image_tokens = num_image_tokens self.text_seq_len = text_seq_len self.image_seq_len = image_seq_len seq_len = text_seq_len + image_seq_len total_tokens = num_text_tokens + num_image_tokens self.total_tokens = total_tokens self.total_seq_len = seq_len self.vae = vae self.transformer = Transformer(dim=dim, causal=True, seq_len=seq_len, depth=depth, heads=heads, dim_head=dim_head, reversible=reversible, attn_dropout=attn_dropout, ff_dropout=ff_dropout, attn_types=attn_types, image_fmap_size=image_fmap_size, sparse_attn=sparse_attn) self.to_logits = nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, self.total_tokens), ) seq_range = torch.arange(seq_len) logits_range = torch.arange(total_tokens) seq_range = rearrange(seq_range, 'n -> () n ()') logits_range = rearrange(logits_range, 'd -> () () d') logits_mask = (((seq_range >= text_seq_len) & (logits_range < num_text_tokens)) | ((seq_range < text_seq_len) & (logits_range >= num_text_tokens))) self.register_buffer('logits_mask', logits_mask, persistent=False) self.loss_img_weight = loss_img_weight
def __init__( self, *, dim, vae, num_text_tokens = 10000, text_seq_len = 256, depth, heads = 8, dim_head = 64, reversible = False, attn_dropout = 0., ff_dropout = 0, sparse_attn = False ): super().__init__() assert isinstance(vae, DiscreteVAE), 'vae must be an instance of DiscreteVAE' image_size = vae.image_size num_image_tokens = vae.num_tokens image_seq_len = (vae.image_size // (2 ** vae.num_layers)) ** 2 self.text_emb = nn.Embedding(num_text_tokens, dim) self.image_emb = nn.Embedding(num_image_tokens, dim) self.text_pos_emb = nn.Embedding(text_seq_len, dim) self.image_pos_emb = AxialPositionalEmbedding(dim, axial_shape = (image_size, image_size)) self.num_text_tokens = num_text_tokens # for offsetting logits index and calculating cross entropy loss self.num_image_tokens = num_image_tokens self.text_seq_len = text_seq_len self.image_seq_len = image_seq_len seq_len = text_seq_len + image_seq_len total_tokens = num_text_tokens + num_image_tokens + 1 # extra for EOS self.total_tokens = total_tokens self.vae = vae if exists(self.vae): self.vae = vae self.image_emb = vae.codebook self.transformer = Transformer( dim = dim, causal = True, seq_len = seq_len, depth = depth, heads = heads, dim_head = dim_head, reversible = reversible, attn_dropout = attn_dropout, ff_dropout = ff_dropout, sparse_attn = sparse_attn ) self.to_logits = nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, self.total_tokens), ) seq_range = torch.arange(seq_len) logits_range = torch.arange(total_tokens) seq_range = rearrange(seq_range, 'n -> () n ()') logits_range = rearrange(logits_range, 'd -> () () d') logits_mask = ( ((seq_range >= (text_seq_len - 1)) & (logits_range < num_text_tokens)) | ((seq_range < (text_seq_len - 1)) & (logits_range >= num_text_tokens)) | ((seq_range != (seq_len - 1)) & (logits_range >= (total_tokens - 1))) ) self.register_buffer('logits_mask', logits_mask)