def main(opt, device_id): opt = training_opt_postprocessing(opt, device_id) init_logger(opt.log_file) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = checkpoint['opt'] else: checkpoint = None model_opt = opt # Peek the first dataset to determine the data_type. # (All datasets have the same data_type). first_dataset = next(lazily_load_dataset("train", opt)) data_type = first_dataset.data_type # Load fields generated from preprocess phase. fields = _load_fields(first_dataset, data_type, opt, checkpoint) # Report src/tgt features. src_features, tgt_features = _collect_report_features(fields) for j, feat in enumerate(src_features): logger.info(' * src feature %d size = %d' % (j, len(fields[feat].vocab))) for j, feat in enumerate(tgt_features): logger.info(' * tgt feature %d size = %d' % (j, len(fields[feat].vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = build_optim(model, opt, checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, device_id, model, fields, optim, data_type, model_saver=model_saver) def train_iter_fct(): return build_dataset_iter( lazily_load_dataset("train", opt), fields, opt) def valid_iter_fct(): return build_dataset_iter( lazily_load_dataset("valid", opt), fields, opt, is_train=False) # Do training. trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps, opt.valid_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def main(opt): opt = training_opt_postprocessing(opt) # Load checkpoint if we resume from a previous training. if opt.train_from: print('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = checkpoint['opt'] # I don't like reassigning attributes of opt: it's not clear. opt.start_epoch = checkpoint['epoch'] + 1 else: checkpoint = None model_opt = opt # Peek the fisrt dataset to determine the data_type. # (All datasets have the same data_type). first_dataset = next(lazily_load_dataset("train", opt)) data_type = first_dataset.data_type # Load fields generated from preprocess phase. fields = _load_fields(first_dataset, data_type, opt, checkpoint) # Report src/tgt features. _collect_report_features(fields) # Build model. model = build_model(model_opt, opt, fields, checkpoint) _tally_parameters(model) _check_save_model_path(opt) # Build optimizer. optim = build_optim(model, opt, checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, model, fields, optim, data_type, model_saver=model_saver) def train_iter_fct(): return build_dataset_iter(lazily_load_dataset("train", opt), fields, opt) def valid_iter_fct(): return build_dataset_iter(lazily_load_dataset("valid", opt), fields, opt) # Do training. trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps, opt.valid_steps, opt.save_checkpoint_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def load_pre_train(path): logger.info('Loading pre-train model from %s' % path) checkpoint = torch.load(path, map_location=lambda storage, loc: storage) opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) model_opt = opt ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) fields = checkpoint['vocab'] model = build_model(model_opt, opt, fields, checkpoint) return model
def main(opt): ArgumentParser.validate_train_opts(opt) ArgumentParser.update_model_opts(opt) ArgumentParser.validate_model_opts(opt) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: vocab = torch.load(opt.data + '.vocab.pt') segment_token_idx = None if opt.use_segments: segment_token_idx = vocab['tgt'].base_field.vocab.stoi['.'] opt.segment_token_idx = segment_token_idx # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] fields = vocab for side in ['src', 'tgt']: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] for sn, sf in f_iter: if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) model = build_model(model_opt, opt, fields, checkpoint) pdb.set_trace()
def main(opt, device_id, data): # NOTE: It's important that ``opt`` has been validated and updated # at this point. configure_process(opt, device_id) init_logger(opt.log_file) assert len(opt.accum_count) == len(opt.accum_steps), \ 'Number of accum_count values must match number of accum_steps' checkpoint = None # Report src and tgt vocab sizes, including for features for side in ['src', 'tgt', 'tgt_label']: logger.info(' * %s vocab size = %d' % (side, len(data["dict"][side]))) # Build model. model = build_model(opt, data, checkpoint, device_id) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver model_saver = build_model_saver(opt, opt, model, data, optim) trainer = build_trainer( opt, device_id, model, data, optim, model_saver=model_saver) #from IPython.core.debugger import Pdb; Pdb().set_trace() train_iter = build_dataset_iter("train", data, opt) valid_iter = build_dataset_iter( "valid", data, opt, is_train=False) if opt.gpu: logger.info('Starting training on GPU: %s' % opt.gpu) else: logger.info('Starting training on CPU, could be very slow') train_steps = opt.train_steps trainer.train( train_iter, train_steps, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iter=valid_iter, valid_steps=opt.valid_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def main(opt): if opt.gpuid: raise AssertionError("gpuid is deprecated \ see world_size and gpu_ranks") assert opt.world_size <= 1, "you don't need multi-gpu for morphology" device_id = 0 if len(opt.gpu_ranks) == 1 else -1 opt = training_opt_postprocessing(opt, device_id) init_logger(opt.log_file) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) # Load default opts values then overwrite it with opts from # the checkpoint. It's useful in order to re-train a model # after adding a new option (not set in checkpoint) dummy_parser = configargparse.ArgumentParser() opts.model_opts(dummy_parser) default_opt = dummy_parser.parse_known_args([])[0] model_opt = default_opt model_opt.__dict__.update(checkpoint['opt'].__dict__) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) fields = checkpoint['vocab'] else: checkpoint = None model_opt = opt fields = torch.load(opt.data + '.vocab.pt') for key, values in fields.items(): for name, f in values: if hasattr(f, 'use_vocab') and f.use_vocab: logger.info(' * %s vocab size = %d' % (name, len(f.vocab))) # Build model. logger.info('Building model...') model = build_model(model_opt, fields, use_gpu(opt), checkpoint) logger.info(model) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. params = model.parameters() optim_args = {"lr": opt.learning_rate} if opt.optim == "adam": # no need to mess with the default betas optim_args["eps"] = 1e-9 elif opt.optim == "adagrad": optim_args["initial_accumulator_value"] = opt.adagrad_accumulator_init optim = getattr(torch.optim, opt.optim.title())(params, **optim_args) print(optim) trainer = build_trainer(opt, model_opt, device_id, model, fields, optim) # this line is kind of a temporary kludge because different objects expect # fields to have a different structure dataset_fields = dict(chain.from_iterable(fields.values())) device = "cuda" if opt.gpu_ranks else "cpu" train_dataset = torch.load(opt.data + '.train.pt') train_dataset.fields = dataset_fields train_iter = OrderedIterator( train_dataset, opt.batch_size, sort_within_batch=True, device=device, repeat=False, shuffle=not opt.no_shuffle) valid_dataset = torch.load(opt.data + '.valid.pt') valid_dataset.fields = dataset_fields valid_iter = OrderedIterator( valid_dataset, opt.valid_batch_size, train=False, sort_within_batch=True, device=device) logger.info('Starting training on {}'.format(device)) trainer.train(train_iter, valid_iter, opt.epochs)
def main(opt, device_id): # NOTE: It's important that ``opt`` has been validated and updated # at this point. configure_process(opt, device_id) init_logger(opt.log_file) # Load checkpoint if we resume from a previous training. load_str = opt.train_from if opt.train_from else opt.load_uncond_from if load_str: logger.info('Loading checkpoint from %s' % load_str) checkpoint = torch.load(load_str, map_location=lambda storage, loc: storage) logger.info('Loading vocab from checkpoint at %s.' % load_str) vocab = checkpoint['vocab'] if opt.train_from: model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) else: model_opt = opt else: checkpoint = None model_opt = opt vocab = torch.load(opt.data + '.vocab.pt') if opt.gpt2_params_path is not None: import tensorflow as tf import numpy as np # Taken from pytorch-pretrained-BERT: # Load weights from TF model logger.info("Loading TF GPT weights...") init_vars = tf.train.list_variables(opt.gpt2_params_path) names = [] arrays = [] for name, shape in init_vars: if opt.gpt_emb_only and ('wpe' not in name and 'wte' not in name): continue if opt.gpt_wpe_only and 'wpe' not in name: continue #print("Loading TF weight {} with shape {}".format(name, shape)) array = tf.train.load_variable(opt.gpt2_params_path, name) names.append(name) arrays.append(array.squeeze()) logger.info("Done.") if checkpoint is None: checkpoint = {'gpt2_params': zip(names, arrays)} else: checkpoint['gpt2_params'] = zip(names, arrays) if opt.encoder_from is not None: logger.info('Loading checkpoint with encoder from %s' % opt.encoder_from) enc_checkpoint = torch.load(opt.encoder_from, map_location=lambda storage, loc: storage) enc_vocab = enc_checkpoint['vocab'] if vocab['src'].base_field.vocab != enc_vocab['src'].base_field.vocab: raise ValueError( 'encoder vocab and model vocab need to be identical it using pretrained encoder' ) if checkpoint is None: checkpoint = {} checkpoint['enc_model'] = enc_checkpoint['model'] # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab # Report src and tgt vocab sizes, including for features sides = ['tgt'] if opt.model_type == 'none' else ['src', 'tgt'] for side in sides: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] for sn, sf in f_iter: if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec, lm_dec = _tally_parameters(model) n_params_t, enc_t, dec_t, lm_dec_t = _tally_parameters(model, only_trainable=True) logger.info('encoder: %d (%d)' % (enc, enc_t)) logger.info('decoder: %d (%d)' % (dec, dec_t)) if opt.simple_fusion: logger.info('lm decoder: %d (%d)' % (lm_dec, lm_dec_t)) logger.info('* number of parameters: %d (%d)' % (n_params, n_params_t)) _check_save_model_path(opt) if not opt.train_from and opt.gpt2_params_path is not None: checkpoint = None # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, device_id, model, fields, optim, model_saver=model_saver) train_iter = build_dataset_iter("train", fields, opt) valid_iter = build_dataset_iter("valid", fields, opt, is_train=False) if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') train_steps = opt.train_steps if opt.single_pass and train_steps > 0: logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 trainer.train(train_iter, train_steps, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iter=valid_iter, valid_steps=opt.valid_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id, batch_queue=None, semaphore=None, train_iter=None, passed_fields=None): # NOTE: It's important that ``opt`` has been validated and updated # at this point. configure_process(opt, device_id) init_logger(opt.log_file) assert len(opt.accum_count) == len(opt.accum_steps), \ 'Number of accum_count values must match number of accum_steps' # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) if opt.use_opt_from_trained: model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) else: model_opt = opt ArgumentParser.validate_model_opts(model_opt) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: checkpoint = None model_opt = opt vocab = torch.load(opt.data + '.vocab.pt') # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) aux_fields = None if passed_fields is not None: fields = passed_fields['main'] aux_fields = passed_fields['crosslingual'] elif old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab # Report src and tgt vocab sizes, including for features for side in ['src', 'tgt']: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] for sn, sf in f_iter: if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint, aux_fields=aux_fields) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. if opt.almt_only: almt = model.encoder.embeddings.almt_layers['mapping'] logger.info('Only training the alignment mapping.') optim = Optimizer.from_opt(almt, opt, checkpoint=checkpoint) else: optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim, aux_fields=aux_fields) trainer = build_trainer(opt, device_id, model, fields, optim, model_saver=model_saver, aux_fields=aux_fields) if train_iter is not None: pass # NOTE Use the passed one. elif batch_queue is None: if len(opt.data_ids) > 1: train_shards = [] for train_id in opt.data_ids: shard_base = "train_" + train_id train_shards.append(shard_base) train_iter = build_dataset_iter_multiple(train_shards, fields, opt) else: if opt.data_ids[0] is not None: shard_base = "train_" + opt.data_ids[0] else: shard_base = "train" train_iter = build_dataset_iter(shard_base, fields, opt) else: assert semaphore is not None, \ "Using batch_queue requires semaphore as well" def _train_iter(): while True: batch = batch_queue.get() semaphore.release() yield batch train_iter = _train_iter() cl_valid_iter = None if opt.crosslingual: valid_iter = build_dataset_iter("valid", fields, opt, is_train=False, task_cls=Eat2PlainMonoTask) if opt.crosslingual_dev_data: # NOTE I used 'train' to prepare this in `eat_prepare.sh`, so I use 'train' here as well. cl_valid_iter = build_dataset_iter( 'train', fields, opt, is_train=False, data_attr='crosslingual_dev_data', task_cls=Eat2PlainCrosslingualTask) # NOTE This is for the second eat->plain task. aux_valid_iter = build_dataset_iter('valid', fields, opt, is_train=False, data_attr='aux_train_data', task_cls=Eat2PlainMonoTask) valid_iters = [valid_iter, aux_valid_iter] else: valid_iters = [ build_dataset_iter("valid", fields, opt, is_train=False) ] if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') train_steps = opt.train_steps if opt.single_pass and train_steps > 0: logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 trainer.train(train_iter, train_steps, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iters=valid_iters, valid_steps=opt.valid_steps, cl_valid_iter=cl_valid_iter) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def main(opt, fields, transforms_cls, checkpoint, device_id, batch_queue=None, semaphore=None): """Start training on `device_id`.""" # NOTE: It's important that ``opt`` has been validated and updated # at this point. configure_process(opt, device_id) init_logger(opt.log_file) model_opt = _get_model_opts(opt, checkpoint=checkpoint) # Build model. model = build_model(model_opt, opt, fields, checkpoint) model.count_parameters(log=logger.info) # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, device_id, model, fields, optim, model_saver=model_saver) if batch_queue is None: _train_iter = _build_train_iter(opt, fields, transforms_cls) train_iter = IterOnDevice(_train_iter, device_id) else: assert semaphore is not None, \ "Using batch_queue requires semaphore as well" def _train_iter(): while True: batch = batch_queue.get() semaphore.release() # Move batch to specified device IterOnDevice.batch_to_device(batch, device_id) yield batch train_iter = _train_iter() valid_iter = _build_valid_iter(opt, fields, transforms_cls) if valid_iter is not None: valid_iter = IterOnDevice(valid_iter, device_id) if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') train_steps = opt.train_steps if opt.single_pass and train_steps > 0: logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 trainer.train(train_iter, train_steps, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iter=valid_iter, valid_steps=opt.valid_steps) if trainer.report_manager.tensorboard_writer is not None: trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id): # NOTE: It's important that ``opt`` has been validated and updated # at this point. # import pdb # _check_ = torch.load("/home/irteam/users/kaist/ginalee/clean_data/baselines/9-domain5-185pre_step_2500.pt") # model_encoder = [i for i in _check_['model'].keys() if "encoder" in i.split(".")] # encoder = {} # pdb.set_trace() # for i, param in enumerate(model_encoder): # if i == 0: # encoder['embeddings.word_embeddings.weight'] = _check_['model'][param] # continue # param_ = ".".join(param.split(".")[1:]) # # if param.split(".")[1] == 'encoder': # # param_ = ".".join(param.split(".")[2:]) # # else: # # param_ = ".".join(param.split(".")[1:]) # encoder[param_] = _check_['model'][param] # pdb.set_trace() configure_process(opt, device_id) init_logger(opt.log_file) logger.info(opt) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) load_vocab = torch.load(opt.data + '.vocab.pt') vocab = checkpoint['vocab'] load_vocab['src'].fields[0][1].vocab = vocab['src'].fields[0][1].vocab load_vocab['tgt'].fields[0][1].vocab = vocab['tgt'].fields[0][1].vocab vocab = load_vocab else: checkpoint = None model_opt = opt vocab = torch.load(opt.data + '.vocab.pt') # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab # Report src and tgt vocab sizes, including for features for side in ['src', 'tgt']: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] for sn, sf in f_iter: if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) if opt.pretrain_from: check = torch.load(opt.pretrain_from, map_location=lambda storage, loc: storage) model.load_state_dict(check['model'], strict=False) model.load_state_dict(check['generator'], strict=False) if 'dom_classifier' in check: model.load_state_dict(check['dom_classifier'], strict=False) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) translator = None if opt.domain_cls_enc == False: translator = train_build_translator(opt, model, model_opt, fields, report_score=True) trainer = build_trainer(translator, opt, device_id, model, fields, optim, model_saver=model_saver) train_iter = build_dataset_iter("train", fields, opt) valid_iter = build_dataset_iter("valid", fields, opt, is_train=False) if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') train_steps = opt.train_steps if opt.single_pass and train_steps > 0: logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 trainer.train(train_iter, train_steps, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iter=valid_iter, valid_steps=opt.valid_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id): opt = training_opt_postprocessing(opt, device_id) init_logger(opt.log_file) # Gather information related to the training script and commit version script_path = os.path.abspath(__file__) script_dir = os.path.dirname(os.path.dirname(script_path)) logger.info('Train script dir: %s' % script_dir) git_commit = str(subprocess.check_output(['bash', script_dir + '/cluster_scripts/git_version.sh'])) logger.info("Git Commit: %s" % git_commit[2:-3]) # Load checkpoint if we resume from a previous training. if opt.train_from: # TODO: load MTL model logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) # Load default opts values then overwrite it with opts from # the checkpoint. It's usefull in order to re-train a model # after adding a new option (not set in checkpoint) dummy_parser = configargparse.ArgumentParser() opts.model_opts(dummy_parser) default_opt = dummy_parser.parse_known_args([])[0] model_opt = default_opt model_opt.__dict__.update(checkpoint['opt'].__dict__) else: checkpoint = None model_opt = opt num_tasks = len(opt.data.split(',')) opt.num_tasks = num_tasks checkpoint_list=[] if opt.warm_model: base_name=opt.warm_model for task_id in range(num_tasks): chkpt_path=base_name.replace("X",str(task_id)) if not os.path.isfile(chkpt_path): chkpt_path = base_name.replace("X", str(0)) logger.info('Loading a checkpoint from %s' % chkpt_path) checkpoint = torch.load(chkpt_path, map_location=lambda storage, loc: storage) checkpoint_list.append(checkpoint) else: for task_id in range(num_tasks): checkpoint_list.append(None) fields_list = [] data_type=None for task_id in range(num_tasks): # Peek the first dataset to determine the data_type. # (All datasets have the same data_type). first_dataset = next(lazily_load_dataset("train", opt, task_id=task_id)) data_type = first_dataset.data_type # Load fields generated from preprocess phase. if opt.mtl_shared_vocab and task_id > 0: logger.info(' * vocabulary size. Same as the main task!') fields = fields_list[0] else: fields = load_fields(first_dataset, opt, checkpoint_list[task_id], task_id=task_id) # Report src/tgt features. src_features, tgt_features = _collect_report_features(fields) for j, feat in enumerate(src_features): logger.info(' * (Task %d) src feature %d size = %d' % (task_id, j, len(fields[feat].vocab))) for j, feat in enumerate(tgt_features): logger.info(' * (Task %) tgt feature %d size = %d' % (task_id, j, len(fields[feat].vocab))) fields_list.append(fields) if opt.epochs > -1: total_num_batch = 0 for task_id in range(num_tasks): train_iter = build_dataset_iter(lazily_load_dataset("train", opt, task_id=task_id), fields_list[task_id], opt) for i, batch in enumerate(train_iter): num_batch = i total_num_batch+=num_batch if opt.mtl_schedule < 10: break num_batch = total_num_batch opt.train_steps = (num_batch * opt.epochs) + 1 # Do the validation and save after each epoch opt.valid_steps = num_batch opt.save_checkpoint_steps = 1 # logger.info(opt_to_string(opt)) logger.info(opt) # Build model(s). models_list = [] for task_id in range(num_tasks): if opt.mtl_fully_share and task_id > 0: # Since we only have one model, copy the pointer to the model for all models_list.append(models_list[0]) else: main_model = models_list[0] if task_id > 0 else None model = build_model(model_opt, opt, fields_list[task_id], checkpoint_list[task_id], main_model=main_model, task_id=task_id) n_params, enc, dec = _tally_parameters(model) logger.info('(Task %d) encoder: %d' % (task_id, enc)) logger.info('(Task %d) decoder: %d' % (task_id, dec)) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) models_list.append(model) # combine parameters of different models and consider shared parameters just once. def combine_named_parameters(named_params_list): observed_params = [] for model_named_params in named_params_list: for name, p in model_named_params: is_observed = False # Check whether we observed this parameter before for param in observed_params: if p is param: is_observed = True break if not is_observed: observed_params.append(p) yield name, p # Build optimizer. optims_list = [] all_models_params=[] for task_id in range(num_tasks): if not opt.mtl_shared_optimizer: optim = build_optim(models_list[task_id], opt, checkpoint) optims_list.append(optim) else: all_models_params.append(models_list[task_id].named_parameters()) # Extract the list of shared parameters among the models of all tasks. observed_params = [] shared_params = [] for task_id in range(num_tasks): for name, p in models_list[task_id].named_parameters(): is_observed = False # Check whether we observed this parameter before for param in observed_params: if p is param: shared_params.append(name) is_observed = True break if not is_observed: observed_params.append(p) opt.shared_params = shared_params if opt.mtl_shared_optimizer: optim = build_optim_mtl_params(combine_named_parameters(all_models_params), opt, checkpoint) optims_list.append(optim) # Build model saver model_saver = build_mtl_model_saver(model_opt, opt, models_list, fields_list, optims_list) trainer = build_trainer(opt, device_id, models_list, fields_list, optims_list, data_type, model_saver=model_saver) def train_iter_fct(task_id): return build_dataset_iter( lazily_load_dataset("train", opt, task_id=task_id), fields_list[task_id], opt) def valid_iter_fct(task_id): return build_dataset_iter( lazily_load_dataset("valid", opt, task_id=task_id), fields_list[task_id], opt) def meta_valid_iter_fct(task_id, is_log=False): return build_dataset_iter( lazily_load_dataset("meta_valid", opt, task_id=task_id, is_log=is_log), fields_list[task_id], opt) # Do training. if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps, opt.valid_steps, meta_valid_iter_fct=meta_valid_iter_fct) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id): # NOTE: It's important that ``opt`` has been validated and updated # at this point. if opt.local_rank != -1: torch.cuda.set_device(opt.local_rank) device = torch.device("cuda", opt.local_rank) torch.distributed.init_process_group(backend='nccl') device_id = opt.local_rank world_size = torch.distributed.get_world_size() else: if device_id == -1: device = torch.device("cpu") else: device = torch.device("cuda", device_id) if opt.local_rank > 0: logger.disabled = True configure_process(opt, device_id) init_logger(opt.log_file) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: checkpoint = None model_opt = opt vocab = torch.load(opt.data + '.vocab.pt') # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab # Report src and tgt vocab sizes, including for features for side in ['src', 'tgt']: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] for sn, sf in f_iter: if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, device_id, model, fields, optim, model_saver=model_saver) if opt.bert_kd: src_vocab = vocab['src'].fields[0][1].vocab.stoi tgt_vocab = vocab['tgt'].fields[0][1].vocab.stoi assert 0 < opt.kd_topk <= 128 train_dataset = BertKdDataset(opt.data_db, opt.bert_dump, src_vocab, tgt_vocab, max_len=150, k=opt.kd_topk) BUCKET_SIZE = 8192 if True or opt.local_rank == -1 and opt.world_size == 1: train_sampler = TokenBucketSampler(train_dataset.keys, BUCKET_SIZE, opt.batch_size, batch_multiple=1) else: assert False # seems like it's handled in training loop train_sampler = DistributedTokenBucketSampler(world_size, device_id, train_dataset.keys, BUCKET_SIZE, opt.batch_size, batch_multiple=1) train_loader = DataLoader(train_dataset, batch_sampler=train_sampler, num_workers=4, collate_fn=BertKdDataset.pad_collate) train_iter = cycle_loader(train_loader, device) else: train_iter = build_dataset_iter("train", fields, opt) valid_iter = build_dataset_iter("valid", fields, opt, is_train=False) if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') train_steps = opt.train_steps if opt.single_pass and train_steps > 0: logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 trainer.train(train_iter, train_steps, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iter=valid_iter, valid_steps=opt.valid_steps) if opt.tensorboard: if trainer.report_manager.tensorboard_writer: trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id): # NOTE: It's important that ``opt`` has been validated and updated # at this point. configure_process(opt, device_id) init_logger(opt.log_file) assert len(opt.accum_count) == len(opt.accum_steps), \ 'Number of accum_count values must match number of accum_steps' # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: checkpoint = None model_opt = opt #vocab = torch.load(opt.data + '.vocab.pt') train_iters = OrderedDict() valid_iters = OrderedDict() encoders = OrderedDict() decoders = OrderedDict() generators = OrderedDict() src_vocabs = OrderedDict() tgt_vocabs = OrderedDict() Fields_dict = OrderedDict() # variables needed for sharing the same embedding matrix across encoders and decoders firstTime = True weightToShare = None # we share the word embedding space when source lang and target lang are the same mapLang2Emb = {} #for (src_tgt_lang), data_path in zip(opt.src_tgt, opt.data): for index in range(len(opt.src_tgt)): src_tgt_lang = opt.src_tgt[index] data_path = opt.data[index] local_enc_dec_opts = AttrDict({ key: model_opt.__dict__[key] for key in model_opt.__dict__.keys() }) local_enc_dec_opts.model_type = update_to_local_attr( model_opt.model_type, index) #local_enc_dec_opts.audio_enc_pooling = model_opt.audio_enc_pooling[index] local_enc_dec_opts.audio_enc_pooling = update_to_local_attr( model_opt.audio_enc_pooling, index) local_enc_dec_opts.enc_layers = update_to_local_attr( model_opt.enc_layers, index) local_enc_dec_opts.dec_layers = update_to_local_attr( model_opt.dec_layers, index) local_enc_dec_opts.rnn_type = update_to_local_attr( model_opt.rnn_type, index) local_enc_dec_opts.encoder_type = update_to_local_attr( model_opt.encoder_type, index) local_enc_dec_opts.batch_size = update_to_local_attr( model_opt.batch_size, index) local_enc_dec_opts.batch_type = update_to_local_attr( model_opt.batch_type, index) local_enc_dec_opts.normalization = update_to_local_attr( model_opt.normalization, index) #local_enc_dec_opts.dec_rnn_size = model_opt.dec_rnn_size[index] src_lang, tgt_lang = src_tgt_lang.split('-') vocab = torch.load(data_path + '.vocab.pt') # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type[0], dynamic_dict=opt.copy_attn) else: fields = vocab # Report src and tgt vocab sizes, including for features for side in ['src', 'tgt']: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] for sn, sf in f_iter: if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) # Build model. encoder, src_embeddings = build_embeddings_then_encoder( local_enc_dec_opts, fields) encoders[src_lang] = encoder decoder, generator, tgt_embeddings = build_decoder_and_generator( local_enc_dec_opts, fields) decoders[tgt_lang] = decoder # Share the embedding matrix across all the encoders and decoders - preprocess with share_vocab required. if model_opt.share_embeddings and firstTime: tgt_embeddings.word_lut.weight = src_embeddings.word_lut.weight weightToShare = src_embeddings.word_lut.weight if model_opt.share_embeddings and (not firstTime): tgt_embeddings.word_lut.weight = weightToShare src_embeddings.word_lut.weight = weightToShare firstTime = False #TEST #if src_lang in mapLang2Emb: if src_lang in mapLang2Emb and model_opt.model_type == "text": encoder.embeddings.word_lut.weight = mapLang2Emb.get(src_lang) #TEST #else: elif model_opt.model_type == "text": mapLang2Emb[src_lang] = src_embeddings.word_lut.weight if tgt_lang in mapLang2Emb: decoder.embeddings.word_lut.weight = mapLang2Emb.get(tgt_lang) else: mapLang2Emb[tgt_lang] = tgt_embeddings.word_lut.weight #TEST if model_opt.model_type == "text": src_vocabs[src_lang] = fields['src'].base_field.vocab tgt_vocabs[tgt_lang] = fields['tgt'].base_field.vocab generators[tgt_lang] = generator # add this dataset iterator to the training iterators train_iters[(src_lang, tgt_lang)] = build_dataset_iter_fct( 'train', fields, data_path, local_enc_dec_opts) # add this dataset iterator to the validation iterators valid_iters[(src_lang, tgt_lang)] = build_dataset_iter_fct('valid', fields, data_path, local_enc_dec_opts, is_train=False) Fields_dict[src_tgt_lang] = fields # Build model. model = build_model(model_opt, opt, fields, encoders, decoders, generators, src_vocabs, tgt_vocabs, checkpoint) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, Fields_dict, optim) trainer = build_trainer(opt, device_id, model, fields, optim, generators, tgt_vocabs, model_saver=model_saver) # TODO: not implemented yet #train_iterables = [] #if len(opt.data_ids) > 1: # for train_id in opt.data_ids: # shard_base = "train_" + train_id # iterable = build_dataset_iter(shard_base, fields, opt, multi=True) # train_iterables.append(iterable) # train_iter = MultipleDatasetIterator(train_iterables, device_id, opt) #else: # train_iter = build_dataset_iter("train", fields, opt) #valid_iter = build_dataset_iter( # "valid", fields, opt, is_train=False) if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') train_steps = opt.train_steps if opt.single_pass and train_steps > 0: logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 trainer.train(train_iters, train_steps, opt.save_checkpoint_steps, valid_iters, opt.valid_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id): opt = training_opt_postprocessing(opt, device_id) init_logger(opt.log_file) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) # Load default opts values then overwrite it with opts from # the checkpoint. It's usefull in order to re-train a model # after adding a new option (not set in checkpoint) dummy_parser = configargparse.ArgumentParser() opts.model_opts(dummy_parser) default_opt = dummy_parser.parse_known_args([])[0] model_opt = default_opt model_opt.__dict__.update(checkpoint['opt'].__dict__) else: checkpoint = None model_opt = opt # Peek the first dataset to determine the data_type. # (All datasets have the same data_type). first_dataset = next(lazily_load_dataset("train", opt)) data_type = first_dataset.data_type # Load fields generated from preprocess phase. fields = load_fields(first_dataset, opt, checkpoint) # Report src/tgt features. src_features, tgt_features = _collect_report_features(fields) for j, feat in enumerate(src_features): logger.info(' * src feature %d size = %d' % (j, len(fields[feat].vocab))) for j, feat in enumerate(tgt_features): logger.info(' * tgt feature %d size = %d' % (j, len(fields[feat].vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = build_optim(model, opt, checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, device_id, model, fields, optim, data_type, model_saver=model_saver) def train_iter_fct(): return build_dataset_iter(lazily_load_dataset("train", opt), fields, opt) def valid_iter_fct(): return build_dataset_iter(lazily_load_dataset("valid", opt), fields, opt, is_train=False) # Do training. if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps, opt.valid_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id, batch_queue=None, semaphore=None): # NOTE: It's important that ``opt`` has been validated and updated # at this point. configure_process(opt, device_id) init_logger(opt.log_file) # save training settings if opt.log_file: shutil.copy2(opt.config, opt.exp_dir) logger.info(vars(opt)) assert len(opt.accum_count) == len(opt.accum_steps), \ 'Number of accum_count values must match number of accum_steps' # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: checkpoint = None model_opt = opt # added by @memray for multiple datasets if opt.vocab and opt.vocab != 'none': vocab = torch.load(opt.vocab) elif opt.encoder_type == 'pretrained': vocab = None else: vocab = None # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab( vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab # @memray: a temporary workaround, as well as train.py line 43 if opt.model_type == "keyphrase": if opt.tgt_type in ["one2one", "multiple"]: if 'sep_indices' in fields: del fields['sep_indices'] else: if 'sep_indices' not in fields: sep_indices = Field( use_vocab=False, dtype=torch.long, postprocessing=make_tgt, sequential=False) fields["sep_indices"] = sep_indices if 'src_ex_vocab' not in fields: src_ex_vocab = RawField() fields["src_ex_vocab"] = src_ex_vocab tokenizer = None if opt.pretrained_tokenizer: tokenizer = load_pretrained_tokenizer(opt.pretrained_tokenizer, opt.cache_dir, opt.special_vocab_path) setattr(opt, 'vocab_size', len(tokenizer)) if opt.data_type == 'news': fields = reload_news_fields(fields, opt, tokenizer) # Report src and tgt vocab sizes, including for features for side in ['src', 'tgt']: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] for sn, sf in f_iter: if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer( opt, device_id, model, fields, optim, model_saver=model_saver) if batch_queue is None: if len(opt.data_ids) > 1: # added by @memray, for loading multiple datasets if opt.multi_dataset: shard_base = "train" train_iter = build_dataset_iter(shard_base, fields, opt, tokenizer=tokenizer) else: train_shards = [] for train_id in opt.data_ids: shard_base = "train_" + train_id train_shards.append(shard_base) train_iter = build_dataset_iter_multiple(train_shards, fields, opt, tokenizer=tokenizer) else: shard_base = "train" train_iter = build_dataset_iter(shard_base, fields, opt) else: assert semaphore is not None, \ "Using batch_queue requires semaphore as well" def _train_iter(): while True: batch = batch_queue.get() semaphore.release() yield batch train_iter = _train_iter() if opt.valid: valid_iter = build_dataset_iter( "valid", fields, opt, is_train=False) else: valid_iter = None if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') train_steps = opt.train_steps if opt.single_pass and train_steps > 0: logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 trainer.train( train_iter, train_steps, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iter=valid_iter, valid_steps=opt.valid_steps) if trainer.report_manager.tensorboard_writer is not None: trainer.report_manager.tensorboard_writer.close()
def validate(opt, device_id=0): configure_process(opt, device_id) configure_process if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] if old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab # Report src and tgt vocab sizes, including for features for side in ['src', 'tgt']: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] for sn, sf in f_iter: if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) #_check_save_model_path(opt) valid_iter = build_dataset_iter("valid", fields, opt, is_train=False) tgt_field = dict(fields)["tgt"].base_field valid_loss = onmt.utils.loss.build_loss_compute(model, tgt_field, opt, train=False) model.eval() with torch.no_grad(): stats = onmt.utils.Statistics() for batch in valid_iter: src, src_lengths = batch.src if isinstance(batch.src, tuple) \ else (batch.src, None) tgt = batch.tgt # F-prop through the model. outputs, attns = model(src, tgt, src_lengths) # Compute loss. _, batch_stats = valid_loss(batch, outputs, attns) # Update statistics. stats.update(batch_stats) print('n words: %d' % stats.n_words) print('Validation perplexity: %g' % stats.ppl()) print('Validation accuracy: %g' % stats.accuracy()) print('Validation avg attention entropy: %g' % stats.attn_entropy())
def main(opt, device_id): opt = training_opt_postprocessing(opt, device_id) init_logger(opt.log_file) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) # Load default opts values then overwrite it with opts from # the checkpoint. It's usefull in order to re-train a model # after adding a new option (not set in checkpoint) dummy_parser = configargparse.ArgumentParser() opts.model_opts(dummy_parser) default_opt = dummy_parser.parse_known_args([])[0] model_opt = default_opt model_opt.__dict__.update(checkpoint['opt'].__dict__) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: checkpoint = None model_opt = opt vocab = torch.load(opt.data + '.vocab.pt') ################ # model_opt.train_steps = 0 ################ # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) # import pickle # pickle.dump('\n'.join([a for a in vocab['tgt'][0][1].fields[0][1].vocab.itos]), # open('cnndm_duc17_vocab_addtgt.vocab.itos.txt', 'w', encoding='utf-8')) if old_style_vocab(vocab): data_type = opt.model_type fields = load_old_vocab(vocab, data_type, dynamic_dict=opt.copy_attn) else: fields = vocab ################# # if len(fields['src'][0][1].fields[0][1].vocab.itos) == 30522: # fields['src'][0][1].fields[0][1].vocab = fields['tgt'][0][1].fields[0][1].vocab if len(fields['src'][0][1].base_field.vocab.itos) == 30801: fields['src'][0][1].base_field.vocab.stoi.pop( fields['src'][0][1].base_field.vocab.itos.pop(-1)) fields['src'][0][1].base_field.vocab.stoi.pop( fields['src'][0][1].base_field.vocab.itos.pop(-1)) fields['src'][0][1].base_field.vocab.stoi.pop( fields['src'][0][1].base_field.vocab.itos.pop(-1)) fields['src'][0][1].base_field.vocab.stoi.pop( fields['src'][0][1].base_field.vocab.itos.pop(-1)) vocab['src'][0][1].base_field.vocab.freqs.pop('<s>') vocab['src'][0][1].base_field.vocab.freqs.pop('</s>') vocab['src'][0][1].base_field.vocab.freqs.pop('<t>') vocab['src'][0][1].base_field.vocab.freqs.pop('</t>') # if len(fields['tgt'][0][1].fields[0][1].vocab.itos) == 30529: # fields['tgt'][0][1].fields[0][1].vocab.stoi.pop('<unk>') # fields['tgt'][0][1].fields[0][1].vocab.stoi.pop('<blank>') # fields['tgt'][0][1].fields[0][1].vocab.stoi.pop('<s>') # fields['tgt'][0][1].fields[0][1].vocab.stoi.pop('</s>') # fields['tgt'][0][1].fields[0][1].vocab.itos = fields['tgt'][0][1].fields[0][1].vocab.itos[4:] # fields['tgt'][0][1].fields[0][1].vocab.itos += ['<s>', '</s>'] # for i, k in enumerate(fields['tgt'][0][1].fields[0][1].vocab.itos): # fields['tgt'][0][1].fields[0][1].vocab.stoi[k] = i # if len(fields['tgt'][0][1].fields[0][1].vocab.itos) == 30526: # fields['tgt'][0][1].fields[0][1].vocab.stoi.pop('<unk>') # fields['tgt'][0][1].fields[0][1].vocab.stoi.pop('<blank>') # fields['tgt'][0][1].fields[0][1].vocab.stoi.pop('<s>') # fields['tgt'][0][1].fields[0][1].vocab.stoi.pop('</s>') # fields['tgt'][0][1].fields[0][1].vocab.itos += ['</s>'] # fields['tgt'][0][1].fields[0][1].vocab.stoi['</s>'] = 30526 ################# # Report src and tgt vocab sizes, including for features for side in ['src', 'tgt']: for name, f in fields[side]: try: f_iter = iter(f) except TypeError: f_iter = [(name, f)] for sn, sf in f_iter: if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) # import pickle # model = pickle.load(open('model.pkl', 'rb')) # model.param = torch.nn.Parameter(torch.randn(model.encoder.total_hidden_dim * 4, dtype=torch.float32, # device=torch.device('cuda')).view(1, -1, 1)) n_params, enc, dec = _tally_parameters(model) # model = build_model(model_opt, opt, fields, checkpoint) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # optim = Optimizer.from_opt(model, opt, checkpoint=None) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, device_id, model, fields, optim, model_saver=model_saver) # this line is kind of a temporary kludge because different objects expect # fields to have a different structure dataset_fields = dict(chain.from_iterable(fields.values())) train_iter = build_dataset_iter("train", dataset_fields, opt) valid_iter = build_dataset_iter("valid", dataset_fields, opt, is_train=False) if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') trainer.train(train_iter, opt.train_steps, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iter=valid_iter, valid_steps=opt.valid_steps, bert=opt.bert) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id): opt = training_opt_postprocessing(opt, device_id) init_logger(opt.log_file) out_file = None best_test_score, best_ckpt = -10000, None dummy_parser = argparse.ArgumentParser(description='all_dev.py') opts.model_opts(dummy_parser) dummy_opt = dummy_parser.parse_known_args([])[0] for i in range(0, opt.train_epochs, 10): ckpt_path = '{}_epoch_{}.pt'.format(opt.save_model, i + 1) logger.info('Loading checkpoint from %s' % ckpt_path) checkpoint = torch.load(ckpt_path, map_location=lambda storage, loc: storage) model_opt = checkpoint['opt'] fields = load_fields_from_vocab(checkpoint['vocab'], data_type="text") # Build model. model = build_model(model_opt, opt, fields, checkpoint) assert opt.train_from == '' # do not load optimizer state optim = build_optim(model, opt, checkpoint) # Build model saver, no need to create task dir for dev if not os.path.exists('experiments/all_dev'): os.mkdir('experiments/all_dev') os.mkdir('experiments/all_dev/' + opt.meta_dev_task) elif not os.path.exists('experiments/all_dev/' + opt.meta_dev_task): os.mkdir('experiments/all_dev/' + opt.meta_dev_task) model_saver = build_model_saver( model_opt, 'experiments/all_dev/' + opt.meta_dev_task + '/model', opt, model, fields, optim) trainer = build_trainer(opt, device_id, model, fields, optim, "text", model_saver=model_saver) train_iter = list( build_dataset_iter(lazily_load_dataset("train", opt), fields, opt)) # do training on trainset of meta-dev task trainer.train(train_iter, opt.inner_iterations) # do evaluation on devset of meta-dev task best_dev_score, best_model_path = -10000, None for model_path in os.listdir('experiments/all_dev/' + opt.meta_dev_task): if model_path.find('.pt') == -1: continue if out_file is None: out_file = codecs.open(opt.output, 'w+', 'utf-8') fields, model, model_opt = onmt.model_builder.load_test_model( opt, dummy_opt.__dict__, model_path='experiments/all_dev/' + opt.meta_dev_task + '/' + model_path) scorer = onmt.translate.GNMTGlobalScorer(opt.alpha, opt.beta, opt.coverage_penalty, opt.length_penalty) kwargs = { k: getattr(opt, k) for k in [ "beam_size", "n_best", "max_length", "min_length", "stepwise_penalty", "block_ngram_repeat", "ignore_when_blocking", "dump_beam", "report_bleu", "replace_unk", "gpu", "verbose", "fast", "mask_from" ] } fields['graph'] = torchtext.data.Field(sequential=False) translator = Translator(model, fields, global_scorer=scorer, out_file=out_file, report_score=False, copy_attn=model_opt.copy_attn, logger=logger, log_probs_out_file=None, **kwargs) # make translation and save result all_scores, all_predictions = translator.translate( src_path='processed_data/meta-dev/' + opt.meta_dev_task + '/src-dev.txt', tgt_path=None, src_dir=None, batch_size=opt.translate_batch_size, attn_debug=False) # dump predictions f = open('experiments/all_dev/' + opt.meta_dev_task + '/dev_predictions.csv', 'w', encoding='utf-8') f.write('smiles,property\n') for n_best_mols in all_predictions: for mol in n_best_mols: f.write(mol.replace(' ', '') + ',0\n') f.close() # call chemprop to get scores test_path = '\"' + 'experiments/all_dev/' + opt.meta_dev_task + '/dev_predictions.csv' + '\"' checkpoint_path = '\"' + 'scorer_ckpts/' + opt.meta_dev_task + '/model.pt' + '\"' preds_path = '\"' + 'experiments/all_dev/' + opt.meta_dev_task + '/dev_scores.csv' + '\"' # in case of all mols are invalid (will produce not output file by chemprop) # the predictions are copied into score file cmd = 'cp {} {}'.format(test_path, preds_path) result = os.popen(cmd) result.close() cmd = 'python chemprop/predict.py --test_path {} --checkpoint_path {} --preds_path {} --num_workers 0'.format( test_path, checkpoint_path, preds_path) scorer_result = os.popen(cmd) scorer_result.close() # read score file and get score score = read_score_csv('experiments/all_dev/' + opt.meta_dev_task + '/dev_scores.csv') assert len(score) % opt.beam_size == 0 # dev_scores = [] # for i in range(0, len(score), opt.beam_size): # dev_scores.append(sum([x[1] for x in score[i:i+opt.beam_size]]) / opt.beam_size) # report dev score dev_metrics = calculate_metrics(opt.meta_dev_task, 'dev', 'dev', score) logger.info('dev metrics: ' + str(dev_metrics)) dev_score = dev_metrics['success_rate'] if dev_score > best_dev_score: logger.info('New best dev success rate: {:.4f} by {}'.format( dev_score, model_path)) best_model_path = model_path best_dev_score = dev_score else: logger.info('dev success rate: {:.4f} by {}'.format( dev_score, model_path)) del fields del model del model_opt del scorer del translator gc.collect() assert best_model_path != None # do testing on testset of meta-dev task if out_file is None: out_file = codecs.open(opt.output, 'w+', 'utf-8') fields, model, model_opt = onmt.model_builder.load_test_model( opt, dummy_opt.__dict__, model_path='experiments/all_dev/' + opt.meta_dev_task + '/' + best_model_path) scorer = onmt.translate.GNMTGlobalScorer(opt.alpha, opt.beta, opt.coverage_penalty, opt.length_penalty) kwargs = { k: getattr(opt, k) for k in [ "beam_size", "n_best", "max_length", "min_length", "stepwise_penalty", "block_ngram_repeat", "ignore_when_blocking", "dump_beam", "report_bleu", "replace_unk", "gpu", "verbose", "fast", "mask_from" ] } fields['graph'] = torchtext.data.Field(sequential=False) translator = Translator(model, fields, global_scorer=scorer, out_file=out_file, report_score=False, copy_attn=model_opt.copy_attn, logger=logger, log_probs_out_file=None, **kwargs) # make translation and save result all_scores, all_predictions = translator.translate( src_path='processed_data/meta-dev/' + opt.meta_dev_task + '/src-test.txt', tgt_path=None, src_dir=None, batch_size=opt.translate_batch_size, attn_debug=False) # dump predictions f = open('experiments/all_dev/' + opt.meta_dev_task + '/test_predictions.csv', 'w', encoding='utf-8') f.write('smiles,property\n') for n_best_mols in all_predictions: for mol in n_best_mols: f.write(mol.replace(' ', '') + ',0\n') f.close() # call chemprop to get scores test_path = '\"' + 'experiments/all_dev/' + opt.meta_dev_task + '/test_predictions.csv' + '\"' checkpoint_path = '\"' + 'scorer_ckpts/' + opt.meta_dev_task + '/model.pt' + '\"' preds_path = '\"' + 'experiments/all_dev/' + opt.meta_dev_task + '/test_scores.csv' + '\"' # in case of all mols are invalid (will produce not output file by chemprop) # the predictions are copied into score file cmd = 'cp {} {}'.format(test_path, preds_path) result = os.popen(cmd) result.close() cmd = 'python chemprop/predict.py --test_path {} --checkpoint_path {} --preds_path {} --num_workers 0'.format( test_path, checkpoint_path, preds_path) scorer_result = os.popen(cmd) # logger.info('{}'.format('\n'.join(scorer_result.readlines()))) scorer_result.close() # read score file and get score score = read_score_csv('experiments/all_dev/' + opt.meta_dev_task + '/test_scores.csv') assert len(score) % opt.beam_size == 0 # test_scores = [] # for i in range(0, len(score), opt.beam_size): # test_scores.append(sum([x[1] for x in score[i:i+opt.beam_size]]) / opt.beam_size) # report if it is the best on test test_metrics = calculate_metrics(opt.meta_dev_task, 'dev', 'test', score) logger.info('test metrics: ' + str(test_metrics)) test_score = test_metrics['success_rate'] if test_score > best_test_score: best_ckpt = ckpt_path logger.info('New best test success rate: {:.4f} by {}'.format( test_score, ckpt_path)) best_test_score = test_score else: logger.info('test success rate: {:.4f} by {}'.format( test_score, ckpt_path)) del model_opt del fields del checkpoint del model del optim del model_saver del trainer gc.collect()
model_opt = opt checkpoint = None # Peek the fisrt dataset to determine the data_type. # (All datasets have the same data_type). first_dataset = next(lazily_load_dataset("train", opt)) data_type = first_dataset.data_type # Load fields generated from preprocess phase. fields = _load_fields(first_dataset, data_type, opt, checkpoint) # Report src/tgt features. _collect_report_features(fields) # Build model. model = build_model(model_opt, opt, fields, checkpoint) remove(args.host, args.port, "OpenNMT") probe( args.name, model, args.host, args.port, when=lambda m, o: m._v.state == "dev", which=lambda m, o: True, #o._v.operation_name in ["encoder", "decoder"], parameters=False, forward=True, backward=False, batch_axis=1)
def main(opt, device_id): # NOTE: It's important that ``opt`` has been validated and updated # at this point. configure_process(opt, device_id) init_logger(opt.log_file) assert len(opt.accum_count) == len(opt.accum_steps), \ 'Number of accum_count values must match number of accum_steps' # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: checkpoint = None model_opt = opt vocab = torch.load(opt.data + '.vocab.pt') logger.info('Loading alignment.') lemma_aligns = open(model_opt.lemma_align, 'rb').readlines() src_stoi = vocab['src'].base_field.vocab.stoi lemma_stoi = vocab['word_topic'].base_field.vocab.stoi w2l = {} word_to_lemma = [] for pair in lemma_aligns: pair = pair.strip().split() w2l[src_stoi[pair[0].decode('utf-8')]] = \ lemma_stoi[pair[1].decode('utf-8')] w2l[src_stoi['unk']] = lemma_stoi['unk'] for index in range(len(vocab['src'].base_field.vocab.itos)): if index in w2l: word_to_lemma.append(w2l[index]) else: word_to_lemma.append(w2l[lemma_stoi['unk']]) word_to_lemma = torch.tensor(word_to_lemma) logger.info('Loading topic matrix') if device_id >= 0: topic_matrix = torch.load(opt.topic_matrix, map_location=torch.device(device_id)) else: topic_matrix = torch.load(opt.topic_matrix) if opt.model_dtype == 'fp16': topic_matrix = topic_matrix.half() # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab # Report src and tgt vocab sizes, including for features for side in ['src', 'tgt']: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] for sn, sf in f_iter: if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, device_id, model, fields, optim, model_saver=model_saver) train_iter = build_dataset_iter("train", fields, opt) valid_iter = build_dataset_iter("valid", fields, opt, is_train=False) if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') train_steps = opt.train_steps if opt.single_pass and train_steps > 0: logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 trainer.train(topic_matrix, word_to_lemma, train_iter, train_steps, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iter=valid_iter, valid_steps=opt.valid_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id): # NOTE: It's important that ``opt`` has been validated and updated # at this point. configure_process(opt, device_id) init_logger(opt.log_file) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: checkpoint = None model_opt = opt vocab = torch.load(opt.data + '.vocab.pt') # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab # Report src and tgt vocab sizes, including for features for side in ['src', 'tgt']: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] for sn, sf in f_iter: if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, device_id, model, fields, optim, model_saver=model_saver) train_iter = build_dataset_iter("train", fields, opt) valid_iter = build_dataset_iter("valid", fields, opt, is_train=False) if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') train_steps = opt.train_steps if opt.single_pass and train_steps > 0: logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 trainer.train(train_iter, train_steps, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iter=valid_iter, valid_steps=opt.valid_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id): #TODO delete all these lines related to WALS features #begin SimulationLanguages = [opt.wals_src, opt.wals_tgt] print('Loading WALS features from databases...') cwd = os.getcwd() db = sqlite3.connect(cwd + '/onmt/WalsValues.db') cursor = db.cursor() cursor.execute('SELECT * FROM WalsValues') WalsValues = cursor.fetchall() db = sqlite3.connect(cwd + '/onmt/FeaturesList.db') cursor = db.cursor() cursor.execute('SELECT * FROM FeaturesList') FeaturesList = cursor.fetchall() db = sqlite3.connect(cwd + '/onmt/FTInfos.db') cursor = db.cursor() cursor.execute('SELECT * FROM FTInfos') FTInfos = cursor.fetchall() db = sqlite3.connect(cwd + '/onmt/FTList.db') cursor = db.cursor() cursor.execute('SELECT * FROM FTList') FTList = cursor.fetchall() ListLanguages = [] for i in WalsValues: ListLanguages.append(i[0]) FeatureTypes = [] for i in FTList: FeatureTypes.append((i[0], i[1].split(','))) FeatureNames = [] for i in FeatureTypes: FeatureNames += i[1] FeatureTypesNames = [] for i in FeatureTypes: FeatureTypesNames.append(i[0]) FeatureValues, FeatureTensors = get_feat_values(SimulationLanguages, WalsValues, FeaturesList, ListLanguages, FeatureTypes, FeatureNames) print('WALS databases loaded!') #end #TODO: load wals features from command-line (wals.npz) # FeatureValues: defaultdict with feature values, per language. # FeatureTensors: tensor of possible outputs, per feature. opt = training_opt_postprocessing(opt, device_id) init_logger(opt.log_file) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = checkpoint['opt'] else: checkpoint = None model_opt = opt # Peek the first dataset to determine the data_type. # (All datasets have the same data_type). first_dataset = next(lazily_load_dataset("train", opt)) data_type = first_dataset.data_type # Load fields generated from preprocess phase. fields = _load_fields(first_dataset, data_type, opt, checkpoint) # Report src/tgt features. src_features, tgt_features = _collect_report_features(fields) for j, feat in enumerate(src_features): logger.info(' * src feature %d size = %d' % (j, len(fields[feat].vocab))) for j, feat in enumerate(tgt_features): logger.info(' * tgt feature %d size = %d' % (j, len(fields[feat].vocab))) # Build model. #TODO: remove all parameters related to WALS features: FeatureValues, FeatureTensors, FeatureTypes, FeaturesList, FeatureNames, FTInfos, FeatureTypesNames, SimulationLanguages #TODO: include four parameter related to WALS features: the four numpy arrays separately model = build_model(model_opt, opt, fields, checkpoint, FeatureValues, FeatureTensors, FeatureTypes, FeaturesList, FeatureNames, FTInfos, FeatureTypesNames, SimulationLanguages) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = build_optim(model, opt, checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, device_id, model, fields, optim, data_type, model_saver=model_saver) def train_iter_fct(): return build_dataset_iter( lazily_load_dataset("train", opt), fields, opt) def valid_iter_fct(): return build_dataset_iter( lazily_load_dataset("valid", opt), fields, opt, is_train=False) # Do training. trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps, opt.valid_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id): # NOTE: It's important that ``opt`` has been validated and updated # at this point. configure_process(opt, device_id) init_logger(opt.log_file) assert len(opt.accum_count) == len(opt.accum_steps), \ 'Number of accum_count values must match number of accum_steps' # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: checkpoint = None model_opt = opt vocab = torch.load(opt.data + '.vocab.pt') # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab( vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab # Report src and tgt vocab sizes, including for features for side in ['src', 'tgt']: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] for sn, sf in f_iter: if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer( opt, device_id, model, fields, optim, model_saver=model_saver) train_iter = build_dataset_iter("train", fields, opt) valid_iter = build_dataset_iter( "valid", fields, opt, is_train=False) if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') train_steps = opt.train_steps if opt.single_pass and train_steps > 0: logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 trainer.train( train_iter, train_steps, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iter=valid_iter, valid_steps=opt.valid_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id, batch_queue=None, semaphore=None): # NOTE: It's important that ``opt`` has been validated and updated # at this point. configure_process(opt, device_id) init_logger(opt.log_file) assert len(opt.accum_count) == len(opt.accum_steps), \ 'Number of accum_count values must match number of accum_steps' # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: checkpoint = None model_opt = opt vocab = torch.load(opt.data + '.vocab.pt') if opt.teacher_model_path: logger.info('Loading teacher model from {path}'.format( path=opt.teacher_model_path)) teacher_model_ckpt = torch.load( opt.teacher_model_path, map_location=lambda storage, loc: storage) teacher_model_opt = ArgumentParser.ckpt_model_opts( teacher_model_ckpt['opt']) ArgumentParser.update_model_opts(teacher_model_opt) ArgumentParser.validate_model_opts(teacher_model_opt) logger.info('Loading vocab from checkpoint at {path}'.format( path=opt.teacher_model_path)) teacher_vocab = teacher_model_ckpt['vocab'] # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab teacher_fields = teacher_vocab if opt.teacher_model_path else None # patch for fields that may be missing in old data/model # patch_fields(opt, fields) # Report src and tgt vocab sizes, including for features report_vocab_size(fields) if teacher_fields is not None: report_vocab_size(teacher_fields) # Build model. fields_opt = {"original": fields, "teacher": teacher_fields} model = custom_builder.build_model(model_opt, opt, fields_opt, checkpoint) # model = build_model(model_opt, opt, fields, checkpoint) teacher_model = build_model( teacher_model_opt, teacher_model_opt, teacher_fields, teacher_model_ckpt) if opt.teacher_model_path else None n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) if teacher_model is not None: n_params, enc, dec = _tally_parameters(teacher_model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(teacher_model_opt) # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver # model_saver = build_model_saver(model_opt, opt, model, fields, optim) model_saver = custom_model_saver.build_model_saver(model_opt, opt, model, fields_opt, optim) tgt_field = dict(teacher_fields)["tgt"].base_field if teacher_model is not None \ else dict(fields)["tgt"].base_field sos_id = tgt_field.vocab.stoi[tgt_field.init_token] if teacher_model is not None and opt.word_sampling: sampler = Emulator(teacher_model, teacher_fields, device_id, max_length=50, random_sampling_topk=5) else: sampler = None if teacher_model is not None: trainer = build_trainer(opt, device_id, model, teacher_fields, optim, model_saver, teacher_model=teacher_model, emulator=sampler) else: trainer = build_trainer(opt, device_id, model, fields, optim, model_saver, teacher_model=teacher_model, emulator=sampler) if batch_queue is None: if len(opt.data_ids) > 1: train_shards = [] for train_id in opt.data_ids: shard_base = "train_" + train_id train_shards.append(shard_base) train_iter = build_dataset_iter_multiple(train_shards, fields, opt) else: if opt.data_ids[0] is not None: shard_base = "train_" + opt.data_ids[0] else: shard_base = "train" train_iter = build_dataset_iter(shard_base, fields, opt) else: assert semaphore is not None, \ "Using batch_queue requires semaphore as well" def _train_iter(): while True: batch = batch_queue.get() semaphore.release() yield batch train_iter = _train_iter() valid_iter = build_dataset_iter("valid", fields, opt, is_train=False) if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') train_steps = opt.train_steps if opt.single_pass and train_steps > 0: logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 trainer.train(train_iter, train_steps, sos_id=sos_id, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iter=valid_iter, valid_steps=opt.valid_steps) if trainer.report_manager.tensorboard_writer is not None: trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id): opt = training_opt_postprocessing(opt, device_id) init_logger(opt.log_file) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) # Load default opts values then overwrite it with opts from # the checkpoint. It's usefull in order to re-train a model # after adding a new option (not set in checkpoint) dummy_parser = configargparse.ArgumentParser() opts.model_opts(dummy_parser) default_opt = dummy_parser.parse_known_args([])[0] model_opt = default_opt model_opt.__dict__.update(checkpoint['opt'].__dict__) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: checkpoint = None model_opt = opt vocab = torch.load(opt.data + '.vocab.pt') # Load a shard dataset to determine the data_type. # (All datasets have the same data_type). # this should be refactored out of existence reasonably soon first_dataset = torch.load(glob.glob(opt.data + '.train*.pt')[0]) data_type = first_dataset.data_type # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way if old_style_vocab(vocab): fields = load_fields_from_vocab(vocab, data_type) else: fields = vocab # Report src and tgt vocab sizes, including for features for side in ['src', 'tgt']: for name, f in fields[side]: if f.use_vocab: logger.info(' * %s vocab size = %d' % (name, len(f.vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = build_optim(model, opt, checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, device_id, model, fields, optim, data_type, model_saver=model_saver) # this line is kind of a temporary kludge because different objects expect # fields to have a different structure dataset_fields = dict(chain.from_iterable(fields.values())) train_iter = build_dataset_iter("train", dataset_fields, opt) valid_iter = build_dataset_iter("valid", dataset_fields, opt, is_train=False) if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') trainer.train(train_iter, valid_iter, opt.train_steps, opt.valid_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id): # opt = training_opt_postprocessing(opt, device_id) init_logger(opt.log_file) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = checkpoint['opt'] else: raise Exception('You need to load a model') logger.info('Loading data from %s' % opt.data) dataset = next(lazily_load_dataset("train", opt)) data_type = dataset.data_type logger.info('Data type %s' % data_type) # Load fields generated from preprocess phase. fields = _load_fields(dataset, data_type, opt, checkpoint) # Build model. model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = build_optim(model, opt, checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) dataset_iter = build_dataset_iter(lazily_load_dataset("train", opt), fields, opt) out_file = codecs.open(opt.output, 'w+', 'utf-8') scorer = onmt.translate.GNMTGlobalScorer(opt.alpha, opt.beta, opt.coverage_penalty, opt.length_penalty) translation_builder = TranslationBuilder(dataset, fields, n_best=opt.n_best, replace_unk=opt.replace_unk, has_tgt=False) def train_iter_fct(): return build_dataset_iter(lazily_load_dataset("train", opt), fields, opt) trainer = build_trainer(opt, device_id, model, fields, optim, data_type, model_saver=model_saver) translator = Translator(trainer.model, fields, opt.beam_size, global_scorer=scorer, out_file=out_file, report_score=False, copy_attn=model_opt.copy_attn, logger=logger) for i, batch in enumerate(dataset_iter): unprocessed_translations = translator.translate_batch(batch, dataset) translations = translation_builder.from_batch(unprocessed_translations) print "Translations: ", ' '.join(translations[0].pred_sents[0]) trainer.train_from_data(batch, train_steps=1) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id): # NOTE: It's important that ``opt`` has been validated and updated # at this point. configure_process(opt, device_id) init_logger(opt.log_file) assert len(opt.accum_count) == len(opt.accum_steps), \ 'Number of accum_count values must match number of accum_steps' # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) print("load weight success") model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: checkpoint = None model_opt = opt vocab = torch.load(opt.data + '.vocab.pt') # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): print("old style vocab") fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: print("not old style") fields = vocab # Report src and tgt vocab sizes, including for features for side in ['src', 'tgt']: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] for sn, sf in f_iter: if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) # added and deleted by zhengquan # model = torch.nn.parallel.DistributedDataParallel(model, # device_ids=[opt.local_rank], # output_device=opt.local_rank) # added and deleted by zhengquan for the availability of cuda devices. # In the DistributedDataParallel doc, it says # "DistributedDataParallel with multi-device module only works " # "with CUDA devices, but module parameters locate in {}." n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, device_id, model, fields, optim, model_saver=model_saver) train_iter = build_dataset_iter( "train", fields, opt) #在build_dataset_iter()中会用opt中的dataset_paths来载入数据 valid_iter = build_dataset_iter("valid", fields, opt, is_train=False) if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') train_steps = opt.train_steps if opt.single_pass and train_steps > 0: logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 trainer.train(train_iter, train_steps, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iter=valid_iter, valid_steps=opt.valid_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id, batch_queue=None, semaphore=None): # NOTE: It's important that ``opt`` has been validated and updated # at this point. configure_process(opt, device_id) init_logger(opt.log_file) assert len(opt.accum_count) == len(opt.accum_steps), \ 'Number of accum_count values must match number of accum_steps' # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: checkpoint = None model_opt = opt vocab = torch.load(opt.data + '.vocab.pt') # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab # Report src and tgt vocab sizes, including for features for side in ['src', 'tgt']: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] for sn, sf in f_iter: if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec, nontrainable = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('non-trainable parameters (tgt_out_emb): %d' % nontrainable) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, device_id, model, fields, optim, model_saver=model_saver) if batch_queue is None: if len(opt.data_ids) > 1: train_shards = [] for train_id in opt.data_ids: shard_base = "train_" + train_id train_shards.append(shard_base) train_iter = build_dataset_iter_multiple(train_shards, fields, opt) else: if opt.data_ids[0] is not None: shard_base = "train_" + opt.data_ids[0] else: shard_base = "train" train_iter = build_dataset_iter(shard_base, fields, opt) else: assert semaphore is not None, \ "Using batch_queue requires semaphore as well" def _train_iter(): while True: batch = batch_queue.get() semaphore.release() yield batch train_iter = _train_iter() valid_iter = build_dataset_iter("valid", fields, opt, is_train=False) if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') train_steps = opt.train_steps if opt.single_pass and train_steps > 0: logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 trainer.train(train_iter, train_steps, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iter=valid_iter, valid_steps=opt.valid_steps) if trainer.report_manager.tensorboard_writer is not None: trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id): opt = training_opt_postprocessing(opt, device_id) init_logger(opt.log_file) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) # Load default opts values then overwrite it with opts from # the checkpoint. It's usefull in order to re-train a model # after adding a new option (not set in checkpoint) dummy_parser = configargparse.ArgumentParser() opts.model_opts(dummy_parser) default_opt = dummy_parser.parse_known_args([])[0] model_opt = default_opt model_opt.__dict__.update(checkpoint['opt'].__dict__) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: checkpoint = None model_opt = opt vocab = torch.load(opt.data + '.vocab.pt') # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): data_type = opt.model_type fields = load_old_vocab(vocab, data_type, dynamic_dict=opt.copy_attn) else: fields = vocab # Report src and tgt vocab sizes, including for features for side in ['src', 'tgt']: for name, f in fields[side]: try: f_iter = iter(f) except TypeError: f_iter = [(name, f)] for sn, sf in f_iter: if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, device_id, model, fields, optim, model_saver=model_saver) # this line is kind of a temporary kludge because different objects expect # fields to have a different structure dataset_fields = dict(chain.from_iterable(fields.values())) train_iter = build_dataset_iter("train", dataset_fields, opt) valid_iter = build_dataset_iter("valid", dataset_fields, opt, is_train=False) if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') train_steps = opt.train_steps if opt.single_pass and train_steps > 0: logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 trainer.train(train_iter, train_steps, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iter=valid_iter, valid_steps=opt.valid_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id): opt = training_opt_postprocessing(opt, device_id) init_logger(opt.log_file) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = checkpoint['opt'] else: checkpoint = None model_opt = opt # Peek the first dataset to determine the data_type. # (All datasets have the same data_type). first_dataset = next(lazily_load_dataset("train", opt)) data_type = first_dataset.data_type # Load fields generated from preprocess phase. fields = _load_fields(first_dataset, data_type, opt, checkpoint) # Build model. model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) logger.info('* batch_size: %d' % opt.batch_size) _check_save_model_path(opt) # Build optimizer. optim = build_optim(model, opt, checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, device_id, model, fields, optim, data_type, model_saver=model_saver) def data_iter_fct(data_stage): """data_stage: train / valid""" pt_file = opt.data + '.' + data_stage + '.pt' logger.info('Loading {} dataset'.format(data_stage)) dataset = torch.load(pt_file) logger.info('Loaded {} dataset'.format(data_stage)) dataset.fields = fields is_train = True if data_stage == "train" else False batch_size = opt.batch_size if is_train else opt.valid_batch_size repeat = True if data_stage == "train" else False if opt.gpuid != -1: device = "cuda" else: device = "cpu" def sort_key(ex): """ Sort using length of source sentences. """ return ex.total_tokens return torchtext.data.Iterator(dataset=dataset, batch_size=batch_size, device=device, train=is_train, sort=False, sort_key=sort_key, repeat=repeat) # Do training. trainer.train(data_iter_fct, opt.train_steps, opt.valid_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def main(opt): opt = training_opt_postprocessing(opt) init_logger(opt.log_file) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = checkpoint['opt'] else: checkpoint = None model_opt = opt if opt.load_pretrained_selector_from: logger.info('Loading selector checkpoint from %s' % opt.load_pretrained_selector_from) sel_checkpoint = torch.load(opt.load_pretrained_selector_from, map_location=lambda storage, loc: storage) else: sel_checkpoint = None if opt.load_pretrained_s2s_generator_from: logger.info('Loading s2s generator checkpoint from %s' % opt.load_pretrained_s2s_generator_from) s2s_gen_checkpoint = torch.load( opt.load_pretrained_s2s_generator_from, map_location=lambda storage, loc: storage) else: s2s_gen_checkpoint = None # Peek the fisrt dataset to determine the data_type. # (All datasets have the same data_type). first_dataset = next(lazily_load_dataset("train", opt)) data_type = first_dataset.data_type # Load fields generated from preprocess phase. fields = _load_fields(first_dataset, data_type, opt, checkpoint) # Report src/tgt features. src_features, tgt_features = _collect_report_features(fields) for j, feat in enumerate(src_features): logger.info(' * src feature %d size = %d' % (j, len(fields[feat].vocab))) for j, feat in enumerate(tgt_features): logger.info(' * tgt feature %d size = %d' % (j, len(fields[feat].vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint, sel_checkpoint, s2s_gen_checkpoint) # Fix the pretrained selector parameters if needed if model_opt.fix_sel_all: assert opt.load_pretrained_selector_from assert opt.sel_lambda == 0.0 assert not model_opt.fix_sel_classifier for name, param in model.named_parameters(): if 'selector' in name: param.requires_grad = False # only fix the classifier of the selector if model_opt.fix_sel_classifier: assert opt.load_pretrained_selector_from assert not model_opt.fix_sel_all for name, param in model.named_parameters(): if 'selector' in name and 'rnn' not in name and 'embeddings' not in name: param.requires_grad = False n_params, sel, enc, dec, gen = _my_tally_parameters(model) logger.info('selector: %d' % sel) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('generator: %d' % gen) logger.info('* number of parameters: %d' % n_params) print_trainable_parameters(model) _check_save_model_path(opt) # Build optimizer. optim = build_optim(model, opt, checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, model, fields, optim, data_type, model_saver=model_saver) def train_iter_fct(): return build_dataset_iter(lazily_load_dataset("train", opt), fields, opt) def valid_iter_fct(): return build_dataset_iter(lazily_load_dataset("valid", opt), fields, opt) # Do training. trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps, opt.valid_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()