def test_calculate_branching_model_parameters_decoder_resize(
            self, enforce_output_size):
        tf.reset_default_graph()

        hparams, _ = self._get_wrong_output_dim_decoder_hparams()
        hparams.enforce_output_size = enforce_output_size
        hparams.decoder_left_norms = [translation_nas_net.NO_NORM_KEY] * 5
        hparams.decoder_right_norms = [translation_nas_net.NO_NORM_KEY] * 5

        # Get predicted number of parameters.
        (predicted_num_params, _, _,
         _) = translation_nas_net.calculate_branching_model_parameters(
             encoding_depth=_EMBEDDING_DEPTH,
             left_inputs=hparams.decoder_left_inputs,
             left_layers=hparams.decoder_left_layers,
             left_output_dims=hparams.decoder_left_output_dims,
             right_inputs=hparams.decoder_right_inputs,
             right_layers=hparams.decoder_right_layers,
             right_output_dims=hparams.decoder_right_output_dims,
             combiner_functions=hparams.decoder_combiner_functions,
             final_combiner_function=hparams.decoder_final_combiner_function,
             layer_registry=layers.DECODER_LAYERS,
             num_cells=hparams.decoder_num_cells,
             encoder_depth=_EMBEDDING_DEPTH,
             enforce_output_size=enforce_output_size)

        # Count graph variables.
        input_tensor = tf.zeros([_BATCH_SIZE, _INPUT_LENGTH, _EMBEDDING_DEPTH])
        decoder_self_attention_bias = (
            common_attention.attention_bias_lower_triangle(_INPUT_LENGTH))
        _ = translation_nas_net.nas_decoder(
            decoder_input=input_tensor,
            encoder_cell_outputs=[input_tensor] * hparams.encoder_num_cells,
            decoder_self_attention_bias=decoder_self_attention_bias,
            encoder_decoder_attention_bias=None,
            hparams=hparams,
            final_layer_norm=False)
        trainable_variables_list = tf.trainable_variables()
        empirical_num_params = 0
        for variable_tensor in trainable_variables_list:
            empirical_num_params += _list_product(
                variable_tensor.shape.as_list())

        self.assertEqual(empirical_num_params, predicted_num_params)
    def test_calculate_branching_model_parameters_output_size_last_two(self):
        left_inputs = [0, 1, 2, 2]
        right_inputs = [0, 1, 2, 2]
        left_output_dims = [1, 10, 100, 1000]
        right_output_dims = [10000, 100000, 1000000, 10000000]
        right_layers = [
            layers.IDENTITY_REGISTRY_KEY,
            layers.STANDARD_CONV_1X1_REGISTRY_KEY,
            layers.STANDARD_CONV_1X1_REGISTRY_KEY, layers.IDENTITY_REGISTRY_KEY
        ]
        combiner_functions = [
            translation_nas_net.ADD_COMBINER_FUNC_KEY,
            translation_nas_net.ADD_COMBINER_FUNC_KEY,
            translation_nas_net.MULTIPLY_COMBINER_FUNC_KEY,
            translation_nas_net.CONCAT_COMBINER_FUNC_KEY
        ]

        (num_cells, _, left_layers, _, _, _, _, _, final_combiner_function,
         dummy_activations, dummy_norms, layer_registry,
         _) = _get_transformer_branching_encoder_config()

        # Get predicted number of parameters.
        (_, output_size, _,
         _) = translation_nas_net.calculate_branching_model_parameters(
             encoding_depth=_EMBEDDING_DEPTH,
             left_inputs=left_inputs,
             left_layers=left_layers,
             left_output_dims=left_output_dims,
             right_inputs=right_inputs,
             right_layers=right_layers,
             right_output_dims=right_output_dims,
             combiner_functions=combiner_functions,
             final_combiner_function=final_combiner_function,
             layer_registry=layer_registry,
             num_cells=num_cells,
             encoder_depth=_EMBEDDING_DEPTH,
             enforce_output_size=False,
             enforce_fixed_output_sizes=False)

        self.assertEqual(output_size, 11001000)
    def test_calculate_branching_model_parameters_transformer(
            self, get_config, expected_hidden_depths):
        tf.reset_default_graph()

        (num_cells, left_inputs, left_layers, left_output_dims, right_inputs,
         right_layers, right_output_dims, combiner_functions,
         final_combiner_function, dummy_activations, dummy_norms,
         layer_registry, is_decoder) = get_config()

        # Get predicted number of parameters.
        (predicted_num_params, output_size, hidden_depths,
         _) = translation_nas_net.calculate_branching_model_parameters(
             encoding_depth=_EMBEDDING_DEPTH,
             left_inputs=left_inputs,
             left_layers=left_layers,
             left_output_dims=left_output_dims,
             right_inputs=right_inputs,
             right_layers=right_layers,
             right_output_dims=right_output_dims,
             combiner_functions=combiner_functions,
             final_combiner_function=final_combiner_function,
             layer_registry=layer_registry,
             num_cells=num_cells,
             encoder_depth=_EMBEDDING_DEPTH)

        # Create model graph.
        input_tensor = tf.zeros([32, _INPUT_LENGTH, _EMBEDDING_DEPTH])
        hparams = transformer.transformer_small()

        if is_decoder:
            nonpadding = None
            mask_future = True
            decoder_self_attention_bias = (
                common_attention.attention_bias_lower_triangle(_INPUT_LENGTH))
            encoder_cell_outputs = [input_tensor] * 6
        else:
            nonpadding = tf.ones([32, _INPUT_LENGTH])
            mask_future = False
            decoder_self_attention_bias = None
            encoder_cell_outputs = None

        translation_nas_net.apply_nas_layers(
            input_tensor=input_tensor,
            left_inputs=left_inputs,
            left_layers=left_layers,
            left_activations=dummy_activations,
            left_output_dims=left_output_dims,
            left_norms=dummy_norms,
            right_inputs=right_inputs,
            right_layers=right_layers,
            right_activations=dummy_activations,
            right_output_dims=right_output_dims,
            right_norms=dummy_norms,
            combiner_functions=combiner_functions,
            final_combiner_function=final_combiner_function,
            num_cells=num_cells,
            nonpadding=nonpadding,
            layer_registry=layer_registry,
            mask_future=mask_future,
            hparams=hparams,
            var_scope="test",
            encoder_decoder_attention_bias=None,
            encoder_cell_outputs=encoder_cell_outputs,
            decoder_self_attention_bias=decoder_self_attention_bias,
            final_layer_norm=False)

        # Count graph variables.
        trainable_variables_list = tf.trainable_variables()
        empirical_num_params = 0
        for variable_tensor in trainable_variables_list:
            empirical_num_params += _list_product(
                variable_tensor.shape.as_list())

        # Compare.
        self.assertEqual(empirical_num_params, predicted_num_params)
        self.assertEqual(output_size, _EMBEDDING_DEPTH)
        self.assertEqual(hidden_depths, expected_hidden_depths)