예제 #1
0
def experiment(variant):
    common.initialise(variant)

    expl_envs, eval_envs = common.create_environments(variant)

    (
        obs_shape,
        obs_space,
        action_space,
        n,
        mlp,
        channels,
        fc_input,
    ) = common.get_spaces(expl_envs)

    # # CHANGE TO ORDINAL ACTION SPACE
    # action_space = gym.spaces.Box(-np.inf, np.inf, (8,))
    # expl_envs.action_space = action_space
    # eval_envs.action_space = action_space
    ANCILLARY_GOAL_SIZE = variant["ancillary_goal_size"]
    SYMBOLIC_ACTION_SIZE = 12

    base = common.create_networks(variant, n, mlp, channels, fc_input)
    control_base = common.create_networks(
        variant, n, mlp, channels,
        fc_input + SYMBOLIC_ACTION_SIZE)  # for uvfa goal representation

    dist = common.create_symbolic_action_distributions(variant["action_space"],
                                                       base.output_size)

    control_dist = distributions.Categorical(base.output_size, action_space.n)

    eval_learner = WrappedPolicy(
        obs_shape,
        action_space,
        ptu.device,
        base=base,
        deterministic=True,
        dist=dist,
        num_processes=variant["num_processes"],
        obs_space=obs_space,
    )

    planner = ENHSPPlanner()

    # multihead
    # eval_controller = CraftController(
    #     MultiPolicy(
    #         obs_shape,
    #         action_space,
    #         ptu.device,
    #         18,
    #         base=base,
    #         deterministic=True,
    #         num_processes=variant["num_processes"],
    #         obs_space=obs_space,
    #     )
    # )

    # expl_controller = CraftController(
    #     MultiPolicy(
    #         obs_shape,
    #         action_space,
    #         ptu.device,
    #         18,
    #         base=base,
    #         deterministic=False,
    #         num_processes=variant["num_processes"],
    #         obs_space=obs_space,
    #     )
    # )

    # uvfa
    eval_controller = CraftController(
        WrappedPolicy(
            obs_shape,
            action_space,
            ptu.device,
            base=control_base,
            dist=control_dist,
            deterministic=True,
            num_processes=variant["num_processes"],
            obs_space=obs_space,
            symbolic_action_size=SYMBOLIC_ACTION_SIZE,
        ),
        n=n,
    )

    expl_controller = CraftController(
        WrappedPolicy(
            obs_shape,
            action_space,
            ptu.device,
            base=control_base,
            dist=control_dist,
            deterministic=False,
            num_processes=variant["num_processes"],
            obs_space=obs_space,
            symbolic_action_size=SYMBOLIC_ACTION_SIZE,
        ),
        n=n,
    )
    function_env = gym.make(variant["env_name"])

    eval_policy = LearnPlanPolicy(
        eval_learner,
        planner,
        eval_controller,
        num_processes=variant["num_processes"],
        vectorised=True,
        env=function_env,
    )

    expl_learner = WrappedPolicy(
        obs_shape,
        action_space,
        ptu.device,
        base=base,
        deterministic=False,
        dist=dist,
        num_processes=variant["num_processes"],
        obs_space=obs_space,
    )

    expl_policy = LearnPlanPolicy(
        expl_learner,
        planner,
        expl_controller,
        num_processes=variant["num_processes"],
        vectorised=True,
        env=function_env,
    )

    eval_path_collector = ThreeTierStepCollector(
        eval_envs,
        eval_policy,
        ptu.device,
        ANCILLARY_GOAL_SIZE,
        SYMBOLIC_ACTION_SIZE,
        max_num_epoch_paths_saved=variant["algorithm_kwargs"]
        ["num_eval_steps_per_epoch"],
        num_processes=variant["num_processes"],
        render=variant["render"],
        gamma=1,
        no_plan_penalty=True,
        meta_num_epoch_paths=variant["meta_num_steps"],
    )
    expl_path_collector = ThreeTierStepCollector(
        expl_envs,
        expl_policy,
        ptu.device,
        ANCILLARY_GOAL_SIZE,
        SYMBOLIC_ACTION_SIZE,
        max_num_epoch_paths_saved=variant["num_steps"],
        num_processes=variant["num_processes"],
        render=variant["render"],
        gamma=variant["trainer_kwargs"]["gamma"],
        no_plan_penalty=variant.get("no_plan_penalty", False),
        meta_num_epoch_paths=variant["meta_num_steps"],
    )
    # added: created rollout(5,1,(4,84,84),Discrete(6),1), reset env and added obs to rollout[step]

    learn_trainer = PPOTrainer(actor_critic=expl_policy.learner,
                               **variant["trainer_kwargs"])
    control_trainer = PPOTrainer(actor_critic=expl_policy.controller.policy,
                                 **variant["trainer_kwargs"])
    trainer = MultiTrainer([control_trainer, learn_trainer])
    # missing: by this point, rollout back in sync.
    replay_buffer = EnvReplayBuffer(variant["replay_buffer_size"], expl_envs)
    # added: replay buffer is new
    algorithm = TorchIkostrikovRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_envs,
        evaluation_env=eval_envs,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant["algorithm_kwargs"],
        # batch_size,
        # max_path_length,
        # num_epochs,
        # num_eval_steps_per_epoch,
        # num_expl_steps_per_train_loop,
        # num_trains_per_train_loop,
        # num_train_loops_per_epoch=1,
        # min_num_steps_before_training=0,
    )

    algorithm.to(ptu.device)

    algorithm.train()
def experiment(variant):
    # common.initialise(variant)

    setup_logger("name-of-experiment", variant=variant)
    ptu.set_gpu_mode(True)

    expl_env = gym.make(variant["env_name"], seed=5)
    eval_env = gym.make(variant["env_name"], seed=5)

    ANCILLARY_GOAL_SIZE = 16
    SYMBOLIC_ACTION_SIZE = (
        12
    )  # Size of embedding (ufva/multihead) for goal space direction to controller
    GRID_SIZE = 31

    action_dim = ANCILLARY_GOAL_SIZE
    symbolic_action_space = gym.spaces.Discrete(ANCILLARY_GOAL_SIZE)
    symb_env = gym.make(variant["env_name"])
    symb_env.action_space = symbolic_action_space

    (
        obs_shape,
        obs_space,
        action_space,
        n,
        mlp,
        channels,
        fc_input,
    ) = common.get_spaces(expl_env)

    qf = Mlp(
        input_size=n,
        output_size=action_dim,
        hidden_sizes=[256, 256],
        init_w=variant["init_w"],
        b_init_value=variant["b_init_value"],
    )
    target_qf = Mlp(
        input_size=n,
        output_size=action_dim,
        hidden_sizes=[256, 256],
        init_w=variant["init_w"],
        b_init_value=variant["b_init_value"],
    )

    planner = ENHSPPlanner()

    # collect
    filepath = "/home/achester/anaconda3/envs/goal-gen/.guild/runs/e77c75eed02e4b38a0a308789fbfcbd8/data/params.pkl"  # collect
    with (open(filepath, "rb")) as openfile:
        while True:
            try:
                policies = pickle.load(openfile)
            except EOFError:
                break

    loaded_collect_policy = policies["exploration/policy"]
    loaded_collect_policy.rnn_hxs = loaded_collect_policy.rnn_hxs[0].unsqueeze(
        0)
    eval_collect = CraftController(loaded_collect_policy, n=GRID_SIZE)
    expl_collect = CraftController(loaded_collect_policy, n=GRID_SIZE)

    # other
    # filepath = "/home/achester/anaconda3/envs/goal-gen/.guild/runs/cf5c31afe0724acd8f6398d77a80443e/data/params.pkl"  # other (RC 28)
    filepath = "/home/achester/anaconda3/envs/goal-gen/.guild/runs/4989f4bcbadb4ac58c3668c068d63225/data/params.pkl"  # other (RC 55)
    # filepath = "/home/achester/Documents/misc/craft-model/params.pkl"
    with (open(filepath, "rb")) as openfile:
        while True:
            try:
                policies = pickle.load(openfile)
            except EOFError:
                break

    loaded_other_policy = policies["exploration/policy"]
    loaded_other_policy.rnn_hxs = loaded_other_policy.rnn_hxs[0].unsqueeze(0)
    eval_other = CraftController(loaded_other_policy, n=GRID_SIZE)
    expl_other = CraftController(loaded_other_policy, n=GRID_SIZE)

    eval_controller = PretrainedController([eval_collect, eval_other])
    expl_controller = PretrainedController([expl_collect, expl_other])

    function_env = gym.make(variant["env_name"])

    qf_criterion = nn.MSELoss()
    if variant["softmax"]:
        eval_learner = SoftmaxDiscretePolicy(qf, variant["temperature"])
    else:
        eval_learner = ArgmaxDiscretePolicy(qf)

    expl_learner = PolicyWrappedWithExplorationStrategy(
        LinearEpsilonGreedy(symbolic_action_space,
                            anneal_schedule=variant["anneal_schedule"]),
        eval_learner,
    )

    eval_policy = LearnPlanPolicy(
        eval_learner,
        planner,
        eval_controller,
        num_processes=1,
        vectorised=False,
        env=function_env,
    )

    expl_policy = LearnPlanPolicy(
        expl_learner,
        planner,
        expl_controller,
        num_processes=1,
        vectorised=False,
        env=function_env,
    )

    eval_path_collector = IntermediatePathCollector(
        eval_env,
        eval_policy,
        rollout=intermediate_rollout,
        gamma=1,
        render=variant["render"],
        single_plan_discounting=variant["trainer_kwargs"]
        ["single_plan_discounting"],
        experience_interval=variant["experience_interval"],
    )
    expl_path_collector = IntermediatePathCollector(
        expl_env,
        expl_policy,
        rollout=intermediate_rollout,
        gamma=variant["trainer_kwargs"]["discount"],
        render=variant["render"],
        single_plan_discounting=variant["trainer_kwargs"]
        ["single_plan_discounting"],
        experience_interval=variant["experience_interval"],
    )

    if variant["double_dqn"]:
        trainer = DoubleDQNTrainer(qf=qf,
                                   target_qf=target_qf,
                                   qf_criterion=qf_criterion,
                                   **variant["trainer_kwargs"])
    else:
        trainer = DQNTrainer(qf=qf,
                             target_qf=target_qf,
                             qf_criterion=qf_criterion,
                             **variant["trainer_kwargs"])
    replay_buffer = PlanReplayBuffer(variant["replay_buffer_size"], symb_env)

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant["algorithm_kwargs"])

    algorithm.to(ptu.device)

    algorithm.train()
예제 #3
0
def experiment(variant):
    common.initialise(variant)

    expl_envs, eval_envs = common.create_environments(variant)

    (
        obs_shape,
        obs_space,
        action_space,
        n,
        mlp,
        channels,
        fc_input,
    ) = common.get_spaces(expl_envs)

    base = common.create_networks(variant, n, mlp, channels, fc_input)

    dist = create_output_distribution(action_space, base.output_size)

    NUM_OPTIONS = 14
    eval_policy = MultiPolicy(
        obs_shape,
        action_space,
        ptu.device,
        base=base,
        deterministic=True,
        dist=dist,
        num_processes=variant["num_processes"],
        obs_space=obs_space,
        num_options=NUM_OPTIONS,
    )
    expl_policy = MultiPolicy(
        obs_shape,
        action_space,
        ptu.device,
        base=base,
        deterministic=False,
        dist=dist,
        num_processes=variant["num_processes"],
        obs_space=obs_space,
        num_options=NUM_OPTIONS,
    )

    if action_space.__class__.__name__ == "Tuple":
        action_space = gym.spaces.Box(-np.inf, np.inf, (8, ))
        expl_envs.action_space = action_space
        eval_envs.action_space = action_space

    eval_path_collector = RolloutStepCollector(
        eval_envs,
        eval_policy,
        ptu.device,
        max_num_epoch_paths_saved=variant["algorithm_kwargs"]
        ["num_eval_steps_per_epoch"],
        num_processes=variant["num_processes"],
        render=variant["render"],
    )
    expl_path_collector = RolloutStepCollector(
        expl_envs,
        expl_policy,
        ptu.device,
        max_num_epoch_paths_saved=variant["num_steps"],
        num_processes=variant["num_processes"],
        render=variant["render"],
    )
    # added: created rollout(5,1,(4,84,84),Discrete(6),1), reset env and added obs to rollout[step]

    trainer = PPOTrainer(actor_critic=expl_policy, **variant["trainer_kwargs"])
    # missing: by this point, rollout back in sync.
    replay_buffer = EnvReplayBuffer(variant["replay_buffer_size"], expl_envs)
    # added: replay buffer is new

    algorithm = TorchIkostrikovRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_envs,
        evaluation_env=eval_envs,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant["algorithm_kwargs"],
        # batch_size,
        # max_path_length,
        # num_epochs,
        # num_eval_steps_per_epoch,
        # num_expl_steps_per_train_loop,
        # num_trains_per_train_loop,
        # num_train_loops_per_epoch=1,
        # min_num_steps_before_training=0,
    )

    algorithm.to(ptu.device)
    # missing: device back in sync
    algorithm.train()
예제 #4
0
def experiment(variant):
    setup_logger("name-of-experiment", variant=variant)
    ptu.set_gpu_mode(True)

    expl_env = gym.make(variant["env_name"])
    eval_env = gym.make(variant["env_name"])

    # OLD - Taxi image env
    # if isinstance(expl_env.observation_space, Json):
    #     expl_env = BoxWrapper(expl_env)
    #     eval_env = BoxWrapper(eval_env)
    #     # obs_shape = expl_env.observation_space.image.shape

    # obs_shape = expl_env.observation_space.shape
    # if len(obs_shape) == 3 and obs_shape[2] in [1, 3]:  # convert WxHxC into CxWxH
    #     expl_env = TransposeImage(expl_env, op=[2, 0, 1])
    #     eval_env = TransposeImage(eval_env, op=[2, 0, 1])

    # obs_shape = expl_env.observation_space.shape
    # channels, obs_width, obs_height = obs_shape
    # action_dim = eval_env.action_space.n

    # qf = CNN(
    #     input_width=obs_width,
    #     input_height=obs_height,
    #     input_channels=channels,
    #     output_size=action_dim,
    #     kernel_sizes=[8, 4],
    #     n_channels=[16, 32],
    #     strides=[4, 2],
    #     paddings=[0, 0],
    #     hidden_sizes=[256],
    # )
    # target_qf = CNN(
    #     input_width=obs_width,
    #     input_height=obs_height,
    #     input_channels=channels,
    #     output_size=action_dim,
    #     kernel_sizes=[8, 4],
    #     n_channels=[16, 32],
    #     strides=[4, 2],
    #     paddings=[0, 0],
    #     hidden_sizes=[256],
    # )

    (
        obs_shape,
        obs_space,
        action_space,
        n,
        mlp,
        channels,
        fc_input,
    ) = common.get_spaces(expl_env)

    qf = Mlp(
        input_size=n,
        output_size=action_space.n,
        hidden_sizes=[256, 256],
        init_w=variant["init_w"],
        b_init_value=variant["b_init_value"],
    )
    target_qf = Mlp(
        input_size=n,
        output_size=action_space.n,
        hidden_sizes=[256, 256],
        init_w=variant["init_w"],
        b_init_value=variant["b_init_value"],
    )

    qf_criterion = nn.MSELoss()

    if variant["softmax"]:
        eval_policy = SoftmaxDiscretePolicy(qf, variant["temperature"])
    else:
        eval_policy = ArgmaxDiscretePolicy(qf)

    expl_policy = PolicyWrappedWithExplorationStrategy(
        LinearEpsilonGreedy(action_space,
                            anneal_schedule=variant["anneal_schedule"]),
        eval_policy,
    )
    eval_path_collector = MdpPathCollector(eval_env,
                                           eval_policy,
                                           render=variant["render"])
    expl_path_collector = MdpPathCollector(expl_env,
                                           expl_policy,
                                           render=variant["render"])
    trainer = DQNTrainer(qf=qf,
                         target_qf=target_qf,
                         qf_criterion=qf_criterion,
                         **variant["trainer_kwargs"])
    replay_buffer = EnvReplayBuffer(variant["replay_buffer_size"], expl_env)
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant["algorithm_kwargs"])
    algorithm.to(ptu.device)
    algorithm.train()
예제 #5
0
def experiment(variant):
    common.initialise(variant)

    expl_envs, eval_envs = common.create_environments(variant)

    (
        obs_shape,
        obs_space,
        action_space,
        n,
        mlp,
        channels,
        fc_input,
    ) = common.get_spaces(expl_envs)

    obs_dim = obs_shape[1]

    qf = CNN(
        input_width=obs_dim,
        input_height=obs_dim,
        input_channels=channels,
        output_size=8,
        kernel_sizes=[8, 4],
        n_channels=[16, 32],
        strides=[4, 2],
        paddings=[0, 0],
        hidden_sizes=[256],
    )
    # CHANGE TO ORDINAL ACTION SPACE
    action_space = gym.spaces.Box(-np.inf, np.inf, (8, ))
    expl_envs.action_space = action_space
    eval_envs.action_space = action_space

    base = common.create_networks(variant, n, mlp, channels, fc_input)

    bernoulli_dist = distributions.Bernoulli(base.output_size, 4)
    passenger_dist = distributions.Categorical(base.output_size, 5)
    delivered_dist = distributions.Categorical(base.output_size, 5)
    continuous_dist = distributions.DiagGaussian(base.output_size, 2)
    dist = distributions.DistributionGeneratorTuple(
        (bernoulli_dist, continuous_dist, passenger_dist, delivered_dist))

    eval_policy = LearnPlanPolicy(
        ScriptedPolicy(qf, variant["always_return"]),
        num_processes=variant["num_processes"],
        vectorised=True,
        json_to_screen=expl_envs.observation_space.converter,
    )
    expl_policy = LearnPlanPolicy(
        ScriptedPolicy(qf, variant["always_return"]),
        num_processes=variant["num_processes"],
        vectorised=True,
        json_to_screen=expl_envs.observation_space.converter,
    )

    eval_path_collector = HierarchicalStepCollector(
        eval_envs,
        eval_policy,
        ptu.device,
        max_num_epoch_paths_saved=variant["algorithm_kwargs"]
        ["num_eval_steps_per_epoch"],
        num_processes=variant["num_processes"],
        render=variant["render"],
        gamma=1,
        no_plan_penalty=variant.get("no_plan_penalty", False),
    )
    expl_path_collector = HierarchicalStepCollector(
        expl_envs,
        expl_policy,
        ptu.device,
        max_num_epoch_paths_saved=variant["num_steps"],
        num_processes=variant["num_processes"],
        render=variant["render"],
        gamma=variant["trainer_kwargs"]["gamma"],
        no_plan_penalty=variant.get("no_plan_penalty", False),
    )
    # added: created rollout(5,1,(4,84,84),Discrete(6),1), reset env and added obs to rollout[step]

    trainer = PPOTrainer(actor_critic=expl_policy.learner,
                         **variant["trainer_kwargs"])
    # missing: by this point, rollout back in sync.
    replay_buffer = EnvReplayBuffer(variant["replay_buffer_size"], expl_envs)
    # added: replay buffer is new
    algorithm = TorchIkostrikovRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_envs,
        evaluation_env=eval_envs,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant["algorithm_kwargs"],
        # batch_size,
        # max_path_length,
        # num_epochs,
        # num_eval_steps_per_epoch,
        # num_expl_steps_per_train_loop,
        # num_trains_per_train_loop,
        # num_train_loops_per_epoch=1,
        # min_num_steps_before_training=0,
    )

    algorithm.to(ptu.device)
    # missing: device back in sync
    algorithm.evaluate()
예제 #6
0
def experiment(variant):
    common.initialise(variant)

    expl_envs, eval_envs = common.create_environments(variant)

    (
        obs_shape,
        obs_space,
        action_space,
        n,
        mlp,
        channels,
        fc_input,
    ) = common.get_spaces(expl_envs)

    # CHANGE TO FACTORD ACTION SPACE
    action_space = gym.spaces.Box(-np.inf, np.inf, (10, ))
    expl_envs.action_space = action_space
    eval_envs.action_space = action_space

    base = common.create_networks(variant, n, mlp, channels, fc_input)

    bernoulli_dist = distributions.Bernoulli(base.output_size, 4)
    continuous_dist = distributions.DiagGaussian(base.output_size, 6)
    dist = distributions.DistributionGeneratorTuple(
        (bernoulli_dist, continuous_dist))

    eval_policy = LearnPlanPolicy(
        WrappedPolicy(
            obs_shape,
            action_space,
            ptu.device,
            base=base,
            deterministic=True,
            dist=dist,
            num_processes=variant["num_processes"],
            obs_space=obs_space,
        ),
        num_processes=variant["num_processes"],
        vectorised=True,
        json_to_screen=expl_envs.observation_space.converter,
    )
    expl_policy = LearnPlanPolicy(
        WrappedPolicy(
            obs_shape,
            action_space,
            ptu.device,
            base=base,
            deterministic=False,
            dist=dist,
            num_processes=variant["num_processes"],
            obs_space=obs_space,
        ),
        num_processes=variant["num_processes"],
        vectorised=True,
        json_to_screen=expl_envs.observation_space.converter,
    )

    eval_path_collector = HierarchicalStepCollector(
        eval_envs,
        eval_policy,
        ptu.device,
        max_num_epoch_paths_saved=variant["algorithm_kwargs"]
        ["num_eval_steps_per_epoch"],
        num_processes=variant["num_processes"],
        render=variant["render"],
    )
    expl_path_collector = HierarchicalStepCollector(
        expl_envs,
        expl_policy,
        ptu.device,
        max_num_epoch_paths_saved=variant["num_steps"],
        num_processes=variant["num_processes"],
        render=variant["render"],
    )
    # added: created rollout(5,1,(4,84,84),Discrete(6),1), reset env and added obs to rollout[step]

    trainer = PPOTrainer(actor_critic=expl_policy.learner,
                         **variant["trainer_kwargs"])
    # missing: by this point, rollout back in sync.
    replay_buffer = EnvReplayBuffer(variant["replay_buffer_size"], expl_envs)
    # added: replay buffer is new
    algorithm = TorchIkostrikovRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_envs,
        evaluation_env=eval_envs,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant["algorithm_kwargs"],
        # batch_size,
        # max_path_length,
        # num_epochs,
        # num_eval_steps_per_epoch,
        # num_expl_steps_per_train_loop,
        # num_trains_per_train_loop,
        # num_train_loops_per_epoch=1,
        # min_num_steps_before_training=0,
    )

    algorithm.to(ptu.device)
    # missing: device back in sync
    algorithm.train()
예제 #7
0
def experiment(variant):
    common.initialise(variant)

    expl_envs, eval_envs = common.create_environments(variant)

    (
        obs_shape,
        obs_space,
        action_space,
        n,
        mlp,
        channels,
        fc_input,
    ) = common.get_spaces(expl_envs)

    # # CHANGE TO ORDINAL ACTION SPACE
    # action_space = gym.spaces.Box(-np.inf, np.inf, (8,))
    # expl_envs.action_space = action_space
    # eval_envs.action_space = action_space
    ANCILLARY_GOAL_SIZE = variant[
        "ancillary_goal_size"]  # This is the length of the action space for the learner
    SYMBOLIC_ACTION_SIZE = 12
    GRID_SIZE = 31

    base = common.create_networks(variant, n, mlp, channels, fc_input)
    control_base = common.create_networks(
        variant, n, mlp, channels,
        fc_input + SYMBOLIC_ACTION_SIZE)  # for uvfa goal representation

    dist = common.create_symbolic_action_distributions(variant["action_space"],
                                                       base.output_size)

    control_dist = distributions.Categorical(base.output_size, action_space.n)

    eval_learner = WrappedPolicy(
        obs_shape,
        action_space,
        ptu.device,
        base=base,
        deterministic=True,
        dist=dist,
        num_processes=variant["num_processes"],
        obs_space=obs_space,
    )

    planner = ENHSPPlanner()

    # collect

    filepath = "/home/achester/anaconda3/envs/goal-gen/.guild/runs/e77c75eed02e4b38a0a308789fbfcbd8/data/params.pkl"  # collect
    with (open(filepath, "rb")) as openfile:
        while True:
            try:
                policies = pickle.load(openfile)
            except EOFError:
                break

    loaded_collect_policy = policies["exploration/policy"]
    loaded_collect_policy.rnn_hxs = loaded_collect_policy.rnn_hxs[0].unsqueeze(
        0)

    eval_collect = CraftController(loaded_collect_policy, n=GRID_SIZE)

    expl_collect = CraftController(loaded_collect_policy, n=GRID_SIZE)

    # other
    filepath = "/home/achester/anaconda3/envs/goal-gen/.guild/runs/cf5c31afe0724acd8f6398d77a80443e/data/params.pkl"  # other
    # filepath = "/home/achester/Documents/symbolic-goal-generation/data/params.pkl"
    with (open(filepath, "rb")) as openfile:
        while True:
            try:
                policies = pickle.load(openfile)
            except EOFError:
                break

    loaded_other_policy = policies["exploration/policy"]
    loaded_other_policy.rnn_hxs = loaded_other_policy.rnn_hxs[0].unsqueeze(0)

    eval_other = CraftController(loaded_other_policy, n=GRID_SIZE)
    expl_other = CraftController(loaded_other_policy, n=GRID_SIZE)

    eval_controller = PretrainedController([eval_collect, eval_other])
    expl_controller = PretrainedController([expl_collect, expl_other])

    function_env = gym.make(variant["env_name"])

    eval_policy = LearnPlanPolicy(
        eval_learner,
        planner,
        eval_controller,
        num_processes=variant["num_processes"],
        vectorised=True,
        env=function_env,
    )

    expl_learner = WrappedPolicy(
        obs_shape,
        action_space,
        ptu.device,
        base=base,
        deterministic=False,
        dist=dist,
        num_processes=variant["num_processes"],
        obs_space=obs_space,
    )

    expl_policy = LearnPlanPolicy(
        expl_learner,
        planner,
        expl_controller,
        num_processes=variant["num_processes"],
        vectorised=True,
        env=function_env,
    )

    eval_path_collector = ThreeTierStepCollector(
        eval_envs,
        eval_policy,
        ptu.device,
        ANCILLARY_GOAL_SIZE,
        SYMBOLIC_ACTION_SIZE,
        max_num_epoch_paths_saved=variant["algorithm_kwargs"]
        ["num_eval_steps_per_epoch"],
        num_processes=variant["num_processes"],
        render=variant["render"],
        gamma=1,
        no_plan_penalty=True,
        meta_num_epoch_paths=variant["meta_num_steps"],
    )
    expl_path_collector = ThreeTierStepCollector(
        expl_envs,
        expl_policy,
        ptu.device,
        ANCILLARY_GOAL_SIZE,
        SYMBOLIC_ACTION_SIZE,
        max_num_epoch_paths_saved=variant["num_steps"],
        num_processes=variant["num_processes"],
        render=variant["render"],
        gamma=variant["trainer_kwargs"]["gamma"],
        no_plan_penalty=variant.get("no_plan_penalty", False),
        meta_num_epoch_paths=variant["meta_num_steps"],
    )

    learn_trainer = PPOTrainer(actor_critic=expl_policy.learner,
                               **variant["trainer_kwargs"])
    control_trainer = DummyTrainer()
    trainer = MultiTrainer([control_trainer, learn_trainer])

    replay_buffer = EnvReplayBuffer(variant["replay_buffer_size"], expl_envs)

    algorithm = TorchIkostrikovRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_envs,
        evaluation_env=eval_envs,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant["algorithm_kwargs"],
        # batch_size,
        # max_path_length,
        # num_epochs,
        # num_eval_steps_per_epoch,
        # num_expl_steps_per_train_loop,
        # num_trains_per_train_loop,
        # num_train_loops_per_epoch=1,
        # min_num_steps_before_training=0,
    )

    algorithm.to(ptu.device)

    algorithm.train()