예제 #1
0
 def get_fields(vocab):
     if old_style_vocab(vocab):
         return load_old_vocab(vocab,
                               opt.model_type,
                               dynamic_dict=opt.copy_attn)
     else:
         return vocab
예제 #2
0
    def load(cls, path, args):
        vocab = torch.load(path)

        # check for code where vocab is saved instead of fields
        # (in the future this will be done in a smarter way)
        if old_style_vocab(vocab):
            fields = load_old_vocab(
                vocab, args.model_type, dynamic_dict=args.copy_attn)
        else:
            fields = vocab

        # Report src and tgt vocab sizes, including for features
        for side in ['src', 'tgt']:
            f = fields[side]
            try:
                f_iter = iter(f)
            except TypeError:
                f_iter = [(side, f)]
            for sn, sf in f_iter:
                if sf.use_vocab:
                    print(f'| [{sn}] dictionary: {len(sf.vocab)} types')

        return cls(fields)
예제 #3
0
def main(opt, device_id):
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    configure_process(opt, device_id)
    init_logger(opt.log_file)
    assert len(opt.accum_count) == len(opt.accum_steps), \
        'Number of accum_count values must match number of accum_steps'
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)

        model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        ArgumentParser.update_model_opts(model_opt)
        ArgumentParser.validate_model_opts(model_opt)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        checkpoint = None
        model_opt = opt
        #vocab = torch.load(opt.data + '.vocab.pt')

    train_iters = OrderedDict()
    valid_iters = OrderedDict()

    encoders = OrderedDict()
    decoders = OrderedDict()

    generators = OrderedDict()
    src_vocabs = OrderedDict()
    tgt_vocabs = OrderedDict()
    Fields_dict = OrderedDict()

    # variables needed for sharing the same embedding matrix across encoders and decoders
    firstTime = True
    weightToShare = None

    # we share the word embedding space when source lang and target lang are the same
    mapLang2Emb = {}
    #for (src_tgt_lang), data_path in zip(opt.src_tgt, opt.data):
    for index in range(len(opt.src_tgt)):
        src_tgt_lang = opt.src_tgt[index]
        data_path = opt.data[index]
        local_enc_dec_opts = AttrDict({
            key: model_opt.__dict__[key]
            for key in model_opt.__dict__.keys()
        })
        local_enc_dec_opts.model_type = update_to_local_attr(
            model_opt.model_type, index)
        #local_enc_dec_opts.audio_enc_pooling = model_opt.audio_enc_pooling[index]
        local_enc_dec_opts.audio_enc_pooling = update_to_local_attr(
            model_opt.audio_enc_pooling, index)
        local_enc_dec_opts.enc_layers = update_to_local_attr(
            model_opt.enc_layers, index)
        local_enc_dec_opts.dec_layers = update_to_local_attr(
            model_opt.dec_layers, index)
        local_enc_dec_opts.rnn_type = update_to_local_attr(
            model_opt.rnn_type, index)
        local_enc_dec_opts.encoder_type = update_to_local_attr(
            model_opt.encoder_type, index)
        local_enc_dec_opts.batch_size = update_to_local_attr(
            model_opt.batch_size, index)
        local_enc_dec_opts.batch_type = update_to_local_attr(
            model_opt.batch_type, index)
        local_enc_dec_opts.normalization = update_to_local_attr(
            model_opt.normalization, index)
        #local_enc_dec_opts.dec_rnn_size = model_opt.dec_rnn_size[index]

        src_lang, tgt_lang = src_tgt_lang.split('-')

        vocab = torch.load(data_path + '.vocab.pt')

        # check for code where vocab is saved instead of fields
        # (in the future this will be done in a smarter way)
        if old_style_vocab(vocab):
            fields = load_old_vocab(vocab,
                                    opt.model_type[0],
                                    dynamic_dict=opt.copy_attn)
        else:
            fields = vocab

        # Report src and tgt vocab sizes, including for features
        for side in ['src', 'tgt']:
            f = fields[side]
            try:
                f_iter = iter(f)
            except TypeError:
                f_iter = [(side, f)]
            for sn, sf in f_iter:
                if sf.use_vocab:
                    logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

        # Build model.
        encoder, src_embeddings = build_embeddings_then_encoder(
            local_enc_dec_opts, fields)

        encoders[src_lang] = encoder

        decoder, generator, tgt_embeddings = build_decoder_and_generator(
            local_enc_dec_opts, fields)

        decoders[tgt_lang] = decoder

        # Share the embedding matrix across all the encoders and decoders - preprocess with share_vocab required.
        if model_opt.share_embeddings and firstTime:
            tgt_embeddings.word_lut.weight = src_embeddings.word_lut.weight
            weightToShare = src_embeddings.word_lut.weight
        if model_opt.share_embeddings and (not firstTime):
            tgt_embeddings.word_lut.weight = weightToShare
            src_embeddings.word_lut.weight = weightToShare
        firstTime = False

        #TEST
        #if src_lang in mapLang2Emb:
        if src_lang in mapLang2Emb and model_opt.model_type == "text":
            encoder.embeddings.word_lut.weight = mapLang2Emb.get(src_lang)
        #TEST
        #else:
        elif model_opt.model_type == "text":
            mapLang2Emb[src_lang] = src_embeddings.word_lut.weight
        if tgt_lang in mapLang2Emb:
            decoder.embeddings.word_lut.weight = mapLang2Emb.get(tgt_lang)
        else:
            mapLang2Emb[tgt_lang] = tgt_embeddings.word_lut.weight

        #TEST
        if model_opt.model_type == "text":
            src_vocabs[src_lang] = fields['src'].base_field.vocab
        tgt_vocabs[tgt_lang] = fields['tgt'].base_field.vocab

        generators[tgt_lang] = generator

        # add this dataset iterator to the training iterators
        train_iters[(src_lang, tgt_lang)] = build_dataset_iter_fct(
            'train', fields, data_path, local_enc_dec_opts)
        # add this dataset iterator to the validation iterators
        valid_iters[(src_lang,
                     tgt_lang)] = build_dataset_iter_fct('valid',
                                                         fields,
                                                         data_path,
                                                         local_enc_dec_opts,
                                                         is_train=False)

        Fields_dict[src_tgt_lang] = fields

    # Build model.
    model = build_model(model_opt, opt, fields, encoders, decoders, generators,
                        src_vocabs, tgt_vocabs, checkpoint)

    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, Fields_dict, optim)

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            generators,
                            tgt_vocabs,
                            model_saver=model_saver)

    # TODO: not implemented yet
    #train_iterables = []
    #if len(opt.data_ids) > 1:
    #    for train_id in opt.data_ids:
    #        shard_base = "train_" + train_id
    #        iterable = build_dataset_iter(shard_base, fields, opt, multi=True)
    #        train_iterables.append(iterable)
    #    train_iter = MultipleDatasetIterator(train_iterables, device_id, opt)
    #else:
    #    train_iter = build_dataset_iter("train", fields, opt)

    #valid_iter = build_dataset_iter(
    #    "valid", fields, opt, is_train=False)

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0
    trainer.train(train_iters, train_steps, opt.save_checkpoint_steps,
                  valid_iters, opt.valid_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
예제 #4
0
def main(opt, device_id):
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    configure_process(opt, device_id)
    init_logger(opt.log_file)
    assert len(opt.accum_count) == len(opt.accum_steps), \
        'Number of accum_count values must match number of accum_steps'
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)

        model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        ArgumentParser.update_model_opts(model_opt)
        ArgumentParser.validate_model_opts(model_opt)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        checkpoint = None
        model_opt = opt
        vocab = torch.load(opt.data + '.vocab.pt')

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(
            vocab, opt.model_type, dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    # Report src and tgt vocab sizes, including for features
    for side in ['src', 'tgt']:
        f = fields[side]
        try:
            f_iter = iter(f)
        except TypeError:
            f_iter = [(side, f)]
        for sn, sf in f_iter:
            if sf.use_vocab:
                logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(
        opt, device_id, model, fields, optim, model_saver=model_saver)

    train_iter = build_dataset_iter("train", fields, opt)
    valid_iter = build_dataset_iter(
        "valid", fields, opt, is_train=False)

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0
    trainer.train(
        train_iter,
        train_steps,
        save_checkpoint_steps=opt.save_checkpoint_steps,
        valid_iter=valid_iter,
        valid_steps=opt.valid_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
예제 #5
0
def main(opt,
         device_id,
         batch_queue=None,
         semaphore=None,
         train_iter=None,
         passed_fields=None):
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    configure_process(opt, device_id)
    init_logger(opt.log_file)
    assert len(opt.accum_count) == len(opt.accum_steps), \
        'Number of accum_count values must match number of accum_steps'
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        if opt.use_opt_from_trained:
            model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        else:
            model_opt = opt
        ArgumentParser.validate_model_opts(model_opt)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        checkpoint = None
        model_opt = opt
        vocab = torch.load(opt.data + '.vocab.pt')

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    aux_fields = None
    if passed_fields is not None:
        fields = passed_fields['main']
        aux_fields = passed_fields['crosslingual']
    elif old_style_vocab(vocab):
        fields = load_old_vocab(vocab,
                                opt.model_type,
                                dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    # Report src and tgt vocab sizes, including for features
    for side in ['src', 'tgt']:
        f = fields[side]
        try:
            f_iter = iter(f)
        except TypeError:
            f_iter = [(side, f)]
        for sn, sf in f_iter:
            if sf.use_vocab:
                logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

    # Build model.
    model = build_model(model_opt,
                        opt,
                        fields,
                        checkpoint,
                        aux_fields=aux_fields)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    if opt.almt_only:
        almt = model.encoder.embeddings.almt_layers['mapping']
        logger.info('Only training the alignment mapping.')
        optim = Optimizer.from_opt(almt, opt, checkpoint=checkpoint)
    else:
        optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt,
                                    opt,
                                    model,
                                    fields,
                                    optim,
                                    aux_fields=aux_fields)

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            model_saver=model_saver,
                            aux_fields=aux_fields)

    if train_iter is not None:
        pass  # NOTE Use the passed one.
    elif batch_queue is None:
        if len(opt.data_ids) > 1:
            train_shards = []
            for train_id in opt.data_ids:
                shard_base = "train_" + train_id
                train_shards.append(shard_base)
            train_iter = build_dataset_iter_multiple(train_shards, fields, opt)
        else:
            if opt.data_ids[0] is not None:
                shard_base = "train_" + opt.data_ids[0]
            else:
                shard_base = "train"
            train_iter = build_dataset_iter(shard_base, fields, opt)

    else:
        assert semaphore is not None, \
            "Using batch_queue requires semaphore as well"

        def _train_iter():
            while True:
                batch = batch_queue.get()
                semaphore.release()
                yield batch

        train_iter = _train_iter()

    cl_valid_iter = None
    if opt.crosslingual:
        valid_iter = build_dataset_iter("valid",
                                        fields,
                                        opt,
                                        is_train=False,
                                        task_cls=Eat2PlainMonoTask)
        if opt.crosslingual_dev_data:
            # NOTE I used 'train' to prepare this in `eat_prepare.sh`, so I use 'train' here as well.
            cl_valid_iter = build_dataset_iter(
                'train',
                fields,
                opt,
                is_train=False,
                data_attr='crosslingual_dev_data',
                task_cls=Eat2PlainCrosslingualTask)
        # NOTE This is for the second eat->plain task.
        aux_valid_iter = build_dataset_iter('valid',
                                            fields,
                                            opt,
                                            is_train=False,
                                            data_attr='aux_train_data',
                                            task_cls=Eat2PlainMonoTask)
        valid_iters = [valid_iter, aux_valid_iter]
    else:
        valid_iters = [
            build_dataset_iter("valid", fields, opt, is_train=False)
        ]

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0

    trainer.train(train_iter,
                  train_steps,
                  save_checkpoint_steps=opt.save_checkpoint_steps,
                  valid_iters=valid_iters,
                  valid_steps=opt.valid_steps,
                  cl_valid_iter=cl_valid_iter)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
예제 #6
0
def main(opt, device_id):
    import pickle
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    configure_process(opt, device_id)
    init_logger(opt.log_file)
    # Load checkpoint if we resume from a previous training.
    import json
    train_iters = []
    with open(opt.data) as json_file:
        data = json.load(json_file)
        vocab = data["vocab"]
        vocab2 = vocab

    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)

        model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        ArgumentParser.update_model_opts(model_opt)
        ArgumentParser.validate_model_opts(model_opt)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        checkpoint = None
        model_opt = opt

        vocab = torch.load(vocab)
    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(vocab,
                                opt.model_type,
                                dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    for key, value in data.items():
        if key == ("vocab"):
            continue
        elif key.startswith("valid"):
            valid_iter = (key.split("valid-")[1].split("-"),
                          build_dataset_iter(value,
                                             fields,
                                             opt,
                                             is_train=False))
        else:
            train_iters.append(
                (key.split("-"), build_dataset_iter(value, fields, opt)))
    # Report src and tgt vocab sizes, including for features
    for side in ['src', 'tgt']:
        f = fields[side]
        try:
            f_iter = iter(f)
        except TypeError:
            f_iter = [(side, f)]
        for sn, sf in f_iter:
            if sf.use_vocab:
                logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)

    model.critic = critic()
    model.critic.to(model.device)

    if model.decoder2 is not None:
        model.critic2 = critic()
        model.critic2.to(model.device)
    else:
        model.critic2 = None

    model.critic3 = None

    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            model_saver=model_saver)

    vocab = torch.load(vocab2)

    #print (valid_iter is None)
    #valid_iter = valid_iter(datas[0][0],valid_iter)
    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0
    trainer.train(train_iters,
                  train_steps,
                  save_checkpoint_steps=opt.save_checkpoint_steps,
                  valid_iter=valid_iter,
                  valid_steps=opt.valid_steps,
                  smooth=opt.smooth)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()