Ejemplo n.º 1
0
    def optimize_agent(self, itr, samples):
        """
        Train the agent, for multiple epochs over minibatches taken from the
        input samples.  Organizes agent inputs from the training data, and
        moves them to device (e.g. GPU) up front, so that minibatches are
        formed within device, without further data transfer.
        """
        recurrent = self.agent.recurrent
        agent_inputs = AgentInputs(  # Move inputs to device once, index there.
            observation=samples.env.observation,
            prev_action=samples.agent.prev_action,
            prev_reward=samples.env.prev_reward,
        )
        agent_inputs = buffer_to(agent_inputs, device=self.agent.device)
        if hasattr(self.agent, "update_obs_rms"):
            self.agent.update_obs_rms(agent_inputs.observation)
        return_, advantage, valid = self.process_returns(samples, self.normalize_rewards)
        loss_inputs = LossInputs(  # So can slice all.
            agent_inputs=agent_inputs,
            action=samples.agent.action,
            return_=return_,
            advantage=advantage,
            valid=valid,
            old_dist_info=samples.agent.agent_info.dist_info,
        )
        if recurrent:
            # Leave in [B,N,H] for slicing to minibatches.
            init_rnn_state = samples.agent.agent_info.prev_rnn_state[0]  # T=0.
        T, B = samples.env.reward.shape[:2]
        opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields))))
        # If recurrent, use whole trajectories, only shuffle B; else shuffle all.
        batch_size = B if self.agent.recurrent else T * B
        mb_size = batch_size // self.minibatches
        for _ in range(self.epochs):
            for idxs in iterate_mb_idxs(batch_size, mb_size, shuffle=True):
                T_idxs = slice(None) if recurrent else idxs % T
                B_idxs = idxs if recurrent else idxs // T
                self.optimizer.zero_grad()
                rnn_state = init_rnn_state[B_idxs] if recurrent else None
                # NOTE: if not recurrent, will lose leading T dim, should be OK.
                loss, entropy, perplexity = self.loss(
                    *loss_inputs[T_idxs, B_idxs], rnn_state)
                loss.backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    self.agent.parameters(), self.clip_grad_norm)
                self.optimizer.step()

                opt_info.loss.append(loss.item())
                opt_info.gradNorm.append(grad_norm)
                opt_info.entropy.append(entropy.item())
                opt_info.perplexity.append(perplexity.item())
                self.update_counter += 1
        if self.linear_lr_schedule:
            self.lr_scheduler.step()
            self.ratio_clip = self._ratio_clip * (self.n_itr - itr) / self.n_itr
        # if self.vae_lr_scheduler:
        #     self.vae_lr_scheduler.step()

        return opt_info
    def optimize_agent(self, itr, samples):
        recurrent = self.agent.recurrent
        agent_inputs = AgentInputs(  # Move inputs to device once, index there.
            observation=samples.env.observation,
            prev_action=samples.agent.prev_action,
            prev_reward=samples.env.prev_reward,
        )
        agent_inputs = buffer_to(agent_inputs, device=self.agent.device)
        return_, advantage, valid = self.process_returns(samples)
        loss_inputs = LossInputs(  # So can slice all.
            agent_inputs=agent_inputs,
            action=samples.agent.action,
            return_=return_,
            advantage=advantage,
            valid=valid,
            old_dist_info=samples.agent.agent_info.dist_info,
        )
        if recurrent:
            # Leave in [B,N,H] for slicing to minibatches.
            init_rnn_state = samples.agent.agent_info.prev_rnn_state[0]  # T=0.
        T, B = samples.env.reward.shape[:2]
        opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields))))
        # If recurrent, use whole trajectories, only shuffle B; else shuffle all.
        batch_size = B if self.agent.recurrent else T * B
        mb_size = batch_size // self.minibatches
        for _ in range(self.epochs):
            for idxs in iterate_mb_idxs(batch_size, mb_size, shuffle=True):
                T_idxs = slice(None) if recurrent else idxs % T
                B_idxs = idxs if recurrent else idxs // T
                rnn_state = init_rnn_state[B_idxs] if recurrent else None
                # NOTE: if not recurrent, will lose leading T dim, should be OK.
                pi_loss, value_loss, entropy, perplexity = self.loss(
                    *loss_inputs[T_idxs, B_idxs], rnn_state)

                self.optimizer.zero_grad()
                pi_loss.backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    self.agent.parameters(), self.clip_grad_norm)
                self.optimizer.step()

                self.v_optimizer.zero_grad()
                value_loss.backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    self.agent.parameters(), self.clip_grad_norm)
                self.v_optimizer.step()

                opt_info.loss.append(pi_loss.item())
                opt_info.gradNorm.append(grad_norm)
                opt_info.entropy.append(entropy.item())
                opt_info.perplexity.append(perplexity.item())
                self.update_counter += 1
        if self.linear_lr_schedule:
            self.lr_scheduler.step()
            self.ratio_clip = self._ratio_clip * (self.n_itr - itr) / self.n_itr

        return opt_info
Ejemplo n.º 3
0
    def curiosity_step(self, curiosity_type, *args):
        curiosity_model = self.model.module.curiosity_model if isinstance(
            self.model, torch.nn.parallel.DistributedDataParallel
        ) else self.model.curiosity_model
        curiosity_step_minibatches = self.model_kwargs[
            'curiosity_step_kwargs']['curiosity_step_minibatches']
        T, B = args[0].shape[:2]  # either observation or next_observation
        batch_size = B
        mb_size = batch_size // curiosity_step_minibatches

        if curiosity_type in {'icm', 'micm', 'disagreement'}:
            observation, next_observation, actions = args
            actions = self.distribution.to_onehot(actions)
            curiosity_agent_inputs = IcmAgentCuriosityStepInputs(
                observation=observation,
                next_observation=next_observation,
                actions=actions)
            curiosity_agent_inputs = buffer_to(curiosity_agent_inputs,
                                               device=self.device)
            agent_curiosity_info = IcmInfo()
        elif curiosity_type == 'ndigo':
            observation, prev_actions, actions = args
            actions = self.distribution.to_onehot(actions)
            prev_actions = self.distribution.to_onehot(prev_actions)
            curiosity_agent_inputs = NdigoAgentCuriosityStepInputs(
                observations=observation,
                prev_actions=prev_actions,
                actions=actions)
            curiosity_agent_inputs = buffer_to(curiosity_agent_inputs,
                                               device=self.device)
            agent_curiosity_info = NdigoInfo(prev_gru_state=None)
        elif curiosity_type == 'rnd':
            next_observation, done = args
            curiosity_agent_inputs = RndAgentCuriosityStepInputs(
                next_observation=next_observation, done=done)
            curiosity_agent_inputs = buffer_to(curiosity_agent_inputs,
                                               device=self.device)
            agent_curiosity_info = RndInfo()

        # Need to split the intrinsic reward predictions to several minibatches -- otherwise, we will run out of GPU memory
        r_ints = []
        for idxs in iterate_mb_idxs(batch_size, mb_size, shuffle=False):
            T_idxs = slice(None)
            B_idxs = idxs
            mb_r_int = curiosity_model.compute_bonus(
                *curiosity_agent_inputs[slice(None), B_idxs])
            r_ints.append(mb_r_int)
        r_int = torch.cat(r_ints, dim=1)

        r_int, agent_curiosity_info = buffer_to((r_int, agent_curiosity_info),
                                                device="cpu")

        return AgentCuriosityStep(r_int=r_int,
                                  agent_curiosity_info=agent_curiosity_info)
Ejemplo n.º 4
0
    def optimize_agent(self, itr, samples):
        recurrent = self.agent.recurrent
        agent_inputs = AgentInputs(observation=samples.env.observation,
                                   prev_action=samples.agent.prev_action,
                                   prev_reward=samples.env.prev_reward)
        agent_inputs = buffer_to(agent_inputs, device=self.agent.device)
        return_, advantage, valid = self.process_returns(samples)

        loss_inputs = LossInputs(
            agent_inputs=agent_inputs,
            action=samples.agent.action,
            return_=return_,
            advantage=advantage,
            valid=valid,
            old_dist_info=samples.agent.agent_info.dist_info,
        )

        if recurrent:
            init_rnn_state = samples.agent.agent_info.prev_rnn_state[0]
        T, B = samples.env.reward.shape[:2]
        opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields))))

        batch_size = B if self.agent.recurrent else T * B
        mb_size = batch_size // self.minibatches

        for _ in range(self.epochs):
            for idxs in iterate_mb_idxs(batch_size, mb_size, shuffle=True):
                T_idxs = slice(None) if recurrent else idxs % T
                B_idxs = idxs if recurrent else idxs // T
                self.optimizer.zero_grad()
                rnn_state = init_rnn_state[B_idxs] if recurrent else None

                loss, entropy, perplexity = self.loss(
                    *loss_inputs[T_idxs, B_idxs], rnn_state)
                loss.backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    self.agent.parameters(), self.clip_grad_norm)
                self.optimizer.step()

                opt_info.loss.append(loss.item())
                opt_info.gradNorm.append(grad_norm)
                opt_info.entropy.append(entropy.item())
                opt_info.perplexity.append(perplexity.item())
        if self.linear_lr_schedule:
            self.lr_scheduler.step()
            self.ratio_clip = self._ratio_clip * (self.n_itr -
                                                  itr) / self.n_itr

        return opt_info
Ejemplo n.º 5
0
 def optimize_agent(self, itr, samples=None, sampler_itr=None):
     """
     Train the agent, for multiple epochs over minibatches taken from the
     input samples.  Organizes agent inputs from the training data, and
     moves them to device (e.g. GPU) up front, so that minibatches are
     formed within device, without further data transfer.
     """
     opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields))))
     agent_inputs = AgentInputs(  # Move inputs to device once, index there.
         observation=samples.env.observation,
         prev_action=samples.agent.prev_action,
         prev_reward=samples.env.prev_reward,
     )
     agent_inputs = buffer_to(agent_inputs, device=self.agent.device)
     init_rnn_states = buffer_to(samples.agent.agent_info.prev_rnn_state[0],
                                 device=self.agent.device)
     T, B = samples.env.reward.shape[:2]
     mb_size = B // self.minibatches
     for _ in range(self.epochs):
         for idxs in iterate_mb_idxs(B, mb_size, shuffle=True):
             self.optimizer.zero_grad()
             init_rnn_state = buffer_method(init_rnn_states[idxs],
                                            "transpose", 0, 1)
             dist_info, value, _ = self.agent(*agent_inputs[:, idxs],
                                              init_rnn_state)
             loss, opt_info = self.process_returns(
                 samples.env.reward[:, idxs],
                 done=samples.env.done[:, idxs],
                 value_prediction=value.cpu(),
                 action=samples.agent.action[:, idxs],
                 dist_info=dist_info,
                 old_dist_info=samples.agent.agent_info.dist_info[:, idxs],
                 opt_info=opt_info)
             loss.backward()
             self.optimizer.step()
             self.clamp_lagrange_multipliers()
             opt_info.loss.append(loss.item())
             self.update_counter += 1
     return opt_info
Ejemplo n.º 6
0
    def optimize_agent(self, itr, samples):
        """
        Override to provide additional flexibility in what enters the combined_loss function.
        """
        opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields))))
        recurrent = self.agent.recurrent
        agent_inputs = AgentInputs(
            observation=samples.env.observation,
            prev_action=samples.agent.prev_action,
            prev_reward=samples.env.prev_reward,
        )
        agent_inputs = buffer_to(agent_inputs, device=self.agent.device)
        if hasattr(self.agent, "update_obs_rms"):
            self.agent.update_obs_rms(agent_inputs.observation)

        # Process extrinsic returns and advantages
        ext_rew, done, ext_val, ext_bv = (samples.env.reward, samples.env.done,
                                          samples.agent.agent_info.ext_value,
                                          samples.agent.bootstrap_value)
        done = done.type(ext_rew.dtype)
        if self.ext_rew_clip:  # Clip extrinsic reward is specified
            rew_min, rew_max = self.ext_rew_clip
            ext_rew = ext_rew.clamp(rew_min, rew_max)
        ext_return, ext_adv, valid = self.process_extrinsic_returns(
            ext_rew, done, ext_val, ext_bv)

        # Gather next observations, or fill with dummy placeholder (current obs)
        # Note the agent decides what it extracts and uses as input to its model,
        # so the dummy tensor scenario will have no effect
        next_obs = samples.env.next_observation if "next_observation" in samples.env else samples.env.observation

        # First call to bonus model, generates intrinsic rewards for samples batch
        # [T, B] leading dims are flattened, and the resulting returns are unflattened
        batch_shape = samples.env.observation.shape[:2]
        bonus_model_inputs = self.agent.extract_bonus_inputs(
            observation=samples.env.observation.flatten(end_dim=1),
            next_observation=next_obs.flatten(
                end_dim=1
            ),  # May be same as observation (dummy placeholder) if algo set next_obs=False
            action=samples.agent.action.flatten(end_dim=1))
        self.agent.set_norm_update(
            True
        )  # Bonus model will update any normalization models where applicable
        with torch.no_grad():
            int_rew, _ = self.agent.bonus_call(bonus_model_inputs)
        int_rew = int_rew.view(batch_shape)

        # Process intrinsic returns and advantages (updating intrinsic reward normalization model, if applicable)
        int_val, int_bv = samples.agent.agent_info.int_value, samples.agent.int_bootstrap_value
        int_return, int_adv = self.process_intrinsic_returns(
            int_rew, int_val, int_bv)

        # Avoid repeating any norm updates on same data in subsequent loss forward calls
        self.agent.set_norm_update(False)

        # Add front-processed optimizer data to logging buffer
        # Flattened to match elsewhere, though the ultimate statistics summarize over all dims anyway
        opt_info.extrinsicValue.extend(ext_val.flatten().tolist())
        opt_info.intrinsicValue.extend(int_val.flatten().tolist())
        opt_info.intrinsicReward.extend(int_rew.flatten().tolist())
        opt_info.discountedIntrinsicReturn.extend(
            int_return.flatten().tolist())
        opt_info.meanObsRmsModel.extend(
            self.agent.bonus_model.obs_rms.mean.flatten().tolist())
        opt_info.varObsRmsModel.extend(
            self.agent.bonus_model.obs_rms.var.flatten().tolist())
        opt_info.meanIntRetRmsModel.extend(
            self.agent.bonus_model.int_rff_rms.mean.flatten().tolist())
        opt_info.varIntRetRmsModel.extend(
            self.agent.bonus_model.int_rff_rms.var.flatten().tolist())

        loss_inputs = LossInputs(  # So can slice all.
            agent_inputs=agent_inputs,
            action=samples.agent.action,
            next_obs=next_obs,
            ext_return=ext_return,
            ext_adv=ext_adv,
            int_return=int_return,
            int_adv=int_adv,
            valid=valid,
            old_dist_info=samples.agent.agent_info.dist_info)
        if recurrent:
            # Leave in [B,N,H] for slicing to minibatches.
            init_rnn_state = samples.agent.agent_info.prev_rnn_state[0]  # T=0.
        # If recurrent, use whole trajectories, only shuffle B; else shuffle all.
        T, B = samples.env.reward.shape[:2]
        batch_size = B if self.agent.recurrent else T * B
        mb_size = batch_size // self.minibatches
        for _ in range(self.epochs):
            for idxs in iterate_mb_idxs(batch_size, mb_size, shuffle=True):
                T_idxs = slice(None) if recurrent else idxs % T
                B_idxs = idxs if recurrent else idxs // T
                self.optimizer.zero_grad()
                rnn_state = init_rnn_state[B_idxs] if recurrent else None
                # NOTE: if not recurrent, will lose leading T dim, should be OK.
                # Combined loss produces single loss for both actor and bonus model
                loss, entropy, perplexity, pi_loss, value_loss, entropy_loss, bonus_loss = \
                    self.combined_loss(*loss_inputs[T_idxs, B_idxs], rnn_state)
                loss.backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    self.agent.parameters(), self.clip_grad_norm)
                self.optimizer.step()

                opt_info.loss.append(loss.item())
                opt_info.policyLoss.append(pi_loss.item())
                opt_info.valueLoss.append(value_loss.item())
                opt_info.entropyLoss.append(entropy_loss.item())
                opt_info.bonusLoss.append(bonus_loss.item())
                opt_info.gradNorm.append(grad_norm)
                opt_info.entropy.append(entropy.item())
                opt_info.perplexity.append(perplexity.item())
                self.update_counter += 1
        if self.linear_lr_schedule:
            self.lr_scheduler.step()
            self.ratio_clip = self._ratio_clip * (self.n_itr -
                                                  itr) / self.n_itr

        return opt_info
Ejemplo n.º 7
0
    def compute_beta_kl(self, loss_inputs, init_rnn_state, batch_size, mb_size,
                        T):
        """Ratio of KL divergences from reward-only vs cost-only updates."""
        self.agent.beta_r_model.load_state_dict(
            strip_ddp_state_dict(self.agent.model.state_dict()))
        self.agent.beta_c_model.load_state_dict(
            strip_ddp_state_dict(self.agent.model.state_dict()))
        self.beta_r_optimizer.load_state_dict(self.optimizer.state_dict())
        self.beta_c_optimizer.load_state_dict(self.optimizer.state_dict())

        recurrent = self.agent.recurrent
        for _ in range(self.beta_kl_epochs):
            for idxs in iterate_mb_idxs(batch_size,
                                        mb_size,
                                        shuffle=batch_size > mb_size):
                T_idxs = slice(None) if recurrent else idxs % T
                B_idxs = idxs if recurrent else idxs // T
                rnn_state = init_rnn_state[B_idxs] if recurrent else None
                self.beta_r_optimizer.zero_grad()
                self.beta_c_optimizer.zero_grad()

                beta_r_loss, beta_c_loss = self.beta_kl_losses(
                    *loss_inputs[T_idxs, B_idxs], rnn_state)

                beta_r_loss.backward()
                _ = torch.nn.utils.clip_grad_norm_(
                    self.agent.beta_r_model.parameters(), self.clip_grad_norm)
                self.beta_r_optimizer.step()

                beta_c_loss.backward()
                _ = torch.nn.utils.clip_grad_norm_(
                    self.agent.beta_c_model.parameters(), self.clip_grad_norm)
                self.beta_c_optimizer.step()

        if init_rnn_state is not None:
            # [B,N,H] --> [N,B,H] (for cudnn).
            init_rnn_state = buffer_method(init_rnn_state, "transpose", 0, 1)
            init_rnn_state = buffer_method(init_rnn_state, "contiguous")
            with torch.no_grad():
                r_dist_info, c_dist_info = self.agent.beta_dist_infos(
                    *loss_inputs.agent_inputs, init_rnn_state)
        else:
            with torch.no_grad():
                r_dist_info, c_dist_info = self.agent.beta_dist_infos(
                    *loss_inputs.agent_inputs, init_rnn_state)

        dist = self.agent.distribution
        beta_r_KL = dist.mean_kl(new_dist_info=r_dist_info,
                                 old_dist_info=loss_inputs.old_dist_info,
                                 valid=loss_inputs.valid)
        beta_c_KL = dist.mean_kl(new_dist_info=c_dist_info,
                                 old_dist_info=loss_inputs.old_dist_info,
                                 valid=loss_inputs.valid)

        if self._ddp:
            beta_KLs = torch.stack([beta_r_KL, beta_c_KL])
            beta_KLs = beta_KLs.to(self.agent.device)
            torch.distributed.all_reduce(beta_KLs)
            beta_KLs = beta_KLs.to("cpu")
            beta_KLs /= torch.distributed.get_world_size()
            beta_r_KL, beta_c_KL = beta_KLs[0], beta_KLs[1]

        raw_beta_KL = float(beta_r_KL / max(beta_c_KL, 1e-8))

        return raw_beta_KL, float(beta_r_KL), float(beta_c_KL)
Ejemplo n.º 8
0
    def optimize_agent(self, itr, samples):
        recurrent = self.agent.recurrent
        agent_inputs = AgentInputs(  # Move inputs to device once, index there.
            observation=samples.env.observation,
            prev_action=samples.agent.prev_action,
            prev_reward=samples.env.prev_reward,
        )
        agent_inputs = buffer_to(agent_inputs, device=self.agent.device)
        # return_, advantage, valid = self.process_returns(samples)
        (return_, advantage, valid, c_return, c_advantage,
         ep_cost_avg) = self.process_returns(itr, samples)
        loss_inputs = LossInputs(  # So can slice all.
            agent_inputs=agent_inputs,
            action=samples.agent.action,
            return_=return_,
            advantage=advantage,
            valid=valid,
            old_dist_info=samples.agent.agent_info.dist_info,
            c_return=c_return,  # Can be None.
            c_advantage=c_advantage,
        )
        opt_info = OptInfoCost(*([] for _ in range(len(OptInfoCost._fields))))

        if (self.step_cost_limit_itr is not None
                and self.step_cost_limit_itr == itr):
            self.cost_limit = self.step_cost_limit_value
        opt_info.costLimit.append(self.cost_limit)

        # PID update here:
        delta = float(ep_cost_avg - self.cost_limit)  # ep_cost_avg: tensor
        self.pid_i = max(0., self.pid_i + delta * self.pid_Ki)
        if self.diff_norm:
            self.pid_i = max(0., min(1., self.pid_i))
        a_p = self.pid_delta_p_ema_alpha
        self._delta_p *= a_p
        self._delta_p += (1 - a_p) * delta
        a_d = self.pid_delta_d_ema_alpha
        self._cost_d *= a_d
        self._cost_d += (1 - a_d) * float(ep_cost_avg)
        pid_d = max(0., self._cost_d - self.cost_ds[0])
        pid_o = (self.pid_Kp * self._delta_p + self.pid_i +
                 self.pid_Kd * pid_d)
        self.cost_penalty = max(0., pid_o)
        if self.diff_norm:
            self.cost_penalty = min(1., self.cost_penalty)
        if not (self.diff_norm or self.sum_norm):
            self.cost_penalty = min(self.cost_penalty, self.penalty_max)
        self.cost_ds.append(self._cost_d)
        opt_info.pid_i.append(self.pid_i)
        opt_info.pid_p.append(self._delta_p)
        opt_info.pid_d.append(pid_d)
        opt_info.pid_o.append(pid_o)

        opt_info.costPenalty.append(self.cost_penalty)

        if hasattr(self.agent, "update_obs_rms"):
            self.agent.update_obs_rms(agent_inputs.observation)
            if itr == 0:
                return opt_info  # Sacrifice the first batch to get obs stats.

        if recurrent:
            # Leave in [B,N,H] for slicing to minibatches.
            init_rnn_state = samples.agent.agent_info.prev_rnn_state[0]  # T=0.
        T, B = samples.env.reward.shape[:2]
        # If recurrent, use whole trajectories, only shuffle B; else shuffle all.
        batch_size = B if self.agent.recurrent else T * B
        mb_size = batch_size // self.minibatches

        if self.use_beta_kl or self.record_beta_kl:
            raw_beta_kl, beta_r_kl, beta_c_kl = self.compute_beta_kl(
                loss_inputs, init_rnn_state, batch_size, mb_size, T)
            beta_KL = min(self.beta_max, max(self.beta_min, raw_beta_kl))
            self._beta_kl *= self.beta_ema_alpha
            self._beta_kl += (1 - self.beta_ema_alpha) * beta_KL
            opt_info.betaKlRaw.append(raw_beta_kl)
            opt_info.betaKL.append(self._beta_kl)
            opt_info.betaKlR.append(beta_r_kl)
            opt_info.betaKlC.append(beta_c_kl)
            # print("raw_beta_kl: ", raw_beta_kl)
            # print("self._beta_kl: ", self._beta_kl, "\n\n")

        if self.use_beta_grad or self.record_beta_grad:
            raw_beta_grad = self.compute_beta_grad(loss_inputs, init_rnn_state)
            beta_grad = min(self.beta_max, max(self.beta_min, raw_beta_grad))
            self._beta_grad *= self.beta_ema_alpha
            self._beta_grad += (1 - self.beta_ema_alpha) * beta_grad
            opt_info.betaGradRaw.append(raw_beta_grad)
            opt_info.betaGrad.append(self._beta_grad)

        for _ in range(self.epochs):
            for idxs in iterate_mb_idxs(batch_size, mb_size, shuffle=True):
                T_idxs = slice(None) if recurrent else idxs % T
                B_idxs = idxs if recurrent else idxs // T
                self.optimizer.zero_grad()
                rnn_state = init_rnn_state[B_idxs] if recurrent else None
                # NOTE: if not recurrent, will lose leading T dim, should be OK.
                loss, entropy, perplexity, value_errors, abs_value_errors = self.loss(
                    *loss_inputs[T_idxs, B_idxs], rnn_state)
                loss.backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    self.agent.parameters(), self.clip_grad_norm)
                self.optimizer.step()

                opt_info.loss.append(loss.item())
                opt_info.gradNorm.append(grad_norm)
                opt_info.entropy.append(entropy.item())
                opt_info.perplexity.append(perplexity.item())
                opt_info.valueError.extend(value_errors[0][::10].numpy())
                opt_info.cvalueError.extend(value_errors[1][::10].numpy())
                opt_info.valueAbsError.extend(
                    abs_value_errors[0][::10].numpy())
                opt_info.cvalueAbsError.extend(
                    abs_value_errors[1][::10].numpy())

                self.update_counter += 1
        if self.linear_lr_schedule:
            self.lr_scheduler.step()
            self.ratio_clip = self._ratio_clip * \
                (self.n_itr - itr) / self.n_itr

        return opt_info
Ejemplo n.º 9
0
    def optimize_agent(self, itr, samples):
        """
        Train the agent, for multiple epochs over minibatches taken from the
        input samples.  Organizes agent inputs from the training data, and
        moves them to device (e.g. GPU) up front, so that minibatches are
        formed within device, without further data transfer.
        """
        recurrent = self.agent.recurrent
        agent_inputs = AgentInputs(  # Move inputs to device once, index there.
            observation=samples.env.observation,
            prev_action=samples.agent.prev_action,
            prev_reward=samples.env.prev_reward,
        )
        agent_inputs = buffer_to(agent_inputs, device=self.agent.device)

        if hasattr(self.agent, "update_obs_rms"):
            self.agent.update_obs_rms(agent_inputs.observation)

        if self.agent.dual_model:
            return_, advantage, valid, return_int_, advantage_int = self.process_returns(
                samples)
        else:
            return_, advantage, valid = self.process_returns(samples)

        if self.curiosity_type in {'icm', 'micm', 'disagreement'}:
            agent_curiosity_inputs = IcmAgentCuriosityInputs(
                observation=samples.env.observation.clone(),
                next_observation=samples.env.next_observation.clone(),
                action=samples.agent.action.clone(),
                valid=valid)
            agent_curiosity_inputs = buffer_to(agent_curiosity_inputs,
                                               device=self.agent.device)
        elif self.curiosity_type == 'ndigo':
            agent_curiosity_inputs = NdigoAgentCuriosityInputs(
                observation=samples.env.observation.clone(),
                prev_actions=samples.agent.prev_action.clone(),
                actions=samples.agent.action.clone(),
                valid=valid)
            agent_curiosity_inputs = buffer_to(agent_curiosity_inputs,
                                               device=self.agent.device)
        elif self.curiosity_type == 'rnd':
            agent_curiosity_inputs = RndAgentCuriosityInputs(
                next_observation=samples.env.next_observation.clone(),
                valid=valid)
            agent_curiosity_inputs = buffer_to(agent_curiosity_inputs,
                                               device=self.agent.device)
        elif self.curiosity_type == 'none':
            agent_curiosity_inputs = None

        if self.policy_loss_type == 'dual':
            loss_inputs = LossInputsTwin(  # So can slice all.
                agent_inputs=agent_inputs,
                agent_curiosity_inputs=agent_curiosity_inputs,
                action=samples.agent.action,
                return_=return_,
                advantage=advantage,
                valid=valid,
                old_dist_info=samples.agent.agent_info.dist_info,
                return_int_=return_int_,
                advantage_int=advantage_int,
                old_dist_int_info=samples.agent.agent_info.dist_int_info,
            )
        else:
            loss_inputs = LossInputs(  # So can slice all.
                agent_inputs=agent_inputs,
                agent_curiosity_inputs=agent_curiosity_inputs,
                action=samples.agent.action,
                return_=return_,
                advantage=advantage,
                valid=valid,
                old_dist_info=samples.agent.agent_info.dist_info,
            )

        if recurrent:
            # Leave in [B,N,H] for slicing to minibatches.
            init_rnn_state = samples.agent.agent_info.prev_rnn_state[0]  # T=0.
            if self.agent.dual_model:
                init_int_rnn_state = samples.agent.agent_info.prev_int_rnn_state[
                    0]  # T=0.

        T, B = samples.env.reward.shape[:2]

        if self.policy_loss_type == 'dual':
            opt_info = OptInfoTwin(*([]
                                     for _ in range(len(OptInfoTwin._fields))))
        else:
            opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields))))

        # If recurrent, use whole trajectories, only shuffle B; else shuffle all.
        batch_size = B if self.agent.recurrent else T * B
        mb_size = batch_size // self.minibatches

        for _ in range(self.epochs):
            for idxs in iterate_mb_idxs(batch_size, mb_size, shuffle=True):
                T_idxs = slice(None) if recurrent else idxs % T
                B_idxs = idxs if recurrent else idxs // T
                self.optimizer.zero_grad()
                rnn_state = init_rnn_state[B_idxs] if recurrent else None

                # NOTE: if not recurrent, will lose leading T dim, should be OK.
                if self.policy_loss_type == 'dual':
                    int_rnn_state = init_int_rnn_state[
                        B_idxs] if recurrent else None
                    loss_inputs_batch = loss_inputs[T_idxs, B_idxs]
                    loss, pi_loss, value_loss, entropy_loss, entropy, perplexity, \
                        int_pi_loss, int_value_loss, int_entropy_loss, int_entropy, int_perplexity, \
                         curiosity_losses = self.loss(
                                    agent_inputs=loss_inputs_batch.agent_inputs,
                                    agent_curiosity_inputs=loss_inputs_batch.agent_curiosity_inputs,
                                    action=loss_inputs_batch.action,
                                    return_=loss_inputs_batch.return_,
                                    advantage=loss_inputs_batch.advantage,
                                    valid=loss_inputs_batch.valid,
                                    old_dist_info=loss_inputs_batch.old_dist_info,
                                    return_int_=loss_inputs_batch.return_int_,
                                    advantage_int=loss_inputs_batch.advantage_int,
                                    old_dist_int_info=loss_inputs_batch.old_dist_int_info,
                                    init_rnn_state=rnn_state, init_int_rnn_state=int_rnn_state)
                else:
                    loss, pi_loss, value_loss, entropy_loss, entropy, perplexity, curiosity_losses = self.loss(
                        *loss_inputs[T_idxs, B_idxs], rnn_state)

                loss.backward()
                count = 0
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    self.agent.parameters(), self.clip_grad_norm)
                self.optimizer.step()

                # Tensorboard summaries
                opt_info.loss.append(loss.item())
                opt_info.pi_loss.append(pi_loss.item())
                opt_info.value_loss.append(value_loss.item())
                opt_info.entropy_loss.append(entropy_loss.item())

                if self.policy_loss_type == 'dual':
                    opt_info.int_pi_loss.append(int_pi_loss.item())
                    opt_info.int_value_loss.append(int_value_loss.item())
                    opt_info.int_entropy_loss.append(int_entropy_loss.item())

                if self.curiosity_type in {'icm', 'micm'}:
                    inv_loss, forward_loss = curiosity_losses
                    opt_info.inv_loss.append(inv_loss.item())
                    opt_info.forward_loss.append(forward_loss.item())
                    opt_info.intrinsic_rewards.append(
                        np.mean(self.intrinsic_rewards))
                    opt_info.extint_ratio.append(np.mean(self.extint_ratio))
                elif self.curiosity_type == 'disagreement':
                    forward_loss = curiosity_losses
                    opt_info.forward_loss.append(forward_loss.item())
                    opt_info.intrinsic_rewards.append(
                        np.mean(self.intrinsic_rewards))
                    opt_info.extint_ratio.append(np.mean(self.extint_ratio))
                elif self.curiosity_type == 'ndigo':
                    forward_loss = curiosity_losses
                    opt_info.forward_loss.append(forward_loss.item())
                    opt_info.intrinsic_rewards.append(
                        np.mean(self.intrinsic_rewards))
                    opt_info.extint_ratio.append(np.mean(self.extint_ratio))
                elif self.curiosity_type == 'rnd':
                    forward_loss = curiosity_losses
                    opt_info.forward_loss.append(forward_loss.item())
                    opt_info.intrinsic_rewards.append(
                        np.mean(self.intrinsic_rewards))
                    opt_info.extint_ratio.append(np.mean(self.extint_ratio))

                if self.normalize_reward:
                    opt_info.reward_total_std.append(self.reward_rms.var**0.5)
                    if self.policy_loss_type == 'dual':
                        opt_info.int_reward_total_std.append(
                            self.int_reward_rms.var**0.5)

                opt_info.entropy.append(entropy.item())
                opt_info.perplexity.append(perplexity.item())

                if self.policy_loss_type == 'dual':
                    opt_info.int_entropy.append(int_entropy.item())
                    opt_info.int_perplexity.append(int_perplexity.item())
                self.update_counter += 1

        opt_info.return_.append(
            torch.mean(return_.detach()).detach().clone().item())
        opt_info.advantage.append(
            torch.mean(advantage.detach()).detach().clone().item())
        opt_info.valpred.append(
            torch.mean(samples.agent.agent_info.value.detach()).detach().clone(
            ).item())

        if self.policy_loss_type == 'dual':
            opt_info.return_int_.append(
                torch.mean(return_int_.detach()).detach().clone().item())
            opt_info.advantage_int.append(
                torch.mean(advantage_int.detach()).detach().clone().item())
            opt_info.int_valpred.append(
                torch.mean(samples.agent.agent_info.int_value.detach()).detach(
                ).clone().item())

        if self.linear_lr_schedule:
            self.lr_scheduler.step()
            self.ratio_clip = self._ratio_clip * (self.n_itr -
                                                  itr) / self.n_itr

        layer_info = dict(
        )  # empty dict to store model layer weights for tensorboard visualizations

        return opt_info, layer_info
Ejemplo n.º 10
0
    def compute_minibatch_gradients(self, samples):
        """
        Train the agent, for multiple epochs over minibatches taken from the
        input samples.  Organizes agent inputs from the training data, and
        moves them to device (e.g. GPU) up front, so that minibatches are
        formed within device, without further data transfer.
        """
        recurrent = self.agent.recurrent
        agent_inputs = AgentInputs(  # Move inputs to device once, index there.
            observation=samples.env.observation,
            prev_action=samples.agent.prev_action,
            prev_reward=samples.env.prev_reward,
        )
        agent_inputs = buffer_to(agent_inputs, device=self.agent.device)
        return_, advantage, valid, value, reward, pre_reward = self.process_returns(
            samples)
        loss_inputs = LossInputs(  # So can slice all.
            agent_inputs=agent_inputs,
            action=samples.agent.action,
            return_=return_,
            old_value=value,
            advantage=advantage,
            valid=valid,
            old_dist_info=samples.agent.agent_info.dist_info,
        )
        if recurrent:
            # Leave in [B,N,H] for slicing to minibatches.
            init_rnn_state = samples.agent.agent_info.prev_rnn_state[0]  # T=0.
        T, B = samples.env.reward.shape[:2]

        # If recurrent, use whole trajectories, only shuffle B; else shuffle all.
        batch_size = T * B if self.agent.recurrent else T * B

        policy_gradients = []
        value_gradients = []
        all_value_diffs = []
        all_ratios = []

        for idxs in iterate_mb_idxs(batch_size, batch_size, shuffle=True):
            T_idxs = slice(None) if recurrent else idxs % T
            B_idxs = idxs if recurrent else idxs // T
            self.optimizer.zero_grad()
            rnn_state = init_rnn_state[B_idxs] if recurrent else None
            # NOTE: if not recurrent, will lose leading T dim, should be OK.
            loss, pi_loss, value_loss, entropy, perplexity, value_diffs, ratio = self.loss(
                *loss_inputs[T_idxs, B_idxs], rnn_state)
            loss.backward()
            # for i, p in enumerate(self.agent.parameters()):
            #     print(i, p.grad)
            # print([(i, p.shape) for i, p in enumerate(self.agent.parameters())])
            # first 7 is policy, last 6 is value network
            params = [
                p.grad.data.cpu().numpy().flatten()
                for p in self.agent.parameters()
            ]
            pg = np.concatenate(params[:7]).ravel()
            vg = np.concatenate(params[7:]).ravel()
            policy_gradients.append(pg)
            value_gradients.append(vg)
            # gradient = np.concatenate([p.grad.data.cpu().numpy().flatten() for p in self.agent.parameters()]).ravel()
            # gradients.append(gradient)
            all_value_diffs.extend(value_diffs.detach().numpy().flatten())
            all_ratios.extend(ratio.detach().numpy().flatten())
            self.update_counter += 1

        return policy_gradients, value_gradients, all_value_diffs, all_ratios, reward, pre_reward