Example #1
0
    def train(self):
        """Train the agent."""
        # logger
        if self.is_log:
            self.set_wandb()
            # wandb.watch([self.actor, self.critic], log="parameters")

        # pre-training if needed
        self.pretrain()

        for self.i_episode in range(1, self.episode_num + 1):
            state = self.env.reset()
            done = False
            score = 0
            self.episode_step = 0
            losses = list()

            t_begin = time.time()

            while not done:
                if self.is_render and self.i_episode >= self.render_after:
                    self.env.render()

                action = self.select_action(state)
                next_state, reward, done, _ = self.step(action)
                self.total_step += 1
                self.episode_step += 1

                if len(self.memory) >= self.hyper_params.batch_size:
                    for _ in range(self.hyper_params.multiple_update):
                        experience = self.memory.sample()
                        demos = self.demo_memory.sample()
                        experience, demos = (
                            numpy2floattensor(experience, self.learner.device),
                            numpy2floattensor(demos, self.learner.device),
                        )
                        loss = self.learner.update_model(experience, demos)
                        losses.append(loss)  # for logging

                state = next_state
                score += reward

            t_end = time.time()
            avg_time_cost = (t_end - t_begin) / self.episode_step

            # logging
            if losses:
                avg_loss = np.vstack(losses).mean(axis=0)
                log_value = (self.i_episode, avg_loss, score, avg_time_cost)
                self.write_log(log_value)
                losses.clear()

            if self.i_episode % self.save_period == 0:
                self.learner.save_params(self.i_episode)
                self.interim_test()

        # termination
        self.env.close()
        self.learner.save_params(self.i_episode)
        self.interim_test()
Example #2
0
    def sample_experience(self) -> Tuple[torch.Tensor, ...]:
        experience_1 = self.memory.sample(self.per_beta)
        if self.use_n_step:
            indices = experience_1[-2]
            experience_n = self.memory_n.sample(indices)
            return numpy2floattensor(experience_1), numpy2floattensor(
                experience_n)

        return numpy2floattensor(experience_1)
Example #3
0
    def sample_experience(self) -> Tuple[torch.Tensor, ...]:
        """Sample experience from replay buffer."""
        experiences_1 = self.memory.sample(self.per_beta)
        experiences_1 = numpy2floattensor(experiences_1[:6]) + experiences_1[6:]

        if self.use_n_step:
            indices = experiences_1[-2]
            experiences_n = self.memory_n.sample(indices)
            return experiences_1, numpy2floattensor(experiences_n)

        return experiences_1
Example #4
0
 def _update_model(self):
     # training
     if len(self.memory) >= self.hyper_params.sac_batch_size:
         for _ in range(self.hyper_params.multiple_update):
             experience = self.memory.sample()
             demos = self.demo_memory.sample()
             experience, demo = (
                 numpy2floattensor(experience),
                 numpy2floattensor(demos),
             )
             loss = self.learner.update_model(experience, demo)
             self.loss_episode.append(loss)  # for logging
Example #5
0
    def sample_experience(self) -> Tuple[torch.Tensor, ...]:
        experiences_1 = self.memory.sample(self.per_beta)
        experiences_1 = (common_utils.numpy2floattensor(
            experiences_1[:6], self.learner.device) + experiences_1[6:])
        if self.use_n_step:
            indices = experiences_1[-2]
            experiences_n = self.memory_n.sample(indices)
            return (
                experiences_1,
                common_utils.numpy2floattensor(experiences_n,
                                               self.learner.device),
            )

        return experiences_1
Example #6
0
    def sample_experience(self) -> Tuple[torch.Tensor, ...]:
        experiences_1 = self.memory.sample(self.per_beta)
        experiences_1 = (numpy2floattensor(experiences_1[:3]) +
                         (experiences_1[3], ) +
                         numpy2floattensor(experiences_1[4:6]) +
                         (experiences_1[6:]))
        if self.use_n_step:
            indices = experiences_1[-2]
            experiences_n = self.memory_n.sample(indices)
            return (
                experiences_1,
                numpy2floattensor(experiences_n[:3]) + (experiences_n[3], ) +
                numpy2floattensor(experiences_n[4:]),
            )

        return experiences_1
Example #7
0
    def step(self, action: torch.Tensor) -> Tuple[np.ndarray, np.float64, bool, dict]:
        next_state, reward, done, info = self.env.step(action.detach().cpu().numpy())

        if not self.is_test:
            # if the last state is not a terminal state, store done as false
            done_bool = done.copy()
            done_bool[np.where(self.episode_steps == self.max_episode_steps)] = False

            self.rewards.append(
                numpy2floattensor(reward, self.learner.device).unsqueeze(1)
            )
            self.masks.append(
                numpy2floattensor((1 - done_bool), self.learner.device).unsqueeze(1)
            )

        return next_state, reward, done, info
Example #8
0
    def run(self):
        """Run main training loop."""
        self.telapsed = 0
        while self.update_step < self.max_update_step:
            replay_data = self.recv_replay_data()
            if replay_data is not None:
                replay_data = (
                    numpy2floattensor(replay_data[:6], self.learner.device) +
                    replay_data[6:])
                info = self.update_model(replay_data)
                indices, new_priorities = info[-2:]
                step_info = info[:-2]
                self.update_step = self.update_step + 1

                self.send_new_priorities(indices, new_priorities)

                if self.update_step % self.worker_update_interval == 0:
                    state_dict = self.get_state_dict()
                    np_state_dict = state_dict2numpy(state_dict)
                    self.publish_params(self.update_step, np_state_dict)

                if self.update_step % self.logger_interval == 0:
                    state_dict = self.get_state_dict()
                    np_state_dict = state_dict2numpy(state_dict)
                    self.send_info_to_logger(np_state_dict, step_info)
                    self.learner.save_params(self.update_step)
Example #9
0
 def _preprocess_state(self, state: np.ndarray) -> torch.Tensor:
     """Preprocess state so that actor selects an action."""
     if self.hyper_params.use_her:
         self.desired_state = self.her.get_desired_state()
         state = np.concatenate((state, self.desired_state), axis=-1)
     state = numpy2floattensor(state, self.learner.device)
     return state
Example #10
0
    def add_expert_q(self):
        """Generate Q of gathered states using laoded agent."""
        self.make_distillation_dir()
        file_name_list = []

        for _dir in self.hyper_params.dataset_path:
            data = os.listdir(_dir)
            file_name_list += ["./" + _dir + "/" + x for x in data]

        for i in tqdm(range(len(file_name_list))):
            with open(file_name_list[i], "rb") as f:
                state = pickle.load(f)[0]

            torch_state = numpy2floattensor(state, self.device)
            pred_q = self.learner.dqn(
                torch_state).squeeze().detach().cpu().numpy()

            with open(self.save_distillation_dir + "/" + str(i) + ".pkl",
                      "wb") as f:
                pickle.dump([state, pred_q],
                            f,
                            protocol=pickle.HIGHEST_PROTOCOL)
        print(
            f"Data containing expert Q has been saved at {self.save_distillation_dir}"
        )
Example #11
0
    def forward_(self, state: torch.Tensor,
                 n_tau_samples: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """Get quantile values and quantiles."""
        batch_size = np.prod(state.size()) // self.input_size

        state_tiled = state.repeat(n_tau_samples, 1)

        # torch.rand (CPU) may make a segmentation fault due to its non-thread safety.
        # on v0.4.1
        # check: https://bit.ly/2TXlNbq
        quantiles = np.random.rand(n_tau_samples * batch_size, 1)
        quantiles = numpy2floattensor(quantiles, torch.device("cpu"))
        quantile_net = quantiles.repeat(1, self.quantile_embedding_dim)
        quantile_net = (torch.arange(
            1, self.quantile_embedding_dim + 1, dtype=torch.float) * math.pi *
                        quantile_net)
        quantile_net = torch.cos(quantile_net).to(device)
        quantile_net = F.relu(self.quantile_fc_layer(quantile_net))

        # Hadamard product
        quantile_net = state_tiled * quantile_net

        quantile_values = super(IQNMLP, self).forward(quantile_net)

        return quantile_values, quantiles.to(device)
Example #12
0
 def _update_model(self):
     if not self.args.test and len(
             self.memory) >= self.hyper_params.sac_batch_size:
         for _ in range(self.hyper_params.multiple_update):
             experience = self.memory.sample()
             experience = numpy2floattensor(experience)
             loss = self.learner.update_model(experience)
             self.loss_episode.append(loss)  # for logging
Example #13
0
 def select_action(self, state: np.ndarray) -> Tuple[int, torch.Tensor]:
     """Select action from input space."""
     state = numpy2floattensor(state, self.learner.device)
     with torch.no_grad():
         prob = F.softmax(self.learner.actor_target(state).squeeze(), 0) + 1e-8
     action_dist = Categorical(prob)
     selected_action = action_dist.sample().item()
     return selected_action, prob.cpu().numpy()
Example #14
0
    def train(self):
        """Train the agent."""
        # logger
        if self.args.log:
            self.set_wandb()
            # wandb.watch([self.actor, self.critic1, self.critic2], log="parameters")

        for self.i_episode in range(1, self.args.episode_num + 1):
            state = self.env.reset()
            done = False
            score = 0
            loss_episode = list()
            self.episode_step = 0

            t_begin = time.time()

            while not done:
                if self.args.render and self.i_episode >= self.args.render_after:
                    self.env.render()

                action = self.select_action(state)
                next_state, reward, done, _ = self.step(action)
                self.total_step += 1
                self.episode_step += 1

                state = next_state
                score += reward

                if len(self.memory) >= self.hyper_params.batch_size:
                    experience = self.memory.sample()
                    experience = numpy2floattensor(experience,
                                                   self.learner.device)
                    loss = self.learner.update_model(experience)
                    loss_episode.append(loss)  # for logging

            t_end = time.time()
            avg_time_cost = (t_end - t_begin) / self.episode_step

            # logging
            if loss_episode:
                avg_loss = np.vstack(loss_episode).mean(axis=0)
                log_value = (
                    self.i_episode,
                    avg_loss,
                    score,
                    self.hyper_params.policy_update_freq,
                    avg_time_cost,
                )
                self.write_log(log_value)
            if self.i_episode % self.args.save_period == 0:
                self.learner.save_params(self.i_episode)
                self.interim_test()

        # termination
        self.env.close()
        self.learner.save_params(self.i_episode)
        self.interim_test()
Example #15
0
 def synchronize(self, state_dict: Dict[str, np.ndarray]):
     """Copy parameters from numpy arrays."""
     param_name_list = list(state_dict.keys())
     for logger_named_param in self.brain.named_parameters():
         logger_param_name = logger_named_param[0]
         if logger_param_name in param_name_list:
             new_param = numpy2floattensor(state_dict[logger_param_name],
                                           self.device)
             logger_named_param[1].data.copy_(new_param)
 def _synchronize(self, network: Brain, new_state_dict: Dict[str,
                                                             np.ndarray]):
     """Copy parameters from numpy arrays."""
     param_name_list = list(new_state_dict.keys())
     for worker_named_param in network.named_parameters():
         worker_param_name = worker_named_param[0]
         if worker_param_name in param_name_list:
             new_param = numpy2floattensor(
                 new_state_dict[worker_param_name], self.device)
             worker_named_param[1].data.copy_(new_param)
Example #17
0
    def compute_priorities(self, memory: Dict[str, np.ndarray]) -> np.ndarray:
        """Compute initial priority values of experiences in local memory."""
        states = numpy2floattensor(memory["states"], self.device)
        actions = numpy2floattensor(memory["actions"], self.device).long()
        rewards = numpy2floattensor(memory["rewards"].reshape(-1, 1),
                                    self.device)
        next_states = numpy2floattensor(memory["next_states"], self.device)
        dones = numpy2floattensor(memory["dones"].reshape(-1, 1), self.device)
        memory_tensors = (states, actions, rewards, next_states, dones)

        with torch.no_grad():
            dq_loss_element_wise, _ = self.loss_fn(
                self.dqn,
                self.dqn,
                memory_tensors,
                self.hyper_params.gamma,
                self.head_cfg,
            )
        loss_for_prior = dq_loss_element_wise.detach().cpu().numpy()
        new_priorities = loss_for_prior + self.hyper_params.per_eps
        return new_priorities
Example #18
0
    def select_action(self, state: np.ndarray) -> torch.Tensor:
        """Select an action from the input space."""
        state = numpy2floattensor(state, self.learner.device)

        selected_action, dist = self.learner.actor(state)

        if self.is_test:
            selected_action = dist.mean
        else:
            predicted_value = self.learner.critic(state)
            log_prob = dist.log_prob(selected_action).sum(dim=-1)
            self.transition = []
            self.transition.extend([log_prob, predicted_value])

        return selected_action
Example #19
0
    def select_action(self, state: np.ndarray) -> torch.Tensor:
        """Select an action from the input space."""
        state = numpy2floattensor(state, self.learner.device)
        selected_action, dist = self.learner.actor(state)

        if self.args.test and not self.is_discrete:
            selected_action = dist.mean

        if not self.args.test:
            value = self.learner.critic(state)
            self.states.append(state)
            self.actions.append(selected_action)
            self.values.append(value)
            self.log_probs.append(dist.log_prob(selected_action))

        return selected_action
Example #20
0
    def select_action(self, state: np.ndarray) -> np.ndarray:
        """Select an action from the input space."""
        # initial training step, try random action for exploration
        self.curr_state = state

        if (self.total_step < self.hyper_params.initial_random_action
                and not self.args.test):
            return np.array(self.env_info.action_space.sample())

        with torch.no_grad():
            state = numpy2floattensor(state, self.learner.device)
            selected_action = self.learner.actor(state).detach().cpu().numpy()

        if not self.args.test:
            noise = self.exploration_noise.sample()
            selected_action = np.clip(selected_action + noise, -1.0, 1.0)

        return selected_action
Example #21
0
    def update_model(self, experience: TensorTuple) -> TensorTuple:
        """Update A2C actor and critic networks"""

        log_prob, pred_value, next_state, reward, done = experience
        next_state = numpy2floattensor(next_state, self.device)

        # Q_t   = r + gamma * V(s_{t+1})  if state != Terminal
        #       = r                       otherwise
        mask = 1 - done
        next_value = self.critic(next_state).detach()
        q_value = reward + self.hyper_params.gamma * next_value * mask
        q_value = q_value.to(self.device)

        # advantage = Q_t - V(s_t)
        advantage = q_value - pred_value

        # calculate loss at the current step
        policy_loss = -advantage.detach(
        ) * log_prob  # adv. is not backpropagated
        policy_loss += self.hyper_params.w_entropy * -log_prob  # entropy
        value_loss = F.smooth_l1_loss(pred_value, q_value.detach())

        # train
        gradient_clip_ac = self.hyper_params.gradient_clip_ac
        gradient_clip_cr = self.hyper_params.gradient_clip_cr

        self.actor_optim.zero_grad()
        policy_loss.backward()
        clip_grad_norm_(self.actor.parameters(), gradient_clip_ac)
        self.actor_optim.step()

        self.critic_optim.zero_grad()
        value_loss.backward()
        clip_grad_norm_(self.critic.parameters(), gradient_clip_cr)
        self.critic_optim.step()

        return policy_loss.item(), value_loss.item()
Example #22
0
    def select_action(self, state: np.ndarray) -> torch.Tensor:
        """Select an action from the input space."""
        with torch.no_grad():
            state = numpy2floattensor(state, self.learner.device)
            selected_action, dist = self.learner.actor(state)
            selected_action = selected_action.detach()
            log_prob = dist.log_prob(selected_action)
            value = self.learner.critic(state)

            if self.is_test:
                selected_action = (dist.logits.argmax()
                                   if self.is_discrete else dist.mean)

            else:
                _selected_action = (selected_action.unsqueeze(1)
                                    if self.is_discrete else selected_action)
                _log_prob = log_prob.unsqueeze(
                    1) if self.is_discrete else log_prob
                self.states.append(state)
                self.actions.append(_selected_action)
                self.values.append(value)
                self.log_probs.append(_log_prob)

        return selected_action.detach().cpu().numpy()
Example #23
0
    def update_model(
            self, experience: Tuple[torch.Tensor,
                                    ...]) -> Tuple[torch.Tensor, ...]:
        """Update TD3 actor and critic networks."""
        self.update_step += 1

        states, actions, rewards, next_states, dones = experience
        masks = 1 - dones

        # get actions with noise
        noise = common_utils.numpy2floattensor(
            self.target_policy_noise.sample(), self.device)
        clipped_noise = torch.clamp(
            noise,
            -self.noise_cfg.target_policy_noise_clip,
            self.noise_cfg.target_policy_noise_clip,
        )
        next_actions = (self.actor_target(next_states) + clipped_noise).clamp(
            -1.0, 1.0)

        # min (Q_1', Q_2')
        next_states_actions = torch.cat((next_states, next_actions), dim=-1)
        next_values1 = self.critic_target1(next_states_actions)
        next_values2 = self.critic_target2(next_states_actions)
        next_values = torch.min(next_values1, next_values2)

        # G_t   = r + gamma * v(s_{t+1})  if state != Terminal
        #       = r                       otherwise
        curr_returns = rewards + self.hyper_params.gamma * next_values * masks
        curr_returns = curr_returns.detach()

        # critic loss
        state_actions = torch.cat((states, actions), dim=-1)
        values1 = self.critic1(state_actions)
        values2 = self.critic2(state_actions)
        critic1_loss = F.mse_loss(values1, curr_returns)
        critic2_loss = F.mse_loss(values2, curr_returns)

        # train critic
        critic_loss = critic1_loss + critic2_loss
        self.critic_optim.zero_grad()
        critic_loss.backward()
        self.critic_optim.step()

        if self.update_step % self.hyper_params.policy_update_freq == 0:
            # policy loss
            actions = self.actor(states)
            state_actions = torch.cat((states, actions), dim=-1)
            actor_loss = -self.critic1(state_actions).mean()

            # train actor
            self.actor_optim.zero_grad()
            actor_loss.backward()
            self.actor_optim.step()

            # update target networks
            tau = self.hyper_params.tau
            common_utils.soft_update(self.critic1, self.critic_target1, tau)
            common_utils.soft_update(self.critic2, self.critic_target2, tau)
            common_utils.soft_update(self.actor, self.actor_target, tau)
        else:
            actor_loss = torch.zeros(1)

        return actor_loss.item(), critic1_loss.item(), critic2_loss.item()
Example #24
0
 def _preprocess_state(state: np.ndarray,
                       device: torch.device) -> torch.Tensor:
     state = numpy2floattensor(state, device)
     return state
Example #25
0
 def _preprocess_state(self, state: np.ndarray) -> torch.Tensor:
     """Preprocess state so that actor selects an action."""
     state = numpy2floattensor(state, self.learner.device)
     return state
Example #26
0
    def update_model(self, experience: TensorTuple,
                     epsilon: float) -> TensorTuple:
        """Update PPO actor and critic networks"""
        states, actions, rewards, values, log_probs, next_state, masks = experience
        next_state = numpy2floattensor(next_state, self.device)
        next_value = self.critic(next_state)

        returns = ppo_utils.compute_gae(
            next_value,
            rewards,
            masks,
            values,
            self.hyper_params.gamma,
            self.hyper_params.tau,
        )

        states = torch.cat(states)
        actions = torch.cat(actions)
        returns = torch.cat(returns).detach()
        values = torch.cat(values).detach()
        log_probs = torch.cat(log_probs).detach()
        advantages = returns - values

        if self.is_discrete:
            actions = actions.unsqueeze(1)
            log_probs = log_probs.unsqueeze(1)

        if self.hyper_params.standardize_advantage:
            advantages = (advantages - advantages.mean()) / (advantages.std() +
                                                             1e-7)

        actor_losses, critic_losses, total_losses = [], [], []

        for state, action, old_value, old_log_prob, return_, adv in ppo_utils.ppo_iter(
                self.hyper_params.epoch,
                self.hyper_params.batch_size,
                states,
                actions,
                values,
                log_probs,
                returns,
                advantages,
        ):
            # calculate ratios
            _, dist = self.actor(state)
            log_prob = dist.log_prob(action)
            ratio = (log_prob - old_log_prob).exp()

            # actor_loss
            surr_loss = ratio * adv
            clipped_surr_loss = torch.clamp(ratio, 1.0 - epsilon,
                                            1.0 + epsilon) * adv
            actor_loss = -torch.min(surr_loss, clipped_surr_loss).mean()

            # critic_loss
            value = self.critic(state)
            if self.hyper_params.use_clipped_value_loss:
                value_pred_clipped = old_value + torch.clamp(
                    (value - old_value), -epsilon, epsilon)
                value_loss_clipped = (return_ - value_pred_clipped).pow(2)
                value_loss = (return_ - value).pow(2)
                critic_loss = 0.5 * torch.max(value_loss,
                                              value_loss_clipped).mean()
            else:
                critic_loss = 0.5 * (return_ - value).pow(2).mean()

            # entropy
            entropy = dist.entropy().mean()

            # total_loss
            w_value = self.hyper_params.w_value
            w_entropy = self.hyper_params.w_entropy

            critic_loss_ = w_value * critic_loss
            actor_loss_ = actor_loss - w_entropy * entropy
            total_loss = critic_loss_ + actor_loss_

            # train critic
            gradient_clip_ac = self.hyper_params.gradient_clip_ac
            gradient_clip_cr = self.hyper_params.gradient_clip_cr

            self.critic_optim.zero_grad()
            critic_loss_.backward(retain_graph=True)
            clip_grad_norm_(self.critic.parameters(), gradient_clip_ac)
            self.critic_optim.step()

            # train actor
            self.actor_optim.zero_grad()
            actor_loss_.backward()
            clip_grad_norm_(self.actor.parameters(), gradient_clip_cr)
            self.actor_optim.step()

            actor_losses.append(actor_loss.item())
            critic_losses.append(critic_loss.item())
            total_losses.append(total_loss.item())

        actor_loss = sum(actor_losses) / len(actor_losses)
        critic_loss = sum(critic_losses) / len(critic_losses)
        total_loss = sum(total_losses) / len(total_losses)

        return actor_loss, critic_loss, total_loss
Example #27
0
    def update_model(self, experience: TensorTuple,
                     epsilon: float) -> TensorTuple:
        """Update generator(actor), critic and discriminator networks."""
        states, actions, rewards, values, log_probs, next_state, masks = experience
        next_state = numpy2floattensor(next_state, self.device)
        with torch.no_grad():
            next_value = self.critic(next_state)

        returns = ppo_utils.compute_gae(
            next_value,
            rewards,
            masks,
            values,
            self.hyper_params.gamma,
            self.hyper_params.tau,
        )

        states = torch.cat(states)
        actions = torch.cat(actions)
        returns = torch.cat(returns).detach()
        values = torch.cat(values).detach()
        log_probs = torch.cat(log_probs).detach()
        advantages = (returns - values).detach()

        if self.hyper_params.standardize_advantage:
            advantages = (advantages - advantages.mean()) / (advantages.std() +
                                                             1e-7)

        actor_losses, critic_losses, total_losses, discriminator_losses = [], [], [], []

        for (
                state,
                action,
                old_value,
                old_log_prob,
                return_,
                adv,
                epoch,
        ) in ppo_utils.ppo_iter(
                self.hyper_params.epoch,
                self.hyper_params.batch_size,
                states,
                actions,
                values,
                log_probs,
                returns,
                advantages,
        ):

            # critic_loss
            value = self.critic(state)
            if self.hyper_params.use_clipped_value_loss:
                value_pred_clipped = old_value + torch.clamp(
                    (value - old_value), -epsilon, epsilon)
                value_loss_clipped = (return_ - value_pred_clipped).pow(2)
                value_loss = (return_ - value).pow(2)
                critic_loss = 0.5 * torch.max(value_loss,
                                              value_loss_clipped).mean()
            else:
                critic_loss = 0.5 * (return_ - value).pow(2).mean()
            critic_loss_ = self.hyper_params.w_value * critic_loss

            # train critic
            self.critic_optim.zero_grad()
            critic_loss_.backward()
            clip_grad_norm_(self.critic.parameters(),
                            self.hyper_params.gradient_clip_cr)
            self.critic_optim.step()

            # calculate ratios
            _, dist = self.actor(state)
            log_prob = dist.log_prob(action)
            ratio = (log_prob - old_log_prob).exp()

            # actor_loss
            surr_loss = ratio * adv
            clipped_surr_loss = torch.clamp(ratio, 1.0 - epsilon,
                                            1.0 + epsilon) * adv
            actor_loss = -torch.min(surr_loss, clipped_surr_loss).mean()

            # entropy
            entropy = dist.entropy().mean()
            actor_loss_ = actor_loss - self.hyper_params.w_entropy * entropy

            # train actor
            self.actor_optim.zero_grad()
            actor_loss_.backward()
            clip_grad_norm_(self.actor.parameters(),
                            self.hyper_params.gradient_clip_ac)
            self.actor_optim.step()

            # total_loss
            total_loss = critic_loss_ + actor_loss_

            # discriminator loss
            demo_state, demo_action = self.demo_memory.sample(len(state))
            exp_score = torch.sigmoid(
                self.discriminator.forward((state, action)))
            demo_score = torch.sigmoid(
                self.discriminator.forward((demo_state, demo_action)))
            discriminator_exp_acc = (exp_score > 0.5).float().mean().item()
            discriminator_demo_acc = (demo_score <= 0.5).float().mean().item()
            discriminator_loss = F.binary_cross_entropy(
                exp_score,
                torch.ones_like(exp_score)) + F.binary_cross_entropy(
                    demo_score, torch.zeros_like(demo_score))

            # train discriminator
            if (discriminator_exp_acc <
                    self.optim_cfg.discriminator_acc_threshold
                    or discriminator_demo_acc <
                    self.optim_cfg.discriminator_acc_threshold and epoch == 0):
                self.discriminator_optim.zero_grad()
                discriminator_loss.backward()
                self.discriminator_optim.step()

            actor_losses.append(actor_loss.item())
            critic_losses.append(critic_loss.item())
            total_losses.append(total_loss.item())
            discriminator_losses.append(discriminator_loss.item())

        actor_loss = sum(actor_losses) / len(actor_losses)
        critic_loss = sum(critic_losses) / len(critic_losses)
        total_loss = sum(total_losses) / len(total_losses)
        discriminator_loss = sum(discriminator_losses) / len(
            discriminator_losses)

        return (
            (actor_loss, critic_loss, total_loss, discriminator_loss),
            (discriminator_exp_acc, discriminator_demo_acc),
        )
Example #28
0
    def train(self):
        """Train the agent."""
        # logger
        if self.is_log:
            self.set_wandb()
            # wandb.watch([self.actor, self.critic], log="parameters")

        score = 0
        i_episode_prev = 0
        loss = [0.0, 0.0, 0.0, 0.0]
        discriminator_acc = [0.0, 0.0]
        state = self.env.reset()

        while self.i_episode <= self.episode_num:
            for _ in range(self.hyper_params.rollout_len):
                if self.is_render and self.i_episode >= self.render_after:
                    self.env.render()

                action = self.select_action(state)
                next_state, task_reward, done, _ = self.step(action)

                # gail reward (imitation reward)
                gail_reward = compute_gail_reward(
                    self.learner.discriminator((
                        numpy2floattensor(state, self.learner.device),
                        numpy2floattensor(action, self.learner.device),
                    )))

                # hybrid reward
                # Reference: https://arxiv.org/abs/1802.09564
                reward = (
                    self.hyper_params.gail_reward_weight * gail_reward +
                    (1.0 - self.hyper_params.gail_reward_weight) * task_reward)

                if not self.is_test:
                    # if the last state is not a terminal state, store done as false
                    done_bool = done.copy()
                    done_bool[np.where(
                        self.episode_steps == self.max_episode_steps)] = False

                    self.rewards.append(
                        numpy2floattensor(reward,
                                          self.learner.device).unsqueeze(1))
                    self.masks.append(
                        numpy2floattensor((1 - done_bool),
                                          self.learner.device).unsqueeze(1))

                self.episode_steps += 1

                state = next_state
                score += task_reward[0]
                i_episode_prev = self.i_episode
                self.i_episode += done.sum()

                if (self.i_episode // self.save_period) != (i_episode_prev //
                                                            self.save_period):
                    self.learner.save_params(self.i_episode)

                if done[0]:
                    n_step = self.episode_steps[0]
                    log_value = (
                        self.i_episode,
                        n_step,
                        score,
                        gail_reward,
                        loss[0],
                        loss[1],
                        loss[2],
                        loss[3],
                        discriminator_acc[0],
                        discriminator_acc[1],
                    )
                    self.write_log(log_value)
                    score = 0

                self.episode_steps[np.where(done)] = 0
            self.next_state = next_state
            loss, discriminator_acc = self.learner.update_model(
                (
                    self.states,
                    self.actions,
                    self.rewards,
                    self.values,
                    self.log_probs,
                    self.next_state,
                    self.masks,
                ),
                self.epsilon,
            )
            self.states, self.actions, self.rewards = [], [], []
            self.values, self.masks, self.log_probs = [], [], []
            self.decay_epsilon(self.i_episode)

        # termination
        self.env.close()
        self.learner.save_params(self.i_episode)
Example #29
0
    def scale_noise(size: int) -> torch.Tensor:
        """Set scale to make noise (factorized gaussian noise)."""
        x = numpy2floattensor(np.random.normal(loc=0.0, scale=1.0, size=size), device)

        return x.sign().mul(x.abs().sqrt())