Example #1
0
    def build(self, input_shapes):
        context_shape = input_shapes[1]
        height, width = context_shape[1:3]
        ff_size = self.config.ff_size
        hidden_size = self.config.hidden_size
        num_heads = self.config.num_heads
        res = [height, width]

        self.pos_embed = coltran_layers.PositionEmbed(axes=[1, 2],
                                                      max_lengths=res)
        self.shift_right = coltran_layers.Shift(dimension=1, resolution=res)

        self.residual_layers, self.layer_norms, self.cmlp_layers = [], [], []
        num_norms = 2 * self.config.num_inner_layers
        if self.cond_ln:
            for _ in range(num_norms):
                curr_norm = coltran_layers.ConditionalLayerNorm(
                    spatial_average=self.cond_ln_sp_ave,
                    sequence=self.cond_ln_seq,
                    out_init=self.cond_ln_init,
                    out_act=self.cond_ln_act)
                self.layer_norms.append(curr_norm)
        else:
            self.layer_norms = [
                layers.LayerNormalization() for _ in range(num_norms)
            ]

        for layer_ind in range(self.config.num_inner_layers):

            mask_row = coltran_layers.SelfAttentionND(
                hidden_size=hidden_size,
                num_heads=num_heads,
                mask='future',
                nd_block_size=[1, width],
                resolution=[height, width],
                cond_q=self.cond_att_q,
                cond_k=self.cond_att_k,
                cond_v=self.cond_att_v,
                cond_init=self.cond_att_init,
                cond_scale=self.cond_att_scale,
                cond_act=self.cond_att_act,
                name='mask_row_att_%d' % layer_ind)

            ff_block = tf.keras.Sequential([
                layers.Dense(units=ff_size, activation='relu'),
                layers.Dense(units=hidden_size)
            ],
                                           name='dense_%d' % layer_ind)

            self.residual_layers.append(mask_row)
            self.residual_layers.append(ff_block)

            if self.cond_mlp == 'shift':
                shift_c = layers.Dense(units=hidden_size,
                                       name='shift_c_%d' % layer_ind)
                self.cmlp_layers.append(shift_c)
            elif self.cond_mlp == 'affine':
                aff_c = layers.Dense(units=2 * hidden_size,
                                     name='affine_c_%d' % layer_ind)
                self.cmlp_layers.append(aff_c)
 def test_conditional_layer_norm(self, spatial_average, sequence):
     cond_layer_norm = layers.ConditionalLayerNorm(
         spatial_average=spatial_average, sequence=sequence)
     x = tf.random.uniform(shape=(8, 32, 32, 128))
     cond_inputs = tf.random.uniform(shape=(8, 32, 32, 128))
     out = cond_layer_norm(inputs=(x, cond_inputs))
     self.assertEqual(out.shape, (8, 32, 32, 128))
Example #3
0
    def build(self, input_shapes):
        embed_shape = input_shapes[0]
        height, width, num_filters = embed_shape[1:]
        hidden_size = self.config.hidden_size
        num_heads = self.config.num_heads
        ff_size = self.config.ff_size
        res = [height, width]

        self.pos_embed = coltran_layers.PositionEmbed(axes=[1, 2],
                                                      max_lengths=res)

        self.residual_layers, self.layer_norms, self.cmlp_layers = [], [], []
        num_norms = self.config.num_outer_layers * 4
        if self.cond_ln:
            for _ in range(num_norms):
                curr_norm = coltran_layers.ConditionalLayerNorm(
                    spatial_average=self.cond_ln_sp_ave,
                    sequence=self.cond_ln_seq,
                    out_init=self.cond_ln_init,
                    out_act=self.cond_ln_act)
                self.layer_norms.append(curr_norm)
        else:
            self.layer_norms = [
                layers.LayerNormalization() for _ in range(num_norms)
            ]

        for layer_ind in range(self.config.num_outer_layers):
            # unmasked row
            unmask_row = coltran_layers.SelfAttentionND(
                hidden_size=hidden_size,
                num_heads=num_heads,
                nd_block_size=[1, width],
                resolution=[height, width],
                cond_q=self.cond_att_q,
                cond_k=self.cond_att_k,
                cond_v=self.cond_att_v,
                cond_init=self.cond_att_init,
                cond_scale=self.cond_att_scale,
                cond_act=self.cond_att_act,
                name='unmask_row_att_%d' % layer_ind)

            ff_row = tf.keras.Sequential([
                layers.Dense(units=ff_size, activation='relu'),
                layers.Dense(units=num_filters)
            ],
                                         name='row_dense_%d' % layer_ind)

            # masked column,
            mask_col = coltran_layers.SelfAttentionND(
                hidden_size=hidden_size,
                num_heads=num_heads,
                mask='future',
                nd_block_size=[height, 1],
                resolution=[height, width],
                cond_q=self.cond_att_q,
                cond_k=self.cond_att_k,
                cond_v=self.cond_att_v,
                cond_act=self.cond_att_act,
                cond_init=self.cond_att_init,
                cond_scale=self.cond_att_scale,
                name='mask_col_att_%d' % layer_ind)

            ff_col = tf.keras.Sequential([
                layers.Dense(units=ff_size, activation='relu'),
                layers.Dense(units=num_filters)
            ],
                                         name='col_dense_%d' % layer_ind)

            self.residual_layers.append(unmask_row)
            self.residual_layers.append(ff_row)
            self.residual_layers.append(mask_col)
            self.residual_layers.append(ff_col)

            # Conditional MLP layers.
            if self.cond_mlp == 'shift':
                shift_r = layers.Dense(units=hidden_size,
                                       name='shift_r_%d' % layer_ind)
                shift_c = layers.Dense(units=hidden_size,
                                       name='shift_c_%d' % layer_ind)
                self.cmlp_layers.append(shift_r)
                self.cmlp_layers.append(shift_c)
            elif self.cond_mlp == 'affine':
                aff_r = layers.Dense(units=2 * hidden_size,
                                     name='affine_r_%d' % layer_ind)
                aff_c = layers.Dense(units=2 * hidden_size,
                                     name='affine_c_%d' % layer_ind)
                self.cmlp_layers.append(aff_r)
                self.cmlp_layers.append(aff_c)

        self.shift_down = coltran_layers.Shift(dimension=0, resolution=res)