Пример #1
0
def main(opt, device_id, data):
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    configure_process(opt, device_id)
    init_logger(opt.log_file)
    assert len(opt.accum_count) == len(opt.accum_steps), \
        'Number of accum_count values must match number of accum_steps'
    checkpoint = None

    # Report src and tgt vocab sizes, including for features
    for side in ['src', 'tgt', 'tgt_label']:
        logger.info(' * %s vocab size = %d' % (side, len(data["dict"][side])))

    # Build model.
    model = build_model(opt, data, checkpoint, device_id)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(opt, opt, model, data, optim)

    trainer = build_trainer(
        opt, device_id, model, data, optim, model_saver=model_saver)

    #from IPython.core.debugger import Pdb; Pdb().set_trace()
    train_iter = build_dataset_iter("train", data, opt)


    valid_iter = build_dataset_iter(
        "valid", data, opt, is_train=False)

    if opt.gpu:
        logger.info('Starting training on GPU: %s' % opt.gpu)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    trainer.train(
        train_iter,
        train_steps,
        save_checkpoint_steps=opt.save_checkpoint_steps,
        valid_iter=valid_iter,
        valid_steps=opt.valid_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
Пример #2
0
 def monitor_iter_fct():
     monitor_data = dict()
     for src, tgt in zip(opt.monitor_src, opt.monitor_tgt):
         fname = src.split("/" if "/" in src else "\\")[-1].split(
             ".")[0].replace("_src", "")
         monitor_data[fname] = build_dataset_iter(lazily_load_dataset(
             "monitor", opt, fname),
                                                  fields,
                                                  opt,
                                                  is_train=False)
     return monitor_data
Пример #3
0
 def get_speech_iterator(self, name, lang1, lang2, is_train=True): # name: train or valid
     """
     Create a new iterator for a dataset.
     """
     key = ','.join([x for x in ['speech', name, lang1, lang2] if x is not None])
     logger.info("Creating new training %s iterator ..." % key)
     speech_direction = (lang1, lang2)
     iterator = build_dataset_iter(
         lazily_load_dataset(name, self.params.speech_dataset[speech_direction][0]),
                                     self.speech_fields[speech_direction], self.params, is_train)
     iterator = iter(iterator)
     self.iterators[key] = iterator
     return iterator
Пример #4
0
    def get_speech_iterator(self, data_type, lang1, lang2):
        """
        Create a new iterator for a dataset.
        """
        assert data_type in ['valid']
        speech_direction = (lang1, lang2)
        lang2_id = self.params.lang2id[lang2]

        iterator = build_dataset_iter(
            lazily_load_dataset(data_type, self.params.speech_dataset[speech_direction][0]),
                                        self.params.speech_fields[speech_direction], self.params, False)
        iterator = iter(iterator)
        for batch in iterator:
            bos_index = self.params.bos_index[lang2_id]
            batch.tgt[0] = bos_index
            yield batch
Пример #5
0
 def train_iter_fct():
     return build_dataset_iter(lazily_load_dataset("train", opt), fields,
                               opt)
Пример #6
0
def main(opt, device_id):
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    configure_process(opt, device_id)
    init_logger(opt.log_file)
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)

        model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        ArgumentParser.update_model_opts(model_opt)
        ArgumentParser.validate_model_opts(model_opt)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        checkpoint = None
        model_opt = opt
        vocab = torch.load(opt.data + '.vocab.pt')

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(vocab,
                                opt.model_type,
                                dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    # Report src and tgt vocab sizes, including for features
    for side in ['src', 'tgt']:
        f = fields[side]
        try:
            f_iter = iter(f)
        except TypeError:
            f_iter = [(side, f)]
        for sn, sf in f_iter:
            if sf.use_vocab:
                logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            model_saver=model_saver)

    train_iter = build_dataset_iter("train", fields, opt)
    valid_iter = build_dataset_iter("valid", fields, opt, is_train=False)

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0
    trainer.train(train_iter,
                  train_steps,
                  save_checkpoint_steps=opt.save_checkpoint_steps,
                  valid_iter=valid_iter,
                  valid_steps=opt.valid_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
Пример #7
0
def main(opt, device_id, batch_queue=None, semaphore=None):
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    configure_process(opt, device_id)
    init_logger(opt.log_file)
    assert len(opt.accum_count) == len(opt.accum_steps), \
        'Number of accum_count values must match number of accum_steps'
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        ArgumentParser.update_model_opts(model_opt)
        ArgumentParser.validate_model_opts(model_opt)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        checkpoint = None
        model_opt = opt
        vocab = torch.load(opt.data + '.vocab.pt')

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(vocab,
                                opt.model_type,
                                dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    # Report src and tgt vocab sizes, including for features
    for side in ['src', 'tgt']:
        f = fields[side]
        try:
            f_iter = iter(f)
        except TypeError:
            f_iter = [(side, f)]
        for sn, sf in f_iter:
            if sf.use_vocab:
                logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec, nontrainable = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('non-trainable parameters (tgt_out_emb): %d' % nontrainable)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            model_saver=model_saver)

    if batch_queue is None:
        if len(opt.data_ids) > 1:
            train_shards = []
            for train_id in opt.data_ids:
                shard_base = "train_" + train_id
                train_shards.append(shard_base)
            train_iter = build_dataset_iter_multiple(train_shards, fields, opt)
        else:
            if opt.data_ids[0] is not None:
                shard_base = "train_" + opt.data_ids[0]
            else:
                shard_base = "train"
            train_iter = build_dataset_iter(shard_base, fields, opt)

    else:
        assert semaphore is not None, \
            "Using batch_queue requires semaphore as well"

        def _train_iter():
            while True:
                batch = batch_queue.get()
                semaphore.release()
                yield batch

        train_iter = _train_iter()

    valid_iter = build_dataset_iter("valid", fields, opt, is_train=False)

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0

    trainer.train(train_iter,
                  train_steps,
                  save_checkpoint_steps=opt.save_checkpoint_steps,
                  valid_iter=valid_iter,
                  valid_steps=opt.valid_steps)

    if trainer.report_manager.tensorboard_writer is not None:
        trainer.report_manager.tensorboard_writer.close()
def train(opt):
    ArgumentParser.validate_train_opts(opt)

    set_random_seed(opt.seed, False)

    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        vocab = torch.load(opt.data + '.vocab.pt')

    if old_style_vocab(vocab):
        fields = load_old_vocab(vocab,
                                opt.model_type,
                                dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    patch_fields(opt, fields)

    if len(opt.data_ids) > 1:
        train_shards = []
        for train_id in opt.data_ids:
            shard_base = "train_" + train_id
            train_shards.append(shard_base)
        train_iter = build_dataset_iter_multiple(train_shards, fields, opt)
    else:
        if opt.data_ids[0] is not None:
            shard_base = "train_" + opt.data_ids[0]
        else:
            shard_base = "train"
        train_iter = build_dataset_iter(shard_base, fields, opt)

    nb_gpu = len(opt.gpu_ranks)

    if opt.world_size > 1:
        queues = []
        mp = torch.multiprocessing.get_context('spawn')
        semaphore = mp.Semaphore(opt.world_size * opt.queue_size)

        procs = []
        for device_id in range(nb_gpu):
            q = mp.Queue(opt.queue_size)
            queues += [q]
            procs.append(
                mp.Process(target=run,
                           args=(opt, device_id, error_queue, q, semaphore),
                           daemon=True))
            procs[device_id].start()
            logger.info(" Starting process pid: %d  " % procs[device_id].pid)
            error_handler.add_child(procs[device_id].pid)
        producer = mp.Process(target=batch_producer,
                              args=(
                                  train_iter,
                                  queues,
                                  semaphore,
                                  opt,
                              ),
                              daemon=True)
        producer.start()
        error_handler.add_child(producer.pid)

        for p in procs:
            p.join()
        producer.terminate()

    elif nb_gpu == 1:
        single_main(opt, 0)
    else:
        single_main(opt, -1)
Пример #9
0
def main(opt, device_id):
    # opt = training_opt_postprocessing(opt, device_id)
    init_logger(opt.log_file)
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        model_opt = checkpoint['opt']
    else:
        raise Exception('You need to load a model')

    logger.info('Loading data from %s' % opt.data)
    dataset = next(lazily_load_dataset("train", opt))
    data_type = dataset.data_type
    logger.info('Data type %s' % data_type)

    # Load fields generated from preprocess phase.
    fields = _load_fields(dataset, data_type, opt, checkpoint)
    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = build_optim(model, opt, checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    dataset_iter = build_dataset_iter(lazily_load_dataset("train", opt),
                                      fields, opt)
    out_file = codecs.open(opt.output, 'w+', 'utf-8')
    scorer = onmt.translate.GNMTGlobalScorer(opt.alpha, opt.beta,
                                             opt.coverage_penalty,
                                             opt.length_penalty)

    translation_builder = TranslationBuilder(dataset,
                                             fields,
                                             n_best=opt.n_best,
                                             replace_unk=opt.replace_unk,
                                             has_tgt=False)

    def train_iter_fct():
        return build_dataset_iter(lazily_load_dataset("train", opt), fields,
                                  opt)

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            data_type,
                            model_saver=model_saver)

    translator = Translator(trainer.model,
                            fields,
                            opt.beam_size,
                            global_scorer=scorer,
                            out_file=out_file,
                            report_score=False,
                            copy_attn=model_opt.copy_attn,
                            logger=logger)

    for i, batch in enumerate(dataset_iter):
        unprocessed_translations = translator.translate_batch(batch, dataset)
        translations = translation_builder.from_batch(unprocessed_translations)
        print "Translations: ", ' '.join(translations[0].pred_sents[0])
        trainer.train_from_data(batch, train_steps=1)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
Пример #10
0
def main(opt, device_id, batch_queue=None, semaphore=None):
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    configure_process(opt, device_id)
    init_logger(opt.log_file)
    assert len(opt.accum_count) == len(opt.accum_steps), \
        'Number of accum_count values must match number of accum_steps'
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        ArgumentParser.update_model_opts(model_opt)
        ArgumentParser.validate_model_opts(model_opt)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        checkpoint = None
        model_opt = opt
        vocab = torch.load(opt.data + '.vocab.pt')

    if opt.teacher_model_path:
        logger.info('Loading teacher model from {path}'.format(
            path=opt.teacher_model_path))
        teacher_model_ckpt = torch.load(
            opt.teacher_model_path, map_location=lambda storage, loc: storage)

        teacher_model_opt = ArgumentParser.ckpt_model_opts(
            teacher_model_ckpt['opt'])
        ArgumentParser.update_model_opts(teacher_model_opt)
        ArgumentParser.validate_model_opts(teacher_model_opt)
        logger.info('Loading vocab from checkpoint at {path}'.format(
            path=opt.teacher_model_path))
        teacher_vocab = teacher_model_ckpt['vocab']

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(vocab,
                                opt.model_type,
                                dynamic_dict=opt.copy_attn)
    else:
        fields = vocab
        teacher_fields = teacher_vocab if opt.teacher_model_path else None

    # patch for fields that may be missing in old data/model
    # patch_fields(opt, fields)

    # Report src and tgt vocab sizes, including for features
    report_vocab_size(fields)
    if teacher_fields is not None:
        report_vocab_size(teacher_fields)

    # Build model.
    fields_opt = {"original": fields, "teacher": teacher_fields}
    model = custom_builder.build_model(model_opt, opt, fields_opt, checkpoint)
    # model = build_model(model_opt, opt, fields, checkpoint)
    teacher_model = build_model(
        teacher_model_opt, teacher_model_opt, teacher_fields,
        teacher_model_ckpt) if opt.teacher_model_path else None

    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    if teacher_model is not None:
        n_params, enc, dec = _tally_parameters(teacher_model)
        logger.info('encoder: %d' % enc)
        logger.info('decoder: %d' % dec)
        logger.info('* number of parameters: %d' % n_params)
        _check_save_model_path(teacher_model_opt)

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    # model_saver = build_model_saver(model_opt, opt, model, fields, optim)
    model_saver = custom_model_saver.build_model_saver(model_opt, opt, model,
                                                       fields_opt, optim)

    tgt_field = dict(teacher_fields)["tgt"].base_field if teacher_model is not None \
        else dict(fields)["tgt"].base_field
    sos_id = tgt_field.vocab.stoi[tgt_field.init_token]

    if teacher_model is not None and opt.word_sampling:
        sampler = Emulator(teacher_model,
                           teacher_fields,
                           device_id,
                           max_length=50,
                           random_sampling_topk=5)
    else:
        sampler = None

    if teacher_model is not None:
        trainer = build_trainer(opt,
                                device_id,
                                model,
                                teacher_fields,
                                optim,
                                model_saver,
                                teacher_model=teacher_model,
                                emulator=sampler)
    else:
        trainer = build_trainer(opt,
                                device_id,
                                model,
                                fields,
                                optim,
                                model_saver,
                                teacher_model=teacher_model,
                                emulator=sampler)

    if batch_queue is None:
        if len(opt.data_ids) > 1:
            train_shards = []
            for train_id in opt.data_ids:
                shard_base = "train_" + train_id
                train_shards.append(shard_base)
            train_iter = build_dataset_iter_multiple(train_shards, fields, opt)
        else:
            if opt.data_ids[0] is not None:
                shard_base = "train_" + opt.data_ids[0]
            else:
                shard_base = "train"
            train_iter = build_dataset_iter(shard_base, fields, opt)

    else:
        assert semaphore is not None, \
            "Using batch_queue requires semaphore as well"

        def _train_iter():
            while True:
                batch = batch_queue.get()
                semaphore.release()
                yield batch

        train_iter = _train_iter()

    valid_iter = build_dataset_iter("valid", fields, opt, is_train=False)

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0

    trainer.train(train_iter,
                  train_steps,
                  sos_id=sos_id,
                  save_checkpoint_steps=opt.save_checkpoint_steps,
                  valid_iter=valid_iter,
                  valid_steps=opt.valid_steps)

    if trainer.report_manager.tensorboard_writer is not None:
        trainer.report_manager.tensorboard_writer.close()
Пример #11
0
 def meta_valid_iter_fct(task_id, is_log=False):
     return build_dataset_iter(
         lazily_load_dataset("meta_valid", opt, task_id=task_id, is_log=is_log), fields_list[task_id], opt)
Пример #12
0
def main(opt, device_id):
    opt = training_opt_postprocessing(opt, device_id)
    init_logger(opt.log_file)
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)

        # Load default opts values then overwrite it with opts from
        # the checkpoint. It's usefull in order to re-train a model
        # after adding a new option (not set in checkpoint)
        dummy_parser = configargparse.ArgumentParser()
        opts.model_opts(dummy_parser)
        default_opt = dummy_parser.parse_known_args([])[0]

        model_opt = default_opt
        model_opt.__dict__.update(checkpoint['opt'].__dict__)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        checkpoint = None
        model_opt = opt
        vocab = torch.load(opt.data + '.vocab.pt')

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        data_type = opt.model_type
        fields = load_old_vocab(vocab, data_type, dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    # Report src and tgt vocab sizes, including for features
    for side in ['src', 'tgt']:
        for name, f in fields[side]:
            try:
                f_iter = iter(f)
            except TypeError:
                f_iter = [(name, f)]
            for sn, sf in f_iter:
                if sf.use_vocab:
                    logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            model_saver=model_saver)

    # this line is kind of a temporary kludge because different objects expect
    # fields to have a different structure
    dataset_fields = dict(chain.from_iterable(fields.values()))

    train_iter = build_dataset_iter("train", dataset_fields, opt)
    valid_iter = build_dataset_iter("valid",
                                    dataset_fields,
                                    opt,
                                    is_train=False)

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0
    trainer.train(train_iter,
                  train_steps,
                  save_checkpoint_steps=opt.save_checkpoint_steps,
                  valid_iter=valid_iter,
                  valid_steps=opt.valid_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
def validate(opt, device_id=0):
    configure_process(opt, device_id)
    configure_process
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        ArgumentParser.update_model_opts(model_opt)
        ArgumentParser.validate_model_opts(model_opt)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']

        if old_style_vocab(vocab):
            fields = load_old_vocab(vocab,
                                    opt.model_type,
                                    dynamic_dict=opt.copy_attn)
        else:
            fields = vocab

    # Report src and tgt vocab sizes, including for features
    for side in ['src', 'tgt']:
        f = fields[side]
        try:
            f_iter = iter(f)
        except TypeError:
            f_iter = [(side, f)]
        for sn, sf in f_iter:
            if sf.use_vocab:
                logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    #_check_save_model_path(opt)

    valid_iter = build_dataset_iter("valid", fields, opt, is_train=False)

    tgt_field = dict(fields)["tgt"].base_field
    valid_loss = onmt.utils.loss.build_loss_compute(model,
                                                    tgt_field,
                                                    opt,
                                                    train=False)

    model.eval()

    with torch.no_grad():
        stats = onmt.utils.Statistics()

        for batch in valid_iter:

            src, src_lengths = batch.src if isinstance(batch.src, tuple) \
                                   else (batch.src, None)
            tgt = batch.tgt

            # F-prop through the model.
            outputs, attns = model(src, tgt, src_lengths)

            # Compute loss.
            _, batch_stats = valid_loss(batch, outputs, attns)

            # Update statistics.
            stats.update(batch_stats)

    print('n words:  %d' % stats.n_words)
    print('Validation perplexity: %g' % stats.ppl())
    print('Validation accuracy: %g' % stats.accuracy())
    print('Validation avg attention entropy: %g' % stats.attn_entropy())
Пример #14
0
def main(opt, device_id, batch_queue=None, semaphore=None):
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    configure_process(opt, device_id)
    init_logger(opt.log_file)

    # save training settings
    if opt.log_file:
        shutil.copy2(opt.config, opt.exp_dir)
    logger.info(vars(opt))

    assert len(opt.accum_count) == len(opt.accum_steps), \
        'Number of accum_count values must match number of accum_steps'
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        ArgumentParser.update_model_opts(model_opt)
        ArgumentParser.validate_model_opts(model_opt)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        checkpoint = None
        model_opt = opt
        # added by @memray for multiple datasets
        if opt.vocab and opt.vocab != 'none':
            vocab = torch.load(opt.vocab)
        elif opt.encoder_type == 'pretrained':
            vocab = None
        else:
            vocab = None

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(
            vocab, opt.model_type, dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    # @memray: a temporary workaround, as well as train.py line 43
    if opt.model_type == "keyphrase":
        if opt.tgt_type in ["one2one", "multiple"]:
            if 'sep_indices' in fields:
                del fields['sep_indices']
        else:
            if 'sep_indices' not in fields:
                sep_indices = Field(
                    use_vocab=False, dtype=torch.long,
                    postprocessing=make_tgt, sequential=False)
                fields["sep_indices"] = sep_indices
        if 'src_ex_vocab' not in fields:
            src_ex_vocab = RawField()
            fields["src_ex_vocab"] = src_ex_vocab

    tokenizer = None
    if opt.pretrained_tokenizer:
        tokenizer = load_pretrained_tokenizer(opt.pretrained_tokenizer, opt.cache_dir, opt.special_vocab_path)
        setattr(opt, 'vocab_size', len(tokenizer))
    if opt.data_type == 'news':
        fields = reload_news_fields(fields, opt, tokenizer)


    # Report src and tgt vocab sizes, including for features
    for side in ['src', 'tgt']:
        f = fields[side]
        try:
            f_iter = iter(f)
        except TypeError:
            f_iter = [(side, f)]
        for sn, sf in f_iter:
            if sf.use_vocab:
                logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(
        opt, device_id, model, fields, optim, model_saver=model_saver)

    if batch_queue is None:
        if len(opt.data_ids) > 1:
            # added by @memray, for loading multiple datasets
            if opt.multi_dataset:
                shard_base = "train"
                train_iter = build_dataset_iter(shard_base, fields, opt, tokenizer=tokenizer)
            else:
                train_shards = []
                for train_id in opt.data_ids:
                    shard_base = "train_" + train_id
                    train_shards.append(shard_base)
                train_iter = build_dataset_iter_multiple(train_shards, fields, opt, tokenizer=tokenizer)
        else:
            shard_base = "train"
            train_iter = build_dataset_iter(shard_base, fields, opt)

    else:
        assert semaphore is not None, \
            "Using batch_queue requires semaphore as well"

        def _train_iter():
            while True:
                batch = batch_queue.get()
                semaphore.release()
                yield batch

        train_iter = _train_iter()

    if opt.valid:
        valid_iter = build_dataset_iter(
            "valid", fields, opt, is_train=False)
    else:
        valid_iter = None

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0

    trainer.train(
        train_iter,
        train_steps,
        save_checkpoint_steps=opt.save_checkpoint_steps,
        valid_iter=valid_iter,
        valid_steps=opt.valid_steps)

    if trainer.report_manager.tensorboard_writer is not None:
        trainer.report_manager.tensorboard_writer.close()
Пример #15
0
def main(opt, device_id):
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    configure_process(opt, device_id)
    init_logger(opt.log_file)
    assert len(opt.accum_count) == len(opt.accum_steps), \
        'Number of accum_count values must match number of accum_steps'
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)

        model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        ArgumentParser.update_model_opts(model_opt)
        ArgumentParser.validate_model_opts(model_opt)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        checkpoint = None
        model_opt = opt
        vocab = torch.load(opt.data + '.vocab.pt')

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(
            vocab, opt.model_type, dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    # Report src and tgt vocab sizes, including for features
    for side in ['src', 'tgt']:
        f = fields[side]
        try:
            f_iter = iter(f)
        except TypeError:
            f_iter = [(side, f)]
        for sn, sf in f_iter:
            if sf.use_vocab:
                logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(
        opt, device_id, model, fields, optim, model_saver=model_saver)

    train_iter = build_dataset_iter("train", fields, opt)
    valid_iter = build_dataset_iter(
        "valid", fields, opt, is_train=False)

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0
    trainer.train(
        train_iter,
        train_steps,
        save_checkpoint_steps=opt.save_checkpoint_steps,
        valid_iter=valid_iter,
        valid_steps=opt.valid_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
Пример #16
0
 def valid_iter_fct(task_id):
     return build_dataset_iter(
         lazily_load_dataset("valid", opt, task_id=task_id), fields_list[task_id], opt)
Пример #17
0
def main(opt, device_id):
    opt = training_opt_postprocessing(opt, device_id)
    init_logger(opt.log_file)
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)

        # Load default opts values then overwrite it with opts from
        # the checkpoint. It's usefull in order to re-train a model
        # after adding a new option (not set in checkpoint)
        dummy_parser = configargparse.ArgumentParser()
        opts.model_opts(dummy_parser)
        default_opt = dummy_parser.parse_known_args([])[0]

        model_opt = default_opt
        model_opt.__dict__.update(checkpoint['opt'].__dict__)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        checkpoint = None
        model_opt = opt
        vocab = torch.load(opt.data + '.vocab.pt')

    # Load a shard dataset to determine the data_type.
    # (All datasets have the same data_type).
    # this should be refactored out of existence reasonably soon
    first_dataset = torch.load(glob.glob(opt.data + '.train*.pt')[0])
    data_type = first_dataset.data_type

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way
    if old_style_vocab(vocab):
        fields = load_fields_from_vocab(vocab, data_type)
    else:
        fields = vocab

    # Report src and tgt vocab sizes, including for features
    for side in ['src', 'tgt']:
        for name, f in fields[side]:
            if f.use_vocab:
                logger.info(' * %s vocab size = %d' % (name, len(f.vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = build_optim(model, opt, checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            data_type,
                            model_saver=model_saver)

    # this line is kind of a temporary kludge because different objects expect
    # fields to have a different structure
    dataset_fields = dict(chain.from_iterable(fields.values()))

    train_iter = build_dataset_iter("train", dataset_fields, opt)
    valid_iter = build_dataset_iter("valid",
                                    dataset_fields,
                                    opt,
                                    is_train=False)

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    trainer.train(train_iter, valid_iter, opt.train_steps, opt.valid_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
Пример #18
0
def main(opt, device_id):
    opt = training_opt_postprocessing(opt, device_id)
    init_logger(opt.log_file)
    # Gather information related to the training script and commit version
    script_path = os.path.abspath(__file__)
    script_dir = os.path.dirname(os.path.dirname(script_path))
    logger.info('Train script dir: %s' % script_dir)
    git_commit = str(subprocess.check_output(['bash', script_dir + '/cluster_scripts/git_version.sh']))
    logger.info("Git Commit: %s" % git_commit[2:-3])
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        # TODO: load MTL model
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)

        # Load default opts values then overwrite it with opts from
        # the checkpoint. It's usefull in order to re-train a model
        # after adding a new option (not set in checkpoint)
        dummy_parser = configargparse.ArgumentParser()
        opts.model_opts(dummy_parser)
        default_opt = dummy_parser.parse_known_args([])[0]

        model_opt = default_opt
        model_opt.__dict__.update(checkpoint['opt'].__dict__)
    else:
        checkpoint = None
        model_opt = opt


    num_tasks = len(opt.data.split(','))
    opt.num_tasks = num_tasks

    checkpoint_list=[]
    if opt.warm_model:
        base_name=opt.warm_model
        for task_id in range(num_tasks):
            chkpt_path=base_name.replace("X",str(task_id))
            if not os.path.isfile(chkpt_path):
                chkpt_path = base_name.replace("X", str(0))
            logger.info('Loading a checkpoint from %s' % chkpt_path)

            checkpoint = torch.load(chkpt_path,
                                    map_location=lambda storage, loc: storage)
            checkpoint_list.append(checkpoint)
    else:
        for task_id in range(num_tasks):
            checkpoint_list.append(None)

    fields_list = []
    data_type=None
    for task_id in range(num_tasks):
        # Peek the first dataset to determine the data_type.
        # (All datasets have the same data_type).
        first_dataset = next(lazily_load_dataset("train", opt, task_id=task_id))
        data_type = first_dataset.data_type

        # Load fields generated from preprocess phase.
        if opt.mtl_shared_vocab and task_id > 0:
            logger.info(' * vocabulary size. Same as the main task!')
            fields = fields_list[0]
        else:
            fields = load_fields(first_dataset, opt, checkpoint_list[task_id], task_id=task_id)

        # Report src/tgt features.

        src_features, tgt_features = _collect_report_features(fields)
        for j, feat in enumerate(src_features):
            logger.info(' * (Task %d) src feature %d size = %d'
                        % (task_id, j, len(fields[feat].vocab)))
        for j, feat in enumerate(tgt_features):
            logger.info(' * (Task %) tgt feature %d size = %d'
                        % (task_id, j, len(fields[feat].vocab)))
        fields_list.append(fields)

    if opt.epochs > -1:
        total_num_batch = 0
        for task_id in range(num_tasks):
            train_iter = build_dataset_iter(lazily_load_dataset("train", opt, task_id=task_id), fields_list[task_id], opt)
            for i, batch in enumerate(train_iter):
                num_batch = i
            total_num_batch+=num_batch
            if opt.mtl_schedule < 10:
                break
        num_batch = total_num_batch
        opt.train_steps = (num_batch * opt.epochs) + 1
        # Do the validation and save after each epoch
        opt.valid_steps = num_batch
        opt.save_checkpoint_steps = 1

    # logger.info(opt_to_string(opt))
    logger.info(opt)

    # Build model(s).
    models_list = []
    for task_id in range(num_tasks):
        if opt.mtl_fully_share and task_id > 0:
            # Since we only have one model, copy the pointer to the model for all
            models_list.append(models_list[0])
        else:

            main_model = models_list[0] if task_id > 0 else None
            model = build_model(model_opt, opt, fields_list[task_id], checkpoint_list[task_id], main_model=main_model, task_id=task_id)
            n_params, enc, dec = _tally_parameters(model)
            logger.info('(Task %d) encoder: %d' % (task_id, enc))
            logger.info('(Task %d) decoder: %d' % (task_id, dec))
            logger.info('* number of parameters: %d' % n_params)
            _check_save_model_path(opt)
            models_list.append(model)

    # combine parameters of different models and consider shared parameters just once.
    def combine_named_parameters(named_params_list):
        observed_params = []
        for model_named_params in named_params_list:
            for name, p in model_named_params:
                is_observed = False
                # Check whether we observed this parameter before
                for param in observed_params:
                    if p is param:
                        is_observed = True
                        break
                if not is_observed:
                    observed_params.append(p)
                    yield name, p

    # Build optimizer.
    optims_list = []
    all_models_params=[]
    for task_id in range(num_tasks):
        if not opt.mtl_shared_optimizer:
            optim = build_optim(models_list[task_id], opt, checkpoint)
            optims_list.append(optim)
        else:
            all_models_params.append(models_list[task_id].named_parameters())

    # Extract the list of shared parameters among the models of all tasks.
    observed_params = []
    shared_params = []
    for task_id in range(num_tasks):
        for name, p in models_list[task_id].named_parameters():
            is_observed = False
            # Check whether we observed this parameter before
            for param in observed_params:
                if p is param:
                    shared_params.append(name)
                    is_observed = True
                    break
            if not is_observed:
                observed_params.append(p)
    opt.shared_params = shared_params

    if opt.mtl_shared_optimizer:
        optim = build_optim_mtl_params(combine_named_parameters(all_models_params), opt, checkpoint)
        optims_list.append(optim)

    # Build model saver
    model_saver = build_mtl_model_saver(model_opt, opt, models_list, fields_list, optims_list)

    trainer = build_trainer(opt, device_id, models_list, fields_list,
                            optims_list, data_type, model_saver=model_saver)

    def train_iter_fct(task_id):
        return build_dataset_iter(
            lazily_load_dataset("train", opt, task_id=task_id), fields_list[task_id], opt)

    def valid_iter_fct(task_id):
        return build_dataset_iter(
            lazily_load_dataset("valid", opt, task_id=task_id), fields_list[task_id], opt)

    def meta_valid_iter_fct(task_id, is_log=False):
        return build_dataset_iter(
            lazily_load_dataset("meta_valid", opt, task_id=task_id, is_log=is_log), fields_list[task_id], opt)

    # Do training.
    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps,
                  opt.valid_steps, meta_valid_iter_fct=meta_valid_iter_fct)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
Пример #19
0
def train(opt):
    ArgumentParser.validate_train_opts(opt)
    ArgumentParser.update_model_opts(opt)
    ArgumentParser.validate_model_opts(opt)

    set_random_seed(opt.seed, False)

    # @Memray, check the dir existence beforehand to avoid path conflicting errors,
    #   and set save_model, tensorboard_log_dir, wandb_log_dir if not exist
    train_single._check_save_model_path(opt)
    if not os.path.exists(opt.tensorboard_log_dir):
        os.makedirs(opt.tensorboard_log_dir)

    # Scan previous checkpoint to resume training
    latest_step = 0
    latest_ckpt = None
    for subdir, dirs, filenames in os.walk(opt.exp_dir):
        for filename in sorted(filenames):
            if not filename.endswith('.pt'):
                continue
            step = int(filename[filename.rfind('_') + 1:filename.rfind('.pt')])
            if step > latest_step:
                latest_ckpt = os.path.join(subdir, filename)
                latest_step = step
    # if not saved in the exp folder, check opt.save_model
    if latest_ckpt is None and opt.save_model is not None:
        save_model_dir = os.path.dirname(os.path.abspath(opt.save_model))
        model_prefix = opt.save_model[opt.save_model.rfind(os.path.sep) + 1:]
        for subdir, dirs, filenames in os.walk(save_model_dir):
            for filename in sorted(filenames):
                if not filename.endswith('.pt'):
                    continue
                if not filename.startswith(model_prefix):
                    continue
                step = int(filename[filename.rfind('_') +
                                    1:filename.rfind('.pt')])
                if step > latest_step:
                    latest_ckpt = os.path.join(subdir, filename)
                    latest_step = step
    if latest_ckpt is not None:
        logger.info("A previous checkpoint is found, train from it: %s" %
                    latest_ckpt)
        setattr(opt, 'train_from', latest_ckpt)
        setattr(opt, 'reset_optim', 'none')

    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    elif opt.vocab and opt.vocab != 'none':
        # added by @memray for multiple datasets
        vocab = torch.load(opt.vocab)
        # check for code where vocab is saved instead of fields
        # (in the future this will be done in a smarter way)
        if old_style_vocab(vocab):
            vocab = load_old_vocab(vocab,
                                   opt.model_type,
                                   dynamic_dict=opt.copy_attn)
    elif opt.encoder_type == 'pretrained':
        vocab = None
    else:
        vocab = None

    fields = vocab

    # @memray: a temporary workaround, as well as train_single.py line 78
    if fields and opt.data_type == "keyphrase":
        if opt.tgt_type in ["one2one", "multiple"]:
            if 'sep_indices' in fields:
                del fields['sep_indices']
        else:
            if 'sep_indices' not in fields:
                sep_indices = Field(use_vocab=False,
                                    dtype=torch.long,
                                    postprocessing=make_tgt,
                                    sequential=False)
                fields["sep_indices"] = sep_indices
        if 'src_ex_vocab' not in fields:
            src_ex_vocab = RawField()
            fields["src_ex_vocab"] = src_ex_vocab

    # @memray reload fields for news dataset and pretrained models
    tokenizer = None
    if opt.pretrained_tokenizer is not None:
        tokenizer = load_pretrained_tokenizer(opt.pretrained_tokenizer,
                                              opt.cache_dir,
                                              opt.special_vocab_path)
        setattr(opt, 'vocab_size', len(tokenizer))
    if opt.data_type == 'news':
        fields = reload_news_fields(opt, tokenizer=tokenizer)
    # elif opt.data_type == 'keyphrase':
    #     fields = reload_keyphrase_fields(opt, tokenizer=tokenizer)

    if len(opt.data_ids) > 1:
        # added by @memray, for loading multiple datasets
        if opt.multi_dataset:
            shard_base = "train"
            train_iter = build_dataset_iter(shard_base,
                                            fields,
                                            opt,
                                            multi=True)
        else:
            train_shards = []
            for train_id in opt.data_ids:
                shard_base = "train_" + train_id
                train_shards.append(shard_base)
            train_iter = build_dataset_iter_multiple(train_shards, fields, opt)
    else:
        shard_base = "train"
        train_iter = build_dataset_iter(shard_base, fields, opt)

    nb_gpu = len(opt.gpu_ranks)

    if opt.world_size > 1:
        queues = []
        mp = torch.multiprocessing.get_context('spawn')
        semaphore = mp.Semaphore(opt.world_size * opt.queue_size)
        # Create a thread to listen for errors in the child processes.
        error_queue = mp.SimpleQueue()
        error_handler = ErrorHandler(error_queue)
        # Train with multiprocessing.
        procs = []
        for device_id in range(nb_gpu):
            q = mp.Queue(opt.queue_size)
            queues += [q]
            procs.append(
                mp.Process(target=run,
                           args=(opt, device_id, error_queue, q, semaphore),
                           daemon=True))
            procs[device_id].start()
            logger.info(" Starting process pid: %d  " % procs[device_id].pid)
            error_handler.add_child(procs[device_id].pid)
        producer = mp.Process(target=batch_producer,
                              args=(
                                  train_iter,
                                  queues,
                                  semaphore,
                                  opt,
                              ),
                              daemon=True)
        producer.start()
        error_handler.add_child(producer.pid)

        for p in procs:
            p.join()
        producer.terminate()

    elif nb_gpu == 1:  # case 1 GPU only
        single_main(opt, 0)
    else:  # case only CPU
        single_main(opt, -1)
Пример #20
0
def main(opt, device_id):
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    if opt.local_rank != -1:
        torch.cuda.set_device(opt.local_rank)
        device = torch.device("cuda", opt.local_rank)
        torch.distributed.init_process_group(backend='nccl')
        device_id = opt.local_rank
        world_size = torch.distributed.get_world_size()
    else:
        if device_id == -1:
            device = torch.device("cpu")
        else:
            device = torch.device("cuda", device_id)
    if opt.local_rank > 0:
        logger.disabled = True
    configure_process(opt, device_id)
    init_logger(opt.log_file)
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)

        model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        ArgumentParser.update_model_opts(model_opt)
        ArgumentParser.validate_model_opts(model_opt)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        checkpoint = None
        model_opt = opt
        vocab = torch.load(opt.data + '.vocab.pt')

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(vocab,
                                opt.model_type,
                                dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    # Report src and tgt vocab sizes, including for features
    for side in ['src', 'tgt']:
        f = fields[side]
        try:
            f_iter = iter(f)
        except TypeError:
            f_iter = [(side, f)]
        for sn, sf in f_iter:
            if sf.use_vocab:
                logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            model_saver=model_saver)

    if opt.bert_kd:
        src_vocab = vocab['src'].fields[0][1].vocab.stoi
        tgt_vocab = vocab['tgt'].fields[0][1].vocab.stoi
        assert 0 < opt.kd_topk <= 128
        train_dataset = BertKdDataset(opt.data_db,
                                      opt.bert_dump,
                                      src_vocab,
                                      tgt_vocab,
                                      max_len=150,
                                      k=opt.kd_topk)
        BUCKET_SIZE = 8192
        if True or opt.local_rank == -1 and opt.world_size == 1:
            train_sampler = TokenBucketSampler(train_dataset.keys,
                                               BUCKET_SIZE,
                                               opt.batch_size,
                                               batch_multiple=1)
        else:
            assert False  # seems like it's handled in training loop
            train_sampler = DistributedTokenBucketSampler(world_size,
                                                          device_id,
                                                          train_dataset.keys,
                                                          BUCKET_SIZE,
                                                          opt.batch_size,
                                                          batch_multiple=1)
        train_loader = DataLoader(train_dataset,
                                  batch_sampler=train_sampler,
                                  num_workers=4,
                                  collate_fn=BertKdDataset.pad_collate)
        train_iter = cycle_loader(train_loader, device)
    else:
        train_iter = build_dataset_iter("train", fields, opt)
    valid_iter = build_dataset_iter("valid", fields, opt, is_train=False)

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0
    trainer.train(train_iter,
                  train_steps,
                  save_checkpoint_steps=opt.save_checkpoint_steps,
                  valid_iter=valid_iter,
                  valid_steps=opt.valid_steps)

    if opt.tensorboard:
        if trainer.report_manager.tensorboard_writer:
            trainer.report_manager.tensorboard_writer.close()
Пример #21
0
def main(opt, device_id):
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    #     import pdb
    #     _check_ = torch.load("/home/irteam/users/kaist/ginalee/clean_data/baselines/9-domain5-185pre_step_2500.pt")
    #     model_encoder = [i for i in _check_['model'].keys() if "encoder" in i.split(".")]
    #     encoder = {}
    #     pdb.set_trace()
    #     for i, param in enumerate(model_encoder):
    #         if i == 0:
    #             encoder['embeddings.word_embeddings.weight'] = _check_['model'][param]
    #             continue
    #         param_ = ".".join(param.split(".")[1:])
    # #         if param.split(".")[1] == 'encoder':
    # #             param_ = ".".join(param.split(".")[2:])
    # #         else:
    # #             param_ = ".".join(param.split(".")[1:])
    #         encoder[param_] = _check_['model'][param]
    #     pdb.set_trace()

    configure_process(opt, device_id)
    init_logger(opt.log_file)
    logger.info(opt)
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)

        model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        ArgumentParser.update_model_opts(model_opt)
        ArgumentParser.validate_model_opts(model_opt)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)

        load_vocab = torch.load(opt.data + '.vocab.pt')
        vocab = checkpoint['vocab']
        load_vocab['src'].fields[0][1].vocab = vocab['src'].fields[0][1].vocab
        load_vocab['tgt'].fields[0][1].vocab = vocab['tgt'].fields[0][1].vocab
        vocab = load_vocab
    else:
        checkpoint = None
        model_opt = opt
        vocab = torch.load(opt.data + '.vocab.pt')

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(vocab,
                                opt.model_type,
                                dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    # Report src and tgt vocab sizes, including for features
    for side in ['src', 'tgt']:
        f = fields[side]
        try:
            f_iter = iter(f)
        except TypeError:
            f_iter = [(side, f)]
        for sn, sf in f_iter:
            if sf.use_vocab:
                logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    if opt.pretrain_from:
        check = torch.load(opt.pretrain_from,
                           map_location=lambda storage, loc: storage)
        model.load_state_dict(check['model'], strict=False)
        model.load_state_dict(check['generator'], strict=False)
        if 'dom_classifier' in check:
            model.load_state_dict(check['dom_classifier'], strict=False)

    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    translator = None
    if opt.domain_cls_enc == False:
        translator = train_build_translator(opt,
                                            model,
                                            model_opt,
                                            fields,
                                            report_score=True)

    trainer = build_trainer(translator,
                            opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            model_saver=model_saver)

    train_iter = build_dataset_iter("train", fields, opt)
    valid_iter = build_dataset_iter("valid", fields, opt, is_train=False)

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0
    trainer.train(train_iter,
                  train_steps,
                  save_checkpoint_steps=opt.save_checkpoint_steps,
                  valid_iter=valid_iter,
                  valid_steps=opt.valid_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
Пример #22
0
 def valid_iter_fct():
     return build_dataset_iter(lazily_load_dataset("valid", opt, logger),
                               fields, opt)
def main(opt, device_id):
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    configure_process(opt, device_id)
    init_logger(opt.log_file)
    # Load checkpoint if we resume from a previous training.
    load_str = opt.train_from if opt.train_from else opt.load_uncond_from
    if load_str:
        logger.info('Loading checkpoint from %s' % load_str)
        checkpoint = torch.load(load_str,
                                map_location=lambda storage, loc: storage)

        logger.info('Loading vocab from checkpoint at %s.' % load_str)
        vocab = checkpoint['vocab']

        if opt.train_from:
            model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
            ArgumentParser.update_model_opts(model_opt)
            ArgumentParser.validate_model_opts(model_opt)
        else:
            model_opt = opt
    else:
        checkpoint = None
        model_opt = opt
        vocab = torch.load(opt.data + '.vocab.pt')

    if opt.gpt2_params_path is not None:
        import tensorflow as tf
        import numpy as np
        # Taken from pytorch-pretrained-BERT:
        # Load weights from TF model
        logger.info("Loading TF GPT weights...")
        init_vars = tf.train.list_variables(opt.gpt2_params_path)
        names = []
        arrays = []
        for name, shape in init_vars:
            if opt.gpt_emb_only and ('wpe' not in name and 'wte' not in name):
                continue
            if opt.gpt_wpe_only and 'wpe' not in name:
                continue
            #print("Loading TF weight {} with shape {}".format(name, shape))
            array = tf.train.load_variable(opt.gpt2_params_path, name)
            names.append(name)
            arrays.append(array.squeeze())
        logger.info("Done.")

        if checkpoint is None:
            checkpoint = {'gpt2_params': zip(names, arrays)}
        else:
            checkpoint['gpt2_params'] = zip(names, arrays)

    if opt.encoder_from is not None:
        logger.info('Loading checkpoint with encoder from %s' %
                    opt.encoder_from)
        enc_checkpoint = torch.load(opt.encoder_from,
                                    map_location=lambda storage, loc: storage)
        enc_vocab = enc_checkpoint['vocab']
        if vocab['src'].base_field.vocab != enc_vocab['src'].base_field.vocab:
            raise ValueError(
                'encoder vocab and model vocab need to be identical it using pretrained encoder'
            )
        if checkpoint is None:
            checkpoint = {}
        checkpoint['enc_model'] = enc_checkpoint['model']

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(vocab,
                                opt.model_type,
                                dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    # Report src and tgt vocab sizes, including for features
    sides = ['tgt'] if opt.model_type == 'none' else ['src', 'tgt']
    for side in sides:
        f = fields[side]
        try:
            f_iter = iter(f)
        except TypeError:
            f_iter = [(side, f)]
        for sn, sf in f_iter:
            if sf.use_vocab:
                logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec, lm_dec = _tally_parameters(model)
    n_params_t, enc_t, dec_t, lm_dec_t = _tally_parameters(model,
                                                           only_trainable=True)
    logger.info('encoder: %d (%d)' % (enc, enc_t))
    logger.info('decoder: %d (%d)' % (dec, dec_t))
    if opt.simple_fusion:
        logger.info('lm decoder: %d (%d)' % (lm_dec, lm_dec_t))

    logger.info('* number of parameters: %d (%d)' % (n_params, n_params_t))
    _check_save_model_path(opt)

    if not opt.train_from and opt.gpt2_params_path is not None:
        checkpoint = None

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            model_saver=model_saver)

    train_iter = build_dataset_iter("train", fields, opt)
    valid_iter = build_dataset_iter("valid", fields, opt, is_train=False)

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0
    trainer.train(train_iter,
                  train_steps,
                  save_checkpoint_steps=opt.save_checkpoint_steps,
                  valid_iter=valid_iter,
                  valid_steps=opt.valid_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
    def valid_iter_fct(): return build_dataset_iter(
        lazily_load_dataset("valid", opt), fields, opt, is_train=False)

    # Do training.
    trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps,
Пример #25
0
    def valid_iter_fct(): return build_dataset_iter(
        lazily_load_dataset("valid", opt), fields, opt)

    # Do training.
    trainer.train(train_iter_fct, valid_iter_fct, opt.start_epoch, opt.epochs)
Пример #26
0
def train(opt):
    ArgumentParser.validate_train_opts(opt)
    ArgumentParser.update_model_opts(opt)
    ArgumentParser.validate_model_opts(opt)

    set_random_seed(opt.seed, False)

    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        if 'vocab' in checkpoint:
            logger.info('Loading vocab from checkpoint at %s.' %
                        opt.train_from)
            vocab = checkpoint['vocab']
        else:
            vocab = torch.load(opt.data + '.vocab.pt')
    else:
        vocab = torch.load(opt.data + '.vocab.pt')

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(vocab,
                                opt.model_type,
                                dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    if len(opt.data_ids) > 1:
        train_shards = []
        for train_id in opt.data_ids:
            shard_base = "train_" + train_id
            train_shards.append(shard_base)
        train_iter = build_dataset_iter_multiple(train_shards, fields, opt)
    else:
        if opt.data_ids[0] is not None:
            shard_base = "train_" + opt.data_ids[0]
        else:
            shard_base = "train"
        train_iter = build_dataset_iter(shard_base, fields, opt)

    nb_gpu = len(opt.gpu_ranks)

    if opt.world_size > 1:
        queues = []
        mp = torch.multiprocessing.get_context('spawn')
        semaphore = mp.Semaphore(opt.world_size * opt.queue_size)
        # Create a thread to listen for errors in the child processes.
        error_queue = mp.SimpleQueue()
        error_handler = ErrorHandler(error_queue)
        # Train with multiprocessing.
        procs = []
        for device_id in range(nb_gpu):
            q = mp.Queue(opt.queue_size)
            queues += [q]
            procs.append(
                mp.Process(target=run,
                           args=(opt, device_id, error_queue, q, semaphore),
                           daemon=True))
            procs[device_id].start()
            logger.info(" Starting process pid: %d  " % procs[device_id].pid)
            error_handler.add_child(procs[device_id].pid)
        producer = mp.Process(target=batch_producer,
                              args=(
                                  train_iter,
                                  queues,
                                  semaphore,
                                  opt,
                              ),
                              daemon=True)
        producer.start()
        error_handler.add_child(producer.pid)

        for p in procs:
            p.join()
        producer.terminate()

    elif nb_gpu == 1:  # case 1 GPU only
        single_main(opt, 0)
    else:  # case only CPU
        single_main(opt, -1)
Пример #27
0
def train(opt):
    ArgumentParser.validate_train_opts(opt)
    ArgumentParser.update_model_opts(opt)
    ArgumentParser.validate_model_opts(opt)

    if opt.train_from != '':
        raise Exception(
            'train_from will be set automatically to the latest model, you should not set it manually'
        )

    # set gpu ranks automatically if not specified
    if len(opt.gpu_ranks) == 0:
        opt.gpu_ranks = list(range(opt.world_size))

    # Set train_from to latest checkpoint if it exists
    file_list = glob.glob(opt.save_model + '*.pt')
    if len(os.listdir(os.path.dirname(
            opt.save_model))) > 0 and len(file_list) == 0:
        raise Exception(
            'save_model directory is not empty but no pretrained models found')
    if len(file_list) > 0:
        ckpt_nos = list(
            map(lambda x: int(x.split('_')[-1].split('.')[0]), file_list))
        ckpt_no = max(ckpt_nos)
        opt.train_from = opt.save_model + '_' + str(ckpt_no) + '.pt'
        print(opt.train_from)
        assert os.path.exists(opt.train_from)

    set_random_seed(opt.seed, False)

    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        vocab = torch.load(opt.data + '.vocab.pt')

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(vocab,
                                opt.model_type,
                                dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    if len(opt.data_ids) > 1:
        train_shards = []
        for train_id in opt.data_ids:
            shard_base = "train_" + train_id
            train_shards.append(shard_base)
        train_iter = build_dataset_iter_multiple(train_shards, fields, opt)
    else:
        if opt.data_ids[0] is not None:
            shard_base = "train_" + opt.data_ids[0]
        else:
            shard_base = "train"
        train_iter = build_dataset_iter(shard_base, fields, opt)

    nb_gpu = len(opt.gpu_ranks)

    if opt.world_size > 1:
        queues = []
        mp = torch.multiprocessing.get_context('spawn')
        semaphore = mp.Semaphore(opt.world_size * opt.queue_size)
        # Create a thread to listen for errors in the child processes.
        error_queue = mp.SimpleQueue()
        error_handler = ErrorHandler(error_queue)
        # Train with multiprocessing.
        procs = []
        for device_id in range(nb_gpu):
            q = mp.Queue(opt.queue_size)
            queues += [q]
            procs.append(
                mp.Process(target=run,
                           args=(opt, device_id, error_queue, q, semaphore),
                           daemon=True))
            procs[device_id].start()
            logger.info(" Starting process pid: %d  " % procs[device_id].pid)
            error_handler.add_child(procs[device_id].pid)
        producer = mp.Process(target=batch_producer,
                              args=(
                                  train_iter,
                                  queues,
                                  semaphore,
                                  opt,
                              ),
                              daemon=True)
        producer.start()
        error_handler.add_child(producer.pid)

        for p in procs:
            p.join()
        producer.terminate()

    elif nb_gpu == 1:  # case 1 GPU only
        single_main(opt, 0)
    else:  # case only CPU
        single_main(opt, -1)
Пример #28
0
    def valid_iter_fct(): return build_dataset_iter(
        lazily_load_dataset("valid", opt), fields, opt, is_train=False)

    # Do training.
    if len(opt.gpu_ranks):
Пример #29
0
def main(opt, device_id):
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    configure_process(opt, device_id)
    init_logger(opt.log_file)
    assert len(opt.accum_count) == len(opt.accum_steps), \
        'Number of accum_count values must match number of accum_steps'
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)

        model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        ArgumentParser.update_model_opts(model_opt)
        ArgumentParser.validate_model_opts(model_opt)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        checkpoint = None
        model_opt = opt
        vocab = torch.load(opt.data + '.vocab.pt')

    logger.info('Loading alignment.')
    lemma_aligns = open(model_opt.lemma_align, 'rb').readlines()
    src_stoi = vocab['src'].base_field.vocab.stoi
    lemma_stoi = vocab['word_topic'].base_field.vocab.stoi
    w2l = {}
    word_to_lemma = []
    for pair in lemma_aligns:
        pair = pair.strip().split()
        w2l[src_stoi[pair[0].decode('utf-8')]] = \
            lemma_stoi[pair[1].decode('utf-8')]
    w2l[src_stoi['unk']] = lemma_stoi['unk']
    for index in range(len(vocab['src'].base_field.vocab.itos)):
        if index in w2l:
            word_to_lemma.append(w2l[index])
        else:
            word_to_lemma.append(w2l[lemma_stoi['unk']])
    word_to_lemma = torch.tensor(word_to_lemma)
    logger.info('Loading topic matrix')
    if device_id >= 0:
        topic_matrix = torch.load(opt.topic_matrix,
                                  map_location=torch.device(device_id))
    else:
        topic_matrix = torch.load(opt.topic_matrix)
    if opt.model_dtype == 'fp16':
        topic_matrix = topic_matrix.half()
    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(vocab,
                                opt.model_type,
                                dynamic_dict=opt.copy_attn)
    else:
        fields = vocab
    # Report src and tgt vocab sizes, including for features
    for side in ['src', 'tgt']:
        f = fields[side]
        try:
            f_iter = iter(f)
        except TypeError:
            f_iter = [(side, f)]
        for sn, sf in f_iter:
            if sf.use_vocab:
                logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            model_saver=model_saver)

    train_iter = build_dataset_iter("train", fields, opt)
    valid_iter = build_dataset_iter("valid", fields, opt, is_train=False)

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0
    trainer.train(topic_matrix,
                  word_to_lemma,
                  train_iter,
                  train_steps,
                  save_checkpoint_steps=opt.save_checkpoint_steps,
                  valid_iter=valid_iter,
                  valid_steps=opt.valid_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
Пример #30
0
 def valid_iter_fct():
     return build_dataset_iter(lazily_load_dataset("valid", opt),
                               fields,
                               opt,
                               is_train=False)
Пример #31
0
 def train_iter_fct(task_id):
     return build_dataset_iter(
         lazily_load_dataset("train", opt, task_id=task_id), fields_list[task_id], opt)