def build(self, input_shapes): context_shape = input_shapes[1] height, width = context_shape[1:3] ff_size = self.config.ff_size hidden_size = self.config.hidden_size num_heads = self.config.num_heads res = [height, width] self.pos_embed = coltran_layers.PositionEmbed(axes=[1, 2], max_lengths=res) self.shift_right = coltran_layers.Shift(dimension=1, resolution=res) self.residual_layers, self.layer_norms, self.cmlp_layers = [], [], [] num_norms = 2 * self.config.num_inner_layers if self.cond_ln: for _ in range(num_norms): curr_norm = coltran_layers.ConditionalLayerNorm( spatial_average=self.cond_ln_sp_ave, sequence=self.cond_ln_seq, out_init=self.cond_ln_init, out_act=self.cond_ln_act) self.layer_norms.append(curr_norm) else: self.layer_norms = [ layers.LayerNormalization() for _ in range(num_norms) ] for layer_ind in range(self.config.num_inner_layers): mask_row = coltran_layers.SelfAttentionND( hidden_size=hidden_size, num_heads=num_heads, mask='future', nd_block_size=[1, width], resolution=[height, width], cond_q=self.cond_att_q, cond_k=self.cond_att_k, cond_v=self.cond_att_v, cond_init=self.cond_att_init, cond_scale=self.cond_att_scale, cond_act=self.cond_att_act, name='mask_row_att_%d' % layer_ind) ff_block = tf.keras.Sequential([ layers.Dense(units=ff_size, activation='relu'), layers.Dense(units=hidden_size) ], name='dense_%d' % layer_ind) self.residual_layers.append(mask_row) self.residual_layers.append(ff_block) if self.cond_mlp == 'shift': shift_c = layers.Dense(units=hidden_size, name='shift_c_%d' % layer_ind) self.cmlp_layers.append(shift_c) elif self.cond_mlp == 'affine': aff_c = layers.Dense(units=2 * hidden_size, name='affine_c_%d' % layer_ind) self.cmlp_layers.append(aff_c)
def build(self, input_shape): x_shape = input_shape[0] height, width, features = x_shape[-3:] self.layer_norm = layers.LayerNormalization(trainable=False, name='normalize') if self.spatial_average == 'learnable': self.spatial_weights = self.add_weight( name='spatial_average', shape=(1, height, width, 1), initializer=tf.keras.initializers.Ones()) self.channel_dense = layers.Dense(units=2 * features, kernel_initializer=self.out_init) super(ConditionalLayerNorm, self).build(input_shape)
def build(self, input_shape): # encoder graph self.encoder = core.GrayScaleEncoder(self.enc_cfg) if self.is_parallel_loss: self.parallel_dense = layers.Dense(units=self.num_symbols, name='parallel_logits', use_bias=False) # decoder graph: outer decoder -> inner decoder -> logits. self.pixel_embed_layer = layers.Dense(units=self.hidden_size, use_bias=False) self.outer_decoder = core.OuterDecoder(self.dec_cfg) self.inner_decoder = core.InnerDecoder(self.dec_cfg) self.final_dense = layers.Dense(units=self.num_symbols, name='auto_logits') self.final_norm = layers.LayerNormalization()
def build(self, input_shapes): ff_size, hidden_size = self.config.ff_size, self.config.hidden_size num_heads = self.config.num_heads height, width = input_shapes[1:3] self.pos_embed = PositionEmbed(axes=[1, 2], max_lengths=[height, width]) self.residual_layers = [] num_norms = 4 * self.config.num_encoder_layers self.layer_norms = [ layers.LayerNormalization() for _ in range(num_norms) ] for _ in range(self.config.num_encoder_layers): # unmasked row unmask_row = SelfAttentionND(hidden_size=hidden_size, num_heads=num_heads, nd_block_size=[1, width], resolution=[height, width]) ff_row = tf.keras.Sequential([ layers.Dense(units=ff_size, activation='relu'), layers.Dense(units=hidden_size) ]) # unmasked column, unmask_col = SelfAttentionND(hidden_size=hidden_size, num_heads=num_heads, nd_block_size=[height, 1], resolution=[height, width]) ff_col = tf.keras.Sequential([ layers.Dense(units=ff_size, activation='relu'), layers.Dense(units=hidden_size) ]) self.residual_layers.append(unmask_row) self.residual_layers.append(ff_row) self.residual_layers.append(unmask_col) self.residual_layers.append(ff_col)
def build(self, input_shapes): embed_shape = input_shapes[0] height, width, num_filters = embed_shape[1:] hidden_size = self.config.hidden_size num_heads = self.config.num_heads ff_size = self.config.ff_size res = [height, width] self.pos_embed = coltran_layers.PositionEmbed(axes=[1, 2], max_lengths=res) self.residual_layers, self.layer_norms, self.cmlp_layers = [], [], [] num_norms = self.config.num_outer_layers * 4 if self.cond_ln: for _ in range(num_norms): curr_norm = coltran_layers.ConditionalLayerNorm( spatial_average=self.cond_ln_sp_ave, sequence=self.cond_ln_seq, out_init=self.cond_ln_init, out_act=self.cond_ln_act) self.layer_norms.append(curr_norm) else: self.layer_norms = [ layers.LayerNormalization() for _ in range(num_norms) ] for layer_ind in range(self.config.num_outer_layers): # unmasked row unmask_row = coltran_layers.SelfAttentionND( hidden_size=hidden_size, num_heads=num_heads, nd_block_size=[1, width], resolution=[height, width], cond_q=self.cond_att_q, cond_k=self.cond_att_k, cond_v=self.cond_att_v, cond_init=self.cond_att_init, cond_scale=self.cond_att_scale, cond_act=self.cond_att_act, name='unmask_row_att_%d' % layer_ind) ff_row = tf.keras.Sequential([ layers.Dense(units=ff_size, activation='relu'), layers.Dense(units=num_filters) ], name='row_dense_%d' % layer_ind) # masked column, mask_col = coltran_layers.SelfAttentionND( hidden_size=hidden_size, num_heads=num_heads, mask='future', nd_block_size=[height, 1], resolution=[height, width], cond_q=self.cond_att_q, cond_k=self.cond_att_k, cond_v=self.cond_att_v, cond_act=self.cond_att_act, cond_init=self.cond_att_init, cond_scale=self.cond_att_scale, name='mask_col_att_%d' % layer_ind) ff_col = tf.keras.Sequential([ layers.Dense(units=ff_size, activation='relu'), layers.Dense(units=num_filters) ], name='col_dense_%d' % layer_ind) self.residual_layers.append(unmask_row) self.residual_layers.append(ff_row) self.residual_layers.append(mask_col) self.residual_layers.append(ff_col) # Conditional MLP layers. if self.cond_mlp == 'shift': shift_r = layers.Dense(units=hidden_size, name='shift_r_%d' % layer_ind) shift_c = layers.Dense(units=hidden_size, name='shift_c_%d' % layer_ind) self.cmlp_layers.append(shift_r) self.cmlp_layers.append(shift_c) elif self.cond_mlp == 'affine': aff_r = layers.Dense(units=2 * hidden_size, name='affine_r_%d' % layer_ind) aff_c = layers.Dense(units=2 * hidden_size, name='affine_c_%d' % layer_ind) self.cmlp_layers.append(aff_r) self.cmlp_layers.append(aff_c) self.shift_down = coltran_layers.Shift(dimension=0, resolution=res)