def test_step_with_done(self, batch): output = make_output() self.network.infer = MagicMock(return_value=output) self.metrics.add = MagicMock() self.metrics.get = MagicMock(return_value=1) reward = np.random.random() if batch: inpt = list(make_input(batch_size=4, batch=True)) index = np.random.randint(4) inpt[2] = np.zeros((4, )) inpt[2][index] = 1.0 inpt[3][index]['reward'] = reward else: inpt = list(make_input()) inpt[2] = 1.0 inpt[3]['reward'] = reward self.controller.step(*inpt) if batch: assert self.metrics.add.call_count == 2 assert list(self.metrics.add.mock_calls[1])[1] == ('eval_reward', reward) assert list(self.metrics.add.mock_calls[0])[1] == ('eval_episode', 1) else: self.metrics.add.assert_not_called()
def test_batches(self): output = make_output(batch_size=4, batch=True) self.network._infer_arguments = MagicMock(return_value=['obs_t']) input_history = [] output_history = [] for i in range(129): inpt = make_input(batch_size=4, batch=True) self.network._infer = MagicMock(return_value=output) action = self.controller.step(*inpt) input_history.append(inpt) output_history.append(output) for key in ['obs_t', 'actions_t', 'log_probs_t', 'returns_t', 'advantages_t', 'values_t']: count = 0 for batch in self.controller._batches(): count += 1 assert key in batch assert batch[key].shape[0] == 32 if key == 'obs_t': assert batch[key].shape[1:] == inpt[0].shape[1:] elif key == 'actions_t': assert batch[key].shape[1] == action.shape[1] elif key == 'log_probs_t': assert len(batch[key].shape) == 1 elif key == 'returns_t': assert len(batch[key].shape) == 1 elif key == 'advantages_t': assert len(batch[key].shape) == 1 elif key == 'values_t': assert len(batch[key].shape) == 1 assert count == 128 * 4 // 32
def test_batch_with_short_trajectory_error(self): output = make_output(batch_size=4, batch=True) self.network._infer_arguments = MagicMock(return_value=['obs_t']) self.network._infer = MagicMock(return_value=output) inpt = make_input(batch_size=4, batch=True) action = self.controller.step(*inpt) with pytest.raises(AssertionError): self.controller._batches()
def test_update_with_should_update_false(self): inpt = make_input(batch_size=4, batch=True) output = make_output(batch_size=4, batch=True) self.network._infer = MagicMock(return_value=output) self.network._infer_arguments = MagicMock(return_value=['obs_t']) for i in range(20): action = self.controller.step(*inpt) with pytest.raises(AssertionError): self.controller.update()
def test_should_update(self): output = make_output(batch_size=4, batch=True) self.network._infer = MagicMock(return_value=output) self.network._infer_arguments = MagicMock(return_value=['obs_t']) inpt = make_input(batch_size=4, batch=True) for i in range(128): self.controller.step(*inpt) assert not self.controller.should_update() self.controller.step(*inpt) assert self.controller.should_update()
def test_step_with_eval_episode_over_limit(self): output = make_output() self.network.infer = MagicMock(return_value=output) self.metrics.add = MagicMock(side_effect=Exception) self.metrics.get = MagicMock(return_value=10) inpt = list(make_input(batch_size=4, batch=True)) index = np.random.randint(4) reward = np.random.random() inpt[2] = np.zeros((4, )) inpt[2][index] = 1.0 inpt[3][index]['reward'] = reward self.controller.step(*inpt)
def test_update_success(self): inpt = make_input(batch_size=4, batch=True) output = make_output(batch_size=4, batch=True) loss = np.random.random() self.network._infer = MagicMock(return_value=output) self.network._infer_arguments = MagicMock(return_value=['obs_t']) self.network._update_arguments = MagicMock(return_value=['obs_t', 'actions_t', 'returns_t', 'advantages_t', 'log_probs_t']) self.network._update = MagicMock(return_value=loss) for i in range(129): action = self.controller.step(*inpt) assert np.allclose(self.controller.update(), loss) assert self.rollout.size() == 0 assert self.network._update.call_count == 128 * 4 * 4 // 32
def test_step(self): output = make_output(batch_size=4, batch=True) self.network._infer = MagicMock(return_value=output) self.network._infer_arguments = MagicMock(return_value=['obs_t']) inpt = make_input(batch_size=4, batch=True) action = self.controller.step(*inpt) assert np.all(action == output.action) assert self.rollout.size() == 1 assert np.all(inpt[0] == self.rollout.obs_t[0]) assert np.all(inpt[1] == self.rollout.rewards_t[0]) assert np.all(inpt[2] == self.rollout.terminals_t[0]) assert np.all(output.action == self.rollout.actions_t[0]) assert np.all(output.value == self.rollout.values_t[0]) assert np.all(output.log_prob == self.rollout.log_probs_t[0])
def test_step(self): output = make_output() self.noise.mock = MagicMock() self.network._infer = MagicMock(return_value=output) self.network._infer_arguments = MagicMock(return_value=['obs_t']) self.metrics.add = MagicMock() inpt = make_input() action = self.controller.step(*inpt) assert self.buffer.size() == 1 assert np.all(output.action == action) self.metrics.add.assert_called_once_with('step', 1) assert self.noise.mock.call_count == 1 action = self.controller.step(*inpt) assert self.buffer.size() == 2
def test_step(self, batch): output = make_output() self.network.infer = MagicMock(return_value=output) self.metrics.get = MagicMock(return_value=0) if batch: inpt = list(make_input(batch_size=4, batch=True)) inpt[2] = np.zeros((4, )) else: inpt = list(make_input()) inpt[2] = 0.0 step_output = self.controller.step(*inpt) assert step_output is output.action assert self.network.infer.call_count == 1 if batch: assert self.metrics.get.call_count == 4 else: assert self.metrics.get.call_count == 0
def test_update(self): critic_loss = np.random.random() actor_loss = np.random.random() output = make_output() self.network._update = MagicMock(return_value=(critic_loss, actor_loss)) self.network._update_arguments = MagicMock(return_value=[ 'obs_t', 'actions_t', 'rewards_tp1', 'obs_tp1', 'dones_tp1' ]) self.network._infer_arguments = MagicMock(return_value=['obs_t']) self.network._infer = MagicMock(return_value=output) self.metrics.add = MagicMock() self.controller._record_update_metrics = MagicMock() for i in range(33): inpt = make_input() self.controller.step(*inpt) self.controller.update() assert self.network._update.call_count == 1 self.controller._record_update_metrics.assert_called_once_with( critic_loss, actor_loss)