def testAllLayerParams(self): with self.session(use_gpu=False, graph=tf.Graph()): p = self._testParams() mdl = p.Instantiate() mdl.FPropDefaultTheta() lps = base_layer.RecursiveFindLayerParams(mdl.params) l_names = sorted([p.cls.__name__ for p in lps]) expected_layers = sorted([ 'Adam', 'AdditiveAttention', 'AsciiTokenizer', 'AsrDecoder', 'AsrEncoder', 'AsrModel', 'BeamSearchHelper', 'TargetSequenceSampler', 'ConvLSTMCell', 'Conv2DLayer', 'Conv2DLayer', 'EmbeddingLayer', 'HighwaySkipLayer', 'LSTMCellSimple', 'LSTMCellSimple', 'NullContextualizer', 'NullFusion', 'NullLm', 'Learner', 'PiecewiseConstantLearningRateSchedule', 'ProjectionLayer', 'SimpleFullSoftmax', 'SpectrumAugmenter', 'StackingOverTime', 'TestInputGenerator', ]) self.assertEqual(expected_layers, l_names)
def _DecoderParams(self, per_word_avg_loss=False, dtype=tf.float32, decoder_cls=decoder.MTDecoderV1): p = decoder_cls.Params() p.name = 'decoder' p.source_dim = 4 p.emb.vocab_size = 16 p.emb.embedding_dim = 4 p.emb.max_num_shards = 1 p.rnn_cell_dim = 4 p.rnn_layers = 3 p.attention.hidden_dim = 2 p.softmax.num_classes = 16 p.softmax.num_shards = 1 p.per_word_avg_loss = per_word_avg_loss p.dtype = dtype p.target_seq_len = 5 p.random_seed = 12345 p.emb.params_init = py_utils.WeightInit.Uniform(0.04, 12345) p.atten_rnn_cell_tpl.params_init = py_utils.WeightInit.Uniform( 0.04, 12345) p.rnn_cell_tpl.params_init = py_utils.WeightInit.Uniform(0.04, 12345) p.softmax.params_init = py_utils.WeightInit.Uniform(0.04, 123) for lp in base_layer.RecursiveFindLayerParams(p): lp.dtype = dtype return p
def _DecoderParams(self, per_word_avg_loss=False, is_transparent=False, dtype=tf.float32, fprop_dtype=None, use_task_emb=False, init_step_ids=False): p = decoder.TransformerDecoder.Params() p.name = 'decoder' p.source_dim = 4 p.model_dim = 4 p.num_trans_layers = 6 disable_vn = py_utils.VariationalNoiseParams(1.0, False, False) p.token_emb.vn = disable_vn p.token_emb.vocab_size = 20 p.token_emb.embedding_dim = 4 p.token_emb.max_num_shards = 1 p.token_emb.params_init = py_utils.WeightInit.GaussianSqrtDim( seed=12345) p.position_emb.embedding_dim = 4 if use_task_emb: p.task_emb = p.token_emb.Copy() p.task_emb.vocab_size = 4 p.trans_tpl.vn = disable_vn p.init_step_ids = init_step_ids p.trans_tpl.source_dim = 4 p.trans_tpl.tr_atten_tpl.source_dim = 4 p.trans_tpl.tr_atten_tpl.num_attention_heads = 2 p.trans_tpl.tr_fflayer_tpl.input_dim = 4 p.trans_tpl.tr_fflayer_tpl.hidden_dim = 8 p.label_smoothing = layers.LocalizedLabelSmoother.Params() p.label_smoothing.offsets = [-2, -1, 1, 2] p.label_smoothing.weights = [0.015, 0.035, 0.035, 0.015] p.softmax.vn = disable_vn p.softmax.num_classes = 20 p.softmax.num_shards = 1 p.per_word_avg_loss = per_word_avg_loss p.random_seed = 1234 p.dtype = dtype p.target_seq_len = 5 p.is_transparent = is_transparent for lp in base_layer.RecursiveFindLayerParams(p): lp.dtype = dtype py_utils.UpdateFpropDtype(p, fprop_dtype) return p
def _DecoderParams(self, per_word_avg_loss=False, dtype=tf.float32): p = decoder.MTDecoderV1.Params() p.name = 'decoder' p.source_dim = 4 p.emb.vocab_size = 16 p.emb.embedding_dim = 4 p.emb.max_num_shards = 1 p.rnn_cell_dim = 4 p.rnn_layers = 3 p.attention.hidden_dim = 2 p.softmax.num_classes = 16 p.softmax.num_shards = 1 p.per_word_avg_loss = per_word_avg_loss p.dtype = dtype p.target_seq_len = 5 for lp in base_layer.RecursiveFindLayerParams(p): lp.dtype = dtype return p
def _SetDefaults(p): p.random_seed = 12345 p.decoder.input_dropout_prob = 0.0 mp = p.encoder.transformer_stack.transparent_merger_tpl mp.weighted_merger_dropout_prob = 0.0 disable_vn = py_utils.VariationalNoiseParams(1.0, False, False) for lp in base_layer.RecursiveFindLayerParams(p): # TODO(lepikhin): lp.dtype = dtype lp.params_init = py_utils.WeightInit.Gaussian(0.1, 12345) lp.vn = disable_vn tp = p.train assert tp.l2_regularizer_weight is None tp.clip_gradient_norm_to_value = False tp.grad_norm_to_clip_to_zero = False tp.optimizer = optimizer.SGD.Params() tp.learning_rate = 1e-2 tp.lr_schedule = schedule.ContinuousSchedule.Params() for l in p.ToText().split('\n'): print(l) return p