Beispiel #1
0
 def shard_iterator(srcs, tgts, ids, aligns, existing_shards,
                    existing_fields, corpus_type, opt):
     """
     Builds a single iterator yielding every shard of every corpus.
     """
     for src, tgt, maybe_id, maybe_align in zip(srcs, tgts, ids, aligns):
         if maybe_id in existing_shards:
             if opt.overwrite:
                 logger.warning(
                     "Overwrite shards for corpus {}".format(maybe_id))
             else:
                 if corpus_type == "train":
                     assert existing_fields is not None,\
                         ("A 'vocab.pt' file should be passed to "
                          "`-src_vocab` when adding a corpus to "
                          "a set of already existing shards.")
                 logger.warning("Ignore corpus {} because "
                                "shards already exist".format(maybe_id))
                 continue
         if ((corpus_type == "train" or opt.filter_valid)
                 and tgt is not None):
             filter_pred = partial(inputters.filter_example,
                                   use_src_len=opt.data_type == "text",
                                   max_src_len=opt.src_seq_length,
                                   max_tgt_len=opt.tgt_seq_length)
         else:
             filter_pred = None
         src_shards = split_corpus(src, opt.shard_size)
         tgt_shards = split_corpus(tgt, opt.shard_size)
         align_shards = split_corpus(maybe_align, opt.shard_size)
         for i, (ss, ts,
                 a_s) in enumerate(zip(src_shards, tgt_shards,
                                       align_shards)):
             yield (i, (ss, ts, a_s, maybe_id, filter_pred))
Beispiel #2
0
 def _get_subword_kwargs(self, side='src'):
     """Return a dict containing kwargs relate to `side` subwords."""
     subword_type = self.tgt_subword_type if side == 'tgt' \
         else self.src_subword_type
     subword_model = self.tgt_subword_model if side == 'tgt' \
         else self.src_subword_model
     subword_nbest = self.tgt_subword_nbest if side == 'tgt' \
         else self.src_subword_nbest
     subword_alpha = self.tgt_subword_alpha if side == 'tgt' \
         else self.src_subword_alpha
     kwopts = dict()
     if subword_type == 'bpe':
         kwopts['bpe_model_path'] = subword_model
         kwopts['bpe_dropout'] = subword_alpha
     elif subword_type == 'sentencepiece':
         kwopts['sp_model_path'] = subword_model
         kwopts['sp_nbest_size'] = subword_nbest
         kwopts['sp_alpha'] = subword_alpha
     else:
         logger.warning('No subword method will be applied.')
     vocabulary_threshold = self.tgt_vocab_threshold if side == 'tgt' \
         else self.src_vocab_threshold
     vocabulary_path = self.tgt_subword_vocab if side == 'tgt' \
         else self.src_subword_vocab
     if vocabulary_threshold > 0 and vocabulary_path != "":
         kwopts['vocabulary_path'] = vocabulary_path
         kwopts['vocabulary_threshold'] = vocabulary_threshold
     return kwopts
Beispiel #3
0
def prepare_fields_transforms(opt):
    """Prepare or dump fields & transforms before training."""
    transforms_cls = get_transforms_cls(opt._all_transform)
    specials = get_specials(opt, transforms_cls)

    fields = build_dynamic_fields(opt,
                                  src_specials=specials['src'],
                                  tgt_specials=specials['tgt'])

    # maybe prepare pretrained embeddings, if any
    prepare_pretrained_embeddings(opt, fields)

    if opt.dump_fields:
        save_fields(fields, opt.save_data, overwrite=opt.overwrite)
    if opt.dump_transforms or opt.n_sample != 0:
        transforms = make_transforms(opt, transforms_cls, fields)
    if opt.dump_transforms:
        save_transforms(transforms, opt.save_data, overwrite=opt.overwrite)
    if opt.n_sample != 0:
        logger.warning("`-n_sample` != 0: Training will not be started. "
                       f"Stop after saving {opt.n_sample} samples/corpus.")
        save_transformed_sample(opt, transforms, n_sample=opt.n_sample)
        logger.info("Sample saved, please check it before restart training.")
        sys.exit()
    return fields, transforms_cls
Beispiel #4
0
    def _validate_data(cls, opt):
        """Parse corpora specified in data field of YAML file."""
        import yaml
        default_transforms = opt.transforms
        if len(default_transforms) != 0:
            logger.info(f"Default transforms: {default_transforms}.")
        corpora = yaml.safe_load(opt.data)

        for cname, corpus in corpora.items():
            # Check Transforms
            _transforms = corpus.get('transforms', None)
            if _transforms is None:
                logger.info(f"Missing transforms field for {cname} data, "
                            f"set to default: {default_transforms}.")
                corpus['transforms'] = default_transforms
            # Check path
            path_src = corpus.get('path_src', None)
            path_tgt = corpus.get('path_tgt', None)
            if path_src is None:
                raise ValueError(f'Corpus {cname} src path is required.'
                                 'tgt path is also required for non language'
                                 ' modeling tasks.')
            else:
                opt.data_task = ModelTask.SEQ2SEQ
                if path_tgt is None:
                    logger.warning(
                        "path_tgt is None, it should be set unless the task"
                        " is language modeling")
                    opt.data_task = ModelTask.LANGUAGE_MODEL
                    # tgt is src for LM task
                    corpus["path_tgt"] = path_src
                    corpora[cname] = corpus
                    path_tgt = path_src
                cls._validate_file(path_src, info=f'{cname}/path_src')
                cls._validate_file(path_tgt, info=f'{cname}/path_tgt')
            path_align = corpus.get('path_align', None)
            if path_align is None:
                if hasattr(opt, 'lambda_align') and opt.lambda_align > 0.0:
                    raise ValueError(f'Corpus {cname} alignment file path are '
                                     'required when lambda_align > 0.0')
                corpus['path_align'] = None
            else:
                cls._validate_file(path_align, info=f'{cname}/path_align')
            # Check prefix: will be used when use prefix transform
            src_prefix = corpus.get('src_prefix', None)
            tgt_prefix = corpus.get('tgt_prefix', None)
            if src_prefix is None or tgt_prefix is None:
                if 'prefix' in corpus['transforms']:
                    raise ValueError(f'Corpus {cname} prefix are required.')
            # Check weight
            weight = corpus.get('weight', None)
            if weight is None:
                if cname != CorpusName.VALID:
                    logger.warning(f"Corpus {cname}'s weight should be given."
                                   " We default it to 1 for you.")
                corpus['weight'] = 1
        logger.info(f"Parsed {len(corpora)} corpora from -data.")
        opt.data = corpora
    def validate_model_opts(cls, model_opt):
        assert model_opt.model_type in ["text", "img", "audio", 'imgvec', 'none'], \
            "Unsupported model type %s" % model_opt.model_type

        # this check is here because audio allows the encoder and decoder to
        # be different sizes, but other model types do not yet
        same_size = model_opt.enc_rnn_size == model_opt.dec_rnn_size
        assert model_opt.model_type == 'audio' or same_size, \
            "The encoder and decoder rnns must be the same size for now"

        assert model_opt.rnn_type != "SRU" or model_opt.gpu_ranks, \
            "Using SRU requires -gpu_ranks set."
        if model_opt.share_embeddings:
            if model_opt.model_type != "text":
                raise AssertionError(
                    "--share_embeddings requires --model_type text.")
        if model_opt.model_dtype == "fp16":
            logger.warning(
                "FP16 is experimental, the generated checkpoints may "
                "be incompatible with a future version")

        if model_opt.share_position_embeddings and not model_opt.position_encoding_learned:
            raise AssertionError(
                'It does not make sense to share position embeddings if '
                'they are not learned')
        if int(model_opt.use_GPT_version_psa) + int(model_opt.use_GPT_version_unconditional) + \
           int(model_opt.use_GPT_version_ctxattn) + int(model_opt.decoder_type == 'multi_src_transformer')> 1:
            raise AssertionError(
                'At most one of use_GPT_version, use_GPT_version_alt, '
                'use_GPT_version_psa, use_GPT_version_unconditional, '
                'use_GPT_version_ctxattn can be true at the same time. '
                'Or, multi_src_transformer has a specific psa version')

        if model_opt.simple_fusion and model_opt.gpt2_params_path is None:
            raise AssertionError(
                'Simple fusion requires setting the gpt2_params_path option')

        if model_opt.attn_hidden > 0:
            raise NotImplementedError

        if model_opt.GPT_representation_mode != 'none' and (
                model_opt.gpt2_init_embanddec or model_opt.simple_fusion
                or model_opt.gpt2_init_embandenc):
            raise AssertionError(
                'loading GPT weights for seq2seq initialization AND GPT '
                'probably does not make sense')

        if model_opt.GPT_representation_mode != 'none' and (
                model_opt.gpt2_init_embanddec or model_opt.simple_fusion
                or model_opt.gpt2_init_embandenc):
            raise AssertionError(
                'loading GPT weights for seq2seq initialization AND GPT '
                'probably does not make sense')

        if model_opt.decoder_type == 'multi_src_transformer' and model_opt.num_src < 1:
            raise AssertionError(
                'using GPT with multiple sources. no point in passing num_src < 1. '
                'You really want to use num_src=1 only when debugging.')
Beispiel #6
0
    def validate_model_opts(cls, model_opt):

        # this check is here because audio allows the encoder and decoder to
        # be different sizes, but other model types do not yet
        same_size = model_opt.enc_rnn_size == model_opt.dec_rnn_size

        assert model_opt.rnn_type != "SRU" or model_opt.gpu_ranks, \
            "Using SRU requires -gpu_ranks set."

        if model_opt.model_dtype == "fp16":
            logger.warning(
                "FP16 is experimental, the generated checkpoints may "
                "be incompatible with a future version")
Beispiel #7
0
def make_transforms(opts, transforms_cls, fields):
    """Build transforms in `transforms_cls` with vocab of `fields`."""
    vocabs = get_vocabs(fields) if fields is not None else None
    transforms = {}
    for name, transform_cls in transforms_cls.items():
        if transform_cls.require_vocab() and vocabs is None:
            logger.warning(
                f"{transform_cls.__name__} require vocab to apply, skip it.")
            continue
        transform_obj = transform_cls(opts)
        transform_obj.warm_up(vocabs)
        transforms[name] = transform_obj
    return transforms
Beispiel #8
0
    def apply(self, example, is_train=False, stats=None, **kwargs):
        """Return None if mismatch"""

        if 'src_feats' not in example:
            # Do nothing
            return example

        for feat_name, feat_values in example['src_feats'].items():
            if len(example['src']) != len(feat_values):
                logger.warning(f"Skipping example due to mismatch "
                               f"between source and feature {feat_name}")
                return None
        return example
Beispiel #9
0
 def _add_index(self, stream):
     for i, item in enumerate(stream):
         example = item[0]
         line_number = i * self.stride + self.offset
         example['indices'] = line_number
         if (len(example['src']) == 0 or len(example['tgt']) == 0
                 or ('align' in example and example['align'] == 0)):
             # empty example: skip
             empty_msg = f"Empty line exists in {self.cid}#{line_number}."
             if self.skip_empty_level == 'error':
                 raise IOError(empty_msg)
             elif self.skip_empty_level == 'warning':
                 logger.warning(empty_msg)
             continue
         yield item
Beispiel #10
0
def parse_align_idx(align_pharaoh):
    """
    Parse Pharaoh alignment into [[<src>, <tgt>], ...]
    """
    align_list = align_pharaoh.strip().split(' ')
    flatten_align_idx = []
    for align in align_list:
        try:
            src_idx, tgt_idx = align.split('-')
        except ValueError:
            logger.warning("{} in `{}`".format(align, align_pharaoh))
            logger.warning("Bad alignement line exists. Please check file!")
            raise
        flatten_align_idx.append([int(src_idx), int(tgt_idx)])
    return flatten_align_idx
Beispiel #11
0
    def validate_model_opts(cls, model_opt):
        assert model_opt.model_type in ["text"], \
            "Unsupported model type %s" % model_opt.model_type

        same_size = model_opt.enc_rnn_size == model_opt.dec_rnn_size
        assert  same_size, \
            "The encoder and decoder rnns must be the same size for now"

        if model_opt.share_embeddings:
            if model_opt.model_type != "text":
                raise AssertionError(
                    "--share_embeddings requires --model_type text.")
        if model_opt.model_dtype == "fp16":
            logger.warning(
                "FP16 is experimental, the generated checkpoints may "
                "be incompatible with a future version")
Beispiel #12
0
    def validate_model_opts(cls, model_opt):
        assert model_opt.model_type in ["text", "img", "audio"], \
            "Unsupported model type %s" % model_opt.model_type

        # this check is here because audio allows the encoder and decoder to
        # be different sizes, but other model types do not yet
        same_size = model_opt.enc_rnn_size == model_opt.dec_rnn_size
        assert model_opt.model_type == 'audio' or same_size, \
            "The encoder and decoder rnns must be the same size for now"

        assert model_opt.rnn_type != "SRU" or model_opt.gpu_ranks, \
            "Using SRU requires -gpu_ranks set."
        if model_opt.share_embeddings:
            if model_opt.model_type != "text":
                raise AssertionError(
                    "--share_embeddings requires --model_type text.")
        if model_opt.model_dtype == "fp16":
            logger.warning(
                "FP16 is experimental, the generated checkpoints may "
                "be incompatible with a future version")
Beispiel #13
0
    def validate_model_opts(cls, model_opt):
        assert model_opt.model_type in ["text", "img", "audio"], \
            "Unsupported model type %s" % model_opt.model_type

        # this check is here because audio allows the encoder and decoder to
        # be different sizes, but other model types do not yet
        same_size = model_opt.enc_rnn_size == model_opt.dec_rnn_size
        assert model_opt.model_type == 'audio' or same_size, \
            "The encoder and decoder rnns must be the same size for now"

        assert model_opt.rnn_type != "SRU" or model_opt.gpu_ranks, \
            "Using SRU requires -gpu_ranks set."
        if model_opt.share_embeddings:
            if model_opt.model_type != "text":
                raise AssertionError(
                    "--share_embeddings requires --model_type text.")
        if model_opt.model_dtype == "fp16":
            logger.warning(
                "FP16 is experimental, the generated checkpoints may "
                "be incompatible with a future version")
Beispiel #14
0
def check_existing_pt_files(opt, corpus_type, ids, existing_fields):
    """ Check if there are existing .pt files to avoid overwriting them """
    existing_shards = []
    for maybe_id in ids:
        if maybe_id:
            shard_base = corpus_type + "_" + maybe_id
        else:
            shard_base = corpus_type
        pattern = opt.save_data + '.{}.*.pt'.format(shard_base)
        if glob.glob(pattern):
            if opt.overwrite:
                maybe_overwrite = ("will be overwritten because "
                                   "`-overwrite` option is set.")
            else:
                maybe_overwrite = ("won't be overwritten, pass the "
                                   "`-overwrite` option if you want to.")
            logger.warning("Shards for corpus {} already exist, {}".format(
                shard_base, maybe_overwrite))
            existing_shards += [maybe_id]
    return existing_shards
def batch_iter(data, batch_size, batch_size_fn=None, batch_size_multiple=1):
    """Yield elements from data in chunks of batch_size, where each chunk size
    is a multiple of batch_size_multiple.

    This is an extended version of torchtext.data.batch.
    """
    # print("==================satrt batch_iter function======================")
    if batch_size_fn is None:

        def batch_size_fn(new, count, sofar):
            return count

    minibatch, size_so_far = [], 0
    for ex in data:
        # print(ex)
        minibatch.append(ex)
        size_so_far = batch_size_fn(ex, len(minibatch), size_so_far)
        if size_so_far >= batch_size:
            overflowed = 0
            if size_so_far > batch_size:
                overflowed += 1
            if batch_size_multiple > 1:
                overflowed += ((len(minibatch) - overflowed) %
                               batch_size_multiple)
            if overflowed == 0:
                yield minibatch
                minibatch, size_so_far = [], 0
            else:
                if overflowed == len(minibatch):
                    logger.warning("An example was ignored, more tokens"
                                   " than allowed by tokens batch_size")
                else:
                    yield minibatch[:-overflowed]
                    minibatch = minibatch[-overflowed:]
                    size_so_far = 0
                    for i, ex in enumerate(minibatch):
                        size_so_far = batch_size_fn(ex, i + 1, size_so_far)
    # print("====================through batch_iter yield minibatch=============================")
    if minibatch:
        yield minibatch
Beispiel #16
0
def _init_train(opt):
    """Common initilization stuff for all training process."""
    ArgumentParser.validate_prepare_opts(opt)

    if opt.train_from:
        # Load checkpoint if we resume from a previous training.
        checkpoint = load_checkpoint(ckpt_path=opt.train_from)
        fields = load_fields(opt.save_data, checkpoint)
        transforms_cls = get_transforms_cls(opt._all_transform)
        if (hasattr(checkpoint["opt"], '_all_transform') and
                len(opt._all_transform.symmetric_difference(
                    checkpoint["opt"]._all_transform)) != 0):
            _msg = "configured transforms is different from checkpoint:"
            new_transf = opt._all_transform.difference(
                checkpoint["opt"]._all_transform)
            old_transf = checkpoint["opt"]._all_transform.difference(
                opt._all_transform)
            if len(new_transf) != 0:
                _msg += f" +{new_transf}"
            if len(old_transf) != 0:
                _msg += f" -{old_transf}."
            logger.warning(_msg)
        if opt.update_vocab:
            logger.info("Updating checkpoint vocabulary with new vocabulary")
            fields, transforms_cls = prepare_fields_transforms(opt)
    else:
        checkpoint = None
        fields, transforms_cls = prepare_fields_transforms(opt)

    # Report src and tgt vocab sizes
    for side in ['src', 'tgt']:
        f = fields[side]
        try:
            f_iter = iter(f)
        except TypeError:
            f_iter = [(side, f)]
        for sn, sf in f_iter:
            if sf.use_vocab:
                logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))
    return checkpoint, fields, transforms_cls
Beispiel #17
0
def batch_iter(data, batch_size, batch_size_fn=None, batch_size_multiple=1):
    """Yield elements from data in chunks of batch_size, where each chunk size
    is a multiple of batch_size_multiple.

    This is an extended version of torchtext.data.batch.
    """
    if batch_size_fn is None:

        def batch_size_fn(new, count, sofar):
            return count

    minibatch, size_so_far = [], 0
    for ex in data:
        minibatch.append(ex)
        size_so_far = batch_size_fn(ex, len(minibatch), size_so_far)
        if size_so_far >= batch_size:
            overflowed = 0
            if size_so_far > batch_size:
                overflowed += 1
            if batch_size_multiple > 1:
                overflowed += ((len(minibatch) - overflowed) %
                               batch_size_multiple)
            if overflowed == 0:
                yield minibatch
                minibatch, size_so_far = [], 0
            else:
                if overflowed == len(minibatch):
                    logger.warning(
                        "The batch will be filled until we reach %d,"
                        "its size may exceed %d tokens" %
                        (batch_size_multiple, batch_size))
                else:
                    yield minibatch[:-overflowed]
                    minibatch = minibatch[-overflowed:]
                    size_so_far = 0
                    for i, ex in enumerate(minibatch):
                        size_so_far = batch_size_fn(ex, i + 1, size_so_far)
    if minibatch:
        yield minibatch
Beispiel #18
0
    def __init__(self, opt, generator, indicator_vocab, tgt_vocab, eps=1e-20):
        super(E2ELossCompute, self).__init__()

        self.key_model = opt.key_model
        self.tgt_vocab = tgt_vocab
        self.cur_dataset = None
        self.force_copy = opt.copy_attn_force
        self.top_k = opt.top_k
        self.sel_report_topk = opt.top_k
        self.sel_normalize_by_length = opt.sel_normalize_by_length
        self.gen_normalize_by_length = opt.gen_normalize_by_length
        self.incons_normalize_by_length = opt.incons_normalize_by_length
        self.pos_weight = opt.pos_weight    # default 9.0
        self.sel_threshold = opt.sel_threshold  # default 0.9
        self.sel_lambda = opt.sel_lambda    # default 0.5
        self.sel_train_ratio = opt.sel_train_ratio # default 1.0
        self.gen_lambda = opt.gen_lambda    # default 0.5
        self.incons_lambda = opt.incons_lambda  # default 0.5

        self.generator = generator
        if opt.key_model != 'key_selector':
            self.tgt_padding_idx = tgt_vocab.stoi[inputters.PAD_WORD]
        if opt.key_model != 'key_generator':
            assert len(indicator_vocab) == 3
            self.src_unk_idx = inputters.UNK
            self.sel_padding_idx = indicator_vocab.stoi[inputters.PAD_WORD]
            self.pos_idx = indicator_vocab.stoi['I']
            self.neg_idx = indicator_vocab.stoi['O']

        # BCEWithLogits loss for extraction (selector)
        if self.key_model == 'key_selector' or (self.key_model == 'key_end2end' and self.sel_lambda != 0.0):
            self.bcewithlogits_criterion =\
                nn.BCEWithLogitsLoss(reduction='none', pos_weight=torch.tensor([self.pos_weight]))
        else:
            self.bcewithlogits_criterion = None

        # CopyGenerator loss for generation (generator)
        # (len(tgt_vocab), force_copy, self.padding_idx)
        if self.key_model == 'key_generator' or self.key_model == 'key_end2end':
            self.copynmtloss_criterion =\
                onmt.modules.CopyGeneratorCriterion(len(tgt_vocab), self.force_copy, self.tgt_padding_idx)
        else:
            self.copynmtloss_criterion = 0.0

        # inconsistency loss for extraction attention and generating attention
        if self.key_model == 'key_end2end' and self.incons_lambda != 0.0:
            self.inconsistloss_criterion = InconsistencyLoss(top_k=self.top_k)

        if not self.sel_normalize_by_length:
            logger.warning("These selector losses will not be normalized by length since opt.sel_normalize_by_length=False!")
        if not self.gen_normalize_by_length:
            logger.warning("These generator losses will not be normalized by length since opt.gen_normalize_by_length=False!")
        if not self.incons_normalize_by_length:
            logger.warning("These inconsisitency losses will not be normalized by length since opt.incons_normalize_by_length=False!")
Beispiel #19
0
def main(opt, device_id):
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    configure_process(opt, device_id)
    init_logger(opt.log_file)
    assert len(opt.accum_count) == len(opt.accum_steps), \
        'Number of accum_count values must match number of accum_steps'
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)

        model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        ArgumentParser.update_model_opts(model_opt)
        ArgumentParser.validate_model_opts(model_opt)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        checkpoint = None
        model_opt = opt
        vocab = torch.load(opt.data + '.vocab.pt')

    logger.info('Loading alignment.')
    lemma_aligns = open(model_opt.lemma_align, 'rb').readlines()
    src_stoi = vocab['src'].base_field.vocab.stoi
    lemma_stoi = vocab['word_topic'].base_field.vocab.stoi
    w2l = {}
    word_to_lemma = []
    for pair in lemma_aligns:
        pair = pair.strip().split()
        w2l[src_stoi[pair[0].decode('utf-8')]] = \
            lemma_stoi[pair[1].decode('utf-8')]
    w2l[src_stoi['unk']] = lemma_stoi['unk']
    for index in range(len(vocab['src'].base_field.vocab.itos)):
        if index in w2l:
            word_to_lemma.append(w2l[index])
        else:
            word_to_lemma.append(w2l[lemma_stoi['unk']])
    word_to_lemma = torch.tensor(word_to_lemma)
    logger.info('Loading topic matrix')
    if device_id >= 0:
        topic_matrix = torch.load(opt.topic_matrix,
                                  map_location=torch.device(device_id))
    else:
        topic_matrix = torch.load(opt.topic_matrix)
    if opt.model_dtype == 'fp16':
        topic_matrix = topic_matrix.half()
    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(vocab,
                                opt.model_type,
                                dynamic_dict=opt.copy_attn)
    else:
        fields = vocab
    # Report src and tgt vocab sizes, including for features
    for side in ['src', 'tgt']:
        f = fields[side]
        try:
            f_iter = iter(f)
        except TypeError:
            f_iter = [(side, f)]
        for sn, sf in f_iter:
            if sf.use_vocab:
                logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            model_saver=model_saver)

    train_iter = build_dataset_iter("train", fields, opt)
    valid_iter = build_dataset_iter("valid", fields, opt, is_train=False)

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0
    trainer.train(topic_matrix,
                  word_to_lemma,
                  train_iter,
                  train_steps,
                  save_checkpoint_steps=opt.save_checkpoint_steps,
                  valid_iter=valid_iter,
                  valid_steps=opt.valid_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
Beispiel #20
0
def main(opt,
         fields,
         transforms_cls,
         checkpoint,
         device_id,
         batch_queue=None,
         semaphore=None):
    """Start training on `device_id`."""
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    configure_process(opt, device_id)
    init_logger(opt.log_file)

    model_opt = _get_model_opts(opt, checkpoint=checkpoint)

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    model.count_parameters(log=logger.info)

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            model_saver=model_saver)

    if batch_queue is None:
        _train_iter = _build_train_iter(opt, fields, transforms_cls)
        train_iter = IterOnDevice(_train_iter, device_id)
    else:
        assert semaphore is not None, \
            "Using batch_queue requires semaphore as well"

        def _train_iter():
            while True:
                batch = batch_queue.get()
                semaphore.release()
                # Move batch to specified device
                IterOnDevice.batch_to_device(batch, device_id)
                yield batch

        train_iter = _train_iter()

    valid_iter = _build_valid_iter(opt, fields, transforms_cls)
    if valid_iter is not None:
        valid_iter = IterOnDevice(valid_iter, device_id)

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0

    trainer.train(train_iter,
                  train_steps,
                  save_checkpoint_steps=opt.save_checkpoint_steps,
                  valid_iter=valid_iter,
                  valid_steps=opt.valid_steps)

    if trainer.report_manager.tensorboard_writer is not None:
        trainer.report_manager.tensorboard_writer.close()
Beispiel #21
0
def main(opt, device_id):
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    configure_process(opt, device_id)
    init_logger(opt.log_file)
    assert len(opt.accum_count) == len(opt.accum_steps), \
        'Number of accum_count values must match number of accum_steps'
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)

        model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        ArgumentParser.update_model_opts(model_opt)
        ArgumentParser.validate_model_opts(model_opt)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        checkpoint = None
        model_opt = opt
        vocab = torch.load(opt.data + '.vocab.pt')

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(
            vocab, opt.model_type, dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    # Report src and tgt vocab sizes, including for features
    for side in ['src', 'tgt']:
        f = fields[side]
        try:
            f_iter = iter(f)
        except TypeError:
            f_iter = [(side, f)]
        for sn, sf in f_iter:
            if sf.use_vocab:
                logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(
        opt, device_id, model, fields, optim, model_saver=model_saver)

    train_iter = build_dataset_iter("train", fields, opt)
    valid_iter = build_dataset_iter(
        "valid", fields, opt, is_train=False)

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0
    trainer.train(
        train_iter,
        train_steps,
        save_checkpoint_steps=opt.save_checkpoint_steps,
        valid_iter=valid_iter,
        valid_steps=opt.valid_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
Beispiel #22
0
def main(opt, device_id):
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    configure_process(opt, device_id)
    init_logger(opt.log_file)
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)

        model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        ArgumentParser.update_model_opts(model_opt)
        ArgumentParser.validate_model_opts(model_opt)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        checkpoint = None
        model_opt = opt
        vocab = torch.load(opt.data + '.vocab.pt')

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(vocab,
                                opt.model_type,
                                dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    # Report src and tgt vocab sizes, including for features
    for side in ['src', 'tgt']:
        f = fields[side]
        try:
            f_iter = iter(f)
        except TypeError:
            f_iter = [(side, f)]
        for sn, sf in f_iter:
            if sf.use_vocab:
                logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            model_saver=model_saver)

    train_iter = build_dataset_iter("train", fields, opt)
    valid_iter = build_dataset_iter("valid", fields, opt, is_train=False)

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0
    trainer.train(train_iter,
                  train_steps,
                  save_checkpoint_steps=opt.save_checkpoint_steps,
                  valid_iter=valid_iter,
                  valid_steps=opt.valid_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
Beispiel #23
0
def build_base_model(model_opt, fields, gpu, checkpoint=None, gpu_id=None):
    """
    Args:
        model_opt: the option loaded from checkpoint.
        fields: `Field` objects for the model.
        gpu (bool): whether to use gpu.
        checkpoint: the model gnerated by train phase, or a resumed snapshot
                    model from a stopped training.
        gpu_id (int or NoneType): Which GPU to use.
    Returns:
        the NMTModel.
    """

    assert model_opt.model_type in ["text", "img", "audio"], \
        "Unsupported model type %s" % model_opt.model_type

    # for backward compatibility
    if model_opt.rnn_size != -1:
        model_opt.enc_rnn_size = model_opt.rnn_size
        model_opt.dec_rnn_size = model_opt.rnn_size

    # Build embeddings.
    if model_opt.model_type == "text":
        src_fields = [f for n, f in fields['src']]
        assert len(src_fields) == 1
        src_field = src_fields[0]
        src_emb = build_embeddings(model_opt, src_field)
    else:
        src_emb = None

    # Build encoder.
    encoder = build_encoder(model_opt, src_emb)

    # Build decoder.
    tgt_fields = [f for n, f in fields['tgt']]
    assert len(tgt_fields) == 1
    tgt_field = tgt_fields[0]
    tgt_emb = build_embeddings(model_opt, tgt_field, for_encoder=False)

    # Share the embedding matrix - preprocess with share_vocab required.
    if model_opt.share_embeddings:
        # src/tgt vocab should be the same if `-share_vocab` is specified.
        assert src_field.base_field.vocab == tgt_field.base_field.vocab, \
            "preprocess with -share_vocab if you use share_embeddings"

        tgt_emb.word_lut.weight = src_emb.word_lut.weight

    decoder = build_decoder(model_opt, tgt_emb)

    # Build NMTModel(= encoder + decoder).
    if gpu and gpu_id is not None:
        device = torch.device("cuda", gpu_id)
    elif gpu and not gpu_id:
        device = torch.device("cuda")
    elif not gpu:
        device = torch.device("cpu")
    model = onmt.models.NMTModel(encoder, decoder)

    # Build Generator.
    if not model_opt.copy_attn:
        if model_opt.generator_function == "sparsemax":
            gen_func = onmt.modules.sparse_activations.LogSparsemax(dim=-1)
        else:
            gen_func = nn.LogSoftmax(dim=-1)
        generator = nn.Sequential(
            nn.Linear(model_opt.dec_rnn_size,
                      len(fields["tgt"][0][1].base_field.vocab)),
            Cast(torch.float32), gen_func)
        if model_opt.share_decoder_embeddings:
            generator[0].weight = decoder.embeddings.word_lut.weight
    else:
        assert len(fields["tgt"]) == 1
        tgt_base_field = fields["tgt"][0][1].base_field
        vocab_size = len(tgt_base_field.vocab)
        pad_idx = tgt_base_field.vocab.stoi[tgt_base_field.pad_token]
        generator = CopyGenerator(model_opt.dec_rnn_size, vocab_size, pad_idx)

    # Load the model states from checkpoint or initialize them.
    if checkpoint is not None:
        # This preserves backward-compat for models using customed layernorm
        def fix_key(s):
            s = re.sub(r'(.*)\.layer_norm((_\d+)?)\.b_2',
                       r'\1.layer_norm\2.bias', s)
            s = re.sub(r'(.*)\.layer_norm((_\d+)?)\.a_2',
                       r'\1.layer_norm\2.weight', s)
            return s

        checkpoint['model'] = {
            fix_key(k): v
            for k, v in checkpoint['model'].items()
        }
        # end of patch for backward compatibility

        model.load_state_dict(checkpoint['model'], strict=False)
        generator.load_state_dict(checkpoint['generator'], strict=False)
    else:
        if model_opt.param_init != 0.0:
            for p in model.parameters():
                p.data.uniform_(-model_opt.param_init, model_opt.param_init)
            for p in generator.parameters():
                p.data.uniform_(-model_opt.param_init, model_opt.param_init)
        if model_opt.param_init_glorot:
            for p in model.parameters():
                if p.dim() > 1:
                    xavier_uniform_(p)
            for p in generator.parameters():
                if p.dim() > 1:
                    xavier_uniform_(p)

        if hasattr(model.encoder, 'embeddings'):
            model.encoder.embeddings.load_pretrained_vectors(
                model_opt.pre_word_vecs_enc)
        if hasattr(model.decoder, 'embeddings'):
            model.decoder.embeddings.load_pretrained_vectors(
                model_opt.pre_word_vecs_dec)

    model.generator = generator
    model.to(device)
    if model_opt.model_dtype == 'fp16':
        logger.warning('FP16 is experimental, the generated checkpoints may '
                       'be incompatible with a future version')
        model.half()

    return model
Beispiel #24
0
 def warm_up(self, vocabs):
     self.vocab = vocabs
     if vocabs is None:
         logger.warning(
             "Switchout disable as no vocab, shouldn't happen in training!")
     self.temperature = self.opts.switchout_temperature
Beispiel #25
0
 def warm_up(self, vocabs):
     super().warm_up(None)
     self.vocabs = vocabs
     if vocabs is None:
         logger.warning(
             "Switchout disable as no vocab, shouldn't happen in training!")
Beispiel #26
0
def main(opt, device_id, batch_queue=None, semaphore=None):
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    configure_process(opt, device_id)
    init_logger(opt.log_file)
    assert len(opt.accum_count) == len(opt.accum_steps), \
        'Number of accum_count values must match number of accum_steps'
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        ArgumentParser.update_model_opts(model_opt)
        ArgumentParser.validate_model_opts(model_opt)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        checkpoint = None
        model_opt = opt
        vocab = torch.load(opt.data + '.vocab.pt')

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(vocab,
                                opt.model_type,
                                dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    # Report src and tgt vocab sizes, including for features
    for side in ['src', 'tgt']:
        f = fields[side]
        try:
            f_iter = iter(f)
        except TypeError:
            f_iter = [(side, f)]
        for sn, sf in f_iter:
            if sf.use_vocab:
                logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec, nontrainable = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('non-trainable parameters (tgt_out_emb): %d' % nontrainable)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            model_saver=model_saver)

    if batch_queue is None:
        if len(opt.data_ids) > 1:
            train_shards = []
            for train_id in opt.data_ids:
                shard_base = "train_" + train_id
                train_shards.append(shard_base)
            train_iter = build_dataset_iter_multiple(train_shards, fields, opt)
        else:
            if opt.data_ids[0] is not None:
                shard_base = "train_" + opt.data_ids[0]
            else:
                shard_base = "train"
            train_iter = build_dataset_iter(shard_base, fields, opt)

    else:
        assert semaphore is not None, \
            "Using batch_queue requires semaphore as well"

        def _train_iter():
            while True:
                batch = batch_queue.get()
                semaphore.release()
                yield batch

        train_iter = _train_iter()

    valid_iter = build_dataset_iter("valid", fields, opt, is_train=False)

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0

    trainer.train(train_iter,
                  train_steps,
                  save_checkpoint_steps=opt.save_checkpoint_steps,
                  valid_iter=valid_iter,
                  valid_steps=opt.valid_steps)

    if trainer.report_manager.tensorboard_writer is not None:
        trainer.report_manager.tensorboard_writer.close()
Beispiel #27
0
def main(opt, device_id, batch_queue=None, semaphore=None):
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    configure_process(opt, device_id)
    init_logger(opt.log_file)

    # save training settings
    if opt.log_file:
        shutil.copy2(opt.config, opt.exp_dir)
    logger.info(vars(opt))

    assert len(opt.accum_count) == len(opt.accum_steps), \
        'Number of accum_count values must match number of accum_steps'
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        ArgumentParser.update_model_opts(model_opt)
        ArgumentParser.validate_model_opts(model_opt)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        checkpoint = None
        model_opt = opt
        # added by @memray for multiple datasets
        if opt.vocab and opt.vocab != 'none':
            vocab = torch.load(opt.vocab)
        elif opt.encoder_type == 'pretrained':
            vocab = None
        else:
            vocab = None

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(
            vocab, opt.model_type, dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    # @memray: a temporary workaround, as well as train.py line 43
    if opt.model_type == "keyphrase":
        if opt.tgt_type in ["one2one", "multiple"]:
            if 'sep_indices' in fields:
                del fields['sep_indices']
        else:
            if 'sep_indices' not in fields:
                sep_indices = Field(
                    use_vocab=False, dtype=torch.long,
                    postprocessing=make_tgt, sequential=False)
                fields["sep_indices"] = sep_indices
        if 'src_ex_vocab' not in fields:
            src_ex_vocab = RawField()
            fields["src_ex_vocab"] = src_ex_vocab

    tokenizer = None
    if opt.pretrained_tokenizer:
        tokenizer = load_pretrained_tokenizer(opt.pretrained_tokenizer, opt.cache_dir, opt.special_vocab_path)
        setattr(opt, 'vocab_size', len(tokenizer))
    if opt.data_type == 'news':
        fields = reload_news_fields(fields, opt, tokenizer)


    # Report src and tgt vocab sizes, including for features
    for side in ['src', 'tgt']:
        f = fields[side]
        try:
            f_iter = iter(f)
        except TypeError:
            f_iter = [(side, f)]
        for sn, sf in f_iter:
            if sf.use_vocab:
                logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(
        opt, device_id, model, fields, optim, model_saver=model_saver)

    if batch_queue is None:
        if len(opt.data_ids) > 1:
            # added by @memray, for loading multiple datasets
            if opt.multi_dataset:
                shard_base = "train"
                train_iter = build_dataset_iter(shard_base, fields, opt, tokenizer=tokenizer)
            else:
                train_shards = []
                for train_id in opt.data_ids:
                    shard_base = "train_" + train_id
                    train_shards.append(shard_base)
                train_iter = build_dataset_iter_multiple(train_shards, fields, opt, tokenizer=tokenizer)
        else:
            shard_base = "train"
            train_iter = build_dataset_iter(shard_base, fields, opt)

    else:
        assert semaphore is not None, \
            "Using batch_queue requires semaphore as well"

        def _train_iter():
            while True:
                batch = batch_queue.get()
                semaphore.release()
                yield batch

        train_iter = _train_iter()

    if opt.valid:
        valid_iter = build_dataset_iter(
            "valid", fields, opt, is_train=False)
    else:
        valid_iter = None

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0

    trainer.train(
        train_iter,
        train_steps,
        save_checkpoint_steps=opt.save_checkpoint_steps,
        valid_iter=valid_iter,
        valid_steps=opt.valid_steps)

    if trainer.report_manager.tensorboard_writer is not None:
        trainer.report_manager.tensorboard_writer.close()
Beispiel #28
0
def main(opt, device_id):
    opt = training_opt_postprocessing(opt, device_id)
    init_logger(opt.log_file)
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)

        # Load default opts values then overwrite it with opts from
        # the checkpoint. It's usefull in order to re-train a model
        # after adding a new option (not set in checkpoint)
        dummy_parser = configargparse.ArgumentParser()
        opts.model_opts(dummy_parser)
        default_opt = dummy_parser.parse_known_args([])[0]

        model_opt = default_opt
        model_opt.__dict__.update(checkpoint['opt'].__dict__)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        checkpoint = None
        model_opt = opt
        vocab = torch.load(opt.data + '.vocab.pt')

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        data_type = opt.model_type
        fields = load_old_vocab(vocab, data_type, dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    # Report src and tgt vocab sizes, including for features
    for side in ['src', 'tgt']:
        for name, f in fields[side]:
            try:
                f_iter = iter(f)
            except TypeError:
                f_iter = [(name, f)]
            for sn, sf in f_iter:
                if sf.use_vocab:
                    logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            model_saver=model_saver)

    # this line is kind of a temporary kludge because different objects expect
    # fields to have a different structure
    dataset_fields = dict(chain.from_iterable(fields.values()))

    train_iter = build_dataset_iter("train", dataset_fields, opt)
    valid_iter = build_dataset_iter("valid",
                                    dataset_fields,
                                    opt,
                                    is_train=False)

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0
    trainer.train(train_iter,
                  train_steps,
                  save_checkpoint_steps=opt.save_checkpoint_steps,
                  valid_iter=valid_iter,
                  valid_steps=opt.valid_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
Beispiel #29
0
def main(opt, device_id):
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    #     import pdb
    #     _check_ = torch.load("/home/irteam/users/kaist/ginalee/clean_data/baselines/9-domain5-185pre_step_2500.pt")
    #     model_encoder = [i for i in _check_['model'].keys() if "encoder" in i.split(".")]
    #     encoder = {}
    #     pdb.set_trace()
    #     for i, param in enumerate(model_encoder):
    #         if i == 0:
    #             encoder['embeddings.word_embeddings.weight'] = _check_['model'][param]
    #             continue
    #         param_ = ".".join(param.split(".")[1:])
    # #         if param.split(".")[1] == 'encoder':
    # #             param_ = ".".join(param.split(".")[2:])
    # #         else:
    # #             param_ = ".".join(param.split(".")[1:])
    #         encoder[param_] = _check_['model'][param]
    #     pdb.set_trace()

    configure_process(opt, device_id)
    init_logger(opt.log_file)
    logger.info(opt)
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)

        model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        ArgumentParser.update_model_opts(model_opt)
        ArgumentParser.validate_model_opts(model_opt)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)

        load_vocab = torch.load(opt.data + '.vocab.pt')
        vocab = checkpoint['vocab']
        load_vocab['src'].fields[0][1].vocab = vocab['src'].fields[0][1].vocab
        load_vocab['tgt'].fields[0][1].vocab = vocab['tgt'].fields[0][1].vocab
        vocab = load_vocab
    else:
        checkpoint = None
        model_opt = opt
        vocab = torch.load(opt.data + '.vocab.pt')

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(vocab,
                                opt.model_type,
                                dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    # Report src and tgt vocab sizes, including for features
    for side in ['src', 'tgt']:
        f = fields[side]
        try:
            f_iter = iter(f)
        except TypeError:
            f_iter = [(side, f)]
        for sn, sf in f_iter:
            if sf.use_vocab:
                logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    if opt.pretrain_from:
        check = torch.load(opt.pretrain_from,
                           map_location=lambda storage, loc: storage)
        model.load_state_dict(check['model'], strict=False)
        model.load_state_dict(check['generator'], strict=False)
        if 'dom_classifier' in check:
            model.load_state_dict(check['dom_classifier'], strict=False)

    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    translator = None
    if opt.domain_cls_enc == False:
        translator = train_build_translator(opt,
                                            model,
                                            model_opt,
                                            fields,
                                            report_score=True)

    trainer = build_trainer(translator,
                            opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            model_saver=model_saver)

    train_iter = build_dataset_iter("train", fields, opt)
    valid_iter = build_dataset_iter("valid", fields, opt, is_train=False)

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0
    trainer.train(train_iter,
                  train_steps,
                  save_checkpoint_steps=opt.save_checkpoint_steps,
                  valid_iter=valid_iter,
                  valid_steps=opt.valid_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id):
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    configure_process(opt, device_id)
    init_logger(opt.log_file)
    # Load checkpoint if we resume from a previous training.
    load_str = opt.train_from if opt.train_from else opt.load_uncond_from
    if load_str:
        logger.info('Loading checkpoint from %s' % load_str)
        checkpoint = torch.load(load_str,
                                map_location=lambda storage, loc: storage)

        logger.info('Loading vocab from checkpoint at %s.' % load_str)
        vocab = checkpoint['vocab']

        if opt.train_from:
            model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
            ArgumentParser.update_model_opts(model_opt)
            ArgumentParser.validate_model_opts(model_opt)
        else:
            model_opt = opt
    else:
        checkpoint = None
        model_opt = opt
        vocab = torch.load(opt.data + '.vocab.pt')

    if opt.gpt2_params_path is not None:
        import tensorflow as tf
        import numpy as np
        # Taken from pytorch-pretrained-BERT:
        # Load weights from TF model
        logger.info("Loading TF GPT weights...")
        init_vars = tf.train.list_variables(opt.gpt2_params_path)
        names = []
        arrays = []
        for name, shape in init_vars:
            if opt.gpt_emb_only and ('wpe' not in name and 'wte' not in name):
                continue
            if opt.gpt_wpe_only and 'wpe' not in name:
                continue
            #print("Loading TF weight {} with shape {}".format(name, shape))
            array = tf.train.load_variable(opt.gpt2_params_path, name)
            names.append(name)
            arrays.append(array.squeeze())
        logger.info("Done.")

        if checkpoint is None:
            checkpoint = {'gpt2_params': zip(names, arrays)}
        else:
            checkpoint['gpt2_params'] = zip(names, arrays)

    if opt.encoder_from is not None:
        logger.info('Loading checkpoint with encoder from %s' %
                    opt.encoder_from)
        enc_checkpoint = torch.load(opt.encoder_from,
                                    map_location=lambda storage, loc: storage)
        enc_vocab = enc_checkpoint['vocab']
        if vocab['src'].base_field.vocab != enc_vocab['src'].base_field.vocab:
            raise ValueError(
                'encoder vocab and model vocab need to be identical it using pretrained encoder'
            )
        if checkpoint is None:
            checkpoint = {}
        checkpoint['enc_model'] = enc_checkpoint['model']

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(vocab,
                                opt.model_type,
                                dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    # Report src and tgt vocab sizes, including for features
    sides = ['tgt'] if opt.model_type == 'none' else ['src', 'tgt']
    for side in sides:
        f = fields[side]
        try:
            f_iter = iter(f)
        except TypeError:
            f_iter = [(side, f)]
        for sn, sf in f_iter:
            if sf.use_vocab:
                logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec, lm_dec = _tally_parameters(model)
    n_params_t, enc_t, dec_t, lm_dec_t = _tally_parameters(model,
                                                           only_trainable=True)
    logger.info('encoder: %d (%d)' % (enc, enc_t))
    logger.info('decoder: %d (%d)' % (dec, dec_t))
    if opt.simple_fusion:
        logger.info('lm decoder: %d (%d)' % (lm_dec, lm_dec_t))

    logger.info('* number of parameters: %d (%d)' % (n_params, n_params_t))
    _check_save_model_path(opt)

    if not opt.train_from and opt.gpt2_params_path is not None:
        checkpoint = None

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            model_saver=model_saver)

    train_iter = build_dataset_iter("train", fields, opt)
    valid_iter = build_dataset_iter("valid", fields, opt, is_train=False)

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0
    trainer.train(train_iter,
                  train_steps,
                  save_checkpoint_steps=opt.save_checkpoint_steps,
                  valid_iter=valid_iter,
                  valid_steps=opt.valid_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
Beispiel #31
0
def main(opt, device_id):
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    if opt.local_rank != -1:
        torch.cuda.set_device(opt.local_rank)
        device = torch.device("cuda", opt.local_rank)
        torch.distributed.init_process_group(backend='nccl')
        device_id = opt.local_rank
        world_size = torch.distributed.get_world_size()
    else:
        if device_id == -1:
            device = torch.device("cpu")
        else:
            device = torch.device("cuda", device_id)
    if opt.local_rank > 0:
        logger.disabled = True
    configure_process(opt, device_id)
    init_logger(opt.log_file)
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)

        model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        ArgumentParser.update_model_opts(model_opt)
        ArgumentParser.validate_model_opts(model_opt)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        checkpoint = None
        model_opt = opt
        vocab = torch.load(opt.data + '.vocab.pt')

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(vocab,
                                opt.model_type,
                                dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    # Report src and tgt vocab sizes, including for features
    for side in ['src', 'tgt']:
        f = fields[side]
        try:
            f_iter = iter(f)
        except TypeError:
            f_iter = [(side, f)]
        for sn, sf in f_iter:
            if sf.use_vocab:
                logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            model_saver=model_saver)

    if opt.bert_kd:
        src_vocab = vocab['src'].fields[0][1].vocab.stoi
        tgt_vocab = vocab['tgt'].fields[0][1].vocab.stoi
        assert 0 < opt.kd_topk <= 128
        train_dataset = BertKdDataset(opt.data_db,
                                      opt.bert_dump,
                                      src_vocab,
                                      tgt_vocab,
                                      max_len=150,
                                      k=opt.kd_topk)
        BUCKET_SIZE = 8192
        if True or opt.local_rank == -1 and opt.world_size == 1:
            train_sampler = TokenBucketSampler(train_dataset.keys,
                                               BUCKET_SIZE,
                                               opt.batch_size,
                                               batch_multiple=1)
        else:
            assert False  # seems like it's handled in training loop
            train_sampler = DistributedTokenBucketSampler(world_size,
                                                          device_id,
                                                          train_dataset.keys,
                                                          BUCKET_SIZE,
                                                          opt.batch_size,
                                                          batch_multiple=1)
        train_loader = DataLoader(train_dataset,
                                  batch_sampler=train_sampler,
                                  num_workers=4,
                                  collate_fn=BertKdDataset.pad_collate)
        train_iter = cycle_loader(train_loader, device)
    else:
        train_iter = build_dataset_iter("train", fields, opt)
    valid_iter = build_dataset_iter("valid", fields, opt, is_train=False)

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0
    trainer.train(train_iter,
                  train_steps,
                  save_checkpoint_steps=opt.save_checkpoint_steps,
                  valid_iter=valid_iter,
                  valid_steps=opt.valid_steps)

    if opt.tensorboard:
        if trainer.report_manager.tensorboard_writer:
            trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id, batch_queue=None, semaphore=None):
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    configure_process(opt, device_id)
    init_logger(opt.log_file)
    assert len(opt.accum_count) == len(opt.accum_steps), \
        'Number of accum_count values must match number of accum_steps'
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        ArgumentParser.update_model_opts(model_opt)
        ArgumentParser.validate_model_opts(model_opt)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        checkpoint = None
        model_opt = opt
        vocab = torch.load(opt.data + '.vocab.pt')

    if opt.teacher_model_path:
        logger.info('Loading teacher model from {path}'.format(
            path=opt.teacher_model_path))
        teacher_model_ckpt = torch.load(
            opt.teacher_model_path, map_location=lambda storage, loc: storage)

        teacher_model_opt = ArgumentParser.ckpt_model_opts(
            teacher_model_ckpt['opt'])
        ArgumentParser.update_model_opts(teacher_model_opt)
        ArgumentParser.validate_model_opts(teacher_model_opt)
        logger.info('Loading vocab from checkpoint at {path}'.format(
            path=opt.teacher_model_path))
        teacher_vocab = teacher_model_ckpt['vocab']

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(vocab,
                                opt.model_type,
                                dynamic_dict=opt.copy_attn)
    else:
        fields = vocab
        teacher_fields = teacher_vocab if opt.teacher_model_path else None

    # patch for fields that may be missing in old data/model
    # patch_fields(opt, fields)

    # Report src and tgt vocab sizes, including for features
    report_vocab_size(fields)
    if teacher_fields is not None:
        report_vocab_size(teacher_fields)

    # Build model.
    fields_opt = {"original": fields, "teacher": teacher_fields}
    model = custom_builder.build_model(model_opt, opt, fields_opt, checkpoint)
    # model = build_model(model_opt, opt, fields, checkpoint)
    teacher_model = build_model(
        teacher_model_opt, teacher_model_opt, teacher_fields,
        teacher_model_ckpt) if opt.teacher_model_path else None

    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    if teacher_model is not None:
        n_params, enc, dec = _tally_parameters(teacher_model)
        logger.info('encoder: %d' % enc)
        logger.info('decoder: %d' % dec)
        logger.info('* number of parameters: %d' % n_params)
        _check_save_model_path(teacher_model_opt)

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    # model_saver = build_model_saver(model_opt, opt, model, fields, optim)
    model_saver = custom_model_saver.build_model_saver(model_opt, opt, model,
                                                       fields_opt, optim)

    tgt_field = dict(teacher_fields)["tgt"].base_field if teacher_model is not None \
        else dict(fields)["tgt"].base_field
    sos_id = tgt_field.vocab.stoi[tgt_field.init_token]

    if teacher_model is not None and opt.word_sampling:
        sampler = Emulator(teacher_model,
                           teacher_fields,
                           device_id,
                           max_length=50,
                           random_sampling_topk=5)
    else:
        sampler = None

    if teacher_model is not None:
        trainer = build_trainer(opt,
                                device_id,
                                model,
                                teacher_fields,
                                optim,
                                model_saver,
                                teacher_model=teacher_model,
                                emulator=sampler)
    else:
        trainer = build_trainer(opt,
                                device_id,
                                model,
                                fields,
                                optim,
                                model_saver,
                                teacher_model=teacher_model,
                                emulator=sampler)

    if batch_queue is None:
        if len(opt.data_ids) > 1:
            train_shards = []
            for train_id in opt.data_ids:
                shard_base = "train_" + train_id
                train_shards.append(shard_base)
            train_iter = build_dataset_iter_multiple(train_shards, fields, opt)
        else:
            if opt.data_ids[0] is not None:
                shard_base = "train_" + opt.data_ids[0]
            else:
                shard_base = "train"
            train_iter = build_dataset_iter(shard_base, fields, opt)

    else:
        assert semaphore is not None, \
            "Using batch_queue requires semaphore as well"

        def _train_iter():
            while True:
                batch = batch_queue.get()
                semaphore.release()
                yield batch

        train_iter = _train_iter()

    valid_iter = build_dataset_iter("valid", fields, opt, is_train=False)

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0

    trainer.train(train_iter,
                  train_steps,
                  sos_id=sos_id,
                  save_checkpoint_steps=opt.save_checkpoint_steps,
                  valid_iter=valid_iter,
                  valid_steps=opt.valid_steps)

    if trainer.report_manager.tensorboard_writer is not None:
        trainer.report_manager.tensorboard_writer.close()