Пример #1
0
def test_reset_returns_same_obj_and_goal():
    benchmark = metaworld.MT50()
    env_dict = benchmark.train_classes
    tasks = benchmark.train_tasks
    initial_obj_poses = {name: [] for name in env_dict.keys()}
    goal_poses = {name: [] for name in env_dict.keys()}

    # Execute rollout for each environment in benchmark.
    for env_name, env_cls in env_dict.items():

        # Create environment and set task.
        env = env_cls()
        env_tasks = [t for t in tasks if t.env_name == env_name]
        env.set_task(random.choice(env_tasks))

        # Step through environment for a fixed number of episodes.
        for _ in range(2):
            # Reset environment and extract initial object position.
            obs = env.reset()
            goal = obs[-3:]
            goal_poses[env_name].append(goal)
            initial_obj_pos = obs[3:9]
            initial_obj_poses[env_name].append(initial_obj_pos)

# Display initial object positions and find environments with non-unique positions.
    violating_envs_obs = []
    for env_name, task_initial_pos in initial_obj_poses.items():
        if len(np.unique(np.array(task_initial_pos), axis=0)) > 1:
            violating_envs_obs.append(env_name)
    violating_envs_goals = []
    for env_name, target_pos in goal_poses.items():
        if len(np.unique(np.array(target_pos), axis=0)) > 1:
            violating_envs_goals.append(env_name)
    assert not violating_envs_obs
    assert not violating_envs_goals
Пример #2
0
def test_identical_environments():
    def helper(env, env_2):
        for i in range(len(env.train_tasks)):
            rand_vec_1 = pickle.loads(env.train_tasks[i].data)['rand_vec']
            rand_vec_2 = pickle.loads(env_2.train_tasks[i].data)['rand_vec']
            np.testing.assert_equal(rand_vec_1, rand_vec_2)

    def helper_neq(env, env_2):
        for i in range(len(env.train_tasks)):
            rand_vec_1 = pickle.loads(env.train_tasks[i].data)['rand_vec']
            rand_vec_2 = pickle.loads(env_2.train_tasks[i].data)['rand_vec']
            assert not (rand_vec_1 == rand_vec_2).all()

    #testing MT1
    mt1_1 = metaworld.MT1('sweep-into-v2', seed=10)
    mt1_2 = metaworld.MT1('sweep-into-v2', seed=10)
    helper(mt1_1, mt1_2)

    #testing ML1
    ml1_1 = metaworld.ML1('sweep-into-v2', seed=10)
    ml1_2 = metaworld.ML1('sweep-into-v2', seed=10)
    helper(ml1_1, ml1_2)

    #testing MT10
    mt10_1 = metaworld.MT10(seed=10)
    mt10_2 = metaworld.MT10(seed=10)
    helper(mt10_1, mt10_2)

    # testing ML10
    ml10_1 = metaworld.ML10(seed=10)
    ml10_2 = metaworld.ML10(seed=10)
    helper(ml10_1, ml10_2)

    #testing ML45
    ml45_1 = metaworld.ML45(seed=10)
    ml45_2 = metaworld.ML45(seed=10)
    helper(ml45_1, ml45_2)

    #testing MT50
    mt50_1 = metaworld.MT50(seed=10)
    mt50_2 = metaworld.MT50(seed=10)
    helper(mt50_1, mt50_2)

    # test that 2 benchmarks with different seeds have different goals
    mt50_3 = metaworld.MT50(seed=50)
    helper_neq(mt50_1, mt50_3)
Пример #3
0
    def __init__(
        self,
        experiment_name,
        use_gpu,
        trainer_args
    ):
        """Train MTSAC with metaworld_experiments environment.
        Args:
            experiment_name: expeirment name to be used for logging and checkpointing
            use_wandb: boolean, defines whether or not to log to wandb
            use_gpu: boolean, defines whether or not to use to GPU for training
            trainer_args: named tuple with args given by config
        """

        # Define log and checkpoint dir
        self.checkpoint_dir = os.path.join(
            trainer_args.log_dir,
            f"{experiment_name}-{trainer_args.project_id}"
        )
        print(f"Checkpoint dir: {self.checkpoint_dir}")
        self.state_path = os.path.join(self.checkpoint_dir, "experiment_state.p")
        self.env_state_path = os.path.join(self.checkpoint_dir, "env_state.p")
        self.config_path = os.path.join(self.checkpoint_dir, "config.json")
        self.experiment_name = experiment_name

        # Only define viz_save_path if required to save visualizations local
        self.viz_save_path = None
        if trainer_args.save_visualizations_local:
            self.viz_save_path = os.path.join(self.checkpoint_dir, "viz")

        # Check if loading from existing experiment
        self.loading_from_existing = os.path.exists(self.checkpoint_dir)
        os.makedirs(self.checkpoint_dir, exist_ok=True)

        # Save arguments for later retrieval
        self.init_config(trainer_args)

        num_tasks = trainer_args.num_tasks

        # TODO: do we have to fix which GPU to use? run distributed across multiGPUs
        if use_gpu:
            set_gpu_mode(True, 0)

        if trainer_args.seed is not None:
            deterministic.set_seed(trainer_args.seed)

        # Note: different classes whether it uses 10 or 50 tasks. Why?
        mt_env = (
            metaworld.MT10(seed=trainer_args.env_seed) if num_tasks <= 10
            else metaworld.MT50(seed=trainer_args.env_seed)
        )

        train_task_sampler = MetaWorldTaskSampler(
            mt_env, "train", add_env_onehot=True
        )

        # TODO: add some clarifying comments of why these asserts are required
        assert num_tasks % 10 == 0, "Number of tasks have to divisible by 10"
        assert num_tasks <= 500, "Number of tasks should be less or equal 500"

        # TODO: do we have guarantees that in case seed is set, the tasks being sampled
        # are the same?
        mt_train_envs = train_task_sampler.sample(num_tasks)
        env = mt_train_envs[0]()

        if trainer_args.params_seed is not None:
            torch.manual_seed(trainer_args.params_seed)

        policy = create_policy_net(env_spec=env.spec, net_params=trainer_args)
        qf1 = create_qf_net(env_spec=env.spec, net_params=trainer_args)
        qf2 = create_qf_net(env_spec=env.spec, net_params=trainer_args)

        if trainer_args.params_seed is not None:
            calculate_mean_param("policy", policy)
            calculate_mean_param("qf1", qf1)
            calculate_mean_param("qf2", qf2)

        if trainer_args.override_weight_initialization:
            logging.warn("Overriding dendritic layer weight initialization")
            self.override_weight_initialization([policy, qf1, qf2])

        replay_buffer = PathBuffer(
            capacity_in_transitions=trainer_args.num_buffer_transitions
        )
        max_episode_length = env.spec.max_episode_length
        self.env_steps_per_epoch = int(max_episode_length * num_tasks)
        self.num_epochs = trainer_args.timesteps // self.env_steps_per_epoch

        sampler = RaySampler(
            agent=policy,
            envs=mt_train_envs,
            max_episode_length=max_episode_length,
            cpus_per_worker=trainer_args.cpus_per_worker,
            gpus_per_worker=trainer_args.gpus_per_worker,
            workers_per_env=trainer_args.workers_per_env,
            seed=trainer_args.seed,
        )

        self._algo = CustomMTSAC(
            env_spec=env.spec,
            policy=policy,
            qf1=qf1,
            qf2=qf2,
            replay_buffer=replay_buffer,
            sampler=sampler,
            train_task_sampler=train_task_sampler,
            gradient_steps_per_itr=int(
                max_episode_length * trainer_args.num_grad_steps_scale
            ),
            task_update_frequency=trainer_args.task_update_frequency,
            num_tasks=num_tasks,
            min_buffer_size=max_episode_length * num_tasks,
            target_update_tau=trainer_args.target_update_tau,
            discount=trainer_args.discount,
            buffer_batch_size=trainer_args.buffer_batch_size,
            policy_lr=trainer_args.policy_lr,
            qf_lr=trainer_args.qf_lr,
            reward_scale=trainer_args.reward_scale,
            num_evaluation_episodes=trainer_args.eval_episodes,
            fp16=trainer_args.fp16 if use_gpu else False,
            log_per_task=trainer_args.log_per_task,
            share_train_eval_env=trainer_args.share_train_eval_env
        )

        # Override with loaded networks if existing experiment
        self.current_epoch = 0
        if self.loading_from_existing:
            self.load_experiment_state()

        # Move all networks within the model on device
        self._algo.to()
Пример #4
0
def te_ppo_mt50(ctxt, seed, n_epochs, batch_size_per_task, n_tasks):
    """Train Task Embedding PPO with PointEnv.

    Args:
        ctxt (garage.experiment.ExperimentContext): The experiment
            configuration used by Trainer to create the snapshotter.
        seed (int): Used to seed the random number generator to produce
            determinism.
        n_epochs (int): Total number of epochs for training.
        batch_size_per_task (int): Batch size of samples for each task.
        n_tasks (int): Number of tasks to use. Should be a multiple of 50.

    """
    set_seed(seed)
    mt50 = metaworld.MT50()
    task_sampler = MetaWorldTaskSampler(mt50,
                                        'train',
                                        lambda env, _: normalize(env),
                                        add_env_onehot=False)
    assert n_tasks % 50 == 0
    assert n_tasks <= 2500
    envs = [env_up() for env_up in task_sampler.sample(n_tasks)]
    env = MultiEnvWrapper(envs,
                          sample_strategy=round_robin_strategy,
                          mode='vanilla')

    latent_length = 6
    inference_window = 6
    batch_size = batch_size_per_task * n_tasks
    policy_ent_coeff = 2e-2
    encoder_ent_coeff = 2e-4
    inference_ce_coeff = 5e-2
    embedding_init_std = 0.1
    embedding_max_std = 0.2
    embedding_min_std = 1e-6
    policy_init_std = 1.0
    policy_max_std = None
    policy_min_std = None

    with TFTrainer(snapshot_config=ctxt) as trainer:

        task_embed_spec = TEPPO.get_encoder_spec(env.task_space,
                                                 latent_dim=latent_length)

        task_encoder = GaussianMLPEncoder(
            name='embedding',
            embedding_spec=task_embed_spec,
            hidden_sizes=(20, 20),
            std_share_network=True,
            init_std=embedding_init_std,
            max_std=embedding_max_std,
            output_nonlinearity=tf.nn.tanh,
            min_std=embedding_min_std,
        )

        traj_embed_spec = TEPPO.get_infer_spec(
            env.spec,
            latent_dim=latent_length,
            inference_window_size=inference_window)

        inference = GaussianMLPEncoder(
            name='inference',
            embedding_spec=traj_embed_spec,
            hidden_sizes=(20, 10),
            std_share_network=True,
            init_std=2.0,
            output_nonlinearity=tf.nn.tanh,
            min_std=embedding_min_std,
        )

        policy = GaussianMLPTaskEmbeddingPolicy(
            name='policy',
            env_spec=env.spec,
            encoder=task_encoder,
            hidden_sizes=(32, 16),
            std_share_network=True,
            max_std=policy_max_std,
            init_std=policy_init_std,
            min_std=policy_min_std,
        )

        baseline = LinearMultiFeatureBaseline(
            env_spec=env.spec, features=['observations', 'tasks', 'latents'])

        algo = TEPPO(env_spec=env.spec,
                     policy=policy,
                     baseline=baseline,
                     inference=inference,
                     discount=0.99,
                     lr_clip_range=0.2,
                     policy_ent_coeff=policy_ent_coeff,
                     encoder_ent_coeff=encoder_ent_coeff,
                     inference_ce_coeff=inference_ce_coeff,
                     use_softplus_entropy=True,
                     optimizer_args=dict(
                         batch_size=32,
                         max_optimization_epochs=10,
                         learning_rate=1e-3,
                     ),
                     inference_optimizer_args=dict(
                         batch_size=32,
                         max_optimization_epochs=10,
                     ),
                     center_adv=True,
                     stop_ce_gradient=True)

        trainer.setup(algo,
                      env,
                      sampler_cls=LocalSampler,
                      sampler_args=None,
                      worker_class=TaskEmbeddingWorker)
        trainer.train(n_epochs=n_epochs, batch_size=batch_size, plot=False)
def mtsac_metaworld_mt50(
    ctxt=None, *, config_pth, seed, timesteps, use_wandb, wandb_project_name, gpu
):
    """Train MTSAC with MT50 environment.
    Args:
        ctxt (garage.experiment.ExperimentContext): The experiment
            configuration used by Trainer to create the snapshotter.
        seed (int): Used to seed the random number generator to produce
            determinism.
        _gpu (int): The ID of the gpu to be used (used on multi-gpu machines).
        num_tasks (int): Number of tasks to use. Should be a multiple of 10.
        timesteps (int): Number of timesteps to run.
    """
    """Train MTSAC with metaworld_experiments environment.
    Args:
        ctxt (garage.experiment.ExperimentContext): The experiment
            configuration used by Trainer to create the snapshotter.
        seed (int): Used to seed the random number generator to produce
            determinism.
        _gpu (int): The ID of the gpu to be used (used on multi-gpu machines).
        timesteps (int): Number of timesteps to run.
    """
    print(f"Initiation took {time() - t0:.2f} secs")

    # Get experiment parameters (e.g. hyperparameters) and save the json file
    params = get_params(config_pth)

    with open(ctxt.snapshot_dir + "/params.json", "w") as json_file:
        json.dump(params, json_file)

    if use_wandb == "True":
        use_wandb = True
        wandb.init(
            name=params["experiment_name"],
            project=wandb_project_name,
            group="Baselines{}".format("mt50"),
            reinit=True,
            config=params,
        )
    else:
        use_wandb = False

    num_tasks = 50
    timesteps = timesteps
    deterministic.set_seed(seed)
    trainer = CustomTrainer(ctxt)
    mt10 = metaworld.MT50()

    train_task_sampler = MetaWorldTaskSampler(mt10, "train", add_env_onehot=True)

    assert num_tasks % 10 == 0, "Number of tasks have to divisible by 10"
    assert num_tasks <= 500, "Number of tasks should be less or equal 500"
    mt50_train_envs = train_task_sampler.sample(num_tasks)
    env = mt50_train_envs[0]()

    params["net"]["policy_min_std"] = np.exp(params["net"]["policy_min_log_std"])
    params["net"]["policy_max_std"] = np.exp(params["net"]["policy_max_log_std"])

    policy = create_policy_net(env_spec=env.spec, net_params=params["net"])
    qf1 = create_qf_net(env_spec=env.spec, net_params=params["net"])
    qf2 = create_qf_net(env_spec=env.spec, net_params=params["net"])

    replay_buffer = PathBuffer(
        capacity_in_transitions=int(params["general_setting"]["num_buffer_transitions"])
    )
    max_episode_length = env.spec.max_episode_length
    # Note: are the episode length the same among all tasks?

    sampler = RaySampler(
        agents=policy,
        envs=mt50_train_envs,
        max_episode_length=max_episode_length,
        # 1 sampler worker for each environment
        n_workers=num_tasks,
        worker_class=DefaultWorker
    )

    test_sampler = RaySampler(
        agents=policy,
        envs=mt50_train_envs,
        max_episode_length=max_episode_length,
        # 1 sampler worker for each environment
        n_workers=num_tasks,
        worker_class=EvalWorker
    )

    # Number of transitions before a set of gradient updates
    steps_between_updates = int(max_episode_length * num_tasks)

    # epoch: 1 cycle of data collection + gradient updates
    epochs = timesteps // steps_between_updates

    mtsac = CustomMTSAC(
        env_spec=env.spec,
        policy=policy,
        qf1=qf1,
        qf2=qf2,
        replay_buffer=replay_buffer,
        sampler=sampler,
        train_task_sampler=train_task_sampler,
        test_sampler=test_sampler,
        gradient_steps_per_itr=int(max_episode_length * params["training"]["num_grad_steps_scale"]),
        num_tasks=num_tasks,
        min_buffer_size=max_episode_length * num_tasks,
        target_update_tau=params["training"]["target_update_tau"],
        discount=params["general_setting"]["discount"],
        buffer_batch_size=params["training"]["buffer_batch_size"],
        policy_lr=params["training"]["policy_lr"],
        qf_lr=params["training"]["qf_lr"],
        reward_scale=params["training"]["reward_scale"],
        num_evaluation_episodes=params["general_setting"]["eval_episodes"],
        task_update_frequency=params["training"]["task_update_frequency"],
        wandb_logging=use_wandb,
        evaluation_frequency=params["general_setting"]["evaluation_frequency"]
    )

    if gpu is not None:
        set_gpu_mode(True, gpu)

    mtsac.to()
    trainer.setup(algo=mtsac, env=mt50_train_envs)
    trainer.train(n_epochs=epochs, batch_size=steps_between_updates)
Пример #6
0
def mtsac_metaworld_mt50(ctxt=None,
                         *,
                         seed,
                         use_gpu,
                         _gpu,
                         n_tasks,
                         timesteps):
    """Train MTSAC with MT50 environment.

    Args:
        ctxt (garage.experiment.ExperimentContext): The experiment
            configuration used by Trainer to create the snapshotter.
        seed (int): Used to seed the random number generator to produce
            determinism.
        use_gpu (bool): Used to enable ussage of GPU in training.
        _gpu (int): The ID of the gpu (used on multi-gpu machines).
        n_tasks (int): Number of tasks to use. Should be a multiple of 50.
        timesteps (int): Number of timesteps to run.

    """
    deterministic.set_seed(seed)
    trainer = Trainer(ctxt)
    mt50 = metaworld.MT50()  # pylint: disable=no-member
    mt50_test = metaworld.MT50()  # pylint: disable=no-member
    train_task_sampler = MetaWorldTaskSampler(
        mt50,
        'train',
        lambda env, _: normalize(env, normalize_reward=True),
        add_env_onehot=True)
    test_task_sampler = MetaWorldTaskSampler(mt50_test,
                                             'train',
                                             lambda env, _: normalize(env),
                                             add_env_onehot=True)
    assert n_tasks % 50 == 0
    assert n_tasks <= 2500
    mt50_train_envs = train_task_sampler.sample(n_tasks)
    env = mt50_train_envs[0]()
    mt50_test_envs = [env_up() for env_up in test_task_sampler.sample(n_tasks)]

    policy = TanhGaussianMLPPolicy(
        env_spec=env.spec,
        hidden_sizes=[400, 400, 400],
        hidden_nonlinearity=nn.ReLU,
        output_nonlinearity=None,
        min_std=np.exp(-20.),
        max_std=np.exp(2.),
    )

    qf1 = ContinuousMLPQFunction(env_spec=env.spec,
                                 hidden_sizes=[400, 400, 400],
                                 hidden_nonlinearity=F.relu)

    qf2 = ContinuousMLPQFunction(env_spec=env.spec,
                                 hidden_sizes=[400, 400, 400],
                                 hidden_nonlinearity=F.relu)

    replay_buffer = PathBuffer(capacity_in_transitions=int(1e6), )

    sampler = LocalSampler(
        agents=policy,
        envs=mt50_train_envs,
        max_episode_length=env.spec.max_episode_length,
        # 1 sampler worker for each environment
        n_workers=50,
        worker_class=FragmentWorker,
        # increasing n_envs increases the vectorization of a sampler worker
        # which improves runtime performance, but you will need to adjust this
        # depending on your memory constraints. For reference, each worker by
        # default uses n_envs=8. Each environment is approximately ~50mb large
        # so creating 50 envs with 8 copies comes out to 20gb of memory. Many
        # users want to be able to run multiple seeds on 1 machine, so I have
        # reduced this to n_envs = 2 for 2 copies in the meantime.
        worker_args=dict(n_envs=2))

    batch_size = int(env.spec.max_episode_length * n_tasks)
    num_evaluation_points = 500
    epochs = timesteps // batch_size
    epoch_cycles = epochs // num_evaluation_points
    epochs = epochs // epoch_cycles
    mtsac = MTSAC(policy=policy,
                  qf1=qf1,
                  qf2=qf2,
                  sampler=sampler,
                  gradient_steps_per_itr=env.spec.max_episode_length,
                  eval_env=mt50_test_envs,
                  env_spec=env.spec,
                  num_tasks=50,
                  steps_per_epoch=epoch_cycles,
                  replay_buffer=replay_buffer,
                  min_buffer_size=7500,
                  target_update_tau=5e-3,
                  discount=0.99,
                  buffer_batch_size=6400)
    set_gpu_mode(use_gpu, _gpu)
    mtsac.to()
    trainer.setup(algo=mtsac, env=mt50_train_envs)

    trainer.train(n_epochs=epochs, batch_size=batch_size)
Пример #7
0
    def __init__(
        self,
        benchmark_name: str,
        save_memory: bool = False,
        add_observability: bool = False,
    ) -> None:
        """ Init function for environment wrapper. """

        # We import here so that we avoid importing metaworld if possible, since it is
        # dependent on mujoco.
        import metaworld
        from metaworld import Task

        # Set config for each benchmark.
        if benchmark_name.startswith("MT1_"):
            env_name = benchmark_name[4:]
            benchmark = metaworld.MT1(env_name)
            env_dict = {env_name: benchmark.train_classes[env_name]}
            tasks = benchmark.train_tasks
            resample_tasks = False
            self.augment_obs = False

        elif benchmark_name == "MT10":
            benchmark = metaworld.MT10()
            env_dict = benchmark.train_classes
            tasks = benchmark.train_tasks
            resample_tasks = False
            self.augment_obs = True

        elif benchmark_name == "MT50":
            benchmark = metaworld.MT50()
            env_dict = benchmark.train_classes
            tasks = benchmark.train_tasks
            resample_tasks = False
            self.augment_obs = True

        elif benchmark_name.startswith("ML1_train_"):
            env_name = benchmark_name[10:]
            benchmark = metaworld.ML1(env_name)
            env_dict = {env_name: benchmark.train_classes[env_name]}
            tasks = benchmark.train_tasks
            resample_tasks = True
            self.augment_obs = False

        elif benchmark_name == "ML10_train":
            benchmark = metaworld.ML10()
            env_dict = benchmark.train_classes
            tasks = benchmark.train_tasks
            resample_tasks = True
            self.augment_obs = True

        elif benchmark_name == "ML45_train":
            benchmark = metaworld.ML45()
            env_dict = benchmark.train_classes
            tasks = benchmark.train_tasks
            resample_tasks = True
            self.augment_obs = True

        elif benchmark_name.startswith("ML1_test_"):
            env_name = benchmark_name[9:]
            benchmark = metaworld.ML1(env_name)
            env_dict = {env_name: benchmark.test_classes[env_name]}
            tasks = benchmark.test_tasks
            resample_tasks = True
            self.augment_obs = False

        elif benchmark_name == "ML10_test":
            benchmark = metaworld.ML10()
            env_dict = benchmark.test_classes
            tasks = benchmark.test_tasks
            resample_tasks = True
            self.augment_obs = True

        elif benchmark_name == "ML45_test":
            benchmark = metaworld.ML45()
            env_dict = benchmark.test_classes
            tasks = benchmark.test_tasks
            resample_tasks = True
            self.augment_obs = True

        else:
            raise NotImplementedError

        # Construct list of tasks for each environment, adding observability to tasks if
        # necessary.
        env_tasks = {}
        for task in tasks:
            if add_observability:
                task_data = dict(pickle.loads(task.data))
                task_data["partially_observable"] = False
                task = Task(env_name=task.env_name,
                            data=pickle.dumps(task_data))

            if task.env_name in env_tasks:
                if resample_tasks:
                    env_tasks[task.env_name].append(task)
            else:
                env_tasks[task.env_name] = [task]

        # Construct list of environment classes or class instances.
        self.save_memory = save_memory
        if self.save_memory:
            self.envs_info = [{
                "env_name": env_name,
                "env_cls": env_cls,
                "tasks": env_tasks[env_name]
            } for (env_name, env_cls) in env_dict.items()]
        else:
            self.envs_info = [{
                "env_name": env_name,
                "env": env_cls(),
                "tasks": env_tasks[env_name]
            } for (env_name, env_cls) in env_dict.items()]

        self.num_tasks = len(self.envs_info)

        # Sample environment.
        self._sample_environment()
Пример #8
0
def mtsac_metaworld_mt10(ctxt=None,
                         *,
                         experiment_name,
                         config_pth,
                         seed,
                         use_wandb,
                         gpu):
    """Train MTSAC with MT10 environment.
    Args:
        ctxt (garage.experiment.ExperimentContext): The experiment
            configuration used by Trainer to create the snapshotter.
        seed (int): Used to seed the random number generator to produce
            determinism.
        _gpu (int): The ID of the gpu to be used (used on multi-gpu machines).
        timesteps (int): Number of timesteps to run.
    """
    print(f"Initiation took {time()-t0:.2f} secs")

    device = torch.device("cuda") if gpu else torch.device("cpu")
    print(f"Using GPU: {gpu}, Device: {device}")

    # maybe overring other things - this is required, why?
    if gpu:
        set_gpu_mode(True)
    else:
        set_gpu_mode(False)

    # Get experiment parameters (e.g. hyperparameters) and save the json file
    params = get_params(config_pth)

    with open(ctxt.snapshot_dir + "/params.json", "w") as json_file:
        json.dump(params, json_file)

    if use_wandb == "True":
        use_wandb = True
        wandb.init(
            name=experiment_name,
            project="mt10_debug",
            group="Baselines{}".format("mt10"),
            reinit=True,
            config=params,
        )
    else:
        use_wandb = False

    num_tasks = params["net"]["num_tasks"]
    timesteps = 15000000
    deterministic.set_seed(seed)
    trainer = Trainer(ctxt)

    # Note: different classes whether it uses 10 or 50 tasks. Why?
    if num_tasks <= 10:
        mt_env = metaworld.MT10()
    else:
        mt_env = metaworld.MT50()

    train_task_sampler = MetaWorldTaskSampler(mt_env,
                                              "train",
                                              add_env_onehot=True)

    assert num_tasks % 10 == 0, "Number of tasks have to divisible by 10"
    assert num_tasks <= 500, "Number of tasks should be less or equal 500"
    mt_train_envs = train_task_sampler.sample(num_tasks)
    env = mt_train_envs[0]()

    params["net"]["policy_min_std"] = np.exp(
        params["net"]["policy_min_log_std"])
    params["net"]["policy_max_std"] = np.exp(
        params["net"]["policy_max_log_std"])

    policy = create_policy_net(env_spec=env.spec, net_params=params["net"])
    print("Created policy")

    qf1 = create_qf_net(env_spec=env.spec, net_params=params["net"])
    qf2 = create_qf_net(env_spec=env.spec, net_params=params["net"])
    print("Created value functions")

    replay_buffer = PathBuffer(capacity_in_transitions=int(
        params["general_setting"]["num_buffer_transitions"]))
    max_episode_length = env.spec.max_episode_length
    # Note: are the episode length the same among all tasks?

    sampler = RaySampler(
        agents=policy,
        envs=mt_train_envs,
        max_episode_length=max_episode_length,
        cpus_per_worker=params["sampler"]["cpus_per_worker"],
        gpus_per_worker=params["sampler"]["gpus_per_worker"],
        seed=None,  # set to get_seed() to make it deterministic
    )

    # will probably still need the sampler
    test_sampler = sampler
    # test_sampler = RaySampler(
    #     agents=policy,
    #     envs=mt_train_envs,
    #     max_episode_length=max_episode_length,
    #     # 1 sampler worker for each environment
    #     n_workers=num_tasks,
    #     worker_class=EvalWorker
    # )

    # Note:  difference between sampler and test sampler is only the worker
    # difference is one line in EvalWorker, uses average: a = agent_info["mean"]
    # can we create a unified worker that cointais both rules?

    # Number of transitions before a set of gradient updates
    # Note: should we use avg episode length, if they are not same for all tasks?
    batch_size = int(max_episode_length * num_tasks)

    # TODO: this whole block seems unnecessary, it is not doing anything.
    # Number of times policy is evaluated (also the # of epochs)
    num_evaluation_points = timesteps // batch_size
    epochs = timesteps // batch_size
    # number of times new batch of samples + gradient updates are done per epoch
    epoch_cycles = epochs // num_evaluation_points  # this will always be equal to 1
    epochs = epochs // epoch_cycles

    mtsac = CustomMTSAC(
        env_spec=env.spec,
        policy=policy,
        qf1=qf1,
        qf2=qf2,
        replay_buffer=replay_buffer,
        sampler=sampler,
        train_task_sampler=train_task_sampler,
        test_sampler=test_sampler,
        gradient_steps_per_itr=1,
        num_tasks=num_tasks,
        steps_per_epoch=epoch_cycles,
        min_buffer_size=max_episode_length * num_tasks,
        target_update_tau=params["training"]["target_update_tau"],
        discount=params["general_setting"]["discount"],
        buffer_batch_size=params["training"]["buffer_batch_size"],
        policy_lr=params["training"]["policy_lr"],
        qf_lr=params["training"]["qf_lr"],
        reward_scale=params["training"]["reward_scale"],
        num_evaluation_episodes=params["general_setting"]["eval_episodes"],
        task_update_frequency=params["training"]["task_update_frequency"],
        wandb_logging=use_wandb,
        evaluation_frequency=params["general_setting"]["evaluation_frequency"],
    )

    print("Created algo")

    mtsac.to(device=device)
    print("Moved networks to device")

    trainer.setup(algo=mtsac, env=mt_train_envs)
    print("Setup trainer")

    trainer.train(n_epochs=epochs, batch_size=batch_size)