Ejemplo n.º 1
0
def wrap(*args, idx=None, uid=None, optim_fact, datasets_p, b_sizes, **kwargs):
    model = kwargs['model']
    optim = optim_fact(model=model)
    datasets = _load_datasets(**datasets_p)
    train_loader, eval_loaders = get_classic_dataloaders(datasets, b_sizes, 0)
    if hasattr(model, 'train_loader_wrapper'):
        train_loader = model.train_loader_wrapper(train_loader)

    res = train(*args, train_loader=train_loader, eval_loaders=eval_loaders,
                optimizer=optim, **kwargs)
    # logger.warning('{}=Received option {} results'.format(uid, idx))
    return res
Ejemplo n.º 2
0
    def train_model_on_task(self, task, task_viz, exp_dir, use_ray,
                            use_ray_logging, grace_period,
                            num_hp_samplings, local_mode,
                            redis_address, lca_n, **training_params):
        logger.info("Training dashboard: {}".format(get_env_url(task_viz)))
        t_id = task['id']

        trainable = self.get_trainable(use_ray_logging=use_ray_logging)
        past_tasks = training_params.pop('past_tasks')
        normalize = training_params.pop('normalize')
        augment_data = training_params.pop('augment_data')

        transformations = []
        if augment_data:
            transformations.extend([
                transforms.ToPILImage(),
                transforms.RandomHorizontalFlip(),
                transforms.RandomCrop(32, 4),
                transforms.ToTensor()
            ])
        t_trans = [[] for _ in range(len(task['split_names']))]
        t_trans[0] = transformations
        datasets = trainable._load_datasets(task,
                                            task['loss_fn'],
                                            past_tasks, t_trans, normalize)
        train_loader, eval_loaders = get_classic_dataloaders(datasets,
                                                             training_params.pop(
                                                                 'batch_sizes'))
        model = self.get_model(task_id=t_id, x_dim=task['x_dim'],
                               n_classes=task['n_classes'],
                               descriptor=task['descriptor'],
                               dataset=eval_loaders[:2])

        if use_ray:
            if not ray.is_initialized():
                ray.init(address=redis_address)

            scheduler = None

            training_params['loss_fn'] = tune.function(
                training_params['loss_fn'])
            training_params['optim_func'] = tune.function(self.optim_func)

            init_model_path = os.path.join(exp_dir, 'model_initializations')
            model_file_name = '{}_init.pth'.format(training_params['name'])
            model_path = os.path.join(init_model_path, model_file_name)
            torch.save(model, model_path)

            training_params['model_path'] = model_path
            config = {**self.get_search_space(),
                      'training-params': training_params}
            if use_ray_logging:
                stop_condition = {'training_iteration':
                                      training_params['n_it_max']}
                checkpoint_at_end = False
                keep_checkpoints_num = 1
                checkpoint_score_attr = 'min-Val nll'
            else:
                stop_condition = None
                # loggers = [JsonLogger, MyCSVLogger]
                checkpoint_at_end = False
                keep_checkpoints_num = None
                checkpoint_score_attr = None

            trainable = rename_class(trainable, training_params['name'])
            experiment = Experiment(
                name=training_params['name'],
                run=trainable,
                stop=stop_condition,
                config=config,
                resources_per_trial=self.ray_resources,
                num_samples=num_hp_samplings,
                local_dir=exp_dir,
                loggers=(JsonLogger, CSVLogger),
                checkpoint_at_end=checkpoint_at_end,
                keep_checkpoints_num=keep_checkpoints_num,
                checkpoint_score_attr=checkpoint_score_attr)

            analysis = tune.run(experiment,
                                scheduler=scheduler,
                                verbose=1,
                                raise_on_failed_trial=True,
                                # max_failures=-1,
                                # with_server=True,
                                # server_port=4321
                                )
            os.remove(model_path)
            logger.info("Training dashboard: {}".format(get_env_url(task_viz)))

            all_trials = {t.logdir: t for t in analysis.trials}
            best_logdir = analysis.get_best_logdir('Val nll', 'min')
            best_trial = all_trials[best_logdir]

            # picked_metric = 'accuracy_0'
            # metric_names = {s: '{} {}'.format(s, picked_metric) for s in
            #                 ['Train', 'Val', 'Test']}

            logger.info('Best trial: {}'.format(best_trial))
            best_res = best_trial.checkpoint.result
            best_point = (best_res['training_iteration'], best_res['Val nll'])

            # y_keys = ['mean_loss' if use_ray_logging else 'Val nll', 'train_loss']
            y_keys = ['Val nll', 'Train nll']

            epoch_key = 'training_epoch'
            it_key = 'training_iteration'
            plot_res_dataframe(analysis, training_params['name'], best_point,
                               task_viz, epoch_key, it_key, y_keys)
            if 'entropy' in next(iter(analysis.trial_dataframes.values())):
                plot_res_dataframe(analysis, training_params['name'], None,
                                    task_viz, epoch_key, it_key, ['entropy'])
            best_model = self.get_model(task_id=t_id)
            best_model.load_state_dict(torch.load(best_trial.checkpoint.value))

            train_accs = analysis.trial_dataframes[best_logdir]['Train accuracy_0']
            best_t = best_res['training_iteration']
            t = best_trial.last_result['training_iteration']
        else:
            search_space = self.get_search_space()
            rand_config = list(generate_variants(search_space))[0][1]
            learner_params = rand_config.pop('learner-params', {})
            optim_params = rand_config.pop('optim')


            split_optims = training_params.pop('split_optims')
            if hasattr(model, 'set_h_params'):
                model.set_h_params(**learner_params)
            if hasattr(model, 'train_loader_wrapper'):
                train_loader = model.train_loader_wrapper(train_loader)

            loss_fn = task['loss_fn']
            if hasattr(model, 'loss_wrapper'):
                loss_fn = model.loss_wrapper(task['loss_fn'])

            prepare_batch = _prepare_batch
            if hasattr(model, 'prepare_batch_wrapper'):
                prepare_batch = model.prepare_batch_wrapper(prepare_batch, t_id)

            optim_fact = partial(set_optim_params,
                                 optim_func=self.optim_func,
                                 optim_params=optim_params,
                                 split_optims=split_optims)
            if hasattr(model, 'train_func'):
                f = model.train_func
                t, metrics, b_state_dict = f(train_loader=train_loader,
                                                eval_loaders=eval_loaders,
                                                optim_fact=optim_fact,
                                                loss_fn=loss_fn,
                                                split_names=task['split_names'],
                                                viz=task_viz,
                                                prepare_batch=prepare_batch,
                                                **training_params)
            else:
                optim = optim_fact(model=model)
                t, metrics, b_state_dict = train(model=model,
                                                 train_loader=train_loader,
                                                 eval_loaders=eval_loaders,
                                                 optimizer=optim,
                                                 loss_fn=loss_fn,
                                                 split_names=task['split_names'],
                                                 viz=task_viz,
                                                 prepare_batch=prepare_batch,
                                                 **training_params)
            train_accs = metrics['Train accuracy_0']
            best_t = b_state_dict['iter']
            if 'training_archs' in metrics:
                plot_trajectory(model.ssn.graph, metrics['training_archs'],
                                model.ssn.stochastic_node_ids, task_viz)
                weights = model.arch_sampler().squeeze()
                archs = model.ssn.get_top_archs(weights, 5)
                list_top_archs(archs, task_viz)
                list_arch_scores(self.arch_scores[t_id], task_viz)
                update_summary(self.arch_scores[t_id], task_viz, 'scores')

        if len(train_accs) > lca_n:
            lca_accs = []
            for i in range(lca_n + 1):
                if i in train_accs:
                    lca_accs.append(train_accs[i])
                else:
                    logger.warning('Missing step for {}/{} for lca computation'
                                   .format(i, lca_n))
            lca = np.mean(lca_accs)
        else:
            lca = np.float('nan')
        stats = {}
        start = time.time()
        # train_idx = task['split_names'].index('Train')
        # train_path = task['data_path'][train_idx]
        # train_dataset = _load_datasets([train_path])[0]
        train_dataset = _load_datasets(task, 'Train')[0]
        stats.update(self.finish_task(train_dataset, t_id, task_viz,
                                      path='drawings'))
        stats['duration'] = {'iterations': t,
                             'finish': time.time() - start,
                             'best_iterations': best_t}
        stats['params'] = {'total': self.n_params(t_id),
                           'new': self.new_params(t_id)}
        stats['lca'] = lca
        return stats
Ejemplo n.º 3
0
    def train_model_on_task(self, task, task_viz, exp_dir, use_ray,
                            use_ray_logging, smoke_test, n_it_max, grace_period,
                            num_hp_samplings, local_mode, tune_register_lock,
                            resources, **training_params):
        logger.info("Training dashboard: {}".format(get_env_url(task_viz)))

        model = self.get_model(task_id=task.id)
        trainable = self.get_trainable(use_ray_logging=use_ray_logging)

        self.prepare_task(task, training_params)

        if use_ray:
            # Required to avoid collisions in Tune's global Registry:
            # https://github.com/ray-project/ray/blob/master/python/ray/tune/registry.py
            trainable = rename_class(trainable, training_params['name'])

            scheduler = None


            training_params['loss_fn'] = tune.function(
                training_params['loss_fn'])
            training_params['optim_func'] = tune.function(self.optim_func)
            training_params['n_it_max'] = n_it_max

            init_model_path = os.path.join(exp_dir, 'model_initializations')
            model_file_name = '{}_init.pth'.format(training_params['name'])
            model_path = os.path.join(init_model_path, model_file_name)
            torch.save(model, model_path)

            training_params['model_path'] = model_path
            config = {'hyper-params': self.get_search_space(smoke_test),
                      'tp': training_params}
            if use_ray_logging:
                stop_condition = {'training_iteration': n_it_max}
                loggers = None
            else:
                stop_condition = None
                loggers = [JsonLogger, MyCSVLogger]

            # We need to create the experiment using a lock here to avoid issues
            # with Tune's global registry, more specifically with the
            # `_to_flush` dict that may change during the iteration over it.
            # https://github.com/ray-project/ray/blob/e3c9f7e83a6007ded7ae7e99fcbe9fcaa371bad3/python/ray/tune/registry.py#L91-L93
            tune_register_lock.acquire()
            experiment = Experiment(
                name=training_params['name'],
                run=trainable,
                stop=stop_condition,
                config=config,
                resources_per_trial=resources,
                num_samples=num_hp_samplings,
                local_dir=exp_dir,
                loggers=loggers,
                keep_checkpoints_num=1,
                checkpoint_score_attr='min-mean_loss')
            tune_register_lock.release()

            analysis = tune.run(experiment,
                                scheduler=scheduler,
                                verbose=1,
                                raise_on_failed_trial=True,
                                # max_failures=-1,
                                # with_server=True,
                                # server_port=4321
                                )
            os.remove(model_path)
            logger.info("Training dashboard: {}".format(get_env_url(task_viz)))

            all_trials = {t.logdir: t for t in analysis.trials}
            best_logdir = analysis.get_best_logdir('mean_loss', 'min')
            best_trial = all_trials[best_logdir]

            # picked_metric = 'accuracy_0'
            # metric_names = {s: '{} {}'.format(s, picked_metric) for s in
            #                 ['Train', 'Val', 'Test']}

            logger.info('Best trial: {}'.format(best_trial))
            best_res = best_trial._checkpoint.last_result
            best_point = (best_res['training_iteration'], best_res['mean_loss'])

            y_keys = ['mean_loss' if use_ray_logging else 'Val nll', 'train_loss']
            epoch_key = 'training_epoch'
            it_key = 'training_iteration' if use_ray_logging else 'training_iterations'
            plot_res_dataframe(analysis, training_params['name'], best_point,
                               task_viz, epoch_key, it_key, y_keys)
            best_model = self.get_model(task_id=task.id)
            best_model.load_state_dict(torch.load(best_trial._checkpoint.value))

            t = best_trial._checkpoint.last_result['training_iteration']
        else:
            data_path = training_params.pop('data_path')
            past_tasks = training_params.pop('past_tasks')
            datasets = trainable._load_datasets(data_path,
                                                training_params['loss_fn'],
                                                past_tasks)
            train_loader, eval_loaders = get_classic_dataloaders(datasets,
                                                                 training_params.pop('batch_sizes'))
            optim = self.optim_func(model.parameters())

            t, accs, best_state_dict = train(model, train_loader, eval_loaders,
                                             optimizer=optim, viz=task_viz,
                                             n_it_max=n_it_max, **training_params)
        logger.info('Finishing task ...')
        t1 = time.time()
        self.finish_task(task.datasets[0])
        logger.info('done in {}s'.format(time.time() - t1))

        return t
Ejemplo n.º 4
0
def train_single_task(t_id, task, tasks, vis_p, learner, config, transfer_matrix,
                      total_steps):

    training_params = config.pop('training-params')
    learner_params = config.pop('learner-params', {})
    assert 'model-params' not in config, "Can't have model-specific " \
                                         "parameters while tuning at the " \
                                         "stream level."

    if learner_params:
        learner.set_h_params(**learner_params)

    batch_sizes = training_params.pop('batch_sizes')
    # optim_func = training_params.pop('optim_func')
    optim_func = learner.optim_func
    optim_params = config.pop('optim')
    schedule_mode = training_params.pop('schedule_mode')
    split_optims = training_params.pop('split_optims')

    dropout = config.pop('dropout') if 'dropout' in config else None

    stream_setting = training_params.pop('stream_setting')
    plot_all = training_params.pop('plot_all')
    normalize = training_params.pop('normalize')
    augment_data = training_params.pop('augment_data')
    transformations = []
    if augment_data:
        transformations.extend([
            transforms.ToPILImage(),
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, 4),
            transforms.ToTensor()
        ])
    lca_n = training_params.pop('lca')

    if plot_all:
        vis_p = get_training_vis_conf(vis_p, tune.get_trial_dir())
        # print('NEW vis: ', vis_p)
        task_vis = visdom.Visdom(**vis_p)
        # env = [env[0], env[-1]]
        # vis_p['env'] = '_'.join(env)
        # vis_p['log_to_filename'] = os.path.join(vis_logdir, vis_p['env'])
        # g_task_vis = visdom.Visdom(**vis_p)

        logger.info(get_env_url(task_vis))
    else:
        task_vis = None

    t_trans = [[] for _ in range(len(task['split_names']))]
    t_trans[0] = transformations.copy()

    datasets_p = dict(task=task,
                      transforms=t_trans,
                      normalize=normalize)
    datasets = _load_datasets(**datasets_p)
    train_loader, eval_loaders = get_classic_dataloaders(datasets,
                                                         batch_sizes)

    assert t_id == task['id']

    start1 = time.time()
    model = learner.get_model(task['id'], x_dim=task['x_dim'],
                              n_classes=task['n_classes'],
                              descriptor=task['descriptor'],
                              dataset=eval_loaders[:2])
    model_creation_time = time.time() - start1

    loss_fn = task['loss_fn']
    training_params['loss_fn'] = loss_fn

    prepare_batch = _prepare_batch
    if hasattr(model, 'prepare_batch_wrapper'):
        prepare_batch = model.prepare_batch_wrapper(prepare_batch, t_id)

    if hasattr(model, 'loss_wrapper'):
        training_params['loss_fn'] = \
            model.loss_wrapper(training_params['loss_fn'])

    # if hasattr(model, 'backward_hook'):
    #     training_params[]

    # optim = set_optim_params(optim_func, optim_params, model, split_optims)
    optim_fact = partial(set_optim_params,
                         optim_func=optim_func,
                         optim_params=optim_params,
                         split_optims=split_optims)
    # if schedule_mode == 'steps':
    #     lr_scheduler = torch.optim.lr_scheduler.\
    #         MultiStepLR(optim[0], milestones=[25, 40])
    # elif schedule_mode == 'cos':
    #     lr_scheduler = torch.optim.lr_scheduler.\
    #         CosineAnnealingLR(optim[0], T_max=200, eta_min=0.001)
    # elif schedule_mode is None:
    #     lr_scheduler = None
    # else:
    #     raise NotImplementedError()
    if dropout is not None:
        set_dropout(model, dropout)

    assert not config, config
    start2 = time.time()
    rescaled, t, metrics, b_state_dict = train_model(model, datasets_p,
                                                     batch_sizes, optim_fact,
                                                     prepare_batch, task,
                                                     train_loader, eval_loaders,
                                                     training_params, config)

    training_time = time.time() - start2
    start3 = time.time()
    if not isinstance(model, ExhaustiveSearch):
        #todo Handle the state dict loading uniformly for all learners RN only
        # the exhaustive search models load the best state dict after training
        model.load_state_dict(b_state_dict['state_dict'])

    iterations = list(metrics.pop('training_iteration').values())
    epochs = list(metrics.pop('training_epoch').values())

    assert len(iterations) == len(epochs)
    index = dict(epochs=epochs, iterations=iterations)
    update_summary(index, task_vis, 'index', 0.5)

    grouped_xs = dict()
    grouped_metrics = defaultdict(list)
    grouped_legends = defaultdict(list)
    for metric_n, metric_v in metrics.items():
        split_n = metric_n.split()
        if len(split_n) < 2:
            continue
        name = ' '.join(split_n[:-1])
        grouped_metrics[split_n[-1]].append(list(metric_v.values()))
        grouped_legends[split_n[-1]].append(name)
        if split_n[-1] in grouped_xs:
            if len(metric_v) > len(grouped_xs[split_n[-1]]):
                longer_xs = list(metric_v.keys())
                assert all(a == b for a, b in zip(longer_xs,
                                                  grouped_xs[split_n[-1]]))
                grouped_xs[split_n[-1]] = longer_xs
        else:
            grouped_xs[split_n[-1]] = list(metric_v.keys())

    for (plot_name, val), (_, legends) in sorted(zip(grouped_metrics.items(),
                                                     grouped_legends.items())):
        assert plot_name == _
        val = fill_matrix(val)
        if len(val) == 1:
            val = np.array(val[0])
        else:
            val = np.array(val).transpose()
        x = grouped_xs[plot_name]
        task_vis.line(val, X=x, win=plot_name,
                      opts={'title': plot_name, 'showlegend': True,
                            'width': 500, 'legend': legends,
                            'xlabel': 'iterations', 'ylabel': plot_name})

    avg_data_time = list(metrics['data time_ps'].values())[-1]
    avg_forward_time = list(metrics['forward time_ps'].values())[-1]
    avg_epoch_time = list(metrics['epoch time_ps'].values())[-1]
    avg_eval_time = list(metrics['eval time_ps'].values())[-1]
    total_time = list(metrics['total time'].values())[-1]

    entropies, ent_legend = [], []
    for metric_n, metric_v in metrics.items():
        if metric_n.startswith('Trainer entropy'):
            entropies.append(list(metric_v.values()))
            ent_legend.append(metric_n)

    if entropies:
        task_vis.line(np.array(entropies).transpose(), X=iterations,
                      win='ENT',
                      opts={'title': 'Arch entropy', 'showlegend': True,
                            'width': 500, 'legend': ent_legend,
                            'xlabel': 'Iterations', 'ylabel': 'Loss'})

    if hasattr(learner, 'arch_scores') and hasattr(learner, 'get_top_archs'):
        update_summary(learner.arch_scores[t_id], task_vis, 'scores')
        archs = model.get_top_archs(5)
        list_top_archs(archs, task_vis)

    if 'training_archs' in metrics:
        plot_trajectory(model.ssn.graph, metrics['training_archs'],
                        model.ssn.stochastic_node_ids, task_vis)

    postproc_time = time.time() - start3
    start4 = time.time()
    save_path = tune.get_trial_dir()
    finish_res = learner.finish_task(datasets[0], t_id,
                                     task_vis, save_path)
    finish_time = time.time() - start4

    start5 = time.time()
    eval_tasks = tasks
    # eval_tasks = tasks[:t_id + 1] if stream_setting else tasks
    evaluation = evaluate_on_tasks(eval_tasks, learner, batch_sizes[1],
                                   training_params['device'],
                                   ['Val', 'Test'], normalize,
                                   cur_task=t_id)
    assert evaluation['Val']['accuracy'][t_id] == b_state_dict['value']

    stats = {}
    eval_time = time.time() - start5

    stats.update(finish_res)

    test_accs = metrics['Test accuracy_0']
    if not test_accs:
        lca = np.float('nan')
    else:
        if len(test_accs) <= lca_n:
            last_key = max(test_accs.keys())
            assert len(test_accs) == last_key + 1,\
                f"Can't compute LCA@{lca_n} if steps were skipped " \
                f"(got {list(test_accs.keys())})"
            test_accs = test_accs.copy()
            last_acc = test_accs[last_key]
            for i in range(last_key + 1, lca_n+1):
                test_accs[i] = last_acc
        lca = np.mean([test_accs[i] for i in range(lca_n + 1)])

    accs = {}
    key = 'accuracy'
    # logger.warning(evaluation)
    for split in evaluation.keys():
        transfer_matrix[split].append(evaluation[split][key])
        for i in range(len(tasks)):
            split_acc = evaluation[split][key]
            if i < len(split_acc):
                accs['{}_T{}'.format(split, i)] = split_acc[i]
            else:
                accs['{}_T{}'.format(split, i)] = float('nan')
    plot_heatmaps(list(transfer_matrix.keys()),
                  list(map(fill_matrix, transfer_matrix.values())),
                  task_vis)


    # logger.warning(t_id)
    # logger.warning(transfer_matrix)

    avg_val = np.mean(evaluation['Val']['accuracy'])
    avg_val_so_far = np.mean(evaluation['Val']['accuracy'][:t_id+1])
    avg_test = np.mean(evaluation['Test']['accuracy'])
    avg_test_so_far = np.mean(evaluation['Test']['accuracy'][:t_id+1])

    step_time_s = time.time() - start1
    step_sum = model_creation_time + training_time + postproc_time + \
               finish_time + eval_time
    best_it = b_state_dict.get('cum_best_iter', b_state_dict['iter'])
    tune.report(t=t_id,
                best_val=b_state_dict['value'],
                avg_acc_val=avg_val,
                avg_acc_val_so_far=avg_val_so_far,
                avg_acc_test_so_far=avg_test_so_far,
                lca=lca,
                avg_acc_test=avg_test,
                test_acc=evaluation['Test']['accuracy'][t_id],
                duration_seconds=step_time_s,
                duration_iterations=t,
                duration_best_it=best_it,
                duration_finish=finish_time,
                duration_model_creation=model_creation_time,
                duration_training=training_time,
                duration_postproc=postproc_time,
                duration_eval=eval_time,
                duration_sum=step_sum,
                # entropy=stats.pop('entropy'),
                new_params=learner.new_params(t_id),
                total_params=learner.n_params(t_id),
                total_steps=total_steps + t,
                fw_t=round(avg_forward_time * 1000) / 1000,
                data_t=round(avg_data_time * 1000) / 1000,
                epoch_t=round(avg_epoch_time * 1000) / 1000,
                eval_t=round(avg_eval_time * 1000) / 1000,
                total_t=round(total_time * 1000) / 1000,
                env_url=get_env_url(vis_p),
                **accs, **stats)
    return rescaled, t, metrics, b_state_dict, stats
 def _get_dataloaders(datasets, batch_sizes):
     return get_classic_dataloaders(datasets, batch_sizes)