コード例 #1
0
ファイル: core.py プロジェクト: zizhuang99/google-research
    def call(self, inputs, training=True):
        embeddings, channel_context = inputs
        cond_layer_ind = 0

        output = self.pos_embed(embeddings)
        if self.skip:
            output += channel_context
        inputs = output

        for layer, norm in zip(self.residual_layers, self.layer_norms):
            if 'att' in layer.name and self.cond_att:
                output = layer((inputs, channel_context))
            else:
                output = layer(inputs)

            if 'dense' in layer.name and self.cond_mlp:
                curr_cond_layer = self.cmlp_layers[cond_layer_ind]
                output = cond_with_context(output, curr_cond_layer,
                                           channel_context, self.cond_mlp,
                                           self.cond_mlp_act)
                cond_layer_ind += 1

            output = coltran_layers.residual_dropout(inputs, output,
                                                     self.dropout, training)

            if self.cond_ln:
                inputs = norm((output, channel_context))
            else:
                inputs = norm(output)

        output = self.shift_down(inputs)
        return output
コード例 #2
0
ファイル: core.py プロジェクト: zizhuang99/google-research
    def call(self, inputs, row_ind=None, training=True):
        embeddings, upper_context, channel_context = inputs

        embeddings = self.shift_right(embeddings)
        if row_ind is None:
            embeddings = self.pos_embed(embeddings)
        # special case during sampling.
        else:
            input_shape = embeddings.shape.as_list()
            pos_embed = get_pos_embeddings(self.pos_embed, input_shape)
            pos_embed = pos_embed[:, row_ind:row_ind + 1]
            embeddings += pos_embed

        inputs = embeddings
        if self.skip:
            inputs += channel_context
            inputs += upper_context

        layer_zip = zip(self.residual_layers, self.layer_norms)
        all_context = tf.concat((channel_context, upper_context), -1)

        cond_layer_ind = 0
        for layer, norm in layer_zip:

            # Conditional Self-Attention.
            if 'att' in layer.name and self.cond_att:
                output = layer((inputs, all_context))
            else:
                output = layer(inputs)

            # Conditional MLP.
            if 'dense' in layer.name and self.cond_mlp:
                curr_cond_layer = self.cmlp_layers[cond_layer_ind]
                output = cond_with_context(output, curr_cond_layer,
                                           all_context, self.cond_mlp,
                                           self.cond_mlp_act)
                cond_layer_ind += 1

            output = coltran_layers.residual_dropout(inputs, output,
                                                     self.dropout, training)

            # providing all context here violates causal masking due to the spatial
            # averaging.
            # Conditional Layer norm.
            if self.cond_ln:
                inputs = norm((output, channel_context))
            else:
                inputs = norm(output)

        return inputs