def train_step(self, exp: Experience, state, trainable=True): """This function trains the discriminator or generates intrinsic rewards. If ``trainable=True``, then it only generates and returns the pred loss; otherwise it only generates rewards with no grad. """ # Discriminator training from its own replay buffer or # Discriminator computing intrinsic rewards for training lower_rl untrans_observation, prev_skill, switch_skill, steps = exp.observation observation = self._observation_transformer(untrans_observation) loss = self._predict_skill_loss(observation, exp.prev_action, prev_skill, steps, state) first_observation = self._update_state_if_necessary( switch_skill, observation, state.first_observation) subtrajectory = self._clear_subtrajectory_if_necessary( state.subtrajectory, switch_skill) new_state = DiscriminatorState(first_observation=first_observation, untrans_observation=untrans_observation, subtrajectory=subtrajectory) valid_masks = (exp.step_type != StepType.FIRST) if self._sparse_reward: # Only give intrinsic rewards at the last step of the skill valid_masks &= switch_skill loss *= valid_masks.to(torch.float32) if trainable: info = LossInfo(loss=loss, extra=dict(discriminator_loss=loss)) return AlgStep(state=new_state, info=info) else: intrinsic_reward = -loss.detach() / self._skill_dim return AlgStep(state=common.detach(new_state), info=intrinsic_reward)
def predict_step(self, inputs=None, noise=None, batch_size=None, training=False, state=None): """Generate outputs given inputs. Args: inputs (nested Tensor): if None, the outputs is generated only from noise. noise (Tensor): input to the generator. batch_size (int): batch_size. Must be provided if inputs is None. Its is ignored if inputs is not None training (bool): whether train the generator. state: not used Returns: AlgorithmStep: outputs with shape (batch_size, output_dim) """ outputs, _ = self._predict(inputs=inputs, noise=noise, batch_size=batch_size, training=training) return AlgStep(output=outputs, state=(), info=())
def predict_step(self, time_step: TimeStep, state, epsilon_greedy): switch_action = self._should_switch_action(time_step, state) @torch.no_grad() def _generate_new_action(time_step, state): repr_state = () if self._repr_learner is not None: repr_step = self._repr_learner.predict_step( time_step, state.repr) time_step = time_step._replace(observation=repr_step.output) repr_state = repr_step.state rl_step = self._rl.predict_step(time_step, state.rl, epsilon_greedy) steps, action = rl_step.output return ActionRepeatState( action=action, steps=steps + 1, # [0, K-1] -> [1, K] rl=rl_step.state, repr=repr_state) new_state = conditional_update( target=state, cond=switch_action, func=_generate_new_action, time_step=time_step, state=state) new_state = new_state._replace(steps=new_state.steps - 1) return AlgStep( output=new_state.action, state=new_state, # plot steps and action when rendering video info=dict(action=(new_state.action, new_state.steps)))
def rollout_step(self, time_step: TimeStep, state: AgentState): """Rollout for one step.""" new_state = AgentState() info = AgentInfo() time_step = transform_nest(time_step, "observation", self._observation_transformer) subtrajectory = self._skill_generator.update_disc_subtrajectory( time_step, state.skill_generator) skill_step = self._skill_generator.rollout_step( time_step, state.skill_generator) new_state = new_state._replace(skill_generator=skill_step.state) info = info._replace(skill_generator=skill_step.info) observation = self._make_low_level_observation( subtrajectory, skill_step.output, skill_step.info.switch_skill, skill_step.state.steps, skill_step.state.discriminator.first_observation) rl_step = self._rl_algorithm.rollout_step( time_step._replace(observation=observation), state.rl) new_state = new_state._replace(rl=rl_step.state) info = info._replace(rl=rl_step.info) skill_discount = (( (skill_step.state.steps == 1) & (time_step.step_type != StepType.LAST)).to(torch.float32) * (1 - self._skill_boundary_discount)) info = info._replace(skill_discount=1 - skill_discount) return AlgStep(output=rl_step.output, state=new_state, info=info)
def train_step(self, exp: Experience, state): def _hook(grad, name): alf.summary.scalar("MCTS_state_grad_norm/" + name, grad.norm()) model_output = self._model.initial_inference(exp.observation) if alf.summary.should_record_summaries(): model_output.state.register_hook(partial(_hook, name="s0")) model_output_spec = dist_utils.extract_spec(model_output) model_outputs = [dist_utils.distributions_to_params(model_output)] info = exp.rollout_info for i in range(self._num_unroll_steps): model_output = self._model.recurrent_inference( model_output.state, info.action[:, i, ...]) if alf.summary.should_record_summaries(): model_output.state.register_hook( partial(_hook, name="s" + str(i + 1))) model_output = model_output._replace(state=scale_gradient( model_output.state, self._recurrent_gradient_scaling_factor)) model_outputs.append( dist_utils.distributions_to_params(model_output)) model_outputs = alf.nest.utils.stack_nests(model_outputs, dim=1) model_outputs = dist_utils.params_to_distributions( model_outputs, model_output_spec) return AlgStep(info=self._model.calc_loss(model_outputs, info.target))
def train_step(self, exp: Experience, state): new_state = AgentState() info = AgentInfo() skill_generator_info = exp.rollout_info.skill_generator subtrajectory = self._skill_generator.update_disc_subtrajectory( exp, state.skill_generator) skill_step = self._skill_generator.train_step( exp._replace(rollout_info=skill_generator_info), state.skill_generator) info = info._replace(skill_generator=skill_step.info) new_state = new_state._replace(skill_generator=skill_step.state) exp = transform_nest(exp, "observation", self._observation_transformer) observation = self._make_low_level_observation( subtrajectory, skill_step.output, skill_generator_info.switch_skill, skill_generator_info.steps, skill_step.state.discriminator.first_observation) rl_step = self._rl_algorithm.train_step( exp._replace(observation=observation, rollout_info=exp.rollout_info.rl), state.rl) new_state = new_state._replace(rl=rl_step.state) info = info._replace(rl=rl_step.info) return AlgStep(output=rl_step.output, state=new_state, info=info)
def predict_step(self, time_step: TimeStep, state: AgentState, epsilon_greedy): """Predict for one step.""" new_state = AgentState() time_step = transform_nest(time_step, "observation", self._observation_transformer) subtrajectory = self._skill_generator.update_disc_subtrajectory( time_step, state.skill_generator) skill_step = self._skill_generator.predict_step( time_step, state.skill_generator, epsilon_greedy) new_state = new_state._replace(skill_generator=skill_step.state) observation = self._make_low_level_observation( subtrajectory, skill_step.output, skill_step.info.switch_skill, skill_step.state.steps, skill_step.state.discriminator.first_observation) rl_step = self._rl_algorithm.predict_step( time_step._replace(observation=observation), state.rl, epsilon_greedy) new_state = new_state._replace(rl=rl_step.state) return AlgStep(output=rl_step.output, state=new_state)
def predict_step(self, time_step: TimeStep, state, epsilon_greedy): """This function does one thing, i.e., every ``self._num_steps_per_skill`` it calls ``self._rl`` to generate new skills. """ switch_skill = self._should_switch_skills(time_step, state) discriminator_step = self._discriminator_predict_step( time_step, state, switch_skill) def _generate_new_skills(time_step, state): rl_step = self._rl.predict_step(time_step, state.rl, epsilon_greedy) return SkillGeneratorState(skill=rl_step.output, steps=torch.zeros_like(state.steps), rl=rl_step.state) new_state = conditional_update(target=state, cond=switch_skill, func=_generate_new_skills, time_step=time_step, state=state) new_state = new_state._replace(steps=new_state.steps + 1, discriminator=discriminator_step.state) return AlgStep(output=new_state.skill, state=new_state, info=SkillGeneratorInfo(switch_skill=switch_skill))
def train_step(self, time_step: TimeStep, state: DynamicsState): """ Args: time_step (TimeStep): input data for dynamics learning state (DynamicsState): state for dynamics learning (previous observation) Returns: AlgStep: output: empty tuple () state (DynamicsState): state for training info (DynamicsInfo): """ feature = time_step.observation dynamics_step = self.predict_step(time_step, state) forward_pred = dynamics_step.output forward_loss = (feature - forward_pred)**2 forward_loss = 0.5 * forward_loss.mean( list(range(1, forward_loss.ndim))) # we mask out FIRST as its state is invalid valid_masks = (time_step.step_type != StepType.FIRST).to(torch.float32) forward_loss = forward_loss * valid_masks info = DynamicsInfo(loss=LossInfo( loss=forward_loss, extra=dict(forward_loss=forward_loss))) state = state._replace(feature=feature) return AlgStep(output=(), state=state, info=info)
def predict_step(self, time_step: TimeStep, state, epsilon_greedy=1.): action, state = self._actor_network(time_step.observation, state=state.actor.actor) empty_state = nest.map_structure(lambda x: (), self.train_state_spec) def _sample(a, ou): if epsilon_greedy == 0: return a elif epsilon_greedy >= 1.0: return a + ou() else: ind_explore = torch.where( torch.rand(a.shape[:1]) < epsilon_greedy) noisy_a = a + ou() a[ind_explore[0], :] = noisy_a[ind_explore[0], :] return a noisy_action = nest.map_structure(_sample, action, self._ou_process) noisy_action = nest.map_structure(spec_utils.clip_to_spec, noisy_action, self._action_spec) state = empty_state._replace( actor=DdpgActorState(actor=state, critics=())) return AlgStep(output=noisy_action, state=state, info=DdpgInfo(action_distribution=action))
def train_step(self, exp: TimeStep, state): # [B, num_unroll_steps + 1] info = exp.rollout_info targets = common.as_list(info.target) batch_size = exp.step_type.shape[0] latent, state = self._encoding_net(exp.observation, state) sim_latent = self._multi_step_latent_rollout(latent, self._num_unroll_steps, info.action, state) loss = 0 for i, decoder in enumerate(self._decoders): # [num_unroll_steps + 1)*B, ...] train_info = decoder.train_step(sim_latent).info train_info_spec = dist_utils.extract_spec(train_info) train_info = dist_utils.distributions_to_params(train_info) train_info = alf.nest.map_structure( lambda x: x.reshape(self._num_unroll_steps + 1, batch_size, *x. shape[1:]), train_info) # [num_unroll_steps + 1, B, ...] train_info = dist_utils.params_to_distributions( train_info, train_info_spec) target = alf.nest.map_structure(lambda x: x.transpose(0, 1), targets[i]) loss_info = decoder.calc_loss(target, train_info, info.mask.t()) loss_info = alf.nest.map_structure(lambda x: x.mean(dim=0), loss_info) loss += loss_info.loss loss_info = LossInfo(loss=loss, extra=loss) return AlgStep(output=latent, state=state, info=loss_info)
def rollout_step(self, time_step: TimeStep, state: SarsaState): if self._on_policy: return self._train_step(time_step, state) if not self._is_rnn: critic_states = state.critics else: _, critic_states = self._critic_networks( (state.prev_observation, time_step.prev_action), state.critics) not_first_step = time_step.step_type != StepType.FIRST critic_states = common.reset_state_if_necessary( state.critics, critic_states, not_first_step) action_distribution, action, actor_state, noise_state = self._get_action( self._rollout_actor_network, time_step, state) if not self._is_rnn: target_critic_states = state.target_critics else: _, target_critic_states = self._target_critic_networks( (time_step.observation, action), state.target_critics) info = SarsaInfo(action_distribution=action_distribution) rl_state = SarsaState(noise=noise_state, prev_observation=time_step.observation, prev_step_type=time_step.step_type, actor=actor_state, critics=critic_states, target_critics=target_critic_states) return AlgStep(action, rl_state, info)
def predict_step(self, time_step: TimeStep, state: AgentState, epsilon_greedy): """Predict for one step.""" new_state = AgentState() observation = time_step.observation info = AgentInfo() if self._representation_learner is not None: repr_step = self._representation_learner.predict_step( time_step, state.repr) new_state = new_state._replace(repr=repr_step.state) info = info._replace(repr=repr_step.info) observation = repr_step.output if self._goal_generator is not None: goal_step = self._goal_generator.predict_step( time_step._replace(observation=observation), state.goal_generator, epsilon_greedy) new_state = new_state._replace(goal_generator=goal_step.state) info = info._replace(goal_generator=goal_step.info) observation = [observation, goal_step.output] rl_step = self._rl_algorithm.predict_step( time_step._replace(observation=observation), state.rl, epsilon_greedy) new_state = new_state._replace(rl=rl_step.state) info = info._replace(rl=rl_step.info) return AlgStep(output=rl_step.output, state=new_state, info=info)
def testCreateWithDefaultInfo(self): action = torch.tensor(1) state = torch.tensor(2) info = () step = AlgStep(output=action, state=state) self.assertEqual(step.output, action) self.assertEqual(step.state, state) self.assertEqual(step.info, info)
def testCreate(self): action = torch.tensor(1) state = torch.tensor(2) info = torch.tensor(3) step = AlgStep(output=action, state=state, info=info) self.assertEqual(step.output, action) self.assertEqual(step.state, state) self.assertEqual(step.info, info)
def predict_step(self, time_step, state): observation, switch_skill = time_step.observation first_observation = self._update_state_if_necessary( switch_skill, observation, state.first_observation) subtrajectory = self._clear_subtrajectory_if_necessary( state.subtrajectory, switch_skill) return AlgStep(state=DiscriminatorState( first_observation=first_observation, subtrajectory=subtrajectory))
def _step(self, time_step: TimeStep, state, calc_rewards=True): """This step is for both `rollout_step` and `train_step`. Args: time_step (TimeStep): input time_step data for ICM state (Tensor): state for ICM (previous observation) calc_rewards (bool): whether calculate rewards Returns: AlgStep: output: empty tuple () state: observation info (ICMInfo): """ feature = time_step.observation prev_action = time_step.prev_action.detach() # normalize observation for easier prediction if self._observation_normalizer is not None: feature = self._observation_normalizer.normalize(feature) if self._encoding_net is not None: feature, _ = self._encoding_net(feature) prev_feature = state forward_pred, _ = self._forward_net( inputs=[prev_feature.detach(), self._encode_action(prev_action)]) # nn.MSELoss doesn't support reducing along a dim forward_loss = 0.5 * torch.mean( math_ops.square(forward_pred - feature.detach()), dim=-1) action_pred, _ = self._inverse_net([prev_feature, feature]) if self._action_spec.is_discrete: inverse_loss = torch.nn.CrossEntropyLoss(reduction='none')( input=action_pred, target=prev_action.to(torch.int64)) else: # nn.MSELoss doesn't support reducing along a dim inverse_loss = 0.5 * torch.mean( math_ops.square(action_pred - prev_action), dim=-1) intrinsic_reward = () if calc_rewards: intrinsic_reward = forward_loss.detach() intrinsic_reward = self._reward_normalizer.normalize( intrinsic_reward) return AlgStep( output=(), state=feature, info=ICMInfo( reward=intrinsic_reward, loss=LossInfo( loss=forward_loss + inverse_loss, extra=dict( forward_loss=forward_loss, inverse_loss=inverse_loss))))
def predict_step(self, time_step: TimeStep, state, epsilon_greedy): action_distribution, action, actor_state, noise_state = self._get_action( self._rollout_actor_network, time_step, state, epsilon_greedy) return AlgStep(output=action, state=SarsaState(noise=noise_state, actor=actor_state, prev_observation=time_step.observation, prev_step_type=time_step.step_type), info=SarsaInfo(action_distribution=action_distribution))
def train_step(self, time_step: TimeStep, state=()): """ Args: time_step (TimeStep): input data for dynamics learning state: state for reward learning Returns: AlgStep """ return AlgStep(output=(), state=state, info=())
def train_step(self, time_step: TimeStep, state: DynamicsState): """ Args: time_step (TimeStep): time step structure. The ``prev_action`` from time_step will be used for predicting feature of the next step. It should be a Tensor of the shape [B, ...] or [B, n, ...] when n > 1, where n denotes the number of dynamics network replicas. When the input tensor has the shape of [B, ...] and n > 1, it will be first expanded to [B, n, ...] to match the number of dynamics network replicas. state (DynamicsState): state for dynamics learning with the following fields: - feature (Tensor): features of the previous observation of the shape [B, ...] or [B, n, ...] when n > 1. When ``state.feature`` has the shape of [B, ...] and n > 1, it will be first expanded to [B, n, ...] to match the number of dynamics network replicas. It is used for predicting the feature of the next step together with ``time_step.prev_action``. - network: the input state of the dynamics network Returns: AlgStep: outputs: empty tuple () state (DynamicsState): with the following fields - feature (Tensor): [B, ...] (or [B, n, ...] when n > 1) shape tensor representing the predicted feature of the next step - network: the updated state of the dynamics network info (DynamicsInfo): with the following fields being updated: - loss (LossInfo): - dist (td.Distribution): the predictive distribution which can be used for further calculation or summarization. """ feature = time_step.observation feature = self._expand_to_replica(feature, self._feature_spec) dynamics_step = self.predict_step(time_step, state) dist = dynamics_step.info.dist forward_loss = -dist.log_prob(feature - state.feature) if forward_loss.ndim > 2: # [B, n, ...] -> [B, ...] forward_loss = forward_loss.sum(1) if forward_loss.ndim > 1: forward_loss = forward_loss.mean(list(range(1, forward_loss.ndim))) valid_masks = (time_step.step_type != StepType.FIRST).to(torch.float32) forward_loss = forward_loss * valid_masks info = DynamicsInfo(loss=LossInfo( loss=forward_loss, extra=dict(forward_loss=forward_loss)), dist=dist) state = state._replace(feature=feature) return AlgStep(output=(), state=state, info=info)
def predict_step(self, time_step: TimeStep, state: ActorCriticState, epsilon_greedy): """Predict for one step.""" action_dist, actor_state = self._actor_network(time_step.observation, state=state.actor) action = dist_utils.epsilon_greedy_sample(action_dist, epsilon_greedy) return AlgStep(output=action, state=ActorCriticState(actor=actor_state), info=ActorCriticInfo(action_distribution=action_dist))
def predict_step(self, time_step: TimeStep, state): flat_prev_action = alf.nest.flatten(time_step.prev_action) dists = [ self._make_dist(time_step.step_type, prev_action, spec) for prev_action, spec in zip( flat_prev_action, self._prepared_specs) ] return AlgStep( output=alf.nest.pack_sequence_as(self._action_spec, dists), state=(), info=())
def _predict_with_planning(self, time_step: TimeStep, state, epsilon_greedy): # full state in action = self._planner_module.generate_plan(time_step, state, epsilon_greedy) dynamics_state = self._dynamics_module.update_state( time_step, state.dynamics) return AlgStep(output=action, state=state._replace(dynamics=dynamics_state), info=MbrlInfo())
def rollout_step(self, time_step, state): """This function updates the discriminator state.""" (observation, _, switch_skill, _) = time_step.observation first_observation = self._update_state_if_necessary( switch_skill, observation, state.first_observation) subtrajectory = self._clear_subtrajectory_if_necessary( state.subtrajectory, switch_skill) return AlgStep(state=DiscriminatorState( first_observation=first_observation, untrans_observation=time_step.untransformed.observation, subtrajectory=subtrajectory))
def train_step(self, exp: Experience, state: MbrlState): action = exp.action dynamics_step = self._dynamics_module.train_step(exp, state.dynamics) reward_step = self._reward_module.train_step(exp, state.reward) plan_step = self._planner_module.train_step(exp, state.planner) state = MbrlState(dynamics=dynamics_step.state, reward=reward_step.state, planner=plan_step.state) info = MbrlInfo(dynamics=dynamics_step.info, reward=reward_step.info, planner=plan_step.info) return AlgStep(action, state, info)
def predict_step(self, time_step: TimeStep, state: SacState, epsilon_greedy=1.0): action_dist, action, _, action_state = self._predict_action( time_step.observation, state=state.action, epsilon_greedy=epsilon_greedy, eps_greedy_sampling=True) return AlgStep(output=action, state=SacState(action=action_state), info=SacInfo(action_distribution=action_dist))
def train_step(self, time_step: TimeStep, state): """ Args: time_step (TimeStep): input data for planning state: state for planning (previous observation) Returns: AlgStep: output: empty tuple () state (DynamicsState): state for training info (DynamicsInfo): """ return AlgStep(output=(), state=state, info=())
def predict_step(self, time_step: TimeStep, state, epsilon_greedy): mbp_step = self._mbp.predict_step(inputs=(time_step.observation, time_step.prev_action), state=state.mbp_state) mba_step = self._mba.predict_step( time_step=time_step._replace(observation=mbp_step.output), state=state.mba_state, epsilon_greedy=epsilon_greedy) return AlgStep(output=mba_step.output, state=MerlinState(mbp_state=mbp_step.state, mba_state=mba_step.state), info=())
def predict_step(self, time_step: TimeStep, state: DynamicsState): """Predict the next observation given the current time_step. The next step is predicted using the ``prev_action`` from time_step and the ``feature`` from state. Args: time_step (TimeStep): time step structure. The ``prev_action`` from time_step will be used for predicting feature of the next step. It should be a Tensor of the shape [B, ...], or [B, n, ...] when n > 1, where n denotes the number of dynamics network replicas. When the input tensor has the shape of [B, ...] and n > 1, it will be first expanded to [B, n, ...] to match the number of dynamics network replicas. state (DynamicsState): state for dynamics learning with the following fields: - feature (Tensor): features of the previous observation of the shape [B, ...], or [B, n, ...] when n > 1. When ``state.feature`` has the shape of [B, ...] and n > 1, it will be first expanded to [B, n, ...] to match the number of dynamics network replicas. It is used for predicting the feature of the next step together with ``time_step.prev_action``. - network: the input state of the dynamics network Returns: AlgStep: outputs (Tensor): predicted feature of the next step, of the shape [B, ...], or [B, n, ...] when n > 1. state (DynamicsState): with the following fields - feature (Tensor): [B, n, ...] (or [B, n, ...] when n > 1) shape tensor representing the predicted feature of the next step - network: the updated state of the dynamics network info (DynamicsInfo): with the following fields being updated: - dist (td.Distribution): the predictive distribution which can be used for further calculation or summarization. """ action = self._encode_action(time_step.prev_action) obs = state.feature # perform preprocessing observations = self._expand_to_replica(obs, self._feature_spec) actions = self._expand_to_replica(action, self._action_spec) dist, network_states = self._dynamics_network((observations, actions), state=state.network) forward_deltas = dist.sample() forward_preds = observations + forward_deltas state = state._replace(feature=forward_preds, network=network_states) return AlgStep(output=forward_preds, state=state, info=DynamicsInfo(dist=dist))
def _step(self, time_step: TimeStep, state, calc_rewards=True): """ Args: time_step (TimeStep): input time step data, where the observation is skill-augmened observation. The skill should be a one-hot vector. state (Tensor): state for DIAYN (previous skill) which should be a one-hot vector. calc_rewards (bool): if False, only return the losses. Returns: AlgStep: output: empty tuple () state: skill info (DIAYNInfo): """ observations_aug = time_step.observation step_type = time_step.step_type observation, skill = observations_aug prev_skill = state.detach() # normalize observation for easier prediction if self._observation_normalizer is not None: observation = self._observation_normalizer.normalize(observation) if self._encoding_net is not None: feature, _ = self._encoding_net(observation) skill_pred, _ = self._discriminator_net(feature) if self._skill_spec.is_discrete: loss = torch.nn.CrossEntropyLoss(reduction='none')( input=skill_pred, target=torch.argmax(prev_skill, dim=-1)) else: # nn.MSELoss doesn't support reducing along a dim loss = torch.sum(math_ops.square(skill_pred - prev_skill), dim=-1) valid_masks = (step_type != to_tensor(StepType.FIRST)).to( torch.float32) loss *= valid_masks intrinsic_reward = () if calc_rewards: intrinsic_reward = -loss.detach() intrinsic_reward = self._reward_normalizer.normalize( intrinsic_reward) return AlgStep( output=(), state=skill, info=DIAYNInfo(reward=intrinsic_reward, loss=loss))