def main():
    args = get_args()
    

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.set_num_threads(1)

    device = torch.device(args.device)
    utils.cleanup_log_dir(args.log_dir)
    env_make = make_pybullet_env(args.task, log_dir=args.log_dir, frame_skip=args.frame_skip)
    envs = make_vec_envs(env_make, args.num_processes, args.log_dir, device, args.frame_stack)
    actor_critic = MetaPolicy(envs.observation_space, envs.action_space)
    loss_writer = LossWriter(args.log_dir, fieldnames= ('V_loss','action_loss','meta_action_loss','meta_value_loss','meta_loss', 'loss'))

    if args.restart_model:
        actor_critic.load_state_dict(
            torch.load(args.restart_model, map_location=device))

    actor_critic.to(device)

    agent = MetaPPO(
        actor_critic, args.clip_param, args.ppo_epoch,
        args.num_mini_batch, args.value_loss_coef,
        args.entropy_coef, lr=args.lr, eps=args.eps,
        max_grad_norm=args.max_grad_norm)

    obs = envs.reset()
    rollouts = RolloutStorage(args.num_steps, args.num_processes, obs, envs.action_space, actor_critic.recurrent_hidden_state_size)
    rollouts.to(device)  # they live in GPU, converted to torch from the env wrapper

    start = time.time()
    num_updates = int(args.num_env_steps) // args.num_steps // args.num_processes

    for j in range(num_updates):

        ppo_rollout(args.num_steps, envs, actor_critic, rollouts)

        value_loss, meta_value_loss, action_loss, meta_action_loss, loss, meta_loss = ppo_update(
            agent, actor_critic, rollouts, args.use_gae, args.gamma, args.gae_lambda)
        
        loss_writer.write_row({'V_loss': value_loss.item(), 'action_loss': action_loss.item(), 'meta_action_loss':meta_action_loss.item(),'meta_value_loss':meta_value_loss.item(),'meta_loss': meta_loss.item(), 'loss': loss.item()} )
        
        if (j % args.save_interval == 0 or j == num_updates - 1) and args.log_dir != "":
            ppo_save_model(actor_critic, os.path.join(args.log_dir, "model.state_dict"), j)

        if j % args.log_interval == 0:
            total_num_steps = (j + 1) * args.num_processes * args.num_steps
            s = "Update {}, num timesteps {}, FPS {} \n".format(
                j, total_num_steps, int(total_num_steps / (time.time() - start)))
            s += "Loss {:.5f}, meta loss {:.5f}, value_loss {:.5f}, meta_value_loss {:.5f}, action_loss {:.5f}, meta action loss {:.5f}".format(
                loss.item(), meta_loss.item(), value_loss.item(), meta_value_loss.item(), action_loss.item(), meta_action_loss.item())
            print(s, flush=True)
Beispiel #2
0
def main():
    parser = otc_arg_parser()
    # args = get_args()
    args = parser.parse_args()
    args.cuda = not args.no_cuda and torch.cuda.is_available()

    assert args.algo in ['a2c', 'ppo', 'acktr']
    if args.recurrent_policy:
        assert args.algo in ['a2c', 'ppo'], \
            'Recurrent policy is not implemented for ACKTR'

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    if args.cuda and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    log_dir = os.path.expanduser(args.log_dir)
    tf_log_dir = os.path.join(log_dir, args.exp_name)
    if not os.path.exists(tf_log_dir):
        os.makedirs(tf_log_dir)
    writer = SummaryWriter(log_dir=tf_log_dir)
    eval_log_dir = log_dir + "_eval"
    # history_file = os.path.join(log_dir, args.exp_name+'.csv')

    torch.set_num_threads(1)
    # device = torch.device("cuda:0" if args.cuda else "cpu")
    device = torch.device("cuda" if args.cuda else "cpu")

    # envs = make_vec_envs(args.env, args.seed, args.num_processes,
    #                      args.gamma, args.log_dir, device, False)
    envs = make_otc_env(args, device)

    save_path = os.path.join(args.save_dir, args.exp_name)
    if args.load:
        actor_critic, ob_rms = \
                torch.load(
                    os.path.join(save_path, args.env + ".pt"))
        vec_norm = get_vec_normalize(envs)
        if vec_norm is not None:
            vec_norm.eval()
            vec_norm.ob_rms = ob_rms
    else:
        obs_shape = envs.observation_space.spaces['visual'].shape
        vector_obs_len = envs.observation_space.spaces['vector'].shape[0]
        actor_critic = Policy(obs_shape,
                              envs.action_space,
                              base=CNNBase,
                              base_kwargs={'recurrent': args.recurrent_policy},
                              vector_obs_len=vector_obs_len)
    if torch.cuda.device_count() > 1:
        actor_critic_parallel = nn.DataParallel(actor_critic,
                                                device_ids=[0, 1])
        actor_critic = actor_critic_parallel.module
    if args.half_precision:
        actor_critic.half()  # convert to half precision
        for layer in actor_critic.modules():
            if isinstance(layer, nn.BatchNorm2d):
                layer.float()
    actor_critic.to(device, non_blocking=True)
    from pytorch_wrappers import VecPyTorch  #,  VecPyTorchFrameStack
    envs = VecPyTorch(envs, device, half_precision=args.half_precision)
    # envs = VecPyTorchFrameStack(envs, 1, device)

    if args.algo == 'a2c':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               lr=args.lr,
                               eps=args.eps,
                               alpha=args.alpha,
                               max_grad_norm=args.max_grad_norm)
    elif args.algo == 'ppo':
        agent = algo.PPO(actor_critic,
                         args.clip_param,
                         args.ppo_epoch,
                         args.num_mini_batch,
                         args.value_loss_coef,
                         args.entropy_coef,
                         lr=args.lr,
                         eps=args.eps,
                         max_grad_norm=args.max_grad_norm)
    elif args.algo == 'acktr':
        agent = algo.A2C_ACKTR(actor_critic,
                               args.value_loss_coef,
                               args.entropy_coef,
                               acktr=True)

    if args.gail:
        assert len(envs.observation_space.shape) == 1
        discr = gail.Discriminator(
            envs.observation_space.shape[0] + envs.action_space.shape[0], 100,
            device)
        file_name = os.path.join(
            args.gail_experts_dir,
            "trajs_{}.pt".format(args.env.split('-')[0].lower()))

        gail_train_loader = torch.utils.data.DataLoader(
            gail.ExpertDataset(file_name,
                               num_trajectories=4,
                               subsample_frequency=20),
            batch_size=args.gail_batch_size,
            shuffle=True,
            drop_last=True)

    rollouts = RolloutStorage(args.num_steps, args.num_processes,
                              envs.observation_space.shape,
                              ([envs.vector_obs_len]), envs.action_space,
                              actor_critic.recurrent_hidden_state_size)
    if args.half_precision:
        rollouts.half()
    obs, vector_obs = envs.reset()
    rollouts.obs[0].copy_(obs)
    rollouts.vector_obs[0].copy_(vector_obs)
    rollouts.to(device)

    episode_rewards = deque(maxlen=100)
    episode_floors = deque(maxlen=100)
    episode_times = deque(maxlen=100)
    # history_column_names = ['AgentId', 'Start', 'Seed', 'Floor', 'Reward', 'Steps', 'Time']
    # history_column_types = {'AgentId':np.int, 'Start':np.int, 'Seed':np.int, 'Floor':np.int, 'Reward':np.float, 'Steps':np.int, 'Time':np.float}
    # try:
    #     history_df = pd.read_csv(history_file, dtype={'AgentId':np.int, 'Start': np.int,'Seed':np.int,'Floor': np.int,'Steps':np.int},)
    # except FileNotFoundError:
    #     history_df = pd.DataFrame(columns = history_column_names).astype( dtype=history_column_types)
    #     history_df.to_csv(history_file, encoding='utf-8', index=False)

    start = time.time()
    num_updates = int(
        args.num_env_steps) // args.num_steps // args.num_processes
    for j in range(num_updates):

        if args.use_linear_lr_decay:
            # decrease learning rate linearly
            utils.update_linear_schedule(
                agent.optimizer, j, num_updates,
                agent.optimizer.lr if args.algo == "acktr" else args.lr)

        for step in range(args.num_steps):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                    rollouts.obs[step], rollouts.vector_obs[step],
                    rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step])

            # action_cpu = action.cpu() # send a copy to the cpu

            # Obser reward and next obs
            obs, vector_obs, reward, done, infos = envs.step(action)

            # for i in range(len(action)):
            #     info = infos[i]
            #     # actual_action = action if 'actual_action' not in info.keys() else info['actual_action']
            #     # action[i][0]=int(actual_action)
            #     if 'actual_action' in info.keys() and int(info['actual_action']) != int(action_cpu[i][0]):
            #         action[i][0]=int(info['actual_action'])

            history_is_dirty = False
            for info in infos:
                if 'episode' in info.keys():
                    episode_rewards.append(info['episode']['r'])
                    episode_floors.append(int(info['episode']['floor']))
                    episode_times.append(info['episode']['l'])
            #         data = [int(info['episode']['agent']),
            #                 int(info['episode']['start']), int(info['episode']['seed']), int(info['episode']['floor']),
            #                 np.around(info['episode']['r'],6), int(info['episode']['l']), info['episode']['t']]
            #         new_line = pd.DataFrame([data], columns = history_column_names).astype( dtype=history_column_types)
            #         history_df = new_line.append(history_df)
            #         history_is_dirty = True
            # if history_is_dirty:
            #     history_df.to_csv(history_file, encoding='utf-8', index=False)

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            # [[0.0] if done_ else [1.0] for done_ in done]).to(device)
            bad_masks = torch.FloatTensor(
                [[0.0] if 'bad_transition' in info.keys() else [1.0]
                 for info in infos])
            #  for info in infos]).to(device)
            if args.half_precision:
                masks = masks.half()
                bad_masks = bad_masks.half()
            rollouts.insert(obs, vector_obs, recurrent_hidden_states, action,
                            action_log_prob, value, reward, masks, bad_masks)

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.obs[-1], rollouts.vector_obs[-1],
                rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1]).detach()

        if args.gail:
            if j >= 10:
                envs.venv.eval()

            gail_epoch = args.gail_epoch
            if j < 10:
                gail_epoch = 100  # Warm up
            for _ in range(gail_epoch):
                discr.update(gail_train_loader, rollouts,
                             utils.get_vec_normalize(envs)._obfilt)

            for step in range(args.num_steps):
                rollouts.rewards[step] = discr.predict_reward(
                    rollouts.obs[step], rollouts.vector_obs[step],
                    rollouts.actions[step], args.gamma, rollouts.masks[step])

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.gae_lambda, args.use_proper_time_limits)

        value_loss, action_loss, dist_entropy = agent.update(rollouts)

        rollouts.after_update()

        # save for every interval-th episode or for the last epoch
        if (j % args.save_interval == 0
                or j == num_updates - 1) and args.save_dir != "":
            try:
                os.makedirs(save_path)
            except OSError:
                pass
            total_num_steps = (j + 1) * args.num_processes * args.num_steps
            print("Save at update {} / timestep {}".format(j, total_num_steps))
            torch.save([
                actor_critic,
                getattr(utils.get_vec_normalize(envs), 'ob_rms', None)
            ], os.path.join(save_path, args.env + ".pt"))

        if j % args.log_interval == 0:
            total_num_steps = (j + 1) * args.num_processes * args.num_steps
            end = time.time()
            if len(episode_rewards) == 0:
                print(
                    "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}"
                    .format(
                        j,
                        total_num_steps,
                        int(total_num_steps / (end - start)),
                        0,
                        0,  # len(episode_rewards), np.mean(episode_rewards),
                        0,
                        0,  # np.median(episode_rewards), np.min(episode_rewards),
                        0,  # np.max(episode_rewards), 
                        dist_entropy,
                        value_loss,
                        action_loss))
            else:
                writer.add_scalar('reward',
                                  np.average(episode_rewards),
                                  global_step=total_num_steps)
                writer.add_scalar('floor',
                                  np.average(episode_floors),
                                  global_step=total_num_steps)
                writer.add_scalar('reward.std',
                                  np.std(episode_rewards),
                                  global_step=total_num_steps)
                writer.add_scalar('floor.std',
                                  np.std(episode_floors),
                                  global_step=total_num_steps)
                writer.add_scalar('steps',
                                  np.average(episode_times),
                                  global_step=total_num_steps)
                # writer.add_scalar('median', np.median(episode_rewards), global_step=total_num_steps)
                # writer.add_scalar('min', np.min(episode_rewards), global_step=total_num_steps)
                # writer.add_scalar('max', np.max(episode_rewards), global_step=total_num_steps)
                writer.add_scalar('FPS',
                                  int(total_num_steps / (end - start)),
                                  global_step=total_num_steps)
                writer.add_scalar('value_loss',
                                  np.around(value_loss, 6),
                                  global_step=total_num_steps)
                writer.add_scalar("action_loss:",
                                  np.around(action_loss, 6),
                                  global_step=total_num_steps)
                writer.add_scalar("dist_entropy:",
                                  np.around(dist_entropy, 6),
                                  global_step=total_num_steps)
                print(
                    "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}"
                    .format(j, total_num_steps,
                            int(total_num_steps / (end - start)),
                            len(episode_rewards), np.mean(episode_rewards),
                            np.median(episode_rewards),
                            np.min(episode_rewards), np.max(episode_rewards),
                            dist_entropy, value_loss, action_loss))
            print("value_loss:", np.around(value_loss, 6), "action_loss:",
                  np.around(action_loss, 6), "dist_entropy:",
                  np.around(dist_entropy, 6))

        if (args.eval_interval is not None and len(episode_rewards) > 1
                and j % args.eval_interval == 0):
            ob_rms = utils.get_vec_normalize(envs).ob_rms
            evaluate(actor_critic, ob_rms, args.env, args.seed,
                     args.num_processes, eval_log_dir, device)
               batch_size=eval_batch_size,
               shuffle=False,
               collate_fn=collate_fn,
               num_workers=0)
}

# construct environment

env = MaximumIndependentSetEnv(max_epi_t=max_epi_t,
                               max_num_nodes=max_num_nodes,
                               hamming_reward_coef=hamming_reward_coef,
                               device=device)

# construct rollout storage
rollout = RolloutStorage(max_t=max_rollout_t,
                         batch_size=rollout_batch_size,
                         num_samples=train_num_samples)

# construct actor critic network
actor_critic = ActorCritic(actor_class=PolicyGraphConvNet,
                           critic_class=ValueGraphConvNet,
                           max_num_nodes=max_num_nodes,
                           hidden_dim=hidden_dim,
                           num_layers=num_layers,
                           device=device)

# construct PPO framework
framework = ProxPolicyOptimFramework(actor_critic=actor_critic,
                                     init_lr=init_lr,
                                     clip_value=clip_value,
                                     optim_num_samples=optim_num_samples,
Beispiel #4
0
    def update(self, rollouts: RolloutStorage):
        advantages = rollouts.returns[:-1] - rollouts.value_preds[:-1]
        if advantages.numel() > 1:
            advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-5)

        logger = collections.Counter()

        for e in range(self.ppo_epoch):
            if self.agent.is_recurrent:
                data_generator = rollouts.recurrent_generator(
                    advantages, self.num_mini_batch
                )
            else:
                data_generator = rollouts.feed_forward_generator(
                    advantages, self.num_mini_batch
                )

            sample: Batch
            for sample in data_generator:
                # Reshape to do in a single forward pass for all steps
                act = self.agent(
                    inputs=sample.obs,
                    rnn_hxs=sample.recurrent_hidden_states,
                    masks=sample.masks,
                    action=sample.actions,
                )
                values = act.value
                action_log_probs = act.action_log_probs
                loss = act.aux_loss
                # log_values = act.log
                # logger.update(**log_values)

                if not self.aux_loss_only:
                    ratio = torch.exp(action_log_probs - sample.old_action_log_probs)
                    surr1 = ratio * sample.adv
                    surr2 = (
                        torch.clamp(ratio, 1.0 - self.clip_param, 1.0 + self.clip_param)
                        * sample.adv
                    )
                    action_loss = -torch.min(surr1, surr2).mean()
                    logger.update(action_loss=action_loss)
                    loss += action_loss

                if self.use_clipped_value_loss:

                    value_pred_clipped = sample.value_preds + (
                        values - sample.value_preds
                    ).clamp(-self.clip_param, self.clip_param)
                    value_losses = (values - sample.ret).pow(2)
                    value_losses_clipped = (value_pred_clipped - sample.ret).pow(2)
                    value_loss = (
                        0.5 * torch.max(value_losses, value_losses_clipped).mean()
                    )
                else:
                    value_loss = 0.5 * F.mse_loss(sample.ret, values)
                logger.update(value_loss=value_loss)
                loss += self.value_loss_coef * value_loss

                self.optimizer.zero_grad()
                loss.backward()

                nn.utils.clip_grad_norm_(self.agent.parameters(), self.max_grad_norm)
                self.optimizer.step()

                # noinspection PyTypeChecker
                logger.update(n=1.0)

        n = logger.pop("n", 0)
        return {k: v.mean().item() / n for k, v in logger.items()}
def main():

    from config import config_enhanced
    writer = SummaryWriter(os.path.join('runs', name_dir(config_enhanced)))

    torch.multiprocessing.freeze_support()

    print("Current config_enhanced is:")
    pprint(config_enhanced)
    writer.add_text("config", str(config_enhanced))

    save_path = str(writer.get_logdir())
    try:
        os.makedirs(save_path)
    except OSError:
        pass

    # with open(os.path.join(save_path, "config.json"), 'w') as outfile:
    #     json.dump(config_enhanced, outfile)

    torch.manual_seed(config_enhanced['seed'])
    torch.cuda.manual_seed_all(config_enhanced['seed'])

    use_cuda = torch.cuda.is_available()
    if torch.cuda.is_available() and config_enhanced['cuda_deterministic']:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    # torch.set_num_threads(1)
    if use_cuda:
        device = torch.device('cuda')
        print("using GPU")
    else:
        device = torch.device('cpu')
        print("using CPU")

    if config_enhanced['num_processes'] == "num_cpu":
        num_processes = multiprocessing.cpu_count() - 1
    else:
        num_processes = config_enhanced['num_processes']

    # if torch.cuda.device_count() > 1:
    #     print("Let's use", torch.cuda.device_count(), "GPUs!")
    #     # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
    #     model = torch.nn.DataParallel(model)

    env = CholeskyTaskGraph(**config_enhanced['env_settings'])
    envs = VectorEnv(env, num_processes)
    envs.reset()

    model = SimpleNet(**config_enhanced["network_parameters"])
    if config_enhanced["model_path"]:
        model.load_state_dict(torch.load(config_enhanced['model_path']))

    actor_critic = Policy(model, envs.action_space, config_enhanced)
    actor_critic = actor_critic.to(device)

    if config_enhanced['agent'] == 'PPO':
        print("using PPO")
        agent_settings = config_enhanced['PPO_settings']
        agent = PPO(
            actor_critic,
            **agent_settings)

    elif config_enhanced['agent'] == 'A2C':
        print("using A2C")
        agent_settings = config_enhanced['A2C_settings']
        agent = A2C_ACKTR(
            actor_critic,
            **agent_settings)

    rollouts = RolloutStorage(config_enhanced['trajectory_length'], num_processes,
                              env_example.observation_space.shape, env_example.action_space)



    obs = envs.reset()
    obs = torch.tensor(obs, device=device)
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    episode_rewards = deque(maxlen=10)

    start = time.time()
    num_updates = int(
        config_enhanced['num_env_steps']) // config_enhanced['trajectory_length'] // num_processes
    for j in range(num_updates):

        if config_enhanced['use_linear_lr_decay']:
            # decrease learning rate linearly
            utils.update_linear_schedule(
                agent.optimizer, j, num_updates, config_enhanced['network']['lr'])

        for step in tqdm(range(config_enhanced['trajectory_length'])):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob = actor_critic.act(
                    rollouts.obs[step])
            actions = action.squeeze(-1).detach().cpu().numpy()

            # Observe reward and next obs
            obs, reward, done, infos = envs.step(actions)
            obs = torch.tensor(obs, device=device)
            reward = torch.tensor(reward, device=device).unsqueeze(-1)
            done = torch.tensor(done, device=device)

            n_step = (j * config_enhanced['trajectory_length'] + step) * num_processes
            for info in infos:
                if 'episode' in info.keys():
                    reward_episode = info['episode']['r']
                    episode_rewards.append(reward_episode)
                    writer.add_scalar('reward', reward_episode, n_step)
                    writer.add_scalar('solved', int(info['episode']['length'] == envs.envs[0].max_steps))

            # If done then clean the history of observations.
            masks = torch.FloatTensor(
                [[0.0] if done_ else [1.0] for done_ in done])
            bad_masks = torch.FloatTensor(
                [[0.0] if 'bad_transition' in info.keys() else [1.0]
                 for info in infos])
            rollouts.insert(obs, action,
                            action_log_prob, value, reward, masks, bad_masks)

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.obs[-1]).detach()

        rollouts.compute_returns(next_value, config_enhanced["use_gae"], config_enhanced["gamma"],
                                 config_enhanced['gae_lambda'], config_enhanced['use_proper_time_limits'])

        value_loss, action_loss, dist_entropy = agent.update(rollouts)
        writer.add_scalar('value loss', value_loss, n_step)
        writer.add_scalar('action loss', action_loss, n_step)
        writer.add_scalar('dist_entropy', dist_entropy, n_step)

        rollouts.after_update()

        # save for every interval-th episode or for the last epoch
        if (j % config_enhanced['save_interval'] == 0
                or j == num_updates - 1):
            save_path = str(writer.get_logdir())
            try:
                os.makedirs(save_path)
            except OSError:
                pass

            torch.save(actor_critic, os.path.join(save_path, "model.pth"))

        if j % config_enhanced['log_interval'] == 0 and len(episode_rewards) > 1:
            end = time.time()
            print(
                "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n"
                    .format(j, n_step,
                            int(n_step / (end - start)),
                            len(episode_rewards), np.mean(episode_rewards),
                            np.median(episode_rewards), np.min(episode_rewards),
                            np.max(episode_rewards), dist_entropy, value_loss,
                            action_loss))

        if (config_enhanced['evaluate_every'] is not None and len(episode_rewards) > 1
                and j % config_enhanced['evaluate_every'] == 0):
            eval_reward = evaluate(actor_critic, boxworld, config_enhanced, device)
            writer.add_scalar("eval reward", eval_reward, n_step)
Beispiel #6
0
    def setup(
        self,
        num_steps,
        eval_steps,
        num_processes,
        seed,
        cuda_deterministic,
        cuda,
        time_limit,
        gamma,
        normalize,
        log_interval,
        eval_interval,
        no_eval,
        use_gae,
        tau,
        ppo_args,
        agent_args,
        render,
        render_eval,
        load_path,
        synchronous,
        num_batch,
        env_args,
        success_reward,
        use_tqdm,
    ):
        # Properly restrict pytorch to not consume extra resources.
        #  - https://github.com/pytorch/pytorch/issues/975
        #  - https://github.com/ray-project/ray/issues/3609
        torch.set_num_threads(1)
        os.environ["OMP_NUM_THREADS"] = "1"

        if render_eval and not render:
            eval_interval = 1
        if render or render_eval:
            ppo_args.update(ppo_epoch=0)
            num_processes = 1
            cuda = False

        # reproducibility
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        np.random.seed(seed)
        cuda &= torch.cuda.is_available()
        if cuda and cuda_deterministic:
            torch.backends.cudnn.benchmark = False
            torch.backends.cudnn.deterministic = True
        torch.set_num_threads(1)

        self.device = "cpu"
        if cuda:
            self.device = self.get_device()
        # print("Using device", self.device)

        self.envs = self.make_vec_envs(
            **env_args,
            seed=seed,
            gamma=(gamma if normalize else None),
            render=render,
            synchronous=True if render else synchronous,
            evaluation=False,
            num_processes=num_processes,
            time_limit=time_limit,
        )
        self.make_eval_envs = functools.partial(
            self.make_vec_envs,
            **env_args,
            seed=seed,
            gamma=(gamma if normalize else None),
            render=render,
            synchronous=True if render else synchronous,
            evaluation=True,
            num_processes=num_processes,
            time_limit=time_limit,
        )

        self.envs.to(self.device)
        self.agent = self.build_agent(envs=self.envs, **agent_args)
        self.rollouts = RolloutStorage(
            num_steps=num_steps,
            num_processes=num_processes,
            obs_space=self.envs.observation_space,
            action_space=self.envs.action_space,
            recurrent_hidden_state_size=self.agent.recurrent_hidden_state_size,
            use_gae=use_gae,
            gamma=gamma,
            tau=tau,
        )

        # copy to device
        if cuda:
            tick = time.time()
            self.agent.to(self.device)
            self.rollouts.to(self.device)
            print("Values copied to GPU in", time.time() - tick, "seconds")

        self.ppo = PPO(agent=self.agent, num_batch=num_batch, **ppo_args)
        self.counter = Counter()

        self.i = 0
        if load_path:
            self._restore(load_path)

        self.make_train_iterator = lambda: self.train_generator(
            num_steps=num_steps,
            num_processes=num_processes,
            eval_steps=eval_steps,
            log_interval=log_interval,
            eval_interval=eval_interval,
            no_eval=no_eval,
            use_tqdm=use_tqdm,
            success_reward=success_reward,
        )
        self.train_iterator = self.make_train_iterator()
Beispiel #7
0
class TrainBase(abc.ABC):
    def setup(
        self,
        num_steps,
        eval_steps,
        num_processes,
        seed,
        cuda_deterministic,
        cuda,
        time_limit,
        gamma,
        normalize,
        log_interval,
        eval_interval,
        no_eval,
        use_gae,
        tau,
        ppo_args,
        agent_args,
        render,
        render_eval,
        load_path,
        synchronous,
        num_batch,
        env_args,
        success_reward,
        use_tqdm,
    ):
        # Properly restrict pytorch to not consume extra resources.
        #  - https://github.com/pytorch/pytorch/issues/975
        #  - https://github.com/ray-project/ray/issues/3609
        torch.set_num_threads(1)
        os.environ["OMP_NUM_THREADS"] = "1"

        if render_eval and not render:
            eval_interval = 1
        if render or render_eval:
            ppo_args.update(ppo_epoch=0)
            num_processes = 1
            cuda = False

        # reproducibility
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        np.random.seed(seed)
        cuda &= torch.cuda.is_available()
        if cuda and cuda_deterministic:
            torch.backends.cudnn.benchmark = False
            torch.backends.cudnn.deterministic = True
        torch.set_num_threads(1)

        self.device = "cpu"
        if cuda:
            self.device = self.get_device()
        # print("Using device", self.device)

        self.envs = self.make_vec_envs(
            **env_args,
            seed=seed,
            gamma=(gamma if normalize else None),
            render=render,
            synchronous=True if render else synchronous,
            evaluation=False,
            num_processes=num_processes,
            time_limit=time_limit,
        )
        self.make_eval_envs = functools.partial(
            self.make_vec_envs,
            **env_args,
            seed=seed,
            gamma=(gamma if normalize else None),
            render=render,
            synchronous=True if render else synchronous,
            evaluation=True,
            num_processes=num_processes,
            time_limit=time_limit,
        )

        self.envs.to(self.device)
        self.agent = self.build_agent(envs=self.envs, **agent_args)
        self.rollouts = RolloutStorage(
            num_steps=num_steps,
            num_processes=num_processes,
            obs_space=self.envs.observation_space,
            action_space=self.envs.action_space,
            recurrent_hidden_state_size=self.agent.recurrent_hidden_state_size,
            use_gae=use_gae,
            gamma=gamma,
            tau=tau,
        )

        # copy to device
        if cuda:
            tick = time.time()
            self.agent.to(self.device)
            self.rollouts.to(self.device)
            print("Values copied to GPU in", time.time() - tick, "seconds")

        self.ppo = PPO(agent=self.agent, num_batch=num_batch, **ppo_args)
        self.counter = Counter()

        self.i = 0
        if load_path:
            self._restore(load_path)

        self.make_train_iterator = lambda: self.train_generator(
            num_steps=num_steps,
            num_processes=num_processes,
            eval_steps=eval_steps,
            log_interval=log_interval,
            eval_interval=eval_interval,
            no_eval=no_eval,
            use_tqdm=use_tqdm,
            success_reward=success_reward,
        )
        self.train_iterator = self.make_train_iterator()

    def _train(self):
        try:
            return next(self.train_iterator)
        except StopIteration:
            self.train_iterator = self.make_train_iterator()
            return self._train()

    def train_generator(
        self,
        num_steps,
        num_processes,
        eval_steps,
        log_interval,
        eval_interval,
        no_eval,
        success_reward,
        use_tqdm,
    ):
        if eval_interval and not no_eval:
            # vec_norm = get_vec_normalize(eval_envs)
            # if vec_norm is not None:
            #     vec_norm.eval()
            #     vec_norm.ob_rms = get_vec_normalize(envs).ob_rms

            # self.envs.evaluate()
            eval_masks = torch.zeros(num_processes, 1, device=self.device)
            eval_counter = Counter()
            envs = self.make_eval_envs()
            envs.to(self.device)
            with self.agent.recurrent_module.evaluating(
                    envs.observation_space):
                eval_recurrent_hidden_states = torch.zeros(
                    num_processes,
                    self.agent.recurrent_hidden_state_size,
                    device=self.device,
                )

                eval_result = self.run_epoch(
                    obs=envs.reset(),
                    rnn_hxs=eval_recurrent_hidden_states,
                    masks=eval_masks,
                    num_steps=eval_steps,
                    # max(num_steps, time_limit) if time_limit else num_steps,
                    counter=eval_counter,
                    success_reward=success_reward,
                    use_tqdm=use_tqdm,
                    rollouts=None,
                    envs=envs,
                )
            envs.close()
            eval_result = {f"eval_{k}": v for k, v in eval_result.items()}
        else:
            eval_result = {}
        # self.envs.train()
        obs = self.envs.reset()
        self.rollouts.obs[0].copy_(obs)
        tick = time.time()
        log_progress = None

        if eval_interval:
            eval_iterator = range(self.i % eval_interval, eval_interval)
            if use_tqdm:
                eval_iterator = tqdm(eval_iterator, desc="next eval")
        else:
            eval_iterator = itertools.count(self.i)

        for _ in eval_iterator:
            if self.i % log_interval == 0 and use_tqdm:
                log_progress = tqdm(total=log_interval, desc="next log")
            self.i += 1
            epoch_counter = self.run_epoch(
                obs=self.rollouts.obs[0],
                rnn_hxs=self.rollouts.recurrent_hidden_states[0],
                masks=self.rollouts.masks[0],
                num_steps=num_steps,
                counter=self.counter,
                success_reward=success_reward,
                use_tqdm=False,
                rollouts=self.rollouts,
                envs=self.envs,
            )

            with torch.no_grad():
                next_value = self.agent.get_value(
                    self.rollouts.obs[-1],
                    self.rollouts.recurrent_hidden_states[-1],
                    self.rollouts.masks[-1],
                ).detach()

            self.rollouts.compute_returns(next_value=next_value)
            train_results = self.ppo.update(self.rollouts)
            self.rollouts.after_update()
            if log_progress is not None:
                log_progress.update()
            if self.i % log_interval == 0:
                total_num_steps = log_interval * num_processes * num_steps
                fps = total_num_steps / (time.time() - tick)
                tick = time.time()
                yield dict(tick=tick,
                           fps=fps,
                           **epoch_counter,
                           **train_results,
                           **eval_result)

    def run_epoch(
        self,
        obs,
        rnn_hxs,
        masks,
        num_steps,
        counter,
        success_reward,
        use_tqdm,
        rollouts,
        envs,
    ):
        # noinspection PyTypeChecker
        episode_counter = defaultdict(list)
        iterator = range(num_steps)
        if use_tqdm:
            iterator = tqdm(iterator, desc="evaluating")
        for _ in iterator:
            with torch.no_grad():
                act = self.agent(inputs=obs, rnn_hxs=rnn_hxs,
                                 masks=masks)  # type: AgentValues

            # Observe reward and next obs
            obs, reward, done, infos = envs.step(act.action)
            self.process_infos(episode_counter, done, infos, **act.log)

            # track rewards
            counter["reward"] += reward.numpy()
            counter["time_step"] += np.ones_like(done)
            episode_rewards = counter["reward"][done]
            episode_counter["rewards"] += list(episode_rewards)
            if success_reward is not None:
                # noinspection PyTypeChecker
                episode_counter["success"] += list(
                    episode_rewards >= success_reward)

            episode_counter["time_steps"] += list(counter["time_step"][done])
            counter["reward"][done] = 0
            counter["time_step"][done] = 0

            # If done then clean the history of observations.
            masks = torch.tensor(1 - done,
                                 dtype=torch.float32,
                                 device=obs.device).unsqueeze(1)
            rnn_hxs = act.rnn_hxs
            if rollouts is not None:
                rollouts.insert(
                    obs=obs,
                    recurrent_hidden_states=act.rnn_hxs,
                    actions=act.action,
                    action_log_probs=act.action_log_probs,
                    values=act.value,
                    rewards=reward,
                    masks=masks,
                )

        return dict(episode_counter)

    @staticmethod
    def process_infos(episode_counter, done, infos, **act_log):
        for d in infos:
            for k, v in d.items():
                episode_counter[k] += v if type(v) is list else [float(v)]
        for k, v in act_log.items():
            episode_counter[k] += v if type(v) is list else [float(v)]

    @staticmethod
    def build_agent(envs, **agent_args):
        return Agent(envs.observation_space.shape, envs.action_space,
                     **agent_args)

    @staticmethod
    def make_env(env_id, seed, rank, add_timestep, time_limit, evaluation):
        env = gym.make(env_id)
        is_atari = hasattr(gym.envs, "atari") and isinstance(
            env.unwrapped, gym.envs.atari.atari_env.AtariEnv)
        env.seed(seed + rank)
        obs_shape = env.observation_space.shape
        if add_timestep and len(
                obs_shape) == 1 and str(env).find("TimeLimit") > -1:
            env = AddTimestep(env)
        if is_atari and len(env.observation_space.shape) == 3:
            env = wrap_deepmind(env)

        # elif len(env.observation_space.shape) == 3:
        #     raise NotImplementedError(
        #         "CNN models work only for atari,\n"
        #         "please use a custom wrapper for a custom pixel input env.\n"
        #         "See wrap_deepmind for an example.")

        # If the input has shape (W,H,3), wrap for PyTorch convolutions
        obs_shape = env.observation_space.shape
        if len(obs_shape) == 3 and obs_shape[2] in [1, 3]:
            env = TransposeImage(env)

        if time_limit is not None:
            env = TimeLimit(env, max_episode_steps=time_limit)

        return env

    def make_vec_envs(
        self,
        num_processes,
        gamma,
        render,
        synchronous,
        env_id,
        add_timestep,
        seed,
        evaluation,
        time_limit,
        num_frame_stack=None,
        **env_args,
    ):
        envs = [
            functools.partial(  # thunk
                self.make_env,
                rank=i,
                env_id=env_id,
                add_timestep=add_timestep,
                seed=seed,
                evaluation=evaluation,
                time_limit=time_limit,
                evaluating=evaluation,
                **env_args,
            ) for i in range(num_processes)
        ]

        if len(envs) == 1 or sys.platform == "darwin" or synchronous:
            envs = DummyVecEnv(envs, render=render)
        else:
            envs = SubprocVecEnv(envs)

        # if (
        # envs.observation_space.shape
        # and len(envs.observation_space.shape) == 1
        # ):
        # if gamma is None:
        # envs = VecNormalize(envs, ret=False)
        # else:
        # envs = VecNormalize(envs, gamma=gamma)

        envs = VecPyTorch(envs)

        if num_frame_stack is not None:
            envs = VecPyTorchFrameStack(envs, num_frame_stack)
        # elif len(envs.observation_space.shape) == 3:
        #     envs = VecPyTorchFrameStack(envs, 4, device)

        return envs

    def _save(self, checkpoint_dir):
        modules = dict(optimizer=self.ppo.optimizer,
                       agent=self.agent)  # type: Dict[str, torch.nn.Module]
        # if isinstance(self.envs.venv, VecNormalize):
        #     modules.update(vec_normalize=self.envs.venv)
        state_dict = {
            name: module.state_dict()
            for name, module in modules.items()
        }
        save_path = Path(
            checkpoint_dir,
            f"{self.i if self.save_separate else 'checkpoint'}.pt")
        torch.save(dict(step=self.i, **state_dict), save_path)
        print(f"Saved parameters to {save_path}")
        return str(save_path)

    def _restore(self, checkpoint):
        load_path = checkpoint
        state_dict = torch.load(load_path, map_location=self.device)
        self.agent.load_state_dict(state_dict["agent"])
        self.ppo.optimizer.load_state_dict(state_dict["optimizer"])
        self.i = state_dict.get("step", -1) + 1
        # if isinstance(self.envs.venv, VecNormalize):
        #     self.envs.venv.load_state_dict(state_dict["vec_normalize"])
        print(f"Loaded parameters from {load_path}.")

    @abc.abstractmethod
    def get_device(self):
        raise NotImplementedError