def body_sharded(self, sharded_features): dp = self._data_parallelism hparams = copy.copy(self._hparams) inputs = sharded_features["inputs"] targets = sharded_features["targets"] # Determine attention type and padding from hparams. q_padding, kv_padding = "VALID", "VALID" if hparams.q_filter_width > 1: q_padding = "LEFT" if hparams.kv_filter_width > 1: kv_padding = "LEFT" # Prepare decoder inputs and bias. decoder_input, rows, cols = dp(cia.prepare_decoder_inputs, inputs, targets, hparams) # Run decoder. decoder_output, extra_loss = cia.transformer_layers_sharded( dp, self._ps_devices, decoder_input, hparams.num_hidden_layers, hparams, self_attention_bias=None, enc_output=None, attention_type=hparams.dec_attention_type, q_padding=q_padding, kv_padding=kv_padding, name="decoder") output = dp(cia.create_output, decoder_output, rows, cols, targets, hparams) return output, extra_loss
def body_sharded(self, sharded_features): dp = self._data_parallelism hparams = copy.copy(self._hparams) inputs = sharded_features["inputs"] targets = sharded_features["targets"] # Determine attention type and padding from hparams. q_padding, kv_padding = "VALID", "VALID" if hparams.q_filter_width > 1: q_padding = "LEFT" if hparams.kv_filter_width > 1: kv_padding = "LEFT" # Prepare decoder inputs and bias. decoder_input, rows, cols = dp(cia.prepare_decoder_inputs, inputs, targets, hparams) # Run decoder. # TODO(nikip): Use q_padding and kv_padding del q_padding, kv_padding decoder_output, extra_loss = cia.transformer_layers_sharded( dp, self._ps_devices, decoder_input, hparams.num_hidden_layers, hparams, self_attention_bias=None, enc_output=None, attention_type=hparams.dec_attention_type, name="decoder") output = dp(cia.create_output, decoder_output, rows, cols, targets, hparams) return output, extra_loss