Exemplo n.º 1
0
def main():
    args = get_args()

    #torch.manual_seed(args.seed)
    #torch.cuda.manual_seed_all(args.seed)

    if args.cuda and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    log_dir = os.path.expanduser(args.log_dir)
    utils.cleanup_log_dir(log_dir)
    torch.set_num_threads(1)

    if args.gen_formula_only:
        if args.test_in_domain or args.test_out_domain:
            formulas = sample_formulas_test(args)
        else:
            formulas = sample_formulas_train(args)
        exit()

    if args.train:
        args = setup_summary_writer(args)
        formulas = sample_formulas_train(args)
        args.formula, _, _ = formulas[0]
        train(args, formulas)
        print('Finish training')
    else:
        formulas = sample_formulas_test(args)
        args.formula, _, _ = formulas[0]
        n_successes, _ = test(args, formulas, model_name=args.save_model_name)
        print('Accuracy:', float(n_successes / len(formulas)))
Exemplo n.º 2
0
def main():
    args = get_args()
    args.num_processes = 16
    args.env_name = 'BreakoutNoFrameskip-v4'

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    log_dir = os.path.expanduser(args.log_dir)
    eval_log_dir = log_dir + "_eval"
    utils.cleanup_log_dir(log_dir)
    utils.cleanup_log_dir(eval_log_dir)

    torch.set_num_threads(1)
    device = torch.device("cuda" if args.cuda else "cpu")

    envs = make_vec_envs(args.env_name, args.seed, args.num_processes,
                         args.gamma, args.log_dir, device, False)

    actor_critic = Policy(envs.observation_space.shape,
                          envs.action_space,
                          base_kwargs={'recurrent': args.recurrent_policy})
    actor_critic.to(device)

    if args.algo == 'a2c':
        agent = A2C_ACKTR(actor_critic,
                          args.value_loss_coef,
                          args.entropy_coef,
                          lr=args.lr,
                          eps=args.eps,
                          alpha=args.alpha,
                          max_grad_norm=args.max_grad_norm)

    rollouts = RolloutStorage(args.num_steps, args.num_processes,
                              envs.observation_space.shape, envs.action_space,
                              actor_critic.recurrent_hidden_state_size)

    obs = envs.reset()
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    episode_rewards = deque(maxlen=10)

    start = time.time()
    num_updates = int(
        args.num_env_steps) // args.num_steps // args.num_processes
    for j in range(num_updates):

        for step in range(args.num_steps):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                    rollouts.obs[step], rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step])

            # Obser reward and next obs
            obs, reward, done, infos = envs.step(action)

            for info in infos:
                if 'episode' in info.keys():
                    episode_rewards.append(info['episode']['r'])

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            bad_masks = torch.FloatTensor(
                [[0.0] if 'bad_transition' in info.keys() else [1.0]
                 for info in infos])
            rollouts.insert(obs, recurrent_hidden_states, action,
                            action_log_prob, value, reward, masks, bad_masks)

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.obs[-1], rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1]).detach()

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.gae_lambda, args.use_proper_time_limits)

        value_loss, action_loss, dist_entropy = agent.update(rollouts)

        rollouts.after_update()

        if j % args.log_interval == 0 and len(episode_rewards) > 1:
            total_num_steps = (j + 1) * args.num_processes * args.num_steps
            end = time.time()
            print(
                "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n"
                .format(j, total_num_steps,
                        int(total_num_steps / (end - start)),
                        len(episode_rewards), np.mean(episode_rewards),
                        np.median(episode_rewards), np.min(episode_rewards),
                        np.max(episode_rewards), dist_entropy, value_loss,
                        action_loss))
Exemplo n.º 3
0
def main(
    _run,
    _log,
    num_env_steps,
    env_name,
    seed,
    algorithm,
    dummy_vecenv,
    time_limit,
    wrappers,
    save_dir,
    eval_dir,
    loss_dir,
    log_interval,
    save_interval,
    eval_interval,
):

    if loss_dir:
        loss_dir = path.expanduser(loss_dir.format(id=str(_run._id)))
        utils.cleanup_log_dir(loss_dir)
        writer = SummaryWriter(loss_dir)
    else:
        writer = None

    eval_dir = path.expanduser(eval_dir.format(id=str(_run._id)))
    save_dir = path.expanduser(save_dir.format(id=str(_run._id)))

    utils.cleanup_log_dir(eval_dir)
    utils.cleanup_log_dir(save_dir)

    torch.set_num_threads(1)
    envs = make_vec_envs(
        env_name,
        seed,
        dummy_vecenv,
        algorithm["num_processes"],
        time_limit,
        wrappers,
        algorithm["device"],
    )

    agents = [
        A2C(i, osp, asp)
        for i, (osp, asp) in enumerate(zip(envs.observation_space, envs.action_space))
    ]
    obs = envs.reset()

    for i in range(len(obs)):
        agents[i].storage.obs[0].copy_(obs[i])
        agents[i].storage.to(algorithm["device"])

    start = time.time()
    num_updates = (
        int(num_env_steps) // algorithm["num_steps"] // algorithm["num_processes"]
    )

    all_infos = deque(maxlen=10)

    for j in range(1, num_updates + 1):

        for step in range(algorithm["num_steps"]):
            # Sample actions
            with torch.no_grad():
                n_value, n_action, n_action_log_prob, n_recurrent_hidden_states = zip(
                    *[
                        agent.model.act(
                            agent.storage.obs[step],
                            agent.storage.recurrent_hidden_states[step],
                            agent.storage.masks[step],
                        )
                        for agent in agents
                    ]
                )
            # Obser reward and next obs
            obs, reward, done, infos = envs.step(n_action)
            # envs.envs[0].render()

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done])

            bad_masks = torch.FloatTensor(
                [
                    [0.0] if info.get("TimeLimit.truncated", False) else [1.0]
                    for info in infos
                ]
            )
            for i in range(len(agents)):
                agents[i].storage.insert(
                    obs[i],
                    n_recurrent_hidden_states[i],
                    n_action[i],
                    n_action_log_prob[i],
                    n_value[i],
                    reward[:, i].unsqueeze(1),
                    masks,
                    bad_masks,
                )

            for info in infos:
                if info:
                    all_infos.append(info)

        # value_loss, action_loss, dist_entropy = agent.update(rollouts)
        for agent in agents:
            agent.compute_returns()

        for agent in agents:
            loss = agent.update([a.storage for a in agents])
            for k, v in loss.items():
                if writer:
                    writer.add_scalar(f"agent{agent.agent_id}/{k}", v, j)

        for agent in agents:
            agent.storage.after_update()

        if j % log_interval == 0 and len(all_infos) > 1:
            squashed = _squash_info(all_infos)

            total_num_steps = (
                (j + 1) * algorithm["num_processes"] * algorithm["num_steps"]
            )
            end = time.time()
            _log.info(
                f"Updates {j}, num timesteps {total_num_steps}, FPS {int(total_num_steps / (end - start))}"
            )
            _log.info(
                f"Last {len(all_infos)} training episodes mean reward {squashed['episode_reward'].sum():.3f}"
            )

            for k, v in squashed.items():
                _run.log_scalar(k, v, j)
            all_infos.clear()

        if save_interval is not None and (
            j > 0 and j % save_interval == 0 or j == num_updates
        ):
            cur_save_dir = path.join(save_dir, f"u{j}")
            for agent in agents:
                save_at = path.join(cur_save_dir, f"agent{agent.agent_id}")
                os.makedirs(save_at, exist_ok=True)
                agent.save(save_at)
            archive_name = shutil.make_archive(cur_save_dir, "xztar", save_dir, f"u{j}")
            shutil.rmtree(cur_save_dir)
            _run.add_artifact(archive_name)

        if eval_interval is not None and (
            j > 0 and j % eval_interval == 0 or j == num_updates
        ):
            evaluate(
                agents, os.path.join(eval_dir, f"u{j}"),
            )
            videos = glob.glob(os.path.join(eval_dir, f"u{j}") + "/*.mp4")
            for i, v in enumerate(videos):
                _run.add_artifact(v, f"u{j}.{i}.mp4")
    envs.close()
Exemplo n.º 4
0
def main():
    args = get_args()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    if args.cuda and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    log_dir = os.path.expanduser(args.log_dir)
    eval_log_dir = log_dir + "_eval"
    utils.cleanup_log_dir(log_dir)
    utils.cleanup_log_dir(eval_log_dir)

    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")

    envs = make_vec_envs(args.env_name, args.seed, args.num_processes,
                         args.gamma, args.log_dir, device, False,
                         args.custom_gym)

    base = SEVN

    actor_critic = Policy(envs.observation_space.shape,
                          envs.action_space,
                          base_kwargs={'recurrent': args.recurrent_policy})
    actor_critic.to(device)

    if args.algo == 'ppo':
        agent = PPO(actor_critic,
                    args.clip_param,
                    args.ppo_epoch,
                    args.num_mini_batch,
                    args.value_loss_coef,
                    args.entropy_coef,
                    lr=args.lr,
                    eps=args.eps,
                    max_grad_norm=args.max_grad_norm)

    rollouts = RolloutStorage(args.num_steps, args.num_processes,
                              envs.observation_space.shape, envs.action_space,
                              actor_critic.recurrent_hidden_state_size)

    obs = envs.reset()
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    episode_rewards = deque(maxlen=10)
    episode_length = deque(maxlen=10)
    episode_success_rate = deque(maxlen=100)
    episode_total = 0

    start = time.time()
    num_updates = int(
        args.num_env_steps) // args.num_steps // args.num_processes
    for j in range(num_updates):

        if args.use_linear_lr_decay:
            # decrease learning rate linearly
            utils.update_linear_schedule(agent.optimizer, j, num_updates,
                                         args.lr)

        for step in range(args.num_steps):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                    rollouts.obs[step], rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step])

            # Obser reward and next obs
            obs, reward, done, infos = envs.step(action)

            for info in infos:
                if 'episode' in info.keys():
                    episode_rewards.append(info['episode']['r'])
                    episode_length.append(info['episode']['l'])
                    episode_success_rate.append(
                        info['was_successful_trajectory'])
                    episode_total += 1

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            bad_masks = torch.FloatTensor(
                [[0.0] if 'bad_transition' in info.keys() else [1.0]
                 for info in infos])
            rollouts.insert(obs, recurrent_hidden_states, action,
                            action_log_prob, value, reward, masks, bad_masks)

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.obs[-1], rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1]).detach()

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.gae_lambda, args.use_proper_time_limits)

        value_loss, action_loss, dist_entropy = agent.update(rollouts)

        rollouts.after_update()

        # save for every interval-th episode or for the last epoch
        if (j % args.save_interval == 0
                or j == num_updates - 1) and args.save_dir != "":
            save_path = os.path.join(args.save_dir, args.algo)
            try:
                os.makedirs(save_path)
            except OSError:
                pass

            torch.save([
                actor_critic,
                getattr(utils.get_vec_normalize(envs), 'ob_rms', None)
            ], os.path.join(save_path, args.env_name + ".pt"))

        if j % args.log_interval == 0 and len(episode_rewards) > 1:
            total_num_steps = (j + 1) * args.num_processes * args.num_steps
            end = time.time()
            writer.add_scalars('Train/Episode Reward', {
                "Reward Mean": np.mean(episode_rewards),
                "Reward Min": np.min(episode_rewards),
                "Reward Max": np.max(episode_rewards)
            },
                               global_step=total_num_steps)
            writer.add_scalars('Train/Episode Length', {
                "Episode Length Mean": np.mean(episode_length),
                "Episode Length Min": np.min(episode_length),
                "Episode Length Max": np.max(episode_length)
            },
                               global_step=total_num_steps)
            writer.add_scalar("Train/Episode Reward Mean",
                              np.mean(episode_rewards),
                              global_step=total_num_steps)
            writer.add_scalar("Train/Episode Length Mean",
                              np.mean(episode_length),
                              global_step=total_num_steps)
            writer.add_scalar("Train/Episode Success Rate",
                              np.mean(episode_success_rate),
                              global_step=total_num_steps)

            print(
                "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n"
                .format(j, total_num_steps,
                        int(total_num_steps / (end - start)),
                        len(episode_rewards), np.mean(episode_rewards),
                        np.median(episode_rewards), np.min(episode_rewards),
                        np.max(episode_rewards), dist_entropy, value_loss,
                        action_loss))

        if (args.eval_interval is not None and len(episode_rewards) > 1
                and j % args.eval_interval == 0):
            ob_rms = utils.get_vec_normalize(envs).ob_rms
            evaluate(actor_critic, ob_rms, args.env_name, args.seed,
                     args.num_processes, eval_log_dir, device)
Exemplo n.º 5
0
def get_args():
    parser = argparse.ArgumentParser(description='Batch_PPO')
    parser.add_argument('--task-id',
                        type=str,
                        default='AntBulletEnv-v0',
                        help='task name (default: Pendulum-v0)')
    parser.add_argument('--run-id',
                        type=str,
                        default='test',
                        help="name of the run")
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        help='random seed (default: 1)')
    parser.add_argument('--num-processes',
                        type=int,
                        default=12,
                        help='number of parallel processes (default: 12)')
    parser.add_argument("--disable-cuda", default=False, help='Disable CUDA')

    # Training config
    parser.add_argument(
        '--num-env-steps',
        type=int,
        default=5e6,
        help='number of environment steps to train (default: 10e6)')
    parser.add_argument('--lr',
                        type=float,
                        default=1e-3,
                        help='learning rate (default: 1e-3)')
    parser.add_argument('--eps',
                        type=float,
                        default=1e-5,
                        help='Optimizer epsilon (default: 1e-5)')
    parser.add_argument('--use-linear-lr-decay',
                        type=bool,
                        default=True,
                        help='use a linear schedule on the learning rate')
    parser.add_argument('--max-grad-norm',
                        type=float,
                        default=0.5,
                        help='max norm of gradients (default: 0.5)')
    parser.add_argument(
        '--num-steps',
        type=int,
        default=2048,
        help='number of forward environment steps (default: 1000)')

    # PPO config
    parser.add_argument('--gamma',
                        type=float,
                        default=0.99,
                        help='reward discount coefficient (default: 0.99)')
    parser.add_argument('--gae-lambda',
                        type=float,
                        default=0.95,
                        help='gae lambda parameter (default: 0.95)')
    parser.add_argument('--entropy-coef',
                        type=float,
                        default=0.01,
                        help='entropy term coefficient (default: 0.01)')
    parser.add_argument('--value-loss-coef',
                        type=float,
                        default=0.5,
                        help='value loss coefficient (default: 0.5)')

    parser.add_argument('--ppo-epoch',
                        type=int,
                        default=10,
                        help='number of ppo epochs (default: 10)')
    parser.add_argument('--clip-param',
                        type=float,
                        default=0.2,
                        help='ppo clip parameter (default: 0.2)')
    parser.add_argument('--num-mini-batch',
                        type=int,
                        default=32,
                        help='number of mini batches (default: 32)')

    # Log config
    parser.add_argument(
        '--log-interval',
        type=int,
        default=1,
        help='log interval, one log per n updates (default: 1)')
    parser.add_argument('--log-dir',
                        type=str,
                        default='log/',
                        help='directory to save agent logs (default: log/)')
    parser.add_argument(
        '--monitor-dir',
        type=str,
        default='monitor_log/',
        help='directory to save monitor logs (default: monitor_log/)')
    parser.add_argument(
        '--result-dir',
        type=str,
        default='results/',
        help='directory to save plot results (default: results/)')

    # Evaluate performance
    parser.add_argument('--test-iters',
                        type=int,
                        default=int(1e4),
                        help='test iterations (default: 1000)')
    parser.add_argument('--video-width',
                        type=int,
                        default=720,
                        help='video resolution (default: 720)')
    parser.add_argument('--video-height',
                        type=int,
                        default=720,
                        help='video resolution (default: 720)')

    # Saving and restoring setup
    parser.add_argument(
        '--save-interval',
        type=int,
        default=100,
        help='save interval, one save per n updates (default: 100)')
    parser.add_argument(
        '--save-dir',
        type=str,
        default='./trained_models/',
        help='directory to save agent logs (default: ./trained_models/)')

    args = parser.parse_args()

    # Create directories
    args.save_path = os.path.join("saves", args.task_id, args.run_id)
    args.monitor_dir = os.path.join(args.monitor_dir, args.task_id,
                                    args.run_id)
    args.result_dir = os.path.join(args.result_dir, args.task_id)

    os.makedirs(args.save_path, exist_ok=True)
    os.makedirs(args.monitor_dir, exist_ok=True)
    os.makedirs(args.result_dir, exist_ok=True)
    os.makedirs(args.log_dir, exist_ok=True)
    cleanup_log_dir(args.log_dir)

    # Setup device and random seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available() and not args.disable_cuda:
        args.device = torch.device('cuda')
        torch.cuda.manual_seed(args.seed)
    else:
        args.device = torch.device('cpu')
    torch.set_num_threads(1)

    print(' ' * 26 + 'Options')
    for k, v in vars(args).items():
        print(' ' * 26 + k + ': ' + str(v))

    return args
Exemplo n.º 6
0
def main():
    args = get_args()
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    log_dir = os.path.expanduser(args.log_dir)
    eval_log_dir = log_dir + "_eval"
    utils.cleanup_log_dir(log_dir)
    utils.cleanup_log_dir(eval_log_dir)

    if args.cuda and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")


    base=SEVN

    actor_critic, obs_rms = torch.load(save_dir, map_location=device)
    actor_critic.to(device)
    actor_critic.max_eval_success_rate = 0
    print("Passed!")
    num_processes = args.num_processes
    eval_recurrent_hidden_states = torch.zeros(
        args.num_processes, actor_critic.recurrent_hidden_state_size, device=device)
    eval_masks = torch.zeros(num_processes, 1, device=device)
    x = 0
    while x < 10:
        torch.manual_seed(args.seed + x)
        torch.cuda.manual_seed_all(args.seed + x)
        eval_envs = make_vec_envs(args.env_name, args.seed + x, args.num_processes,
                         args.gamma, args.log_dir, device, False, args.custom_gym)
        eval_episode_rewards = []
        eval_episode_length = []
        eval_episode_success_rate = []
        obs = eval_envs.reset()
        while len(eval_episode_rewards) < num_processes*100:
            with torch.no_grad():
                _, action, _, eval_recurrent_hidden_states = actor_critic.act(
                    obs,
                    eval_recurrent_hidden_states,
                    eval_masks,
                    deterministic=True)
            eval_envs.render()
            obs, _, done, infos = eval_envs.step(action)

            eval_masks = torch.tensor(
                [[0.0] if done_ else [1.0] for done_ in done],
                dtype=torch.float32,
                device=device)

            for info in infos:
                if 'episode' in info.keys():
                    if info['was_successful_trajectory']:
                        if args.mod: #Modified Reward Function
                            reward[idx]=10
                            episode_rewards.append(10)
                    else:
                        eval_episode_rewards.append(info['episode']['r'])
                    eval_episode_length.append(info['episode']['l'])
                    eval_episode_success_rate.append(info['was_successful_trajectory'])
        x+=1
        print(" Evaluation using {} episodes: mean reward {:.5f}, mean_length {:.2f}, mean_success {:.2f} \n".format(
        len(eval_episode_rewards), np.mean(eval_episode_rewards), np.mean(eval_episode_length), np.mean(eval_episode_success_rate)))    

    eval_envs.close()

    print(eval_episode_rewards)
    print(eval_episode_success_rate)
Exemplo n.º 7
0
def main():
    args = get_args()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    if args.cuda and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    log_dir = os.path.expanduser(args.log_dir)
    utils.cleanup_log_dir(log_dir)

    with open(log_dir + 'extras.csv', "w") as file:
        file.write("n, value_loss\n")

    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")

    envs = make_vec_envs(args.env_name, args.seed, args.num_processes,
                         args.gamma, args.log_dir, device, False)

    model = Policy(envs.observation_space.shape,
                   envs.action_space.n,
                   extra_kwargs={'use_backpack': args.algo == 'tdprop'})
    model.to(device)

    if args.algo == 'tdprop':
        from algo.sarsa_tdprop import SARSA
        agent = SARSA(model,
                      lr=args.lr,
                      eps=args.eps,
                      max_grad_norm=args.max_grad_norm,
                      beta_1=args.beta_1,
                      beta_2=args.beta_2,
                      n=args.num_steps,
                      num_processes=args.num_processes,
                      gamma=args.gamma)
    else:
        from algo.sarsa import SARSA
        agent = SARSA(model,
                      lr=args.lr,
                      eps=args.eps,
                      max_grad_norm=args.max_grad_norm,
                      beta_1=args.beta_1,
                      beta_2=args.beta_2,
                      algo=args.algo)

    explore_policy = utils.eps_greedy
    rollouts = RolloutStorage(args.num_steps, args.num_processes,
                              envs.observation_space.shape, envs.action_space,
                              model.recurrent_hidden_state_size)

    obs = envs.reset()
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    episode_rewards = deque(maxlen=10)
    start = time.time()
    num_updates = int(
        args.num_env_steps) // args.num_steps // args.num_processes
    for j in range(num_updates):
        for step in range(args.num_steps):
            # Sample actions
            with torch.no_grad():
                qs = model(rollouts.obs[step])
                _, dist = explore_policy(qs, args.exploration)
                actions = dist.sample().unsqueeze(-1)
                value = qs.gather(-1, actions)

            # Obser reward and next obs
            obs, reward, done, infos = envs.step(actions)
            for info in infos:
                if 'episode' in info.keys():
                    episode_rewards.append(info['episode']['r'])

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            bad_masks = torch.FloatTensor(
                [[0.0] if 'bad_transition' in info.keys() else [1.0]
                 for info in infos])
            rollouts.insert(obs, torch.FloatTensor([0.0]), actions, value,
                            value, reward, masks, bad_masks)
        with torch.no_grad():
            next_qs = model(rollouts.obs[-1])
            next_probs, _ = explore_policy(next_qs, args.exploration)
            next_value = (next_probs * next_qs).sum(-1).unsqueeze(-1)

        rollouts.compute_returns(next_value, args.gamma)

        value_loss = agent.update(rollouts, explore_policy, args.exploration)

        rollouts.after_update()

        # save for every interval-th episode or for the last epoch
        if (j % args.save_interval == 0 or j == num_updates - 1):
            save_path = os.path.join(args.log_dir, args.algo)
            try:
                os.makedirs(save_path)
            except OSError:
                pass
            torch.save([
                list(model.parameters()),
                getattr(utils.get_vec_normalize(envs), 'ob_rms', None)
            ], os.path.join(save_path, args.env_name + ".pt"))

        if j % args.log_interval == 0 and len(episode_rewards) > 1:
            total_num_steps = (j + 1) * args.num_processes * args.num_steps
            end = time.time()
            print(
                    ("Updates {}, num timesteps {}, FPS {}\n" + \
                            "Last {} training episodes: mean/median reward {:.1f}/{:.1f}" + \
                            ", min/max reward {:.1f}/{:.1f}\n" + \
                            "entropy {:.2f}, value loss {:.4f}")
                .format(j, total_num_steps,
                        int(total_num_steps / (end - start)),
                        len(episode_rewards), np.mean(episode_rewards),
                        np.median(episode_rewards), np.min(episode_rewards),
                        np.max(episode_rewards), dist.entropy().mean().item(), value_loss))
            with open(log_dir + 'extras.csv', "a") as file:
                file.write(
                    str(total_num_steps) + ", " + str(value_loss) + "\n")