Exemplo n.º 1
0
def main(repeat_num):
    args = get_args()
    print("start the train function")

    args.init_sigma = 0.6
    args.lr = 0.001

    device = torch.device("cpu")

    # plot_weight_histogram(parameters)
    actor_critic_policy = torch.load(
        "/Users/djrg/code/instincts/modular_rl_safety_gym/trained_models/pulled_from_server/double_rl_experiments/policy_plus_instinct/ba00287951_0_dense_seesaw_phase/model_rl_policy.pt")
    actor_critic_instinct = torch.load(
        "/Users/djrg/code/instincts/modular_rl_safety_gym/trained_models/pulled_from_server/double_rl_experiments/policy_plus_instinct/ba00287951_0_dense_seesaw_phase/model_rl_instinct.pt")

    # Init the environment
    env_name = "Safexp-PointGoal1-v0"
    eval_envs = make_vec_envs(env_name, np.random.randint(2 ** 32), 1,
                              args.gamma, None, device, allow_early_resets=True, normalize=args.norm_vectors)
    ob_rms = utils.get_vec_normalize(eval_envs)
    if ob_rms is not None:
        ob_rms = ob_rms.ob_rms

    for _ in range(repeat_num):
        fits, info = evaluate(EvalActorCritic(actor_critic_policy, actor_critic_instinct), ob_rms, eval_envs, 1,
                              device, instinct_on=True,
                              visualise=True)

    print(f"fitness = {fits.item()}, cost = {info['cost']}")
Exemplo n.º 2
0
def instinct_loop_ppo(
        args,
        learning_rate,
        num_steps,
        num_updates,
        inst_on,
        visualize,
        save_dir
):
    torch.set_num_threads(1)
    log_writer = SummaryWriter(save_dir, max_queue=1, filename_suffix="log")
    device = torch.device("cpu")

    env_name = ENV_NAME_BOX #"Safexp-PointGoal1-v0"
    envs = make_vec_envs(env_name, np.random.randint(2 ** 32), NUM_PROC,
                         args.gamma, None, device, allow_early_resets=True, normalize=args.norm_vectors)
    eval_envs = make_vec_envs(env_name, np.random.randint(2 ** 32), 1,
                         args.gamma, None, device, allow_early_resets=True, normalize=args.norm_vectors)

    actor_critic_policy = init_default_ppo(envs, log(args.init_sigma))

    # Prepare modified observation shape for instinct
    obs_shape = envs.observation_space.shape
    inst_action_space = deepcopy(envs.action_space)
    inst_obs_shape = list(obs_shape)
    inst_obs_shape[0] = inst_obs_shape[0] + envs.action_space.shape[0]
    # Prepare modified action space for instinct
    inst_action_space.shape = list(inst_action_space.shape)
    inst_action_space.shape[0] = inst_action_space.shape[0] + 1
    inst_action_space.shape = tuple(inst_action_space.shape)
    actor_critic_instinct = torch.load("pretrained_instinct_h100.pt")

    actor_critic_policy.to(device)
    actor_critic_instinct.to(device)

    agent_policy = algo.PPO(
        actor_critic_policy,
        args.clip_param,
        args.ppo_epoch,
        args.num_mini_batch,
        args.value_loss_coef,
        args.entropy_coef,
        lr=learning_rate,
        eps=args.eps,
        max_grad_norm=args.max_grad_norm)

    rollouts = RolloutStorage(num_steps, NUM_PROC,
                                   obs_shape, envs.action_space,
                                   actor_critic_policy.recurrent_hidden_state_size)

    obs = envs.reset()
    i_obs = make_instinct_input(obs, torch.zeros((NUM_PROC, envs.action_space.shape[0])))  # Add zero action to the observation
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    fitnesses = []
    best_fitness_so_far = float("-Inf")

    masks = torch.ones(num_steps + 1, NUM_PROC, 1)
    instinct_recurrent_hidden_states = torch.zeros(num_steps + 1, NUM_PROC, actor_critic_instinct.recurrent_hidden_state_size)

    for j in range(num_updates):
        training_collisions_current_update = 0
        for step in range(num_steps):
            # Sample actions
            with torch.no_grad():
                # (value, action, action_log_probs, rnn_hxs), (instinct_value, instinct_action, instinct_outputs_log_prob, i_rnn_hxs), final_action
                value, action, action_log_probs, recurrent_hidden_states = actor_critic_policy.act(
                    rollouts.obs[step],
                    rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step],
                    deterministic=False
                )
                instinct_value, instinct_action, instinct_outputs_log_prob, instinct_recurrent_hidden_states = actor_critic_instinct.act(
                    i_obs,
                    instinct_recurrent_hidden_states,
                    masks,
                    deterministic=False,
                )

            # Combine two networks
            final_action, i_control = policy_instinct_combinator(action, instinct_action)
            obs, reward, done, infos = envs.step(final_action)
            #envs.render()

            training_collisions_current_update += sum([i['cost'] for i in infos])
            modded_reward, violation_cost = reward_cost_combinator(reward, infos, NUM_PROC, i_control)

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done])
            bad_masks = torch.FloatTensor([[0.0] if 'bad_transition' in info.keys() else [1.0] for info in infos])
            # i_obs = torch.cat([obs, action], dim=1)
            i_obs = make_instinct_input(obs, action)
            rollouts.insert(obs, recurrent_hidden_states, action, action_log_probs,
                                 value, modded_reward, masks, bad_masks)

        with torch.no_grad():
            next_value_policy = actor_critic_policy.get_value(rollouts.obs[-1],
                                                                rollouts.recurrent_hidden_states[-1],
                                                                rollouts.masks[-1].detach())

        rollouts.compute_returns(next_value_policy, args.use_gae, args.gamma,
                                args.gae_lambda, args.use_proper_time_limits)

        print("training policy")
        # Instinct training phase
        p_before = deepcopy(actor_critic_instinct)
        val_loss, action_loss, dist_entropy = agent_policy.update(rollouts)
        p_after = deepcopy(actor_critic_instinct)
        assert compare_two_models(p_before, p_after), "policy changed when it shouldn't"

        rollouts.after_update()

        ob_rms = utils.get_vec_normalize(envs)
        if ob_rms is not None:
            ob_rms = ob_rms.ob_rms

        fits, info = evaluate(EvalActorCritic(actor_critic_policy, actor_critic_instinct), ob_rms, eval_envs, NUM_PROC,
                                    reward_cost_combinator, device, instinct_on=inst_on, visualise=visualize)
        instinct_reward = info['instinct_reward']
        hazard_collisions = info['hazard_collisions']
        print(
            f"Step {j}, Fitness {fits.item()}, value_loss instinct = {val_loss}, action_loss instinct= {action_loss}, "
            f"dist_entropy instinct = {dist_entropy}")
        print(
            f"Step {j}, Cost {instinct_reward}")
        print("-----------------------------------------------------------------")

        # Tensorboard logging
        log_writer.add_scalar("Task reward", fits.item(), j)
        log_writer.add_scalar("cost/Training hazard collisions", training_collisions_current_update, j)
        log_writer.add_scalar("cost/Instinct reward", instinct_reward, j)
        log_writer.add_scalar("cost/Eval hazard collisions", hazard_collisions, j)
        log_writer.add_scalar("value loss", val_loss, j)
        log_writer.add_scalar("action loss", action_loss, j)
        log_writer.add_scalar("dist entropy", dist_entropy, j)

        fitnesses.append(fits)
        if fits.item() > best_fitness_so_far:
            best_fitness_so_far = fits.item()
            torch.save(actor_critic_instinct, join(save_dir, "model_rl_instinct.pt"))
            torch.save(actor_critic_policy, join(save_dir, "model_rl_policy.pt"))
        torch.save(actor_critic_instinct, join(save_dir, "model_rl_instinct_latest.pt"))
        torch.save(actor_critic_policy, join(save_dir, "model_rl_policy_latest.pt"))
        torch.save(actor_critic_policy, join(save_dir, f"model_rl_policy_latest_{j}.pt"))
        pickle.dump(ob_rms, open(join(save_dir, "ob_rms.p"), "wb"))
    return (fitnesses[-1]), 0, 0
Exemplo n.º 3
0
def main(repeat_num):
    args = get_args()
    print("start the train function")
    args.init_sigma = 0.6
    args.lr = 0.001
    device = torch.device("cpu")

    # Init the environment
    # env_name = "Safexp-PointGoal1-v0"
    eval_envs = make_vec_envs(env_name,
                              np.random.randint(2**32),
                              1,
                              args.gamma,
                              None,
                              device,
                              allow_early_resets=True,
                              normalize=args.norm_vectors)
    obs_shape = eval_envs.observation_space.shape
    actor_critic_policy = init_default_ppo(eval_envs, log(args.init_sigma))

    # Prepare modified action space for instinct
    inst_action_space = deepcopy(eval_envs.action_space)
    inst_obs_shape = list(obs_shape)
    inst_obs_shape[0] = inst_obs_shape[0] + eval_envs.action_space.shape[0]

    inst_action_space.shape = list(inst_action_space.shape)
    inst_action_space.shape[0] = inst_action_space.shape[0] + 1
    inst_action_space.shape = tuple(inst_action_space.shape)
    actor_critic_instinct = Policy(tuple(inst_obs_shape),
                                   inst_action_space,
                                   init_log_std=log(args.init_sigma),
                                   base_kwargs={'recurrent': False})

    title = "baseline_pretrained_hh_10"
    # f = open(f"/Users/djgr/pulled_from_server/evaluate_instinct_all_inputs_task_switch_button/real_safety_tasks_easier/sweep_eval_hazard_param_BUTTON_more_space/{title}.csv", "w")
    actor_critic_policy = torch.load(
        # f"/Users/djgr/pulled_from_server/evaluate_instinct_all_inputs_task_switch_button/real_safety_tasks_easier/sweep_eval_hazard_param_BOX_more_space_more_time/hh_10_baseline_centered_noHaz/model_rl_policy_latest.pt"
        "/home/calavera/pulled_from_server/evaluate_instinct_all_inputs_task_switch_button/real_safety_tasks_easier/sweep_eval_hazard_param_BOX_more_space/hh_10/model_rl_policy_latest.pt"
        # "/home/calavera/code/ITU_work/IR2L_master/pretrained_policy.pt"
    )
    actor_critic_instinct = torch.load(
        f"/home/calavera/pulled_from_server/evaluate_instinct_all_inputs_task_switch_button/real_safety_tasks_easier/sweep_eval_hazard_param_BOX_more_space/hh_10/model_rl_instinct_latest.pt"
    )

    ob_rms = utils.get_vec_normalize(eval_envs)

    if ob_rms is not None:
        ob_rms = ob_rms.ob_rms
    ob_rms = pickle.load(
        open(
            f"/home/calavera/pulled_from_server/evaluate_instinct_all_inputs_task_switch_button/real_safety_tasks_easier/sweep_eval_hazard_param_BOX_more_space/hh_10/ob_rms.p",
            "rb"))

    for _ in range(repeat_num):
        fits, info = evaluate(
            # EvalActorCritic(actor_critic_policy, actor_critic_instinct, det_policy=True, det_instinct=True),
            EvalActorCritic(actor_critic_policy, actor_critic_instinct),
            ob_rms,
            eval_envs,
            1,
            reward_cost_combinator,
            device,
            instinct_on=True,
            visualise=True)
        visualise_values_over_path(info['plot_info'])

        # f.write(f"fitness; {fits.item()}; hazard_collisions; {info['hazard_collisions']}\n")
        # f.flush()

        print(f"{info['hazard_collisions']}")
        print(
            f"fitness; {fits.item()}; hazard_collisions; {info['hazard_collisions']}\n"
        )
Exemplo n.º 4
0
def train_maml_like_ppo_(
    init_model,
    args,
    learning_rate,
    num_episodes=20,
    num_updates=1,
    vis=False,
    run_idx=0,
    use_linear_lr_decay=False,
):
    num_steps = num_episodes * 100

    torch.set_num_threads(1)
    device = torch.device("cpu")

    envs = make_vec_envs(ENV_NAME, seeding.create_seed(None), NUM_PROC,
                         args.gamma, None, device, allow_early_resets=True, normalize=args.norm_vectors)
    raw_env = navigation_2d.unpeele_navigation_env(envs, 0)

    # raw_env.set_arguments(args.rm_nogo, args.reduce_goals, True, args.large_nogos)
    new_task = raw_env.sample_tasks(run_idx)
    raw_env.reset_task(new_task[0])

    # actor_critic = Policy(
    #     envs.observation_space.shape,
    #     envs.action_space,
    #     base_kwargs={'recurrent': args.recurrent_policy})
    actor_critic = copy.deepcopy(init_model)
    actor_critic.to(device)

    agent = algo.PPO(
        actor_critic,
        args.clip_param,
        args.ppo_epoch,
        args.num_mini_batch,
        args.value_loss_coef,
        args.entropy_coef,
        lr=learning_rate,
        eps=args.eps,
        max_grad_norm=args.max_grad_norm)

    rollouts = RolloutStorage(num_steps, NUM_PROC,
                              envs.observation_space.shape, envs.action_space,
                              actor_critic.recurrent_hidden_state_size)

    obs = envs.reset()
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    fitnesses = []

    for j in range(num_updates):

        # if args.use_linear_lr_decay:
        #    # decrease learning rate linearly
        #    utils.update_linear_schedule(
        #        agent.optimizer, j, num_updates,
        #        agent.optimizer.lr if args.algo == "acktr" else args.lr)
        min_c_rew = float("inf")
        vis = []
        offending = []
        for step in range(num_steps):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                    rollouts.obs[step], rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step])

            # Obser reward and next obs
            obs, reward, done, infos = envs.step(action)
            if done[0]:
                c_rew = infos[0]["cummulative_reward"]
                vis.append((infos[0]['path'], infos[0]['goal']))
                offending.extend(infos[0]['offending'])
                if c_rew < min_c_rew:
                    min_c_rew = c_rew
            # If done then clean the history of observations.
            masks = torch.FloatTensor(
                [[0.0] if done_ else [1.0] for done_ in done])
            bad_masks = torch.FloatTensor(
                [[0.0] if 'bad_transition' in info.keys() else [1.0]
                 for info in infos])
            rollouts.insert(obs, recurrent_hidden_states, action,
                            action_log_prob, value, reward, masks, bad_masks)

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.obs[-1], rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1]).detach()

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.gae_lambda, args.use_proper_time_limits)

        value_loss, action_loss, dist_entropy = agent.update(rollouts)

        rollouts.after_update()

        ob_rms = utils.get_vec_normalize(envs)
        if ob_rms is not None:
            ob_rms = ob_rms.ob_rms

        fits, info = evaluate(actor_critic, ob_rms, envs, NUM_PROC, device)
        print(f"fitness {fits} update {j+1}")
        if (j+1) % 1 == 0:
            vis_path(vis, eval_path_rec=info['path'], offending=offending)
        fitnesses.append(fits)

    return fitnesses[-1], info[0]['reached'], None
Exemplo n.º 5
0
def inner_loop_ppo(args, learning_rate, num_steps, num_updates, inst_on,
                   visualize, save_dir):
    torch.set_num_threads(1)
    log_writer = SummaryWriter(save_dir, max_queue=1, filename_suffix="log")
    device = torch.device("cpu")

    env_name = ENV_NAME  # "Safexp-PointGoal1-v0"
    envs = make_vec_envs(env_name,
                         np.random.randint(2**32),
                         NUM_PROC,
                         args.gamma,
                         None,
                         device,
                         allow_early_resets=True,
                         normalize=args.norm_vectors)
    eval_envs = make_vec_envs(env_name,
                              np.random.randint(2**32),
                              1,
                              args.gamma,
                              None,
                              device,
                              allow_early_resets=True,
                              normalize=args.norm_vectors)

    actor_critic_policy = init_default_ppo(envs, log(args.init_sigma))

    # Prepare modified observation shape for instinct
    obs_shape = envs.observation_space.shape
    inst_action_space = deepcopy(envs.action_space)
    inst_obs_shape = list(obs_shape)
    inst_obs_shape[0] = inst_obs_shape[0] + envs.action_space.shape[0]
    # Prepare modified action space for instinct
    inst_action_space.shape = list(inst_action_space.shape)
    inst_action_space.shape[0] = inst_action_space.shape[0] + 1
    inst_action_space.shape = tuple(inst_action_space.shape)
    actor_critic_instinct = Policy(tuple(inst_obs_shape),
                                   inst_action_space,
                                   init_log_std=log(args.init_sigma),
                                   base_kwargs={'recurrent': False})
    actor_critic_policy.to(device)
    actor_critic_instinct.to(device)

    agent_policy = algo.PPO(actor_critic_policy,
                            args.clip_param,
                            args.ppo_epoch,
                            args.num_mini_batch,
                            args.value_loss_coef,
                            args.entropy_coef,
                            lr=learning_rate,
                            eps=args.eps,
                            max_grad_norm=args.max_grad_norm)

    agent_instinct = algo.PPO(actor_critic_instinct,
                              args.clip_param,
                              args.ppo_epoch,
                              args.num_mini_batch,
                              args.value_loss_coef,
                              args.entropy_coef,
                              lr=learning_rate,
                              eps=args.eps,
                              max_grad_norm=args.max_grad_norm)

    rollouts_rewards = RolloutStorage(
        num_steps, NUM_PROC, envs.observation_space.shape, envs.action_space,
        actor_critic_policy.recurrent_hidden_state_size)

    rollouts_cost = RolloutStorage(
        num_steps, NUM_PROC, inst_obs_shape, inst_action_space,
        actor_critic_instinct.recurrent_hidden_state_size)

    obs = envs.reset()
    i_obs = torch.cat(
        [obs, torch.zeros((NUM_PROC, envs.action_space.shape[0]))],
        dim=1)  # Add zero action to the observation
    rollouts_rewards.obs[0].copy_(obs)
    rollouts_rewards.to(device)
    rollouts_cost.obs[0].copy_(i_obs)
    rollouts_cost.to(device)

    fitnesses = []
    best_fitness_so_far = float("-Inf")
    is_instinct_training = False
    for j in range(num_updates):
        is_instinct_training_old = is_instinct_training
        is_instinct_training = phase_shifter(
            j, PHASE_LENGTH,
            len(TrainPhases)) == TrainPhases.INSTINCT_TRAIN_PHASE.value
        is_instinct_deterministic = not is_instinct_training
        is_policy_deterministic = not is_instinct_deterministic
        for step in range(num_steps):
            # Sample actions
            with torch.no_grad():
                # (value, action, action_log_probs, rnn_hxs), (instinct_value, instinct_action, instinct_outputs_log_prob, i_rnn_hxs), final_action
                value, action, action_log_probs, recurrent_hidden_states = actor_critic_policy.act(
                    rollouts_rewards.obs[step],
                    rollouts_rewards.recurrent_hidden_states[step],
                    rollouts_rewards.masks[step],
                    deterministic=is_policy_deterministic)
                instinct_value, instinct_action, instinct_outputs_log_prob, instinct_recurrent_hidden_states = actor_critic_instinct.act(
                    rollouts_cost.obs[step],
                    rollouts_cost.recurrent_hidden_states[step],
                    rollouts_cost.masks[step],
                    deterministic=is_instinct_deterministic,
                )

            # Combine two networks
            final_action, i_control = policy_instinct_combinator(
                action, instinct_action)
            obs, reward, done, infos = envs.step(final_action)
            # envs.render()

            reward, violation_cost = reward_cost_combinator(
                reward, infos, NUM_PROC, i_control)

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            bad_masks = torch.FloatTensor(
                [[0.0] if 'bad_transition' in info.keys() else [1.0]
                 for info in infos])
            rollouts_rewards.insert(obs, recurrent_hidden_states, action,
                                    action_log_probs, value, reward, masks,
                                    bad_masks)
            i_obs = torch.cat([obs, action], dim=1)
            rollouts_cost.insert(i_obs, instinct_recurrent_hidden_states,
                                 instinct_action, instinct_outputs_log_prob,
                                 instinct_value, violation_cost, masks,
                                 bad_masks)

        with torch.no_grad():
            next_value_policy = actor_critic_policy.get_value(
                rollouts_rewards.obs[-1],
                rollouts_rewards.recurrent_hidden_states[-1],
                rollouts_rewards.masks[-1]).detach()
            next_value_instinct = actor_critic_instinct.get_value(
                rollouts_cost.obs[-1],
                rollouts_cost.recurrent_hidden_states[-1],
                rollouts_cost.masks[-1].detach())

        rollouts_rewards.compute_returns(next_value_policy, args.use_gae,
                                         args.gamma, args.gae_lambda,
                                         args.use_proper_time_limits)
        rollouts_cost.compute_returns(next_value_instinct, args.use_gae,
                                      args.gamma, args.gae_lambda,
                                      args.use_proper_time_limits)

        if not is_instinct_training:
            print("training policy")
            # Policy training phase
            p_before = deepcopy(agent_instinct.actor_critic)
            value_loss, action_loss, dist_entropy = agent_policy.update(
                rollouts_rewards)
            val_loss_i, action_loss_i, dist_entropy_i = 0, 0, 0
            p_after = deepcopy(agent_instinct.actor_critic)
            assert compare_two_models(
                p_before, p_after), "policy changed when it shouldn't"
        else:
            print("training instinct")
            # Instinct training phase
            value_loss, action_loss, dist_entropy = 0, 0, 0
            p_before = deepcopy(agent_policy.actor_critic)
            val_loss_i, action_loss_i, dist_entropy_i = agent_instinct.update(
                rollouts_cost)
            p_after = deepcopy(agent_policy.actor_critic)
            assert compare_two_models(
                p_before, p_after), "policy changed when it shouldn't"

        rollouts_rewards.after_update()
        rollouts_cost.after_update()

        ob_rms = utils.get_vec_normalize(envs)
        if ob_rms is not None:
            ob_rms = ob_rms.ob_rms

        fits, info = evaluate(EvalActorCritic(actor_critic_policy,
                                              actor_critic_instinct),
                              ob_rms,
                              eval_envs,
                              NUM_PROC,
                              reward_cost_combinator,
                              device,
                              instinct_on=inst_on,
                              visualise=visualize)
        instinct_reward = info['instinct_reward']
        eval_hazard_collisions = info['hazard_collisions']
        print(
            f"Step {j}, Fitness {fits.item()}, value_loss = {value_loss}, action_loss = {action_loss}, "
            f"dist_entropy = {dist_entropy}")
        print(
            f"Step {j}, Instinct reward {instinct_reward}, value_loss instinct = {val_loss_i}, action_loss instinct= {action_loss_i}, "
            f"dist_entropy instinct = {dist_entropy_i} hazard_collisions = {eval_hazard_collisions}"
        )
        print(
            "-----------------------------------------------------------------"
        )

        # Tensorboard logging
        log_writer.add_scalar("fitness", fits.item(), j)
        log_writer.add_scalar("value loss", value_loss, j)
        log_writer.add_scalar("action loss", action_loss, j)
        log_writer.add_scalar("dist entropy", dist_entropy, j)

        log_writer.add_scalar("cost/instinct_reward", instinct_reward, j)
        log_writer.add_scalar("cost/hazard_collisions", eval_hazard_collisions,
                              j)
        log_writer.add_scalar("value loss instinct", val_loss_i, j)
        log_writer.add_scalar("action loss instinct", action_loss_i, j)
        log_writer.add_scalar("dist entropy instinct", dist_entropy_i, j)

        fitnesses.append(fits)
        if fits.item() > best_fitness_so_far:
            best_fitness_so_far = fits.item()
            torch.save(actor_critic_policy, join(save_dir,
                                                 "model_rl_policy.pt"))
            torch.save(actor_critic_instinct,
                       join(save_dir, "model_rl_instinct.pt"))
        if is_instinct_training != is_instinct_training_old:
            torch.save(actor_critic_policy,
                       join(save_dir, f"model_rl_policy_update_{j}.pt"))
            torch.save(actor_critic_instinct,
                       join(save_dir, f"model_rl_instinct_update_{j}.pt"))
        torch.save(actor_critic_policy,
                   join(save_dir, "model_rl_policy_latest.pt"))
        torch.save(actor_critic_instinct,
                   join(save_dir, "model_rl_instinct_latest.pt"))
    return (fitnesses[-1]), 0, 0
Exemplo n.º 6
0
def inner_loop_ppo(
    weights,
    args,
    learning_rate,
    num_steps,
    num_updates,
    run_idx,
    input_envs,
):

    torch.set_num_threads(1)
    device = torch.device("cpu")
    #print(input_envs.venv.spec._kwargs['config']['goal_locations'])
    #env_name = register_set_goal(run_idx)

    #envs = make_vec_envs(env_name, np.random.randint(2**32), NUM_PROC,
    #                     args.gamma, None, device, allow_early_resets=True, normalize=args.norm_vectors)
    actor_critic = init_ppo(input_envs, log(args.init_sigma))
    actor_critic.to(device)

    # apply the weights to the model
    apply_from_list(weights, actor_critic)


    agent = algo.PPO(
        actor_critic,
        args.clip_param,
        args.ppo_epoch,
        args.num_mini_batch,
        args.value_loss_coef,
        args.entropy_coef,
        lr=learning_rate,
        eps=args.eps,
        max_grad_norm=args.max_grad_norm)

    rollouts = RolloutStorage(num_steps, NUM_PROC,
                              input_envs.observation_space.shape, input_envs.action_space,
                              actor_critic.recurrent_hidden_state_size)

    obs = input_envs.reset()
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    fitnesses = []
    violation_cost = 0

    for j in range(num_updates):

        episode_step_counter = 0
        for step in range(num_steps):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob, recurrent_hidden_states, (final_action, _) = actor_critic.act(
                    rollouts.obs[step], rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step])
            # Obser reward and next obs
            obs, reward, done, infos = input_envs.step(final_action)
            episode_step_counter += 1

            # Count the cost
            total_reward = reward
            for info in infos:
                violation_cost += info['cost']
                total_reward -= info['cost']

            # If done then clean the history of observations.
            masks = torch.FloatTensor(
                [[0.0] if done_ else [1.0] for done_ in done])
            bad_masks = torch.FloatTensor(
                [[0.0] if 'bad_transition' in info.keys() else [1.0]
                 for info in infos])
            rollouts.insert(obs, recurrent_hidden_states, action,
                            action_log_prob, value, total_reward, masks, bad_masks)

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.obs[-1], rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1]).detach()

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.gae_lambda, args.use_proper_time_limits)

        value_loss, action_loss, dist_entropy = agent.update(rollouts)

        rollouts.after_update()

        ob_rms = utils.get_vec_normalize(input_envs)
        if ob_rms is not None:
            ob_rms = ob_rms.ob_rms

        fits, info = evaluate(actor_critic, ob_rms, input_envs, NUM_PROC, device)
        fitnesses.append(fits)

    return (fitnesses[-1]), 0, 0