def shard_iterator(srcs, tgts, ids, aligns, existing_shards, existing_fields, corpus_type, opt): """ Builds a single iterator yielding every shard of every corpus. """ for src, tgt, maybe_id, maybe_align in zip(srcs, tgts, ids, aligns): if maybe_id in existing_shards: if opt.overwrite: logger.warning( "Overwrite shards for corpus {}".format(maybe_id)) else: if corpus_type == "train": assert existing_fields is not None,\ ("A 'vocab.pt' file should be passed to " "`-src_vocab` when adding a corpus to " "a set of already existing shards.") logger.warning("Ignore corpus {} because " "shards already exist".format(maybe_id)) continue if ((corpus_type == "train" or opt.filter_valid) and tgt is not None): filter_pred = partial(inputters.filter_example, use_src_len=opt.data_type == "text", max_src_len=opt.src_seq_length, max_tgt_len=opt.tgt_seq_length) else: filter_pred = None src_shards = split_corpus(src, opt.shard_size) tgt_shards = split_corpus(tgt, opt.shard_size) align_shards = split_corpus(maybe_align, opt.shard_size) for i, (ss, ts, a_s) in enumerate(zip(src_shards, tgt_shards, align_shards)): yield (i, (ss, ts, a_s, maybe_id, filter_pred))
def _get_subword_kwargs(self, side='src'): """Return a dict containing kwargs relate to `side` subwords.""" subword_type = self.tgt_subword_type if side == 'tgt' \ else self.src_subword_type subword_model = self.tgt_subword_model if side == 'tgt' \ else self.src_subword_model subword_nbest = self.tgt_subword_nbest if side == 'tgt' \ else self.src_subword_nbest subword_alpha = self.tgt_subword_alpha if side == 'tgt' \ else self.src_subword_alpha kwopts = dict() if subword_type == 'bpe': kwopts['bpe_model_path'] = subword_model kwopts['bpe_dropout'] = subword_alpha elif subword_type == 'sentencepiece': kwopts['sp_model_path'] = subword_model kwopts['sp_nbest_size'] = subword_nbest kwopts['sp_alpha'] = subword_alpha else: logger.warning('No subword method will be applied.') vocabulary_threshold = self.tgt_vocab_threshold if side == 'tgt' \ else self.src_vocab_threshold vocabulary_path = self.tgt_subword_vocab if side == 'tgt' \ else self.src_subword_vocab if vocabulary_threshold > 0 and vocabulary_path != "": kwopts['vocabulary_path'] = vocabulary_path kwopts['vocabulary_threshold'] = vocabulary_threshold return kwopts
def prepare_fields_transforms(opt): """Prepare or dump fields & transforms before training.""" transforms_cls = get_transforms_cls(opt._all_transform) specials = get_specials(opt, transforms_cls) fields = build_dynamic_fields(opt, src_specials=specials['src'], tgt_specials=specials['tgt']) # maybe prepare pretrained embeddings, if any prepare_pretrained_embeddings(opt, fields) if opt.dump_fields: save_fields(fields, opt.save_data, overwrite=opt.overwrite) if opt.dump_transforms or opt.n_sample != 0: transforms = make_transforms(opt, transforms_cls, fields) if opt.dump_transforms: save_transforms(transforms, opt.save_data, overwrite=opt.overwrite) if opt.n_sample != 0: logger.warning("`-n_sample` != 0: Training will not be started. " f"Stop after saving {opt.n_sample} samples/corpus.") save_transformed_sample(opt, transforms, n_sample=opt.n_sample) logger.info("Sample saved, please check it before restart training.") sys.exit() return fields, transforms_cls
def _validate_data(cls, opt): """Parse corpora specified in data field of YAML file.""" import yaml default_transforms = opt.transforms if len(default_transforms) != 0: logger.info(f"Default transforms: {default_transforms}.") corpora = yaml.safe_load(opt.data) for cname, corpus in corpora.items(): # Check Transforms _transforms = corpus.get('transforms', None) if _transforms is None: logger.info(f"Missing transforms field for {cname} data, " f"set to default: {default_transforms}.") corpus['transforms'] = default_transforms # Check path path_src = corpus.get('path_src', None) path_tgt = corpus.get('path_tgt', None) if path_src is None: raise ValueError(f'Corpus {cname} src path is required.' 'tgt path is also required for non language' ' modeling tasks.') else: opt.data_task = ModelTask.SEQ2SEQ if path_tgt is None: logger.warning( "path_tgt is None, it should be set unless the task" " is language modeling") opt.data_task = ModelTask.LANGUAGE_MODEL # tgt is src for LM task corpus["path_tgt"] = path_src corpora[cname] = corpus path_tgt = path_src cls._validate_file(path_src, info=f'{cname}/path_src') cls._validate_file(path_tgt, info=f'{cname}/path_tgt') path_align = corpus.get('path_align', None) if path_align is None: if hasattr(opt, 'lambda_align') and opt.lambda_align > 0.0: raise ValueError(f'Corpus {cname} alignment file path are ' 'required when lambda_align > 0.0') corpus['path_align'] = None else: cls._validate_file(path_align, info=f'{cname}/path_align') # Check prefix: will be used when use prefix transform src_prefix = corpus.get('src_prefix', None) tgt_prefix = corpus.get('tgt_prefix', None) if src_prefix is None or tgt_prefix is None: if 'prefix' in corpus['transforms']: raise ValueError(f'Corpus {cname} prefix are required.') # Check weight weight = corpus.get('weight', None) if weight is None: if cname != CorpusName.VALID: logger.warning(f"Corpus {cname}'s weight should be given." " We default it to 1 for you.") corpus['weight'] = 1 logger.info(f"Parsed {len(corpora)} corpora from -data.") opt.data = corpora
def validate_model_opts(cls, model_opt): assert model_opt.model_type in ["text", "img", "audio", 'imgvec', 'none'], \ "Unsupported model type %s" % model_opt.model_type # this check is here because audio allows the encoder and decoder to # be different sizes, but other model types do not yet same_size = model_opt.enc_rnn_size == model_opt.dec_rnn_size assert model_opt.model_type == 'audio' or same_size, \ "The encoder and decoder rnns must be the same size for now" assert model_opt.rnn_type != "SRU" or model_opt.gpu_ranks, \ "Using SRU requires -gpu_ranks set." if model_opt.share_embeddings: if model_opt.model_type != "text": raise AssertionError( "--share_embeddings requires --model_type text.") if model_opt.model_dtype == "fp16": logger.warning( "FP16 is experimental, the generated checkpoints may " "be incompatible with a future version") if model_opt.share_position_embeddings and not model_opt.position_encoding_learned: raise AssertionError( 'It does not make sense to share position embeddings if ' 'they are not learned') if int(model_opt.use_GPT_version_psa) + int(model_opt.use_GPT_version_unconditional) + \ int(model_opt.use_GPT_version_ctxattn) + int(model_opt.decoder_type == 'multi_src_transformer')> 1: raise AssertionError( 'At most one of use_GPT_version, use_GPT_version_alt, ' 'use_GPT_version_psa, use_GPT_version_unconditional, ' 'use_GPT_version_ctxattn can be true at the same time. ' 'Or, multi_src_transformer has a specific psa version') if model_opt.simple_fusion and model_opt.gpt2_params_path is None: raise AssertionError( 'Simple fusion requires setting the gpt2_params_path option') if model_opt.attn_hidden > 0: raise NotImplementedError if model_opt.GPT_representation_mode != 'none' and ( model_opt.gpt2_init_embanddec or model_opt.simple_fusion or model_opt.gpt2_init_embandenc): raise AssertionError( 'loading GPT weights for seq2seq initialization AND GPT ' 'probably does not make sense') if model_opt.GPT_representation_mode != 'none' and ( model_opt.gpt2_init_embanddec or model_opt.simple_fusion or model_opt.gpt2_init_embandenc): raise AssertionError( 'loading GPT weights for seq2seq initialization AND GPT ' 'probably does not make sense') if model_opt.decoder_type == 'multi_src_transformer' and model_opt.num_src < 1: raise AssertionError( 'using GPT with multiple sources. no point in passing num_src < 1. ' 'You really want to use num_src=1 only when debugging.')
def validate_model_opts(cls, model_opt): # this check is here because audio allows the encoder and decoder to # be different sizes, but other model types do not yet same_size = model_opt.enc_rnn_size == model_opt.dec_rnn_size assert model_opt.rnn_type != "SRU" or model_opt.gpu_ranks, \ "Using SRU requires -gpu_ranks set." if model_opt.model_dtype == "fp16": logger.warning( "FP16 is experimental, the generated checkpoints may " "be incompatible with a future version")
def make_transforms(opts, transforms_cls, fields): """Build transforms in `transforms_cls` with vocab of `fields`.""" vocabs = get_vocabs(fields) if fields is not None else None transforms = {} for name, transform_cls in transforms_cls.items(): if transform_cls.require_vocab() and vocabs is None: logger.warning( f"{transform_cls.__name__} require vocab to apply, skip it.") continue transform_obj = transform_cls(opts) transform_obj.warm_up(vocabs) transforms[name] = transform_obj return transforms
def apply(self, example, is_train=False, stats=None, **kwargs): """Return None if mismatch""" if 'src_feats' not in example: # Do nothing return example for feat_name, feat_values in example['src_feats'].items(): if len(example['src']) != len(feat_values): logger.warning(f"Skipping example due to mismatch " f"between source and feature {feat_name}") return None return example
def _add_index(self, stream): for i, item in enumerate(stream): example = item[0] line_number = i * self.stride + self.offset example['indices'] = line_number if (len(example['src']) == 0 or len(example['tgt']) == 0 or ('align' in example and example['align'] == 0)): # empty example: skip empty_msg = f"Empty line exists in {self.cid}#{line_number}." if self.skip_empty_level == 'error': raise IOError(empty_msg) elif self.skip_empty_level == 'warning': logger.warning(empty_msg) continue yield item
def parse_align_idx(align_pharaoh): """ Parse Pharaoh alignment into [[<src>, <tgt>], ...] """ align_list = align_pharaoh.strip().split(' ') flatten_align_idx = [] for align in align_list: try: src_idx, tgt_idx = align.split('-') except ValueError: logger.warning("{} in `{}`".format(align, align_pharaoh)) logger.warning("Bad alignement line exists. Please check file!") raise flatten_align_idx.append([int(src_idx), int(tgt_idx)]) return flatten_align_idx
def validate_model_opts(cls, model_opt): assert model_opt.model_type in ["text"], \ "Unsupported model type %s" % model_opt.model_type same_size = model_opt.enc_rnn_size == model_opt.dec_rnn_size assert same_size, \ "The encoder and decoder rnns must be the same size for now" if model_opt.share_embeddings: if model_opt.model_type != "text": raise AssertionError( "--share_embeddings requires --model_type text.") if model_opt.model_dtype == "fp16": logger.warning( "FP16 is experimental, the generated checkpoints may " "be incompatible with a future version")
def validate_model_opts(cls, model_opt): assert model_opt.model_type in ["text", "img", "audio"], \ "Unsupported model type %s" % model_opt.model_type # this check is here because audio allows the encoder and decoder to # be different sizes, but other model types do not yet same_size = model_opt.enc_rnn_size == model_opt.dec_rnn_size assert model_opt.model_type == 'audio' or same_size, \ "The encoder and decoder rnns must be the same size for now" assert model_opt.rnn_type != "SRU" or model_opt.gpu_ranks, \ "Using SRU requires -gpu_ranks set." if model_opt.share_embeddings: if model_opt.model_type != "text": raise AssertionError( "--share_embeddings requires --model_type text.") if model_opt.model_dtype == "fp16": logger.warning( "FP16 is experimental, the generated checkpoints may " "be incompatible with a future version")
def check_existing_pt_files(opt, corpus_type, ids, existing_fields): """ Check if there are existing .pt files to avoid overwriting them """ existing_shards = [] for maybe_id in ids: if maybe_id: shard_base = corpus_type + "_" + maybe_id else: shard_base = corpus_type pattern = opt.save_data + '.{}.*.pt'.format(shard_base) if glob.glob(pattern): if opt.overwrite: maybe_overwrite = ("will be overwritten because " "`-overwrite` option is set.") else: maybe_overwrite = ("won't be overwritten, pass the " "`-overwrite` option if you want to.") logger.warning("Shards for corpus {} already exist, {}".format( shard_base, maybe_overwrite)) existing_shards += [maybe_id] return existing_shards
def batch_iter(data, batch_size, batch_size_fn=None, batch_size_multiple=1): """Yield elements from data in chunks of batch_size, where each chunk size is a multiple of batch_size_multiple. This is an extended version of torchtext.data.batch. """ # print("==================satrt batch_iter function======================") if batch_size_fn is None: def batch_size_fn(new, count, sofar): return count minibatch, size_so_far = [], 0 for ex in data: # print(ex) minibatch.append(ex) size_so_far = batch_size_fn(ex, len(minibatch), size_so_far) if size_so_far >= batch_size: overflowed = 0 if size_so_far > batch_size: overflowed += 1 if batch_size_multiple > 1: overflowed += ((len(minibatch) - overflowed) % batch_size_multiple) if overflowed == 0: yield minibatch minibatch, size_so_far = [], 0 else: if overflowed == len(minibatch): logger.warning("An example was ignored, more tokens" " than allowed by tokens batch_size") else: yield minibatch[:-overflowed] minibatch = minibatch[-overflowed:] size_so_far = 0 for i, ex in enumerate(minibatch): size_so_far = batch_size_fn(ex, i + 1, size_so_far) # print("====================through batch_iter yield minibatch=============================") if minibatch: yield minibatch
def _init_train(opt): """Common initilization stuff for all training process.""" ArgumentParser.validate_prepare_opts(opt) if opt.train_from: # Load checkpoint if we resume from a previous training. checkpoint = load_checkpoint(ckpt_path=opt.train_from) fields = load_fields(opt.save_data, checkpoint) transforms_cls = get_transforms_cls(opt._all_transform) if (hasattr(checkpoint["opt"], '_all_transform') and len(opt._all_transform.symmetric_difference( checkpoint["opt"]._all_transform)) != 0): _msg = "configured transforms is different from checkpoint:" new_transf = opt._all_transform.difference( checkpoint["opt"]._all_transform) old_transf = checkpoint["opt"]._all_transform.difference( opt._all_transform) if len(new_transf) != 0: _msg += f" +{new_transf}" if len(old_transf) != 0: _msg += f" -{old_transf}." logger.warning(_msg) if opt.update_vocab: logger.info("Updating checkpoint vocabulary with new vocabulary") fields, transforms_cls = prepare_fields_transforms(opt) else: checkpoint = None fields, transforms_cls = prepare_fields_transforms(opt) # Report src and tgt vocab sizes for side in ['src', 'tgt']: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] for sn, sf in f_iter: if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) return checkpoint, fields, transforms_cls
def batch_iter(data, batch_size, batch_size_fn=None, batch_size_multiple=1): """Yield elements from data in chunks of batch_size, where each chunk size is a multiple of batch_size_multiple. This is an extended version of torchtext.data.batch. """ if batch_size_fn is None: def batch_size_fn(new, count, sofar): return count minibatch, size_so_far = [], 0 for ex in data: minibatch.append(ex) size_so_far = batch_size_fn(ex, len(minibatch), size_so_far) if size_so_far >= batch_size: overflowed = 0 if size_so_far > batch_size: overflowed += 1 if batch_size_multiple > 1: overflowed += ((len(minibatch) - overflowed) % batch_size_multiple) if overflowed == 0: yield minibatch minibatch, size_so_far = [], 0 else: if overflowed == len(minibatch): logger.warning( "The batch will be filled until we reach %d," "its size may exceed %d tokens" % (batch_size_multiple, batch_size)) else: yield minibatch[:-overflowed] minibatch = minibatch[-overflowed:] size_so_far = 0 for i, ex in enumerate(minibatch): size_so_far = batch_size_fn(ex, i + 1, size_so_far) if minibatch: yield minibatch
def __init__(self, opt, generator, indicator_vocab, tgt_vocab, eps=1e-20): super(E2ELossCompute, self).__init__() self.key_model = opt.key_model self.tgt_vocab = tgt_vocab self.cur_dataset = None self.force_copy = opt.copy_attn_force self.top_k = opt.top_k self.sel_report_topk = opt.top_k self.sel_normalize_by_length = opt.sel_normalize_by_length self.gen_normalize_by_length = opt.gen_normalize_by_length self.incons_normalize_by_length = opt.incons_normalize_by_length self.pos_weight = opt.pos_weight # default 9.0 self.sel_threshold = opt.sel_threshold # default 0.9 self.sel_lambda = opt.sel_lambda # default 0.5 self.sel_train_ratio = opt.sel_train_ratio # default 1.0 self.gen_lambda = opt.gen_lambda # default 0.5 self.incons_lambda = opt.incons_lambda # default 0.5 self.generator = generator if opt.key_model != 'key_selector': self.tgt_padding_idx = tgt_vocab.stoi[inputters.PAD_WORD] if opt.key_model != 'key_generator': assert len(indicator_vocab) == 3 self.src_unk_idx = inputters.UNK self.sel_padding_idx = indicator_vocab.stoi[inputters.PAD_WORD] self.pos_idx = indicator_vocab.stoi['I'] self.neg_idx = indicator_vocab.stoi['O'] # BCEWithLogits loss for extraction (selector) if self.key_model == 'key_selector' or (self.key_model == 'key_end2end' and self.sel_lambda != 0.0): self.bcewithlogits_criterion =\ nn.BCEWithLogitsLoss(reduction='none', pos_weight=torch.tensor([self.pos_weight])) else: self.bcewithlogits_criterion = None # CopyGenerator loss for generation (generator) # (len(tgt_vocab), force_copy, self.padding_idx) if self.key_model == 'key_generator' or self.key_model == 'key_end2end': self.copynmtloss_criterion =\ onmt.modules.CopyGeneratorCriterion(len(tgt_vocab), self.force_copy, self.tgt_padding_idx) else: self.copynmtloss_criterion = 0.0 # inconsistency loss for extraction attention and generating attention if self.key_model == 'key_end2end' and self.incons_lambda != 0.0: self.inconsistloss_criterion = InconsistencyLoss(top_k=self.top_k) if not self.sel_normalize_by_length: logger.warning("These selector losses will not be normalized by length since opt.sel_normalize_by_length=False!") if not self.gen_normalize_by_length: logger.warning("These generator losses will not be normalized by length since opt.gen_normalize_by_length=False!") if not self.incons_normalize_by_length: logger.warning("These inconsisitency losses will not be normalized by length since opt.incons_normalize_by_length=False!")
def main(opt, device_id): # NOTE: It's important that ``opt`` has been validated and updated # at this point. configure_process(opt, device_id) init_logger(opt.log_file) assert len(opt.accum_count) == len(opt.accum_steps), \ 'Number of accum_count values must match number of accum_steps' # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: checkpoint = None model_opt = opt vocab = torch.load(opt.data + '.vocab.pt') logger.info('Loading alignment.') lemma_aligns = open(model_opt.lemma_align, 'rb').readlines() src_stoi = vocab['src'].base_field.vocab.stoi lemma_stoi = vocab['word_topic'].base_field.vocab.stoi w2l = {} word_to_lemma = [] for pair in lemma_aligns: pair = pair.strip().split() w2l[src_stoi[pair[0].decode('utf-8')]] = \ lemma_stoi[pair[1].decode('utf-8')] w2l[src_stoi['unk']] = lemma_stoi['unk'] for index in range(len(vocab['src'].base_field.vocab.itos)): if index in w2l: word_to_lemma.append(w2l[index]) else: word_to_lemma.append(w2l[lemma_stoi['unk']]) word_to_lemma = torch.tensor(word_to_lemma) logger.info('Loading topic matrix') if device_id >= 0: topic_matrix = torch.load(opt.topic_matrix, map_location=torch.device(device_id)) else: topic_matrix = torch.load(opt.topic_matrix) if opt.model_dtype == 'fp16': topic_matrix = topic_matrix.half() # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab # Report src and tgt vocab sizes, including for features for side in ['src', 'tgt']: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] for sn, sf in f_iter: if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, device_id, model, fields, optim, model_saver=model_saver) train_iter = build_dataset_iter("train", fields, opt) valid_iter = build_dataset_iter("valid", fields, opt, is_train=False) if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') train_steps = opt.train_steps if opt.single_pass and train_steps > 0: logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 trainer.train(topic_matrix, word_to_lemma, train_iter, train_steps, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iter=valid_iter, valid_steps=opt.valid_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def main(opt, fields, transforms_cls, checkpoint, device_id, batch_queue=None, semaphore=None): """Start training on `device_id`.""" # NOTE: It's important that ``opt`` has been validated and updated # at this point. configure_process(opt, device_id) init_logger(opt.log_file) model_opt = _get_model_opts(opt, checkpoint=checkpoint) # Build model. model = build_model(model_opt, opt, fields, checkpoint) model.count_parameters(log=logger.info) # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, device_id, model, fields, optim, model_saver=model_saver) if batch_queue is None: _train_iter = _build_train_iter(opt, fields, transforms_cls) train_iter = IterOnDevice(_train_iter, device_id) else: assert semaphore is not None, \ "Using batch_queue requires semaphore as well" def _train_iter(): while True: batch = batch_queue.get() semaphore.release() # Move batch to specified device IterOnDevice.batch_to_device(batch, device_id) yield batch train_iter = _train_iter() valid_iter = _build_valid_iter(opt, fields, transforms_cls) if valid_iter is not None: valid_iter = IterOnDevice(valid_iter, device_id) if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') train_steps = opt.train_steps if opt.single_pass and train_steps > 0: logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 trainer.train(train_iter, train_steps, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iter=valid_iter, valid_steps=opt.valid_steps) if trainer.report_manager.tensorboard_writer is not None: trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id): # NOTE: It's important that ``opt`` has been validated and updated # at this point. configure_process(opt, device_id) init_logger(opt.log_file) assert len(opt.accum_count) == len(opt.accum_steps), \ 'Number of accum_count values must match number of accum_steps' # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: checkpoint = None model_opt = opt vocab = torch.load(opt.data + '.vocab.pt') # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab( vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab # Report src and tgt vocab sizes, including for features for side in ['src', 'tgt']: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] for sn, sf in f_iter: if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer( opt, device_id, model, fields, optim, model_saver=model_saver) train_iter = build_dataset_iter("train", fields, opt) valid_iter = build_dataset_iter( "valid", fields, opt, is_train=False) if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') train_steps = opt.train_steps if opt.single_pass and train_steps > 0: logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 trainer.train( train_iter, train_steps, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iter=valid_iter, valid_steps=opt.valid_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id): # NOTE: It's important that ``opt`` has been validated and updated # at this point. configure_process(opt, device_id) init_logger(opt.log_file) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: checkpoint = None model_opt = opt vocab = torch.load(opt.data + '.vocab.pt') # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab # Report src and tgt vocab sizes, including for features for side in ['src', 'tgt']: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] for sn, sf in f_iter: if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, device_id, model, fields, optim, model_saver=model_saver) train_iter = build_dataset_iter("train", fields, opt) valid_iter = build_dataset_iter("valid", fields, opt, is_train=False) if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') train_steps = opt.train_steps if opt.single_pass and train_steps > 0: logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 trainer.train(train_iter, train_steps, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iter=valid_iter, valid_steps=opt.valid_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def build_base_model(model_opt, fields, gpu, checkpoint=None, gpu_id=None): """ Args: model_opt: the option loaded from checkpoint. fields: `Field` objects for the model. gpu (bool): whether to use gpu. checkpoint: the model gnerated by train phase, or a resumed snapshot model from a stopped training. gpu_id (int or NoneType): Which GPU to use. Returns: the NMTModel. """ assert model_opt.model_type in ["text", "img", "audio"], \ "Unsupported model type %s" % model_opt.model_type # for backward compatibility if model_opt.rnn_size != -1: model_opt.enc_rnn_size = model_opt.rnn_size model_opt.dec_rnn_size = model_opt.rnn_size # Build embeddings. if model_opt.model_type == "text": src_fields = [f for n, f in fields['src']] assert len(src_fields) == 1 src_field = src_fields[0] src_emb = build_embeddings(model_opt, src_field) else: src_emb = None # Build encoder. encoder = build_encoder(model_opt, src_emb) # Build decoder. tgt_fields = [f for n, f in fields['tgt']] assert len(tgt_fields) == 1 tgt_field = tgt_fields[0] tgt_emb = build_embeddings(model_opt, tgt_field, for_encoder=False) # Share the embedding matrix - preprocess with share_vocab required. if model_opt.share_embeddings: # src/tgt vocab should be the same if `-share_vocab` is specified. assert src_field.base_field.vocab == tgt_field.base_field.vocab, \ "preprocess with -share_vocab if you use share_embeddings" tgt_emb.word_lut.weight = src_emb.word_lut.weight decoder = build_decoder(model_opt, tgt_emb) # Build NMTModel(= encoder + decoder). if gpu and gpu_id is not None: device = torch.device("cuda", gpu_id) elif gpu and not gpu_id: device = torch.device("cuda") elif not gpu: device = torch.device("cpu") model = onmt.models.NMTModel(encoder, decoder) # Build Generator. if not model_opt.copy_attn: if model_opt.generator_function == "sparsemax": gen_func = onmt.modules.sparse_activations.LogSparsemax(dim=-1) else: gen_func = nn.LogSoftmax(dim=-1) generator = nn.Sequential( nn.Linear(model_opt.dec_rnn_size, len(fields["tgt"][0][1].base_field.vocab)), Cast(torch.float32), gen_func) if model_opt.share_decoder_embeddings: generator[0].weight = decoder.embeddings.word_lut.weight else: assert len(fields["tgt"]) == 1 tgt_base_field = fields["tgt"][0][1].base_field vocab_size = len(tgt_base_field.vocab) pad_idx = tgt_base_field.vocab.stoi[tgt_base_field.pad_token] generator = CopyGenerator(model_opt.dec_rnn_size, vocab_size, pad_idx) # Load the model states from checkpoint or initialize them. if checkpoint is not None: # This preserves backward-compat for models using customed layernorm def fix_key(s): s = re.sub(r'(.*)\.layer_norm((_\d+)?)\.b_2', r'\1.layer_norm\2.bias', s) s = re.sub(r'(.*)\.layer_norm((_\d+)?)\.a_2', r'\1.layer_norm\2.weight', s) return s checkpoint['model'] = { fix_key(k): v for k, v in checkpoint['model'].items() } # end of patch for backward compatibility model.load_state_dict(checkpoint['model'], strict=False) generator.load_state_dict(checkpoint['generator'], strict=False) else: if model_opt.param_init != 0.0: for p in model.parameters(): p.data.uniform_(-model_opt.param_init, model_opt.param_init) for p in generator.parameters(): p.data.uniform_(-model_opt.param_init, model_opt.param_init) if model_opt.param_init_glorot: for p in model.parameters(): if p.dim() > 1: xavier_uniform_(p) for p in generator.parameters(): if p.dim() > 1: xavier_uniform_(p) if hasattr(model.encoder, 'embeddings'): model.encoder.embeddings.load_pretrained_vectors( model_opt.pre_word_vecs_enc) if hasattr(model.decoder, 'embeddings'): model.decoder.embeddings.load_pretrained_vectors( model_opt.pre_word_vecs_dec) model.generator = generator model.to(device) if model_opt.model_dtype == 'fp16': logger.warning('FP16 is experimental, the generated checkpoints may ' 'be incompatible with a future version') model.half() return model
def warm_up(self, vocabs): self.vocab = vocabs if vocabs is None: logger.warning( "Switchout disable as no vocab, shouldn't happen in training!") self.temperature = self.opts.switchout_temperature
def warm_up(self, vocabs): super().warm_up(None) self.vocabs = vocabs if vocabs is None: logger.warning( "Switchout disable as no vocab, shouldn't happen in training!")
def main(opt, device_id, batch_queue=None, semaphore=None): # NOTE: It's important that ``opt`` has been validated and updated # at this point. configure_process(opt, device_id) init_logger(opt.log_file) assert len(opt.accum_count) == len(opt.accum_steps), \ 'Number of accum_count values must match number of accum_steps' # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: checkpoint = None model_opt = opt vocab = torch.load(opt.data + '.vocab.pt') # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab # Report src and tgt vocab sizes, including for features for side in ['src', 'tgt']: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] for sn, sf in f_iter: if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec, nontrainable = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('non-trainable parameters (tgt_out_emb): %d' % nontrainable) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, device_id, model, fields, optim, model_saver=model_saver) if batch_queue is None: if len(opt.data_ids) > 1: train_shards = [] for train_id in opt.data_ids: shard_base = "train_" + train_id train_shards.append(shard_base) train_iter = build_dataset_iter_multiple(train_shards, fields, opt) else: if opt.data_ids[0] is not None: shard_base = "train_" + opt.data_ids[0] else: shard_base = "train" train_iter = build_dataset_iter(shard_base, fields, opt) else: assert semaphore is not None, \ "Using batch_queue requires semaphore as well" def _train_iter(): while True: batch = batch_queue.get() semaphore.release() yield batch train_iter = _train_iter() valid_iter = build_dataset_iter("valid", fields, opt, is_train=False) if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') train_steps = opt.train_steps if opt.single_pass and train_steps > 0: logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 trainer.train(train_iter, train_steps, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iter=valid_iter, valid_steps=opt.valid_steps) if trainer.report_manager.tensorboard_writer is not None: trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id, batch_queue=None, semaphore=None): # NOTE: It's important that ``opt`` has been validated and updated # at this point. configure_process(opt, device_id) init_logger(opt.log_file) # save training settings if opt.log_file: shutil.copy2(opt.config, opt.exp_dir) logger.info(vars(opt)) assert len(opt.accum_count) == len(opt.accum_steps), \ 'Number of accum_count values must match number of accum_steps' # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: checkpoint = None model_opt = opt # added by @memray for multiple datasets if opt.vocab and opt.vocab != 'none': vocab = torch.load(opt.vocab) elif opt.encoder_type == 'pretrained': vocab = None else: vocab = None # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab( vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab # @memray: a temporary workaround, as well as train.py line 43 if opt.model_type == "keyphrase": if opt.tgt_type in ["one2one", "multiple"]: if 'sep_indices' in fields: del fields['sep_indices'] else: if 'sep_indices' not in fields: sep_indices = Field( use_vocab=False, dtype=torch.long, postprocessing=make_tgt, sequential=False) fields["sep_indices"] = sep_indices if 'src_ex_vocab' not in fields: src_ex_vocab = RawField() fields["src_ex_vocab"] = src_ex_vocab tokenizer = None if opt.pretrained_tokenizer: tokenizer = load_pretrained_tokenizer(opt.pretrained_tokenizer, opt.cache_dir, opt.special_vocab_path) setattr(opt, 'vocab_size', len(tokenizer)) if opt.data_type == 'news': fields = reload_news_fields(fields, opt, tokenizer) # Report src and tgt vocab sizes, including for features for side in ['src', 'tgt']: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] for sn, sf in f_iter: if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer( opt, device_id, model, fields, optim, model_saver=model_saver) if batch_queue is None: if len(opt.data_ids) > 1: # added by @memray, for loading multiple datasets if opt.multi_dataset: shard_base = "train" train_iter = build_dataset_iter(shard_base, fields, opt, tokenizer=tokenizer) else: train_shards = [] for train_id in opt.data_ids: shard_base = "train_" + train_id train_shards.append(shard_base) train_iter = build_dataset_iter_multiple(train_shards, fields, opt, tokenizer=tokenizer) else: shard_base = "train" train_iter = build_dataset_iter(shard_base, fields, opt) else: assert semaphore is not None, \ "Using batch_queue requires semaphore as well" def _train_iter(): while True: batch = batch_queue.get() semaphore.release() yield batch train_iter = _train_iter() if opt.valid: valid_iter = build_dataset_iter( "valid", fields, opt, is_train=False) else: valid_iter = None if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') train_steps = opt.train_steps if opt.single_pass and train_steps > 0: logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 trainer.train( train_iter, train_steps, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iter=valid_iter, valid_steps=opt.valid_steps) if trainer.report_manager.tensorboard_writer is not None: trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id): opt = training_opt_postprocessing(opt, device_id) init_logger(opt.log_file) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) # Load default opts values then overwrite it with opts from # the checkpoint. It's usefull in order to re-train a model # after adding a new option (not set in checkpoint) dummy_parser = configargparse.ArgumentParser() opts.model_opts(dummy_parser) default_opt = dummy_parser.parse_known_args([])[0] model_opt = default_opt model_opt.__dict__.update(checkpoint['opt'].__dict__) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: checkpoint = None model_opt = opt vocab = torch.load(opt.data + '.vocab.pt') # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): data_type = opt.model_type fields = load_old_vocab(vocab, data_type, dynamic_dict=opt.copy_attn) else: fields = vocab # Report src and tgt vocab sizes, including for features for side in ['src', 'tgt']: for name, f in fields[side]: try: f_iter = iter(f) except TypeError: f_iter = [(name, f)] for sn, sf in f_iter: if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, device_id, model, fields, optim, model_saver=model_saver) # this line is kind of a temporary kludge because different objects expect # fields to have a different structure dataset_fields = dict(chain.from_iterable(fields.values())) train_iter = build_dataset_iter("train", dataset_fields, opt) valid_iter = build_dataset_iter("valid", dataset_fields, opt, is_train=False) if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') train_steps = opt.train_steps if opt.single_pass and train_steps > 0: logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 trainer.train(train_iter, train_steps, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iter=valid_iter, valid_steps=opt.valid_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id): # NOTE: It's important that ``opt`` has been validated and updated # at this point. # import pdb # _check_ = torch.load("/home/irteam/users/kaist/ginalee/clean_data/baselines/9-domain5-185pre_step_2500.pt") # model_encoder = [i for i in _check_['model'].keys() if "encoder" in i.split(".")] # encoder = {} # pdb.set_trace() # for i, param in enumerate(model_encoder): # if i == 0: # encoder['embeddings.word_embeddings.weight'] = _check_['model'][param] # continue # param_ = ".".join(param.split(".")[1:]) # # if param.split(".")[1] == 'encoder': # # param_ = ".".join(param.split(".")[2:]) # # else: # # param_ = ".".join(param.split(".")[1:]) # encoder[param_] = _check_['model'][param] # pdb.set_trace() configure_process(opt, device_id) init_logger(opt.log_file) logger.info(opt) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) load_vocab = torch.load(opt.data + '.vocab.pt') vocab = checkpoint['vocab'] load_vocab['src'].fields[0][1].vocab = vocab['src'].fields[0][1].vocab load_vocab['tgt'].fields[0][1].vocab = vocab['tgt'].fields[0][1].vocab vocab = load_vocab else: checkpoint = None model_opt = opt vocab = torch.load(opt.data + '.vocab.pt') # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab # Report src and tgt vocab sizes, including for features for side in ['src', 'tgt']: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] for sn, sf in f_iter: if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) if opt.pretrain_from: check = torch.load(opt.pretrain_from, map_location=lambda storage, loc: storage) model.load_state_dict(check['model'], strict=False) model.load_state_dict(check['generator'], strict=False) if 'dom_classifier' in check: model.load_state_dict(check['dom_classifier'], strict=False) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) translator = None if opt.domain_cls_enc == False: translator = train_build_translator(opt, model, model_opt, fields, report_score=True) trainer = build_trainer(translator, opt, device_id, model, fields, optim, model_saver=model_saver) train_iter = build_dataset_iter("train", fields, opt) valid_iter = build_dataset_iter("valid", fields, opt, is_train=False) if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') train_steps = opt.train_steps if opt.single_pass and train_steps > 0: logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 trainer.train(train_iter, train_steps, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iter=valid_iter, valid_steps=opt.valid_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id): # NOTE: It's important that ``opt`` has been validated and updated # at this point. configure_process(opt, device_id) init_logger(opt.log_file) # Load checkpoint if we resume from a previous training. load_str = opt.train_from if opt.train_from else opt.load_uncond_from if load_str: logger.info('Loading checkpoint from %s' % load_str) checkpoint = torch.load(load_str, map_location=lambda storage, loc: storage) logger.info('Loading vocab from checkpoint at %s.' % load_str) vocab = checkpoint['vocab'] if opt.train_from: model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) else: model_opt = opt else: checkpoint = None model_opt = opt vocab = torch.load(opt.data + '.vocab.pt') if opt.gpt2_params_path is not None: import tensorflow as tf import numpy as np # Taken from pytorch-pretrained-BERT: # Load weights from TF model logger.info("Loading TF GPT weights...") init_vars = tf.train.list_variables(opt.gpt2_params_path) names = [] arrays = [] for name, shape in init_vars: if opt.gpt_emb_only and ('wpe' not in name and 'wte' not in name): continue if opt.gpt_wpe_only and 'wpe' not in name: continue #print("Loading TF weight {} with shape {}".format(name, shape)) array = tf.train.load_variable(opt.gpt2_params_path, name) names.append(name) arrays.append(array.squeeze()) logger.info("Done.") if checkpoint is None: checkpoint = {'gpt2_params': zip(names, arrays)} else: checkpoint['gpt2_params'] = zip(names, arrays) if opt.encoder_from is not None: logger.info('Loading checkpoint with encoder from %s' % opt.encoder_from) enc_checkpoint = torch.load(opt.encoder_from, map_location=lambda storage, loc: storage) enc_vocab = enc_checkpoint['vocab'] if vocab['src'].base_field.vocab != enc_vocab['src'].base_field.vocab: raise ValueError( 'encoder vocab and model vocab need to be identical it using pretrained encoder' ) if checkpoint is None: checkpoint = {} checkpoint['enc_model'] = enc_checkpoint['model'] # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab # Report src and tgt vocab sizes, including for features sides = ['tgt'] if opt.model_type == 'none' else ['src', 'tgt'] for side in sides: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] for sn, sf in f_iter: if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec, lm_dec = _tally_parameters(model) n_params_t, enc_t, dec_t, lm_dec_t = _tally_parameters(model, only_trainable=True) logger.info('encoder: %d (%d)' % (enc, enc_t)) logger.info('decoder: %d (%d)' % (dec, dec_t)) if opt.simple_fusion: logger.info('lm decoder: %d (%d)' % (lm_dec, lm_dec_t)) logger.info('* number of parameters: %d (%d)' % (n_params, n_params_t)) _check_save_model_path(opt) if not opt.train_from and opt.gpt2_params_path is not None: checkpoint = None # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, device_id, model, fields, optim, model_saver=model_saver) train_iter = build_dataset_iter("train", fields, opt) valid_iter = build_dataset_iter("valid", fields, opt, is_train=False) if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') train_steps = opt.train_steps if opt.single_pass and train_steps > 0: logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 trainer.train(train_iter, train_steps, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iter=valid_iter, valid_steps=opt.valid_steps) if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id): # NOTE: It's important that ``opt`` has been validated and updated # at this point. if opt.local_rank != -1: torch.cuda.set_device(opt.local_rank) device = torch.device("cuda", opt.local_rank) torch.distributed.init_process_group(backend='nccl') device_id = opt.local_rank world_size = torch.distributed.get_world_size() else: if device_id == -1: device = torch.device("cpu") else: device = torch.device("cuda", device_id) if opt.local_rank > 0: logger.disabled = True configure_process(opt, device_id) init_logger(opt.log_file) # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: checkpoint = None model_opt = opt vocab = torch.load(opt.data + '.vocab.pt') # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab # Report src and tgt vocab sizes, including for features for side in ['src', 'tgt']: f = fields[side] try: f_iter = iter(f) except TypeError: f_iter = [(side, f)] for sn, sf in f_iter: if sf.use_vocab: logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) # Build model. model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer(opt, device_id, model, fields, optim, model_saver=model_saver) if opt.bert_kd: src_vocab = vocab['src'].fields[0][1].vocab.stoi tgt_vocab = vocab['tgt'].fields[0][1].vocab.stoi assert 0 < opt.kd_topk <= 128 train_dataset = BertKdDataset(opt.data_db, opt.bert_dump, src_vocab, tgt_vocab, max_len=150, k=opt.kd_topk) BUCKET_SIZE = 8192 if True or opt.local_rank == -1 and opt.world_size == 1: train_sampler = TokenBucketSampler(train_dataset.keys, BUCKET_SIZE, opt.batch_size, batch_multiple=1) else: assert False # seems like it's handled in training loop train_sampler = DistributedTokenBucketSampler(world_size, device_id, train_dataset.keys, BUCKET_SIZE, opt.batch_size, batch_multiple=1) train_loader = DataLoader(train_dataset, batch_sampler=train_sampler, num_workers=4, collate_fn=BertKdDataset.pad_collate) train_iter = cycle_loader(train_loader, device) else: train_iter = build_dataset_iter("train", fields, opt) valid_iter = build_dataset_iter("valid", fields, opt, is_train=False) if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') train_steps = opt.train_steps if opt.single_pass and train_steps > 0: logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 trainer.train(train_iter, train_steps, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iter=valid_iter, valid_steps=opt.valid_steps) if opt.tensorboard: if trainer.report_manager.tensorboard_writer: trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id, batch_queue=None, semaphore=None): # NOTE: It's important that ``opt`` has been validated and updated # at this point. configure_process(opt, device_id) init_logger(opt.log_file) assert len(opt.accum_count) == len(opt.accum_steps), \ 'Number of accum_count values must match number of accum_steps' # Load checkpoint if we resume from a previous training. if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] else: checkpoint = None model_opt = opt vocab = torch.load(opt.data + '.vocab.pt') if opt.teacher_model_path: logger.info('Loading teacher model from {path}'.format( path=opt.teacher_model_path)) teacher_model_ckpt = torch.load( opt.teacher_model_path, map_location=lambda storage, loc: storage) teacher_model_opt = ArgumentParser.ckpt_model_opts( teacher_model_ckpt['opt']) ArgumentParser.update_model_opts(teacher_model_opt) ArgumentParser.validate_model_opts(teacher_model_opt) logger.info('Loading vocab from checkpoint at {path}'.format( path=opt.teacher_model_path)) teacher_vocab = teacher_model_ckpt['vocab'] # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) if old_style_vocab(vocab): fields = load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: fields = vocab teacher_fields = teacher_vocab if opt.teacher_model_path else None # patch for fields that may be missing in old data/model # patch_fields(opt, fields) # Report src and tgt vocab sizes, including for features report_vocab_size(fields) if teacher_fields is not None: report_vocab_size(teacher_fields) # Build model. fields_opt = {"original": fields, "teacher": teacher_fields} model = custom_builder.build_model(model_opt, opt, fields_opt, checkpoint) # model = build_model(model_opt, opt, fields, checkpoint) teacher_model = build_model( teacher_model_opt, teacher_model_opt, teacher_fields, teacher_model_ckpt) if opt.teacher_model_path else None n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(opt) if teacher_model is not None: n_params, enc, dec = _tally_parameters(teacher_model) logger.info('encoder: %d' % enc) logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) _check_save_model_path(teacher_model_opt) # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver # model_saver = build_model_saver(model_opt, opt, model, fields, optim) model_saver = custom_model_saver.build_model_saver(model_opt, opt, model, fields_opt, optim) tgt_field = dict(teacher_fields)["tgt"].base_field if teacher_model is not None \ else dict(fields)["tgt"].base_field sos_id = tgt_field.vocab.stoi[tgt_field.init_token] if teacher_model is not None and opt.word_sampling: sampler = Emulator(teacher_model, teacher_fields, device_id, max_length=50, random_sampling_topk=5) else: sampler = None if teacher_model is not None: trainer = build_trainer(opt, device_id, model, teacher_fields, optim, model_saver, teacher_model=teacher_model, emulator=sampler) else: trainer = build_trainer(opt, device_id, model, fields, optim, model_saver, teacher_model=teacher_model, emulator=sampler) if batch_queue is None: if len(opt.data_ids) > 1: train_shards = [] for train_id in opt.data_ids: shard_base = "train_" + train_id train_shards.append(shard_base) train_iter = build_dataset_iter_multiple(train_shards, fields, opt) else: if opt.data_ids[0] is not None: shard_base = "train_" + opt.data_ids[0] else: shard_base = "train" train_iter = build_dataset_iter(shard_base, fields, opt) else: assert semaphore is not None, \ "Using batch_queue requires semaphore as well" def _train_iter(): while True: batch = batch_queue.get() semaphore.release() yield batch train_iter = _train_iter() valid_iter = build_dataset_iter("valid", fields, opt, is_train=False) if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') train_steps = opt.train_steps if opt.single_pass and train_steps > 0: logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 trainer.train(train_iter, train_steps, sos_id=sos_id, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iter=valid_iter, valid_steps=opt.valid_steps) if trainer.report_manager.tensorboard_writer is not None: trainer.report_manager.tensorboard_writer.close()