def fit(self, train_data, valid_data=None, verbose=True, saved=True, show_progress=False): # load model if self.boost_model is not None: self.model.load_model(self.boost_model) self.best_valid_score = 0. self.best_valid_result = 0. for epoch_idx in range(self.epochs): self._train_at_once(train_data, valid_data) if (epoch_idx + 1) % self.eval_step == 0: # evaluate valid_start_time = time() valid_result, valid_score = self._valid_epoch(valid_data) valid_end_time = time() valid_score_output = (set_color("epoch %d evaluating", 'green') + " [" + set_color("time", 'blue') + ": %.2fs, " + set_color("valid_score", 'blue') + ": %f]") % \ (epoch_idx, valid_end_time - valid_start_time, valid_score) valid_result_output = set_color( 'valid result', 'blue') + ': \n' + dict2str(valid_result) if verbose: self.logger.info(valid_score_output) self.logger.info(valid_result_output) self.best_valid_score = valid_score self.best_valid_result = valid_result return self.best_valid_score, self.best_valid_result
def run_recbole(model=None, dataset=None, config_file_list=None, config_dict=None, saved=True): r""" A fast running api, which includes the complete process of training and testing a model on a specified dataset Args: model (str): model name dataset (str): dataset name config_file_list (list): config files used to modify experiment parameters config_dict (dict): parameters dictionary used to modify experiment parameters saved (bool): whether to save the model """ # configurations initialization config = Config(model=model, dataset=dataset, config_file_list=config_file_list, config_dict=config_dict) init_seed(config['seed'], config['reproducibility']) # logger initialization init_logger(config) logger = getLogger() logger.info(config) # dataset filtering dataset = create_dataset(config) logger.info(dataset) # dataset splitting train_data, valid_data, test_data = data_preparation(config, dataset) print(train_data.dataset.item_feat) print(valid_data.dataset.item_feat) # model loading and initialization model = get_model(config['model'])(config, train_data).to(config['device']) logger.info(model) # trainer loading and initialization trainer = get_trainer(config['MODEL_TYPE'], config['model'])(config, model) # model training with profiler.profile(enabled=config["monitor"], with_stack=True, profile_memory=True, use_cuda=True) as prof: best_valid_score, best_valid_result = trainer.fit( train_data, valid_data, saved=saved, show_progress=config['show_progress'] ) if prof is not None: print(prof.key_averages(group_by_stack_n=5).table(sort_by='self_cpu_time_total')) # model evaluation with profiler.profile(enabled=config["monitor_eval"], with_stack=True, profile_memory=True, use_cuda=True) as prof: test_result = trainer.evaluate(test_data, load_best_model=saved, show_progress=config['show_progress'], cold_warm_distinct_eval=True) if prof is not None: print(prof.key_averages(group_by_stack_n=5).table(sort_by='self_cpu_time_total')) logger.info(set_color('best valid ', 'yellow') + f': {best_valid_result}') logger.info(set_color('test result', 'yellow') + f': {test_result}') return { 'best_valid_score': best_valid_score, 'valid_score_bigger': config['valid_metric_bigger'], 'best_valid_result': best_valid_result, 'test_result': test_result }
def _generate_train_loss_output(self, epoch_idx, s_time, e_time, losses): des = self.config['loss_decimal_place'] or 4 train_loss_output = (set_color('epoch %d training', 'green') + ' [' + set_color('time', 'blue') + ': %.2fs, ') % (epoch_idx, e_time - s_time) if isinstance(losses, tuple): des = (set_color('train_loss%d', 'blue') + ': %.' + str(des) + 'f') train_loss_output += ', '.join(des % (idx + 1, loss) for idx, loss in enumerate(losses)) else: des = '%.' + str(des) + 'f' train_loss_output += set_color('train loss', 'blue') + ': ' + des % losses return train_loss_output + ']'
def _get_field_from_config(self): super()._get_field_from_config() self.source_field = self.config['SOURCE_ID_FIELD'] self.target_field = self.config['TARGET_ID_FIELD'] self._check_field('source_field', 'target_field') self.logger.debug( set_color('source_id_field', 'blue') + f': {self.source_field}') self.logger.debug( set_color('target_id_field', 'blue') + f': {self.target_field}')
def _get_field_from_config(self): super()._get_field_from_config() self.head_entity_field = self.config['HEAD_ENTITY_ID_FIELD'] self.tail_entity_field = self.config['TAIL_ENTITY_ID_FIELD'] self.relation_field = self.config['RELATION_ID_FIELD'] self.entity_field = self.config['ENTITY_ID_FIELD'] self._check_field('head_entity_field', 'tail_entity_field', 'relation_field', 'entity_field') self.set_field_property(self.entity_field, FeatureType.TOKEN, FeatureSource.KG, 1) self.logger.debug(set_color('relation_field', 'blue') + f': {self.relation_field}') self.logger.debug(set_color('entity_field', 'blue') + f': {self.entity_field}')
def run_recbole(model=None, dataset=None, config_file_list=None, config_dict=None, saved=True): r""" A fast running api, which includes the complete process of training and testing a model on a specified dataset Args: model (str): model name dataset (str): dataset name config_file_list (list): config files used to modify experiment parameters config_dict (dict): parameters dictionary used to modify experiment parameters saved (bool): whether to save the model """ # configurations initialization config = Config(model=model, dataset=dataset, config_file_list=config_file_list, config_dict=config_dict) init_seed(config['seed'], config['reproducibility']) # logger initialization init_logger(config) logger = getLogger() logger.info(config) # dataset filtering dataset = create_dataset(config) logger.info(dataset) # dataset splitting train_data, valid_data, test_data = data_preparation(config, dataset) # model loading and initialization model = get_model(config['model'])(config, train_data).to(config['device']) logger.info(model) # trainer loading and initialization trainer = get_trainer(config['MODEL_TYPE'], config['model'])(config, model) # model training best_valid_score, best_valid_result = trainer.fit( train_data, valid_data, saved=saved, show_progress=config['show_progress'] ) # model evaluation test_result = trainer.evaluate(test_data, load_best_model=saved, show_progress=config['show_progress']) logger.info(set_color('best valid ', 'yellow') + f': {best_valid_result}') logger.info(set_color('test result', 'yellow') + f': {test_result}') return { 'best_valid_score': best_valid_score, 'valid_score_bigger': config['valid_metric_bigger'], 'best_valid_result': best_valid_result, 'test_result': test_result }
def pretrain(self, train_data, verbose=True, show_progress=False): for epoch_idx in range(self.start_epoch, self.epochs): # train training_start_time = time() train_loss = self._train_epoch(train_data, epoch_idx, show_progress=show_progress) self.train_loss_dict[epoch_idx] = sum(train_loss) if isinstance( train_loss, tuple) else train_loss training_end_time = time() train_loss_output = \ self._generate_train_loss_output(epoch_idx, training_start_time, training_end_time, train_loss) if verbose: self.logger.info(train_loss_output) if (epoch_idx + 1) % self.config['save_step'] == 0: saved_model_file = os.path.join( self.checkpoint_dir, '{}-{}-{}.pth'.format(self.config['model'], self.config['dataset'], str(epoch_idx + 1))) self.save_pretrained_model(epoch_idx, saved_model_file) update_output = set_color('Saving current', 'blue') + ': %s' % saved_model_file if verbose: self.logger.info(update_output) return self.best_valid_score, self.best_valid_result
def _train_epoch( self, train_data, epoch_idx, n_epochs, optimizer, encoder_flag, loss_func=None, show_progress=False ): self.model.train() loss_func = loss_func or self.model.calculate_loss total_loss = None iter_data = ( tqdm( enumerate(train_data), total=len(train_data), desc=set_color(f"Train {epoch_idx:>5}", 'pink'), ) if show_progress else enumerate(train_data) ) for epoch in range(n_epochs): for batch_idx, interaction in iter_data: interaction = interaction.to(self.device) optimizer.zero_grad() losses = loss_func(interaction, encoder_flag=encoder_flag) if isinstance(losses, tuple): loss = sum(losses) loss_tuple = tuple(per_loss.item() for per_loss in losses) total_loss = loss_tuple if total_loss is None else tuple(map(sum, zip(total_loss, loss_tuple))) else: loss = losses total_loss = losses.item() if total_loss is None else total_loss + losses.item() self._check_nan(loss) loss.backward() if self.clip_grad_norm: clip_grad_norm_(self.model.parameters(), **self.clip_grad_norm) optimizer.step() return total_loss
def __str__(self): """ Model prints with number of trainable parameters """ model_parameters = filter(lambda p: p.requires_grad, self.parameters()) params = sum([np.prod(p.size()) for p in model_parameters]) return super().__str__() + set_color('\nTrainable parameters', 'blue') + f': {params}'
def _load_kg(self, token, dataset_path): self.logger.debug(set_color(f'Loading kg from [{dataset_path}].', 'green')) kg_path = os.path.join(dataset_path, f'{token}.kg') if not os.path.isfile(kg_path): raise ValueError(f'[{token}.kg] not found in [{dataset_path}].') df = self._load_feat(kg_path, FeatureSource.KG) self._check_kg(df) return df
def evaluate(self, eval_data, load_best_model=True, model_file=None, show_progress=False): r"""Evaluate the model based on the eval data. Args: eval_data (DataLoader): the eval data load_best_model (bool, optional): whether load the best model in the training process, default: True. It should be set True, if users want to test the model after training. model_file (str, optional): the saved model file, default: None. If users want to test the previously trained model file, they can set this parameter. show_progress (bool): Show the progress of evaluate epoch. Defaults to ``False``. Returns: dict: eval result, key is the eval metric and value in the corresponding metric value. """ if not eval_data: return if load_best_model: if model_file: checkpoint_file = model_file else: checkpoint_file = self.saved_model_file checkpoint = torch.load(checkpoint_file) self.model.load_state_dict(checkpoint['state_dict']) message_output = 'Loading model structure and parameters from {}'.format(checkpoint_file) self.logger.info(message_output) self.model.eval() if eval_data.dl_type == DataLoaderType.FULL: if self.item_tensor is None: self.item_tensor = eval_data.get_item_feature().to(self.device).repeat(eval_data.step) self.tot_item_num = eval_data.dataset.item_num batch_matrix_list = [] iter_data = ( tqdm( enumerate(eval_data), total=len(eval_data), desc=set_color(f"Evaluate ", 'pink'), ) if show_progress else enumerate(eval_data) ) for batch_idx, batched_data in iter_data: if eval_data.dl_type == DataLoaderType.FULL: interaction, scores = self._full_sort_batch_eval(batched_data) else: interaction = batched_data batch_size = interaction.length if batch_size <= self.test_batch_size: scores = self.model.predict(interaction.to(self.device)) else: scores = self._spilt_predict(interaction, batch_size) batch_matrix = self.evaluator.collect(interaction, scores) batch_matrix_list.append(batch_matrix) result = self.evaluator.evaluate(batch_matrix_list, eval_data) return result
def _get_ent_fields_in_same_space(self): """Return ``field_set`` that should be remapped together with entities. """ fields_in_same_space = super()._get_fields_in_same_space() ent_fields = {self.head_entity_field, self.tail_entity_field} for field_set in fields_in_same_space: if self._contain_ent_field(field_set): field_set = self._remove_ent_field(field_set) ent_fields.update(field_set) self.logger.debug(set_color('ent_fields', 'blue') + f': {fields_in_same_space}') return ent_fields
def _load_link(self, token, dataset_path): self.logger.debug(set_color(f'Loading link from [{dataset_path}].', 'green')) link_path = os.path.join(dataset_path, f'{token}.link') if not os.path.isfile(link_path): raise ValueError(f'[{token}.link] not found in [{dataset_path}].') df = self._load_feat(link_path, 'link') self._check_link(df) item2entity, entity2item = {}, {} for item_id, entity_id in zip(df[self.iid_field].values, df[self.entity_field].values): item2entity[item_id] = entity_id entity2item[entity_id] = item_id return item2entity, entity2item
def _train_epoch(self, train_data, epoch_idx, loss_func=None, show_progress=False): r"""Train the model in an epoch Args: train_data (DataLoader): The train data. epoch_idx (int): The current epoch id. loss_func (function): The loss function of :attr:`model`. If it is ``None``, the loss function will be :attr:`self.model.calculate_loss`. Defaults to ``None``. show_progress (bool): Show the progress of training epoch. Defaults to ``False``. Returns: float/tuple: The sum of loss returned by all batches in this epoch. If the loss in each batch contains multiple parts and the model return these multiple parts loss instead of the sum of loss, it will return a tuple which includes the sum of loss in each part. """ self.model.train() loss_func = loss_func or self.model.calculate_loss total_loss = None iter_data = (tqdm( enumerate(train_data), total=len(train_data), desc=set_color(f"Train {epoch_idx:>5}", 'pink'), ) if show_progress else enumerate(train_data)) #batch_group = [] for batch_idx, interaction in iter_data: interaction = interaction.to(self.device) self.optimizer.zero_grad() losses = loss_func(interaction) #batch_group.append(losses.item()) #if batch_idx % 100 == 0: # print(batch_idx, sum(batch_group)/len(batch_group)) if isinstance(losses, tuple): loss = sum(losses) loss_tuple = tuple(per_loss.item() for per_loss in losses) total_loss = loss_tuple if total_loss is None else tuple( map(sum, zip(total_loss, loss_tuple))) else: loss = losses total_loss = losses.item( ) if total_loss is None else total_loss + losses.item() self._check_nan(loss) loss.backward() if self.clip_grad_norm: clip_grad_norm_(self.model.parameters(), **self.clip_grad_norm) self.optimizer.step() return total_loss
def save_split_dataloaders(config, dataloaders): """Save split dataloaders. Args: config (Config): An instance object of Config, used to record parameter information. dataloaders (tuple of AbstractDataLoader): The split dataloaders. """ save_path = config['checkpoint_dir'] saved_dataloaders_file = f'{config["dataset"]}-for-{config["model"]}-dataloader.pth' file_path = os.path.join(save_path, saved_dataloaders_file) logger = getLogger() logger.info( set_color('Saved split dataloaders', 'blue') + f': {file_path}') with open(file_path, 'wb') as f: pickle.dump(dataloaders, f)
def __str__(self): info = [set_color('Evaluation Setting:', 'pink')] if self.group_field: info.append(set_color('Group by', 'blue') + f' {self.group_field}') else: info.append(set_color('No Grouping', 'yellow')) if self.ordering_args is not None and self.ordering_args['strategy'] != 'none': info.append(set_color('Ordering', 'blue') + f': {self.ordering_args}') else: info.append(set_color('No Ordering', 'yellow')) if self.split_args is not None and self.split_args['strategy'] != 'none': info.append(set_color('Splitting', 'blue') + f': {self.split_args}') else: info.append(set_color('No Splitting', 'yellow')) if self.neg_sample_args is not None and self.neg_sample_args['strategy'] != 'none': info.append(set_color('Negative Sampling', 'blue') + f': {self.neg_sample_args}') else: info.append(set_color('No Negative Sampling', 'yellow')) return '\n\t'.join(info)
def __str__(self): args_info = '\n' for category in self.parameters: args_info += set_color(category + ' Hyper Parameters:\n', 'pink') args_info += '\n'.join([(set_color("{}", 'cyan') + " =" + set_color(" {}", 'yellow')).format(arg, value) for arg, value in self.final_config_dict.items() if arg in self.parameters[category]]) args_info += '\n\n' args_info += set_color('Other Hyper Parameters: \n', 'pink') args_info += '\n'.join([ (set_color("{}", 'cyan') + " = " + set_color("{}", 'yellow')).format(arg, value) for arg, value in self.final_config_dict.items() if arg not in { _ for args in self.parameters.values() for _ in args }.union({'model', 'dataset', 'config_files'}) ]) args_info += '\n\n' return args_info
def data_preparation(config, dataset, save=False): """Split the dataset by :attr:`config['eval_setting']` and call :func:`dataloader_construct` to create corresponding dataloader. Args: config (Config): An instance object of Config, used to record parameter information. dataset (Dataset): An instance object of Dataset, which contains all interaction records. save (bool, optional): If ``True``, it will call :func:`save_datasets` to save split dataset. Defaults to ``False``. Returns: tuple: - train_data (AbstractDataLoader): The dataloader for training. - valid_data (AbstractDataLoader): The dataloader for validation. - test_data (AbstractDataLoader): The dataloader for testing. """ model_type = config['MODEL_TYPE'] es = EvalSetting(config) built_datasets = dataset.build(es) train_dataset, valid_dataset, test_dataset = built_datasets phases = ['train', 'valid', 'test'] sampler = None logger = getLogger() train_neg_sample_args = config['train_neg_sample_args'] eval_neg_sample_args = es.neg_sample_args # Training train_kwargs = { 'config': config, 'dataset': train_dataset, 'batch_size': config['train_batch_size'], 'dl_format': config['MODEL_INPUT_TYPE'], 'shuffle': True, } if train_neg_sample_args['strategy'] != 'none': if dataset.label_field in dataset.inter_feat: raise ValueError( f'`training_neg_sample_num` should be 0 ' f'if inter_feat have label_field [{dataset.label_field}].') if model_type != ModelType.SEQUENTIAL: sampler = Sampler(phases, built_datasets, train_neg_sample_args['distribution']) else: sampler = RepeatableSampler(phases, dataset, train_neg_sample_args['distribution']) train_kwargs['sampler'] = sampler.set_phase('train') train_kwargs['neg_sample_args'] = train_neg_sample_args if model_type == ModelType.KNOWLEDGE: kg_sampler = KGSampler(dataset, train_neg_sample_args['distribution']) train_kwargs['kg_sampler'] = kg_sampler dataloader = get_data_loader('train', config, train_neg_sample_args) logger.info( set_color('Build', 'pink') + set_color(f' [{dataloader.__name__}]', 'yellow') + ' for ' + set_color('[train]', 'yellow') + ' with format ' + set_color(f'[{train_kwargs["dl_format"]}]', 'yellow')) if train_neg_sample_args['strategy'] != 'none': logger.info( set_color('[train]', 'pink') + set_color(' Negative Sampling', 'blue') + f': {train_neg_sample_args}') else: logger.info( set_color('[train]', 'pink') + set_color(' No Negative Sampling', 'yellow')) logger.info( set_color('[train]', 'pink') + set_color(' batch_size', 'cyan') + ' = ' + set_color(f'[{train_kwargs["batch_size"]}]', 'yellow') + ', ' + set_color('shuffle', 'cyan') + ' = ' + set_color(f'[{train_kwargs["shuffle"]}]\n', 'yellow')) train_data = dataloader(**train_kwargs) # Evaluation eval_kwargs = { 'config': config, 'batch_size': config['eval_batch_size'], 'dl_format': InputType.POINTWISE, 'shuffle': False, } valid_kwargs = {'dataset': valid_dataset} test_kwargs = {'dataset': test_dataset} if eval_neg_sample_args['strategy'] != 'none': if dataset.label_field in dataset.inter_feat: raise ValueError( f'It can not validate with `{es.es_str[1]}` ' f'when inter_feat have label_field [{dataset.label_field}].') if sampler is None: if model_type != ModelType.SEQUENTIAL: sampler = Sampler(phases, built_datasets, eval_neg_sample_args['distribution']) else: sampler = RepeatableSampler( phases, dataset, eval_neg_sample_args['distribution']) else: sampler.set_distribution(eval_neg_sample_args['distribution']) eval_kwargs['neg_sample_args'] = eval_neg_sample_args valid_kwargs['sampler'] = sampler.set_phase('valid') test_kwargs['sampler'] = sampler.set_phase('test') valid_kwargs.update(eval_kwargs) test_kwargs.update(eval_kwargs) dataloader = get_data_loader('evaluation', config, eval_neg_sample_args) logger.info( set_color('Build', 'pink') + set_color(f' [{dataloader.__name__}]', 'yellow') + ' for ' + set_color('[evaluation]', 'yellow') + ' with format ' + set_color(f'[{eval_kwargs["dl_format"]}]', 'yellow')) logger.info(es) logger.info( set_color('[evaluation]', 'pink') + set_color(' batch_size', 'cyan') + ' = ' + set_color(f'[{eval_kwargs["batch_size"]}]', 'yellow') + ', ' + set_color('shuffle', 'cyan') + ' = ' + set_color(f'[{eval_kwargs["shuffle"]}]\n', 'yellow')) valid_data = dataloader(**valid_kwargs) test_data = dataloader(**test_kwargs) if save: save_split_dataloaders(config, dataloaders=(train_data, valid_data, test_data)) return train_data, valid_data, test_data
def run_trial(model_name , dataset_name , hp_config = None , save_flag = False): if not hp_config: hp_config = {} tuning = False else: tuning = True commons.init_seeds() verbose = True verbose = (not tuning) model_class = statics.model_name_map[model_name] try: default_config = model_class.default_params except AttributeError: default_config = {} assert model_name in statics.recbole_models default_config.update(statics.datasets_params[dataset_name]) default_config.update(hp_config) config = Config(model=model_class, dataset=dataset_name, config_dict=default_config) init_seed(config['seed'], config['reproducibility']) init_logger(config) logger = logging.getLogger() # logger initialization if verbose: logger.info(config) # dataset filtering dataset = create_dataset(config) train_data, valid_data, test_data = data_preparation(config, dataset) train_data = add_graph(train_data) if verbose: logger.info(dataset) model = model_class(config, train_data).to(commons.device) trainer = utils.get_trainer(config)(config, model) best_valid_score, best_valid_result = trainer.fit(train_data, valid_data , verbose= verbose , show_progress=verbose) test_result = trainer.evaluate(test_data) if verbose: logger.info(set_color('best valid ', 'yellow') + f': {best_valid_result}') logger.info(set_color('test result', 'yellow') + f': {test_result}') metric = str.lower(config['valid_metric']) if save_flag: os.makedirs(os.path.join("bestmodels" , dataset_name , str(config["topk"])) , exist_ok=True) save_path = os.path.join("bestmodels" , dataset_name , str(config["topk"]) , "{}.pth".format(model_name)) else: save_path = None if save_path: shutil.copyfile(trainer.saved_model_file , save_path) return { 'metric' : config['valid_metric'], 'best_valid_score': best_valid_score, 'valid_score_bigger': config['valid_metric_bigger'], 'best_valid_result': best_valid_result, 'test_score': test_result[metric] }
def fit(self, train_data, valid_data=None, verbose=True, saved=True, show_progress=False, callback_fn=None): r"""Train the model based on the train data and the valid data. Args: train_data (DataLoader): the train data valid_data (DataLoader, optional): the valid data, default: None. If it's None, the early_stopping is invalid. verbose (bool, optional): whether to write training and evaluation information to logger, default: True saved (bool, optional): whether to save the model parameters, default: True show_progress (bool): Show the progress of training epoch and evaluate epoch. Defaults to ``False``. callback_fn (callable): Optional callback function executed at end of epoch. Includes (epoch_idx, valid_score) input arguments. Returns: (float, dict): best valid score and best valid result. If valid_data is None, it returns (-1, None) """ if saved and self.start_epoch >= self.epochs: self._save_checkpoint(-1) for epoch_idx in range(self.start_epoch, self.epochs): # train training_start_time = time() train_loss = self._train_epoch(train_data, epoch_idx, show_progress=show_progress) self.train_loss_dict[epoch_idx] = sum(train_loss) if isinstance( train_loss, tuple) else train_loss training_end_time = time() train_loss_output = \ self._generate_train_loss_output(epoch_idx, training_start_time, training_end_time, train_loss) #if verbose: # self.logger.info(train_loss_output) # eval if self.eval_step <= 0 or not valid_data: if saved: self._save_checkpoint(epoch_idx) update_output = set_color( 'Saving current', 'blue') + ': %s' % self.saved_model_file #if verbose: # self.logger.info(update_output) continue if (epoch_idx + 1) % self.eval_step == 0: valid_start_time = time() valid_score, valid_result = self._valid_epoch( valid_data, show_progress=show_progress) self.best_valid_score, self.cur_step, stop_flag, update_flag = early_stopping( valid_score, self.best_valid_score, self.cur_step, max_step=self.stopping_step, bigger=self.valid_metric_bigger) valid_end_time = time() valid_score_output = (set_color("epoch %d evaluating", 'green') + " [" + set_color("time", 'blue') + ": %.2fs, " + set_color("valid_score", 'blue') + ": %f]") % \ (epoch_idx, valid_end_time - valid_start_time, valid_score) valid_result_output = set_color( 'valid result', 'blue') + ': \n' + dict2str(valid_result) if verbose: self.logger.info(valid_score_output) names = [k for k, _ in valid_result.items()] values = [round(v, 3) for _, v in valid_result.items()] my_table = PrettyTable() my_table.field_names = names my_table.add_row(values) print(my_table) if update_flag: if saved: self._save_checkpoint(epoch_idx) update_output = set_color( 'Saving current best', 'blue') + ': %s' % self.saved_model_file #if verbose: # self.logger.info(update_output) self.best_valid_result = valid_result if callback_fn: callback_fn(epoch_idx, valid_score) if stop_flag: stop_output = 'Finished training, best eval result in epoch %d' % \ (epoch_idx - self.cur_step * self.eval_step) if verbose: self.logger.info(stop_output) break if self.draw_loss_pic: save_path = '{}-{}-train_loss.pdf'.format(self.config['model'], get_local_time()) self.plot_train_loss(save_path=os.path.join(save_path)) return self.best_valid_score, self.best_valid_result
def fit(self, train_data, valid_data=None, verbose=True, saved=True, show_progress=False, callback_fn=None): if saved and self.start_epoch >= self.epochs: self._save_checkpoint(-1) encoder_params = set(self.model.encoder.parameters()) decoder_params = set(self.model.decoder.parameters()) optimizer_encoder = self._build_optimizer(encoder_params) optimizer_decoder = self._build_optimizer(decoder_params) for epoch_idx in range(self.start_epoch, self.epochs): # alternate training training_start_time = time() train_loss = self._train_epoch(train_data, epoch_idx, show_progress=show_progress, n_epochs=self.n_enc_epochs, encoder_flag=True, optimizer=optimizer_encoder) self.model.update_prior() train_loss = self._train_epoch(train_data, epoch_idx, show_progress=show_progress, n_epochs=self.n_dec_epochs, encoder_flag=False, optimizer=optimizer_decoder) self.train_loss_dict[epoch_idx] = sum(train_loss) if isinstance( train_loss, tuple) else train_loss training_end_time = time() train_loss_output = \ self._generate_train_loss_output(epoch_idx, training_start_time, training_end_time, train_loss) if verbose: self.logger.info(train_loss_output) # eval if self.eval_step <= 0 or not valid_data: if saved: self._save_checkpoint(epoch_idx) update_output = 'Saving current: %s' % self.saved_model_file if verbose: self.logger.info(update_output) continue if (epoch_idx + 1) % self.eval_step == 0: valid_start_time = time() valid_score, valid_result = self._valid_epoch( valid_data, show_progress=show_progress) self.best_valid_score, self.cur_step, stop_flag, update_flag = early_stopping( valid_score, self.best_valid_score, self.cur_step, max_step=self.stopping_step, bigger=self.valid_metric_bigger) valid_end_time = time() valid_score_output = (set_color("epoch %d evaluating", 'green') + " [" + set_color("time", 'blue') + ": %.2fs, " + set_color("valid_score", 'blue') + ": %f]") % \ (epoch_idx, valid_end_time - valid_start_time, valid_score) valid_result_output = set_color( 'valid result', 'blue') + ': \n' + dict2str(valid_result) if verbose: self.logger.info(valid_score_output) self.logger.info(valid_result_output) if update_flag: if saved: self._save_checkpoint(epoch_idx) update_output = set_color( 'Saving current best', 'blue') + ': %s' % self.saved_model_file if verbose: self.logger.info(update_output) self.best_valid_result = valid_result if callback_fn: callback_fn(epoch_idx, valid_score) if stop_flag: stop_output = 'Finished training, best eval result in epoch %d' % \ (epoch_idx - self.cur_step * self.eval_step) if verbose: self.logger.info(stop_output) break return self.best_valid_score, self.best_valid_result