def test_create_model(self): self.params = model_params.TINY_PARAMS self.params["batch_size"] = 16 self.params["hidden_size"] = 12 self.params["num_hidden_layers"] = 2 self.params["filter_size"] = 14 self.params["num_heads"] = 2 self.params["vocab_size"] = 41 self.params["extra_decode_length"] = 2 self.params["beam_size"] = 3 self.params["dtype"] = tf.float32 model = seq2seq_transformer.create_model(self.params, is_train=True) inputs, outputs = model.inputs, model.outputs self.assertLen(inputs, 2) self.assertLen(outputs, 1) self.assertEqual(inputs[0].shape.as_list(), [None, None]) self.assertEqual(inputs[0].dtype, tf.int64) self.assertEqual(inputs[1].shape.as_list(), [None, None]) self.assertEqual(inputs[1].dtype, tf.int64) self.assertEqual(outputs[0].shape.as_list(), [None, None, 41]) self.assertEqual(outputs[0].dtype, tf.float32) model = seq2seq_transformer.create_model(self.params, is_train=False) inputs, outputs = model.inputs, model.outputs self.assertLen(inputs, 1) self.assertLen(outputs, 2) self.assertEqual(inputs[0].shape.as_list(), [None, None]) self.assertEqual(inputs[0].dtype, tf.int64) self.assertEqual(outputs[0].shape.as_list(), [None, None]) self.assertEqual(outputs[0].dtype, tf.int32) self.assertEqual(outputs[1].shape.as_list(), [None]) self.assertEqual(outputs[1].dtype, tf.float32)
def test_create_model_not_train(self): model = seq2seq_transformer.create_model(self.params, False) inputs, outputs = model.inputs, model.outputs self.assertEqual(len(inputs), 1) self.assertEqual(len(outputs), 2) self.assertEqual(inputs[0].shape.as_list(), [None, None]) self.assertEqual(inputs[0].dtype, tf.int64) self.assertEqual(outputs[0].shape.as_list(), [None, None]) self.assertEqual(outputs[0].dtype, tf.int32) self.assertEqual(outputs[1].shape.as_list(), [None]) self.assertEqual(outputs[1].dtype, tf.float32)
def test_forward_pass_not_train(self): inputs = np.asarray([[5, 2, 1], [7, 5, 0], [1, 4, 0], [7, 5, 11]]) # src_model is the original model before refactored. src_model = transformer.create_model(self.params, False) src_num_weights = _count_params(src_model) src_weights = src_model.get_weights() src_model_output = src_model([inputs], training=False) # dest_model is the refactored model. dest_model = seq2seq_transformer.create_model(self.params, False) dest_num_weights = _count_params(dest_model) self.assertEqual(src_num_weights, dest_num_weights) dest_model.set_weights(src_weights) dest_model_output = dest_model([inputs], training=False) self.assertAllEqual(src_model_output[0], dest_model_output[0]) self.assertAllEqual(src_model_output[1], dest_model_output[1])
def test_forward_pass_train(self): # Set input_len different from target_len inputs = np.asarray([[5, 2, 1], [7, 5, 0], [1, 4, 0], [7, 5, 11]]) targets = np.asarray([[4, 3, 4, 0], [13, 19, 17, 8], [20, 14, 1, 2], [5, 7, 3, 0]]) # src_model is the original model before refactored. src_model = transformer.create_model(self.params, True) src_num_weights = _count_params(src_model) src_weights = src_model.get_weights() src_model_output = src_model([inputs, targets], training=True) # dest_model is the refactored model. dest_model = seq2seq_transformer.create_model(self.params, True) dest_num_weights = _count_params(dest_model) self.assertEqual(src_num_weights, dest_num_weights) dest_model.set_weights(src_weights) dest_model_output = dest_model([inputs, targets], training=True) self.assertAllEqual(src_model_output, dest_model_output)