def build(self, input_shapes): context_shape = input_shapes[1] height, width = context_shape[1:3] ff_size = self.config.ff_size hidden_size = self.config.hidden_size num_heads = self.config.num_heads res = [height, width] self.pos_embed = coltran_layers.PositionEmbed(axes=[1, 2], max_lengths=res) self.shift_right = coltran_layers.Shift(dimension=1, resolution=res) self.residual_layers, self.layer_norms, self.cmlp_layers = [], [], [] num_norms = 2 * self.config.num_inner_layers if self.cond_ln: for _ in range(num_norms): curr_norm = coltran_layers.ConditionalLayerNorm( spatial_average=self.cond_ln_sp_ave, sequence=self.cond_ln_seq, out_init=self.cond_ln_init, out_act=self.cond_ln_act) self.layer_norms.append(curr_norm) else: self.layer_norms = [ layers.LayerNormalization() for _ in range(num_norms) ] for layer_ind in range(self.config.num_inner_layers): mask_row = coltran_layers.SelfAttentionND( hidden_size=hidden_size, num_heads=num_heads, mask='future', nd_block_size=[1, width], resolution=[height, width], cond_q=self.cond_att_q, cond_k=self.cond_att_k, cond_v=self.cond_att_v, cond_init=self.cond_att_init, cond_scale=self.cond_att_scale, cond_act=self.cond_att_act, name='mask_row_att_%d' % layer_ind) ff_block = tf.keras.Sequential([ layers.Dense(units=ff_size, activation='relu'), layers.Dense(units=hidden_size) ], name='dense_%d' % layer_ind) self.residual_layers.append(mask_row) self.residual_layers.append(ff_block) if self.cond_mlp == 'shift': shift_c = layers.Dense(units=hidden_size, name='shift_c_%d' % layer_ind) self.cmlp_layers.append(shift_c) elif self.cond_mlp == 'affine': aff_c = layers.Dense(units=2 * hidden_size, name='affine_c_%d' % layer_ind) self.cmlp_layers.append(aff_c)
def test_position_embed(self): pos_embed = layers.PositionEmbed(axes=[1, 2], max_lengths=[64, 32]) inputs = tf.random.uniform(shape=(8, 64, 32, 256)) embedded = pos_embed(inputs) for variable in pos_embed.variables: if len(variable.shape) == 3: self.assertEqual(variable.shape, (64, 1, 256)) else: self.assertEqual(variable.shape, (32, 256)) self.assertEqual(embedded.shape, (8, 64, 32, 256))
def build(self, input_shapes): embed_shape = input_shapes[0] height, width, num_filters = embed_shape[1:] hidden_size = self.config.hidden_size num_heads = self.config.num_heads ff_size = self.config.ff_size res = [height, width] self.pos_embed = coltran_layers.PositionEmbed(axes=[1, 2], max_lengths=res) self.residual_layers, self.layer_norms, self.cmlp_layers = [], [], [] num_norms = self.config.num_outer_layers * 4 if self.cond_ln: for _ in range(num_norms): curr_norm = coltran_layers.ConditionalLayerNorm( spatial_average=self.cond_ln_sp_ave, sequence=self.cond_ln_seq, out_init=self.cond_ln_init, out_act=self.cond_ln_act) self.layer_norms.append(curr_norm) else: self.layer_norms = [ layers.LayerNormalization() for _ in range(num_norms) ] for layer_ind in range(self.config.num_outer_layers): # unmasked row unmask_row = coltran_layers.SelfAttentionND( hidden_size=hidden_size, num_heads=num_heads, nd_block_size=[1, width], resolution=[height, width], cond_q=self.cond_att_q, cond_k=self.cond_att_k, cond_v=self.cond_att_v, cond_init=self.cond_att_init, cond_scale=self.cond_att_scale, cond_act=self.cond_att_act, name='unmask_row_att_%d' % layer_ind) ff_row = tf.keras.Sequential([ layers.Dense(units=ff_size, activation='relu'), layers.Dense(units=num_filters) ], name='row_dense_%d' % layer_ind) # masked column, mask_col = coltran_layers.SelfAttentionND( hidden_size=hidden_size, num_heads=num_heads, mask='future', nd_block_size=[height, 1], resolution=[height, width], cond_q=self.cond_att_q, cond_k=self.cond_att_k, cond_v=self.cond_att_v, cond_act=self.cond_att_act, cond_init=self.cond_att_init, cond_scale=self.cond_att_scale, name='mask_col_att_%d' % layer_ind) ff_col = tf.keras.Sequential([ layers.Dense(units=ff_size, activation='relu'), layers.Dense(units=num_filters) ], name='col_dense_%d' % layer_ind) self.residual_layers.append(unmask_row) self.residual_layers.append(ff_row) self.residual_layers.append(mask_col) self.residual_layers.append(ff_col) # Conditional MLP layers. if self.cond_mlp == 'shift': shift_r = layers.Dense(units=hidden_size, name='shift_r_%d' % layer_ind) shift_c = layers.Dense(units=hidden_size, name='shift_c_%d' % layer_ind) self.cmlp_layers.append(shift_r) self.cmlp_layers.append(shift_c) elif self.cond_mlp == 'affine': aff_r = layers.Dense(units=2 * hidden_size, name='affine_r_%d' % layer_ind) aff_c = layers.Dense(units=2 * hidden_size, name='affine_c_%d' % layer_ind) self.cmlp_layers.append(aff_r) self.cmlp_layers.append(aff_c) self.shift_down = coltran_layers.Shift(dimension=0, resolution=res)