def update_plots(self):
        all_model_names = list(self.ll_models.keys())
        speeds = update_speed_plots(self.training_times_it, all_model_names,
                                    self.main_viz)
        accs = update_avg_acc(self.all_perfs, all_model_names, self.main_viz,
                              'Average Accuracies')
        if self.norm_models:
            update_avg_acc(self.all_perfs_normalized, all_model_names,
                           self.main_viz, 'Normalized Average Accuracies')
        self.summary.update(speed=speeds.tolist(), accuracy=accs.tolist())
        update_summary(self.summary, self.main_viz)
        plot_heatmaps(all_model_names, self.all_perfs, self.main_viz)

        plot_speeds(self.training_times_it, all_model_names, self.main_viz)
        plot_accs(self.all_perfs, all_model_names, self.main_viz,
                  'Learning Accuracies')
        if self.norm_models:
            plot_accs(self.all_perfs_normalized, all_model_names, self.main_viz,
                      'Normalized Learning Accuracies')
        plot_times(self.training_times_s, all_model_names, self.main_viz)
        if isinstance(self.task_gen.strat, MoreDataStrategy):
            plot_accs_data(self.all_perfs, all_model_names,
                           self.task_gen.strat.n_samples_per_task_per_class,
                           self.main_viz)
        plot_speed_vs_tp(self.training_times_it, self.ideal_potentials, 'Ideal',
                         all_model_names, self.main_viz)
        plot_speed_vs_tp(self.training_times_it, self.current_potentials,
                         'Current',
                         all_model_names, self.main_viz)
        self.save_traces()
Exemple #2
0
    def train_model_on_task(self, task, task_viz, exp_dir, use_ray,
                            use_ray_logging, grace_period,
                            num_hp_samplings, local_mode,
                            redis_address, lca_n, **training_params):
        logger.info("Training dashboard: {}".format(get_env_url(task_viz)))
        t_id = task['id']

        trainable = self.get_trainable(use_ray_logging=use_ray_logging)
        past_tasks = training_params.pop('past_tasks')
        normalize = training_params.pop('normalize')
        augment_data = training_params.pop('augment_data')

        transformations = []
        if augment_data:
            transformations.extend([
                transforms.ToPILImage(),
                transforms.RandomHorizontalFlip(),
                transforms.RandomCrop(32, 4),
                transforms.ToTensor()
            ])
        t_trans = [[] for _ in range(len(task['split_names']))]
        t_trans[0] = transformations
        datasets = trainable._load_datasets(task,
                                            task['loss_fn'],
                                            past_tasks, t_trans, normalize)
        train_loader, eval_loaders = get_classic_dataloaders(datasets,
                                                             training_params.pop(
                                                                 'batch_sizes'))
        model = self.get_model(task_id=t_id, x_dim=task['x_dim'],
                               n_classes=task['n_classes'],
                               descriptor=task['descriptor'],
                               dataset=eval_loaders[:2])

        if use_ray:
            if not ray.is_initialized():
                ray.init(address=redis_address)

            scheduler = None

            training_params['loss_fn'] = tune.function(
                training_params['loss_fn'])
            training_params['optim_func'] = tune.function(self.optim_func)

            init_model_path = os.path.join(exp_dir, 'model_initializations')
            model_file_name = '{}_init.pth'.format(training_params['name'])
            model_path = os.path.join(init_model_path, model_file_name)
            torch.save(model, model_path)

            training_params['model_path'] = model_path
            config = {**self.get_search_space(),
                      'training-params': training_params}
            if use_ray_logging:
                stop_condition = {'training_iteration':
                                      training_params['n_it_max']}
                checkpoint_at_end = False
                keep_checkpoints_num = 1
                checkpoint_score_attr = 'min-Val nll'
            else:
                stop_condition = None
                # loggers = [JsonLogger, MyCSVLogger]
                checkpoint_at_end = False
                keep_checkpoints_num = None
                checkpoint_score_attr = None

            trainable = rename_class(trainable, training_params['name'])
            experiment = Experiment(
                name=training_params['name'],
                run=trainable,
                stop=stop_condition,
                config=config,
                resources_per_trial=self.ray_resources,
                num_samples=num_hp_samplings,
                local_dir=exp_dir,
                loggers=(JsonLogger, CSVLogger),
                checkpoint_at_end=checkpoint_at_end,
                keep_checkpoints_num=keep_checkpoints_num,
                checkpoint_score_attr=checkpoint_score_attr)

            analysis = tune.run(experiment,
                                scheduler=scheduler,
                                verbose=1,
                                raise_on_failed_trial=True,
                                # max_failures=-1,
                                # with_server=True,
                                # server_port=4321
                                )
            os.remove(model_path)
            logger.info("Training dashboard: {}".format(get_env_url(task_viz)))

            all_trials = {t.logdir: t for t in analysis.trials}
            best_logdir = analysis.get_best_logdir('Val nll', 'min')
            best_trial = all_trials[best_logdir]

            # picked_metric = 'accuracy_0'
            # metric_names = {s: '{} {}'.format(s, picked_metric) for s in
            #                 ['Train', 'Val', 'Test']}

            logger.info('Best trial: {}'.format(best_trial))
            best_res = best_trial.checkpoint.result
            best_point = (best_res['training_iteration'], best_res['Val nll'])

            # y_keys = ['mean_loss' if use_ray_logging else 'Val nll', 'train_loss']
            y_keys = ['Val nll', 'Train nll']

            epoch_key = 'training_epoch'
            it_key = 'training_iteration'
            plot_res_dataframe(analysis, training_params['name'], best_point,
                               task_viz, epoch_key, it_key, y_keys)
            if 'entropy' in next(iter(analysis.trial_dataframes.values())):
                plot_res_dataframe(analysis, training_params['name'], None,
                                    task_viz, epoch_key, it_key, ['entropy'])
            best_model = self.get_model(task_id=t_id)
            best_model.load_state_dict(torch.load(best_trial.checkpoint.value))

            train_accs = analysis.trial_dataframes[best_logdir]['Train accuracy_0']
            best_t = best_res['training_iteration']
            t = best_trial.last_result['training_iteration']
        else:
            search_space = self.get_search_space()
            rand_config = list(generate_variants(search_space))[0][1]
            learner_params = rand_config.pop('learner-params', {})
            optim_params = rand_config.pop('optim')


            split_optims = training_params.pop('split_optims')
            if hasattr(model, 'set_h_params'):
                model.set_h_params(**learner_params)
            if hasattr(model, 'train_loader_wrapper'):
                train_loader = model.train_loader_wrapper(train_loader)

            loss_fn = task['loss_fn']
            if hasattr(model, 'loss_wrapper'):
                loss_fn = model.loss_wrapper(task['loss_fn'])

            prepare_batch = _prepare_batch
            if hasattr(model, 'prepare_batch_wrapper'):
                prepare_batch = model.prepare_batch_wrapper(prepare_batch, t_id)

            optim_fact = partial(set_optim_params,
                                 optim_func=self.optim_func,
                                 optim_params=optim_params,
                                 split_optims=split_optims)
            if hasattr(model, 'train_func'):
                f = model.train_func
                t, metrics, b_state_dict = f(train_loader=train_loader,
                                                eval_loaders=eval_loaders,
                                                optim_fact=optim_fact,
                                                loss_fn=loss_fn,
                                                split_names=task['split_names'],
                                                viz=task_viz,
                                                prepare_batch=prepare_batch,
                                                **training_params)
            else:
                optim = optim_fact(model=model)
                t, metrics, b_state_dict = train(model=model,
                                                 train_loader=train_loader,
                                                 eval_loaders=eval_loaders,
                                                 optimizer=optim,
                                                 loss_fn=loss_fn,
                                                 split_names=task['split_names'],
                                                 viz=task_viz,
                                                 prepare_batch=prepare_batch,
                                                 **training_params)
            train_accs = metrics['Train accuracy_0']
            best_t = b_state_dict['iter']
            if 'training_archs' in metrics:
                plot_trajectory(model.ssn.graph, metrics['training_archs'],
                                model.ssn.stochastic_node_ids, task_viz)
                weights = model.arch_sampler().squeeze()
                archs = model.ssn.get_top_archs(weights, 5)
                list_top_archs(archs, task_viz)
                list_arch_scores(self.arch_scores[t_id], task_viz)
                update_summary(self.arch_scores[t_id], task_viz, 'scores')

        if len(train_accs) > lca_n:
            lca_accs = []
            for i in range(lca_n + 1):
                if i in train_accs:
                    lca_accs.append(train_accs[i])
                else:
                    logger.warning('Missing step for {}/{} for lca computation'
                                   .format(i, lca_n))
            lca = np.mean(lca_accs)
        else:
            lca = np.float('nan')
        stats = {}
        start = time.time()
        # train_idx = task['split_names'].index('Train')
        # train_path = task['data_path'][train_idx]
        # train_dataset = _load_datasets([train_path])[0]
        train_dataset = _load_datasets(task, 'Train')[0]
        stats.update(self.finish_task(train_dataset, t_id, task_viz,
                                      path='drawings'))
        stats['duration'] = {'iterations': t,
                             'finish': time.time() - start,
                             'best_iterations': best_t}
        stats['params'] = {'total': self.n_params(t_id),
                           'new': self.new_params(t_id)}
        stats['lca'] = lca
        return stats
Exemple #3
0
def train_single_task(t_id, task, tasks, vis_p, learner, config, transfer_matrix,
                      total_steps):

    training_params = config.pop('training-params')
    learner_params = config.pop('learner-params', {})
    assert 'model-params' not in config, "Can't have model-specific " \
                                         "parameters while tuning at the " \
                                         "stream level."

    if learner_params:
        learner.set_h_params(**learner_params)

    batch_sizes = training_params.pop('batch_sizes')
    # optim_func = training_params.pop('optim_func')
    optim_func = learner.optim_func
    optim_params = config.pop('optim')
    schedule_mode = training_params.pop('schedule_mode')
    split_optims = training_params.pop('split_optims')

    dropout = config.pop('dropout') if 'dropout' in config else None

    stream_setting = training_params.pop('stream_setting')
    plot_all = training_params.pop('plot_all')
    normalize = training_params.pop('normalize')
    augment_data = training_params.pop('augment_data')
    transformations = []
    if augment_data:
        transformations.extend([
            transforms.ToPILImage(),
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, 4),
            transforms.ToTensor()
        ])
    lca_n = training_params.pop('lca')

    if plot_all:
        vis_p = get_training_vis_conf(vis_p, tune.get_trial_dir())
        # print('NEW vis: ', vis_p)
        task_vis = visdom.Visdom(**vis_p)
        # env = [env[0], env[-1]]
        # vis_p['env'] = '_'.join(env)
        # vis_p['log_to_filename'] = os.path.join(vis_logdir, vis_p['env'])
        # g_task_vis = visdom.Visdom(**vis_p)

        logger.info(get_env_url(task_vis))
    else:
        task_vis = None

    t_trans = [[] for _ in range(len(task['split_names']))]
    t_trans[0] = transformations.copy()

    datasets_p = dict(task=task,
                      transforms=t_trans,
                      normalize=normalize)
    datasets = _load_datasets(**datasets_p)
    train_loader, eval_loaders = get_classic_dataloaders(datasets,
                                                         batch_sizes)

    assert t_id == task['id']

    start1 = time.time()
    model = learner.get_model(task['id'], x_dim=task['x_dim'],
                              n_classes=task['n_classes'],
                              descriptor=task['descriptor'],
                              dataset=eval_loaders[:2])
    model_creation_time = time.time() - start1

    loss_fn = task['loss_fn']
    training_params['loss_fn'] = loss_fn

    prepare_batch = _prepare_batch
    if hasattr(model, 'prepare_batch_wrapper'):
        prepare_batch = model.prepare_batch_wrapper(prepare_batch, t_id)

    if hasattr(model, 'loss_wrapper'):
        training_params['loss_fn'] = \
            model.loss_wrapper(training_params['loss_fn'])

    # if hasattr(model, 'backward_hook'):
    #     training_params[]

    # optim = set_optim_params(optim_func, optim_params, model, split_optims)
    optim_fact = partial(set_optim_params,
                         optim_func=optim_func,
                         optim_params=optim_params,
                         split_optims=split_optims)
    # if schedule_mode == 'steps':
    #     lr_scheduler = torch.optim.lr_scheduler.\
    #         MultiStepLR(optim[0], milestones=[25, 40])
    # elif schedule_mode == 'cos':
    #     lr_scheduler = torch.optim.lr_scheduler.\
    #         CosineAnnealingLR(optim[0], T_max=200, eta_min=0.001)
    # elif schedule_mode is None:
    #     lr_scheduler = None
    # else:
    #     raise NotImplementedError()
    if dropout is not None:
        set_dropout(model, dropout)

    assert not config, config
    start2 = time.time()
    rescaled, t, metrics, b_state_dict = train_model(model, datasets_p,
                                                     batch_sizes, optim_fact,
                                                     prepare_batch, task,
                                                     train_loader, eval_loaders,
                                                     training_params, config)

    training_time = time.time() - start2
    start3 = time.time()
    if not isinstance(model, ExhaustiveSearch):
        #todo Handle the state dict loading uniformly for all learners RN only
        # the exhaustive search models load the best state dict after training
        model.load_state_dict(b_state_dict['state_dict'])

    iterations = list(metrics.pop('training_iteration').values())
    epochs = list(metrics.pop('training_epoch').values())

    assert len(iterations) == len(epochs)
    index = dict(epochs=epochs, iterations=iterations)
    update_summary(index, task_vis, 'index', 0.5)

    grouped_xs = dict()
    grouped_metrics = defaultdict(list)
    grouped_legends = defaultdict(list)
    for metric_n, metric_v in metrics.items():
        split_n = metric_n.split()
        if len(split_n) < 2:
            continue
        name = ' '.join(split_n[:-1])
        grouped_metrics[split_n[-1]].append(list(metric_v.values()))
        grouped_legends[split_n[-1]].append(name)
        if split_n[-1] in grouped_xs:
            if len(metric_v) > len(grouped_xs[split_n[-1]]):
                longer_xs = list(metric_v.keys())
                assert all(a == b for a, b in zip(longer_xs,
                                                  grouped_xs[split_n[-1]]))
                grouped_xs[split_n[-1]] = longer_xs
        else:
            grouped_xs[split_n[-1]] = list(metric_v.keys())

    for (plot_name, val), (_, legends) in sorted(zip(grouped_metrics.items(),
                                                     grouped_legends.items())):
        assert plot_name == _
        val = fill_matrix(val)
        if len(val) == 1:
            val = np.array(val[0])
        else:
            val = np.array(val).transpose()
        x = grouped_xs[plot_name]
        task_vis.line(val, X=x, win=plot_name,
                      opts={'title': plot_name, 'showlegend': True,
                            'width': 500, 'legend': legends,
                            'xlabel': 'iterations', 'ylabel': plot_name})

    avg_data_time = list(metrics['data time_ps'].values())[-1]
    avg_forward_time = list(metrics['forward time_ps'].values())[-1]
    avg_epoch_time = list(metrics['epoch time_ps'].values())[-1]
    avg_eval_time = list(metrics['eval time_ps'].values())[-1]
    total_time = list(metrics['total time'].values())[-1]

    entropies, ent_legend = [], []
    for metric_n, metric_v in metrics.items():
        if metric_n.startswith('Trainer entropy'):
            entropies.append(list(metric_v.values()))
            ent_legend.append(metric_n)

    if entropies:
        task_vis.line(np.array(entropies).transpose(), X=iterations,
                      win='ENT',
                      opts={'title': 'Arch entropy', 'showlegend': True,
                            'width': 500, 'legend': ent_legend,
                            'xlabel': 'Iterations', 'ylabel': 'Loss'})

    if hasattr(learner, 'arch_scores') and hasattr(learner, 'get_top_archs'):
        update_summary(learner.arch_scores[t_id], task_vis, 'scores')
        archs = model.get_top_archs(5)
        list_top_archs(archs, task_vis)

    if 'training_archs' in metrics:
        plot_trajectory(model.ssn.graph, metrics['training_archs'],
                        model.ssn.stochastic_node_ids, task_vis)

    postproc_time = time.time() - start3
    start4 = time.time()
    save_path = tune.get_trial_dir()
    finish_res = learner.finish_task(datasets[0], t_id,
                                     task_vis, save_path)
    finish_time = time.time() - start4

    start5 = time.time()
    eval_tasks = tasks
    # eval_tasks = tasks[:t_id + 1] if stream_setting else tasks
    evaluation = evaluate_on_tasks(eval_tasks, learner, batch_sizes[1],
                                   training_params['device'],
                                   ['Val', 'Test'], normalize,
                                   cur_task=t_id)
    assert evaluation['Val']['accuracy'][t_id] == b_state_dict['value']

    stats = {}
    eval_time = time.time() - start5

    stats.update(finish_res)

    test_accs = metrics['Test accuracy_0']
    if not test_accs:
        lca = np.float('nan')
    else:
        if len(test_accs) <= lca_n:
            last_key = max(test_accs.keys())
            assert len(test_accs) == last_key + 1,\
                f"Can't compute LCA@{lca_n} if steps were skipped " \
                f"(got {list(test_accs.keys())})"
            test_accs = test_accs.copy()
            last_acc = test_accs[last_key]
            for i in range(last_key + 1, lca_n+1):
                test_accs[i] = last_acc
        lca = np.mean([test_accs[i] for i in range(lca_n + 1)])

    accs = {}
    key = 'accuracy'
    # logger.warning(evaluation)
    for split in evaluation.keys():
        transfer_matrix[split].append(evaluation[split][key])
        for i in range(len(tasks)):
            split_acc = evaluation[split][key]
            if i < len(split_acc):
                accs['{}_T{}'.format(split, i)] = split_acc[i]
            else:
                accs['{}_T{}'.format(split, i)] = float('nan')
    plot_heatmaps(list(transfer_matrix.keys()),
                  list(map(fill_matrix, transfer_matrix.values())),
                  task_vis)


    # logger.warning(t_id)
    # logger.warning(transfer_matrix)

    avg_val = np.mean(evaluation['Val']['accuracy'])
    avg_val_so_far = np.mean(evaluation['Val']['accuracy'][:t_id+1])
    avg_test = np.mean(evaluation['Test']['accuracy'])
    avg_test_so_far = np.mean(evaluation['Test']['accuracy'][:t_id+1])

    step_time_s = time.time() - start1
    step_sum = model_creation_time + training_time + postproc_time + \
               finish_time + eval_time
    best_it = b_state_dict.get('cum_best_iter', b_state_dict['iter'])
    tune.report(t=t_id,
                best_val=b_state_dict['value'],
                avg_acc_val=avg_val,
                avg_acc_val_so_far=avg_val_so_far,
                avg_acc_test_so_far=avg_test_so_far,
                lca=lca,
                avg_acc_test=avg_test,
                test_acc=evaluation['Test']['accuracy'][t_id],
                duration_seconds=step_time_s,
                duration_iterations=t,
                duration_best_it=best_it,
                duration_finish=finish_time,
                duration_model_creation=model_creation_time,
                duration_training=training_time,
                duration_postproc=postproc_time,
                duration_eval=eval_time,
                duration_sum=step_sum,
                # entropy=stats.pop('entropy'),
                new_params=learner.new_params(t_id),
                total_params=learner.n_params(t_id),
                total_steps=total_steps + t,
                fw_t=round(avg_forward_time * 1000) / 1000,
                data_t=round(avg_data_time * 1000) / 1000,
                epoch_t=round(avg_epoch_time * 1000) / 1000,
                eval_t=round(avg_eval_time * 1000) / 1000,
                total_t=round(total_time * 1000) / 1000,
                env_url=get_env_url(vis_p),
                **accs, **stats)
    return rescaled, t, metrics, b_state_dict, stats
    def __init__(self, task_gen, ll_models, cuda, n_it_max, n_tasks, patience,
                 grace_period, num_hp_samplings, visdom_traces_folder,
                 batch_sizes, plot_tasks, log_steps, log_epoch, name,
                 task_save_folder, use_ray, use_ray_logging, redis_address,
                 use_threads, local_mode, smoke_test, sacred_run, log_dir,
                 norm_models, resources, seed):
        self.task_gen = task_gen
        self.sims = None
        self.sims_comp = None
        self.name = name
        self.smoke_test = smoke_test

        assert isinstance(ll_models, dict)
        if 'finetune-mt-head' in ll_models:
            assert 'multitask-head' in ll_models and \
                   isinstance(ll_models['multitask-head'],
                              MultitaskHeadLLModel), \
                'Fine tune should be associated with multitak LLModel'
            ll_models['finetune-mt-head'].set_source_model(
                ll_models['multitask-head'])
        if 'finetune-mt-leg' in ll_models:
            assert 'multitask-leg' in ll_models and \
                   isinstance(ll_models['multitask-leg'], MultitaskLegLLModel), \
                'Fine tune leg should be associated with multitak Leg LLModel'
            ll_models['finetune-mt-leg'].set_source_model(
                ll_models['multitask-leg'])
        self.ll_models = ll_models
        self.norm_models = norm_models

        keys = list(self.ll_models.keys())
        self.norm_models_idx = [keys.index(nm) for nm in self.norm_models]

        if cuda and torch.cuda.is_available():
            self.device = 'cuda'
        else:
            self.device = 'cpu'
        self.use_ray = use_ray
        self.redis_address = redis_address
        self.use_threads = use_threads
        self.local_mode = local_mode
        self.use_ray_logging = use_ray_logging

        self.n_it_max = n_it_max
        self.n_tasks = n_tasks
        self.patience = patience
        self.grace_period = grace_period
        self.num_hp_samplings = num_hp_samplings
        self.resources = resources

        self.plot_tasks = plot_tasks
        self.batch_sizes = batch_sizes
        if os.path.isfile(VISDOM_CONF_PATH):
            self.visdom_conf = load_conf(VISDOM_CONF_PATH)
        else:
            self.visdom_conf = None

        self.log_steps = log_steps
        self.log_epoch = log_epoch

        self.sacred_run = sacred_run
        self.seed = seed

        self.exp_name = get_env_name(sacred_run.config, sacred_run._id)
        self.exp_dir = os.path.join(log_dir, self.exp_name)
        init_model_path = os.path.join(self.exp_dir, 'model_initializations')
        if not os.path.isdir(init_model_path):
            os.makedirs(init_model_path)
        self.visdom_traces_folder = os.path.join(visdom_traces_folder,
                                                 self.exp_name)

        self.task_save_folder = os.path.join(task_save_folder, self.exp_name)
        main_env = get_env_name(sacred_run.config, sacred_run._id, main=True)
        trace_file = os.path.join(self.visdom_traces_folder, main_env)
        self.main_viz = visdom.Visdom(env=main_env,
                                      log_to_filename=trace_file,
                                      **self.visdom_conf)
        task_env = '{}_tasks'.format(self.exp_name)
        trace_file = '{}/{}'.format(self.visdom_traces_folder,
                                    task_env)
        self.task_env = visdom.Visdom(env=task_env,
                                      log_to_filename=trace_file,
                                      **self.visdom_conf)

        self.summary = {'model': list(self.ll_models.keys()),
                        'speed': [],
                        'accuracy': []}
        update_summary(self.summary, self.main_viz)

        self.sacred_run.info['transfers'] = defaultdict(dict)
        self.task_envs_str = defaultdict(list)

        self.plot_labels = defaultdict()
        self.plot_labels.default_factory = self.plot_labels.__len__

        self.tune_register_lock = Lock()
        self.eval_lock = Lock()

        # Init metrics
        self.metrics = defaultdict(lambda: [[] for _ in self.ll_models])
        self.training_times_it = [[] for _ in self.ll_models]
        self.training_times_s = [[] for _ in self.ll_models]
        self.all_perfs = [[] for _ in self.ll_models]
        self.all_perfs_normalized = [[] for _ in self.ll_models]
        self.ideal_potentials = [[] for _ in self.ll_models]
        self.current_potentials = [[] for _ in self.ll_models]
Exemple #5
0
    def __init__(self, task_gen, ll_models, cuda, n_it_max, n_ep_max,
                 augment_data, normalize, single_pass, n_tasks, patience,
                 grace_period, num_hp_samplings, visdom_traces_folder,
                 plot_all, batch_sizes, plot_tasks, lca, log_steps, log_epoch,
                 name, task_save_folder, load_tasks_from, use_ray,
                 use_ray_logging, redis_address, use_processes, local_mode,
                 smoke_test, stream_setting, sacred_run, log_dir, norm_models,
                 val_per_task, schedule_mode, split_optims, ref_params_id,
                 seed):
        self.task_gen = task_gen
        self.sims = None
        self.sims_comp = None
        self.name = name

        assert isinstance(ll_models, dict)

        self.ll_models = ll_models
        self.learner_names = list(self.ll_models.keys())
        self.norm_models = norm_models

        keys = list(self.ll_models.keys())
        self.norm_models_idx = [keys.index(nm) for nm in self.norm_models]

        if cuda and torch.cuda.is_available():
            self.device = 'cuda'
        else:
            self.device = 'cpu'
        self.use_ray = use_ray
        self.redis_address = redis_address
        self.use_processes = use_processes
        self.local_mode = local_mode
        self.use_ray_logging = use_ray_logging

        self.single_pass = single_pass
        self.n_it_max = n_it_max
        self.n_ep_max = n_ep_max
        self.augment_data = augment_data
        self.normalize = normalize
        self.schedule_mode = schedule_mode

        self.n_tasks = n_tasks
        self.patience = patience
        self.grace_period = grace_period
        self.num_hp_samplings = num_hp_samplings
        self.stream_setting = stream_setting
        self.val_per_task = val_per_task
        self.split_optims = split_optims

        self.plot_tasks = plot_tasks
        self.batch_sizes = batch_sizes
        if os.path.isfile(VISDOM_CONF_PATH):
            self.visdom_conf = load_conf(VISDOM_CONF_PATH)
        else:
            self.visdom_conf = None

        self.lca = lca
        self.log_steps = log_steps
        self.log_epoch = log_epoch

        self.sacred_run = sacred_run
        self.seed = seed

        self.exp_name = get_env_name(sacred_run.config, sacred_run._id)
        self.exp_dir = os.path.join(log_dir, self.exp_name)
        self.init_model_path = os.path.join(self.exp_dir,
                                            'model_initializations')
        if not os.path.isdir(self.init_model_path):
            os.makedirs(self.init_model_path)
        if ref_params_id is None:
            self.ref_params_path = None
        else:
            assert isinstance(ref_params_id, int)
            self.ref_params_path = os.path.join(log_dir, str(ref_params_id),
                                                'model_initializations',
                                                'ref.pth')

        self.visdom_traces_folder = os.path.join(visdom_traces_folder,
                                                 self.exp_name)

        self.load_tasks = load_tasks_from is not None
        if self.load_tasks:
            self.data_path = os.path.join(task_save_folder,
                                          str(load_tasks_from))
            assert os.path.isdir(self.data_path), \
                '{} doesn\'t exists'.format(self.data_path)
        else:
            self.data_path = os.path.join(task_save_folder, self.exp_name)
        main_env = get_env_name(sacred_run.config, sacred_run._id, main=True)
        trace_file = os.path.join(self.visdom_traces_folder, main_env)
        self.main_viz_params = {
            'env': main_env,
            'log_to_filename': trace_file,
            **self.visdom_conf
        }
        self.main_viz = visdom.Visdom(**self.main_viz_params)
        task_env = '{}_tasks'.format(self.exp_name)
        trace_file = '{}/{}'.format(self.visdom_traces_folder, task_env)
        self.task_env = visdom.Visdom(env=task_env,
                                      log_to_filename=trace_file,
                                      **self.visdom_conf)
        self.plot_all = plot_all

        self.summary = {
            'model': list(self.ll_models.keys()),
            'speed': [float('nan')] * len(self.ll_models),
            'accuracy_t': [float('nan')] * len(self.ll_models),
            'accuracy_now': [float('nan')] * len(self.ll_models)
        }
        update_summary(self.summary, self.main_viz)

        self.param_summary = defaultdict(list)
        self.param_summary['Task id'] = list(range(self.n_tasks))

        self.sacred_run.info['transfers'] = defaultdict(dict)
        self.task_envs_str = defaultdict(list)
        self.best_task_envs_str = defaultdict(list)

        # List of dicts. Each dict contains the parameters of a Visdom env for
        # the corresponding task per learner. In the current version this envs
        # are never used directly but modified for each training to contain
        # the actual parameters used.
        self.training_envs = []
        self.task_envs = []

        self.plot_labels = defaultdict()
        self.plot_labels.default_factory = self.plot_labels.__len__

        self.tune_register_lock = threading.Lock()
        self.eval_lock = threading.Lock()

        # Init metrics
        self.metrics = defaultdict(lambda: [[] for _ in self.ll_models])
        self.training_times_it = [[] for _ in self.ll_models]
        self.training_times_s = [[] for _ in self.ll_models]
        self.all_perfs = [[] for _ in self.ll_models]
        self.all_perfs_normalized = [[] for _ in self.ll_models]
        self.ideal_potentials = [[] for _ in self.ll_models]
        self.current_potentials = [[] for _ in self.ll_models]
        self.n_params = [[] for _ in self.ll_models]