def call(self, inputs, training=True): embeddings, channel_context = inputs cond_layer_ind = 0 output = self.pos_embed(embeddings) if self.skip: output += channel_context inputs = output for layer, norm in zip(self.residual_layers, self.layer_norms): if 'att' in layer.name and self.cond_att: output = layer((inputs, channel_context)) else: output = layer(inputs) if 'dense' in layer.name and self.cond_mlp: curr_cond_layer = self.cmlp_layers[cond_layer_ind] output = cond_with_context(output, curr_cond_layer, channel_context, self.cond_mlp, self.cond_mlp_act) cond_layer_ind += 1 output = coltran_layers.residual_dropout(inputs, output, self.dropout, training) if self.cond_ln: inputs = norm((output, channel_context)) else: inputs = norm(output) output = self.shift_down(inputs) return output
def call(self, inputs, row_ind=None, training=True): embeddings, upper_context, channel_context = inputs embeddings = self.shift_right(embeddings) if row_ind is None: embeddings = self.pos_embed(embeddings) # special case during sampling. else: input_shape = embeddings.shape.as_list() pos_embed = get_pos_embeddings(self.pos_embed, input_shape) pos_embed = pos_embed[:, row_ind:row_ind + 1] embeddings += pos_embed inputs = embeddings if self.skip: inputs += channel_context inputs += upper_context layer_zip = zip(self.residual_layers, self.layer_norms) all_context = tf.concat((channel_context, upper_context), -1) cond_layer_ind = 0 for layer, norm in layer_zip: # Conditional Self-Attention. if 'att' in layer.name and self.cond_att: output = layer((inputs, all_context)) else: output = layer(inputs) # Conditional MLP. if 'dense' in layer.name and self.cond_mlp: curr_cond_layer = self.cmlp_layers[cond_layer_ind] output = cond_with_context(output, curr_cond_layer, all_context, self.cond_mlp, self.cond_mlp_act) cond_layer_ind += 1 output = coltran_layers.residual_dropout(inputs, output, self.dropout, training) # providing all context here violates causal masking due to the spatial # averaging. # Conditional Layer norm. if self.cond_ln: inputs = norm((output, channel_context)) else: inputs = norm(output) return inputs