예제 #1
0
    def update(self):
        new_count = MPIUtil.reduce_sum(self.new_count)
        new_sum = MPIUtil.reduce_sum(self.new_sum)
        new_sum_sq = MPIUtil.reduce_sum(self.new_sum_sq)

        new_total = self.count + new_count
        if (self.count // self.CHECK_SYNC_COUNT !=
                new_total // self.CHECK_SYNC_COUNT):
            assert self.check_synced(), Logger.print(
                'Normalizer parameters desynchronized')

        if new_count > 0:
            new_mean = self._process_group_data(new_sum / new_count, self.mean)
            new_mean_sq = self._process_group_data(new_sum_sq / new_count,
                                                   self.mean_sq)
            w_old = float(self.count) / new_total
            w_new = float(new_count) / new_total

            self.mean = w_old * self.mean + w_new * new_mean
            self.mean_sq = w_old * self.mean_sq + w_new * new_mean_sq
            self.count = new_total
            self.std = self.calc_std(self.mean, self.mean_sq)

            self.new_count = 0
            self.new_sum.fill(0)
            self.new_sum_sq.fill(0)

        return
    def _update_mode_test(self):
        if (self.test_episode_count * MPIUtil.get_num_procs() >= self.test_episodes):
            global_return = MPIUtil.reduce_sum(self.test_return)
            global_count = MPIUtil.reduce_sum(self.test_episode_count)
            avg_return = global_return / global_count
            self.avg_test_return = avg_return

            if self.enable_training:
                self._init_mode_train()
        return
예제 #3
0
    def _update_test_return(self, path):
        path_reward = path.calc_return()
        self.test_return += path_reward
        self.test_episode_count += 1

        if (self.test_episode_count * MPIUtil.get_num_procs() >=
                self.test_episodes):
            global_return = MPIUtil.reduce_sum(self.test_return)
            global_count = MPIUtil.reduce_sum(self.test_episode_count)
            avg_return = global_return / global_count
            self.avg_test_return = avg_return

            if self.enable_training:
                self._init_mode_train()

        return
예제 #4
0
    def _update_disc(self):
        info = None

        num_procs = mpi_util.get_num_procs()
        local_expert_batch_size = int(np.ceil(self._disc_batchsize /
                                              num_procs))
        local_agent_batch_size = local_expert_batch_size

        steps_per_batch = self._disc_steps_per_batch
        local_sample_count = self.replay_buffer.get_current_size()
        global_sample_count = int(mpi_util.reduce_sum(local_sample_count))

        num_steps = int(
            np.ceil(steps_per_batch * global_sample_count /
                    (num_procs * local_expert_batch_size)))

        for b in range(num_steps):
            disc_expert_batch = self._disc_expert_buffer.sample(
                local_expert_batch_size)
            obs_expert = self._disc_expert_buffer.get(disc_expert_batch)

            disc_agent_batch = self._disc_agent_buffer.sample(
                local_agent_batch_size)
            obs_agent = self._disc_agent_buffer.get(disc_agent_batch)

            curr_info = self._step_disc(obs_expert=obs_expert,
                                        obs_agent=obs_agent)

            if (info is None):
                info = curr_info
            else:
                for k, v in curr_info.items():
                    info[k] += v

        for k in info.keys():
            info[k] /= num_steps

        return info
예제 #5
0
 def _valid_train_step(self):
     samples = self.replay_buffer.get_current_size()
     exp_samples = self.replay_buffer.count_filtered(self.EXP_ACTION_FLAG)
     global_sample_count = int(MPIUtil.reduce_sum(samples))
     global_exp_min = int(MPIUtil.reduce_min(exp_samples))
     return (global_sample_count > self.batch_size) and (global_exp_min > 0)
예제 #6
0
    def _train_step(self):
        adv_eps = 1e-5

        start_idx = self.replay_buffer.buffer_tail
        end_idx = self.replay_buffer.buffer_head
        assert (start_idx == 0)
        assert (self.replay_buffer.get_current_size() <=
                self.replay_buffer.buffer_size)  # must avoid overflow
        assert (start_idx < end_idx)

        idx = np.array(list(range(start_idx, end_idx)))
        end_mask = self.replay_buffer.is_path_end(idx)
        end_mask = np.logical_not(end_mask)

        vals = self._compute_batch_vals(start_idx, end_idx)
        new_vals = self._compute_batch_new_vals(start_idx, end_idx, vals)

        valid_idx = idx[end_mask]
        exp_idx = self.replay_buffer.get_idx_filtered(
            self.EXP_ACTION_FLAG).copy()
        num_valid_idx = valid_idx.shape[0]
        num_exp_idx = exp_idx.shape[0]
        exp_idx = np.column_stack(
            [exp_idx,
             np.array(list(range(0, num_exp_idx)), dtype=np.int32)])

        local_sample_count = valid_idx.size
        global_sample_count = int(MPIUtil.reduce_sum(local_sample_count))
        mini_batches = int(np.ceil(global_sample_count / self.mini_batch_size))

        adv = new_vals[exp_idx[:, 0]] - vals[exp_idx[:, 0]]
        new_vals = np.clip(new_vals, self.val_min, self.val_max)

        adv_mean = np.mean(adv)
        adv_std = np.std(adv)
        adv = (adv - adv_mean) / (adv_std + adv_eps)
        adv = np.clip(adv, -self.norm_adv_clip, self.norm_adv_clip)

        critic_loss = 0
        actor_loss = 0
        actor_clip_frac = 0

        for e in range(self.epochs):
            np.random.shuffle(valid_idx)
            np.random.shuffle(exp_idx)

            for b in range(mini_batches):
                batch_idx_beg = b * self._local_mini_batch_size
                batch_idx_end = batch_idx_beg + self._local_mini_batch_size

                critic_batch = np.array(range(batch_idx_beg, batch_idx_end),
                                        dtype=np.int32)
                actor_batch = critic_batch.copy()
                critic_batch = np.mod(critic_batch, num_valid_idx)
                actor_batch = np.mod(actor_batch, num_exp_idx)
                shuffle_actor = (actor_batch[-1] < actor_batch[0]) or (
                    actor_batch[-1] == num_exp_idx - 1)

                critic_batch = valid_idx[critic_batch]
                actor_batch = exp_idx[actor_batch]
                critic_batch_vals = new_vals[critic_batch]
                actor_batch_adv = adv[actor_batch[:, 1]]

                critic_s = self.replay_buffer.get('states', critic_batch)
                critic_g = self.replay_buffer.get(
                    'goals', critic_batch) if self.has_goal() else None
                curr_critic_loss = self._update_critic(critic_s, critic_g,
                                                       critic_batch_vals)

                actor_s = self.replay_buffer.get("states", actor_batch[:, 0])
                actor_g = self.replay_buffer.get(
                    "goals", actor_batch[:, 0]) if self.has_goal() else None
                actor_a = self.replay_buffer.get("actions", actor_batch[:, 0])
                actor_logp = self.replay_buffer.get("logps", actor_batch[:, 0])
                curr_actor_loss, curr_actor_clip_frac = self._update_actor(
                    actor_s, actor_g, actor_a, actor_logp, actor_batch_adv)

                critic_loss += curr_critic_loss
                actor_loss += np.abs(curr_actor_loss)
                actor_clip_frac += curr_actor_clip_frac

                if (shuffle_actor):
                    np.random.shuffle(exp_idx)

        total_batches = mini_batches * self.epochs
        critic_loss /= total_batches
        actor_loss /= total_batches
        actor_clip_frac /= total_batches

        critic_loss = MPIUtil.reduce_avg(critic_loss)
        actor_loss = MPIUtil.reduce_avg(actor_loss)
        actor_clip_frac = MPIUtil.reduce_avg(actor_clip_frac)

        critic_stepsize = self.critic_solver.get_stepsize()
        actor_stepsize = self.update_actor_stepsize(actor_clip_frac)

        self.logger.log_tabular('Critic_Loss', critic_loss)
        self.logger.log_tabular('Critic_Stepsize', critic_stepsize)
        self.logger.log_tabular('Actor_Loss', actor_loss)
        self.logger.log_tabular('Actor_Stepsize', actor_stepsize)
        self.logger.log_tabular('Clip_Frac', actor_clip_frac)
        self.logger.log_tabular('Adv_Mean', adv_mean)
        self.logger.log_tabular('Adv_Std', adv_std)

        self.replay_buffer.clear()

        return
예제 #7
0
    def _train(self):
        with self.sess.as_default(), self.graph.as_default():
            samples = self.replay_buffer.total_count
            self._total_sample_count = int(MPIUtil.reduce_sum(samples)) + self._init_total_sample_count
            end_training = False

            if (self.replay_buffer_initialized):
                if (self._valid_train_step()):
                    prev_iter = self.iter
                    iters = self._get_iters_per_update()
                    avg_train_return = MPIUtil.reduce_avg(self.train_return)
                    avg_train_mean_reward = MPIUtil.reduce_avg(self.train_mean_reward)
                    avg_train_pathlen = MPIUtil.reduce_avg(self.train_pathlen)
                    avg_train_pathlen /= 30 # policy is executed in 30Hz

                    for _ in range(iters):
                        curr_iter = self.iter
                        wall_time = time.time() - self.start_time
                        wall_time /= 60 * 60 # store time in hours

                        has_goal = self.has_goal()
                        s_mean = np.mean(self.s_norm.mean)
                        s_std = np.mean(self.s_norm.std)
                        g_mean = np.mean(self.g_norm.mean) if has_goal else 0
                        g_std = np.mean(self.g_norm.std) if has_goal else 0

                        self.logger.log_tabular("Iteration", self.iter)
                        self.logger.log_tabular("Wall_Time", wall_time)
                        self.logger.log_tabular("Samples", self._total_sample_count)
                        self.logger.log_tabular("Train_Path_Length", avg_train_pathlen)
                        self.logger.log_tabular("Train_Mean_Reward", avg_train_mean_reward)
                        self.logger.log_tabular("Train_Return", avg_train_return)
                        self.logger.log_tabular("Test_Return", self.avg_test_return)
                        self.logger.log_tabular("State_Mean", s_mean)
                        self.logger.log_tabular("State_Std", s_std)
                        self.logger.log_tabular("Goal_Mean", g_mean)
                        self.logger.log_tabular("Goal_Std", g_std)
                        self._log_exp_params()

                        self._update_iter(self.iter + 1)

                        train_start_time = time.time()
                        self._train_step()
                        train_time = time.time() - train_start_time

                        if self.iter == 1:
                            iteration_time = time.time() - self.start_time
                        else:
                            iteration_time = time.time() - self.iter_start_time
                        self.iter_start_time = time.time()

                        self.logger.log_tabular("Train_time", train_time)
                        self.logger.log_tabular("Simulation_time",  iteration_time - train_time)

                        Logger.print("Agent " + str(self.id))
                        self.logger.print_tabular()
                        Logger.print("")

                        if (self._enable_output() and curr_iter % self.int_output_iters == 0):
                            self.logger.dump_tabular()


                    if (prev_iter // self.int_output_iters != self.iter // self.int_output_iters):
                        end_training = self.enable_testing()

            else:

                Logger.print("Agent " + str(self.id))
                Logger.print("Samples: " + str(self._total_sample_count))
                Logger.print("")

                if (self._total_sample_count >= self.init_samples):
                    self.replay_buffer_initialized = True
                    end_training = self.enable_testing()

            if self._need_normalizer_update and self._enable_update_normalizer:
                self._update_normalizers()
                self._need_normalizer_update = self.normalizer_samples > self._total_sample_count

            if end_training:
                self._init_mode_train_end()
        return
예제 #8
0
    def _train(self):
        samples = self.replay_buffer.total_count
        self._total_sample_count = int(MPIUtil.reduce_sum(samples))
        end_training = False

        if (self.replay_buffer_initialized):
            if (self._valid_train_step()):
                prev_iter = self.iter
                iters = self._get_iters_per_update()
                avg_train_return = MPIUtil.reduce_avg(self.train_return)

                for i in range(iters):
                    curr_iter = self.iter
                    wall_time = time.time() - self.start_time
                    wall_time /= 60 * 60  # store time in hours

                    has_goal = self.has_goal()
                    s_mean = np.mean(self._s_norm.mean)
                    s_std = np.mean(self._s_norm.std)
                    g_mean = np.mean(self._g_norm.mean) if has_goal else 0
                    g_std = np.mean(self._g_norm.std) if has_goal else 0

                    self.logger.log_tabular("Iteration", self.iter)
                    self.logger.log_tabular("Wall_Time", wall_time)
                    self.logger.log_tabular("Samples",
                                            self._total_sample_count)
                    self.logger.log_tabular("Train_Return", avg_train_return)
                    self.logger.log_tabular("Test_Return",
                                            self.avg_test_return)
                    self.logger.log_tabular("State_Mean", s_mean)
                    self.logger.log_tabular("State_Std", s_std)
                    self.logger.log_tabular("Goal_Mean", g_mean)
                    self.logger.log_tabular("Goal_Std", g_std)
                    self._log_exp_params()

                    self._update_iter(self.iter + 1)
                    self._train_step()

                    Logger.print("Agent " + str(self.id))
                    self.logger.print_tabular()
                    Logger.print("")

                    if (self._enable_output()
                            and curr_iter % self.int_output_iters == 0):
                        self.logger.dump_tabular()

                if (prev_iter // self.int_output_iters !=
                        self.iter // self.int_output_iters):
                    end_training = self.enable_testing()

        else:
            Logger.print("Agent " + str(self.id))
            Logger.print("Samples: " + str(self._total_sample_count))
            Logger.print("")

            if (self._total_sample_count >= self.init_samples):
                self.replay_buffer_initialized = True
                end_training = self.enable_testing()

        if self._need_normalizer_update:
            self._update_normalizers()
            self._need_normalizer_update = self.normalizer_samples > self._total_sample_count

        if end_training:
            self._init_mode_test()

        return