Ejemplo n.º 1
0
 def optimize_agent(self, itr, samples):
     """
     Train the agent on input samples, by one gradient step.
     """
     if hasattr(self.agent, "update_obs_rms"):
         # NOTE: suboptimal--obs sent to device here and in agent(*inputs).
         self.agent.update_obs_rms(samples.env.observation)
     self.optimizer.zero_grad()
     loss, pi_loss, value_loss, entropy, perplexity = self.loss(samples)
     loss.backward()
     grad_norm = torch.nn.utils.clip_grad_norm_(self.agent.parameters(),
                                                self.clip_grad_norm)
     self.optimizer.step()
     opt_info = OptInfo(
         loss=loss.item(),
         pi_loss=pi_loss.item(),
         value_loss=value_loss.item(),
         gradNorm=grad_norm.clone().detach().item(
         ),  # backwards compatible,
         entropy=entropy.item(),
         perplexity=perplexity.item(),
     )
     self.update_counter += 1
     if self.linear_lr_schedule:
         self.lr_scheduler.step()
     return opt_info
Ejemplo n.º 2
0
    def optimize_agent(self, itr, samples):
        """
        Train the agent, for multiple epochs over minibatches taken from the
        input samples.  Organizes agent inputs from the training data, and
        moves them to device (e.g. GPU) up front, so that minibatches are
        formed within device, without further data transfer.
        """
        recurrent = self.agent.recurrent
        agent_inputs = AgentInputs(  # Move inputs to device once, index there.
            observation=samples.env.observation,
            prev_action=samples.agent.prev_action,
            prev_reward=samples.env.prev_reward,
        )
        agent_inputs = buffer_to(agent_inputs, device=self.agent.device)
        if hasattr(self.agent, "update_obs_rms"):
            self.agent.update_obs_rms(agent_inputs.observation)
        return_, advantage, valid = self.process_returns(samples, self.normalize_rewards)
        loss_inputs = LossInputs(  # So can slice all.
            agent_inputs=agent_inputs,
            action=samples.agent.action,
            return_=return_,
            advantage=advantage,
            valid=valid,
            old_dist_info=samples.agent.agent_info.dist_info,
        )
        if recurrent:
            # Leave in [B,N,H] for slicing to minibatches.
            init_rnn_state = samples.agent.agent_info.prev_rnn_state[0]  # T=0.
        T, B = samples.env.reward.shape[:2]
        opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields))))
        # If recurrent, use whole trajectories, only shuffle B; else shuffle all.
        batch_size = B if self.agent.recurrent else T * B
        mb_size = batch_size // self.minibatches
        for _ in range(self.epochs):
            for idxs in iterate_mb_idxs(batch_size, mb_size, shuffle=True):
                T_idxs = slice(None) if recurrent else idxs % T
                B_idxs = idxs if recurrent else idxs // T
                self.optimizer.zero_grad()
                rnn_state = init_rnn_state[B_idxs] if recurrent else None
                # NOTE: if not recurrent, will lose leading T dim, should be OK.
                loss, entropy, perplexity = self.loss(
                    *loss_inputs[T_idxs, B_idxs], rnn_state)
                loss.backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    self.agent.parameters(), self.clip_grad_norm)
                self.optimizer.step()

                opt_info.loss.append(loss.item())
                opt_info.gradNorm.append(grad_norm)
                opt_info.entropy.append(entropy.item())
                opt_info.perplexity.append(perplexity.item())
                self.update_counter += 1
        if self.linear_lr_schedule:
            self.lr_scheduler.step()
            self.ratio_clip = self._ratio_clip * (self.n_itr - itr) / self.n_itr
        # if self.vae_lr_scheduler:
        #     self.vae_lr_scheduler.step()

        return opt_info
    def optimize_agent(self, itr, samples):
        recurrent = self.agent.recurrent
        agent_inputs = AgentInputs(  # Move inputs to device once, index there.
            observation=samples.env.observation,
            prev_action=samples.agent.prev_action,
            prev_reward=samples.env.prev_reward,
        )
        agent_inputs = buffer_to(agent_inputs, device=self.agent.device)
        return_, advantage, valid = self.process_returns(samples)
        loss_inputs = LossInputs(  # So can slice all.
            agent_inputs=agent_inputs,
            action=samples.agent.action,
            return_=return_,
            advantage=advantage,
            valid=valid,
            old_dist_info=samples.agent.agent_info.dist_info,
        )
        if recurrent:
            # Leave in [B,N,H] for slicing to minibatches.
            init_rnn_state = samples.agent.agent_info.prev_rnn_state[0]  # T=0.
        T, B = samples.env.reward.shape[:2]
        opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields))))
        # If recurrent, use whole trajectories, only shuffle B; else shuffle all.
        batch_size = B if self.agent.recurrent else T * B
        mb_size = batch_size // self.minibatches
        for _ in range(self.epochs):
            for idxs in iterate_mb_idxs(batch_size, mb_size, shuffle=True):
                T_idxs = slice(None) if recurrent else idxs % T
                B_idxs = idxs if recurrent else idxs // T
                rnn_state = init_rnn_state[B_idxs] if recurrent else None
                # NOTE: if not recurrent, will lose leading T dim, should be OK.
                pi_loss, value_loss, entropy, perplexity = self.loss(
                    *loss_inputs[T_idxs, B_idxs], rnn_state)

                self.optimizer.zero_grad()
                pi_loss.backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    self.agent.parameters(), self.clip_grad_norm)
                self.optimizer.step()

                self.v_optimizer.zero_grad()
                value_loss.backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    self.agent.parameters(), self.clip_grad_norm)
                self.v_optimizer.step()

                opt_info.loss.append(pi_loss.item())
                opt_info.gradNorm.append(grad_norm)
                opt_info.entropy.append(entropy.item())
                opt_info.perplexity.append(perplexity.item())
                self.update_counter += 1
        if self.linear_lr_schedule:
            self.lr_scheduler.step()
            self.ratio_clip = self._ratio_clip * (self.n_itr - itr) / self.n_itr

        return opt_info
Ejemplo n.º 4
0
 def optimize_agent(self, itr, samples):
     self.optimizer.zero_grad()
     loss, entropy, perplexity = self.loss(samples)
     loss.backward()
     grad_norm = torch.nn.utils.clip_grad_norm_(self.agent.parameters(),
                                                self.clip_grad_norm)
     self.optimizer.step()
     opt_info = OptInfo(
         loss=loss.item(),
         gradNorm=grad_norm,
         entropy=entropy.item(),
         perplexity=perplexity.item(),
     )
     return opt_info
Ejemplo n.º 5
0
    def optimize_agent(self, itr, samples):
        recurrent = self.agent.recurrent
        agent_inputs = AgentInputs(observation=samples.env.observation,
                                   prev_action=samples.agent.prev_action,
                                   prev_reward=samples.env.prev_reward)
        agent_inputs = buffer_to(agent_inputs, device=self.agent.device)
        return_, advantage, valid = self.process_returns(samples)

        loss_inputs = LossInputs(
            agent_inputs=agent_inputs,
            action=samples.agent.action,
            return_=return_,
            advantage=advantage,
            valid=valid,
            old_dist_info=samples.agent.agent_info.dist_info,
        )

        if recurrent:
            init_rnn_state = samples.agent.agent_info.prev_rnn_state[0]
        T, B = samples.env.reward.shape[:2]
        opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields))))

        batch_size = B if self.agent.recurrent else T * B
        mb_size = batch_size // self.minibatches

        for _ in range(self.epochs):
            for idxs in iterate_mb_idxs(batch_size, mb_size, shuffle=True):
                T_idxs = slice(None) if recurrent else idxs % T
                B_idxs = idxs if recurrent else idxs // T
                self.optimizer.zero_grad()
                rnn_state = init_rnn_state[B_idxs] if recurrent else None

                loss, entropy, perplexity = self.loss(
                    *loss_inputs[T_idxs, B_idxs], rnn_state)
                loss.backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    self.agent.parameters(), self.clip_grad_norm)
                self.optimizer.step()

                opt_info.loss.append(loss.item())
                opt_info.gradNorm.append(grad_norm)
                opt_info.entropy.append(entropy.item())
                opt_info.perplexity.append(perplexity.item())
        if self.linear_lr_schedule:
            self.lr_scheduler.step()
            self.ratio_clip = self._ratio_clip * (self.n_itr -
                                                  itr) / self.n_itr

        return opt_info
Ejemplo n.º 6
0
 def optimize_agent(self, itr, samples):
     """
     Train the agent on input samples, by one gradient step.
     """
     self.optimizer.zero_grad()
     loss, entropy, perplexity = self.loss(samples)
     loss.backward()
     grad_norm = torch.nn.utils.clip_grad_norm_(self.agent.parameters(),
                                                self.clip_grad_norm)
     self.optimizer.step()
     opt_info = OptInfo(
         loss=loss.item(),
         gradNorm=grad_norm,
         entropy=entropy.item(),
         perplexity=perplexity.item(),
     )
     self.update_counter += 1
     return opt_info
Ejemplo n.º 7
0
 def optimize_agent(self, itr, samples=None, sampler_itr=None):
     """
     Train the agent, for multiple epochs over minibatches taken from the
     input samples.  Organizes agent inputs from the training data, and
     moves them to device (e.g. GPU) up front, so that minibatches are
     formed within device, without further data transfer.
     """
     opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields))))
     agent_inputs = AgentInputs(  # Move inputs to device once, index there.
         observation=samples.env.observation,
         prev_action=samples.agent.prev_action,
         prev_reward=samples.env.prev_reward,
     )
     agent_inputs = buffer_to(agent_inputs, device=self.agent.device)
     init_rnn_states = buffer_to(samples.agent.agent_info.prev_rnn_state[0],
                                 device=self.agent.device)
     T, B = samples.env.reward.shape[:2]
     mb_size = B // self.minibatches
     for _ in range(self.epochs):
         for idxs in iterate_mb_idxs(B, mb_size, shuffle=True):
             self.optimizer.zero_grad()
             init_rnn_state = buffer_method(init_rnn_states[idxs],
                                            "transpose", 0, 1)
             dist_info, value, _ = self.agent(*agent_inputs[:, idxs],
                                              init_rnn_state)
             loss, opt_info = self.process_returns(
                 samples.env.reward[:, idxs],
                 done=samples.env.done[:, idxs],
                 value_prediction=value.cpu(),
                 action=samples.agent.action[:, idxs],
                 dist_info=dist_info,
                 old_dist_info=samples.agent.agent_info.dist_info[:, idxs],
                 opt_info=opt_info)
             loss.backward()
             self.optimizer.step()
             self.clamp_lagrange_multipliers()
             opt_info.loss.append(loss.item())
             self.update_counter += 1
     return opt_info
Ejemplo n.º 8
0
    def optimize_agent(self, itr, samples):
        """
        Train the agent, for multiple epochs over minibatches taken from the
        input samples.  Organizes agent inputs from the training data, and
        moves them to device (e.g. GPU) up front, so that minibatches are
        formed within device, without further data transfer.
        """
        recurrent = self.agent.recurrent
        agent_inputs = AgentInputs(  # Move inputs to device once, index there.
            observation=samples.env.observation,
            prev_action=samples.agent.prev_action,
            prev_reward=samples.env.prev_reward,
        )
        agent_inputs = buffer_to(agent_inputs, device=self.agent.device)

        if hasattr(self.agent, "update_obs_rms"):
            self.agent.update_obs_rms(agent_inputs.observation)

        if self.agent.dual_model:
            return_, advantage, valid, return_int_, advantage_int = self.process_returns(
                samples)
        else:
            return_, advantage, valid = self.process_returns(samples)

        if self.curiosity_type in {'icm', 'micm', 'disagreement'}:
            agent_curiosity_inputs = IcmAgentCuriosityInputs(
                observation=samples.env.observation.clone(),
                next_observation=samples.env.next_observation.clone(),
                action=samples.agent.action.clone(),
                valid=valid)
            agent_curiosity_inputs = buffer_to(agent_curiosity_inputs,
                                               device=self.agent.device)
        elif self.curiosity_type == 'ndigo':
            agent_curiosity_inputs = NdigoAgentCuriosityInputs(
                observation=samples.env.observation.clone(),
                prev_actions=samples.agent.prev_action.clone(),
                actions=samples.agent.action.clone(),
                valid=valid)
            agent_curiosity_inputs = buffer_to(agent_curiosity_inputs,
                                               device=self.agent.device)
        elif self.curiosity_type == 'rnd':
            agent_curiosity_inputs = RndAgentCuriosityInputs(
                next_observation=samples.env.next_observation.clone(),
                valid=valid)
            agent_curiosity_inputs = buffer_to(agent_curiosity_inputs,
                                               device=self.agent.device)
        elif self.curiosity_type == 'none':
            agent_curiosity_inputs = None

        if self.policy_loss_type == 'dual':
            loss_inputs = LossInputsTwin(  # So can slice all.
                agent_inputs=agent_inputs,
                agent_curiosity_inputs=agent_curiosity_inputs,
                action=samples.agent.action,
                return_=return_,
                advantage=advantage,
                valid=valid,
                old_dist_info=samples.agent.agent_info.dist_info,
                return_int_=return_int_,
                advantage_int=advantage_int,
                old_dist_int_info=samples.agent.agent_info.dist_int_info,
            )
        else:
            loss_inputs = LossInputs(  # So can slice all.
                agent_inputs=agent_inputs,
                agent_curiosity_inputs=agent_curiosity_inputs,
                action=samples.agent.action,
                return_=return_,
                advantage=advantage,
                valid=valid,
                old_dist_info=samples.agent.agent_info.dist_info,
            )

        if recurrent:
            # Leave in [B,N,H] for slicing to minibatches.
            init_rnn_state = samples.agent.agent_info.prev_rnn_state[0]  # T=0.
            if self.agent.dual_model:
                init_int_rnn_state = samples.agent.agent_info.prev_int_rnn_state[
                    0]  # T=0.

        T, B = samples.env.reward.shape[:2]

        if self.policy_loss_type == 'dual':
            opt_info = OptInfoTwin(*([]
                                     for _ in range(len(OptInfoTwin._fields))))
        else:
            opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields))))

        # If recurrent, use whole trajectories, only shuffle B; else shuffle all.
        batch_size = B if self.agent.recurrent else T * B
        mb_size = batch_size // self.minibatches

        for _ in range(self.epochs):
            for idxs in iterate_mb_idxs(batch_size, mb_size, shuffle=True):
                T_idxs = slice(None) if recurrent else idxs % T
                B_idxs = idxs if recurrent else idxs // T
                self.optimizer.zero_grad()
                rnn_state = init_rnn_state[B_idxs] if recurrent else None

                # NOTE: if not recurrent, will lose leading T dim, should be OK.
                if self.policy_loss_type == 'dual':
                    int_rnn_state = init_int_rnn_state[
                        B_idxs] if recurrent else None
                    loss_inputs_batch = loss_inputs[T_idxs, B_idxs]
                    loss, pi_loss, value_loss, entropy_loss, entropy, perplexity, \
                        int_pi_loss, int_value_loss, int_entropy_loss, int_entropy, int_perplexity, \
                         curiosity_losses = self.loss(
                                    agent_inputs=loss_inputs_batch.agent_inputs,
                                    agent_curiosity_inputs=loss_inputs_batch.agent_curiosity_inputs,
                                    action=loss_inputs_batch.action,
                                    return_=loss_inputs_batch.return_,
                                    advantage=loss_inputs_batch.advantage,
                                    valid=loss_inputs_batch.valid,
                                    old_dist_info=loss_inputs_batch.old_dist_info,
                                    return_int_=loss_inputs_batch.return_int_,
                                    advantage_int=loss_inputs_batch.advantage_int,
                                    old_dist_int_info=loss_inputs_batch.old_dist_int_info,
                                    init_rnn_state=rnn_state, init_int_rnn_state=int_rnn_state)
                else:
                    loss, pi_loss, value_loss, entropy_loss, entropy, perplexity, curiosity_losses = self.loss(
                        *loss_inputs[T_idxs, B_idxs], rnn_state)

                loss.backward()
                count = 0
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    self.agent.parameters(), self.clip_grad_norm)
                self.optimizer.step()

                # Tensorboard summaries
                opt_info.loss.append(loss.item())
                opt_info.pi_loss.append(pi_loss.item())
                opt_info.value_loss.append(value_loss.item())
                opt_info.entropy_loss.append(entropy_loss.item())

                if self.policy_loss_type == 'dual':
                    opt_info.int_pi_loss.append(int_pi_loss.item())
                    opt_info.int_value_loss.append(int_value_loss.item())
                    opt_info.int_entropy_loss.append(int_entropy_loss.item())

                if self.curiosity_type in {'icm', 'micm'}:
                    inv_loss, forward_loss = curiosity_losses
                    opt_info.inv_loss.append(inv_loss.item())
                    opt_info.forward_loss.append(forward_loss.item())
                    opt_info.intrinsic_rewards.append(
                        np.mean(self.intrinsic_rewards))
                    opt_info.extint_ratio.append(np.mean(self.extint_ratio))
                elif self.curiosity_type == 'disagreement':
                    forward_loss = curiosity_losses
                    opt_info.forward_loss.append(forward_loss.item())
                    opt_info.intrinsic_rewards.append(
                        np.mean(self.intrinsic_rewards))
                    opt_info.extint_ratio.append(np.mean(self.extint_ratio))
                elif self.curiosity_type == 'ndigo':
                    forward_loss = curiosity_losses
                    opt_info.forward_loss.append(forward_loss.item())
                    opt_info.intrinsic_rewards.append(
                        np.mean(self.intrinsic_rewards))
                    opt_info.extint_ratio.append(np.mean(self.extint_ratio))
                elif self.curiosity_type == 'rnd':
                    forward_loss = curiosity_losses
                    opt_info.forward_loss.append(forward_loss.item())
                    opt_info.intrinsic_rewards.append(
                        np.mean(self.intrinsic_rewards))
                    opt_info.extint_ratio.append(np.mean(self.extint_ratio))

                if self.normalize_reward:
                    opt_info.reward_total_std.append(self.reward_rms.var**0.5)
                    if self.policy_loss_type == 'dual':
                        opt_info.int_reward_total_std.append(
                            self.int_reward_rms.var**0.5)

                opt_info.entropy.append(entropy.item())
                opt_info.perplexity.append(perplexity.item())

                if self.policy_loss_type == 'dual':
                    opt_info.int_entropy.append(int_entropy.item())
                    opt_info.int_perplexity.append(int_perplexity.item())
                self.update_counter += 1

        opt_info.return_.append(
            torch.mean(return_.detach()).detach().clone().item())
        opt_info.advantage.append(
            torch.mean(advantage.detach()).detach().clone().item())
        opt_info.valpred.append(
            torch.mean(samples.agent.agent_info.value.detach()).detach().clone(
            ).item())

        if self.policy_loss_type == 'dual':
            opt_info.return_int_.append(
                torch.mean(return_int_.detach()).detach().clone().item())
            opt_info.advantage_int.append(
                torch.mean(advantage_int.detach()).detach().clone().item())
            opt_info.int_valpred.append(
                torch.mean(samples.agent.agent_info.int_value.detach()).detach(
                ).clone().item())

        if self.linear_lr_schedule:
            self.lr_scheduler.step()
            self.ratio_clip = self._ratio_clip * (self.n_itr -
                                                  itr) / self.n_itr

        layer_info = dict(
        )  # empty dict to store model layer weights for tensorboard visualizations

        return opt_info, layer_info