def predict_step(self, time_step: TimeStep, state: AgentState, epsilon_greedy): """Predict for one step.""" new_state = AgentState() observation = time_step.observation info = AgentInfo() if self._representation_learner is not None: repr_step = self._representation_learner.predict_step( time_step, state.repr) new_state = new_state._replace(repr=repr_step.state) info = info._replace(repr=repr_step.info) observation = repr_step.output if self._goal_generator is not None: goal_step = self._goal_generator.predict_step( time_step._replace(observation=observation), state.goal_generator, epsilon_greedy) new_state = new_state._replace(goal_generator=goal_step.state) info = info._replace(goal_generator=goal_step.info) observation = [observation, goal_step.output] rl_step = self._rl_algorithm.predict_step( time_step._replace(observation=observation), state.rl, epsilon_greedy) new_state = new_state._replace(rl=rl_step.state) info = info._replace(rl=rl_step.info) return AlgStep(output=rl_step.output, state=new_state, info=info)
def setUp(self): self._input_tensor_spec = TensorSpec((10, )) self._time_step = TimeStep( step_type=StepType.MID, reward=0, discount=1, observation=self._input_tensor_spec.zeros(outer_dims=(1, )), prev_action=None, env_id=None) self._hidden_size = 100
def setUp(self): input_tensor_spec = TensorSpec((10, )) self._time_step = TimeStep( step_type=torch.tensor(StepType.MID, dtype=torch.int32), reward=0, discount=1, observation=input_tensor_spec.zeros(outer_dims=(1, )), prev_action=None, env_id=None) self._encoding_net = EncodingNetwork( input_tensor_spec=input_tensor_spec)
def _predict_multi_step_cost(self, observation, actions): """Compute the total cost by unrolling multiple steps according to the given initial observation and multi-step actions. Args: observation: the current observation for predicting quantities of future time steps actions (Tensor): a set of action sequences to shape [B, population, unroll_steps, action_dim] Returns: cost (Tensor): negation of accumulated predicted reward, with the shape of [B, population] """ batch_size, population_size, num_unroll_steps = actions.shape[0:3] state = self.get_initial_predict_state(batch_size) time_step = TimeStep() dyn_state = state.dynamics._replace(feature=observation) dyn_state = nest.map_structure( partial(self._expand_to_population, population_size=population_size), dyn_state) # expand to particles dyn_state = nest.map_structure(self._expand_to_particles, dyn_state) reward_state = state.reward reward = 0 for i in range(num_unroll_steps): action = actions[:, :, i, ...].view(-1, actions.shape[3]) action = self._expand_to_particles(action) time_step = time_step._replace(prev_action=action) time_step, dyn_state = self._predict_next_step( time_step, dyn_state) next_obs = time_step.observation # Note: currently using (next_obs, action), might need to # consider (obs, action) in order to be more compatible # with the conventional definition of the reward function reward_step, reward_state = self._calc_step_reward( next_obs, action, reward_state) reward = reward + reward_step cost = -reward # reshape cost # [B*par, n] -> [B, par*n] cost = cost.reshape( -1, self._particles_per_replica * self._num_dynamics_replicas) cost = cost.mean(-1) # reshape cost back to [batch size, population_size] cost = torch.reshape(cost, [batch_size, -1]) return cost
class DIAYNAlgorithmTest(alf.test.TestCase): def setUp(self): input_tensor_spec = TensorSpec((10, )) self._time_step = TimeStep( step_type=torch.tensor(StepType.MID, dtype=torch.int32), reward=0, discount=1, observation=input_tensor_spec.zeros(outer_dims=(1, )), prev_action=None, env_id=None) self._encoding_net = EncodingNetwork( input_tensor_spec=input_tensor_spec) def test_discrete_skill_loss(self): skill_spec = BoundedTensorSpec((), dtype=torch.int64, minimum=0, maximum=3) alg = DIAYNAlgorithm(skill_spec=skill_spec, encoding_net=self._encoding_net) skill = state = torch.nn.functional.one_hot( skill_spec.zeros(outer_dims=(1, )), int(skill_spec.maximum - skill_spec.minimum + 1)).to(torch.float32) alg_step = alg.train_step( self._time_step._replace( observation=[self._time_step.observation, skill]), state) # the discriminator should predict a uniform distribution self.assertTensorClose(torch.sum(alg_step.info.loss), torch.as_tensor( math.log(skill_spec.maximum - skill_spec.minimum + 1)), epsilon=1e-4) def test_continuous_skill_loss(self): skill_spec = TensorSpec((4, )) alg = DIAYNAlgorithm(skill_spec=skill_spec, encoding_net=self._encoding_net) skill = state = skill_spec.zeros(outer_dims=(1, )) alg_step = alg.train_step( self._time_step._replace( observation=[self._time_step.observation, skill]), state) # the discriminator should predict a zero skill vector self.assertTensorClose(torch.sum(alg_step.info.loss), torch.as_tensor(0))
class ICMAlgorithmTest(alf.test.TestCase): def setUp(self): self._input_tensor_spec = TensorSpec((10, )) self._time_step = TimeStep( step_type=StepType.MID, reward=0, discount=1, observation=self._input_tensor_spec.zeros(outer_dims=(1, )), prev_action=None, env_id=None) self._hidden_size = 100 def test_discrete_action(self): action_spec = BoundedTensorSpec((), dtype=torch.int64, minimum=0, maximum=3) alg = ICMAlgorithm(action_spec=action_spec, observation_spec=self._input_tensor_spec, hidden_size=self._hidden_size) state = self._input_tensor_spec.zeros(outer_dims=(1, )) alg_step = alg.train_step( self._time_step._replace(prev_action=action_spec.zeros( outer_dims=(1, ))), state) # the inverse net should predict a uniform distribution self.assertTensorClose( torch.sum(alg_step.info.loss.extra['inverse_loss']), torch.as_tensor( math.log(action_spec.maximum - action_spec.minimum + 1)), epsilon=1e-4) def test_continuous_action(self): action_spec = TensorSpec((4, )) alg = ICMAlgorithm(action_spec=action_spec, observation_spec=self._input_tensor_spec, hidden_size=self._hidden_size) state = self._input_tensor_spec.zeros(outer_dims=(1, )) alg_step = alg.train_step( self._time_step._replace(prev_action=action_spec.zeros( outer_dims=(1, ))), state) # the inverse net should predict a zero action vector self.assertTensorClose( torch.sum(alg_step.info.loss.extra['inverse_loss']), torch.as_tensor(0))
def _reset(self): self._num_moves.fill_(0) self._board.reset_board(self._B) self._previous_board = self._board.get_board() self._game_over.fill_(False) self._prev_action.fill_(self._pass_action) return TimeStep(observation=OrderedDict( board=self._board.get_board().detach().unsqueeze(1), prev_action=self._prev_action, valid_action_mask=self._get_valid_action_mask(), steps=self._num_moves, to_play=torch.zeros((self._batch_size), dtype=torch.int8)), step_type=torch.full((self._batch_size, ), StepType.FIRST, dtype=torch.int32), reward=torch.zeros((self._batch_size, )), discount=torch.ones((self._batch_size, )), prev_action=self._prev_action, env_id=self._env_ids, env_info={ "player0_win": torch.zeros(self._batch_size), "player1_win": torch.zeros(self._batch_size), "player0_pass": torch.zeros(self._batch_size), "player1_pass": torch.zeros(self._batch_size), "draw": torch.zeros(self._batch_size), "invalid_move": torch.zeros(self._batch_size), "too_long": torch.zeros(self._batch_size), "bad_move": torch.zeros(self._batch_size), })
def rollout_step(self, time_step: TimeStep, state: AgentState): """Rollout for one step.""" new_state = AgentState() info = AgentInfo() time_step = transform_nest(time_step, "observation", self._observation_transformer) subtrajectory = self._skill_generator.update_disc_subtrajectory( time_step, state.skill_generator) skill_step = self._skill_generator.rollout_step( time_step, state.skill_generator) new_state = new_state._replace(skill_generator=skill_step.state) info = info._replace(skill_generator=skill_step.info) observation = self._make_low_level_observation( subtrajectory, skill_step.output, skill_step.info.switch_skill, skill_step.state.steps, skill_step.state.discriminator.first_observation) rl_step = self._rl_algorithm.rollout_step( time_step._replace(observation=observation), state.rl) new_state = new_state._replace(rl=rl_step.state) info = info._replace(rl=rl_step.info) skill_discount = (( (skill_step.state.steps == 1) & (time_step.step_type != StepType.LAST)).to(torch.float32) * (1 - self._skill_boundary_discount)) info = info._replace(skill_discount=1 - skill_discount) return AlgStep(output=rl_step.output, state=new_state, info=info)
def rollout_step(self, time_step: TimeStep, state): if self._reward_normalizer is not None: self._reward_normalizer.update(time_step.reward) time_step = time_step._replace( reward=self._reward_normalizer.normalize( time_step.reward, self._reward_clip_value)) return self._mcts.predict_step(time_step, state)
def predict_step(self, time_step: TimeStep, state: AgentState, epsilon_greedy): """Predict for one step.""" new_state = AgentState() time_step = transform_nest(time_step, "observation", self._observation_transformer) subtrajectory = self._skill_generator.update_disc_subtrajectory( time_step, state.skill_generator) skill_step = self._skill_generator.predict_step( time_step, state.skill_generator, epsilon_greedy) new_state = new_state._replace(skill_generator=skill_step.state) observation = self._make_low_level_observation( subtrajectory, skill_step.output, skill_step.info.switch_skill, skill_step.state.steps, skill_step.state.discriminator.first_observation) rl_step = self._rl_algorithm.predict_step( time_step._replace(observation=observation), state.rl, epsilon_greedy) new_state = new_state._replace(rl=rl_step.state) return AlgStep(output=rl_step.output, state=new_state)
def _gen_time_step(self, s, action): step_type = StepType.MID discount = 1.0 if s == 0: step_type = StepType.FIRST elif s == self._episode_length - 1: step_type = StepType.LAST discount = 0.0 if s == 0: reward = torch.zeros(self.batch_size) else: prev_observation = self._current_time_step.observation reward = 1.0 - torch.abs(prev_observation - action.reshape(prev_observation.shape)) reward = reward.reshape(self.batch_size) if self._reward_dim != 1: reward = reward.unsqueeze(-1).expand((-1, self._reward_dim)) observation = torch.randint(0, 2, size=(self.batch_size, 1), dtype=torch.float32) return TimeStep(step_type=torch.full([self.batch_size], step_type, dtype=torch.int32), reward=reward, discount=torch.full([self.batch_size], discount), observation=observation)
def step(self, action): prev_step_type = self._current_time_step.step_type is_first = prev_step_type == StepType.FIRST is_mid = prev_step_type == StepType.MID is_last = prev_step_type == StepType.LAST step_type = torch.where( is_mid & (torch.rand(self._batch_size) < 0.2), torch.full([self._batch_size], StepType.LAST, dtype=torch.int32), torch.full([self._batch_size], StepType.MID, dtype=torch.int32)) step_type = torch.where( is_last, torch.full([self._batch_size], StepType.FIRST, dtype=torch.int32), step_type) step_type = torch.where( is_first, torch.full([self._batch_size], StepType.MID, dtype=torch.int32), step_type) self._current_time_step = TimeStep( observation=self._observation_spec.randn([self._batch_size]), step_type=step_type, reward=self._rewards[action], discount=torch.zeros(self._batch_size), prev_action=self._prev_action, env_id=torch.arange(self._batch_size, dtype=torch.int32)) self._prev_action = action return self._current_time_step
def train_step(self, exp: Experience, state): time_step = TimeStep(step_type=exp.step_type, reward=exp.reward, discount=exp.discount, observation=exp.observation, prev_action=exp.prev_action, env_id=exp.env_id) return self.rollout_step(time_step, state)
def rollout_step(self, time_step: TimeStep, state: AgentState): """Rollout for one step.""" new_state = AgentState() info = AgentInfo() observation = time_step.observation if self._representation_learner is not None: repr_step = self._representation_learner.rollout_step( time_step, state.repr) new_state = new_state._replace(repr=repr_step.state) info = info._replace(repr=repr_step.info) observation = repr_step.output if self._goal_generator is not None: goal_step = self._goal_generator.rollout_step( time_step._replace(observation=observation), state.goal_generator) new_state = new_state._replace(goal_generator=goal_step.state) info = info._replace(goal_generator=goal_step.info) observation = [observation, goal_step.output] rl_step = self._rl_algorithm.rollout_step( time_step._replace(observation=observation), state.rl) new_state = new_state._replace(rl=rl_step.state) info = info._replace(rl=rl_step.info) if self._irm is not None: irm_step = self._irm.rollout_step( time_step._replace(observation=observation), state=state.irm) info = info._replace(irm=irm_step.info) new_state = new_state._replace(irm=irm_step.state) if self._entropy_target_algorithm: assert 'action_distribution' in rl_step.info._fields, ( "AlgStep from rl_algorithm.rollout() does not contain " "`action_distribution`, which is required by " "`enforce_entropy_target`") et_step = self._entropy_target_algorithm.rollout_step( rl_step.info.action_distribution, step_type=time_step.step_type, on_policy_training=self.is_on_policy()) info = info._replace(entropy_target=et_step.info) return AlgStep(output=rl_step.output, state=new_state, info=info)
def reset(self): self._prev_action = torch.zeros(self._batch_size, dtype=torch.int64) self._current_time_step = TimeStep( observation=self._observation_spec.randn([self._batch_size]), step_type=torch.full([self._batch_size], StepType.FIRST, dtype=torch.int32), reward=torch.zeros(self._batch_size), discount=torch.zeros(self._batch_size), prev_action=self._prev_action, env_id=torch.arange(self._batch_size, dtype=torch.int32)) return self._current_time_step
def predict_step(self, time_step: TimeStep, state, epsilon_greedy): mbp_step = self._mbp.predict_step(inputs=(time_step.observation, time_step.prev_action), state=state.mbp_state) mba_step = self._mba.predict_step( time_step=time_step._replace(observation=mbp_step.output), state=state.mba_state, epsilon_greedy=epsilon_greedy) return AlgStep(output=mba_step.output, state=MerlinState(mbp_state=mbp_step.state, mba_state=mba_step.state), info=())
def rollout_step(self, time_step: TimeStep, state): """Train one step.""" mbp_step = self._mbp.train_step( inputs=(time_step.observation, time_step.prev_action), state=state.mbp_state) mba_step = self._mba.rollout_step( time_step=time_step._replace(observation=mbp_step.output), state=state.mba_state) return AlgStep( output=mba_step.output, state=MerlinState( mbp_state=mbp_step.state, mba_state=mba_step.state), info=MerlinInfo(mbp_info=mbp_step.info, mba_info=mba_step.info))
def _reset(self): self._boards = self._observation_spec.zeros((self._batch_size, )) self._game_over = torch.zeros((self._batch_size, ), dtype=torch.bool) self._prev_action = self._action_spec.zeros((self._batch_size, )) return TimeStep( observation=self._boards.clone().detach(), step_type=torch.full((self._batch_size, ), StepType.FIRST), reward=torch.zeros((self._batch_size, )), discount=torch.ones((self._batch_size, )), prev_action=self._action_spec.zeros((self._batch_size, )), env_id=self._env_ids, env_info={ "play0_win": torch.zeros(self._batch_size), "play1_win": torch.zeros(self._batch_size), "draw": torch.zeros(self._batch_size), "invalid_move": torch.zeros(self._batch_size), })
def _calc_cost_for_action_sequence(self, time_step: TimeStep, state, ac_seqs): """ Args: time_step (TimeStep): input data for next step prediction state (MbrlState): input state for next step prediction ac_seqs: action_sequence (Tensor) of shape [batch_size, population_size, solution_dim]), where solution_dim = planning_horizon * num_actions Returns: cost (Tensor) with shape [batch_size, population_size] """ obs = time_step.observation batch_size = obs.shape[0] ac_seqs = torch.reshape( ac_seqs, [batch_size, self._population_size, self._planning_horizon, -1]) ac_seqs = ac_seqs.permute(2, 0, 1, 3) ac_seqs = torch.reshape( ac_seqs, (self._planning_horizon, -1, self._num_actions)) state = state._replace(dynamics=state.dynamics._replace(feature=obs)) init_obs = self._expand_to_population(obs) state = nest.map_structure(self._expand_to_population, state) obs = init_obs cost = 0 for i in range(ac_seqs.shape[0]): action = ac_seqs[i] time_step = time_step._replace(prev_action=action) time_step, state = self._dynamics_func(time_step, state) next_obs = time_step.observation # Note: currently using (next_obs, action), might need to # consider (obs, action) in order to be more compatible # with the conventional definition of the reward function reward_step, state = self._reward_func(next_obs, action, state) cost = cost - reward_step obs = next_obs # reshape cost back to [batch size, population_size] cost = torch.reshape(cost, [batch_size, -1]) return cost
def _gen_time_step(self, s, action): """Return the current `TimeStep`.""" step_type = StepType.MID discount = 1.0 if s == 0: step_type = StepType.FIRST elif s == self._episode_length - 1: step_type = StepType.LAST discount = 0.0 return TimeStep(step_type=torch.full([self.batch_size], step_type, dtype=torch.int32), reward=torch.ones(self.batch_size), discount=torch.full([ self.batch_size, ], discount), observation=torch.ones(self.batch_size))
def test_agent_steps(self): batch_size = 1 observation_spec = TensorSpec((10, )) action_spec = BoundedTensorSpec((), dtype='int64') time_step = TimeStep( observation=observation_spec.zeros(outer_dims=(batch_size, )), prev_action=action_spec.zeros(outer_dims=(batch_size, ))) actor_net = functools.partial(ActorDistributionNetwork, fc_layer_params=(100, )) value_net = functools.partial(ValueNetwork, fc_layer_params=(100, )) # TODO: add a goal generator and an entropy target algorithm once they # are implemented. agent = Agent(observation_spec=observation_spec, action_spec=action_spec, rl_algorithm_cls=functools.partial( ActorCriticAlgorithm, actor_network_ctor=actor_net, value_network_ctor=value_net), intrinsic_reward_module=ICMAlgorithm( action_spec=action_spec, observation_spec=observation_spec)) predict_state = agent.get_initial_predict_state(batch_size) rollout_state = agent.get_initial_rollout_state(batch_size) train_state = agent.get_initial_train_state(batch_size) pred_step = agent.predict_step(time_step, predict_state, epsilon_greedy=0.1) self.assertEqual(pred_step.state.irm, ()) rollout_step = agent.rollout_step(time_step, rollout_state) self.assertNotEqual(rollout_step.state.irm, ()) exp = make_experience(time_step, rollout_step, rollout_state) train_step = agent.train_step(exp, train_state) self.assertNotEqual(train_step.state.irm, ()) self.assertTensorEqual(rollout_step.state.irm, train_step.state.irm)
def testCreate(self): step_type = torch.tensor(0, dtype=torch.int32) reward = torch.tensor(1, dtype=torch.int32) discount = 0.99 observation = torch.tensor(-1) prev_action = torch.tensor(-1) env_id = torch.tensor(0, dtype=torch.int32) time_step = TimeStep( step_type=step_type, reward=reward, discount=discount, observation=observation, prev_action=prev_action, env_id=env_id) self.assertEqual(StepType.FIRST, time_step.step_type) self.assertEqual(reward, time_step.reward) self.assertEqual(discount, time_step.discount) self.assertEqual(observation, time_step.observation) self.assertEqual(prev_action, time_step.prev_action) self.assertEqual(env_id, time_step.env_id)
def _step(self, action): prev_game_over = self._game_over prev_action = action.clone() prev_action[prev_game_over] = 0 self._boards[prev_game_over] = self._empty_board step_type = torch.full((self._batch_size, ), int(StepType.MID)) player = self._get_current_player().to(torch.float32) x = action % 3 y = action // 3 valid = self._boards[self._B, y, x] == 0 self._boards[self._B[valid], y[valid], x[valid]] = player[valid] won = self._check_player_win(player) reward = torch.where(won, -player, torch.tensor(0.)) reward = torch.where(valid, reward, player) game_over = self._check_game_over() game_over = torch.max(game_over, ~valid) step_type[game_over] = int(StepType.LAST) step_type[prev_game_over] = int(StepType.FIRST) discount = torch.ones(self._batch_size) discount[game_over] = 0. self._boards[prev_game_over] = self._empty_board self._game_over = game_over self._prev_action = action player0_win = self._check_player_win(self._player_0) player1_win = self._check_player_win(self._player_1) draw = torch.min(game_over, reward == 0) return TimeStep( observation=self._boards.clone().detach(), reward=reward.detach(), step_type=step_type.detach(), discount=discount.detach(), prev_action=prev_action.detach(), env_id=self._env_ids, env_info={ "play0_win": player0_win.to(torch.float32), "play1_win": player1_win.to(torch.float32), "draw": draw.to(torch.float32), "invalid_move": (~valid).to(torch.float32), })
def _gen_time_step(self, s, action): step_type = StepType.MID discount = 1.0 obs_dim = self._obs_dim if s == 0: self._observation0 = 2. * torch.randint( 0, 2, size=(self.batch_size, 1)) - 1. if obs_dim > 1: self._observation0 = torch.cat([ self._observation0, torch.ones(self.batch_size, obs_dim - 1) ], dim=-1) step_type = StepType.FIRST elif s == self._episode_length - 1: step_type = StepType.LAST discount = 0.0 if s <= self._gap: reward = torch.zeros(self.batch_size) else: obs0 = self._observation0[:, 0].reshape(self.batch_size, 1) reward = 1.0 - 0.5 * torch.abs(2 * action.reshape(obs0.shape) - 1 - obs0) reward = reward.reshape(self.batch_size) if s == 0: observation = self._observation0 else: observation = torch.zeros(self.batch_size, obs_dim) return TimeStep(step_type=torch.full([self.batch_size], step_type, dtype=torch.int32), reward=reward, discount=torch.full([self.batch_size], discount), observation=observation)
def test_same_actin_prior_actor(self): action_spec = dict(a=BoundedTensorSpec(shape=()), b=BoundedTensorSpec((3, ), minimum=(-1, 0, -2), maximum=(2, 2, 3)), c=BoundedTensorSpec((2, 3), minimum=-1, maximum=1)) actor = SameActionPriorActor(observation_spec=(), action_spec=action_spec) batch = TimeStep(step_type=torch.tensor([StepType.FIRST, StepType.MID]), prev_action=dict(a=torch.tensor([0., 1.]), b=torch.tensor([[-1., 0., -2.], [2., 2., 3.]]), c=action_spec['c'].sample((2, )))) alg_step = actor.predict_step(batch, ()) self.assertAlmostEqual( alg_step.output['a'].log_prob(torch.tensor([0., 0.]))[0], alg_step.output['a'].log_prob(torch.tensor([1., 1.]))[0], delta=1e-6) self.assertAlmostEqual( alg_step.output['a'].log_prob(torch.tensor([0., 0.]))[1], alg_step.output['a'].log_prob(torch.tensor([0., 0.]))[0] + math.log(0.1), delta=1e-6) self.assertAlmostEqual(alg_step.output['b'].log_prob( torch.tensor([[-1., 0., -2.]] * 2))[0], alg_step.output['b'].log_prob( torch.tensor([[2., 2., 3.]] * 2))[0], delta=1e-6) self.assertAlmostEqual(alg_step.output['b'].log_prob( torch.tensor([[-1., 0., -2.]] * 2))[1], alg_step.output['b'].log_prob( torch.tensor([[-1., 0., -2.]] * 2))[0] + 3 * math.log(0.1), delta=1e-6)
def _gen_time_step(self, s, action): step_type = StepType.MID discount = 1.0 reward = torch.zeros(self.batch_size) if s == 0: step_type = StepType.FIRST elif s == self._episode_length - 1: step_type = StepType.LAST discount = 0.0 if s > 0: reward = (action[0] == (action[1].squeeze(-1) > 0.5).to( torch.int64)).to(torch.float32) observation = self._observation_spec.randn( outer_dims=(self.batch_size, )) return TimeStep(step_type=torch.full([self.batch_size], step_type, dtype=torch.int32), reward=reward, discount=torch.full([self.batch_size], discount), observation=observation)
def test_uniform_prior_actor(self): action_spec = dict(a=BoundedTensorSpec(shape=()), b=BoundedTensorSpec((3, ), minimum=(-1, 0, -2), maximum=(2, 2, 3)), c=BoundedTensorSpec((2, 3), minimum=-1, maximum=1)) actor = UniformPriorActor(observation_spec=(), action_spec=action_spec) batch = TimeStep(step_type=torch.tensor([StepType.FIRST, StepType.MID]), prev_action=dict(a=torch.tensor([0., 1.]), b=torch.tensor([[-1., 0., -2.], [2., 2., 3.]]), c=action_spec['c'].sample((2, )))) alg_step = actor.predict_step(batch, ()) self.assertEqual( alg_step.output['a'].log_prob(action_spec['a'].sample()), torch.tensor(0.)) self.assertEqual( alg_step.output['b'].log_prob(action_spec['b'].sample()), -torch.tensor(30.).log()) self.assertEqual( alg_step.output['c'].log_prob(action_spec['c'].sample()), -torch.tensor(64.).log())
def _create_timestep(reward, env_id, step_type, env_info): return TimeStep(step_type=to_tensor(step_type), reward=to_tensor(reward), env_info=env_info, env_id=to_tensor(env_id))
def _should_switch_skills(self, time_step: TimeStep, state): should_switch_skills = ((state.steps % self._num_steps_per_skill) == 0) # is_last is only necessary for `rollout_step` because it marks an # episode end in the replay buffer for training the policy `self._rl`. return should_switch_skills | time_step.is_first() | time_step.is_last( )
def test_mcts_algorithm(self): observation_spec = alf.TensorSpec((3, 3)) action_spec = alf.BoundedTensorSpec((), dtype=torch.int64, minimum=0, maximum=8) model = TicTacToeModel() time_step = TimeStep(step_type=torch.tensor([StepType.MID])) # board situations and expected actions # yapf: disable cases = [ ([[1, -1, 1], [1, -1, -1], [0, 0, 1]], 6), ([[0, 0, 0], [0, -1, -1], [0, 1, 0]], 3), ([[ 1, -1, -1], [-1, -1, 0], [ 0, 1, 1]], 6), ([[-1, 0, 1], [ 0, -1, -1], [ 0, 0, 1]], 3), ([[0, 0, 0], [0, 0, 0], [0, 0, -1]], 4), ([[0, 0, 0], [0, -1, 0], [0, 0, 0]], (0, 2, 6, 8)), ([[0, 0, 0], [0, 1, -1], [1, -1, -1]], 2), ] # yapf: enable def _create_mcts(observation_spec, action_spec, num_simulations): return MCTSAlgorithm( observation_spec, action_spec, discount=1.0, root_dirichlet_alpha=100., root_exploration_fraction=0.25, num_simulations=num_simulations, pb_c_init=1.25, pb_c_base=19652, visit_softmax_temperature_fn=VisitSoftmaxTemperatureByMoves( [(0, 1.0), (10, 0.0001)]), known_value_bounds=(-1, 1), is_two_player_game=True) # test case serially for observation, action in cases: observation = torch.tensor([observation], dtype=torch.float32) state = MCTSState(steps=(observation != 0).sum(dim=(1, 2))) # We use varing num_simulations instead of a fixed large number such # as 2000 to make the test faster. num_simulations = int((observation == 0).sum().cpu()) * 200 mcts = _create_mcts( observation_spec, action_spec, num_simulations=num_simulations) mcts.set_model(model) alg_step = mcts.predict_step( time_step._replace(observation=observation), state) print(observation, alg_step.output, alg_step.info) if type(action) == tuple: self.assertTrue(alg_step.output[0] in action) else: self.assertEqual(alg_step.output[0], action) # test batch predict observation = torch.tensor([case[0] for case in cases], dtype=torch.float32) state = MCTSState(steps=(observation != 0).sum(dim=(1, 2))) mcts = _create_mcts( observation_spec, action_spec, num_simulations=2000) mcts.set_model(model) alg_step = mcts.predict_step( time_step._replace( step_type=torch.tensor([StepType.MID] * len(cases)), observation=observation), state) for i, (observation, action) in enumerate(cases): if type(action) == tuple: self.assertTrue(alg_step.output[i] in action) else: self.assertEqual(alg_step.output[i], action)