Exemplo n.º 1
0
def main(opt, device_id):
    opt = training_opt_postprocessing(opt, device_id)
    init_logger(opt.log_file)
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        model_opt = checkpoint['opt']
    else:
        checkpoint = None
        model_opt = opt

    # Peek the first dataset to determine the data_type.
    # (All datasets have the same data_type).
    first_dataset = next(lazily_load_dataset("train", opt))
    data_type = first_dataset.data_type  # concat, query, hier

    # Load fields generated from preprocess phase.
    fields = _load_fields(first_dataset, data_type, opt, checkpoint)

    # Report src/tgt features.
    src_features, tgt_features = _collect_report_features(fields)
    for j, feat in enumerate(src_features):
        logger.info(' * src feature %d size = %d' %
                    (j, len(fields[feat].vocab)))
    for j, feat in enumerate(tgt_features):
        logger.info(' * tgt feature %d size = %d' %
                    (j, len(fields[feat].vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = build_optim(model, opt, checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            data_type,
                            model_saver=model_saver)

    def data_iter_fct(data_stage):
        """data_stage: train / valid"""
        pt_file = opt.data + '.' + data_stage + '.pt'
        logger.info('Loading {} dataset'.format(data_stage))
        dataset = torch.load(pt_file)
        logger.info('Loaded {} dataset'.format(data_stage))
        dataset.fields = fields
        is_train = True if data_stage == "train" else False
        batch_size = opt.batch_size if is_train else opt.valid_batch_size
        repeat = True if data_stage == "train" else False
        if opt.gpuid != -1:
            device = "cuda"
        else:
            device = "cpu"

        def sort_key(ex):
            """ Sort using length of source sentences. """
            return ex.total_tokens

        return torchtext.data.Iterator(dataset=dataset,
                                       batch_size=batch_size,
                                       device=device,
                                       train=is_train,
                                       sort=False,
                                       sort_key=sort_key,
                                       repeat=repeat)

    # Do training.
    trainer.train(data_iter_fct, opt.train_steps, opt.valid_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
Exemplo n.º 2
0
def main(opt, device_id):
    opt = training_opt_postprocessing(opt, device_id)
    init_logger(opt.log_file)
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        model_opt = checkpoint['opt']
    else:
        checkpoint = None
        model_opt = opt

    # Peek the first dataset to determine the data_type.
    # (All datasets have the same data_type).
    first_dataset = next(lazily_load_dataset("train", opt))
    data_type = first_dataset.data_type

    # Load fields generated from preprocess phase.
    fields = _load_fields(first_dataset, data_type, opt, checkpoint)

    # Report src/tgt features.

    src_features, tgt_features = _collect_report_features(fields)
    for j, feat in enumerate(src_features):
        logger.info(' * src feature %d size = %d' %
                    (j, len(fields[feat].vocab)))
    for j, feat in enumerate(tgt_features):
        logger.info(' * tgt feature %d size = %d' %
                    (j, len(fields[feat].vocab)))

    # Build model.
    #pdb.set_trace()
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = build_optim(model, opt, checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            data_type,
                            model_saver=model_saver)

    def train_iter_fct():
        return build_dataset_iter(lazily_load_dataset("train", opt), fields,
                                  opt)

    def valid_iter_fct():
        return build_dataset_iter(lazily_load_dataset("valid", opt),
                                  fields,
                                  opt,
                                  is_train=False)

    # Do training.
    trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps,
                  opt.valid_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
Exemplo n.º 3
0
def main(opt):
    opt = training_opt_postprocessing(opt)
    init_logger(opt.log_file)
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        model_opt = checkpoint['opt']
    else:
        checkpoint = None
        model_opt = opt

    if opt.load_pretrained_selector_from:
        logger.info('Loading selector checkpoint from %s' %
                    opt.load_pretrained_selector_from)
        sel_checkpoint = torch.load(opt.load_pretrained_selector_from,
                                    map_location=lambda storage, loc: storage)
    else:
        sel_checkpoint = None

    if opt.load_pretrained_s2s_generator_from:
        logger.info('Loading s2s generator checkpoint from %s' %
                    opt.load_pretrained_s2s_generator_from)
        s2s_gen_checkpoint = torch.load(
            opt.load_pretrained_s2s_generator_from,
            map_location=lambda storage, loc: storage)
    else:
        s2s_gen_checkpoint = None

    # Peek the fisrt dataset to determine the data_type.
    # (All datasets have the same data_type).
    first_dataset = next(lazily_load_dataset("train", opt))
    data_type = first_dataset.data_type

    # Load fields generated from preprocess phase.
    fields = _load_fields(first_dataset, data_type, opt, checkpoint)

    # Report src/tgt features.
    src_features, tgt_features = _collect_report_features(fields)
    for j, feat in enumerate(src_features):
        logger.info(' * src feature %d size = %d' %
                    (j, len(fields[feat].vocab)))
    for j, feat in enumerate(tgt_features):
        logger.info(' * tgt feature %d size = %d' %
                    (j, len(fields[feat].vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint, sel_checkpoint,
                        s2s_gen_checkpoint)

    # Fix the pretrained selector parameters if needed
    if model_opt.fix_sel_all:
        assert opt.load_pretrained_selector_from
        assert opt.sel_lambda == 0.0
        assert not model_opt.fix_sel_classifier
        for name, param in model.named_parameters():
            if 'selector' in name:
                param.requires_grad = False
    # only fix the classifier of the selector
    if model_opt.fix_sel_classifier:
        assert opt.load_pretrained_selector_from
        assert not model_opt.fix_sel_all
        for name, param in model.named_parameters():
            if 'selector' in name and 'rnn' not in name and 'embeddings' not in name:
                param.requires_grad = False

    n_params, sel, enc, dec, gen = _my_tally_parameters(model)
    logger.info('selector: %d' % sel)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('generator: %d' % gen)
    logger.info('* number of parameters: %d' % n_params)

    print_trainable_parameters(model)

    _check_save_model_path(opt)

    # Build optimizer.
    optim = build_optim(model, opt, checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(opt,
                            model,
                            fields,
                            optim,
                            data_type,
                            model_saver=model_saver)

    def train_iter_fct():
        return build_dataset_iter(lazily_load_dataset("train", opt), fields,
                                  opt)

    def valid_iter_fct():
        return build_dataset_iter(lazily_load_dataset("valid", opt), fields,
                                  opt)

    # Do training.
    trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps,
                  opt.valid_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
Exemplo n.º 4
0
def main(opt, device_id):
    opt = training_opt_postprocessing(opt, device_id)
    init_logger(opt.log_file)
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)

        # Load default opts values then overwrite it with opts from
        # the checkpoint. It's usefull in order to re-train a model
        # after adding a new option (not set in checkpoint)
        dummy_parser = configargparse.ArgumentParser()
        opts.model_opts(dummy_parser)
        default_opt = dummy_parser.parse_known_args([])[0]

        model_opt = default_opt
        model_opt.__dict__.update(checkpoint['opt'].__dict__)
    else:
        checkpoint = None
        model_opt = opt

    # Peek the first dataset to determine the data_type.
    # (All datasets have the same data_type).
    first_dataset = next(lazily_load_dataset("train", opt))
    # print("[onmt.train_single.py] first_dataset.examples[0]: {}".format(first_dataset.examples[0]))
    # print("[onmt.train_single.py] first_dataset.examples[0].src[:10]: {}".format(first_dataset.examples[0].src[:10]))
    print("[onmt.train_single.py] first_dataset.examples[0].src_da_label: {}".
          format(first_dataset.examples[0].src_da_label))
    print(
        "[onmt.train_single.py] first_dataset.examples[0].__dict__.keys(): {}".
        format(first_dataset.examples[0].__dict__.keys()))
    data_type = first_dataset.data_type
    print("[onmt.train_single.py] first_dataset.data_type: {}".format(
        first_dataset.data_type))

    # Load fields generated from preprocess phase.
    fields = load_fields(first_dataset, opt, checkpoint)

    # Report src/tgt features.

    knl_features, src_features, tgt_features = _collect_report_features(fields)
    for j, feat in enumerate(src_features):
        logger.info(' * knl feature %d size = %d' %
                    (j, len(fields[feat].vocab)))
    for j, feat in enumerate(src_features):
        logger.info(' * src feature %d size = %d' %
                    (j, len(fields[feat].vocab)))
    for j, feat in enumerate(tgt_features):
        logger.info(' * tgt feature %d size = %d' %
                    (j, len(fields[feat].vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = build_optim(model, opt, checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            data_type,
                            model_saver=model_saver)

    def train_iter_fct():
        return build_dataset_iter(lazily_load_dataset("train", opt), fields,
                                  opt)

    def valid_iter_fct():
        return build_dataset_iter(lazily_load_dataset("valid", opt),
                                  fields,
                                  opt,
                                  is_train=False)

    # Do training.
    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps,
                  opt.valid_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
Exemplo n.º 5
0
def main(opt, device_id):
    opt = training_opt_postprocessing(opt, device_id)
    init_logger(opt.log_file)
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        model_opt = checkpoint['opt']
    else:
        checkpoint = None
        model_opt = opt

    # Peek the first dataset to determine the data_type.
    # (All datasets have the same data_type).
    first_dataset = next(lazily_load_dataset("train", opt))
    data_type = first_dataset.data_type

    # Load fields generated from preprocess phase.
    fields = _load_fields(first_dataset, data_type, opt, checkpoint)

    # Report src/tgt features.

    src_features, tgt_features = _collect_report_features(fields)
    for j, feat in enumerate(src_features):
        logger.info(' * src feature %d size = %d' %
                    (j, len(fields[feat].vocab)))
    for j, feat in enumerate(tgt_features):
        logger.info(' * tgt feature %d size = %d' %
                    (j, len(fields[feat].vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    if opt.train_from and opt.reset_optim != 'all':
        logger.info('* checkpoint training not considered by me yet')
    else:
        # warmup_steps and rnn_size are parameters for Noam decay (transformer):
        #    https://arxiv.org/pdf/1706.03762.pdf (Section 3)
        decay_method = opt.decay_method if opt.decay_method else "standard"
        logger.info(
            '* Opt: %s (rate %.5f, maxgnorm %.1f, %s decay, '
            'decay_rate %.1f, start_decay_at %d, decay_every %d, '
            'ab1 %.5f, ab2 %.5f, adagradaccum %.1f, '
            'warmupsteps %d, hiddensize %d)' %
            (opt.optim, opt.learning_rate, opt.max_grad_norm, decay_method,
             opt.learning_rate_decay, opt.start_decay_steps, opt.decay_steps,
             opt.adam_beta1, opt.adam_beta2, opt.adagrad_accumulator_init,
             opt.warmup_steps, opt.rnn_size))

    optim = build_optim(model, opt, checkpoint)

    # Build model saver                               v
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    logger.info('* model_saver built, using it to build trainer with ')

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            data_type,
                            model_saver=model_saver)

    #---------------------------------------------------------------------------
    # 1. lazily_load_dataset = for pt in pts: yield torch.load(pt)
    # 2. build_dataset_iter  = return DatasetLazyIter (train_iter_fct)
    # 3. train_iter_fct()    = iterator over torchtext.data.batch.Batches
    #---------------------------------------------------------------------------
    def train_iter_fct():
        return build_dataset_iter(lazily_load_dataset("train", opt), fields,
                                  opt)

    def valid_iter_fct():
        return build_dataset_iter(lazily_load_dataset("valid", opt),
                                  fields,
                                  opt,
                                  is_train=False)

    # Do training.
    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps,
                  opt.valid_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
Exemplo n.º 6
0
        opt = train_args
        opt = training_opt_postprocessing(opt)

        model_opt = opt
        checkpoint = None
        # Peek the fisrt dataset to determine the data_type.
        # (All datasets have the same data_type).
        first_dataset = next(lazily_load_dataset("train", opt))
        data_type = first_dataset.data_type

        # Load fields generated from preprocess phase.
        fields = _load_fields(first_dataset, data_type, opt, checkpoint)

        # Report src/tgt features.
        _collect_report_features(fields)

        # Build model.
        model = build_model(model_opt, opt, fields, checkpoint)
        remove(args.host, args.port, "OpenNMT")

        probe(
            args.name,
            model,
            args.host,
            args.port,
            when=lambda m, o: m._v.state == "dev",
            which=lambda m, o:
            True,  #o._v.operation_name in ["encoder", "decoder"],
            parameters=False,
            forward=True,
Exemplo n.º 7
0
def main(opt, device_id):
    opt = training_opt_postprocessing(opt, device_id)
    init_logger(opt.log_file)
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        model_opt = checkpoint['opt']
    else:
        checkpoint = None
        model_opt = opt

    # Peek the first dataset to determine the data_type.
    # (All datasets have the same data_type).
    first_dataset = next(lazily_load_dataset("train", opt))
    data_type = first_dataset.data_type

    # Load fields generated from preprocess phase.
    fields = _load_fields(first_dataset, data_type, opt, checkpoint)

    # Report src/tgt features.

    src_features, tgt_features = _collect_report_features(fields)
    for j, feat in enumerate(src_features):
        logger.info(' * src feature %d size = %d' %
                    (j, len(fields[feat].vocab)))
    for j, feat in enumerate(tgt_features):
        logger.info(' * tgt feature %d size = %d' %
                    (j, len(fields[feat].vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = build_optim(model, opt, checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            data_type,
                            model_saver=model_saver)

    def train_iter_fct():
        return build_dataset_iter(lazily_load_dataset("train", opt, True),
                                  fields, opt)

    def valid_iter_fct():
        return build_dataset_iter(lazily_load_dataset("valid", opt),
                                  fields,
                                  opt,
                                  is_train=False)

    def monitor_iter_fct():
        monitor_data = dict()
        for src, tgt in zip(opt.monitor_src, opt.monitor_tgt):
            fname = src.split("/" if "/" in src else "\\")[-1].split(
                ".")[0].replace("_src", "")
            monitor_data[fname] = build_dataset_iter(lazily_load_dataset(
                "monitor", opt, fname),
                                                     fields,
                                                     opt,
                                                     is_train=False)
        return monitor_data

    # Do training.
    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    trainer.train(train_iter_fct, valid_iter_fct, monitor_iter_fct,
                  opt.train_steps, opt.valid_steps, opt.monitor_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
Exemplo n.º 8
0
def main(opt, device_id):
    opt = training_opt_postprocessing(opt, device_id)
    init_logger(opt.log_file)

    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        model_opt = checkpoint['opt']
    else:
        checkpoint = None
        model_opt = opt

    first_dataset = pickle.load(open('processed_data/all-train/train.pt',
                                     'rb'))
    data_type = first_dataset.data_type

    # Load fields generated from preprocess phase.
    fields = _load_fields(first_dataset, data_type, opt, checkpoint)
    # Report src/tgt features.
    src_features, tgt_features = _collect_report_features(fields)
    for j, feat in enumerate(src_features):
        logger.info(' * src feature %d size = %d' %
                    (j, len(fields[feat].vocab)))
    for j, feat in enumerate(tgt_features):
        logger.info(' * tgt feature %d size = %d' %
                    (j, len(fields[feat].vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)

    optim = build_optim(model, opt, checkpoint)  # opt.train_from == ''
    # Build model saver
    if not os.path.exists('experiments/all_train'):
        os.mkdir('experiments/all_train')
    model_saver = build_model_saver(model_opt, opt.save_model, opt, model,
                                    fields, optim)

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            "text",
                            model_saver=model_saver)

    def _lazy_dataset_loader(pt_file):
        # dataset = torch.load(pt_file)
        def dataset_loader(pt_file):
            with open(pt_file, 'rb') as f:
                dataset = pickle.load(f)
            # logger.info('Loading task from <{}>, number of examples: {}'.format(pt_file, len(dataset)))
            return dataset

        yield dataset_loader(pt_file)

    train_iter = list(
        build_dataset_iter(
            _lazy_dataset_loader('processed_data/all-train/train.pt'), fields,
            opt))

    trainer.train(train_iter, opt.train_epochs)
Exemplo n.º 9
0
def main(opt, device_id):
    opt = training_opt_postprocessing(opt, device_id)
    init_logger(opt.log_file)
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)

        # Load default opts values then overwrite it with opts from
        # the checkpoint. It's usefull in order to re-train a model
        # after adding a new option (not set in checkpoint)
        dummy_parser = configargparse.ArgumentParser()
        opts.model_opts(dummy_parser)
        default_opt = dummy_parser.parse_known_args([])[0]

        model_opt = default_opt
        model_opt.__dict__.update(checkpoint['opt'].__dict__)
    else:
        checkpoint = None
        model_opt = opt

    # Peek the first dataset to determine the data_type.
    # (All datasets have the same data_type).
    first_dataset = next(lazily_load_dataset("train", opt))
    data_type = first_dataset.data_type

    # Load fields generated from preprocess phase.
    fields = load_fields(first_dataset, opt, checkpoint)

    # Report src/tgt features.

    src_features, tgt_features = _collect_report_features(fields)
    for j, feat in enumerate(src_features):
        logger.info(' * src feature %d size = %d' %
                    (j, len(fields[feat].vocab)))
    for j, feat in enumerate(tgt_features):
        logger.info(' * tgt feature %d size = %d' %
                    (j, len(fields[feat].vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = build_optim(model, opt, checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            data_type,
                            model_saver=model_saver)

    def train_iter_fct():
        return build_dataset_iter(lazily_load_dataset("train", opt), fields,
                                  opt)

    def valid_iter_fct():
        return build_dataset_iter(lazily_load_dataset("valid", opt),
                                  fields,
                                  opt,
                                  is_train=False)

    # Do training.
    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    if opt.no_base == False:
        trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps,
                      opt.valid_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()

    if opt.comparable:
        logger.info('')
        logger.info('Beginning comparable data extraction and training.')

        # 1. Initialize Comparable object
        comp = Comparable(model, trainer, fields, logger, opt)

        # 2. Infer similarity threshold from training data

        for epoch in range(opt.comp_epochs):
            # 3. Update threshold if dynamic
            if opt.threshold_dynamics != 'static' and epoch != 0:
                comp.update_threshold(opt.threshold_dynamics,
                                      opt.infer_threshold)

            # 4. Extract parallel data and train
            #if opt.match_articles:
            #    comparable_data = comp.match_articles(opt.match_articles)
            #    train_stats = comp.extract_and_train(comparable_data)
            #else:
            train_stats = comp.extract_and_train(opt.comparable_data)

            # 5. Validate on validation set
            if opt.no_valid == False:
                valid_iter = build_dataset_iter(
                    lazily_load_dataset("valid", opt), fields, opt)
                valid_stats = comp.validate(valid_iter)

            # 6. Drop a checkpoint if needed
            comp.trainer.model_saver._save(epoch)
Exemplo n.º 10
0
def main(opt, device_id):
    opt = training_opt_postprocessing(opt, device_id)
    init_logger(opt.log_file)
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info("Loading checkpoint from %s" % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)

        # Load default opts values then overwrite it with opts from
        # the checkpoint. It's usefull in order to re-train a model
        # after adding a new option (not set in checkpoint)
        dummy_parser = configargparse.ArgumentParser()
        opts.model_opts(dummy_parser)
        default_opt = dummy_parser.parse_known_args([])[0]

        model_opt = default_opt
        model_opt.__dict__.update(checkpoint["opt"].__dict__)
    else:
        checkpoint = None
        model_opt = opt

    # Peek the first dataset to determine the data_type.
    # (All datasets have the same data_type).
    first_dataset = next(lazily_load_dataset("train", opt))
    data_type = first_dataset.data_type

    # Load fields generated from preprocess phase.
    fields = _load_fields(first_dataset, data_type, opt, checkpoint)

    # Report src/tgt features.

    src_features, tgt_features = _collect_report_features(fields)
    for j, feat in enumerate(src_features):
        logger.info(" * src feature %d size = %d" %
                    (j, len(fields[feat].vocab)))
    for j, feat in enumerate(tgt_features):
        logger.info(" * tgt feature %d size = %d" %
                    (j, len(fields[feat].vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec = _tally_parameters(model)
    logger.info("encoder: %d" % enc)
    logger.info("decoder: %d" % dec)
    logger.info("* number of parameters: %d" % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = build_optim(model, opt, checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            data_type,
                            model_saver=model_saver)

    def train_iter_fct():
        return build_dataset_iter(lazily_load_dataset("train", opt), fields,
                                  opt)

    def valid_iter_fct():
        return build_dataset_iter(lazily_load_dataset("valid", opt),
                                  fields,
                                  opt,
                                  is_train=False)

    # Do training.
    if len(opt.gpu_ranks):
        logger.info("Starting training on GPU: %s" % opt.gpu_ranks)
    else:
        logger.info("Starting training on CPU, could be very slow")

    # import cProfile
    # with cProfile.Profile() as pr:
    trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps,
                  opt.valid_steps)
    # pr.print_stats()
    # pr.dump_stats('/home/philhc/OpenNMT-evidential-softmax/exp/prof.stats')

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
Exemplo n.º 11
0
def training_main(opt):
    opt = training_opt_postprocessing(opt)

    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        print('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        model_opt = checkpoint['opt']
        # I don't like reassigning attributes of opt: it's not clear.
        opt.start_epoch = checkpoint['epoch'] + 1
    else:
        checkpoint = None
        model_opt = opt

    # Peek the fisrt dataset to determine the data_type.
    # (All datasets have the same data_type).
    first_dataset = next(lazily_load_dataset("train", opt))
    data_type = first_dataset.data_type

    # Load fields generated from preprocess phase.
    fields = _load_fields(first_dataset, data_type, opt, checkpoint)

    # Report src/tgt features.
    _collect_report_features(fields)

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    _tally_parameters(model)
    _check_save_model_path(opt)

    model._vivisect = {
        "iteration": 0,
        "model_name": "OpenNMT Model",
        "framework": "pytorch",
        "mode": "train"
    }
    probe(model, "localhost", 8082)

    # Build optimizer.
    optim = build_optim(model, opt, checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(opt,
                            model,
                            fields,
                            optim,
                            data_type,
                            model_saver=model_saver)

    def train_iter_fct():
        model._vivisect["iteration"] += 1
        model._vivisect["mode"] = "train"
        return build_dataset_iter(lazily_load_dataset("train", opt), fields,
                                  opt)

    def valid_iter_fct():
        model._vivisect["mode"] = "dev"
        return build_dataset_iter(lazily_load_dataset("valid", opt), fields,
                                  opt)

    # Do training.
    trainer.train(train_iter_fct, valid_iter_fct, opt.start_epoch, opt.epochs)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()