예제 #1
0
    def build(self, input_shapes):
        context_shape = input_shapes[1]
        height, width = context_shape[1:3]
        ff_size = self.config.ff_size
        hidden_size = self.config.hidden_size
        num_heads = self.config.num_heads
        res = [height, width]

        self.pos_embed = coltran_layers.PositionEmbed(axes=[1, 2],
                                                      max_lengths=res)
        self.shift_right = coltran_layers.Shift(dimension=1, resolution=res)

        self.residual_layers, self.layer_norms, self.cmlp_layers = [], [], []
        num_norms = 2 * self.config.num_inner_layers
        if self.cond_ln:
            for _ in range(num_norms):
                curr_norm = coltran_layers.ConditionalLayerNorm(
                    spatial_average=self.cond_ln_sp_ave,
                    sequence=self.cond_ln_seq,
                    out_init=self.cond_ln_init,
                    out_act=self.cond_ln_act)
                self.layer_norms.append(curr_norm)
        else:
            self.layer_norms = [
                layers.LayerNormalization() for _ in range(num_norms)
            ]

        for layer_ind in range(self.config.num_inner_layers):

            mask_row = coltran_layers.SelfAttentionND(
                hidden_size=hidden_size,
                num_heads=num_heads,
                mask='future',
                nd_block_size=[1, width],
                resolution=[height, width],
                cond_q=self.cond_att_q,
                cond_k=self.cond_att_k,
                cond_v=self.cond_att_v,
                cond_init=self.cond_att_init,
                cond_scale=self.cond_att_scale,
                cond_act=self.cond_att_act,
                name='mask_row_att_%d' % layer_ind)

            ff_block = tf.keras.Sequential([
                layers.Dense(units=ff_size, activation='relu'),
                layers.Dense(units=hidden_size)
            ],
                                           name='dense_%d' % layer_ind)

            self.residual_layers.append(mask_row)
            self.residual_layers.append(ff_block)

            if self.cond_mlp == 'shift':
                shift_c = layers.Dense(units=hidden_size,
                                       name='shift_c_%d' % layer_ind)
                self.cmlp_layers.append(shift_c)
            elif self.cond_mlp == 'affine':
                aff_c = layers.Dense(units=2 * hidden_size,
                                     name='affine_c_%d' % layer_ind)
                self.cmlp_layers.append(aff_c)
예제 #2
0
 def test_column_attention(self):
   # row with cache
   column = layers.SelfAttentionND(
       hidden_size=256, num_heads=4, nd_block_size=[32, 1],
       resolution=[32, 32])
   x = tf.random.uniform(shape=[4, 32, 2, 3])
   output = column(inputs=x)
   self.assertEqual(output.shape, (4, 32, 2, 256))
예제 #3
0
 def test_col_attention_mask(self):
   col_mask = layers.SelfAttentionND(
       hidden_size=256, num_heads=8, nd_block_size=[4, 1],
       resolution=[4, 4], mask="future")
   x = tf.random.uniform(shape=[4, 4, 2, 3])
   output = col_mask(inputs=x)
   self.assertEqual(output.shape, (4, 4, 2, 256))
   self.assertEqual(col_mask.attention_dim_k, -4)
   self.assertEqual(col_mask.attention_dim_q, -4)
예제 #4
0
 def test_row_attention_mask(self):
   row_mask = layers.SelfAttentionND(
       hidden_size=256, num_heads=4, nd_block_size=[1, 32],
       resolution=[32, 32], mask="future")
   x = tf.random.uniform(shape=[4, 2, 32, 3])
   output = row_mask(inputs=x)
   self.assertEqual(row_mask.attention_dim_k, -3)
   self.assertEqual(row_mask.attention_dim_q, -3)
   self.assertEqual(output.shape, (4, 2, 32, 256))
예제 #5
0
 def test_self_attention_nd_cond_scale_false(self):
   row_mask = layers.SelfAttentionND(
       hidden_size=256, num_heads=4, nd_block_size=[1, 32],
       resolution=[32, 32], cond_q=True, cond_k=True, cond_v=True,
       cond_scale=False)
   inputs = tf.random.uniform(shape=(1, 3, 32, 32, 3))
   cond_inputs = tf.random.uniform(shape=(1, 3, 32, 32, 3))
   output = row_mask(inputs=(inputs, cond_inputs))
   self.assertEqual(output.shape, (1, 3, 32, 32, 256))
예제 #6
0
 def test_row_attention(self):
   # row with cache
   row = layers.SelfAttentionND(
       hidden_size=256, num_heads=4, nd_block_size=[1, 32],
       resolution=[32, 32])
   x = tf.random.uniform(shape=[4, 2, 32, 3])
   output = row(inputs=x)
   self.assertEqual(row.attention_dim_q, -3)
   self.assertEqual(row.attention_dim_k, -3)
   self.assertEqual(output.shape, (4, 2, 32, 256))
예제 #7
0
    def build(self, input_shapes):
        embed_shape = input_shapes[0]
        height, width, num_filters = embed_shape[1:]
        hidden_size = self.config.hidden_size
        num_heads = self.config.num_heads
        ff_size = self.config.ff_size
        res = [height, width]

        self.pos_embed = coltran_layers.PositionEmbed(axes=[1, 2],
                                                      max_lengths=res)

        self.residual_layers, self.layer_norms, self.cmlp_layers = [], [], []
        num_norms = self.config.num_outer_layers * 4
        if self.cond_ln:
            for _ in range(num_norms):
                curr_norm = coltran_layers.ConditionalLayerNorm(
                    spatial_average=self.cond_ln_sp_ave,
                    sequence=self.cond_ln_seq,
                    out_init=self.cond_ln_init,
                    out_act=self.cond_ln_act)
                self.layer_norms.append(curr_norm)
        else:
            self.layer_norms = [
                layers.LayerNormalization() for _ in range(num_norms)
            ]

        for layer_ind in range(self.config.num_outer_layers):
            # unmasked row
            unmask_row = coltran_layers.SelfAttentionND(
                hidden_size=hidden_size,
                num_heads=num_heads,
                nd_block_size=[1, width],
                resolution=[height, width],
                cond_q=self.cond_att_q,
                cond_k=self.cond_att_k,
                cond_v=self.cond_att_v,
                cond_init=self.cond_att_init,
                cond_scale=self.cond_att_scale,
                cond_act=self.cond_att_act,
                name='unmask_row_att_%d' % layer_ind)

            ff_row = tf.keras.Sequential([
                layers.Dense(units=ff_size, activation='relu'),
                layers.Dense(units=num_filters)
            ],
                                         name='row_dense_%d' % layer_ind)

            # masked column,
            mask_col = coltran_layers.SelfAttentionND(
                hidden_size=hidden_size,
                num_heads=num_heads,
                mask='future',
                nd_block_size=[height, 1],
                resolution=[height, width],
                cond_q=self.cond_att_q,
                cond_k=self.cond_att_k,
                cond_v=self.cond_att_v,
                cond_act=self.cond_att_act,
                cond_init=self.cond_att_init,
                cond_scale=self.cond_att_scale,
                name='mask_col_att_%d' % layer_ind)

            ff_col = tf.keras.Sequential([
                layers.Dense(units=ff_size, activation='relu'),
                layers.Dense(units=num_filters)
            ],
                                         name='col_dense_%d' % layer_ind)

            self.residual_layers.append(unmask_row)
            self.residual_layers.append(ff_row)
            self.residual_layers.append(mask_col)
            self.residual_layers.append(ff_col)

            # Conditional MLP layers.
            if self.cond_mlp == 'shift':
                shift_r = layers.Dense(units=hidden_size,
                                       name='shift_r_%d' % layer_ind)
                shift_c = layers.Dense(units=hidden_size,
                                       name='shift_c_%d' % layer_ind)
                self.cmlp_layers.append(shift_r)
                self.cmlp_layers.append(shift_c)
            elif self.cond_mlp == 'affine':
                aff_r = layers.Dense(units=2 * hidden_size,
                                     name='affine_r_%d' % layer_ind)
                aff_c = layers.Dense(units=2 * hidden_size,
                                     name='affine_c_%d' % layer_ind)
                self.cmlp_layers.append(aff_r)
                self.cmlp_layers.append(aff_c)

        self.shift_down = coltran_layers.Shift(dimension=0, resolution=res)