コード例 #1
0
 def setUpClass(cls):
     """
     Avoid redundant, time-consuming, equivalent setups when testing across
     the different methods, that can use common instantiations.
     """
     cls.layer = MultiHeadAttention(
         n_attention_heads=N_ATTENTION_HEADS,
         token_representation_dimension=REPRESENTATION_DIMENSION,
         dropout_prob=DROPOUT_PROB)
     cls.forward_propagation_kwargs = {
         'query_tokens':
         torch_rand(size=(MINI_BATCH_SIZE, MAX_SEQUENCE_LENGTH - 1,
                          REPRESENTATION_DIMENSION),
                    dtype=torch_float),
         'key_or_value_tokens':
         torch_rand(size=(MINI_BATCH_SIZE, MAX_SEQUENCE_LENGTH,
                          REPRESENTATION_DIMENSION),
                    dtype=torch_float),
         'mask':
         torch_rand(size=(MINI_BATCH_SIZE, 1, MAX_SEQUENCE_LENGTH),
                    dtype=torch_float)
     }
     cls.expected_output_shapes = [
         (MINI_BATCH_SIZE, MAX_SEQUENCE_LENGTH - 1,
          REPRESENTATION_DIMENSION)
     ]
     cls.expected_output_dtypes = [torch_float]
コード例 #2
0
 def setUpClass(cls):
     """
     Avoid redundant, time-consuming, equivalent setups when testing across
     the different methods, that can use common instantiations.
     """
     feedforward_layer = PositionWiseFeedForward(
         token_representation_dimension=REPRESENTATION_DIMENSION,
         feedforward_dimension=FEEDFORWARD_DIMENSION,
         dropout_prob=DROPOUT_PROB)
     multi_head_attention_later = MultiHeadAttention(
         n_attention_heads=N_ATTENTION_HEADS,
         token_representation_dimension=REPRESENTATION_DIMENSION,
         dropout_prob=DROPOUT_PROB)
     cls.layer = EncoderBlock(building_blocks=EncoderBlockBuildingBlocks(
         self_multi_head_attention_layer=deepcopy(
             multi_head_attention_later),
         fully_connected_layer=feedforward_layer),
                              feature_dimension=REPRESENTATION_DIMENSION,
                              dropout_prob=DROPOUT_PROB)
     cls.forward_propagation_kwargs = {
         'src_features':
         torch_rand(size=(MINI_BATCH_SIZE, MAX_SEQUENCE_LENGTH,
                          REPRESENTATION_DIMENSION),
                    dtype=torch_float),
         'src_mask':
         torch_rand(size=(MINI_BATCH_SIZE, 1, MAX_SEQUENCE_LENGTH),
                    dtype=torch_float)
     }
     cls.expected_output_shapes = [(MINI_BATCH_SIZE, MAX_SEQUENCE_LENGTH,
                                    REPRESENTATION_DIMENSION)]
     cls.expected_output_dtypes = [torch_float]
コード例 #3
0
 def test_pong_duel_converter(self):
     mapper = lambda actions: torch_round(actions * 1.5 + 1).flatten(
     ).tolist()
     action_tensor = torch_rand((2, )) * 2 - 1
     action_list = mapper(action_tensor)
     self.assertEqual(len(action_list), 2,
                      "the action list is not of dim 2")
     self.assertTrue(all([0 <= i <= 2 for i in action_list]),
                     "the action list is not mapped to [0, 2]")
     print("Origin:", action_tensor)
     print("Action:", action_list)
コード例 #4
0
 def test_agent_step(self):
     wrapper = VectorizedMultiAgentEnvWrapper.MultiAgentEnvWrapper(
         env="PongDuel-v0",
         mapper=lambda actions: torch_round(actions * 1.5 + 1).flatten(
         ).tolist())
     wrapper.reset()
     mock_action = torch_rand(size=(2, ))
     states, reward, done, _ = wrapper.step(mock_action)
     self.assertTrue(isinstance(states, Tensor))
     self.assertTrue(isinstance(reward, Tensor))
     self.assertTrue(isinstance(done, list))
     self.assertEqual(states.shape, (2, 10))
     self.assertEqual(reward.shape, (1, 2))
コード例 #5
0
 def setUpClass(cls):
     """
     Avoid redundant, time-consuming, equivalent setups when testing across
     the different methods, that can use common instantiations.
     """
     cls.layer = LayerNorm(feature_dimension=REPRESENTATION_DIMENSION)
     cls.forward_propagation_kwargs = {
         'features':
         torch_rand(size=(MINI_BATCH_SIZE, MAX_SEQUENCE_LENGTH,
                          REPRESENTATION_DIMENSION),
                    dtype=torch_float)
     }
     cls.expected_output_shapes = [(MINI_BATCH_SIZE, MAX_SEQUENCE_LENGTH,
                                    REPRESENTATION_DIMENSION)]
     cls.expected_output_dtypes = [torch_float]
コード例 #6
0
 def setUpClass(cls):
     """
     Avoid redundant, time-consuming, equivalent setups when testing across
     the different methods, that can use common instantiations.
     """
     cls.layer = LogSoftmax(
         token_representation_dimension=REPRESENTATION_DIMENSION,
         vocabulary_dimension=TGT_VOCABULARY_DIMENSION)
     cls.forward_propagation_kwargs = {
         'logits':
         torch_rand(size=(MINI_BATCH_SIZE, MAX_SEQUENCE_LENGTH,
                          REPRESENTATION_DIMENSION),
                    dtype=torch_float)
     }
     cls.expected_output_shapes = [(MINI_BATCH_SIZE, MAX_SEQUENCE_LENGTH,
                                    TGT_VOCABULARY_DIMENSION)]
     cls.expected_output_dtypes = [torch_float]
コード例 #7
0
 def setUpClass(cls):
     """
     Avoid redundant, time-consuming, equivalent setups when testing across
     the different methods, that can use common instantiations.
     """
     cls.layer = PositionWiseFeedForward(
         token_representation_dimension=REPRESENTATION_DIMENSION,
         feedforward_dimension=FEEDFORWARD_DIMENSION,
         dropout_prob=DROPOUT_PROB)
     cls.forward_propagation_kwargs = {
         'features':
         torch_rand(size=(MINI_BATCH_SIZE, MAX_SEQUENCE_LENGTH,
                          REPRESENTATION_DIMENSION),
                    dtype=torch_float)
     }
     cls.expected_output_shapes = [(MINI_BATCH_SIZE, MAX_SEQUENCE_LENGTH,
                                    REPRESENTATION_DIMENSION)]
     cls.expected_output_dtypes = [torch_float]
コード例 #8
0
 def setUpClass(cls):
     """
     Avoid redundant, time-consuming, equivalent setups when testing across
     the different methods, that can use common instantiations.
     """
     cls.layer = PositionalEncoding(
         token_representation_dimension=REPRESENTATION_DIMENSION,
         dropout_prob=DROPOUT_PROB,
         max_sequence_length=MAX_SEQUENCE_LENGTH)
     cls.forward_propagation_kwargs = {
         'token_embeddings':
         torch_rand(size=(MINI_BATCH_SIZE, MAX_SEQUENCE_LENGTH,
                          REPRESENTATION_DIMENSION),
                    dtype=torch_float)
     }
     cls.expected_output_shapes = [(MINI_BATCH_SIZE, MAX_SEQUENCE_LENGTH,
                                    REPRESENTATION_DIMENSION)]
     cls.expected_output_dtypes = [torch_float]
コード例 #9
0
 def setUpClass(cls):
     """
     Avoid redundant, time-consuming, equivalent setups when testing across
     the different methods, that can use common instantiations.
     """
     positional_encoding_layer = PositionalEncoding(
         token_representation_dimension=REPRESENTATION_DIMENSION,
         dropout_prob=DROPOUT_PROB,
         max_sequence_length=MAX_SEQUENCE_LENGTH)
     src_embedder = Sequential(
         Embedder(token_representation_dimension=REPRESENTATION_DIMENSION,
                  vocabulary_dimension=SRC_VOCABULARY_DIMENSION),
         deepcopy(positional_encoding_layer))
     tgt_embedder = Sequential(
         Embedder(token_representation_dimension=REPRESENTATION_DIMENSION,
                  vocabulary_dimension=TGT_VOCABULARY_DIMENSION),
         deepcopy(positional_encoding_layer))
     feedforward_layer = PositionWiseFeedForward(
         token_representation_dimension=REPRESENTATION_DIMENSION,
         feedforward_dimension=FEEDFORWARD_DIMENSION,
         dropout_prob=DROPOUT_PROB)
     multi_head_attention_later = MultiHeadAttention(
         n_attention_heads=N_ATTENTION_HEADS,
         token_representation_dimension=REPRESENTATION_DIMENSION,
         dropout_prob=DROPOUT_PROB)
     encoder = Encoder(base_block=EncoderBlock(
         building_blocks=EncoderBlockBuildingBlocks(
             self_multi_head_attention_layer=deepcopy(
                 multi_head_attention_later),
             fully_connected_layer=feedforward_layer),
         feature_dimension=REPRESENTATION_DIMENSION,
         dropout_prob=DROPOUT_PROB),
                       n_clones=N_ENCODER_BLOCKS)
     decoder = Decoder(base_block=DecoderBlock(
         building_blocks=DecoderBlockBuildingBlocks(
             self_multi_head_attention_layer=deepcopy(
                 multi_head_attention_later),
             source_multi_head_attention_layer=deepcopy(
                 multi_head_attention_later),
             fully_connected_layer=feedforward_layer),
         feature_dimension=REPRESENTATION_DIMENSION,
         dropout_prob=DROPOUT_PROB),
                       n_clones=N_DECODER_BLOCKS)
     log_softmax_layer = LogSoftmax(
         token_representation_dimension=REPRESENTATION_DIMENSION,
         vocabulary_dimension=TGT_VOCABULARY_DIMENSION)
     building_blocks = Seq2SeqBuildingBlocks(
         encoder=encoder,
         decoder=decoder,
         src_embedder=src_embedder,
         tgt_embedder=tgt_embedder,
         log_softmax_layer=log_softmax_layer)
     cls.layer = Seq2Seq(building_blocks=building_blocks)
     cls.forward_propagation_kwargs = {
         'src_tokens':
         torch_randint(low=0,
                       high=SRC_VOCABULARY_DIMENSION,
                       size=(MINI_BATCH_SIZE, MAX_SEQUENCE_LENGTH),
                       dtype=torch_long),
         'tgt_tokens':
         torch_randint(low=0,
                       high=TGT_VOCABULARY_DIMENSION,
                       size=(MINI_BATCH_SIZE, MAX_SEQUENCE_LENGTH - 1),
                       dtype=torch_long),
         'src_mask':
         torch_rand(size=(MINI_BATCH_SIZE, 1, MAX_SEQUENCE_LENGTH),
                    dtype=torch_float),
         'tgt_mask':
         torch_rand(size=(MINI_BATCH_SIZE, MAX_SEQUENCE_LENGTH - 1,
                          MAX_SEQUENCE_LENGTH - 1),
                    dtype=torch_float)
     }
     cls.expected_output_shapes = [
         (MINI_BATCH_SIZE, MAX_SEQUENCE_LENGTH - 1,
          REPRESENTATION_DIMENSION)
     ]
     cls.expected_output_dtypes = [torch_float]