Exemple #1
0
    def fit(self,
            train_data,
            valid_data=None,
            verbose=True,
            saved=True,
            show_progress=False):
        # load model
        if self.boost_model is not None:
            self.model.load_model(self.boost_model)

        self.best_valid_score = 0.
        self.best_valid_result = 0.

        for epoch_idx in range(self.epochs):
            self._train_at_once(train_data, valid_data)

            if (epoch_idx + 1) % self.eval_step == 0:
                # evaluate
                valid_start_time = time()
                valid_result, valid_score = self._valid_epoch(valid_data)
                valid_end_time = time()
                valid_score_output = (set_color("epoch %d evaluating", 'green') + " [" + set_color("time", 'blue')
                                    + ": %.2fs, " + set_color("valid_score", 'blue') + ": %f]") % \
                                     (epoch_idx, valid_end_time - valid_start_time, valid_score)
                valid_result_output = set_color(
                    'valid result', 'blue') + ': \n' + dict2str(valid_result)
                if verbose:
                    self.logger.info(valid_score_output)
                    self.logger.info(valid_result_output)

                self.best_valid_score = valid_score
                self.best_valid_result = valid_result

        return self.best_valid_score, self.best_valid_result
Exemple #2
0
def run_recbole(model=None, dataset=None, config_file_list=None, config_dict=None, saved=True):
    r""" A fast running api, which includes the complete process of
    training and testing a model on a specified dataset

    Args:
        model (str): model name
        dataset (str): dataset name
        config_file_list (list): config files used to modify experiment parameters
        config_dict (dict): parameters dictionary used to modify experiment parameters
        saved (bool): whether to save the model
    """
    # configurations initialization
    config = Config(model=model, dataset=dataset, config_file_list=config_file_list, config_dict=config_dict)
    init_seed(config['seed'], config['reproducibility'])

    # logger initialization
    init_logger(config)
    logger = getLogger()

    logger.info(config)

    # dataset filtering
    dataset = create_dataset(config)
    logger.info(dataset)

    # dataset splitting
    train_data, valid_data, test_data = data_preparation(config, dataset)
    print(train_data.dataset.item_feat)
    print(valid_data.dataset.item_feat)

    # model loading and initialization
    model = get_model(config['model'])(config, train_data).to(config['device'])
    logger.info(model)

    # trainer loading and initialization
    trainer = get_trainer(config['MODEL_TYPE'], config['model'])(config, model)

    # model training
    with profiler.profile(enabled=config["monitor"], with_stack=True, profile_memory=True, use_cuda=True) as prof:
        best_valid_score, best_valid_result = trainer.fit(
            train_data, valid_data, saved=saved, show_progress=config['show_progress']
        )
    if prof is not None:
        print(prof.key_averages(group_by_stack_n=5).table(sort_by='self_cpu_time_total'))

    # model evaluation
    with profiler.profile(enabled=config["monitor_eval"], with_stack=True, profile_memory=True, use_cuda=True) as prof:
        test_result = trainer.evaluate(test_data, load_best_model=saved, show_progress=config['show_progress'], cold_warm_distinct_eval=True)
    if prof is not None:
        print(prof.key_averages(group_by_stack_n=5).table(sort_by='self_cpu_time_total'))

    logger.info(set_color('best valid ', 'yellow') + f': {best_valid_result}')
    logger.info(set_color('test result', 'yellow') + f': {test_result}')

    return {
        'best_valid_score': best_valid_score,
        'valid_score_bigger': config['valid_metric_bigger'],
        'best_valid_result': best_valid_result,
        'test_result': test_result
    }
Exemple #3
0
 def _generate_train_loss_output(self, epoch_idx, s_time, e_time, losses):
     des = self.config['loss_decimal_place'] or 4
     train_loss_output = (set_color('epoch %d training', 'green') + ' [' + set_color('time', 'blue') +
                          ': %.2fs, ') % (epoch_idx, e_time - s_time)
     if isinstance(losses, tuple):
         des = (set_color('train_loss%d', 'blue') + ': %.' + str(des) + 'f')
         train_loss_output += ', '.join(des % (idx + 1, loss) for idx, loss in enumerate(losses))
     else:
         des = '%.' + str(des) + 'f'
         train_loss_output += set_color('train loss', 'blue') + ': ' + des % losses
     return train_loss_output + ']'
Exemple #4
0
    def _get_field_from_config(self):
        super()._get_field_from_config()

        self.source_field = self.config['SOURCE_ID_FIELD']
        self.target_field = self.config['TARGET_ID_FIELD']
        self._check_field('source_field', 'target_field')

        self.logger.debug(
            set_color('source_id_field', 'blue') + f': {self.source_field}')
        self.logger.debug(
            set_color('target_id_field', 'blue') + f': {self.target_field}')
Exemple #5
0
    def _get_field_from_config(self):
        super()._get_field_from_config()

        self.head_entity_field = self.config['HEAD_ENTITY_ID_FIELD']
        self.tail_entity_field = self.config['TAIL_ENTITY_ID_FIELD']
        self.relation_field = self.config['RELATION_ID_FIELD']
        self.entity_field = self.config['ENTITY_ID_FIELD']
        self._check_field('head_entity_field', 'tail_entity_field', 'relation_field', 'entity_field')
        self.set_field_property(self.entity_field, FeatureType.TOKEN, FeatureSource.KG, 1)

        self.logger.debug(set_color('relation_field', 'blue') + f': {self.relation_field}')
        self.logger.debug(set_color('entity_field', 'blue') + f': {self.entity_field}')
Exemple #6
0
def run_recbole(model=None, dataset=None, config_file_list=None, config_dict=None, saved=True):
    r""" A fast running api, which includes the complete process of
    training and testing a model on a specified dataset

    Args:
        model (str): model name
        dataset (str): dataset name
        config_file_list (list): config files used to modify experiment parameters
        config_dict (dict): parameters dictionary used to modify experiment parameters
        saved (bool): whether to save the model
    """
    # configurations initialization
    config = Config(model=model, dataset=dataset, config_file_list=config_file_list, config_dict=config_dict)
    init_seed(config['seed'], config['reproducibility'])

    # logger initialization
    init_logger(config)
    logger = getLogger()

    logger.info(config)

    # dataset filtering
    dataset = create_dataset(config)
    logger.info(dataset)

    # dataset splitting
    train_data, valid_data, test_data = data_preparation(config, dataset)

    # model loading and initialization
    model = get_model(config['model'])(config, train_data).to(config['device'])
    logger.info(model)

    # trainer loading and initialization
    trainer = get_trainer(config['MODEL_TYPE'], config['model'])(config, model)

    # model training
    best_valid_score, best_valid_result = trainer.fit(
        train_data, valid_data, saved=saved, show_progress=config['show_progress']
    )

    # model evaluation
    test_result = trainer.evaluate(test_data, load_best_model=saved, show_progress=config['show_progress'])

    logger.info(set_color('best valid ', 'yellow') + f': {best_valid_result}')
    logger.info(set_color('test result', 'yellow') + f': {test_result}')

    return {
        'best_valid_score': best_valid_score,
        'valid_score_bigger': config['valid_metric_bigger'],
        'best_valid_result': best_valid_result,
        'test_result': test_result
    }
Exemple #7
0
    def pretrain(self, train_data, verbose=True, show_progress=False):

        for epoch_idx in range(self.start_epoch, self.epochs):
            # train
            training_start_time = time()
            train_loss = self._train_epoch(train_data,
                                           epoch_idx,
                                           show_progress=show_progress)
            self.train_loss_dict[epoch_idx] = sum(train_loss) if isinstance(
                train_loss, tuple) else train_loss
            training_end_time = time()
            train_loss_output = \
                self._generate_train_loss_output(epoch_idx, training_start_time, training_end_time, train_loss)
            if verbose:
                self.logger.info(train_loss_output)

            if (epoch_idx + 1) % self.config['save_step'] == 0:
                saved_model_file = os.path.join(
                    self.checkpoint_dir,
                    '{}-{}-{}.pth'.format(self.config['model'],
                                          self.config['dataset'],
                                          str(epoch_idx + 1)))
                self.save_pretrained_model(epoch_idx, saved_model_file)
                update_output = set_color('Saving current',
                                          'blue') + ': %s' % saved_model_file
                if verbose:
                    self.logger.info(update_output)

        return self.best_valid_score, self.best_valid_result
Exemple #8
0
    def _train_epoch(
        self, train_data, epoch_idx, n_epochs, optimizer, encoder_flag, loss_func=None, show_progress=False
    ):
        self.model.train()
        loss_func = loss_func or self.model.calculate_loss
        total_loss = None
        iter_data = (
            tqdm(
                enumerate(train_data),
                total=len(train_data),
                desc=set_color(f"Train {epoch_idx:>5}", 'pink'),
            ) if show_progress else enumerate(train_data)
        )
        for epoch in range(n_epochs):
            for batch_idx, interaction in iter_data:
                interaction = interaction.to(self.device)
                optimizer.zero_grad()
                losses = loss_func(interaction, encoder_flag=encoder_flag)
                if isinstance(losses, tuple):
                    loss = sum(losses)
                    loss_tuple = tuple(per_loss.item() for per_loss in losses)
                    total_loss = loss_tuple if total_loss is None else tuple(map(sum, zip(total_loss, loss_tuple)))
                else:
                    loss = losses
                    total_loss = losses.item() if total_loss is None else total_loss + losses.item()
                self._check_nan(loss)
                loss.backward()
                if self.clip_grad_norm:
                    clip_grad_norm_(self.model.parameters(), **self.clip_grad_norm)
                optimizer.step()

        return total_loss
 def __str__(self):
     """
     Model prints with number of trainable parameters
     """
     model_parameters = filter(lambda p: p.requires_grad, self.parameters())
     params = sum([np.prod(p.size()) for p in model_parameters])
     return super().__str__() + set_color('\nTrainable parameters',
                                          'blue') + f': {params}'
Exemple #10
0
 def _load_kg(self, token, dataset_path):
     self.logger.debug(set_color(f'Loading kg from [{dataset_path}].', 'green'))
     kg_path = os.path.join(dataset_path, f'{token}.kg')
     if not os.path.isfile(kg_path):
         raise ValueError(f'[{token}.kg] not found in [{dataset_path}].')
     df = self._load_feat(kg_path, FeatureSource.KG)
     self._check_kg(df)
     return df
Exemple #11
0
    def evaluate(self, eval_data, load_best_model=True, model_file=None, show_progress=False):
        r"""Evaluate the model based on the eval data.

        Args:
            eval_data (DataLoader): the eval data
            load_best_model (bool, optional): whether load the best model in the training process, default: True.
                                              It should be set True, if users want to test the model after training.
            model_file (str, optional): the saved model file, default: None. If users want to test the previously
                                        trained model file, they can set this parameter.
            show_progress (bool): Show the progress of evaluate epoch. Defaults to ``False``.

        Returns:
            dict: eval result, key is the eval metric and value in the corresponding metric value.
        """
        if not eval_data:
            return

        if load_best_model:
            if model_file:
                checkpoint_file = model_file
            else:
                checkpoint_file = self.saved_model_file
            checkpoint = torch.load(checkpoint_file)
            self.model.load_state_dict(checkpoint['state_dict'])
            message_output = 'Loading model structure and parameters from {}'.format(checkpoint_file)
            self.logger.info(message_output)

        self.model.eval()

        if eval_data.dl_type == DataLoaderType.FULL:
            if self.item_tensor is None:
                self.item_tensor = eval_data.get_item_feature().to(self.device).repeat(eval_data.step)
            self.tot_item_num = eval_data.dataset.item_num

        batch_matrix_list = []
        iter_data = (
            tqdm(
                enumerate(eval_data),
                total=len(eval_data),
                desc=set_color(f"Evaluate   ", 'pink'),
            ) if show_progress else enumerate(eval_data)
        )
        for batch_idx, batched_data in iter_data:
            if eval_data.dl_type == DataLoaderType.FULL:
                interaction, scores = self._full_sort_batch_eval(batched_data)
            else:
                interaction = batched_data
                batch_size = interaction.length
                if batch_size <= self.test_batch_size:
                    scores = self.model.predict(interaction.to(self.device))
                else:
                    scores = self._spilt_predict(interaction, batch_size)

            batch_matrix = self.evaluator.collect(interaction, scores)
            batch_matrix_list.append(batch_matrix)
        result = self.evaluator.evaluate(batch_matrix_list, eval_data)

        return result
Exemple #12
0
    def _get_ent_fields_in_same_space(self):
        """Return ``field_set`` that should be remapped together with entities.
        """
        fields_in_same_space = super()._get_fields_in_same_space()

        ent_fields = {self.head_entity_field, self.tail_entity_field}
        for field_set in fields_in_same_space:
            if self._contain_ent_field(field_set):
                field_set = self._remove_ent_field(field_set)
                ent_fields.update(field_set)
        self.logger.debug(set_color('ent_fields', 'blue') + f': {fields_in_same_space}')
        return ent_fields
Exemple #13
0
    def _load_link(self, token, dataset_path):
        self.logger.debug(set_color(f'Loading link from [{dataset_path}].', 'green'))
        link_path = os.path.join(dataset_path, f'{token}.link')
        if not os.path.isfile(link_path):
            raise ValueError(f'[{token}.link] not found in [{dataset_path}].')
        df = self._load_feat(link_path, 'link')
        self._check_link(df)

        item2entity, entity2item = {}, {}
        for item_id, entity_id in zip(df[self.iid_field].values, df[self.entity_field].values):
            item2entity[item_id] = entity_id
            entity2item[entity_id] = item_id
        return item2entity, entity2item
Exemple #14
0
    def _train_epoch(self,
                     train_data,
                     epoch_idx,
                     loss_func=None,
                     show_progress=False):
        r"""Train the model in an epoch

        Args:
            train_data (DataLoader): The train data.
            epoch_idx (int): The current epoch id.
            loss_func (function): The loss function of :attr:`model`. If it is ``None``, the loss function will be
                :attr:`self.model.calculate_loss`. Defaults to ``None``.
            show_progress (bool): Show the progress of training epoch. Defaults to ``False``.

        Returns:
            float/tuple: The sum of loss returned by all batches in this epoch. If the loss in each batch contains
            multiple parts and the model return these multiple parts loss instead of the sum of loss, it will return a
            tuple which includes the sum of loss in each part.
        """
        self.model.train()
        loss_func = loss_func or self.model.calculate_loss
        total_loss = None
        iter_data = (tqdm(
            enumerate(train_data),
            total=len(train_data),
            desc=set_color(f"Train {epoch_idx:>5}", 'pink'),
        ) if show_progress else enumerate(train_data))
        #batch_group = []
        for batch_idx, interaction in iter_data:
            interaction = interaction.to(self.device)
            self.optimizer.zero_grad()
            losses = loss_func(interaction)
            #batch_group.append(losses.item())
            #if batch_idx % 100 == 0:
            #    print(batch_idx, sum(batch_group)/len(batch_group))

            if isinstance(losses, tuple):
                loss = sum(losses)
                loss_tuple = tuple(per_loss.item() for per_loss in losses)
                total_loss = loss_tuple if total_loss is None else tuple(
                    map(sum, zip(total_loss, loss_tuple)))
            else:
                loss = losses
                total_loss = losses.item(
                ) if total_loss is None else total_loss + losses.item()
            self._check_nan(loss)
            loss.backward()
            if self.clip_grad_norm:
                clip_grad_norm_(self.model.parameters(), **self.clip_grad_norm)
            self.optimizer.step()
        return total_loss
Exemple #15
0
def save_split_dataloaders(config, dataloaders):
    """Save split dataloaders.

    Args:
        config (Config): An instance object of Config, used to record parameter information.
        dataloaders (tuple of AbstractDataLoader): The split dataloaders.
    """
    save_path = config['checkpoint_dir']
    saved_dataloaders_file = f'{config["dataset"]}-for-{config["model"]}-dataloader.pth'
    file_path = os.path.join(save_path, saved_dataloaders_file)
    logger = getLogger()
    logger.info(
        set_color('Saved split dataloaders', 'blue') + f': {file_path}')
    with open(file_path, 'wb') as f:
        pickle.dump(dataloaders, f)
Exemple #16
0
    def __str__(self):
        info = [set_color('Evaluation Setting:', 'pink')]

        if self.group_field:
            info.append(set_color('Group by', 'blue') + f' {self.group_field}')
        else:
            info.append(set_color('No Grouping', 'yellow'))

        if self.ordering_args is not None and self.ordering_args['strategy'] != 'none':
            info.append(set_color('Ordering', 'blue') + f': {self.ordering_args}')
        else:
            info.append(set_color('No Ordering', 'yellow'))

        if self.split_args is not None and self.split_args['strategy'] != 'none':
            info.append(set_color('Splitting', 'blue') + f': {self.split_args}')
        else:
            info.append(set_color('No Splitting', 'yellow'))

        if self.neg_sample_args is not None and self.neg_sample_args['strategy'] != 'none':
            info.append(set_color('Negative Sampling', 'blue') + f': {self.neg_sample_args}')
        else:
            info.append(set_color('No Negative Sampling', 'yellow'))

        return '\n\t'.join(info)
Exemple #17
0
    def __str__(self):
        args_info = '\n'
        for category in self.parameters:
            args_info += set_color(category + ' Hyper Parameters:\n', 'pink')
            args_info += '\n'.join([(set_color("{}", 'cyan') + " =" + set_color(" {}", 'yellow')).format(arg, value)
                                    for arg, value in self.final_config_dict.items()
                                    if arg in self.parameters[category]])
            args_info += '\n\n'

        args_info += set_color('Other Hyper Parameters: \n', 'pink')
        args_info += '\n'.join([
            (set_color("{}", 'cyan') + " = " + set_color("{}", 'yellow')).format(arg, value)
            for arg, value in self.final_config_dict.items()
            if arg not in {
                _ for args in self.parameters.values() for _ in args
            }.union({'model', 'dataset', 'config_files'})
        ])
        args_info += '\n\n'
        return args_info
Exemple #18
0
def data_preparation(config, dataset, save=False):
    """Split the dataset by :attr:`config['eval_setting']` and call :func:`dataloader_construct` to create
    corresponding dataloader.

    Args:
        config (Config): An instance object of Config, used to record parameter information.
        dataset (Dataset): An instance object of Dataset, which contains all interaction records.
        save (bool, optional): If ``True``, it will call :func:`save_datasets` to save split dataset.
            Defaults to ``False``.

    Returns:
        tuple:
            - train_data (AbstractDataLoader): The dataloader for training.
            - valid_data (AbstractDataLoader): The dataloader for validation.
            - test_data (AbstractDataLoader): The dataloader for testing.
    """
    model_type = config['MODEL_TYPE']

    es = EvalSetting(config)

    built_datasets = dataset.build(es)
    train_dataset, valid_dataset, test_dataset = built_datasets
    phases = ['train', 'valid', 'test']
    sampler = None
    logger = getLogger()
    train_neg_sample_args = config['train_neg_sample_args']
    eval_neg_sample_args = es.neg_sample_args

    # Training
    train_kwargs = {
        'config': config,
        'dataset': train_dataset,
        'batch_size': config['train_batch_size'],
        'dl_format': config['MODEL_INPUT_TYPE'],
        'shuffle': True,
    }
    if train_neg_sample_args['strategy'] != 'none':
        if dataset.label_field in dataset.inter_feat:
            raise ValueError(
                f'`training_neg_sample_num` should be 0 '
                f'if inter_feat have label_field [{dataset.label_field}].')
        if model_type != ModelType.SEQUENTIAL:
            sampler = Sampler(phases, built_datasets,
                              train_neg_sample_args['distribution'])
        else:
            sampler = RepeatableSampler(phases, dataset,
                                        train_neg_sample_args['distribution'])
        train_kwargs['sampler'] = sampler.set_phase('train')
        train_kwargs['neg_sample_args'] = train_neg_sample_args
        if model_type == ModelType.KNOWLEDGE:
            kg_sampler = KGSampler(dataset,
                                   train_neg_sample_args['distribution'])
            train_kwargs['kg_sampler'] = kg_sampler

    dataloader = get_data_loader('train', config, train_neg_sample_args)
    logger.info(
        set_color('Build', 'pink') +
        set_color(f' [{dataloader.__name__}]', 'yellow') + ' for ' +
        set_color('[train]', 'yellow') + ' with format ' +
        set_color(f'[{train_kwargs["dl_format"]}]', 'yellow'))
    if train_neg_sample_args['strategy'] != 'none':
        logger.info(
            set_color('[train]', 'pink') +
            set_color(' Negative Sampling', 'blue') +
            f': {train_neg_sample_args}')
    else:
        logger.info(
            set_color('[train]', 'pink') +
            set_color(' No Negative Sampling', 'yellow'))
    logger.info(
        set_color('[train]', 'pink') + set_color(' batch_size', 'cyan') +
        ' = ' + set_color(f'[{train_kwargs["batch_size"]}]', 'yellow') + ', ' +
        set_color('shuffle', 'cyan') + ' = ' +
        set_color(f'[{train_kwargs["shuffle"]}]\n', 'yellow'))
    train_data = dataloader(**train_kwargs)

    # Evaluation
    eval_kwargs = {
        'config': config,
        'batch_size': config['eval_batch_size'],
        'dl_format': InputType.POINTWISE,
        'shuffle': False,
    }
    valid_kwargs = {'dataset': valid_dataset}
    test_kwargs = {'dataset': test_dataset}
    if eval_neg_sample_args['strategy'] != 'none':
        if dataset.label_field in dataset.inter_feat:
            raise ValueError(
                f'It can not validate with `{es.es_str[1]}` '
                f'when inter_feat have label_field [{dataset.label_field}].')
        if sampler is None:
            if model_type != ModelType.SEQUENTIAL:
                sampler = Sampler(phases, built_datasets,
                                  eval_neg_sample_args['distribution'])
            else:
                sampler = RepeatableSampler(
                    phases, dataset, eval_neg_sample_args['distribution'])
        else:
            sampler.set_distribution(eval_neg_sample_args['distribution'])
        eval_kwargs['neg_sample_args'] = eval_neg_sample_args
        valid_kwargs['sampler'] = sampler.set_phase('valid')
        test_kwargs['sampler'] = sampler.set_phase('test')
    valid_kwargs.update(eval_kwargs)
    test_kwargs.update(eval_kwargs)

    dataloader = get_data_loader('evaluation', config, eval_neg_sample_args)
    logger.info(
        set_color('Build', 'pink') +
        set_color(f' [{dataloader.__name__}]', 'yellow') + ' for ' +
        set_color('[evaluation]', 'yellow') + ' with format ' +
        set_color(f'[{eval_kwargs["dl_format"]}]', 'yellow'))
    logger.info(es)
    logger.info(
        set_color('[evaluation]', 'pink') + set_color(' batch_size', 'cyan') +
        ' = ' + set_color(f'[{eval_kwargs["batch_size"]}]', 'yellow') + ', ' +
        set_color('shuffle', 'cyan') + ' = ' +
        set_color(f'[{eval_kwargs["shuffle"]}]\n', 'yellow'))

    valid_data = dataloader(**valid_kwargs)
    test_data = dataloader(**test_kwargs)

    if save:
        save_split_dataloaders(config,
                               dataloaders=(train_data, valid_data, test_data))

    return train_data, valid_data, test_data
Exemple #19
0
def run_trial(model_name , dataset_name , hp_config = None , save_flag = False):

    if not hp_config:
        hp_config = {}
        tuning = False
    else:
        tuning = True

    commons.init_seeds()
    verbose = True
    verbose = (not tuning)
    model_class = statics.model_name_map[model_name]
    try:
        default_config = model_class.default_params
    except AttributeError:
        default_config = {}
        assert model_name in statics.recbole_models

    default_config.update(statics.datasets_params[dataset_name])
    default_config.update(hp_config)

    config = Config(model=model_class, dataset=dataset_name, config_dict=default_config)
    init_seed(config['seed'], config['reproducibility'])

    init_logger(config)
    logger = logging.getLogger()

    # logger initialization
    if verbose:
        logger.info(config)

    # dataset filtering
    dataset = create_dataset(config)
    train_data, valid_data, test_data = data_preparation(config, dataset)
    train_data = add_graph(train_data)

    if verbose:
        logger.info(dataset)

    model = model_class(config, train_data).to(commons.device)
    trainer = utils.get_trainer(config)(config, model)

    best_valid_score, best_valid_result = trainer.fit(train_data, valid_data , verbose= verbose , show_progress=verbose)
    test_result = trainer.evaluate(test_data)

    if verbose:
        logger.info(set_color('best valid ', 'yellow') + f': {best_valid_result}')
        logger.info(set_color('test result', 'yellow') + f': {test_result}')

    metric = str.lower(config['valid_metric'])

    if save_flag:
        os.makedirs(os.path.join("bestmodels" , dataset_name , str(config["topk"])) , exist_ok=True)
        save_path = os.path.join("bestmodels" , dataset_name , str(config["topk"]) , "{}.pth".format(model_name))
    else:
        save_path = None

    if save_path:
        shutil.copyfile(trainer.saved_model_file , save_path)

    return {
        'metric' : config['valid_metric'],
        'best_valid_score': best_valid_score,
        'valid_score_bigger': config['valid_metric_bigger'],
        'best_valid_result': best_valid_result,
        'test_score': test_result[metric]
    }
Exemple #20
0
    def fit(self,
            train_data,
            valid_data=None,
            verbose=True,
            saved=True,
            show_progress=False,
            callback_fn=None):
        r"""Train the model based on the train data and the valid data.

        Args:
            train_data (DataLoader): the train data
            valid_data (DataLoader, optional): the valid data, default: None.
                                               If it's None, the early_stopping is invalid.
            verbose (bool, optional): whether to write training and evaluation information to logger, default: True
            saved (bool, optional): whether to save the model parameters, default: True
            show_progress (bool): Show the progress of training epoch and evaluate epoch. Defaults to ``False``.
            callback_fn (callable): Optional callback function executed at end of epoch.
                                    Includes (epoch_idx, valid_score) input arguments.

        Returns:
             (float, dict): best valid score and best valid result. If valid_data is None, it returns (-1, None)
        """
        if saved and self.start_epoch >= self.epochs:
            self._save_checkpoint(-1)

        for epoch_idx in range(self.start_epoch, self.epochs):
            # train
            training_start_time = time()
            train_loss = self._train_epoch(train_data,
                                           epoch_idx,
                                           show_progress=show_progress)
            self.train_loss_dict[epoch_idx] = sum(train_loss) if isinstance(
                train_loss, tuple) else train_loss
            training_end_time = time()
            train_loss_output = \
                self._generate_train_loss_output(epoch_idx, training_start_time, training_end_time, train_loss)
            #if verbose:
            #    self.logger.info(train_loss_output)

            # eval
            if self.eval_step <= 0 or not valid_data:
                if saved:
                    self._save_checkpoint(epoch_idx)
                    update_output = set_color(
                        'Saving current',
                        'blue') + ': %s' % self.saved_model_file
                    #if verbose:
                    #    self.logger.info(update_output)
                continue
            if (epoch_idx + 1) % self.eval_step == 0:
                valid_start_time = time()
                valid_score, valid_result = self._valid_epoch(
                    valid_data, show_progress=show_progress)
                self.best_valid_score, self.cur_step, stop_flag, update_flag = early_stopping(
                    valid_score,
                    self.best_valid_score,
                    self.cur_step,
                    max_step=self.stopping_step,
                    bigger=self.valid_metric_bigger)
                valid_end_time = time()
                valid_score_output = (set_color("epoch %d evaluating", 'green') + " [" + set_color("time", 'blue')
                                    + ": %.2fs, " + set_color("valid_score", 'blue') + ": %f]") % \
                                     (epoch_idx, valid_end_time - valid_start_time, valid_score)
                valid_result_output = set_color(
                    'valid result', 'blue') + ': \n' + dict2str(valid_result)
                if verbose:
                    self.logger.info(valid_score_output)
                    names = [k for k, _ in valid_result.items()]
                    values = [round(v, 3) for _, v in valid_result.items()]
                    my_table = PrettyTable()
                    my_table.field_names = names
                    my_table.add_row(values)
                    print(my_table)

                if update_flag:
                    if saved:
                        self._save_checkpoint(epoch_idx)
                        update_output = set_color(
                            'Saving current best',
                            'blue') + ': %s' % self.saved_model_file
                        #if verbose:
                        #    self.logger.info(update_output)
                    self.best_valid_result = valid_result

                if callback_fn:
                    callback_fn(epoch_idx, valid_score)

                if stop_flag:
                    stop_output = 'Finished training, best eval result in epoch %d' % \
                                  (epoch_idx - self.cur_step * self.eval_step)
                    if verbose:
                        self.logger.info(stop_output)
                    break
        if self.draw_loss_pic:
            save_path = '{}-{}-train_loss.pdf'.format(self.config['model'],
                                                      get_local_time())
            self.plot_train_loss(save_path=os.path.join(save_path))
        return self.best_valid_score, self.best_valid_result
Exemple #21
0
    def fit(self,
            train_data,
            valid_data=None,
            verbose=True,
            saved=True,
            show_progress=False,
            callback_fn=None):
        if saved and self.start_epoch >= self.epochs:
            self._save_checkpoint(-1)

        encoder_params = set(self.model.encoder.parameters())
        decoder_params = set(self.model.decoder.parameters())

        optimizer_encoder = self._build_optimizer(encoder_params)
        optimizer_decoder = self._build_optimizer(decoder_params)

        for epoch_idx in range(self.start_epoch, self.epochs):
            # alternate training
            training_start_time = time()
            train_loss = self._train_epoch(train_data,
                                           epoch_idx,
                                           show_progress=show_progress,
                                           n_epochs=self.n_enc_epochs,
                                           encoder_flag=True,
                                           optimizer=optimizer_encoder)
            self.model.update_prior()
            train_loss = self._train_epoch(train_data,
                                           epoch_idx,
                                           show_progress=show_progress,
                                           n_epochs=self.n_dec_epochs,
                                           encoder_flag=False,
                                           optimizer=optimizer_decoder)
            self.train_loss_dict[epoch_idx] = sum(train_loss) if isinstance(
                train_loss, tuple) else train_loss
            training_end_time = time()
            train_loss_output = \
                self._generate_train_loss_output(epoch_idx, training_start_time, training_end_time, train_loss)
            if verbose:
                self.logger.info(train_loss_output)

            # eval
            if self.eval_step <= 0 or not valid_data:
                if saved:
                    self._save_checkpoint(epoch_idx)
                    update_output = 'Saving current: %s' % self.saved_model_file
                    if verbose:
                        self.logger.info(update_output)
                continue
            if (epoch_idx + 1) % self.eval_step == 0:
                valid_start_time = time()
                valid_score, valid_result = self._valid_epoch(
                    valid_data, show_progress=show_progress)
                self.best_valid_score, self.cur_step, stop_flag, update_flag = early_stopping(
                    valid_score,
                    self.best_valid_score,
                    self.cur_step,
                    max_step=self.stopping_step,
                    bigger=self.valid_metric_bigger)
                valid_end_time = time()
                valid_score_output = (set_color("epoch %d evaluating", 'green') + " [" + set_color("time", 'blue')
                                    + ": %.2fs, " + set_color("valid_score", 'blue') + ": %f]") % \
                                     (epoch_idx, valid_end_time - valid_start_time, valid_score)
                valid_result_output = set_color(
                    'valid result', 'blue') + ': \n' + dict2str(valid_result)
                if verbose:
                    self.logger.info(valid_score_output)
                    self.logger.info(valid_result_output)
                if update_flag:
                    if saved:
                        self._save_checkpoint(epoch_idx)
                        update_output = set_color(
                            'Saving current best',
                            'blue') + ': %s' % self.saved_model_file
                        if verbose:
                            self.logger.info(update_output)
                    self.best_valid_result = valid_result

                if callback_fn:
                    callback_fn(epoch_idx, valid_score)

                if stop_flag:
                    stop_output = 'Finished training, best eval result in epoch %d' % \
                                  (epoch_idx - self.cur_step * self.eval_step)
                    if verbose:
                        self.logger.info(stop_output)
                    break
        return self.best_valid_score, self.best_valid_result