Exemplo n.º 1
0
def main(opt, device_id, batch_queue=None, semaphore=None):
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    configure_process(opt, device_id)
    init_logger(opt.log_file)
    assert len(opt.accum_count) == len(opt.accum_steps), \
        'Number of accum_count values must match number of accum_steps'
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        ArgumentParser.update_model_opts(model_opt)
        ArgumentParser.validate_model_opts(model_opt)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        checkpoint = None
        model_opt = opt
        vocab = torch.load(opt.data + '.vocab.pt')

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(vocab,
                                opt.model_type,
                                dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    # Report src and tgt vocab sizes, including for features
    data_keys = [f"src.{src_type}" for src_type in opt.src_types] + ["tgt"]
    for side in data_keys:
        f = fields[side]
        try:
            f_iter = iter(f)
        except TypeError:
            f_iter = [(side, f)]
        for sn, sf in f_iter:
            if sf.use_vocab:
                logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

    # Build model.
    logger.info('Building model...')
    if opt.type_append:
        model = MultiSourceS2STypeAppendedModelBuilder.build_model(
            opt.src_types, model_opt, opt, fields, checkpoint)
    else:
        model = MultiSourceModelBuilder.build_model(opt.src_types, model_opt,
                                                    opt, fields, checkpoint)

    logger.info(model)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = MultiSourceModelSaver.build_model_saver(
        opt.src_types, model_opt, opt, model, fields, optim)
    if opt.consist_reg:
        trainer = MultiSourceCRTrainer.build_trainer(opt.src_types,
                                                     opt,
                                                     device_id,
                                                     model,
                                                     fields,
                                                     optim,
                                                     model_saver=model_saver)
    elif opt.type_append:
        trainer = MultiSourceTypeAppendedTrainer.build_trainer(
            opt.src_types,
            opt,
            device_id,
            model,
            fields,
            optim,
            model_saver=model_saver)
    else:
        trainer = MultiSourceTrainer.build_trainer(opt.src_types,
                                                   opt,
                                                   device_id,
                                                   model,
                                                   fields,
                                                   optim,
                                                   model_saver=model_saver)

    if batch_queue is None:
        if len(opt.data_ids) > 1:
            train_shards = []
            for train_id in opt.data_ids:
                shard_base = "train_" + train_id
                train_shards.append(shard_base)
            train_iter = MultiSourceInputter.build_dataset_iter_multiple(
                opt.src_types, train_shards, fields, opt)
        else:
            if opt.data_ids[0] is not None:
                shard_base = "train_" + opt.data_ids[0]
            else:
                shard_base = "train"
            train_iter = MultiSourceInputter.build_dataset_iter(
                opt.src_types, shard_base, fields, opt)
    else:
        assert semaphore is not None, \
            "Using batch_queue requires semaphore as well"

        def _train_iter():
            while True:
                batch = batch_queue.get()
                semaphore.release()
                yield batch

        train_iter = _train_iter()

    valid_iter = MultiSourceInputter.build_dataset_iter(opt.src_types,
                                                        "valid",
                                                        fields,
                                                        opt,
                                                        is_train=False)

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0

    trainer.train(train_iter,
                  train_steps,
                  save_checkpoint_steps=opt.save_checkpoint_steps,
                  valid_iter=valid_iter,
                  valid_steps=opt.valid_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
Exemplo n.º 2
0
    def train_single(self,
                     output_model_dir: Path,
                     opt,
                     device_id,
                     batch_queue=None,
                     semaphore=None):
        from roosterize.ml.onmt.MultiSourceInputter import MultiSourceInputter
        from roosterize.ml.onmt.MultiSourceModelBuilder import MultiSourceModelBuilder
        from roosterize.ml.onmt.MultiSourceModelSaver import MultiSourceModelSaver
        from roosterize.ml.onmt.MultiSourceTrainer import MultiSourceTrainer
        from onmt.inputters.inputter import load_old_vocab, old_style_vocab
        from onmt.train_single import configure_process, _tally_parameters, _check_save_model_path
        from onmt.utils.optimizers import Optimizer
        from onmt.utils.parse import ArgumentParser

        configure_process(opt, device_id)
        assert len(opt.accum_count) == len(
            opt.accum_steps
        ), 'Number of accum_count values must match number of accum_steps'
        # Load checkpoint if we resume from a previous training.
        if opt.train_from:
            self.logger.info('Loading checkpoint from %s' % opt.train_from)
            checkpoint = torch.load(opt.train_from,
                                    map_location=lambda storage, loc: storage)
            model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
            ArgumentParser.update_model_opts(model_opt)
            ArgumentParser.validate_model_opts(model_opt)
            self.logger.info('Loading vocab from checkpoint at %s.' %
                             opt.train_from)
            vocab = checkpoint['vocab']
        else:
            checkpoint = None
            model_opt = opt
            vocab = torch.load(opt.data + '.vocab.pt')
        # end if

        # check for code where vocab is saved instead of fields
        # (in the future this will be done in a smarter way)
        if old_style_vocab(vocab):
            fields = load_old_vocab(vocab,
                                    opt.model_type,
                                    dynamic_dict=opt.copy_attn)
        else:
            fields = vocab
        # end if

        # Report src and tgt vocab sizes, including for features
        data_keys = [
            f"src.{src_type}" for src_type in self.config.get_src_types()
        ] + ["tgt"]
        for side in data_keys:
            f = fields[side]
            try:
                f_iter = iter(f)
            except TypeError:
                f_iter = [(side, f)]
            # end try
            for sn, sf in f_iter:
                if sf.use_vocab:
                    self.logger.info(' * %s vocab size = %d' %
                                     (sn, len(sf.vocab)))
            # end for

        # Build model
        model = MultiSourceModelBuilder.build_model(
            self.config.get_src_types(), model_opt, opt, fields, checkpoint)
        n_params, enc, dec = _tally_parameters(model)
        self.logger.info('encoder: %d' % enc)
        self.logger.info('decoder: %d' % dec)
        self.logger.info('* number of parameters: %d' % n_params)
        _check_save_model_path(opt)

        # Build optimizer.
        optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

        # Build model saver
        model_saver = MultiSourceModelSaver.build_model_saver(
            self.config.get_src_types(), model_opt, opt, model, fields, optim)

        trainer = MultiSourceTrainer.build_trainer(self.config.get_src_types(),
                                                   opt,
                                                   device_id,
                                                   model,
                                                   fields,
                                                   optim,
                                                   model_saver=model_saver)

        if batch_queue is None:
            if len(opt.data_ids) > 1:
                train_shards = []
                for train_id in opt.data_ids:
                    shard_base = "train_" + train_id
                    train_shards.append(shard_base)
                # end for
                train_iter = MultiSourceInputter.build_dataset_iter_multiple(
                    self.config.get_src_types(), train_shards, fields, opt)
            else:
                if opt.data_ids[0] is not None:
                    shard_base = "train_" + opt.data_ids[0]
                else:
                    shard_base = "train"
                # end if
                train_iter = MultiSourceInputter.build_dataset_iter(
                    self.config.get_src_types(), shard_base, fields, opt)
            # end if
        else:
            assert semaphore is not None, "Using batch_queue requires semaphore as well"

            def _train_iter():
                while True:
                    batch = batch_queue.get()
                    semaphore.release()
                    yield batch
                # end while

            # end def

            train_iter = _train_iter()
        # end if

        valid_iter = MultiSourceInputter.build_dataset_iter(
            self.config.get_src_types(), "valid", fields, opt, is_train=False)

        if len(opt.gpu_ranks):
            self.logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
        else:
            self.logger.info('Starting training on CPU, could be very slow')
        # end if
        train_steps = opt.train_steps
        if opt.single_pass and train_steps > 0:
            self.logger.warning(
                "Option single_pass is enabled, ignoring train_steps.")
            train_steps = 0
        # end if

        trainer.train(train_iter,
                      train_steps,
                      save_checkpoint_steps=opt.save_checkpoint_steps,
                      valid_iter=valid_iter,
                      valid_steps=opt.valid_steps)
        time_begin = trainer.report_manager.start_time
        time_end = time.time()

        if opt.tensorboard: trainer.report_manager.tensorboard_writer.close()

        # Dump train metrics
        train_history = trainer.report_manager.get_joint_history()
        train_metrics = {
            "time_begin": time_begin,
            "time_end": time_end,
            "time": time_end - time_begin,
            "train_history": train_history,
        }
        IOUtils.dump(output_model_dir / "train-metrics.json", train_metrics,
                     IOUtils.Format.jsonNoSort)

        # Get the best step, depending on the lowest val_xent (cross entropy)
        best_loss = min([th["val_xent"] for th in train_history])
        best_step = [
            th["step"] for th in train_history if th["val_xent"] == best_loss
        ][-1]  # Take the last if multiple
        IOUtils.dump(output_model_dir / "best-step.json", best_step,
                     IOUtils.Format.json)
        return
def validate(opt, device_id=0):
    configure_process(opt, device_id)
    configure_process
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        ArgumentParser.update_model_opts(model_opt)
        ArgumentParser.validate_model_opts(model_opt)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']

        if old_style_vocab(vocab):
            fields = load_old_vocab(vocab,
                                    opt.model_type,
                                    dynamic_dict=opt.copy_attn)
        else:
            fields = vocab

    # Report src and tgt vocab sizes, including for features
    for side in ['src', 'tgt']:
        f = fields[side]
        try:
            f_iter = iter(f)
        except TypeError:
            f_iter = [(side, f)]
        for sn, sf in f_iter:
            if sf.use_vocab:
                logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    #_check_save_model_path(opt)

    valid_iter = build_dataset_iter("valid", fields, opt, is_train=False)

    tgt_field = dict(fields)["tgt"].base_field
    valid_loss = onmt.utils.loss.build_loss_compute(model,
                                                    tgt_field,
                                                    opt,
                                                    train=False)

    model.eval()

    with torch.no_grad():
        stats = onmt.utils.Statistics()

        for batch in valid_iter:

            src, src_lengths = batch.src if isinstance(batch.src, tuple) \
                                   else (batch.src, None)
            tgt = batch.tgt

            # F-prop through the model.
            outputs, attns = model(src, tgt, src_lengths)

            # Compute loss.
            _, batch_stats = valid_loss(batch, outputs, attns)

            # Update statistics.
            stats.update(batch_stats)

    print('n words:  %d' % stats.n_words)
    print('Validation perplexity: %g' % stats.ppl())
    print('Validation accuracy: %g' % stats.accuracy())
    print('Validation avg attention entropy: %g' % stats.attn_entropy())