예제 #1
0
    def test_nas_decoder_resizing_output(self):
        hparams, wrong_size = self._get_wrong_output_dim_decoder_hparams()
        hparams.enforce_output_size = False
        input_tensor = tf.zeros([_BATCH_SIZE, _INPUT_LENGTH, _EMBEDDING_DEPTH])
        decoder_self_attention_bias = (
            common_attention.attention_bias_lower_triangle(_INPUT_LENGTH))
        with tf.variable_scope("wrong"):
            wrong_size_decoder_output = translation_nas_net.nas_decoder(
                decoder_input=input_tensor,
                encoder_cell_outputs=[input_tensor] *
                hparams.encoder_num_cells,
                decoder_self_attention_bias=decoder_self_attention_bias,
                encoder_decoder_attention_bias=None,
                hparams=hparams)

        # Now add the correction.
        hparams.enforce_output_size = True
        with tf.variable_scope("correct"):
            correct_size_decoder_output = translation_nas_net.nas_decoder(
                decoder_input=input_tensor,
                encoder_cell_outputs=[input_tensor] *
                hparams.encoder_num_cells,
                decoder_self_attention_bias=decoder_self_attention_bias,
                encoder_decoder_attention_bias=None,
                hparams=hparams)

        with self.test_session() as session:
            session.run(tf.global_variables_initializer())
            wrong_output, correct_output = session.run(
                [wrong_size_decoder_output, correct_size_decoder_output])
        self.assertEqual(wrong_output.shape,
                         (_BATCH_SIZE, _INPUT_LENGTH, wrong_size))
        self.assertEqual(correct_output.shape,
                         (_BATCH_SIZE, _INPUT_LENGTH, _EMBEDDING_DEPTH))
예제 #2
0
    def test_calculate_branching_model_parameters_decoder_resize(
            self, enforce_output_size):
        tf.reset_default_graph()

        hparams, _ = self._get_wrong_output_dim_decoder_hparams()
        hparams.enforce_output_size = enforce_output_size
        hparams.decoder_left_norms = [translation_nas_net.NO_NORM_KEY] * 5
        hparams.decoder_right_norms = [translation_nas_net.NO_NORM_KEY] * 5

        # Get predicted number of parameters.
        (predicted_num_params, _, _,
         _) = translation_nas_net.calculate_branching_model_parameters(
             encoding_depth=_EMBEDDING_DEPTH,
             left_inputs=hparams.decoder_left_inputs,
             left_layers=hparams.decoder_left_layers,
             left_output_dims=hparams.decoder_left_output_dims,
             right_inputs=hparams.decoder_right_inputs,
             right_layers=hparams.decoder_right_layers,
             right_output_dims=hparams.decoder_right_output_dims,
             combiner_functions=hparams.decoder_combiner_functions,
             final_combiner_function=hparams.decoder_final_combiner_function,
             layer_registry=layers.DECODER_LAYERS,
             num_cells=hparams.decoder_num_cells,
             encoder_depth=_EMBEDDING_DEPTH,
             enforce_output_size=enforce_output_size)

        # Count graph variables.
        input_tensor = tf.zeros([_BATCH_SIZE, _INPUT_LENGTH, _EMBEDDING_DEPTH])
        decoder_self_attention_bias = (
            common_attention.attention_bias_lower_triangle(_INPUT_LENGTH))
        _ = translation_nas_net.nas_decoder(
            decoder_input=input_tensor,
            encoder_cell_outputs=[input_tensor] * hparams.encoder_num_cells,
            decoder_self_attention_bias=decoder_self_attention_bias,
            encoder_decoder_attention_bias=None,
            hparams=hparams,
            final_layer_norm=False)
        trainable_variables_list = tf.trainable_variables()
        empirical_num_params = 0
        for variable_tensor in trainable_variables_list:
            empirical_num_params += _list_product(
                variable_tensor.shape.as_list())

        self.assertEqual(empirical_num_params, predicted_num_params)