def main(opt, device_id): opt = training_opt_postprocessing(opt, device_id) init_logger(opt.log_file) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = checkpoint['opt'] else: checkpoint = None model_opt = opt # Peek the first dataset to determine the data_type. # (All datasets have the same data_type). first_dataset = next(lazily_load_dataset("train", opt)) data_type = first_dataset.data_type # concat, query, hier # Load fields generated from preprocess phase. fields = _load_fields(first_dataset, data_type, opt, checkpoint) # Report src/tgt features. src_features, tgt_features = _collect_report_features(fields) for j, feat in enumerate(src_features): logger.info(' * src feature %d size = %d' % (j, len(fields[feat].vocab))) for j, feat in enumerate(tgt_features): logger.info(' * tgt feature %d size = %d' % (j, len(fields[feat].vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = build_optim(model, opt, checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, device_id, model, fields, optim, data_type, model_saver=model_saver) def data_iter_fct(data_stage): """data_stage: train / valid""" pt_file = opt.data + '.' + data_stage + '.pt' logger.info('Loading {} dataset'.format(data_stage)) dataset = torch.load(pt_file) logger.info('Loaded {} dataset'.format(data_stage)) dataset.fields = fields is_train = True if data_stage == "train" else False batch_size = opt.batch_size if is_train else opt.valid_batch_size repeat = True if data_stage == "train" else False if opt.gpuid != -1: device = "cuda" else: device = "cpu" def sort_key(ex): """ Sort using length of source sentences. """ return ex.total_tokens return torchtext.data.Iterator(dataset=dataset, batch_size=batch_size, device=device, train=is_train, sort=False, sort_key=sort_key, repeat=repeat) # Do training. trainer.train(data_iter_fct, opt.train_steps, opt.valid_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id): opt = training_opt_postprocessing(opt, device_id) init_logger(opt.log_file) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = checkpoint['opt'] else: checkpoint = None model_opt = opt # Peek the first dataset to determine the data_type. # (All datasets have the same data_type). first_dataset = next(lazily_load_dataset("train", opt)) data_type = first_dataset.data_type # Load fields generated from preprocess phase. fields = _load_fields(first_dataset, data_type, opt, checkpoint) # Report src/tgt features. src_features, tgt_features = _collect_report_features(fields) for j, feat in enumerate(src_features): logger.info(' * src feature %d size = %d' % (j, len(fields[feat].vocab))) for j, feat in enumerate(tgt_features): logger.info(' * tgt feature %d size = %d' % (j, len(fields[feat].vocab))) # Build model. #pdb.set_trace() model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = build_optim(model, opt, checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, device_id, model, fields, optim, data_type, model_saver=model_saver) def train_iter_fct(): return build_dataset_iter(lazily_load_dataset("train", opt), fields, opt) def valid_iter_fct(): return build_dataset_iter(lazily_load_dataset("valid", opt), fields, opt, is_train=False) # Do training. trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps, opt.valid_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def main(opt): opt = training_opt_postprocessing(opt) init_logger(opt.log_file) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = checkpoint['opt'] else: checkpoint = None model_opt = opt if opt.load_pretrained_selector_from: logger.info('Loading selector checkpoint from %s' % opt.load_pretrained_selector_from) sel_checkpoint = torch.load(opt.load_pretrained_selector_from, map_location=lambda storage, loc: storage) else: sel_checkpoint = None if opt.load_pretrained_s2s_generator_from: logger.info('Loading s2s generator checkpoint from %s' % opt.load_pretrained_s2s_generator_from) s2s_gen_checkpoint = torch.load( opt.load_pretrained_s2s_generator_from, map_location=lambda storage, loc: storage) else: s2s_gen_checkpoint = None # Peek the fisrt dataset to determine the data_type. # (All datasets have the same data_type). first_dataset = next(lazily_load_dataset("train", opt)) data_type = first_dataset.data_type # Load fields generated from preprocess phase. fields = _load_fields(first_dataset, data_type, opt, checkpoint) # Report src/tgt features. src_features, tgt_features = _collect_report_features(fields) for j, feat in enumerate(src_features): logger.info(' * src feature %d size = %d' % (j, len(fields[feat].vocab))) for j, feat in enumerate(tgt_features): logger.info(' * tgt feature %d size = %d' % (j, len(fields[feat].vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint, sel_checkpoint, s2s_gen_checkpoint) # Fix the pretrained selector parameters if needed if model_opt.fix_sel_all: assert opt.load_pretrained_selector_from assert opt.sel_lambda == 0.0 assert not model_opt.fix_sel_classifier for name, param in model.named_parameters(): if 'selector' in name: param.requires_grad = False # only fix the classifier of the selector if model_opt.fix_sel_classifier: assert opt.load_pretrained_selector_from assert not model_opt.fix_sel_all for name, param in model.named_parameters(): if 'selector' in name and 'rnn' not in name and 'embeddings' not in name: param.requires_grad = False n_params, sel, enc, dec, gen = _my_tally_parameters(model) logger.info('selector: %d' % sel) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('generator: %d' % gen) logger.info('* number of parameters: %d' % n_params) print_trainable_parameters(model) _check_save_model_path(opt) # Build optimizer. optim = build_optim(model, opt, checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, model, fields, optim, data_type, model_saver=model_saver) def train_iter_fct(): return build_dataset_iter(lazily_load_dataset("train", opt), fields, opt) def valid_iter_fct(): return build_dataset_iter(lazily_load_dataset("valid", opt), fields, opt) # Do training. trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps, opt.valid_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id): opt = training_opt_postprocessing(opt, device_id) init_logger(opt.log_file) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) # Load default opts values then overwrite it with opts from # the checkpoint. It's usefull in order to re-train a model # after adding a new option (not set in checkpoint) dummy_parser = configargparse.ArgumentParser() opts.model_opts(dummy_parser) default_opt = dummy_parser.parse_known_args([])[0] model_opt = default_opt model_opt.__dict__.update(checkpoint['opt'].__dict__) else: checkpoint = None model_opt = opt # Peek the first dataset to determine the data_type. # (All datasets have the same data_type). first_dataset = next(lazily_load_dataset("train", opt)) # print("[onmt.train_single.py] first_dataset.examples[0]: {}".format(first_dataset.examples[0])) # print("[onmt.train_single.py] first_dataset.examples[0].src[:10]: {}".format(first_dataset.examples[0].src[:10])) print("[onmt.train_single.py] first_dataset.examples[0].src_da_label: {}". format(first_dataset.examples[0].src_da_label)) print( "[onmt.train_single.py] first_dataset.examples[0].__dict__.keys(): {}". format(first_dataset.examples[0].__dict__.keys())) data_type = first_dataset.data_type print("[onmt.train_single.py] first_dataset.data_type: {}".format( first_dataset.data_type)) # Load fields generated from preprocess phase. fields = load_fields(first_dataset, opt, checkpoint) # Report src/tgt features. knl_features, src_features, tgt_features = _collect_report_features(fields) for j, feat in enumerate(src_features): logger.info(' * knl feature %d size = %d' % (j, len(fields[feat].vocab))) for j, feat in enumerate(src_features): logger.info(' * src feature %d size = %d' % (j, len(fields[feat].vocab))) for j, feat in enumerate(tgt_features): logger.info(' * tgt feature %d size = %d' % (j, len(fields[feat].vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = build_optim(model, opt, checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, device_id, model, fields, optim, data_type, model_saver=model_saver) def train_iter_fct(): return build_dataset_iter(lazily_load_dataset("train", opt), fields, opt) def valid_iter_fct(): return build_dataset_iter(lazily_load_dataset("valid", opt), fields, opt, is_train=False) # Do training. if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps, opt.valid_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id): opt = training_opt_postprocessing(opt, device_id) init_logger(opt.log_file) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = checkpoint['opt'] else: checkpoint = None model_opt = opt # Peek the first dataset to determine the data_type. # (All datasets have the same data_type). first_dataset = next(lazily_load_dataset("train", opt)) data_type = first_dataset.data_type # Load fields generated from preprocess phase. fields = _load_fields(first_dataset, data_type, opt, checkpoint) # Report src/tgt features. src_features, tgt_features = _collect_report_features(fields) for j, feat in enumerate(src_features): logger.info(' * src feature %d size = %d' % (j, len(fields[feat].vocab))) for j, feat in enumerate(tgt_features): logger.info(' * tgt feature %d size = %d' % (j, len(fields[feat].vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. if opt.train_from and opt.reset_optim != 'all': logger.info('* checkpoint training not considered by me yet') else: # warmup_steps and rnn_size are parameters for Noam decay (transformer): # https://arxiv.org/pdf/1706.03762.pdf (Section 3) decay_method = opt.decay_method if opt.decay_method else "standard" logger.info( '* Opt: %s (rate %.5f, maxgnorm %.1f, %s decay, ' 'decay_rate %.1f, start_decay_at %d, decay_every %d, ' 'ab1 %.5f, ab2 %.5f, adagradaccum %.1f, ' 'warmupsteps %d, hiddensize %d)' % (opt.optim, opt.learning_rate, opt.max_grad_norm, decay_method, opt.learning_rate_decay, opt.start_decay_steps, opt.decay_steps, opt.adam_beta1, opt.adam_beta2, opt.adagrad_accumulator_init, opt.warmup_steps, opt.rnn_size)) optim = build_optim(model, opt, checkpoint) # Build model saver v model_saver = build_model_saver(model_opt, opt, model, fields, optim) logger.info('* model_saver built, using it to build trainer with ') trainer = build_trainer(opt, device_id, model, fields, optim, data_type, model_saver=model_saver) #--------------------------------------------------------------------------- # 1. lazily_load_dataset = for pt in pts: yield torch.load(pt) # 2. build_dataset_iter = return DatasetLazyIter (train_iter_fct) # 3. train_iter_fct() = iterator over torchtext.data.batch.Batches #--------------------------------------------------------------------------- def train_iter_fct(): return build_dataset_iter(lazily_load_dataset("train", opt), fields, opt) def valid_iter_fct(): return build_dataset_iter(lazily_load_dataset("valid", opt), fields, opt, is_train=False) # Do training. if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps, opt.valid_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
opt = train_args opt = training_opt_postprocessing(opt) model_opt = opt checkpoint = None # Peek the fisrt dataset to determine the data_type. # (All datasets have the same data_type). first_dataset = next(lazily_load_dataset("train", opt)) data_type = first_dataset.data_type # Load fields generated from preprocess phase. fields = _load_fields(first_dataset, data_type, opt, checkpoint) # Report src/tgt features. _collect_report_features(fields) # Build model. model = build_model(model_opt, opt, fields, checkpoint) remove(args.host, args.port, "OpenNMT") probe( args.name, model, args.host, args.port, when=lambda m, o: m._v.state == "dev", which=lambda m, o: True, #o._v.operation_name in ["encoder", "decoder"], parameters=False, forward=True,
def main(opt, device_id): opt = training_opt_postprocessing(opt, device_id) init_logger(opt.log_file) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = checkpoint['opt'] else: checkpoint = None model_opt = opt # Peek the first dataset to determine the data_type. # (All datasets have the same data_type). first_dataset = next(lazily_load_dataset("train", opt)) data_type = first_dataset.data_type # Load fields generated from preprocess phase. fields = _load_fields(first_dataset, data_type, opt, checkpoint) # Report src/tgt features. src_features, tgt_features = _collect_report_features(fields) for j, feat in enumerate(src_features): logger.info(' * src feature %d size = %d' % (j, len(fields[feat].vocab))) for j, feat in enumerate(tgt_features): logger.info(' * tgt feature %d size = %d' % (j, len(fields[feat].vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = build_optim(model, opt, checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, device_id, model, fields, optim, data_type, model_saver=model_saver) def train_iter_fct(): return build_dataset_iter(lazily_load_dataset("train", opt, True), fields, opt) def valid_iter_fct(): return build_dataset_iter(lazily_load_dataset("valid", opt), fields, opt, is_train=False) def monitor_iter_fct(): monitor_data = dict() for src, tgt in zip(opt.monitor_src, opt.monitor_tgt): fname = src.split("/" if "/" in src else "\\")[-1].split( ".")[0].replace("_src", "") monitor_data[fname] = build_dataset_iter(lazily_load_dataset( "monitor", opt, fname), fields, opt, is_train=False) return monitor_data # Do training. if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') trainer.train(train_iter_fct, valid_iter_fct, monitor_iter_fct, opt.train_steps, opt.valid_steps, opt.monitor_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id): opt = training_opt_postprocessing(opt, device_id) init_logger(opt.log_file) if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = checkpoint['opt'] else: checkpoint = None model_opt = opt first_dataset = pickle.load(open('processed_data/all-train/train.pt', 'rb')) data_type = first_dataset.data_type # Load fields generated from preprocess phase. fields = _load_fields(first_dataset, data_type, opt, checkpoint) # Report src/tgt features. src_features, tgt_features = _collect_report_features(fields) for j, feat in enumerate(src_features): logger.info(' * src feature %d size = %d' % (j, len(fields[feat].vocab))) for j, feat in enumerate(tgt_features): logger.info(' * tgt feature %d size = %d' % (j, len(fields[feat].vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) optim = build_optim(model, opt, checkpoint) # opt.train_from == '' # Build model saver if not os.path.exists('experiments/all_train'): os.mkdir('experiments/all_train') model_saver = build_model_saver(model_opt, opt.save_model, opt, model, fields, optim) trainer = build_trainer(opt, device_id, model, fields, optim, "text", model_saver=model_saver) def _lazy_dataset_loader(pt_file): # dataset = torch.load(pt_file) def dataset_loader(pt_file): with open(pt_file, 'rb') as f: dataset = pickle.load(f) # logger.info('Loading task from <{}>, number of examples: {}'.format(pt_file, len(dataset))) return dataset yield dataset_loader(pt_file) train_iter = list( build_dataset_iter( _lazy_dataset_loader('processed_data/all-train/train.pt'), fields, opt)) trainer.train(train_iter, opt.train_epochs)
def main(opt, device_id): opt = training_opt_postprocessing(opt, device_id) init_logger(opt.log_file) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) # Load default opts values then overwrite it with opts from # the checkpoint. It's usefull in order to re-train a model # after adding a new option (not set in checkpoint) dummy_parser = configargparse.ArgumentParser() opts.model_opts(dummy_parser) default_opt = dummy_parser.parse_known_args([])[0] model_opt = default_opt model_opt.__dict__.update(checkpoint['opt'].__dict__) else: checkpoint = None model_opt = opt # Peek the first dataset to determine the data_type. # (All datasets have the same data_type). first_dataset = next(lazily_load_dataset("train", opt)) data_type = first_dataset.data_type # Load fields generated from preprocess phase. fields = load_fields(first_dataset, opt, checkpoint) # Report src/tgt features. src_features, tgt_features = _collect_report_features(fields) for j, feat in enumerate(src_features): logger.info(' * src feature %d size = %d' % (j, len(fields[feat].vocab))) for j, feat in enumerate(tgt_features): logger.info(' * tgt feature %d size = %d' % (j, len(fields[feat].vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = build_optim(model, opt, checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, device_id, model, fields, optim, data_type, model_saver=model_saver) def train_iter_fct(): return build_dataset_iter(lazily_load_dataset("train", opt), fields, opt) def valid_iter_fct(): return build_dataset_iter(lazily_load_dataset("valid", opt), fields, opt, is_train=False) # Do training. if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') if opt.no_base == False: trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps, opt.valid_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close() if opt.comparable: logger.info('') logger.info('Beginning comparable data extraction and training.') # 1. Initialize Comparable object comp = Comparable(model, trainer, fields, logger, opt) # 2. Infer similarity threshold from training data for epoch in range(opt.comp_epochs): # 3. Update threshold if dynamic if opt.threshold_dynamics != 'static' and epoch != 0: comp.update_threshold(opt.threshold_dynamics, opt.infer_threshold) # 4. Extract parallel data and train #if opt.match_articles: # comparable_data = comp.match_articles(opt.match_articles) # train_stats = comp.extract_and_train(comparable_data) #else: train_stats = comp.extract_and_train(opt.comparable_data) # 5. Validate on validation set if opt.no_valid == False: valid_iter = build_dataset_iter( lazily_load_dataset("valid", opt), fields, opt) valid_stats = comp.validate(valid_iter) # 6. Drop a checkpoint if needed comp.trainer.model_saver._save(epoch)
def main(opt, device_id): opt = training_opt_postprocessing(opt, device_id) init_logger(opt.log_file) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info("Loading checkpoint from %s" % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) # Load default opts values then overwrite it with opts from # the checkpoint. It's usefull in order to re-train a model # after adding a new option (not set in checkpoint) dummy_parser = configargparse.ArgumentParser() opts.model_opts(dummy_parser) default_opt = dummy_parser.parse_known_args([])[0] model_opt = default_opt model_opt.__dict__.update(checkpoint["opt"].__dict__) else: checkpoint = None model_opt = opt # Peek the first dataset to determine the data_type. # (All datasets have the same data_type). first_dataset = next(lazily_load_dataset("train", opt)) data_type = first_dataset.data_type # Load fields generated from preprocess phase. fields = _load_fields(first_dataset, data_type, opt, checkpoint) # Report src/tgt features. src_features, tgt_features = _collect_report_features(fields) for j, feat in enumerate(src_features): logger.info(" * src feature %d size = %d" % (j, len(fields[feat].vocab))) for j, feat in enumerate(tgt_features): logger.info(" * tgt feature %d size = %d" % (j, len(fields[feat].vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec = _tally_parameters(model) logger.info("encoder: %d" % enc) logger.info("decoder: %d" % dec) logger.info("* number of parameters: %d" % n_params) _check_save_model_path(opt) # Build optimizer. optim = build_optim(model, opt, checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, device_id, model, fields, optim, data_type, model_saver=model_saver) def train_iter_fct(): return build_dataset_iter(lazily_load_dataset("train", opt), fields, opt) def valid_iter_fct(): return build_dataset_iter(lazily_load_dataset("valid", opt), fields, opt, is_train=False) # Do training. if len(opt.gpu_ranks): logger.info("Starting training on GPU: %s" % opt.gpu_ranks) else: logger.info("Starting training on CPU, could be very slow") # import cProfile # with cProfile.Profile() as pr: trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps, opt.valid_steps) # pr.print_stats() # pr.dump_stats('/home/philhc/OpenNMT-evidential-softmax/exp/prof.stats') if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def training_main(opt): opt = training_opt_postprocessing(opt) # Load checkpoint if we resume from a previous training. if opt.train_from: print('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = checkpoint['opt'] # I don't like reassigning attributes of opt: it's not clear. opt.start_epoch = checkpoint['epoch'] + 1 else: checkpoint = None model_opt = opt # Peek the fisrt dataset to determine the data_type. # (All datasets have the same data_type). first_dataset = next(lazily_load_dataset("train", opt)) data_type = first_dataset.data_type # Load fields generated from preprocess phase. fields = _load_fields(first_dataset, data_type, opt, checkpoint) # Report src/tgt features. _collect_report_features(fields) # Build model. model = build_model(model_opt, opt, fields, checkpoint) _tally_parameters(model) _check_save_model_path(opt) model._vivisect = { "iteration": 0, "model_name": "OpenNMT Model", "framework": "pytorch", "mode": "train" } probe(model, "localhost", 8082) # Build optimizer. optim = build_optim(model, opt, checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, model, fields, optim, data_type, model_saver=model_saver) def train_iter_fct(): model._vivisect["iteration"] += 1 model._vivisect["mode"] = "train" return build_dataset_iter(lazily_load_dataset("train", opt), fields, opt) def valid_iter_fct(): model._vivisect["mode"] = "dev" return build_dataset_iter(lazily_load_dataset("valid", opt), fields, opt) # Do training. trainer.train(train_iter_fct, valid_iter_fct, opt.start_epoch, opt.epochs) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()