def test_calculate_branching_model_parameters_decoder_resize( self, enforce_output_size): tf.reset_default_graph() hparams, _ = self._get_wrong_output_dim_decoder_hparams() hparams.enforce_output_size = enforce_output_size hparams.decoder_left_norms = [translation_nas_net.NO_NORM_KEY] * 5 hparams.decoder_right_norms = [translation_nas_net.NO_NORM_KEY] * 5 # Get predicted number of parameters. (predicted_num_params, _, _, _) = translation_nas_net.calculate_branching_model_parameters( encoding_depth=_EMBEDDING_DEPTH, left_inputs=hparams.decoder_left_inputs, left_layers=hparams.decoder_left_layers, left_output_dims=hparams.decoder_left_output_dims, right_inputs=hparams.decoder_right_inputs, right_layers=hparams.decoder_right_layers, right_output_dims=hparams.decoder_right_output_dims, combiner_functions=hparams.decoder_combiner_functions, final_combiner_function=hparams.decoder_final_combiner_function, layer_registry=layers.DECODER_LAYERS, num_cells=hparams.decoder_num_cells, encoder_depth=_EMBEDDING_DEPTH, enforce_output_size=enforce_output_size) # Count graph variables. input_tensor = tf.zeros([_BATCH_SIZE, _INPUT_LENGTH, _EMBEDDING_DEPTH]) decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(_INPUT_LENGTH)) _ = translation_nas_net.nas_decoder( decoder_input=input_tensor, encoder_cell_outputs=[input_tensor] * hparams.encoder_num_cells, decoder_self_attention_bias=decoder_self_attention_bias, encoder_decoder_attention_bias=None, hparams=hparams, final_layer_norm=False) trainable_variables_list = tf.trainable_variables() empirical_num_params = 0 for variable_tensor in trainable_variables_list: empirical_num_params += _list_product( variable_tensor.shape.as_list()) self.assertEqual(empirical_num_params, predicted_num_params)
def test_calculate_branching_model_parameters_output_size_last_two(self): left_inputs = [0, 1, 2, 2] right_inputs = [0, 1, 2, 2] left_output_dims = [1, 10, 100, 1000] right_output_dims = [10000, 100000, 1000000, 10000000] right_layers = [ layers.IDENTITY_REGISTRY_KEY, layers.STANDARD_CONV_1X1_REGISTRY_KEY, layers.STANDARD_CONV_1X1_REGISTRY_KEY, layers.IDENTITY_REGISTRY_KEY ] combiner_functions = [ translation_nas_net.ADD_COMBINER_FUNC_KEY, translation_nas_net.ADD_COMBINER_FUNC_KEY, translation_nas_net.MULTIPLY_COMBINER_FUNC_KEY, translation_nas_net.CONCAT_COMBINER_FUNC_KEY ] (num_cells, _, left_layers, _, _, _, _, _, final_combiner_function, dummy_activations, dummy_norms, layer_registry, _) = _get_transformer_branching_encoder_config() # Get predicted number of parameters. (_, output_size, _, _) = translation_nas_net.calculate_branching_model_parameters( encoding_depth=_EMBEDDING_DEPTH, left_inputs=left_inputs, left_layers=left_layers, left_output_dims=left_output_dims, right_inputs=right_inputs, right_layers=right_layers, right_output_dims=right_output_dims, combiner_functions=combiner_functions, final_combiner_function=final_combiner_function, layer_registry=layer_registry, num_cells=num_cells, encoder_depth=_EMBEDDING_DEPTH, enforce_output_size=False, enforce_fixed_output_sizes=False) self.assertEqual(output_size, 11001000)
def test_calculate_branching_model_parameters_transformer( self, get_config, expected_hidden_depths): tf.reset_default_graph() (num_cells, left_inputs, left_layers, left_output_dims, right_inputs, right_layers, right_output_dims, combiner_functions, final_combiner_function, dummy_activations, dummy_norms, layer_registry, is_decoder) = get_config() # Get predicted number of parameters. (predicted_num_params, output_size, hidden_depths, _) = translation_nas_net.calculate_branching_model_parameters( encoding_depth=_EMBEDDING_DEPTH, left_inputs=left_inputs, left_layers=left_layers, left_output_dims=left_output_dims, right_inputs=right_inputs, right_layers=right_layers, right_output_dims=right_output_dims, combiner_functions=combiner_functions, final_combiner_function=final_combiner_function, layer_registry=layer_registry, num_cells=num_cells, encoder_depth=_EMBEDDING_DEPTH) # Create model graph. input_tensor = tf.zeros([32, _INPUT_LENGTH, _EMBEDDING_DEPTH]) hparams = transformer.transformer_small() if is_decoder: nonpadding = None mask_future = True decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(_INPUT_LENGTH)) encoder_cell_outputs = [input_tensor] * 6 else: nonpadding = tf.ones([32, _INPUT_LENGTH]) mask_future = False decoder_self_attention_bias = None encoder_cell_outputs = None translation_nas_net.apply_nas_layers( input_tensor=input_tensor, left_inputs=left_inputs, left_layers=left_layers, left_activations=dummy_activations, left_output_dims=left_output_dims, left_norms=dummy_norms, right_inputs=right_inputs, right_layers=right_layers, right_activations=dummy_activations, right_output_dims=right_output_dims, right_norms=dummy_norms, combiner_functions=combiner_functions, final_combiner_function=final_combiner_function, num_cells=num_cells, nonpadding=nonpadding, layer_registry=layer_registry, mask_future=mask_future, hparams=hparams, var_scope="test", encoder_decoder_attention_bias=None, encoder_cell_outputs=encoder_cell_outputs, decoder_self_attention_bias=decoder_self_attention_bias, final_layer_norm=False) # Count graph variables. trainable_variables_list = tf.trainable_variables() empirical_num_params = 0 for variable_tensor in trainable_variables_list: empirical_num_params += _list_product( variable_tensor.shape.as_list()) # Compare. self.assertEqual(empirical_num_params, predicted_num_params) self.assertEqual(output_size, _EMBEDDING_DEPTH) self.assertEqual(hidden_depths, expected_hidden_depths)