def experiment(
    experiment_config,
    exp_prefix,
    variant,
    gpu_kwargs=None,
    log_to_wandb=False,
):
    """
    Reset timers
    (Useful if running multiple seeds from same command)
    """

    gt.reset()
    gt.start()
    """
    Setup logging
    """

    seed = variant['seed']
    setup_logger(exp_prefix,
                 variant=variant,
                 seed=seed,
                 log_to_wandb=log_to_wandb)
    output_csv = logger.get_tabular_output()
    """
    Set GPU mode for pytorch (+ possible other things later)
    """

    if gpu_kwargs is None:
        gpu_kwargs = {'mode': False}
    ptu.set_gpu_mode(**gpu_kwargs)
    """
    Set experiment seeds
    """

    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    """
    Environment setup
    """

    envs_list = variant.get('envs_list', None)

    if envs_list is None:
        expl_env, env_infos = make_env(variant['env_name'],
                                       **variant.get('env_kwargs', {}))

    else:
        # TODO: not sure if this is tested
        if len(envs_list) == 0:
            raise AttributeError('length of envs_list is zero')
        switch_every = variant['switch_every']
        expl_envs = []
        for env_params in envs_list:
            expl_env, env_infos = make_env(**env_params)
            expl_envs.append(expl_env)
        expl_env = ContinualLifelongEnv(expl_envs[0], switch_every, expl_envs)

    obs_dim = get_dim(expl_env.observation_space)
    action_dim = get_dim(expl_env.action_space)

    if env_infos['mujoco']:
        replay_buffer = MujocoReplayBuffer(variant['replay_buffer_size'],
                                           expl_env)
    else:
        replay_buffer = EnvReplayBuffer(variant['replay_buffer_size'],
                                        expl_env)

    eval_env = FollowerEnv(expl_env)
    """
    Import any teacher data
    """

    if 'teacher_data_files' in variant:
        for data_file in variant['teacher_data_files']:
            if 'max_teacher_transitions' in variant:
                add_transitions(
                    replay_buffer,
                    data_file,
                    obs_dim,
                    action_dim,
                    max_transitions=variant['max_teacher_transitions'],
                )
            else:
                add_transitions(replay_buffer, data_file, obs_dim, action_dim)
    """
    Experiment-specific configuration
    """

    config = experiment_config['get_config'](
        variant,
        expl_env=expl_env,
        eval_env=eval_env,
        obs_dim=obs_dim,
        action_dim=action_dim,
        replay_buffer=replay_buffer,
    )

    if 'load_config' in experiment_config:
        experiment_config['load_config'](config, variant, gpu_kwargs)

    if 'algorithm_kwargs' not in config:
        config['algorithm_kwargs'] = variant.get('algorithm_kwargs', dict())
    if 'offline_kwargs' not in config:
        config['offline_kwargs'] = variant.get('offline_kwargs', dict())
    """
    Path collectors for sampling from environment
    """

    collector_type = variant.get('collector_type', 'step')
    exploration_policy = config['exploration_policy']
    if collector_type == 'step':
        expl_path_collector = MdpStepCollector(expl_env, exploration_policy)
    elif collector_type == 'batch':
        expl_path_collector = MdpPathCollector(expl_env, exploration_policy)
    elif collector_type == 'batch_latent':
        expl_path_collector = LatentPathCollector(
            sample_latent_every=None,
            env=expl_env,
            policy=exploration_policy,
        )
    elif collector_type == 'rf':
        expl_path_collector = RFCollector(expl_env, exploration_policy)
    else:
        raise NotImplementedError(
            'collector_type of experiment not recognized')

    if collector_type == 'gcr':
        eval_path_collector = GoalConditionedReplayStepCollector(
            eval_env,
            config['evaluation_policy'],
            replay_buffer,
            variant['resample_goal_every'],
        )
    else:
        eval_path_collector = MdpPathCollector(
            eval_env,
            config['evaluation_policy'],
        )
    """
    Finish timer
    """

    gt.stamp('initialization', unique=False)
    """
    Offline RL pretraining
    """

    if 'get_offline_algorithm' in experiment_config and variant.get(
            'do_offline_training', False):
        logger.set_tabular_output(
            os.path.join(logger.log_dir, 'offline_progress.csv'))

        offline_algorithm = experiment_config['get_offline_algorithm'](
            config,
            eval_path_collector=eval_path_collector,
        )
        offline_algorithm.to(ptu.device)
        offline_algorithm.train()

        logger.set_tabular_output(output_csv)
    """
    Generate algorithm that performs training
    """

    if 'get_algorithm' in experiment_config and variant.get(
            'do_online_training', True):
        algorithm = experiment_config['get_algorithm'](
            config,
            expl_path_collector=expl_path_collector,
            eval_path_collector=eval_path_collector,
        )
        algorithm.to(ptu.device)
        algorithm.train()
    def _train(self):

        batch_idxes = np.arange(self.num_tasks)

        gt.start()

        for epoch in gt.timed_for(
                trange(self._start_epoch, self.num_epochs),
                save_itrs=True,
        ):
            # Distribute the evaluation. We ship the
            # params of each needed network to the
            # remote path collector

            params_list = []
            for net in self.policy.networks:
                params_list.append(ptu.state_dict_cpu(net))

            self.path_collector.set_policy_params(params_list)

            evaluation_train_obj_id_list = []
            count = 0
            while count < len(self.train_goals):
                if len(self.train_goals) - count < self.num_workers:
                    evaluation_obj_id = self.path_collector.async_evaluate(
                        self.train_goals[count:])
                    count = len(self.train_goals)
                else:
                    evaluation_obj_id = self.path_collector.async_evaluate(
                        self.train_goals[count:count + self.num_workers])
                    count += self.num_workers
                evaluation_train_obj_id_list.extend(evaluation_obj_id)

            assert len(evaluation_train_obj_id_list) == len(
                self.train_goals
            ), f'{len(evaluation_train_obj_id_list)}, {len(self.train_goals)}'

            evaluation_wd_obj_id_list = []
            count = 0
            while count < len(self.wd_goals):
                if len(self.wd_goals) - count < self.num_workers:
                    evaluation_obj_id = self.path_collector.async_evaluate(
                        self.wd_goals[count:])
                    count = len(self.wd_goals)
                else:
                    evaluation_obj_id = self.path_collector.async_evaluate(
                        self.wd_goals[count:count + self.num_workers])
                    count += self.num_workers
                evaluation_wd_obj_id_list.extend(evaluation_obj_id)

            assert len(evaluation_wd_obj_id_list) == len(self.wd_goals)

            # evaluation_ood_obj_id_list = []
            # count = 0
            # while count < len(self.ood_goals) :
            #     if len(self.ood_goals) - count < self.num_workers:
            #         evaluation_obj_id = self.path_collector.async_evaluate(self.ood_goals[count:])
            #         count = len(self.ood_goals)
            #     else:
            #         evaluation_obj_id = self.path_collector.async_evaluate(self.ood_goals[count:count + self.num_workers])
            #         count += self.num_workers
            #     evaluation_ood_obj_id_list.extend(evaluation_obj_id)

            # assert len(evaluation_ood_obj_id_list) == len(self.ood_goals)

            gt.stamp('set_up_evaluation', unique=False)

            train_batch_obj_id = self.train_buffer.sample_training_data(
                batch_idxes)

            for _ in trange(self.num_train_loops_per_epoch):
                train_raw_batch = ray.get(train_batch_obj_id)

                gt.stamp('sample_training_data', unique=False)

                # In this way, we can start the data sampling job for the
                # next training while doing training for the current loop.
                train_batch_obj_id = self.train_buffer.sample_training_data(
                    batch_idxes)

                gt.stamp('set_up_sampling', unique=False)

                train_data = self.construct_training_batch(train_raw_batch)
                gt.stamp('construct_training_batch', unique=False)

                self.policy.train(train_data)
                gt.stamp('training', unique=False)

            eval_train_returns = ray.get(evaluation_train_obj_id_list)

            self.avg_train_episode_returns = [
                item[0] for item in eval_train_returns
            ]
            self.final_train_achieved = [
                item[1] for item in eval_train_returns
            ]
            self.train_avg_returns = np.mean(self.avg_train_episode_returns)

            eval_wd_returns = ray.get(evaluation_wd_obj_id_list)

            self.avg_wd_episode_returns = [item[0] for item in eval_wd_returns]
            self.final_wd_achieved = [item[1] for item in eval_wd_returns]
            self.wd_avg_returns = np.mean(self.avg_wd_episode_returns)

            # eval_ood_returns = ray.get(evaluation_ood_obj_id_list)

            # self.avg_ood_episode_returns = [item[0] for item in eval_ood_returns]
            # self.final_ood_achieved = [item[1] for item in eval_ood_returns]
            # self.ood_avg_returns = np.mean(self.avg_ood_episode_returns)

            gt.stamp('evaluation', unique=False)

            self._end_epoch(epoch)
        coreset = None  # not needed
        data.x_mean, data.x_std = data.get_statistics(
            data.X[data.index['train']])
    else:
        raise ValueError('Invalid coreset algorithm: {}'.format(args.coreset))

    test_performances = {
        'LL': [],
        'RMSE': [],
        'wt': [0.],
        'wt_batch': [0.],
        'num_samples': []
    }
    test_nll, test_performance = np.zeros(1, ), np.zeros(1, )

    gt.start()
    while len(data.index['train']) < args.init_num_labeled + args.budget:
        print('{}: Number of samples {}/{}'.format(
            args.seed,
            len(data.index['train']) - args.init_num_labeled, args.budget))

        optim_params = {
            'num_epochs': args.training_epochs,
            'batch_size': get_batch_size(args.dataset, data),
            'weight_decay': args.weight_decay,
            'initial_lr': args.initial_lr
        }
        nl = NeuralLinearTB(data, out_features=out_features, **kwargs)
        # nl = NeuralLinear(data, out_features=out_features, **kwargs)
        nl = utils.to_gpu(nl)
        nl.optimize(data, **optim_params)