Exemplo n.º 1
0
    def test_forward_success(self, forward_mock, module_dict_get, tensor_dict_get, tensor_dict_set):
        # mock
        net: nn.Module = DummyNet()
        forward_mock.return_value = self.dummy_states_tensor
        net.forward = forward_mock

        module_dict = ModuleDict()
        module_dict_get.return_value = net
        module_dict.get = module_dict_get

        tensor_dict = TensorDict()
        tensor_dict_get.return_value = self.dummy_states_tensor
        tensor_dict_set.return_value = None
        tensor_dict.get = tensor_dict_get
        tensor_dict.set = tensor_dict_set

        array_dict = ArrayDict(0)

        # run
        tensor_inserter = TensorInserterForward(TensorKey.states_tensor, ModuleKey.scaler, TensorKey.next_states_tensor)
        tensor_inserter.insert_tensor(tensor_dict, array_dict, module_dict, np.arange(N_EXAMPLES))

        # assert
        tensor_dict_get.assert_called_once_with(TensorKey.states_tensor)
        module_dict_get.assert_called_once_with(ModuleKey.scaler)
        forward_mock.assert_called_once_with(self.dummy_states_tensor)
        tensor_dict_set.assert_called_once()
Exemplo n.º 2
0
    def test_composition_success(self, forward_mock, array_dict_get, module_dict_get, tensor_dict_get, tensor_dict_set):
        # mock
        array_dict = ArrayDict(N_EXAMPLES)
        array_dict_get.return_value = self.dummy_states
        array_dict.get = array_dict_get

        tensor_dict = TensorDict()
        tensor_dict_get.return_value = self.dummy_states_tensor
        tensor_dict_set.return_value = None
        tensor_dict.get = tensor_dict_get
        tensor_dict.set = tensor_dict_set

        net: nn.Module = DummyNet()
        forward_mock.return_value = self.dummy_states_tensor
        net.forward = forward_mock

        module_dict = ModuleDict()
        module_dict_get.return_value = net
        module_dict.get = module_dict_get

        # run
        tensor_inserter1 = TensorInserterTensorize(ArrayKey.states, TensorKey.states_tensor)
        tensor_inserter2 = TensorInserterForward(TensorKey.states_tensor, ModuleKey.scaler, TensorKey.states_tensor)
        tensor_inserter = tensor_inserter1 + tensor_inserter2
        tensor_inserter.insert_tensor(tensor_dict, array_dict, module_dict, np.arange(N_EXAMPLES))

        # assert
        array_dict_get.called_once_with(ArrayKey.states)
        self.assertEqual(tensor_dict_set.call_count, 2)
        self.assertEqual(tensor_dict_get.call_count, 1)
        tensor_dict_get.assert_called_once_with(TensorKey.states_tensor)
        module_dict_get.assert_called_once_with(ModuleKey.scaler)
 def insert_tensor(self, tensor_dict: TensorDict, array_dict: ArrayDict,
                   module_dict: ModuleDict, batch_idx: np.ndarray):
     source_tensor = tensor_dict.get(self.source_key).to(device)
     module = module_dict.get(self.module_key).to(device)
     target_tensor = module.forward(source_tensor)
     tensor_dict.set(self.target_key, target_tensor)
     return tensor_dict
 def insert_tensor(self, tensor_dict: TensorDict, array_dict: ArrayDict,
                   module_dict: ModuleDict, batch_idx: np.ndarray):
     input_tensors = [
         tensor_dict.get(k).to(device) for k in self.source_keys
     ]
     output_tensor = self.f(*input_tensors)
     tensor_dict.set(self.target_key, output_tensor)
     return tensor_dict
 def insert_tensor(self, tensor_dict: TensorDict, array_dict: ArrayDict,
                   module_dict: ModuleDict, batch_idx: np.ndarray):
     array = array_dict.get(self.array_key)[batch_idx]
     tensor = torch.as_tensor(array, dtype=self.dtype).to(device)
     if len(tensor.shape) == 1:
         tensor = tensor.reshape(-1, 1)
     tensor_dict.set(self.tensor_key, tensor)
     return tensor_dict
Exemplo n.º 6
0
    def test_lambda_success(self, get_mock):
        # mock
        tensor_dict = TensorDict()
        get_mock.return_value = 1
        tensor_dict.get = get_mock

        # run
        loss_calculator: LossCalculator = LossCalculatorLambda([TensorKey.states_tensor, TensorKey.next_states_tensor],
                                                               lambda x, y: x + y)
        loss = loss_calculator.calculate_loss(tensor_dict)

        # assert
        self.assertEqual(get_mock.call_count, 2)
        get_mock.assert_any_call(TensorKey.states_tensor)
        get_mock.assert_any_call(TensorKey.next_states_tensor)
        self.assertEqual(loss, 2)
Exemplo n.º 7
0
    def test_lambda_success(self, get_mock, set_mock):
        # mock
        array_dict = ArrayDict(N_EXAMPLES)
        module_dict = ModuleDict()
        tensor_dict = TensorDict()
        get_mock.return_value = 1
        tensor_dict.get = get_mock
        tensor_dict.set = set_mock

        # run
        tensor_inserter: TensorInserter = TensorInserterLambda([TensorKey.states_tensor, TensorKey.actions_tensor],
                                                               lambda x, y: x + y, TensorKey.dones_tensor)
        tensor_inserter.insert_tensor(tensor_dict, array_dict, module_dict, np.arange(N_EXAMPLES))

        # assert
        get_mock.assert_any_call(TensorKey.states_tensor)
        get_mock.assert_any_call(TensorKey.actions_tensor)
        set_mock.assert_called_with(TensorKey.dones_tensor, 2)
Exemplo n.º 8
0
    def test_mse_success(self, forward_mock, get_mock):
        # mock
        mse: nn.Module = nn.MSELoss()
        forward_mock.return_value = torch.as_tensor(0).float()
        mse.forward = forward_mock

        tensor_dict: TensorDict = TensorDict()
        tensor_dict.get = get_mock

        loss_calculator: LossCalculator = LossCalculatorInputTarget(TensorKey.states_tensor,
                                                                    TensorKey.next_states_tensor, mse)

        # run
        loss_calculator.calculate_loss(tensor_dict)

        # assert
        self.assertEqual(get_mock.call_count, 2)
        get_mock.assert_any_call(TensorKey.next_states_tensor)
        get_mock.assert_any_call(TensorKey.states_tensor)
        forward_mock.assert_called_once()
Exemplo n.º 9
0
 def train_one_cycle(self, module_dict: ModuleDict):
     """
     A cycle refers to the cycle through the trainees.
     By default, one sample collection is done per cycle.
     """
     array_dict: ArrayDict = self.sample_collector.collect_samples_by_number(
     )
     n_batches = int(array_dict.n_examples / self.batch_size)
     for trainee in self.trainees:
         for epoch in range(trainee.n_epochs):
             all_idxs = np.random.choice(array_dict.n_examples,
                                         array_dict.n_examples,
                                         replace=False)
             batch_idxs = np.array_split(all_idxs, n_batches)
             tensor_dict: TensorDict = TensorDict()
             for batch_idx in batch_idxs:
                 trainee.tensor_inserter.insert_tensor(
                     tensor_dict, array_dict, module_dict, batch_idx)
                 loss = trainee.loss_calculator.calculate_loss(tensor_dict)
                 trainee.module_updater.update_module(loss)
Exemplo n.º 10
0
    def test_tensorize_success(self, array_dict_get, tensor_dict_get, tensor_dict_set):
        # mock
        array_dict: ArrayDict = ArrayDict(N_EXAMPLES)
        array_dict_get.return_value = self.dummy_states
        array_dict.get = array_dict_get
        tensor_dict: TensorDict = TensorDict()
        tensor_dict_get.return_value = self.dummy_states_tensor
        tensor_dict_set.return_value = None
        tensor_dict.get = tensor_dict_get
        tensor_dict.set = tensor_dict_set

        # run
        tensor_inserter: TensorInserter = TensorInserterTensorize(ArrayKey.states, TensorKey.states_tensor, torch.float)
        tensor_dict = tensor_inserter.insert_tensor(tensor_dict, array_dict, self.dummy_module_dict,
                                                    np.arange(N_EXAMPLES))

        # assert
        array_dict_get.assert_called_once_with(ArrayKey.states)
        tensor_dict_set.assert_called_once()

        np.testing.assert_array_almost_equal(tensor_dict.get(TensorKey.states_tensor), self.dummy_states)
Exemplo n.º 11
0
    def test_composition_success(self, get_mock, forward_mock):
        # mock
        mse: nn.Module = nn.MSELoss()
        forward_mock.return_value = torch.as_tensor(0).float()
        mse.forward = forward_mock

        tensor_dict: TensorDict = TensorDict()
        tensor_dict.get = get_mock

        loss_calculator1: LossCalculator = LossCalculatorInputTarget(TensorKey.states_tensor,
                                                                     TensorKey.next_states_tensor, mse)
        loss_calculator2: LossCalculator = LossCalculatorInputTarget(TensorKey.actions_tensor,
                                                                     TensorKey.dones_tensor, mse)
        loss_calculator = loss_calculator1 + loss_calculator2
        # run
        loss_calculator.calculate_loss(tensor_dict)

        # assert
        self.assertEqual(get_mock.call_count, 4)
        get_mock.assert_any_call(TensorKey.next_states_tensor)
        get_mock.assert_any_call(TensorKey.states_tensor)
        get_mock.assert_any_call(TensorKey.actions_tensor)
        get_mock.assert_any_call(TensorKey.dones_tensor)
        self.assertEqual(forward_mock.call_count, 2)
Exemplo n.º 12
0
 def calculate_loss(self, tensor_dict: TensorDict):
     input_tensors = [tensor_dict.get(k) for k in self.input_keys]
     loss = self.f(*input_tensors)
     return loss
Exemplo n.º 13
0
 def calculate_loss(self, tensor_dict: TensorDict):
     input_tensor = tensor_dict.get(self.input_key)
     target_tensor = tensor_dict.get(self.target_key)
     loss = self.weight * self.loss_module.forward(input_tensor,
                                                   target_tensor)
     return loss