def testSymbolModalityTargets(self): batch_size = 10 num_datashards = 5 length = 6 height = 7 hidden_size = 9 vocab_size = 11 model_hparams = common_hparams.basic_params1() model_hparams.hidden_size = hidden_size model_hparams.mode = tf.estimator.ModeKeys.TRAIN body_output = np.random.randint(100, size=(batch_size, length, height, hidden_size)) targets = np.random.randint(vocab_size, size=(batch_size, length, height, 1)) data_parallelism = expert_utils.Parallelism(["/device:CPU:0"] * num_datashards) sharded_body_output = tf.split(tf.to_float(body_output), num_datashards) sharded_targets = tf.split(targets, num_datashards) sharded_logits = data_parallelism( modalities.get_top(modalities.ModalityType.SYMBOL), sharded_body_output, sharded_targets, model_hparams, vocab_size) sharded_loss_num, sharded_loss_den = data_parallelism( modalities.get_loss(modalities.ModalityType.SYMBOL), sharded_logits, sharded_targets, model_hparams, vocab_size, modalities.get_weights_fn(modalities.ModalityType.SYMBOL)) train_loss = (tf.add_n(sharded_loss_num) / tf.maximum(1.0, tf.add_n(sharded_loss_den))) logits = tf.concat(sharded_logits, 0) self.evaluate(tf.global_variables_initializer()) res1, res2 = self.evaluate((logits, train_loss)) self.assertEqual(res1.shape, (batch_size, length, height, 1, vocab_size)) self.assertEqual(res2.shape, ())
def symbols_to_logits_fn(ids, i, cache): """Go from ids to logits for next symbol.""" ids = ids[:, -1:] targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) targets = preprocess_targets(targets, i) bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] with tf.variable_scope("body"): body_outputs = dp( self.decode, targets, cache.get("encoder_output"), cache.get("encoder_decoder_attention_bias"), bias, hparams, cache, nonpadding=features_to_nonpadding(features, "targets")) update_decoder_attention_history(cache) cache["body_outputs"] = tf.concat([cache["body_outputs"], body_outputs[0]], axis=2) modality_name = hparams.name.get( "targets", modalities.get_name(target_modality))(hparams, target_vocab_size) with tf.variable_scope(modality_name): top = hparams.top.get("targets", modalities.get_top(target_modality)) logits = dp(top, body_outputs, None, hparams, target_vocab_size)[0] ret = tf.squeeze(logits, axis=[1, 2, 3]) if partial_targets is not None: # If the position is within the given partial targets, we alter the # logits to always return those values. # A faster approach would be to process the partial targets in one # iteration in order to fill the corresponding parts of the cache. # This would require broader changes, though. vocab_size = tf.shape(ret)[1] def forced_logits(): return tf.one_hot( tf.tile(partial_targets[:, i], [beam_size]), vocab_size, 0.0, -1e9) ret = tf.cond( tf.less(i, partial_targets_length), forced_logits, lambda: ret) return ret, cache
def testGetForAllModalities(self): for modality in modalities.ModalityType.get_choices(): bottom = modalities.get_bottom(modality) loss = modalities.get_loss(modality) name = modalities.get_name(modality) targets_bottom = modalities.get_targets_bottom(modality) top = modalities.get_top(modality) weights_fn = modalities.get_weights_fn(modality) self.assertIsNotNone(bottom, msg="{} has no default bottom".format(modality)) self.assertIsNotNone(loss, msg="{} has no default loss".format(modality)) self.assertIsNotNone(name, msg="{} has no default name".format(modality)) self.assertIsNotNone( targets_bottom, msg="{} has no default targets_bottom".format(modality)) self.assertIsNotNone(top, msg="{} has no default top".format(modality)) self.assertIsNotNone(weights_fn, msg="{} has no default weights_fn".format(modality))
def infer_step(recent_output, recent_logits, unused_loss): """Inference step.""" if not tf.executing_eagerly(): if self._target_modality_is_real: dim = self._problem_hparams.vocab_size["targets"] if dim is not None and hasattr(self._hparams, "vocab_divisor"): dim += (-dim) % self._hparams.vocab_divisor recent_output.set_shape([None, None, None, dim]) else: recent_output.set_shape([None, None, None, 1]) padded = tf.pad(recent_output, [[0, 0], [0, 1], [0, 0], [0, 0]]) features["targets"] = padded # This is inefficient in that it generates samples at all timesteps, # not just the last one, except if target_modality is pointwise. samples, logits, losses = self.sample(features) # Concatenate the already-generated recent_output with last timestep # of the newly-generated samples. top = self._hparams.top.get("targets", modalities.get_top(target_modality)) if getattr(top, "pointwise", False): cur_sample = samples[:, -1, :, :] else: cur_sample = samples[:, common_layers.shape_list(recent_output )[1], :, :] if self._target_modality_is_real: cur_sample = tf.expand_dims(cur_sample, axis=1) samples = tf.concat([recent_output, cur_sample], axis=1) else: cur_sample = tf.to_int64(tf.expand_dims(cur_sample, axis=1)) samples = tf.concat([recent_output, cur_sample], axis=1) if not tf.executing_eagerly(): samples.set_shape([None, None, None, 1]) # Assuming we have one shard for logits. logits = tf.concat([recent_logits, logits[:, -1:]], 1) loss = sum([l for l in losses.values() if l is not None]) return samples, logits, loss
def symbols_to_logits_fn(ids, ids_tag, i, cache): """Go from ids to logits for next symbol.""" ids = ids[:, -1:] targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) targets = preprocess_targets_method(targets, i) ids_tag = ids_tag[:, -1:] targets_tag = tf.expand_dims(tf.expand_dims(ids_tag, axis=2), axis=3) targets_tag = preprocess_targets_tag_method(targets_tag, i) bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] with tf.variable_scope('body'): with tf.variable_scope('edit_ops_layer'): with tf.variable_scope('ffn'): x = targets preproc = lambda z: common_layers.layer_preprocess( z, hparams, layer_collection=None) layer_inputs = [ tf.concat(preproc(x), axis=0), tf.concat(preproc(targets_tag), axis=0), ] y = transformer_layers.transformer_ffn_layer( tf.concat(layer_inputs, axis=2), hparams, conv_padding='LEFT', nonpadding_mask=features_to_nonpadding( features, 'targets'), losses=None, cache=cache, decode_loop_step=None, layer_collection=None, ) targets = common_layers.layer_postprocess( x, y, hparams) if hparams.middle_prediction: num_decoder_layers = (hparams.num_decoder_layers or hparams.num_hidden_layers) hparams.num_decoder_layers = int( num_decoder_layers / hparams.middle_prediction_layer_factor) body_outputs = dp( self.decode, targets, cache.get('encoder_output'), cache.get('encoder_decoder_attention_bias'), bias, hparams, cache, nonpadding=features_to_nonpadding(features, 'targets'), )[0] body_outputs, logits_tag = dp( self._prediction_cascade_predict, hparams, features_to_nonpadding(features, 'targets'), cache.get('encoder_decoder_attention_bias'), cache.get('encoder_output'), body_outputs, ) logits_tag = logits_tag[0]['targets_error_tag'] if hparams.middle_prediction: with tf.variable_scope('after_prediction'): body_outputs = dp( self.decode, targets + body_outputs[0], cache.get('encoder_output'), cache.get('encoder_decoder_attention_bias'), bias, hparams, cache, nonpadding=features_to_nonpadding( features, 'targets'), ) update_decoder_attention_history(cache) modality_name = hparams.name.get( 'targets', modalities.get_name(target_modality))(hparams, target_vocab_size) with tf.variable_scope('targets/' + modality_name): top = hparams.top.get('targets', modalities.get_top(target_modality)) logits = dp(top, body_outputs, None, hparams, target_vocab_size)[0] ret = tf.squeeze(logits, axis=[1, 2]) if partial_targets is not None: vocab_size = tf.shape(ret)[1] def forced_logits(): return tf.one_hot( tf.tile(partial_targets[:, i], [beam_size]), vocab_size, 0.0, -1e9, ) ret = tf.cond( tf.less(i, partial_targets_length), forced_logits, lambda: ret, ) logits_tag = tf.squeeze(logits_tag, axis=[1]) return ret, logits_tag, cache