コード例 #1
0
ファイル: reformer_test.py プロジェクト: wangleiphy/trax
 def test_reformer_lm_forward_shape(self):
     """Run the ReformerLM forward and check output shape."""
     vocab_size = 16
     input_sd = ShapeDtype((1, 8), np.int32)
     input_signature = (input_sd, input_sd)
     model = reformer.ReformerLM(vocab_size,
                                 d_model=32,
                                 d_ff=64,
                                 d_attention_key=16,
                                 d_attention_value=16,
                                 n_layers=1,
                                 n_heads=2,
                                 max_len=16,
                                 n_chunks=2,
                                 n_attention_chunks=1)
     final_shape = tl.check_shape_agreement(model, input_signature)
     self.assertEqual(((1, 8, 16), (1, 8, 16)), final_shape)
コード例 #2
0
 def _test_transformer_forward_shape(self, input_vocab_size,
                                     output_vocab_size):
     """Run the Transformer forward and check output shape."""
     single_input_shape = [3, 5]
     input_shape = (tuple(single_input_shape), tuple(single_input_shape))
     model = transformer.Transformer(input_vocab_size,
                                     output_vocab_size,
                                     d_model=32,
                                     d_ff=64,
                                     n_layers=2,
                                     n_heads=2)
     final_shape = tl.check_shape_agreement(model,
                                            input_shape,
                                            integer_inputs=True)
     expected_shape = (tuple(single_input_shape + [
         output_vocab_size
         if output_vocab_size is not None else input_vocab_size
     ]))
     self.assertEqual(expected_shape, final_shape[0])
コード例 #3
0
 def test_rnnlm_forward_shape(self):
   """Runs the RNN LM forward and checks output shape."""
   input_signature = ShapeDtype((3, 28), dtype=math.numpy.int32)
   model = rnn.RNNLM(vocab_size=20, d_model=16)
   final_shape = tl.check_shape_agreement(model, input_signature)
   self.assertEqual((3, 28, 20), final_shape)
コード例 #4
0
ファイル: mlp_test.py プロジェクト: zzszmyf/trax
 def test_mlp_forward_shape(self):
     """Run the MLP model forward and check output shape."""
     input_signature = ShapeDtype((3, 28, 28, 1))
     model = mlp.MLP(d_hidden=32, n_output_classes=10)
     final_shape = tl.check_shape_agreement(model, input_signature)
     self.assertEqual((3, 10), final_shape)
コード例 #5
0
 def test_wide_resnet(self):
     input_shape = (3, 32, 32, 3)
     model = resnet.WideResnet(n_blocks=1, n_output_classes=10)
     final_shape = tl.check_shape_agreement(model, input_shape)
     self.assertEqual((3, 10), final_shape)
コード例 #6
0
 def test_resnet(self):
     input_shape = (3, 256, 256, 3)
     model = resnet.Resnet50(d_hidden=8, n_output_classes=10)
     final_shape = tl.check_shape_agreement(model, input_shape)
     self.assertEqual((3, 10), final_shape)
コード例 #7
0
ファイル: mlp_test.py プロジェクト: youngjt/trax
 def test_pure_mlp_forward_shape(self):
     """Run the PureMLP model forward and check output shape."""
     input_signature = ShapeDtype((7, 28, 28, 3))
     model = mlp.PureMLP(hidden_dims=(32, 16, 8))
     final_shape = tl.check_shape_agreement(model, input_signature)
     self.assertEqual((7, 8), final_shape)
コード例 #8
0
 def test_image_fec(self):
     input_signature = ShapeDtype((3, 256, 256, 3))
     model = ImageFEC(d_hidden=8, n_output_classes=10)
     final_shape = tl.check_shape_agreement(model, input_signature)
     self.assertEqual((3, 10), final_shape)