Ejemplo n.º 1
0
def train_model(model, datasets_p, batch_sizes, optim_fact, prepare_batch,
                task, train_loader, eval_loaders, training_params, config):
    if hasattr(model, 'train_func'):
        assert not config, config
        f = model.train_func
        t, metrics, b_state_dict = f(datasets_p=datasets_p,
                                     b_sizes=batch_sizes,
                                     optim_fact=optim_fact,
                                     # lr_scheduler=lr_scheduler,
                                     # viz=task_vis,
                                     prepare_batch=prepare_batch,
                                     split_names=task['split_names'],
                                     # viz=task_vis,
                                     **training_params)
        rescaled = list(
            filter(lambda itm: 'rescaled' in itm[0], metrics.items()))
        rescaled = rescaled[0][1]
    else:
        optim = optim_fact(model=model)
        if hasattr(model, 'train_loader_wrapper'):
            train_loader = model.train_loader_wrapper(train_loader)
        t, metrics, b_state_dict = train(model, train_loader, eval_loaders,
                                         optimizer=optim,
                                         # lr_scheduler=lr_scheduler,
                                         # viz=task_vis,
                                         prepare_batch=prepare_batch,
                                         split_names=task['split_names'],
                                         # viz=task_vis,
                                         **training_params)
        rescaled = metrics['Val accuracy_0']

    return rescaled, t, metrics, b_state_dict
    def _train(self):
        t, accs, self.best_state_dict = train(self.model,
                                              train_loader=self.train_loader,
                                              eval_loaders=self.eval_loaders,
                                              optimizer=self.optim,
                                              loss_fn=self.loss_fn,
                                              n_it_max=self.n_it_max,
                                              patience=self.patience,
                                              split_names=self.split_names,
                                              device=self.device,
                                              name=self.name,
                                              log_steps=self.log_steps,
                                              log_epoch=False)
        keys = []
        res = {}
        for k, v in accs.items():
            if isinstance(v, dict):
                assert list(v.keys()) == list(
                    accs['training_iterations'].values())
                res[k] = list(v.values())
                keys.append(k)
            else:
                res[k] = v

        return {'unroll_columns': keys,
                'should_checkpoint': True,
                'done': True,
                **res}
Ejemplo n.º 3
0
def wrap(*args, idx=None, uid=None, optim_fact, datasets_p, b_sizes, **kwargs):
    model = kwargs['model']
    optim = optim_fact(model=model)
    datasets = _load_datasets(**datasets_p)
    train_loader, eval_loaders = get_classic_dataloaders(datasets, b_sizes, 0)
    if hasattr(model, 'train_loader_wrapper'):
        train_loader = model.train_loader_wrapper(train_loader)

    res = train(*args, train_loader=train_loader, eval_loaders=eval_loaders,
                optimizer=optim, **kwargs)
    # logger.warning('{}=Received option {} results'.format(uid, idx))
    return res
Ejemplo n.º 4
0
    def train_model_on_task(self, task, task_viz, exp_dir, use_ray,
                            use_ray_logging, grace_period,
                            num_hp_samplings, local_mode,
                            redis_address, lca_n, **training_params):
        logger.info("Training dashboard: {}".format(get_env_url(task_viz)))
        t_id = task['id']

        trainable = self.get_trainable(use_ray_logging=use_ray_logging)
        past_tasks = training_params.pop('past_tasks')
        normalize = training_params.pop('normalize')
        augment_data = training_params.pop('augment_data')

        transformations = []
        if augment_data:
            transformations.extend([
                transforms.ToPILImage(),
                transforms.RandomHorizontalFlip(),
                transforms.RandomCrop(32, 4),
                transforms.ToTensor()
            ])
        t_trans = [[] for _ in range(len(task['split_names']))]
        t_trans[0] = transformations
        datasets = trainable._load_datasets(task,
                                            task['loss_fn'],
                                            past_tasks, t_trans, normalize)
        train_loader, eval_loaders = get_classic_dataloaders(datasets,
                                                             training_params.pop(
                                                                 'batch_sizes'))
        model = self.get_model(task_id=t_id, x_dim=task['x_dim'],
                               n_classes=task['n_classes'],
                               descriptor=task['descriptor'],
                               dataset=eval_loaders[:2])

        if use_ray:
            if not ray.is_initialized():
                ray.init(address=redis_address)

            scheduler = None

            training_params['loss_fn'] = tune.function(
                training_params['loss_fn'])
            training_params['optim_func'] = tune.function(self.optim_func)

            init_model_path = os.path.join(exp_dir, 'model_initializations')
            model_file_name = '{}_init.pth'.format(training_params['name'])
            model_path = os.path.join(init_model_path, model_file_name)
            torch.save(model, model_path)

            training_params['model_path'] = model_path
            config = {**self.get_search_space(),
                      'training-params': training_params}
            if use_ray_logging:
                stop_condition = {'training_iteration':
                                      training_params['n_it_max']}
                checkpoint_at_end = False
                keep_checkpoints_num = 1
                checkpoint_score_attr = 'min-Val nll'
            else:
                stop_condition = None
                # loggers = [JsonLogger, MyCSVLogger]
                checkpoint_at_end = False
                keep_checkpoints_num = None
                checkpoint_score_attr = None

            trainable = rename_class(trainable, training_params['name'])
            experiment = Experiment(
                name=training_params['name'],
                run=trainable,
                stop=stop_condition,
                config=config,
                resources_per_trial=self.ray_resources,
                num_samples=num_hp_samplings,
                local_dir=exp_dir,
                loggers=(JsonLogger, CSVLogger),
                checkpoint_at_end=checkpoint_at_end,
                keep_checkpoints_num=keep_checkpoints_num,
                checkpoint_score_attr=checkpoint_score_attr)

            analysis = tune.run(experiment,
                                scheduler=scheduler,
                                verbose=1,
                                raise_on_failed_trial=True,
                                # max_failures=-1,
                                # with_server=True,
                                # server_port=4321
                                )
            os.remove(model_path)
            logger.info("Training dashboard: {}".format(get_env_url(task_viz)))

            all_trials = {t.logdir: t for t in analysis.trials}
            best_logdir = analysis.get_best_logdir('Val nll', 'min')
            best_trial = all_trials[best_logdir]

            # picked_metric = 'accuracy_0'
            # metric_names = {s: '{} {}'.format(s, picked_metric) for s in
            #                 ['Train', 'Val', 'Test']}

            logger.info('Best trial: {}'.format(best_trial))
            best_res = best_trial.checkpoint.result
            best_point = (best_res['training_iteration'], best_res['Val nll'])

            # y_keys = ['mean_loss' if use_ray_logging else 'Val nll', 'train_loss']
            y_keys = ['Val nll', 'Train nll']

            epoch_key = 'training_epoch'
            it_key = 'training_iteration'
            plot_res_dataframe(analysis, training_params['name'], best_point,
                               task_viz, epoch_key, it_key, y_keys)
            if 'entropy' in next(iter(analysis.trial_dataframes.values())):
                plot_res_dataframe(analysis, training_params['name'], None,
                                    task_viz, epoch_key, it_key, ['entropy'])
            best_model = self.get_model(task_id=t_id)
            best_model.load_state_dict(torch.load(best_trial.checkpoint.value))

            train_accs = analysis.trial_dataframes[best_logdir]['Train accuracy_0']
            best_t = best_res['training_iteration']
            t = best_trial.last_result['training_iteration']
        else:
            search_space = self.get_search_space()
            rand_config = list(generate_variants(search_space))[0][1]
            learner_params = rand_config.pop('learner-params', {})
            optim_params = rand_config.pop('optim')


            split_optims = training_params.pop('split_optims')
            if hasattr(model, 'set_h_params'):
                model.set_h_params(**learner_params)
            if hasattr(model, 'train_loader_wrapper'):
                train_loader = model.train_loader_wrapper(train_loader)

            loss_fn = task['loss_fn']
            if hasattr(model, 'loss_wrapper'):
                loss_fn = model.loss_wrapper(task['loss_fn'])

            prepare_batch = _prepare_batch
            if hasattr(model, 'prepare_batch_wrapper'):
                prepare_batch = model.prepare_batch_wrapper(prepare_batch, t_id)

            optim_fact = partial(set_optim_params,
                                 optim_func=self.optim_func,
                                 optim_params=optim_params,
                                 split_optims=split_optims)
            if hasattr(model, 'train_func'):
                f = model.train_func
                t, metrics, b_state_dict = f(train_loader=train_loader,
                                                eval_loaders=eval_loaders,
                                                optim_fact=optim_fact,
                                                loss_fn=loss_fn,
                                                split_names=task['split_names'],
                                                viz=task_viz,
                                                prepare_batch=prepare_batch,
                                                **training_params)
            else:
                optim = optim_fact(model=model)
                t, metrics, b_state_dict = train(model=model,
                                                 train_loader=train_loader,
                                                 eval_loaders=eval_loaders,
                                                 optimizer=optim,
                                                 loss_fn=loss_fn,
                                                 split_names=task['split_names'],
                                                 viz=task_viz,
                                                 prepare_batch=prepare_batch,
                                                 **training_params)
            train_accs = metrics['Train accuracy_0']
            best_t = b_state_dict['iter']
            if 'training_archs' in metrics:
                plot_trajectory(model.ssn.graph, metrics['training_archs'],
                                model.ssn.stochastic_node_ids, task_viz)
                weights = model.arch_sampler().squeeze()
                archs = model.ssn.get_top_archs(weights, 5)
                list_top_archs(archs, task_viz)
                list_arch_scores(self.arch_scores[t_id], task_viz)
                update_summary(self.arch_scores[t_id], task_viz, 'scores')

        if len(train_accs) > lca_n:
            lca_accs = []
            for i in range(lca_n + 1):
                if i in train_accs:
                    lca_accs.append(train_accs[i])
                else:
                    logger.warning('Missing step for {}/{} for lca computation'
                                   .format(i, lca_n))
            lca = np.mean(lca_accs)
        else:
            lca = np.float('nan')
        stats = {}
        start = time.time()
        # train_idx = task['split_names'].index('Train')
        # train_path = task['data_path'][train_idx]
        # train_dataset = _load_datasets([train_path])[0]
        train_dataset = _load_datasets(task, 'Train')[0]
        stats.update(self.finish_task(train_dataset, t_id, task_viz,
                                      path='drawings'))
        stats['duration'] = {'iterations': t,
                             'finish': time.time() - start,
                             'best_iterations': best_t}
        stats['params'] = {'total': self.n_params(t_id),
                           'new': self.new_params(t_id)}
        stats['lca'] = lca
        return stats
Ejemplo n.º 5
0
    def train_model_on_task(self, task, task_viz, exp_dir, use_ray,
                            use_ray_logging, smoke_test, n_it_max, grace_period,
                            num_hp_samplings, local_mode, tune_register_lock,
                            resources, **training_params):
        logger.info("Training dashboard: {}".format(get_env_url(task_viz)))

        model = self.get_model(task_id=task.id)
        trainable = self.get_trainable(use_ray_logging=use_ray_logging)

        self.prepare_task(task, training_params)

        if use_ray:
            # Required to avoid collisions in Tune's global Registry:
            # https://github.com/ray-project/ray/blob/master/python/ray/tune/registry.py
            trainable = rename_class(trainable, training_params['name'])

            scheduler = None


            training_params['loss_fn'] = tune.function(
                training_params['loss_fn'])
            training_params['optim_func'] = tune.function(self.optim_func)
            training_params['n_it_max'] = n_it_max

            init_model_path = os.path.join(exp_dir, 'model_initializations')
            model_file_name = '{}_init.pth'.format(training_params['name'])
            model_path = os.path.join(init_model_path, model_file_name)
            torch.save(model, model_path)

            training_params['model_path'] = model_path
            config = {'hyper-params': self.get_search_space(smoke_test),
                      'tp': training_params}
            if use_ray_logging:
                stop_condition = {'training_iteration': n_it_max}
                loggers = None
            else:
                stop_condition = None
                loggers = [JsonLogger, MyCSVLogger]

            # We need to create the experiment using a lock here to avoid issues
            # with Tune's global registry, more specifically with the
            # `_to_flush` dict that may change during the iteration over it.
            # https://github.com/ray-project/ray/blob/e3c9f7e83a6007ded7ae7e99fcbe9fcaa371bad3/python/ray/tune/registry.py#L91-L93
            tune_register_lock.acquire()
            experiment = Experiment(
                name=training_params['name'],
                run=trainable,
                stop=stop_condition,
                config=config,
                resources_per_trial=resources,
                num_samples=num_hp_samplings,
                local_dir=exp_dir,
                loggers=loggers,
                keep_checkpoints_num=1,
                checkpoint_score_attr='min-mean_loss')
            tune_register_lock.release()

            analysis = tune.run(experiment,
                                scheduler=scheduler,
                                verbose=1,
                                raise_on_failed_trial=True,
                                # max_failures=-1,
                                # with_server=True,
                                # server_port=4321
                                )
            os.remove(model_path)
            logger.info("Training dashboard: {}".format(get_env_url(task_viz)))

            all_trials = {t.logdir: t for t in analysis.trials}
            best_logdir = analysis.get_best_logdir('mean_loss', 'min')
            best_trial = all_trials[best_logdir]

            # picked_metric = 'accuracy_0'
            # metric_names = {s: '{} {}'.format(s, picked_metric) for s in
            #                 ['Train', 'Val', 'Test']}

            logger.info('Best trial: {}'.format(best_trial))
            best_res = best_trial._checkpoint.last_result
            best_point = (best_res['training_iteration'], best_res['mean_loss'])

            y_keys = ['mean_loss' if use_ray_logging else 'Val nll', 'train_loss']
            epoch_key = 'training_epoch'
            it_key = 'training_iteration' if use_ray_logging else 'training_iterations'
            plot_res_dataframe(analysis, training_params['name'], best_point,
                               task_viz, epoch_key, it_key, y_keys)
            best_model = self.get_model(task_id=task.id)
            best_model.load_state_dict(torch.load(best_trial._checkpoint.value))

            t = best_trial._checkpoint.last_result['training_iteration']
        else:
            data_path = training_params.pop('data_path')
            past_tasks = training_params.pop('past_tasks')
            datasets = trainable._load_datasets(data_path,
                                                training_params['loss_fn'],
                                                past_tasks)
            train_loader, eval_loaders = get_classic_dataloaders(datasets,
                                                                 training_params.pop('batch_sizes'))
            optim = self.optim_func(model.parameters())

            t, accs, best_state_dict = train(model, train_loader, eval_loaders,
                                             optimizer=optim, viz=task_viz,
                                             n_it_max=n_it_max, **training_params)
        logger.info('Finishing task ...')
        t1 = time.time()
        self.finish_task(task.datasets[0])
        logger.info('done in {}s'.format(time.time() - t1))

        return t