def main(opt, device_id): opt = training_opt_postprocessing(opt, device_id) init_logger(opt.log_file) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = checkpoint['opt'] else: checkpoint = None model_opt = opt # Peek the first dataset to determine the data_type. # (All datasets have the same data_type). first_dataset = next(lazily_load_dataset("train", opt)) data_type = first_dataset.data_type # Load fields generated from preprocess phase. fields = _load_fields(first_dataset, data_type, opt, checkpoint) # Report src/tgt features. src_features, tgt_features = _collect_report_features(fields) for j, feat in enumerate(src_features): logger.info(' * src feature %d size = %d' % (j, len(fields[feat].vocab))) for j, feat in enumerate(tgt_features): logger.info(' * tgt feature %d size = %d' % (j, len(fields[feat].vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = build_optim(model, opt, checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, device_id, model, fields, optim, data_type, model_saver=model_saver) def train_iter_fct(): return build_dataset_iter( lazily_load_dataset("train", opt), fields, opt) def valid_iter_fct(): return build_dataset_iter( lazily_load_dataset("valid", opt), fields, opt, is_train=False) # Do training. trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps, opt.valid_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def bulid_nmt_model(opt): # opt = training_opt_postprocessing(opt) # init_logger(opt.log_file) # Load checkpoint if we resume from a previous training. if opt.model_path_G: logger.info('Loading checkpoint from %s' % opt.model_path_G) checkpoint = torch.load(opt.model_path_G, map_location=lambda storage, loc: storage) model_opt = checkpoint['opt'] else: raise(AssertionError("no nmt model")) # Peek the fisrt dataset to determine the data_type. # (All datasets have the same data_type). first_dataset = next(lazily_load_dataset("train", opt)) data_type = first_dataset.data_type # Load fields generated from preprocess phase. fields = _load_fields(first_dataset, data_type, opt, checkpoint) # Report src/tgt features. src_features, tgt_features = _collect_report_features(fields) for j, feat in enumerate(src_features): logger.info(' * src feature %d size = %d' % (j, len(fields[feat].vocab))) for j, feat in enumerate(tgt_features): logger.info(' * tgt feature %d size = %d' % (j, len(fields[feat].vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint, additional_device0) # model = build_model(model_opt, opt, fields, checkpoint, "cpu") return model
def main(opt): opt = training_opt_postprocessing(opt) # Load checkpoint if we resume from a previous training. if opt.train_from: print('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = checkpoint['opt'] # I don't like reassigning attributes of opt: it's not clear. opt.start_epoch = checkpoint['epoch'] + 1 else: checkpoint = None model_opt = opt # Peek the fisrt dataset to determine the data_type. # (All datasets have the same data_type). first_dataset = next(lazily_load_dataset("train", opt)) data_type = first_dataset.data_type # Load fields generated from preprocess phase. fields = _load_fields(first_dataset, data_type, opt, checkpoint) # Report src/tgt features. _collect_report_features(fields) # Build model. model = build_model(model_opt, opt, fields, checkpoint) _tally_parameters(model) _check_save_model_path(opt) # Build optimizer. optim = build_optim(model, opt, checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, model, fields, optim, data_type, model_saver=model_saver) def train_iter_fct(): return build_dataset_iter(lazily_load_dataset("train", opt), fields, opt) def valid_iter_fct(): return build_dataset_iter(lazily_load_dataset("valid", opt), fields, opt) # Do training. trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps, opt.valid_steps, opt.save_checkpoint_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def monitor_iter_fct(): monitor_data = dict() for src, tgt in zip(opt.monitor_src, opt.monitor_tgt): fname = src.split("/" if "/" in src else "\\")[-1].split( ".")[0].replace("_src", "") monitor_data[fname] = build_dataset_iter(lazily_load_dataset( "monitor", opt, fname), fields, opt, is_train=False) return monitor_data
def get_speech_iterator(self, name, lang1, lang2, is_train=True): # name: train or valid """ Create a new iterator for a dataset. """ key = ','.join([x for x in ['speech', name, lang1, lang2] if x is not None]) logger.info("Creating new training %s iterator ..." % key) speech_direction = (lang1, lang2) iterator = build_dataset_iter( lazily_load_dataset(name, self.params.speech_dataset[speech_direction][0]), self.speech_fields[speech_direction], self.params, is_train) iterator = iter(iterator) self.iterators[key] = iterator return iterator
def get_speech_iterator(self, data_type, lang1, lang2): """ Create a new iterator for a dataset. """ assert data_type in ['valid'] speech_direction = (lang1, lang2) lang2_id = self.params.lang2id[lang2] iterator = build_dataset_iter( lazily_load_dataset(data_type, self.params.speech_dataset[speech_direction][0]), self.params.speech_fields[speech_direction], self.params, False) iterator = iter(iterator) for batch in iterator: bos_index = self.params.bos_index[lang2_id] batch.tgt[0] = bos_index yield batch
def main(opt, device_id): #TODO delete all these lines related to WALS features #begin SimulationLanguages = [opt.wals_src, opt.wals_tgt] print('Loading WALS features from databases...') cwd = os.getcwd() db = sqlite3.connect(cwd + '/onmt/WalsValues.db') cursor = db.cursor() cursor.execute('SELECT * FROM WalsValues') WalsValues = cursor.fetchall() db = sqlite3.connect(cwd + '/onmt/FeaturesList.db') cursor = db.cursor() cursor.execute('SELECT * FROM FeaturesList') FeaturesList = cursor.fetchall() db = sqlite3.connect(cwd + '/onmt/FTInfos.db') cursor = db.cursor() cursor.execute('SELECT * FROM FTInfos') FTInfos = cursor.fetchall() db = sqlite3.connect(cwd + '/onmt/FTList.db') cursor = db.cursor() cursor.execute('SELECT * FROM FTList') FTList = cursor.fetchall() ListLanguages = [] for i in WalsValues: ListLanguages.append(i[0]) FeatureTypes = [] for i in FTList: FeatureTypes.append((i[0], i[1].split(','))) FeatureNames = [] for i in FeatureTypes: FeatureNames += i[1] FeatureTypesNames = [] for i in FeatureTypes: FeatureTypesNames.append(i[0]) FeatureValues, FeatureTensors = get_feat_values(SimulationLanguages, WalsValues, FeaturesList, ListLanguages, FeatureTypes, FeatureNames) print('WALS databases loaded!') #end #TODO: load wals features from command-line (wals.npz) # FeatureValues: defaultdict with feature values, per language. # FeatureTensors: tensor of possible outputs, per feature. opt = training_opt_postprocessing(opt, device_id) init_logger(opt.log_file) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = checkpoint['opt'] else: checkpoint = None model_opt = opt # Peek the first dataset to determine the data_type. # (All datasets have the same data_type). first_dataset = next(lazily_load_dataset("train", opt)) data_type = first_dataset.data_type # Load fields generated from preprocess phase. fields = _load_fields(first_dataset, data_type, opt, checkpoint) # Report src/tgt features. src_features, tgt_features = _collect_report_features(fields) for j, feat in enumerate(src_features): logger.info(' * src feature %d size = %d' % (j, len(fields[feat].vocab))) for j, feat in enumerate(tgt_features): logger.info(' * tgt feature %d size = %d' % (j, len(fields[feat].vocab))) # Build model. #TODO: remove all parameters related to WALS features: FeatureValues, FeatureTensors, FeatureTypes, FeaturesList, FeatureNames, FTInfos, FeatureTypesNames, SimulationLanguages #TODO: include four parameter related to WALS features: the four numpy arrays separately model = build_model(model_opt, opt, fields, checkpoint, FeatureValues, FeatureTensors, FeatureTypes, FeaturesList, FeatureNames, FTInfos, FeatureTypesNames, SimulationLanguages) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = build_optim(model, opt, checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, device_id, model, fields, optim, data_type, model_saver=model_saver) def train_iter_fct(): return build_dataset_iter( lazily_load_dataset("train", opt), fields, opt) def valid_iter_fct(): return build_dataset_iter( lazily_load_dataset("valid", opt), fields, opt, is_train=False) # Do training. trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps, opt.valid_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def train_iter_fct(): model._vivisect["iteration"] += 1 model._vivisect["mode"] = "train" return build_dataset_iter(lazily_load_dataset("train", opt), fields, opt)
def main(opt, device_id): opt = training_opt_postprocessing(opt, device_id) init_logger(opt.log_file) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) # Load default opts values then overwrite it with opts from # the checkpoint. It's usefull in order to re-train a model # after adding a new option (not set in checkpoint) dummy_parser = configargparse.ArgumentParser() opts.model_opts(dummy_parser) default_opt = dummy_parser.parse_known_args([])[0] model_opt = default_opt model_opt.__dict__.update(checkpoint['opt'].__dict__) else: checkpoint = None model_opt = opt # Peek the first dataset to determine the data_type. # (All datasets have the same data_type). first_dataset = next(lazily_load_dataset("train", opt)) data_type = first_dataset.data_type # Load fields generated from preprocess phase. fields = load_fields(first_dataset, opt, checkpoint) # Report src/tgt features. src_features, tgt_features = _collect_report_features(fields) for j, feat in enumerate(src_features): logger.info(' * src feature %d size = %d' % (j, len(fields[feat].vocab))) for j, feat in enumerate(tgt_features): logger.info(' * tgt feature %d size = %d' % (j, len(fields[feat].vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = build_optim(model, opt, checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, device_id, model, fields, optim, data_type, model_saver=model_saver) def train_iter_fct(): return build_dataset_iter(lazily_load_dataset("train", opt), fields, opt) def valid_iter_fct(): return build_dataset_iter(lazily_load_dataset("valid", opt), fields, opt, is_train=False) # Do training. if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') if opt.no_base == False: trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps, opt.valid_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close() if opt.comparable: logger.info('') logger.info('Beginning comparable data extraction and training.') # 1. Initialize Comparable object comp = Comparable(model, trainer, fields, logger, opt) # 2. Infer similarity threshold from training data for epoch in range(opt.comp_epochs): # 3. Update threshold if dynamic if opt.threshold_dynamics != 'static' and epoch != 0: comp.update_threshold(opt.threshold_dynamics, opt.infer_threshold) # 4. Extract parallel data and train #if opt.match_articles: # comparable_data = comp.match_articles(opt.match_articles) # train_stats = comp.extract_and_train(comparable_data) #else: train_stats = comp.extract_and_train(opt.comparable_data) # 5. Validate on validation set if opt.no_valid == False: valid_iter = build_dataset_iter( lazily_load_dataset("valid", opt), fields, opt) valid_stats = comp.validate(valid_iter) # 6. Drop a checkpoint if needed comp.trainer.model_saver._save(epoch)
def valid_iter_fct(): return build_dataset_iter( lazily_load_dataset("valid", opt), fields, opt, is_train=False) # Do training. trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps,
def main(opt, device_id): # opt = training_opt_postprocessing(opt, device_id) init_logger(opt.log_file) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = checkpoint['opt'] else: raise Exception('You need to load a model') logger.info('Loading data from %s' % opt.data) dataset = next(lazily_load_dataset("train", opt)) data_type = dataset.data_type logger.info('Data type %s' % data_type) # Load fields generated from preprocess phase. fields = _load_fields(dataset, data_type, opt, checkpoint) # Build model. model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = build_optim(model, opt, checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) dataset_iter = build_dataset_iter(lazily_load_dataset("train", opt), fields, opt) out_file = codecs.open(opt.output, 'w+', 'utf-8') scorer = onmt.translate.GNMTGlobalScorer(opt.alpha, opt.beta, opt.coverage_penalty, opt.length_penalty) translation_builder = TranslationBuilder(dataset, fields, n_best=opt.n_best, replace_unk=opt.replace_unk, has_tgt=False) def train_iter_fct(): return build_dataset_iter(lazily_load_dataset("train", opt), fields, opt) trainer = build_trainer(opt, device_id, model, fields, optim, data_type, model_saver=model_saver) translator = Translator(trainer.model, fields, opt.beam_size, global_scorer=scorer, out_file=out_file, report_score=False, copy_attn=model_opt.copy_attn, logger=logger) for i, batch in enumerate(dataset_iter): unprocessed_translations = translator.translate_batch(batch, dataset) translations = translation_builder.from_batch(unprocessed_translations) print "Translations: ", ' '.join(translations[0].pred_sents[0]) trainer.train_from_data(batch, train_steps=1) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def meta_valid_iter_fct(task_id, is_log=False): return build_dataset_iter( lazily_load_dataset("meta_valid", opt, task_id=task_id, is_log=is_log), fields_list[task_id], opt)
def __init__(self, encoder, decoder, discriminator, lm, data, params): """ Initialize trainer. """ super().__init__(device_ids=tuple(range(params.otf_num_processes))) self.encoder = encoder self.decoder = decoder self.discriminator = discriminator self.lm = lm self.data = data self.params = params self.train_iter = None self.speech_fields = dict() params.speech_fields = self.speech_fields # training variables self.best_metrics = {metric: -1e12 for metric in self.VALIDATION_METRICS} self.epoch = 0 self.n_total_iter = 0 self.freeze_enc_emb = self.params.freeze_enc_emb self.freeze_dec_emb = self.params.freeze_dec_emb self.reload_model() # initialization for on-the-fly generation/training if len(params.pivo_directions) > 0: self.otf_start_multiprocessing() # define encoder parameters (the ones shared with the # decoder are optimized by the decoder optimizer) enc_params = list(encoder.parameters()) for i in range(params.n_langs): if params.share_lang_emb and i >= 0: break #assert enc_params[i].size() == (params.n_words[i], params.emb_dim) if self.params.share_encdec_emb: to_ignore = 1 if params.share_lang_emb else params.n_langs enc_params = enc_params[to_ignore:] # optimizers if params.dec_optimizer == 'enc_optimizer': params.dec_optimizer = params.enc_optimizer self.enc_optimizer = get_optimizer(enc_params, params.enc_optimizer) if len(enc_params) > 0 else None self.dec_optimizer = get_optimizer(decoder.parameters(), params.dec_optimizer) self.dis_optimizer = get_optimizer(discriminator.parameters(), params.dis_optimizer) if discriminator is not None else None self.lm_optimizer = get_optimizer(lm.parameters(), params.enc_optimizer) if lm is not None else None # models / optimizers self.model_opt = { 'enc': (self.encoder, self.enc_optimizer), 'dec': (self.decoder, self.dec_optimizer), 'dis': (self.discriminator, self.dis_optimizer), 'lm': (self.lm, self.lm_optimizer), } # define validation metrics / stopping criterion used for early stopping logger.info("Stopping criterion: %s" % params.stopping_criterion) if params.stopping_criterion == '': for lang1, lang2 in self.data['para'].keys(): for data_type in ['valid', 'test']: self.VALIDATION_METRICS.append('bleu_%s_%s_%s' % (lang1, lang2, data_type)) self.VALIDATION_METRICS.append('speech_bleu_%s_%s_%s' % (lang1, lang2, data_type)) for lang1, lang2, lang3 in self.params.pivo_directions: if lang1 == lang3: continue for data_type in ['valid', 'test']: self.VALIDATION_METRICS.append('bleu_%s_%s_%s_%s' % (lang1, lang2, lang3, data_type)) self.stopping_criterion = None self.best_stopping_criterion = None else: split = params.stopping_criterion.split(',') assert len(split) == 2 and split[1].isdigit() self.decrease_counts_max = int(split[1]) self.decrease_counts = 0 self.stopping_criterion = split[0] self.best_stopping_criterion = -1e12 assert len(self.VALIDATION_METRICS) == 0 self.VALIDATION_METRICS.append(self.stopping_criterion) # training statistics self.n_iter = 0 self.n_sentences = 0 self.stats = { 'dis_costs': [], 'processed_s': 0, 'processed_w': 0, } for lang1, lang2 in params.speech_dataset.keys(): self.stats['xe_costs_sp_%s_%s' % (lang1, lang2)] = [] for speech_direction in params.speech_dataset.keys(): # Peek the first dataset to determine the data_type. # (All datasets have the same data_type). first_dataset = next(lazily_load_dataset("train", params.speech_dataset[speech_direction][0])) data_type = first_dataset.data_type # Load fields generated from preprocess phase. fields = _load_fields_vocab(first_dataset, data_type, params.speech_vocabs[speech_direction[1]], None, 'vocab') self.speech_fields[speech_direction] = fields for lang in params.mono_directions: self.stats['xe_costs_%s_%s' % (lang, lang)] = [] for lang1, lang2 in params.para_directions: self.stats['xe_costs_%s_%s' % (lang1, lang2)] = [] for lang1, lang2 in params.back_directions: self.stats['xe_costs_bt_%s_%s' % (lang1, lang2)] = [] for lang1, lang2, lang3 in params.pivo_directions: self.stats['xe_costs_%s_%s_%s' % (lang1, lang2, lang3)] = [] for lang in params.langs: self.stats['lme_costs_%s' % lang] = [] self.stats['lmd_costs_%s' % lang] = [] self.stats['lmer_costs_%s' % lang] = [] self.stats['enc_norms_%s' % lang] = [] self.last_time = time.time() if len(params.pivo_directions) > 0: self.gen_time = 0 # data iterators self.iterators = {} # initialize BPE subwords self.init_bpe() # initialize lambda coefficients and their configurations parse_lambda_config(params, 'lambda_xe_mono') parse_lambda_config(params, 'lambda_xe_para') parse_lambda_config(params, 'lambda_xe_back') parse_lambda_config(params, 'lambda_xe_otfd') parse_lambda_config(params, 'lambda_xe_otfa') parse_lambda_config(params, 'lambda_dis') parse_lambda_config(params, 'lambda_lm') parse_lambda_config(params, 'lambda_speech')
def extract_and_train(self, comparable_data_list): """ Manages the alternating extraction of parallel sentences and training. Args: comparable_data_list(str): path to list of mapped documents Returns: train_stats(:obj:'onmt.Trainer.Statistics'): epoch loss statistics """ # Start first epoch self.trainer.next_epoch() self.accepted_file = \ open('{}_accepted-e{}.txt'.format(self.comp_log, self.trainer.cur_epoch), 'w+', encoding='utf8') self.status_file = '{}_status-e{}.txt'.format(self.comp_log, self.trainer.cur_epoch) if self.write_dual: self.embed_file = '{}_accepted_embed-e{}.txt'.format( self.comp_log, self.trainer.cur_epoch) self.hidden_file = '{}_accepted_hidden-e{}.txt'.format( self.comp_log, self.trainer.cur_epoch) epoch_similarities = [] epoch_scores = [] counter = 0 src_sents = [] tgt_sents = [] src_embeds = [] tgt_embeds = [] # Go through comparable data with open(comparable_data_list, encoding='utf8') as c: comp_list = c.read().split('\n') num_articles = len(comp_list) cur_article = 0 for article_pair in comp_list: cur_article += 1 # Update status with open(self.status_file, 'a', encoding='utf8') as sf: sf.write('{} / {}\n'.format(cur_article, num_articles)) articles = article_pair.split('\t') # Discard malaligned documents if len(articles) != 2: continue # Prepare iterator objects for current src/tgt document src_article = self._get_iterator(articles[0]) tgt_article = self._get_iterator(articles[1]) # Get sentence representations try: if self.representations == 'embed-only': # C_e src_sents += self.get_article_coves(src_article, 'embed', fast=self.fast) tgt_sents += self.get_article_coves(tgt_article, 'embed', fast=self.fast) else: # C_h src_sents += self.get_article_coves(src_article, fast=self.fast) tgt_sents += self.get_article_coves(tgt_article, fast=self.fast) # C_e src_embeds += self.get_article_coves(src_article, 'embed', fast=self.fast) tgt_embeds += self.get_article_coves(tgt_article, 'embed', fast=self.fast) except: # Skip document pair in case of errors src_sents = [] tgt_sents = [] src_embeds = [] tgt_embeds = [] continue # Ensure enough sentences are accumulated (otherwise scoring becomes unstable) if len(src_sents) < 15 or len(tgt_sents) < 15: continue # Score src and tgt sentences src2tgt, tgt2src, similarities, scores = self.score_sents( src_sents, tgt_sents) # Keep statistics epoch_similarities += similarities epoch_scores += scores src_sents = [] tgt_sents = [] # Filter candidates (primary filter) try: if self.representations == 'dual': # For dual representation systems, filter C_h... candidates = self.filter_candidates(src2tgt, tgt2src, second=self.second) # ...and C_e comparison_pool, cand_embed = self.get_comparison_pool( src_embeds, tgt_embeds) src_embeds = [] tgt_embeds = [] if self.write_dual: self.write_embed_only(candidates, cand_embed) else: # Filter C_e or C_h for single representation system candidates = self.filter_candidates(src2tgt, tgt2src) comparison_pool = None except: # Skip document pair in case of errors print('Error occured in: {}\n'.format(article_pair), flush=True) src_embeds = [] tgt_embeds = [] continue # Extract parallel samples (secondary filter) self.extract_parallel_sents(candidates, comparison_pool) # Check if enough parallel sentences were collected while self.similar_pairs.contains_batch(): # Get a batch of extracted parrallel sentences and train try: training_batch = self.similar_pairs.yield_batch() except: print('Error creating batch. Continuing...', flush=True) continue # Statistics train_stats = self.trainer.train(training_batch) self.trainstep += 1 # Validate if self.trainstep % self.valid_steps == 0: if self.no_valid == False: valid_iter = build_dataset_iter( lazily_load_dataset('valid', self.opt), self.fields, self.opt) valid_stats = self.validate(valid_iter) # Create checkpoint if self.trainstep % 5000 == 0: self.trainer.model_saver._save(self.trainstep) # Train on remaining partial batch if len((self.similar_pairs.pairs)) > 0: train_stats = self.trainer.train( self.similar_pairs.yield_batch()) self.trainstep += 1 # Write epoch statistics self.write_similarities(epoch_similarities, 'e{}_comp'.format(self.trainer.cur_epoch)) self.write_similarities( epoch_scores, 'e{}_comp_scores'.format(self.trainer.cur_epoch)) self.trainer.report_epoch() self.logger.info( 'Accepted parrallel sentences from comparable data: %d / %d' % (self.accepted, self.total)) self.logger.info( 'Acceptable parrallel sentences from comparable data (out of limit): %d / %d' % (self.accepted_limit, self.total)) self.logger.info('Declined sentences from comparable data: %d / %d' % (self.declined, self.total)) # Reset epoch statistics self.accepted = 0 self.accepted_limit = 0 self.declined = 0 self.total = 0 self.accepted_file.close() return train_stats
def main(opt, device_id): opt = training_opt_postprocessing(opt, device_id) init_logger(opt.log_file) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = checkpoint['opt'] else: checkpoint = None model_opt = opt # Peek the first dataset to determine the data_type. # (All datasets have the same data_type). first_dataset = next(lazily_load_dataset("train", opt)) data_type = first_dataset.data_type # Load fields generated from preprocess phase. fields = _load_fields(first_dataset, data_type, opt, checkpoint) # Report src/tgt features. src_features, tgt_features = _collect_report_features(fields) for j, feat in enumerate(src_features): logger.info(' * src feature %d size = %d' % (j, len(fields[feat].vocab))) for j, feat in enumerate(tgt_features): logger.info(' * tgt feature %d size = %d' % (j, len(fields[feat].vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. if opt.train_from and opt.reset_optim != 'all': logger.info('* checkpoint training not considered by me yet') else: # warmup_steps and rnn_size are parameters for Noam decay (transformer): # https://arxiv.org/pdf/1706.03762.pdf (Section 3) decay_method = opt.decay_method if opt.decay_method else "standard" logger.info( '* Opt: %s (rate %.5f, maxgnorm %.1f, %s decay, ' 'decay_rate %.1f, start_decay_at %d, decay_every %d, ' 'ab1 %.5f, ab2 %.5f, adagradaccum %.1f, ' 'warmupsteps %d, hiddensize %d)' % (opt.optim, opt.learning_rate, opt.max_grad_norm, decay_method, opt.learning_rate_decay, opt.start_decay_steps, opt.decay_steps, opt.adam_beta1, opt.adam_beta2, opt.adagrad_accumulator_init, opt.warmup_steps, opt.rnn_size)) optim = build_optim(model, opt, checkpoint) # Build model saver v model_saver = build_model_saver(model_opt, opt, model, fields, optim) logger.info('* model_saver built, using it to build trainer with ') trainer = build_trainer(opt, device_id, model, fields, optim, data_type, model_saver=model_saver) #--------------------------------------------------------------------------- # 1. lazily_load_dataset = for pt in pts: yield torch.load(pt) # 2. build_dataset_iter = return DatasetLazyIter (train_iter_fct) # 3. train_iter_fct() = iterator over torchtext.data.batch.Batches #--------------------------------------------------------------------------- def train_iter_fct(): return build_dataset_iter(lazily_load_dataset("train", opt), fields, opt) def valid_iter_fct(): return build_dataset_iter(lazily_load_dataset("valid", opt), fields, opt, is_train=False) # Do training. if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps, opt.valid_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def train_iter_fct(task_id): return build_dataset_iter( lazily_load_dataset("train", opt, task_id=task_id), fields_list[task_id], opt)
train_dataset_files = build_save_dataset('train', fields, opt, logger) logger.info("Building & saving vocabulary...") build_save_vocab(train_dataset_files, fields, opt, logger) logger.info("Building & saving validation data...") build_save_dataset('valid', fields, opt, logger) opt = train_args opt = training_opt_postprocessing(opt) model_opt = opt checkpoint = None # Peek the fisrt dataset to determine the data_type. # (All datasets have the same data_type). first_dataset = next(lazily_load_dataset("train", opt)) data_type = first_dataset.data_type # Load fields generated from preprocess phase. fields = _load_fields(first_dataset, data_type, opt, checkpoint) # Report src/tgt features. _collect_report_features(fields) # Build model. model = build_model(model_opt, opt, fields, checkpoint) remove(args.host, args.port, "OpenNMT") probe( args.name, model,
def valid_iter_fct(task_id): return build_dataset_iter( lazily_load_dataset("valid", opt, task_id=task_id), fields_list[task_id], opt)
def main(opt, device_id): opt = training_opt_postprocessing(opt, device_id) init_logger(opt.log_file) out_file = None best_test_score, best_ckpt = -10000, None dummy_parser = argparse.ArgumentParser(description='all_dev.py') opts.model_opts(dummy_parser) dummy_opt = dummy_parser.parse_known_args([])[0] for i in range(0, opt.train_epochs, 10): ckpt_path = '{}_epoch_{}.pt'.format(opt.save_model, i + 1) logger.info('Loading checkpoint from %s' % ckpt_path) checkpoint = torch.load(ckpt_path, map_location=lambda storage, loc: storage) model_opt = checkpoint['opt'] fields = load_fields_from_vocab(checkpoint['vocab'], data_type="text") # Build model. model = build_model(model_opt, opt, fields, checkpoint) assert opt.train_from == '' # do not load optimizer state optim = build_optim(model, opt, checkpoint) # Build model saver, no need to create task dir for dev if not os.path.exists('experiments/all_dev'): os.mkdir('experiments/all_dev') os.mkdir('experiments/all_dev/' + opt.meta_dev_task) elif not os.path.exists('experiments/all_dev/' + opt.meta_dev_task): os.mkdir('experiments/all_dev/' + opt.meta_dev_task) model_saver = build_model_saver( model_opt, 'experiments/all_dev/' + opt.meta_dev_task + '/model', opt, model, fields, optim) trainer = build_trainer(opt, device_id, model, fields, optim, "text", model_saver=model_saver) train_iter = list( build_dataset_iter(lazily_load_dataset("train", opt), fields, opt)) # do training on trainset of meta-dev task trainer.train(train_iter, opt.inner_iterations) # do evaluation on devset of meta-dev task best_dev_score, best_model_path = -10000, None for model_path in os.listdir('experiments/all_dev/' + opt.meta_dev_task): if model_path.find('.pt') == -1: continue if out_file is None: out_file = codecs.open(opt.output, 'w+', 'utf-8') fields, model, model_opt = onmt.model_builder.load_test_model( opt, dummy_opt.__dict__, model_path='experiments/all_dev/' + opt.meta_dev_task + '/' + model_path) scorer = onmt.translate.GNMTGlobalScorer(opt.alpha, opt.beta, opt.coverage_penalty, opt.length_penalty) kwargs = { k: getattr(opt, k) for k in [ "beam_size", "n_best", "max_length", "min_length", "stepwise_penalty", "block_ngram_repeat", "ignore_when_blocking", "dump_beam", "report_bleu", "replace_unk", "gpu", "verbose", "fast", "mask_from" ] } fields['graph'] = torchtext.data.Field(sequential=False) translator = Translator(model, fields, global_scorer=scorer, out_file=out_file, report_score=False, copy_attn=model_opt.copy_attn, logger=logger, log_probs_out_file=None, **kwargs) # make translation and save result all_scores, all_predictions = translator.translate( src_path='processed_data/meta-dev/' + opt.meta_dev_task + '/src-dev.txt', tgt_path=None, src_dir=None, batch_size=opt.translate_batch_size, attn_debug=False) # dump predictions f = open('experiments/all_dev/' + opt.meta_dev_task + '/dev_predictions.csv', 'w', encoding='utf-8') f.write('smiles,property\n') for n_best_mols in all_predictions: for mol in n_best_mols: f.write(mol.replace(' ', '') + ',0\n') f.close() # call chemprop to get scores test_path = '\"' + 'experiments/all_dev/' + opt.meta_dev_task + '/dev_predictions.csv' + '\"' checkpoint_path = '\"' + 'scorer_ckpts/' + opt.meta_dev_task + '/model.pt' + '\"' preds_path = '\"' + 'experiments/all_dev/' + opt.meta_dev_task + '/dev_scores.csv' + '\"' # in case of all mols are invalid (will produce not output file by chemprop) # the predictions are copied into score file cmd = 'cp {} {}'.format(test_path, preds_path) result = os.popen(cmd) result.close() cmd = 'python chemprop/predict.py --test_path {} --checkpoint_path {} --preds_path {} --num_workers 0'.format( test_path, checkpoint_path, preds_path) scorer_result = os.popen(cmd) scorer_result.close() # read score file and get score score = read_score_csv('experiments/all_dev/' + opt.meta_dev_task + '/dev_scores.csv') assert len(score) % opt.beam_size == 0 # dev_scores = [] # for i in range(0, len(score), opt.beam_size): # dev_scores.append(sum([x[1] for x in score[i:i+opt.beam_size]]) / opt.beam_size) # report dev score dev_metrics = calculate_metrics(opt.meta_dev_task, 'dev', 'dev', score) logger.info('dev metrics: ' + str(dev_metrics)) dev_score = dev_metrics['success_rate'] if dev_score > best_dev_score: logger.info('New best dev success rate: {:.4f} by {}'.format( dev_score, model_path)) best_model_path = model_path best_dev_score = dev_score else: logger.info('dev success rate: {:.4f} by {}'.format( dev_score, model_path)) del fields del model del model_opt del scorer del translator gc.collect() assert best_model_path != None # do testing on testset of meta-dev task if out_file is None: out_file = codecs.open(opt.output, 'w+', 'utf-8') fields, model, model_opt = onmt.model_builder.load_test_model( opt, dummy_opt.__dict__, model_path='experiments/all_dev/' + opt.meta_dev_task + '/' + best_model_path) scorer = onmt.translate.GNMTGlobalScorer(opt.alpha, opt.beta, opt.coverage_penalty, opt.length_penalty) kwargs = { k: getattr(opt, k) for k in [ "beam_size", "n_best", "max_length", "min_length", "stepwise_penalty", "block_ngram_repeat", "ignore_when_blocking", "dump_beam", "report_bleu", "replace_unk", "gpu", "verbose", "fast", "mask_from" ] } fields['graph'] = torchtext.data.Field(sequential=False) translator = Translator(model, fields, global_scorer=scorer, out_file=out_file, report_score=False, copy_attn=model_opt.copy_attn, logger=logger, log_probs_out_file=None, **kwargs) # make translation and save result all_scores, all_predictions = translator.translate( src_path='processed_data/meta-dev/' + opt.meta_dev_task + '/src-test.txt', tgt_path=None, src_dir=None, batch_size=opt.translate_batch_size, attn_debug=False) # dump predictions f = open('experiments/all_dev/' + opt.meta_dev_task + '/test_predictions.csv', 'w', encoding='utf-8') f.write('smiles,property\n') for n_best_mols in all_predictions: for mol in n_best_mols: f.write(mol.replace(' ', '') + ',0\n') f.close() # call chemprop to get scores test_path = '\"' + 'experiments/all_dev/' + opt.meta_dev_task + '/test_predictions.csv' + '\"' checkpoint_path = '\"' + 'scorer_ckpts/' + opt.meta_dev_task + '/model.pt' + '\"' preds_path = '\"' + 'experiments/all_dev/' + opt.meta_dev_task + '/test_scores.csv' + '\"' # in case of all mols are invalid (will produce not output file by chemprop) # the predictions are copied into score file cmd = 'cp {} {}'.format(test_path, preds_path) result = os.popen(cmd) result.close() cmd = 'python chemprop/predict.py --test_path {} --checkpoint_path {} --preds_path {} --num_workers 0'.format( test_path, checkpoint_path, preds_path) scorer_result = os.popen(cmd) # logger.info('{}'.format('\n'.join(scorer_result.readlines()))) scorer_result.close() # read score file and get score score = read_score_csv('experiments/all_dev/' + opt.meta_dev_task + '/test_scores.csv') assert len(score) % opt.beam_size == 0 # test_scores = [] # for i in range(0, len(score), opt.beam_size): # test_scores.append(sum([x[1] for x in score[i:i+opt.beam_size]]) / opt.beam_size) # report if it is the best on test test_metrics = calculate_metrics(opt.meta_dev_task, 'dev', 'test', score) logger.info('test metrics: ' + str(test_metrics)) test_score = test_metrics['success_rate'] if test_score > best_test_score: best_ckpt = ckpt_path logger.info('New best test success rate: {:.4f} by {}'.format( test_score, ckpt_path)) best_test_score = test_score else: logger.info('test success rate: {:.4f} by {}'.format( test_score, ckpt_path)) del model_opt del fields del checkpoint del model del optim del model_saver del trainer gc.collect()
def main(opt, device_id): opt = training_opt_postprocessing(opt, device_id) init_logger(opt.log_file) # Gather information related to the training script and commit version script_path = os.path.abspath(__file__) script_dir = os.path.dirname(os.path.dirname(script_path)) logger.info('Train script dir: %s' % script_dir) git_commit = str(subprocess.check_output(['bash', script_dir + '/cluster_scripts/git_version.sh'])) logger.info("Git Commit: %s" % git_commit[2:-3]) # Load checkpoint if we resume from a previous training. if opt.train_from: # TODO: load MTL model logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) # Load default opts values then overwrite it with opts from # the checkpoint. It's usefull in order to re-train a model # after adding a new option (not set in checkpoint) dummy_parser = configargparse.ArgumentParser() opts.model_opts(dummy_parser) default_opt = dummy_parser.parse_known_args([])[0] model_opt = default_opt model_opt.__dict__.update(checkpoint['opt'].__dict__) else: checkpoint = None model_opt = opt num_tasks = len(opt.data.split(',')) opt.num_tasks = num_tasks checkpoint_list=[] if opt.warm_model: base_name=opt.warm_model for task_id in range(num_tasks): chkpt_path=base_name.replace("X",str(task_id)) if not os.path.isfile(chkpt_path): chkpt_path = base_name.replace("X", str(0)) logger.info('Loading a checkpoint from %s' % chkpt_path) checkpoint = torch.load(chkpt_path, map_location=lambda storage, loc: storage) checkpoint_list.append(checkpoint) else: for task_id in range(num_tasks): checkpoint_list.append(None) fields_list = [] data_type=None for task_id in range(num_tasks): # Peek the first dataset to determine the data_type. # (All datasets have the same data_type). first_dataset = next(lazily_load_dataset("train", opt, task_id=task_id)) data_type = first_dataset.data_type # Load fields generated from preprocess phase. if opt.mtl_shared_vocab and task_id > 0: logger.info(' * vocabulary size. Same as the main task!') fields = fields_list[0] else: fields = load_fields(first_dataset, opt, checkpoint_list[task_id], task_id=task_id) # Report src/tgt features. src_features, tgt_features = _collect_report_features(fields) for j, feat in enumerate(src_features): logger.info(' * (Task %d) src feature %d size = %d' % (task_id, j, len(fields[feat].vocab))) for j, feat in enumerate(tgt_features): logger.info(' * (Task %) tgt feature %d size = %d' % (task_id, j, len(fields[feat].vocab))) fields_list.append(fields) if opt.epochs > -1: total_num_batch = 0 for task_id in range(num_tasks): train_iter = build_dataset_iter(lazily_load_dataset("train", opt, task_id=task_id), fields_list[task_id], opt) for i, batch in enumerate(train_iter): num_batch = i total_num_batch+=num_batch if opt.mtl_schedule < 10: break num_batch = total_num_batch opt.train_steps = (num_batch * opt.epochs) + 1 # Do the validation and save after each epoch opt.valid_steps = num_batch opt.save_checkpoint_steps = 1 # logger.info(opt_to_string(opt)) logger.info(opt) # Build model(s). models_list = [] for task_id in range(num_tasks): if opt.mtl_fully_share and task_id > 0: # Since we only have one model, copy the pointer to the model for all models_list.append(models_list[0]) else: main_model = models_list[0] if task_id > 0 else None model = build_model(model_opt, opt, fields_list[task_id], checkpoint_list[task_id], main_model=main_model, task_id=task_id) n_params, enc, dec = _tally_parameters(model) logger.info('(Task %d) encoder: %d' % (task_id, enc)) logger.info('(Task %d) decoder: %d' % (task_id, dec)) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) models_list.append(model) # combine parameters of different models and consider shared parameters just once. def combine_named_parameters(named_params_list): observed_params = [] for model_named_params in named_params_list: for name, p in model_named_params: is_observed = False # Check whether we observed this parameter before for param in observed_params: if p is param: is_observed = True break if not is_observed: observed_params.append(p) yield name, p # Build optimizer. optims_list = [] all_models_params=[] for task_id in range(num_tasks): if not opt.mtl_shared_optimizer: optim = build_optim(models_list[task_id], opt, checkpoint) optims_list.append(optim) else: all_models_params.append(models_list[task_id].named_parameters()) # Extract the list of shared parameters among the models of all tasks. observed_params = [] shared_params = [] for task_id in range(num_tasks): for name, p in models_list[task_id].named_parameters(): is_observed = False # Check whether we observed this parameter before for param in observed_params: if p is param: shared_params.append(name) is_observed = True break if not is_observed: observed_params.append(p) opt.shared_params = shared_params if opt.mtl_shared_optimizer: optim = build_optim_mtl_params(combine_named_parameters(all_models_params), opt, checkpoint) optims_list.append(optim) # Build model saver model_saver = build_mtl_model_saver(model_opt, opt, models_list, fields_list, optims_list) trainer = build_trainer(opt, device_id, models_list, fields_list, optims_list, data_type, model_saver=model_saver) def train_iter_fct(task_id): return build_dataset_iter( lazily_load_dataset("train", opt, task_id=task_id), fields_list[task_id], opt) def valid_iter_fct(task_id): return build_dataset_iter( lazily_load_dataset("valid", opt, task_id=task_id), fields_list[task_id], opt) def meta_valid_iter_fct(task_id, is_log=False): return build_dataset_iter( lazily_load_dataset("meta_valid", opt, task_id=task_id, is_log=is_log), fields_list[task_id], opt) # Do training. if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps, opt.valid_steps, meta_valid_iter_fct=meta_valid_iter_fct) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id): opt = training_opt_postprocessing(opt, device_id) init_logger(opt.log_file) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = checkpoint['opt'] else: checkpoint = None model_opt = opt # Peek the first dataset to determine the data_type. # (All datasets have the same data_type). first_dataset = next(lazily_load_dataset("train", opt)) data_type = first_dataset.data_type # Load fields generated from preprocess phase. fields = _load_fields(first_dataset, data_type, opt, checkpoint) # Build model. model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) logger.info('* batch_size: %d' % opt.batch_size) _check_save_model_path(opt) # Build optimizer. optim = build_optim(model, opt, checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, device_id, model, fields, optim, data_type, model_saver=model_saver) def data_iter_fct(data_stage): """data_stage: train / valid""" pt_file = opt.data + '.' + data_stage + '.pt' logger.info('Loading {} dataset'.format(data_stage)) dataset = torch.load(pt_file) logger.info('Loaded {} dataset'.format(data_stage)) dataset.fields = fields is_train = True if data_stage == "train" else False batch_size = opt.batch_size if is_train else opt.valid_batch_size repeat = True if data_stage == "train" else False if opt.gpuid != -1: device = "cuda" else: device = "cpu" def sort_key(ex): """ Sort using length of source sentences. """ return ex.total_tokens return torchtext.data.Iterator(dataset=dataset, batch_size=batch_size, device=device, train=is_train, sort=False, sort_key=sort_key, repeat=repeat) # Do training. trainer.train(data_iter_fct, opt.train_steps, opt.valid_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def valid_iter_fct(): return build_dataset_iter(lazily_load_dataset("valid", opt, logger), fields, opt)
def train_iter_fct(): return build_dataset_iter(lazily_load_dataset("train", opt), fields, opt)
def main(opt, device_id): opt = training_opt_postprocessing(opt, device_id) init_logger(opt.log_file) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) # Load default opts values then overwrite it with opts from # the checkpoint. It's usefull in order to re-train a model # after adding a new option (not set in checkpoint) dummy_parser = configargparse.ArgumentParser() opts.model_opts(dummy_parser) default_opt = dummy_parser.parse_known_args([])[0] model_opt = default_opt model_opt.__dict__.update(checkpoint['opt'].__dict__) else: checkpoint = None model_opt = opt # Peek the first dataset to determine the data_type. # (All datasets have the same data_type). first_dataset = next(lazily_load_dataset("train", opt)) data_type = first_dataset.data_type # Load fields generated from preprocess phase. fields = load_fields(first_dataset, opt, checkpoint) # Report src/tgt features. src_features, tgt_features = _collect_report_features(fields) for j, feat in enumerate(src_features): logger.info(' * src feature %d size = %d' % (j, len(fields[feat].vocab))) for j, feat in enumerate(tgt_features): logger.info(' * tgt feature %d size = %d' % (j, len(fields[feat].vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = build_optim(model, opt, checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, device_id, model, fields, optim, data_type, model_saver=model_saver) def train_iter_fct(): return build_dataset_iter(lazily_load_dataset("train", opt), fields, opt) def valid_iter_fct(): return build_dataset_iter(lazily_load_dataset("valid", opt), fields, opt, is_train=False) # Do training. if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps, opt.valid_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def valid_iter_fct(): model._vivisect["mode"] = "dev" return build_dataset_iter(lazily_load_dataset("valid", opt), fields, opt)
def valid_iter_fct(): return build_dataset_iter( lazily_load_dataset("valid", opt), fields, opt, is_train=False) # Do training. if len(opt.gpu_ranks):
def main(opt): opt = training_opt_postprocessing(opt) init_logger(opt.log_file) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = checkpoint['opt'] else: checkpoint = None model_opt = opt if opt.load_pretrained_selector_from: logger.info('Loading selector checkpoint from %s' % opt.load_pretrained_selector_from) sel_checkpoint = torch.load(opt.load_pretrained_selector_from, map_location=lambda storage, loc: storage) else: sel_checkpoint = None if opt.load_pretrained_s2s_generator_from: logger.info('Loading s2s generator checkpoint from %s' % opt.load_pretrained_s2s_generator_from) s2s_gen_checkpoint = torch.load( opt.load_pretrained_s2s_generator_from, map_location=lambda storage, loc: storage) else: s2s_gen_checkpoint = None # Peek the fisrt dataset to determine the data_type. # (All datasets have the same data_type). first_dataset = next(lazily_load_dataset("train", opt)) data_type = first_dataset.data_type # Load fields generated from preprocess phase. fields = _load_fields(first_dataset, data_type, opt, checkpoint) # Report src/tgt features. src_features, tgt_features = _collect_report_features(fields) for j, feat in enumerate(src_features): logger.info(' * src feature %d size = %d' % (j, len(fields[feat].vocab))) for j, feat in enumerate(tgt_features): logger.info(' * tgt feature %d size = %d' % (j, len(fields[feat].vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint, sel_checkpoint, s2s_gen_checkpoint) # Fix the pretrained selector parameters if needed if model_opt.fix_sel_all: assert opt.load_pretrained_selector_from assert opt.sel_lambda == 0.0 assert not model_opt.fix_sel_classifier for name, param in model.named_parameters(): if 'selector' in name: param.requires_grad = False # only fix the classifier of the selector if model_opt.fix_sel_classifier: assert opt.load_pretrained_selector_from assert not model_opt.fix_sel_all for name, param in model.named_parameters(): if 'selector' in name and 'rnn' not in name and 'embeddings' not in name: param.requires_grad = False n_params, sel, enc, dec, gen = _my_tally_parameters(model) logger.info('selector: %d' % sel) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('generator: %d' % gen) logger.info('* number of parameters: %d' % n_params) print_trainable_parameters(model) _check_save_model_path(opt) # Build optimizer. optim = build_optim(model, opt, checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, model, fields, optim, data_type, model_saver=model_saver) def train_iter_fct(): return build_dataset_iter(lazily_load_dataset("train", opt), fields, opt) def valid_iter_fct(): return build_dataset_iter(lazily_load_dataset("valid", opt), fields, opt) # Do training. trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps, opt.valid_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def valid_iter_fct(): return build_dataset_iter(lazily_load_dataset("valid", opt), fields, opt, is_train=False)
def valid_iter_fct(): return build_dataset_iter( lazily_load_dataset("valid", opt), fields, opt) # Do training. trainer.train(train_iter_fct, valid_iter_fct, opt.start_epoch, opt.epochs)