def test_nas_decoder_resizing_output(self): hparams, wrong_size = self._get_wrong_output_dim_decoder_hparams() hparams.enforce_output_size = False input_tensor = tf.zeros([_BATCH_SIZE, _INPUT_LENGTH, _EMBEDDING_DEPTH]) decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(_INPUT_LENGTH)) with tf.variable_scope("wrong"): wrong_size_decoder_output = translation_nas_net.nas_decoder( decoder_input=input_tensor, encoder_cell_outputs=[input_tensor] * hparams.encoder_num_cells, decoder_self_attention_bias=decoder_self_attention_bias, encoder_decoder_attention_bias=None, hparams=hparams) # Now add the correction. hparams.enforce_output_size = True with tf.variable_scope("correct"): correct_size_decoder_output = translation_nas_net.nas_decoder( decoder_input=input_tensor, encoder_cell_outputs=[input_tensor] * hparams.encoder_num_cells, decoder_self_attention_bias=decoder_self_attention_bias, encoder_decoder_attention_bias=None, hparams=hparams) with self.test_session() as session: session.run(tf.global_variables_initializer()) wrong_output, correct_output = session.run( [wrong_size_decoder_output, correct_size_decoder_output]) self.assertEqual(wrong_output.shape, (_BATCH_SIZE, _INPUT_LENGTH, wrong_size)) self.assertEqual(correct_output.shape, (_BATCH_SIZE, _INPUT_LENGTH, _EMBEDDING_DEPTH))
def test_calculate_branching_model_parameters_decoder_resize( self, enforce_output_size): tf.reset_default_graph() hparams, _ = self._get_wrong_output_dim_decoder_hparams() hparams.enforce_output_size = enforce_output_size hparams.decoder_left_norms = [translation_nas_net.NO_NORM_KEY] * 5 hparams.decoder_right_norms = [translation_nas_net.NO_NORM_KEY] * 5 # Get predicted number of parameters. (predicted_num_params, _, _, _) = translation_nas_net.calculate_branching_model_parameters( encoding_depth=_EMBEDDING_DEPTH, left_inputs=hparams.decoder_left_inputs, left_layers=hparams.decoder_left_layers, left_output_dims=hparams.decoder_left_output_dims, right_inputs=hparams.decoder_right_inputs, right_layers=hparams.decoder_right_layers, right_output_dims=hparams.decoder_right_output_dims, combiner_functions=hparams.decoder_combiner_functions, final_combiner_function=hparams.decoder_final_combiner_function, layer_registry=layers.DECODER_LAYERS, num_cells=hparams.decoder_num_cells, encoder_depth=_EMBEDDING_DEPTH, enforce_output_size=enforce_output_size) # Count graph variables. input_tensor = tf.zeros([_BATCH_SIZE, _INPUT_LENGTH, _EMBEDDING_DEPTH]) decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(_INPUT_LENGTH)) _ = translation_nas_net.nas_decoder( decoder_input=input_tensor, encoder_cell_outputs=[input_tensor] * hparams.encoder_num_cells, decoder_self_attention_bias=decoder_self_attention_bias, encoder_decoder_attention_bias=None, hparams=hparams, final_layer_norm=False) trainable_variables_list = tf.trainable_variables() empirical_num_params = 0 for variable_tensor in trainable_variables_list: empirical_num_params += _list_product( variable_tensor.shape.as_list()) self.assertEqual(empirical_num_params, predicted_num_params)