コード例 #1
0
    def test_create_model(self):
        self.params = model_params.TINY_PARAMS
        self.params["batch_size"] = 16
        self.params["hidden_size"] = 12
        self.params["num_hidden_layers"] = 2
        self.params["filter_size"] = 14
        self.params["num_heads"] = 2
        self.params["vocab_size"] = 41
        self.params["extra_decode_length"] = 2
        self.params["beam_size"] = 3
        self.params["dtype"] = tf.float32
        model = seq2seq_transformer.create_model(self.params, is_train=True)
        inputs, outputs = model.inputs, model.outputs
        self.assertLen(inputs, 2)
        self.assertLen(outputs, 1)
        self.assertEqual(inputs[0].shape.as_list(), [None, None])
        self.assertEqual(inputs[0].dtype, tf.int64)
        self.assertEqual(inputs[1].shape.as_list(), [None, None])
        self.assertEqual(inputs[1].dtype, tf.int64)
        self.assertEqual(outputs[0].shape.as_list(), [None, None, 41])
        self.assertEqual(outputs[0].dtype, tf.float32)

        model = seq2seq_transformer.create_model(self.params, is_train=False)
        inputs, outputs = model.inputs, model.outputs
        self.assertLen(inputs, 1)
        self.assertLen(outputs, 2)
        self.assertEqual(inputs[0].shape.as_list(), [None, None])
        self.assertEqual(inputs[0].dtype, tf.int64)
        self.assertEqual(outputs[0].shape.as_list(), [None, None])
        self.assertEqual(outputs[0].dtype, tf.int32)
        self.assertEqual(outputs[1].shape.as_list(), [None])
        self.assertEqual(outputs[1].dtype, tf.float32)
コード例 #2
0
 def test_create_model_not_train(self):
     model = seq2seq_transformer.create_model(self.params, False)
     inputs, outputs = model.inputs, model.outputs
     self.assertEqual(len(inputs), 1)
     self.assertEqual(len(outputs), 2)
     self.assertEqual(inputs[0].shape.as_list(), [None, None])
     self.assertEqual(inputs[0].dtype, tf.int64)
     self.assertEqual(outputs[0].shape.as_list(), [None, None])
     self.assertEqual(outputs[0].dtype, tf.int32)
     self.assertEqual(outputs[1].shape.as_list(), [None])
     self.assertEqual(outputs[1].dtype, tf.float32)
コード例 #3
0
    def test_forward_pass_not_train(self):
        inputs = np.asarray([[5, 2, 1], [7, 5, 0], [1, 4, 0], [7, 5, 11]])

        # src_model is the original model before refactored.
        src_model = transformer.create_model(self.params, False)
        src_num_weights = _count_params(src_model)
        src_weights = src_model.get_weights()
        src_model_output = src_model([inputs], training=False)

        # dest_model is the refactored model.
        dest_model = seq2seq_transformer.create_model(self.params, False)
        dest_num_weights = _count_params(dest_model)
        self.assertEqual(src_num_weights, dest_num_weights)
        dest_model.set_weights(src_weights)
        dest_model_output = dest_model([inputs], training=False)
        self.assertAllEqual(src_model_output[0], dest_model_output[0])
        self.assertAllEqual(src_model_output[1], dest_model_output[1])
コード例 #4
0
    def test_forward_pass_train(self):
        # Set input_len different from target_len
        inputs = np.asarray([[5, 2, 1], [7, 5, 0], [1, 4, 0], [7, 5, 11]])
        targets = np.asarray([[4, 3, 4, 0], [13, 19, 17, 8], [20, 14, 1, 2],
                              [5, 7, 3, 0]])

        # src_model is the original model before refactored.
        src_model = transformer.create_model(self.params, True)
        src_num_weights = _count_params(src_model)
        src_weights = src_model.get_weights()
        src_model_output = src_model([inputs, targets], training=True)

        # dest_model is the refactored model.
        dest_model = seq2seq_transformer.create_model(self.params, True)
        dest_num_weights = _count_params(dest_model)
        self.assertEqual(src_num_weights, dest_num_weights)
        dest_model.set_weights(src_weights)
        dest_model_output = dest_model([inputs, targets], training=True)
        self.assertAllEqual(src_model_output, dest_model_output)