Ejemplo n.º 1
0
    def build(self, input_shapes):
        context_shape = input_shapes[1]
        height, width = context_shape[1:3]
        ff_size = self.config.ff_size
        hidden_size = self.config.hidden_size
        num_heads = self.config.num_heads
        res = [height, width]

        self.pos_embed = coltran_layers.PositionEmbed(axes=[1, 2],
                                                      max_lengths=res)
        self.shift_right = coltran_layers.Shift(dimension=1, resolution=res)

        self.residual_layers, self.layer_norms, self.cmlp_layers = [], [], []
        num_norms = 2 * self.config.num_inner_layers
        if self.cond_ln:
            for _ in range(num_norms):
                curr_norm = coltran_layers.ConditionalLayerNorm(
                    spatial_average=self.cond_ln_sp_ave,
                    sequence=self.cond_ln_seq,
                    out_init=self.cond_ln_init,
                    out_act=self.cond_ln_act)
                self.layer_norms.append(curr_norm)
        else:
            self.layer_norms = [
                layers.LayerNormalization() for _ in range(num_norms)
            ]

        for layer_ind in range(self.config.num_inner_layers):

            mask_row = coltran_layers.SelfAttentionND(
                hidden_size=hidden_size,
                num_heads=num_heads,
                mask='future',
                nd_block_size=[1, width],
                resolution=[height, width],
                cond_q=self.cond_att_q,
                cond_k=self.cond_att_k,
                cond_v=self.cond_att_v,
                cond_init=self.cond_att_init,
                cond_scale=self.cond_att_scale,
                cond_act=self.cond_att_act,
                name='mask_row_att_%d' % layer_ind)

            ff_block = tf.keras.Sequential([
                layers.Dense(units=ff_size, activation='relu'),
                layers.Dense(units=hidden_size)
            ],
                                           name='dense_%d' % layer_ind)

            self.residual_layers.append(mask_row)
            self.residual_layers.append(ff_block)

            if self.cond_mlp == 'shift':
                shift_c = layers.Dense(units=hidden_size,
                                       name='shift_c_%d' % layer_ind)
                self.cmlp_layers.append(shift_c)
            elif self.cond_mlp == 'affine':
                aff_c = layers.Dense(units=2 * hidden_size,
                                     name='affine_c_%d' % layer_ind)
                self.cmlp_layers.append(aff_c)
Ejemplo n.º 2
0
    def build(self, input_shape):
        x_shape = input_shape[0]
        height, width, features = x_shape[-3:]
        self.layer_norm = layers.LayerNormalization(trainable=False,
                                                    name='normalize')

        if self.spatial_average == 'learnable':
            self.spatial_weights = self.add_weight(
                name='spatial_average',
                shape=(1, height, width, 1),
                initializer=tf.keras.initializers.Ones())
        self.channel_dense = layers.Dense(units=2 * features,
                                          kernel_initializer=self.out_init)
        super(ConditionalLayerNorm, self).build(input_shape)
Ejemplo n.º 3
0
    def build(self, input_shape):
        # encoder graph
        self.encoder = core.GrayScaleEncoder(self.enc_cfg)
        if self.is_parallel_loss:
            self.parallel_dense = layers.Dense(units=self.num_symbols,
                                               name='parallel_logits',
                                               use_bias=False)

        # decoder graph: outer decoder -> inner decoder -> logits.
        self.pixel_embed_layer = layers.Dense(units=self.hidden_size,
                                              use_bias=False)
        self.outer_decoder = core.OuterDecoder(self.dec_cfg)
        self.inner_decoder = core.InnerDecoder(self.dec_cfg)
        self.final_dense = layers.Dense(units=self.num_symbols,
                                        name='auto_logits')
        self.final_norm = layers.LayerNormalization()
Ejemplo n.º 4
0
    def build(self, input_shapes):
        ff_size, hidden_size = self.config.ff_size, self.config.hidden_size
        num_heads = self.config.num_heads
        height, width = input_shapes[1:3]

        self.pos_embed = PositionEmbed(axes=[1, 2],
                                       max_lengths=[height, width])

        self.residual_layers = []
        num_norms = 4 * self.config.num_encoder_layers
        self.layer_norms = [
            layers.LayerNormalization() for _ in range(num_norms)
        ]

        for _ in range(self.config.num_encoder_layers):
            # unmasked row
            unmask_row = SelfAttentionND(hidden_size=hidden_size,
                                         num_heads=num_heads,
                                         nd_block_size=[1, width],
                                         resolution=[height, width])

            ff_row = tf.keras.Sequential([
                layers.Dense(units=ff_size, activation='relu'),
                layers.Dense(units=hidden_size)
            ])

            # unmasked column,
            unmask_col = SelfAttentionND(hidden_size=hidden_size,
                                         num_heads=num_heads,
                                         nd_block_size=[height, 1],
                                         resolution=[height, width])

            ff_col = tf.keras.Sequential([
                layers.Dense(units=ff_size, activation='relu'),
                layers.Dense(units=hidden_size)
            ])

            self.residual_layers.append(unmask_row)
            self.residual_layers.append(ff_row)
            self.residual_layers.append(unmask_col)
            self.residual_layers.append(ff_col)
Ejemplo n.º 5
0
    def build(self, input_shapes):
        embed_shape = input_shapes[0]
        height, width, num_filters = embed_shape[1:]
        hidden_size = self.config.hidden_size
        num_heads = self.config.num_heads
        ff_size = self.config.ff_size
        res = [height, width]

        self.pos_embed = coltran_layers.PositionEmbed(axes=[1, 2],
                                                      max_lengths=res)

        self.residual_layers, self.layer_norms, self.cmlp_layers = [], [], []
        num_norms = self.config.num_outer_layers * 4
        if self.cond_ln:
            for _ in range(num_norms):
                curr_norm = coltran_layers.ConditionalLayerNorm(
                    spatial_average=self.cond_ln_sp_ave,
                    sequence=self.cond_ln_seq,
                    out_init=self.cond_ln_init,
                    out_act=self.cond_ln_act)
                self.layer_norms.append(curr_norm)
        else:
            self.layer_norms = [
                layers.LayerNormalization() for _ in range(num_norms)
            ]

        for layer_ind in range(self.config.num_outer_layers):
            # unmasked row
            unmask_row = coltran_layers.SelfAttentionND(
                hidden_size=hidden_size,
                num_heads=num_heads,
                nd_block_size=[1, width],
                resolution=[height, width],
                cond_q=self.cond_att_q,
                cond_k=self.cond_att_k,
                cond_v=self.cond_att_v,
                cond_init=self.cond_att_init,
                cond_scale=self.cond_att_scale,
                cond_act=self.cond_att_act,
                name='unmask_row_att_%d' % layer_ind)

            ff_row = tf.keras.Sequential([
                layers.Dense(units=ff_size, activation='relu'),
                layers.Dense(units=num_filters)
            ],
                                         name='row_dense_%d' % layer_ind)

            # masked column,
            mask_col = coltran_layers.SelfAttentionND(
                hidden_size=hidden_size,
                num_heads=num_heads,
                mask='future',
                nd_block_size=[height, 1],
                resolution=[height, width],
                cond_q=self.cond_att_q,
                cond_k=self.cond_att_k,
                cond_v=self.cond_att_v,
                cond_act=self.cond_att_act,
                cond_init=self.cond_att_init,
                cond_scale=self.cond_att_scale,
                name='mask_col_att_%d' % layer_ind)

            ff_col = tf.keras.Sequential([
                layers.Dense(units=ff_size, activation='relu'),
                layers.Dense(units=num_filters)
            ],
                                         name='col_dense_%d' % layer_ind)

            self.residual_layers.append(unmask_row)
            self.residual_layers.append(ff_row)
            self.residual_layers.append(mask_col)
            self.residual_layers.append(ff_col)

            # Conditional MLP layers.
            if self.cond_mlp == 'shift':
                shift_r = layers.Dense(units=hidden_size,
                                       name='shift_r_%d' % layer_ind)
                shift_c = layers.Dense(units=hidden_size,
                                       name='shift_c_%d' % layer_ind)
                self.cmlp_layers.append(shift_r)
                self.cmlp_layers.append(shift_c)
            elif self.cond_mlp == 'affine':
                aff_r = layers.Dense(units=2 * hidden_size,
                                     name='affine_r_%d' % layer_ind)
                aff_c = layers.Dense(units=2 * hidden_size,
                                     name='affine_c_%d' % layer_ind)
                self.cmlp_layers.append(aff_r)
                self.cmlp_layers.append(aff_c)

        self.shift_down = coltran_layers.Shift(dimension=0, resolution=res)