Пример #1
0
def main(opt, device_id):
    opt = training_opt_postprocessing(opt, device_id)
    init_logger(opt.log_file)
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)

        # Load default opts values then overwrite it with opts from
        # the checkpoint. It's usefull in order to re-train a model
        # after adding a new option (not set in checkpoint)
        dummy_parser = configargparse.ArgumentParser()
        opts.model_opts(dummy_parser)
        default_opt = dummy_parser.parse_known_args([])[0]

        model_opt = default_opt
        model_opt.__dict__.update(checkpoint['opt'].__dict__)
    else:
        checkpoint = None
        model_opt = opt

    # Peek the first dataset to determine the data_type.
    # (All datasets have the same data_type).
    first_dataset = next(lazily_load_dataset("train", opt))
    data_type = first_dataset.data_type

    # Load fields generated from preprocess phase.
    fields = _load_fields(first_dataset, data_type, opt, checkpoint)

    # Report src/tgt features.

    src_features, tgt_features = _collect_report_features(fields)
    for j, feat in enumerate(src_features):
        logger.info(' * src feature %d size = %d' %
                    (j, len(fields[feat].vocab)))
    for j, feat in enumerate(tgt_features):
        logger.info(' * tgt feature %d size = %d' %
                    (j, len(fields[feat].vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = build_optim(model, opt, checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            data_type,
                            model_saver=model_saver)

    def train_iter_fct():
        return build_dataset_iter(lazily_load_dataset("train", opt), fields,
                                  opt)

    def valid_iter_fct():
        return build_dataset_iter(lazily_load_dataset("valid", opt),
                                  fields,
                                  opt,
                                  is_train=False)

    # Do training.
    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps,
                  opt.valid_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
Пример #2
0
def main(opt,
         device_id,
         batch_queue=None,
         semaphore=None,
         train_iter=None,
         passed_fields=None):
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    configure_process(opt, device_id)
    init_logger(opt.log_file)
    assert len(opt.accum_count) == len(opt.accum_steps), \
        'Number of accum_count values must match number of accum_steps'
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        if opt.use_opt_from_trained:
            model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        else:
            model_opt = opt
        ArgumentParser.validate_model_opts(model_opt)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        checkpoint = None
        model_opt = opt
        vocab = torch.load(opt.data + '.vocab.pt')

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    aux_fields = None
    if passed_fields is not None:
        fields = passed_fields['main']
        aux_fields = passed_fields['crosslingual']
    elif old_style_vocab(vocab):
        fields = load_old_vocab(vocab,
                                opt.model_type,
                                dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    # Report src and tgt vocab sizes, including for features
    for side in ['src', 'tgt']:
        f = fields[side]
        try:
            f_iter = iter(f)
        except TypeError:
            f_iter = [(side, f)]
        for sn, sf in f_iter:
            if sf.use_vocab:
                logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

    # Build model.
    model = build_model(model_opt,
                        opt,
                        fields,
                        checkpoint,
                        aux_fields=aux_fields)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    if opt.almt_only:
        almt = model.encoder.embeddings.almt_layers['mapping']
        logger.info('Only training the alignment mapping.')
        optim = Optimizer.from_opt(almt, opt, checkpoint=checkpoint)
    else:
        optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt,
                                    opt,
                                    model,
                                    fields,
                                    optim,
                                    aux_fields=aux_fields)

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            model_saver=model_saver,
                            aux_fields=aux_fields)

    if train_iter is not None:
        pass  # NOTE Use the passed one.
    elif batch_queue is None:
        if len(opt.data_ids) > 1:
            train_shards = []
            for train_id in opt.data_ids:
                shard_base = "train_" + train_id
                train_shards.append(shard_base)
            train_iter = build_dataset_iter_multiple(train_shards, fields, opt)
        else:
            if opt.data_ids[0] is not None:
                shard_base = "train_" + opt.data_ids[0]
            else:
                shard_base = "train"
            train_iter = build_dataset_iter(shard_base, fields, opt)

    else:
        assert semaphore is not None, \
            "Using batch_queue requires semaphore as well"

        def _train_iter():
            while True:
                batch = batch_queue.get()
                semaphore.release()
                yield batch

        train_iter = _train_iter()

    cl_valid_iter = None
    if opt.crosslingual:
        valid_iter = build_dataset_iter("valid",
                                        fields,
                                        opt,
                                        is_train=False,
                                        task_cls=Eat2PlainMonoTask)
        if opt.crosslingual_dev_data:
            # NOTE I used 'train' to prepare this in `eat_prepare.sh`, so I use 'train' here as well.
            cl_valid_iter = build_dataset_iter(
                'train',
                fields,
                opt,
                is_train=False,
                data_attr='crosslingual_dev_data',
                task_cls=Eat2PlainCrosslingualTask)
        # NOTE This is for the second eat->plain task.
        aux_valid_iter = build_dataset_iter('valid',
                                            fields,
                                            opt,
                                            is_train=False,
                                            data_attr='aux_train_data',
                                            task_cls=Eat2PlainMonoTask)
        valid_iters = [valid_iter, aux_valid_iter]
    else:
        valid_iters = [
            build_dataset_iter("valid", fields, opt, is_train=False)
        ]

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0

    trainer.train(train_iter,
                  train_steps,
                  save_checkpoint_steps=opt.save_checkpoint_steps,
                  valid_iters=valid_iters,
                  valid_steps=opt.valid_steps,
                  cl_valid_iter=cl_valid_iter)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
Пример #3
0
def main(opt, device_id):
    opt = training_opt_postprocessing(opt, device_id)
    init_logger(opt.log_file)
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)

        # Load default opts values then overwrite it with opts from
        # the checkpoint. It's usefull in order to re-train a model
        # after adding a new option (not set in checkpoint)
        dummy_parser = configargparse.ArgumentParser()
        opts.model_opts(dummy_parser)
        default_opt = dummy_parser.parse_known_args([])[0]

        model_opt = default_opt
        model_opt.__dict__.update(checkpoint['opt'].__dict__)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        checkpoint = None
        model_opt = opt
        vocab = torch.load(opt.data + '.vocab.pt')

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        data_type = opt.model_type
        fields = load_old_vocab(vocab, data_type, dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    # Report src and tgt vocab sizes, including for features
    for side in ['src', 'tgt']:
        for name, f in fields[side]:
            try:
                f_iter = iter(f)
            except TypeError:
                f_iter = [(name, f)]
            for sn, sf in f_iter:
                if sf.use_vocab:
                    logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            model_saver=model_saver)

    # this line is kind of a temporary kludge because different objects expect
    # fields to have a different structure
    dataset_fields = dict(chain.from_iterable(fields.values()))

    train_iter = build_dataset_iter("train",
                                    dataset_fields,
                                    opt,
                                    is_train=True)
    valid_iter = build_dataset_iter("valid",
                                    dataset_fields,
                                    opt,
                                    is_train=False)

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    trainer.train(train_iter, valid_iter, opt.train_steps, opt.valid_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
Пример #4
0
    def __init__(self, model_dir):

        # Model dir
        self._model_dir = os.path.abspath(model_dir)
        if not os.path.isdir(self._model_dir):
            msg = f"{model_dir} doesn't exists'"
            raise ValueError(msg)

        # Extended model
        self._extended_model = ExtendedModel(model_dir)

        # Config
        self._config = self._extended_model.config

        # Options
        self._opts = self._config.opts

        # Get the model options
        model_path = self._opts.models[0]
        checkpoint = torch.load(model_path,
                                map_location=lambda storage, loc: storage)
        self._model_opts = ArgumentParser.ckpt_model_opts(checkpoint['opt'])
        ArgumentParser.update_model_opts(self._model_opts)
        ArgumentParser.validate_model_opts(self._model_opts)

        # Train_steps
        self._train_steps = self._model_opts.train_steps

        # Extract vocabulary
        vocab = checkpoint['vocab']
        if inputters.old_style_vocab(vocab):
            self._fields = inputters.load_old_vocab(
                vocab,
                self._opts.data_type,
                dynamic_dict=self._model_opts.copy_attn)
        else:
            self._fields = vocab

        # Build model
        self._model = build_base_model(self._model_opts, self._fields,
                                       use_gpu(self._opts), checkpoint,
                                       self._opts.gpu)

        if self._opts.fp32:
            self._model.float()

        #Translator
        scorer = GNMTGlobalScorer.from_opt(self._opts)

        self.translator = OnmtxTranslator.from_opt(
            self._model,
            self._fields,
            self._opts,
            self._model_opts,
            global_scorer=scorer,
            out_file=None,
            report_score=False,
            logger=None,
        )

        # Create trainer
        self._optim = Optimizer.from_opt(self._model,
                                         self._opts,
                                         checkpoint=checkpoint)

        device_id = -1  # TODO Handle GPU
        self.trainer = build_trainer(self._opts, device_id, self._model,
                                     self._fields, self._optim)
Пример #5
0
def main(opt, device_id):
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    configure_process(opt, device_id)
    init_logger(opt.log_file)
    assert len(opt.accum_count) == len(opt.accum_steps), \
        'Number of accum_count values must match number of accum_steps'
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)

        model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        ArgumentParser.update_model_opts(model_opt)
        ArgumentParser.validate_model_opts(model_opt)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        checkpoint = None
        model_opt = opt
        #vocab = torch.load(opt.data + '.vocab.pt')

    train_iters = OrderedDict()
    valid_iters = OrderedDict()

    encoders = OrderedDict()
    decoders = OrderedDict()

    generators = OrderedDict()
    src_vocabs = OrderedDict()
    tgt_vocabs = OrderedDict()
    Fields_dict = OrderedDict()

    # variables needed for sharing the same embedding matrix across encoders and decoders
    firstTime = True
    weightToShare = None

    # we share the word embedding space when source lang and target lang are the same
    mapLang2Emb = {}
    #for (src_tgt_lang), data_path in zip(opt.src_tgt, opt.data):
    for index in range(len(opt.src_tgt)):
        src_tgt_lang = opt.src_tgt[index]
        data_path = opt.data[index]
        local_enc_dec_opts = AttrDict({
            key: model_opt.__dict__[key]
            for key in model_opt.__dict__.keys()
        })
        local_enc_dec_opts.model_type = update_to_local_attr(
            model_opt.model_type, index)
        #local_enc_dec_opts.audio_enc_pooling = model_opt.audio_enc_pooling[index]
        local_enc_dec_opts.audio_enc_pooling = update_to_local_attr(
            model_opt.audio_enc_pooling, index)
        local_enc_dec_opts.enc_layers = update_to_local_attr(
            model_opt.enc_layers, index)
        local_enc_dec_opts.dec_layers = update_to_local_attr(
            model_opt.dec_layers, index)
        local_enc_dec_opts.rnn_type = update_to_local_attr(
            model_opt.rnn_type, index)
        local_enc_dec_opts.encoder_type = update_to_local_attr(
            model_opt.encoder_type, index)
        local_enc_dec_opts.batch_size = update_to_local_attr(
            model_opt.batch_size, index)
        local_enc_dec_opts.batch_type = update_to_local_attr(
            model_opt.batch_type, index)
        local_enc_dec_opts.normalization = update_to_local_attr(
            model_opt.normalization, index)
        #local_enc_dec_opts.dec_rnn_size = model_opt.dec_rnn_size[index]

        src_lang, tgt_lang = src_tgt_lang.split('-')

        vocab = torch.load(data_path + '.vocab.pt')

        # check for code where vocab is saved instead of fields
        # (in the future this will be done in a smarter way)
        if old_style_vocab(vocab):
            fields = load_old_vocab(vocab,
                                    opt.model_type[0],
                                    dynamic_dict=opt.copy_attn)
        else:
            fields = vocab

        # Report src and tgt vocab sizes, including for features
        for side in ['src', 'tgt']:
            f = fields[side]
            try:
                f_iter = iter(f)
            except TypeError:
                f_iter = [(side, f)]
            for sn, sf in f_iter:
                if sf.use_vocab:
                    logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

        # Build model.
        encoder, src_embeddings = build_embeddings_then_encoder(
            local_enc_dec_opts, fields)

        encoders[src_lang] = encoder

        decoder, generator, tgt_embeddings = build_decoder_and_generator(
            local_enc_dec_opts, fields)

        decoders[tgt_lang] = decoder

        # Share the embedding matrix across all the encoders and decoders - preprocess with share_vocab required.
        if model_opt.share_embeddings and firstTime:
            tgt_embeddings.word_lut.weight = src_embeddings.word_lut.weight
            weightToShare = src_embeddings.word_lut.weight
        if model_opt.share_embeddings and (not firstTime):
            tgt_embeddings.word_lut.weight = weightToShare
            src_embeddings.word_lut.weight = weightToShare
        firstTime = False

        #TEST
        #if src_lang in mapLang2Emb:
        if src_lang in mapLang2Emb and model_opt.model_type == "text":
            encoder.embeddings.word_lut.weight = mapLang2Emb.get(src_lang)
        #TEST
        #else:
        elif model_opt.model_type == "text":
            mapLang2Emb[src_lang] = src_embeddings.word_lut.weight
        if tgt_lang in mapLang2Emb:
            decoder.embeddings.word_lut.weight = mapLang2Emb.get(tgt_lang)
        else:
            mapLang2Emb[tgt_lang] = tgt_embeddings.word_lut.weight

        #TEST
        if model_opt.model_type == "text":
            src_vocabs[src_lang] = fields['src'].base_field.vocab
        tgt_vocabs[tgt_lang] = fields['tgt'].base_field.vocab

        generators[tgt_lang] = generator

        # add this dataset iterator to the training iterators
        train_iters[(src_lang, tgt_lang)] = build_dataset_iter_fct(
            'train', fields, data_path, local_enc_dec_opts)
        # add this dataset iterator to the validation iterators
        valid_iters[(src_lang,
                     tgt_lang)] = build_dataset_iter_fct('valid',
                                                         fields,
                                                         data_path,
                                                         local_enc_dec_opts,
                                                         is_train=False)

        Fields_dict[src_tgt_lang] = fields

    # Build model.
    model = build_model(model_opt, opt, fields, encoders, decoders, generators,
                        src_vocabs, tgt_vocabs, checkpoint)

    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, Fields_dict, optim)

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            generators,
                            tgt_vocabs,
                            model_saver=model_saver)

    # TODO: not implemented yet
    #train_iterables = []
    #if len(opt.data_ids) > 1:
    #    for train_id in opt.data_ids:
    #        shard_base = "train_" + train_id
    #        iterable = build_dataset_iter(shard_base, fields, opt, multi=True)
    #        train_iterables.append(iterable)
    #    train_iter = MultipleDatasetIterator(train_iterables, device_id, opt)
    #else:
    #    train_iter = build_dataset_iter("train", fields, opt)

    #valid_iter = build_dataset_iter(
    #    "valid", fields, opt, is_train=False)

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0
    trainer.train(train_iters, train_steps, opt.save_checkpoint_steps,
                  valid_iters, opt.valid_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
Пример #6
0
def main(opt,
         fields,
         transforms_cls,
         checkpoint,
         device_id,
         batch_queue=None,
         semaphore=None):
    """Start training on `device_id`."""
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    configure_process(opt, device_id)  # Set an appropriate device id
    init_logger(opt.log_file)

    model_opt = _get_model_opts(opt, checkpoint=checkpoint)

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    model.count_parameters(log=logger.info)

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            model_saver=model_saver)

    if batch_queue is None:
        _train_iter = _build_train_iter(opt, fields, transforms_cls)
        train_iter = IterOnDevice(_train_iter, device_id)
    else:
        assert semaphore is not None, \
            "Using batch_queue requires semaphore as well"

        def _train_iter():
            while True:
                batch = batch_queue.get()
                semaphore.release()
                # Move batch to specified device
                IterOnDevice.batch_to_device(batch, device_id)
                yield batch

        train_iter = _train_iter()

    valid_iter = _build_valid_iter(opt, fields, transforms_cls)
    if valid_iter is not None:
        valid_iter = IterOnDevice(valid_iter, device_id)

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0

    opt.variational = True if opt.variational.lower() == 'true' else False
    opt.only_src = True if opt.only_src.lower() == 'true' else False

    if opt.variational:
        variational_staff = build_variational_staff(opt, fields, device_id)
    else:
        variational_staff = None

    trainer.train(train_iter,
                  train_steps,
                  save_checkpoint_steps=opt.save_checkpoint_steps,
                  valid_iter=valid_iter,
                  valid_steps=opt.valid_steps,
                  opts=opt,
                  variational_staff=variational_staff)

    if trainer.report_manager.tensorboard_writer is not None:
        trainer.report_manager.tensorboard_writer.close()
Пример #7
0
def main(opt):
    if opt.gpuid:
        raise AssertionError("gpuid is deprecated \
              see world_size and gpu_ranks")

    assert opt.world_size <= 1, "you don't need multi-gpu for morphology"

    device_id = 0 if len(opt.gpu_ranks) == 1 else -1

    opt = training_opt_postprocessing(opt, device_id)
    init_logger(opt.log_file)
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)

        # Load default opts values then overwrite it with opts from
        # the checkpoint. It's useful in order to re-train a model
        # after adding a new option (not set in checkpoint)
        dummy_parser = configargparse.ArgumentParser()
        opts.model_opts(dummy_parser)
        default_opt = dummy_parser.parse_known_args([])[0]

        model_opt = default_opt
        model_opt.__dict__.update(checkpoint['opt'].__dict__)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        fields = checkpoint['vocab']
    else:
        checkpoint = None
        model_opt = opt
        fields = torch.load(opt.data + '.vocab.pt')

    for key, values in fields.items():
        for name, f in values:
            if f.use_vocab:
                logger.info(' * %s vocab size = %d' % (name, len(f.vocab)))

    # Build model.
    logger.info('Building model...')
    model = build_model(model_opt, fields, use_gpu(opt), checkpoint)
    logger.info(model)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    params = model.parameters()
    optim_args = {"lr": opt.learning_rate}
    if opt.optim == "adam":
        # no need to mess with the default betas
        optim_args["eps"] = 1e-9
    elif opt.optim == "adagrad":
        optim_args["initial_accumulator_value"] = opt.adagrad_accumulator_init
    optim = getattr(torch.optim, opt.optim.title())(params, **optim_args)
    print(optim)

    trainer = build_trainer(opt, model_opt, device_id, model, fields, optim)

    # this line is kind of a temporary kludge because different objects expect
    # fields to have a different structure
    dataset_fields = dict(chain.from_iterable(fields.values()))

    device = "cuda" if opt.gpu_ranks else "cpu"

    train_dataset = torch.load(opt.data + '.train.pt')
    train_dataset.fields = dataset_fields
    train_iter = OrderedIterator(train_dataset,
                                 opt.batch_size,
                                 sort_within_batch=True,
                                 device=device,
                                 repeat=False,
                                 shuffle=not opt.no_shuffle)

    valid_dataset = torch.load(opt.data + '.valid.pt')
    valid_dataset.fields = dataset_fields
    valid_iter = OrderedIterator(valid_dataset,
                                 opt.valid_batch_size,
                                 train=False,
                                 sort_within_batch=True,
                                 device=device)

    logger.info('Starting training on {}'.format(device))
    trainer.train(train_iter, valid_iter, opt.epochs)
Пример #8
0
def main(opt):
    opt = training_opt_postprocessing(opt)
    init_logger(opt.log_file)
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        model_opt = checkpoint['opt']
    else:
        checkpoint = None
        model_opt = opt

    # Peek the first dataset to determine the data_type.
    # (All datasets have the same data_type).
    first_dataset = next(lazily_load_dataset("train", opt))
    data_type = first_dataset.data_type

    # Load fields generated from preprocess phase.
    fields = _load_fields(first_dataset, data_type, opt, checkpoint)

    # Report src/tgt features.

    src_features, tgt_features = _collect_report_features(fields)
    for j, feat in enumerate(src_features):
        logger.info(' * src feature %d size = %d' %
                    (j, len(fields[feat].vocab)))
    for j, feat in enumerate(tgt_features):
        logger.info(' * tgt feature %d size = %d' %
                    (j, len(fields[feat].vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = build_optim(model, opt, checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(opt,
                            model,
                            fields,
                            optim,
                            data_type,
                            model_saver=model_saver)

    def train_iter_fct():
        return build_dataset_iter(lazily_load_dataset("train", opt), fields,
                                  opt)

    def valid_iter_fct():
        return build_dataset_iter(lazily_load_dataset("valid", opt), fields,
                                  opt)

    # Do training.
    trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps,
                  opt.valid_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
Пример #9
0
def main(opt, device_id):
    opt = training_opt_postprocessing(opt, device_id)
    init_logger(opt.log_file)

    ckpt_path = '{}_epoch_{}.pt'.format(
        opt.save_model, opt.load_meta_step)  # opt.train_from != ''
    logger.info('Loading checkpoint from %s' %
                ckpt_path)  # opt.train_from != ''
    checkpoint = torch.load(
        ckpt_path,
        map_location=lambda storage, loc: storage)  # opt.train_from != ''
    model_opt = checkpoint['opt']  # opt.train_from != ''

    fields = load_fields_from_vocab(checkpoint['vocab'],
                                    data_type="text")  # opt.train_from != ''
    # first_dataset = pickle.load(open(opt.data + '/train.pt', 'rb'))  # opt.train_from == ''
    # data_type = first_dataset.data_type  # opt.train_from == ''
    # fields = _load_fields(first_dataset, data_type, opt, checkpoint=None)  # opt.train_from == ''
    # model_opt = opt  # opt.train_from == ''

    # Build model.
    model = build_model(model_opt, opt, fields,
                        checkpoint)  # opt.train_from != ''
    # model = build_model(model_opt, opt, fields, checkpoint=None)  # opt.train_from == ''

    optim = build_optim(model, opt, checkpoint)  # opt.train_from != ''
    # optim = build_optim(model, opt, checkpoint=None)  # opt.train_from == ''
    # Build model saver
    if not os.path.exists('experiments/meta_test'):
        os.mkdir('experiments/meta_test')
        os.mkdir('experiments/meta_test/' + opt.meta_test_task)
    elif not os.path.exists('experiments/meta_test/' + opt.meta_test_task):
        os.mkdir('experiments/meta_test/' + opt.meta_test_task)
    model_saver = build_model_saver(
        model_opt, 'experiments/meta_test/' + opt.meta_test_task + '/model',
        opt, model, fields, optim)

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            "text",
                            model_saver=model_saver)

    train_iter = list(
        build_dataset_iter(lazily_load_dataset("train", opt), fields, opt))
    # do training on trainset of meta-test task
    trainer.train(train_iter, opt.inner_iterations)

    # del first_dataset
    del model_opt
    del fields
    del checkpoint  # opt.train_from != ''
    del model
    del optim
    del model_saver
    del trainer
    gc.collect()

    # do evaluation on devset of meta-test task
    best_dev_score, best_model_path = -10000, None
    out_file = None
    dummy_parser = argparse.ArgumentParser(description='meta_test.py')
    opts.model_opts(dummy_parser)
    dummy_opt = dummy_parser.parse_known_args([])[0]
    for model_path in os.listdir('experiments/meta_test/' +
                                 opt.meta_test_task):
        if model_path.find('.pt') == -1:
            continue
        if out_file is None:
            out_file = codecs.open(opt.output, 'w', 'utf-8')
        fields, model, model_opt = onmt.model_builder.load_test_model(
            opt,
            dummy_opt.__dict__,
            model_path='experiments/meta_test/' + opt.meta_test_task + '/' +
            model_path)

        scorer = onmt.translate.GNMTGlobalScorer(opt.alpha, opt.beta,
                                                 opt.coverage_penalty,
                                                 opt.length_penalty)

        kwargs = {
            k: getattr(opt, k)
            for k in [
                "beam_size", "n_best", "max_length", "min_length",
                "stepwise_penalty", "block_ngram_repeat",
                "ignore_when_blocking", "dump_beam", "report_bleu",
                "replace_unk", "gpu", "verbose", "fast", "mask_from"
            ]
        }
        fields['graph'] = torchtext.data.Field(sequential=False)
        translator = Translator(model,
                                fields,
                                global_scorer=scorer,
                                out_file=out_file,
                                report_score=False,
                                copy_attn=model_opt.copy_attn,
                                logger=logger,
                                log_probs_out_file=None,
                                **kwargs)
        # make translation and save result
        all_scores, all_predictions = translator.translate(
            src_path='processed_data/meta-test/' + opt.meta_test_task +
            '/src-dev.txt',
            tgt_path=None,
            src_dir=None,
            batch_size=opt.translate_batch_size,
            attn_debug=False)
        # dump predictions
        f = open('experiments/meta_test/' + opt.meta_test_task +
                 '/dev_predictions.csv',
                 'w',
                 encoding='utf-8')
        f.write('smiles,property\n')
        for n_best_mols in all_predictions:
            for mol in n_best_mols:
                f.write(mol.replace(' ', '') + ',0\n')
        f.close()
        # call chemprop to get scores
        test_path = '\"' + 'experiments/meta_test/' + opt.meta_test_task + '/dev_predictions.csv' + '\"'
        checkpoint_path = '\"scorer_ckpts/' + opt.meta_test_task + '/model.pt' + '\"'
        preds_path = '\"' + 'experiments/meta_test/' + opt.meta_test_task + '/dev_scores.csv' + '\"'

        # in case of all mols are invalid (will produce not output file by chemprop)
        # the predictions are copied into score file
        cmd = 'cp {} {}'.format(test_path, preds_path)
        result = os.popen(cmd)
        result.close()

        cmd = 'python chemprop/predict.py --test_path {} --checkpoint_path {} --preds_path {} --num_workers 0'.format(
            test_path, checkpoint_path, preds_path)
        scorer_result = os.popen(cmd)
        scorer_result.close()
        # read score file and get score

        score = read_score_csv('experiments/meta_test/' + opt.meta_test_task +
                               '/dev_scores.csv')

        assert len(score) % opt.beam_size == 0

        # report if it is the best on devset
        dev_metrics = calculate_metrics(opt.meta_test_task, 'test', 'dev',
                                        score)
        logger.info('dev metrics: ' + str(dev_metrics))
        dev_score = dev_metrics['success_rate']
        if dev_score > best_dev_score:
            logger.info('New best dev success rate: {:.4f} by {}'.format(
                dev_score, model_path))
            best_model_path = model_path
            best_dev_score = dev_score
        else:
            logger.info('dev success rate: {:.4f} by {}'.format(
                dev_score, model_path))

        del fields
        del model
        del model_opt
        del scorer
        del translator
        gc.collect()

    # do testing on testset of meta-test task
    out_file.close()
    out_file = codecs.open(opt.output, 'w', 'utf-8')
    fields, model, model_opt = onmt.model_builder.load_test_model(
        opt,
        dummy_opt.__dict__,
        model_path='experiments/meta_test/' + opt.meta_test_task + '/' +
        best_model_path)

    scorer = onmt.translate.GNMTGlobalScorer(opt.alpha, opt.beta,
                                             opt.coverage_penalty,
                                             opt.length_penalty)

    kwargs = {
        k: getattr(opt, k)
        for k in [
            "beam_size", "n_best", "max_length", "min_length",
            "stepwise_penalty", "block_ngram_repeat", "ignore_when_blocking",
            "dump_beam", "report_bleu", "replace_unk", "gpu", "verbose",
            "fast", "mask_from"
        ]
    }
    kwargs['beam_size'] = 100
    kwargs['n_best'] = 100
    fields['graph'] = torchtext.data.Field(sequential=False)
    translator = Translator(model,
                            fields,
                            global_scorer=scorer,
                            out_file=out_file,
                            report_score=False,
                            copy_attn=model_opt.copy_attn,
                            logger=logger,
                            log_probs_out_file=None,
                            **kwargs)
    # make translation and save result
    all_scores, all_predictions = translator.translate(
        src_path='processed_data/meta-test/' + opt.meta_test_task +
        '/src-test.txt',
        tgt_path=None,
        src_dir=None,
        batch_size=opt.translate_batch_size,
        attn_debug=False)
    # dump predictions
    f = open('experiments/meta_test/' + opt.meta_test_task +
             '/test_predictions.csv',
             'w',
             encoding='utf-8')
    f.write('smiles,property\n')
    for n_best_mols in all_predictions:
        for mol in n_best_mols:
            f.write(mol.replace(' ', '') + ',0\n')
    f.close()
    # call chemprop to get scores
    test_path = '\"' + 'experiments/meta_test/' + opt.meta_test_task + '/test_predictions.csv' + '\"'
    checkpoint_path = '\"scorer_ckpts/' + opt.meta_test_task + '/model.pt' + '\"'
    preds_path = '\"' + 'experiments/meta_test/' + opt.meta_test_task + '/test_scores.csv' + '\"'

    # in case of all mols are invalid (will produce not output file by chemprop)
    # the predictions are copied into score file
    cmd = 'cp {} {}'.format(test_path, preds_path)
    result = os.popen(cmd)
    result.close()

    cmd = 'python chemprop/predict.py --test_path {} --checkpoint_path {} --preds_path {} --num_workers 0'.format(
        test_path, checkpoint_path, preds_path)
    scorer_result = os.popen(cmd)
    # logger.info('{}'.format('\n'.join(scorer_result.readlines())))
    scorer_result.close()
    # read score file and get score

    score = read_score_csv('experiments/meta_test/' + opt.meta_test_task +
                           '/test_scores.csv')

    assert len(score) % opt.beam_size == 0

    # report if it is the best on dev
    test_metrics = calculate_metrics(opt.meta_test_task, 'test', 'test', score)
    logger.info('test metrics: ' + str(test_metrics))
    test_score = test_metrics['success_rate']
    # logger.info('test success rate: {:.4f} by {}'.format(test_score, ckpt_path))
    logger.info('test success rate: {:.4f} by {}'.format(
        test_score, 'no pretrain'))
Пример #10
0
def main(opt, device_id):
    opt = training_opt_postprocessing(opt, device_id)
    init_logger(opt.log_file)
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)

        # Load default opts values then overwrite it with opts from
        # the checkpoint. It's usefull in order to re-train a model
        # after adding a new option (not set in checkpoint)
        dummy_parser = configargparse.ArgumentParser()
        opts.model_opts(dummy_parser)
        default_opt = dummy_parser.parse_known_args([])[0]

        model_opt = default_opt
        model_opt.__dict__.update(checkpoint['opt'].__dict__)
    else:
        checkpoint = None
        model_opt = opt

    # Peek the first dataset to determine the data_type.
    # (All datasets have the same data_type).
    first_dataset = next(lazily_load_dataset("train", opt))
    data_type = first_dataset.data_type

    # Load fields generated from preprocess phase.
    fields = load_fields(first_dataset, opt, checkpoint)

    # Report src/tgt features.

    src_features, tgt_features = _collect_report_features(fields)
    for j, feat in enumerate(src_features):
        logger.info(' * src feature %d size = %d' %
                    (j, len(fields[feat].vocab)))
    for j, feat in enumerate(tgt_features):
        logger.info(' * tgt feature %d size = %d' %
                    (j, len(fields[feat].vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = build_optim(model, opt, checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            data_type,
                            model_saver=model_saver)

    def train_iter_fct():
        return build_dataset_iter(lazily_load_dataset("train", opt), fields,
                                  opt)

    def valid_iter_fct():
        return build_dataset_iter(lazily_load_dataset("valid", opt),
                                  fields,
                                  opt,
                                  is_train=False)

    # Do training.
    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    if opt.no_base == False:
        trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps,
                      opt.valid_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()

    if opt.comparable:
        logger.info('')
        logger.info('Beginning comparable data extraction and training.')

        # 1. Initialize Comparable object
        comp = Comparable(model, trainer, fields, logger, opt)

        # 2. Infer similarity threshold from training data

        for epoch in range(opt.comp_epochs):
            # 3. Update threshold if dynamic
            if opt.threshold_dynamics != 'static' and epoch != 0:
                comp.update_threshold(opt.threshold_dynamics,
                                      opt.infer_threshold)

            # 4. Extract parallel data and train
            #if opt.match_articles:
            #    comparable_data = comp.match_articles(opt.match_articles)
            #    train_stats = comp.extract_and_train(comparable_data)
            #else:
            train_stats = comp.extract_and_train(opt.comparable_data)

            # 5. Validate on validation set
            if opt.no_valid == False:
                valid_iter = build_dataset_iter(
                    lazily_load_dataset("valid", opt), fields, opt)
                valid_stats = comp.validate(valid_iter)

            # 6. Drop a checkpoint if needed
            comp.trainer.model_saver._save(epoch)
Пример #11
0
def main(opt, device_id):
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    configure_process(opt, device_id)
    init_logger(opt.log_file)
    # Load checkpoint if we resume from a previous training.
    load_str = opt.train_from if opt.train_from else opt.load_uncond_from
    if load_str:
        logger.info('Loading checkpoint from %s' % load_str)
        checkpoint = torch.load(load_str,
                                map_location=lambda storage, loc: storage)

        logger.info('Loading vocab from checkpoint at %s.' % load_str)
        vocab = checkpoint['vocab']

        if opt.train_from:
            model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
            ArgumentParser.update_model_opts(model_opt)
            ArgumentParser.validate_model_opts(model_opt)
            if opt.multi_task:
                # manually set up the learning rate
                model_opt.__setattr__("multi_task", opt.multi_task)
                model_opt.__setattr__("multi_task_lr", opt.multi_task_lr)
                model_opt.__setattr__("data", opt.data)
                model_opt.__setattr__("save_model", opt.save_model)
                model_opt.__setattr__("multi_task_finish", opt.multi_task_finish)
                model_opt.__setattr__("clf_task", opt.clf_task)
                model_opt.__setattr__("valid_steps", opt.valid_steps)
                model_opt.__setattr__("report_every", opt.report_every)
                model_opt.__setattr__("clf_task", opt.clf_task)

            if opt.multi_task_finish:
                model_opt.__setattr__("data", opt.query_data)
            if opt.only_query:
                model_opt.__setattr__("data", opt.query_data)

        else:
            model_opt = opt
    else:
        checkpoint = None
        model_opt = opt
        vocab = torch.load(opt.data + '.vocab.pt')

    if opt.clf_task:
        vocab = torch.load(opt.data + '.vocab.pt')
    if opt.gpt2_params_path is not None:
        import tensorflow as tf
        import numpy as np
        # Taken from pytorch-pretrained-BERT:
        # Load weights from TF model
        logger.info("Loading TF GPT weights...")
        init_vars = tf.train.list_variables(opt.gpt2_params_path)
        names = []
        arrays = []
        for name, shape in init_vars:
            if opt.gpt_emb_only and ('wpe' not in name and 'wte' not in name):
                continue
            if opt.gpt_wpe_only and 'wpe' not in name:
                continue
            #print("Loading TF weight {} with shape {}".format(name, shape))
            array = tf.train.load_variable(opt.gpt2_params_path, name)
            names.append(name)
            arrays.append(array.squeeze())
        logger.info("Done.")
        
        if checkpoint is None:
            checkpoint = {'gpt2_params': zip(names, arrays)}
        else:
            checkpoint['gpt2_params'] = zip(names, arrays)

    if opt.encoder_from is not None:
        logger.info('Loading checkpoint with encoder from %s' % opt.encoder_from)
        enc_checkpoint = torch.load(opt.encoder_from,
                                map_location=lambda storage, loc: storage)
        enc_vocab = enc_checkpoint['vocab']
        if vocab['src'].base_field.vocab != enc_vocab['src'].base_field.vocab:
            raise ValueError('encoder vocab and model vocab need to be identical it using pretrained encoder')
        if checkpoint is None:
            checkpoint = {}
        checkpoint['enc_model'] = enc_checkpoint['model']
        
    
    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(
            vocab, opt.model_type, dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    # Report src and tgt vocab sizes, including for features
    sides = ['tgt'] if opt.model_type == 'none' else ['src', 'tgt']
    for side in sides:
        f = fields[side]
        try:
            f_iter = iter(f)
        except TypeError:
            f_iter = [(side, f)]
        for sn, sf in f_iter:
            if sf.use_vocab:
                logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec, lm_dec = _tally_parameters(model)
    n_params_t, enc_t, dec_t, lm_dec_t = _tally_parameters(model, only_trainable=True)
    logger.info('encoder: %d (%d)' % (enc, enc_t))
    logger.info('decoder: %d (%d)' % (dec, dec_t))
    if opt.simple_fusion:
        logger.info('lm decoder: %d (%d)' % (lm_dec, lm_dec_t))

    logger.info('* number of parameters: %d (%d)' % (n_params, n_params_t))
    _check_save_model_path(opt)

    if not opt.train_from and opt.gpt2_params_path is not None:
        checkpoint = None

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(
        opt, device_id, model, fields, optim, model_saver=model_saver)

    train_iter = build_dataset_iter("train", fields, opt)
    valid_iter = build_dataset_iter(
        "valid", fields, opt, is_train=False)

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0
    trainer.train(
        train_iter,
        train_steps,
        save_checkpoint_steps=opt.save_checkpoint_steps,
        valid_iter=valid_iter,
        valid_steps=opt.valid_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
Пример #12
0
def training_main(opt):
    opt = training_opt_postprocessing(opt)

    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        print('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        model_opt = checkpoint['opt']
        # I don't like reassigning attributes of opt: it's not clear.
        opt.start_epoch = checkpoint['epoch'] + 1
    else:
        checkpoint = None
        model_opt = opt

    # Peek the fisrt dataset to determine the data_type.
    # (All datasets have the same data_type).
    first_dataset = next(lazily_load_dataset("train", opt))
    data_type = first_dataset.data_type

    # Load fields generated from preprocess phase.
    fields = _load_fields(first_dataset, data_type, opt, checkpoint)

    # Report src/tgt features.
    _collect_report_features(fields)

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    _tally_parameters(model)
    _check_save_model_path(opt)

    model._vivisect = {
        "iteration": 0,
        "model_name": "OpenNMT Model",
        "framework": "pytorch",
        "mode": "train"
    }
    probe(model, "localhost", 8082)

    # Build optimizer.
    optim = build_optim(model, opt, checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(opt,
                            model,
                            fields,
                            optim,
                            data_type,
                            model_saver=model_saver)

    def train_iter_fct():
        model._vivisect["iteration"] += 1
        model._vivisect["mode"] = "train"
        return build_dataset_iter(lazily_load_dataset("train", opt), fields,
                                  opt)

    def valid_iter_fct():
        model._vivisect["mode"] = "dev"
        return build_dataset_iter(lazily_load_dataset("valid", opt), fields,
                                  opt)

    # Do training.
    trainer.train(train_iter_fct, valid_iter_fct, opt.start_epoch, opt.epochs)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
Пример #13
0
def main(opt, device_id):
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    configure_process(opt, device_id)
    init_logger(opt.log_file)
    assert len(opt.accum_count) == len(opt.accum_steps), \
        'Number of accum_count values must match number of accum_steps'
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)

        model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        ArgumentParser.update_model_opts(model_opt)
        ArgumentParser.validate_model_opts(model_opt)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
    else:
        checkpoint = None
        model_opt = opt
    vocab = torch.load(opt.data + '/vocab.pt')

    if opt.train_from and not opt.reset_optim in {'score', 'all'}:
        valid_loss = checkpoint['valid_loss']
        test_score = checkpoint['test_score']
    else:
        valid_loss = 100000
        test_score = -1

    fields = vocab
    # Report src and tgt vocab sizes, including for features
    for idx in fields.keys():
        logger.info(' * %s feat size = %d' % (idx, len(fields[idx])))

    # Detect device
    gpu = use_gpu(opt)
    if gpu:
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    # Build model.
    if opt.model_type in {'single', 'nmt', 'tts', 'asr', 'vocoder'}:
        model = build_model(model_opt, opt, fields, device, checkpoint)
        n_params, enc, dec = _tally_parameters(model)
        logger.info('encoder: %d' % enc)
        logger.info('decoder: %d' % dec)
        logger.info('* number of parameters: %d' % n_params)
    else:
        model = build_multi_model(model_opt, opt, fields, device, checkpoint)

    #_check_save_model_path(opt)
    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)
    #if opt.adversarial:
    optim_ad, loss_ad = adversarial(opt.adversarial, model, model_opt, vocab)
    trainer = build_trainer(fields,
                            opt,
                            device_id,
                            model,
                            optim,
                            model_saver=model_saver,
                            optim_ad=optim_ad,
                            loss_ad=loss_ad)

    train_set, valid_set, test_set = load_dataset(opt.data)

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0
    trainer.train(train_set,
                  opt.batch_size,
                  train_steps,
                  save_checkpoint_steps=opt.save_checkpoint_steps,
                  valid_set=valid_set,
                  valid_batch_size=min(opt.valid_batch_size, opt.batch_size),
                  valid_steps=opt.valid_steps,
                  test_set=test_set,
                  valid_loss=valid_loss,
                  test_score=test_score)

    trainer.report_manager.tensorboard_writer.close()
Пример #14
0
def main(opt, device_id):
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    configure_process(opt, device_id)
    init_logger(opt.log_file)
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)

        model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        ArgumentParser.update_model_opts(model_opt)
        ArgumentParser.validate_model_opts(model_opt)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        checkpoint = None
        model_opt = opt
        vocab = []
        for i in range(int(opt.encoder_num)):
            vocab.append(torch.load(opt.data + '.' + str(i) + '.vocab.pt'))

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab[0]):
        fields = [
            load_old_vocab(vocab[i],
                           opt.model_type,
                           dynamic_dict=opt.copy_attn)
            for i in range(int(opt.encoder_num))
        ]
    else:
        fields = vocab

    # Report src and tgt vocab sizes, including for features
    for side in ['src', 'tgt']:
        f = fields[opt.encoder_id][side]
        try:
            f_iter = iter(f)
        except TypeError:
            f_iter = [(side, f)]
        for sn, sf in f_iter:
            if sf.use_vocab:
                logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields[opt.encoder_id],
                            optim,
                            model_saver=model_saver)

    train_iter = build_dataset_iter("train", fields[opt.encoder_id], opt)
    valid_iter = build_dataset_iter("valid",
                                    fields[opt.encoder_id],
                                    opt,
                                    is_train=False)

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0

    logger.info("Training the encoder: %d" % opt.encoder_id)
    logger.info("Training the generator: %d " % opt.generator_id)
    trainer.train(train_iter,
                  train_steps,
                  train_interval_steps=opt.train_interval_steps,
                  save_checkpoint_steps=opt.save_checkpoint_steps,
                  valid_iter=valid_iter,
                  valid_steps=opt.valid_steps,
                  encoder_id=opt.encoder_id,
                  generator_id=opt.generator_id)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
Пример #15
0
def main(opt, device_id):
    opt = training_opt_postprocessing(opt, device_id)
    init_logger(opt.log_file)

    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        model_opt = checkpoint['opt']
    else:
        checkpoint = None
        model_opt = opt

    first_dataset = pickle.load(open('processed_data/all-train/train.pt',
                                     'rb'))
    data_type = first_dataset.data_type

    # Load fields generated from preprocess phase.
    fields = _load_fields(first_dataset, data_type, opt, checkpoint)
    # Report src/tgt features.
    src_features, tgt_features = _collect_report_features(fields)
    for j, feat in enumerate(src_features):
        logger.info(' * src feature %d size = %d' %
                    (j, len(fields[feat].vocab)))
    for j, feat in enumerate(tgt_features):
        logger.info(' * tgt feature %d size = %d' %
                    (j, len(fields[feat].vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)

    optim = build_optim(model, opt, checkpoint)  # opt.train_from == ''
    # Build model saver
    if not os.path.exists('experiments/all_train'):
        os.mkdir('experiments/all_train')
    model_saver = build_model_saver(model_opt, opt.save_model, opt, model,
                                    fields, optim)

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            "text",
                            model_saver=model_saver)

    def _lazy_dataset_loader(pt_file):
        # dataset = torch.load(pt_file)
        def dataset_loader(pt_file):
            with open(pt_file, 'rb') as f:
                dataset = pickle.load(f)
            # logger.info('Loading task from <{}>, number of examples: {}'.format(pt_file, len(dataset)))
            return dataset

        yield dataset_loader(pt_file)

    train_iter = list(
        build_dataset_iter(
            _lazy_dataset_loader('processed_data/all-train/train.pt'), fields,
            opt))

    trainer.train(train_iter, opt.train_epochs)
Пример #16
0
    def __init__(self, model_dir):

        # Model dir
        self._model_dir = os.path.abspath(model_dir)
        if not os.path.isdir(self._model_dir):
            msg = f"{model_dir} doesn't exists'"
            raise ValueError(msg)

        # Extended model
        self._extended_model = ExtendedModel(model_dir)

        # Config
        self._config = self._extended_model.config

        # Options
        self._opts = self._config.opts

        # Get the model options
        model_path = self._opts.models[0]
        checkpoint = torch.load(
            model_path, map_location=lambda storage, loc: storage
        )
        self._model_opts = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        ArgumentParser.update_model_opts(self._model_opts)
        ArgumentParser.validate_model_opts(self._model_opts)

        # Extract vocabulary
        vocab = checkpoint["vocab"]
        if inputters.old_style_vocab(vocab):
            self._fields = inputters.load_old_vocab(
                vocab, "text", dynamic_dict=False
            )
        else:
            self._fields = vocab

        # Train_steps
        self._train_steps = self._model_opts.train_steps

        # Build openmmt model
        self._opennmt_model = build_base_model(
            self._model_opts,
            self._fields,
            use_gpu(self._opts),
            checkpoint,
            self._opts.gpu,
        )

        # Translator
        try:
            min_length = self._opts.min_length
        except:
            min_length = 0

        try:
            max_length = self._opts.max_length
        except:
            max_length = 100

        try:
            beam_size = self._opts.beam_size
        except:
            beam_size = 5

        try:
            replace_unk = self._opts.replace_unk
        except:
            replace_unk = 0

        self._translator = Translator(
            self._opennmt_model,
            self._fields,
            TextDataReader(),
            TextDataReader(),
            gpu=self._opts.gpu,
            min_length=min_length,
            max_length=max_length,
            beam_size=beam_size,
            replace_unk=replace_unk,
            copy_attn=self._model_opts.copy_attn,
            global_scorer=GNMTGlobalScorer(0.0, -0.0, "none", "none"),
            seed=self.SEED,
        )

        online_learning = self._config.online_learning
        if online_learning:
            # Optim
            optimizer_opt = type("", (), {})()
            optimizer_opt.optim = "sgd"
            optimizer_opt.learning_rate = self._opts.learning_rate
            optimizer_opt.train_from = ""
            optimizer_opt.adam_beta1 = 0
            optimizer_opt.adam_beta2 = 0
            optimizer_opt.model_dtype = "fp32"
            optimizer_opt.decay_method = "none"
            optimizer_opt.start_decay_steps = 100000
            optimizer_opt.learning_rate_decay = 1.0
            optimizer_opt.decay_steps = 100000
            optimizer_opt.max_grad_norm = 5
            self._optim = Optimizer.from_opt(
                self._opennmt_model, optimizer_opt, checkpoint=None
            )

            trainer_opt = type("", (), {})()
            trainer_opt.lambda_coverage = 0.0
            trainer_opt.copy_attn = False
            trainer_opt.label_smoothing = 0.0
            trainer_opt.truncated_decoder = 0
            trainer_opt.model_dtype = "fp32"
            trainer_opt.max_generator_batches = 32
            trainer_opt.normalization = "sents"
            trainer_opt.accum_count = [1]
            trainer_opt.accum_steps = [0]
            trainer_opt.world_size = 1
            trainer_opt.average_decay = 0
            trainer_opt.average_every = 1
            trainer_opt.dropout = 0
            trainer_opt.dropout_steps = (0,)
            trainer_opt.gpu_verbose_level = 0
            trainer_opt.early_stopping = 0
            trainer_opt.early_stopping_criteria = (None,)
            trainer_opt.tensorboard = False
            trainer_opt.report_every = 50
            trainer_opt.gpu_ranks = []
            if self._opts.gpu != -1:
                trainer_opt.gpu_ranks = [self._opts.gpu]

            self._trainer = build_trainer(
                trainer_opt,
                self._opts.gpu,
                self._opennmt_model,
                self._fields,
                self._optim,
            )
        else:
            self._trainer = None
Пример #17
0
def main(opt, device_id):
    opt = training_opt_postprocessing(opt, device_id)
    init_logger(opt.log_file)
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        model_opt = checkpoint['opt']
    else:
        checkpoint = None
        model_opt = opt

    # Peek the first dataset to determine the data_type.
    # (All datasets have the same data_type).
    first_dataset = next(lazily_load_dataset("train", opt))
    data_type = first_dataset.data_type

    # Load fields generated from preprocess phase.
    fields = _load_fields(first_dataset, data_type, opt, checkpoint)

    # Report src/tgt features.

    src_features, tgt_features = _collect_report_features(fields)
    for j, feat in enumerate(src_features):
        logger.info(' * src feature %d size = %d' %
                    (j, len(fields[feat].vocab)))
    for j, feat in enumerate(tgt_features):
        logger.info(' * tgt feature %d size = %d' %
                    (j, len(fields[feat].vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = build_optim(model, opt, checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            data_type,
                            model_saver=model_saver)

    def train_iter_fct():
        return build_dataset_iter(lazily_load_dataset("train", opt, True),
                                  fields, opt)

    def valid_iter_fct():
        return build_dataset_iter(lazily_load_dataset("valid", opt),
                                  fields,
                                  opt,
                                  is_train=False)

    def monitor_iter_fct():
        monitor_data = dict()
        for src, tgt in zip(opt.monitor_src, opt.monitor_tgt):
            fname = src.split("/" if "/" in src else "\\")[-1].split(
                ".")[0].replace("_src", "")
            monitor_data[fname] = build_dataset_iter(lazily_load_dataset(
                "monitor", opt, fname),
                                                     fields,
                                                     opt,
                                                     is_train=False)
        return monitor_data

    # Do training.
    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    trainer.train(train_iter_fct, valid_iter_fct, monitor_iter_fct,
                  opt.train_steps, opt.valid_steps, opt.monitor_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
Пример #18
0
def main(opt, device_id):
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    configure_process(opt, device_id)
    init_logger(opt.log_file)
    assert len(opt.accum_count) == len(opt.accum_steps), \
        'Number of accum_count values must match number of accum_steps'
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)

        model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        ArgumentParser.update_model_opts(model_opt)
        ArgumentParser.validate_model_opts(model_opt)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
        vocab = checkpoint['vocab']
    else:
        checkpoint = None
        model_opt = opt
        vocab = torch.load(opt.data + '.vocab.pt')

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(
            vocab, opt.model_type, dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    # Report src and tgt vocab sizes, including for features
    for side in ['src', 'tgt']:
        f = fields[side]
        try:
            f_iter = iter(f)
        except TypeError:
            f_iter = [(side, f)]
        for sn, sf in f_iter:
            if sf.use_vocab:
                logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(
        opt, device_id, model, fields, optim, model_saver=model_saver)

    train_iter = build_dataset_iter("train", fields, opt)
    valid_iter = build_dataset_iter(
        "valid", fields, opt, is_train=False)

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0
    trainer.train(
        train_iter,
        train_steps,
        save_checkpoint_steps=opt.save_checkpoint_steps,
        valid_iter=valid_iter,
        valid_steps=opt.valid_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
Пример #19
0
def main(opt, device_id):
    opt = training_opt_postprocessing(opt, device_id)
    init_logger(opt.log_file)
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        model_opt = checkpoint['opt']
    else:
        checkpoint = None
        model_opt = opt

    # Peek the first dataset to determine the data_type.
    # (All datasets have the same data_type).
    first_dataset = next(lazily_load_dataset("train", opt))
    data_type = first_dataset.data_type

    # Load fields generated from preprocess phase.
    fields = _load_fields(first_dataset, data_type, opt, checkpoint)

    # Report src/tgt features.

    src_features, tgt_features = _collect_report_features(fields)
    for j, feat in enumerate(src_features):
        logger.info(' * src feature %d size = %d' %
                    (j, len(fields[feat].vocab)))
    for j, feat in enumerate(tgt_features):
        logger.info(' * tgt feature %d size = %d' %
                    (j, len(fields[feat].vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    if opt.train_from and opt.reset_optim != 'all':
        logger.info('* checkpoint training not considered by me yet')
    else:
        # warmup_steps and rnn_size are parameters for Noam decay (transformer):
        #    https://arxiv.org/pdf/1706.03762.pdf (Section 3)
        decay_method = opt.decay_method if opt.decay_method else "standard"
        logger.info(
            '* Opt: %s (rate %.5f, maxgnorm %.1f, %s decay, '
            'decay_rate %.1f, start_decay_at %d, decay_every %d, '
            'ab1 %.5f, ab2 %.5f, adagradaccum %.1f, '
            'warmupsteps %d, hiddensize %d)' %
            (opt.optim, opt.learning_rate, opt.max_grad_norm, decay_method,
             opt.learning_rate_decay, opt.start_decay_steps, opt.decay_steps,
             opt.adam_beta1, opt.adam_beta2, opt.adagrad_accumulator_init,
             opt.warmup_steps, opt.rnn_size))

    optim = build_optim(model, opt, checkpoint)

    # Build model saver                               v
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    logger.info('* model_saver built, using it to build trainer with ')

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            data_type,
                            model_saver=model_saver)

    #---------------------------------------------------------------------------
    # 1. lazily_load_dataset = for pt in pts: yield torch.load(pt)
    # 2. build_dataset_iter  = return DatasetLazyIter (train_iter_fct)
    # 3. train_iter_fct()    = iterator over torchtext.data.batch.Batches
    #---------------------------------------------------------------------------
    def train_iter_fct():
        return build_dataset_iter(lazily_load_dataset("train", opt), fields,
                                  opt)

    def valid_iter_fct():
        return build_dataset_iter(lazily_load_dataset("valid", opt),
                                  fields,
                                  opt,
                                  is_train=False)

    # Do training.
    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps,
                  opt.valid_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id, batch_queue=None, semaphore=None):
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.
    configure_process(opt, device_id)
    init_logger(opt.log_file)
    assert len(opt.accum_count) == len(opt.accum_steps), \
        'Number of accum_count values must match number of accum_steps'
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        if 'opt' in checkpoint:
            model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
            ArgumentParser.update_model_opts(model_opt)
            ArgumentParser.validate_model_opts(model_opt)
        else:
            model_opt = opt

        if 'vocab' in checkpoint:
            logger.info('Loading vocab from checkpoint at %s.', opt.train_from)
            vocab = checkpoint['vocab']
        else:
            vocab = torch.load(opt.data + '.vocab.pt')
    else:
        checkpoint = None
        model_opt = opt
        vocab = torch.load(opt.data + '.vocab.pt')

    # check for code where vocab is saved instead of fields
    # (in the future this will be done in a smarter way)
    if old_style_vocab(vocab):
        fields = load_old_vocab(vocab,
                                opt.model_type,
                                dynamic_dict=opt.copy_attn)
    else:
        fields = vocab

    # Report src and tgt vocab sizes, including for features
    for side in ['src', 'tgt']:
        f = fields[side]
        try:
            f_iter = iter(f)
        except TypeError:
            f_iter = [(side, f)]
        for sn, sf in f_iter:
            if sf.use_vocab:
                logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            model_saver=model_saver)

    if batch_queue is None:
        if len(opt.data_ids) > 1:
            train_shards = []
            for train_id in opt.data_ids:
                shard_base = "train_" + train_id
                train_shards.append(shard_base)
            train_iter = build_dataset_iter_multiple(train_shards, fields, opt)
        else:
            if opt.data_ids[0] is not None:
                shard_base = "train_" + opt.data_ids[0]
            else:
                shard_base = "train"
            train_iter = build_dataset_iter(shard_base, fields, opt)

    else:
        assert semaphore is not None, \
            "Using batch_queue requires semaphore as well"

        def _train_iter():
            while True:
                batch = batch_queue.get()
                semaphore.release()
                yield batch

        train_iter = _train_iter()

    valid_iter = build_dataset_iter("valid", fields, opt, is_train=False)

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0

    trainer.train(train_iter,
                  train_steps,
                  save_checkpoint_steps=opt.save_checkpoint_steps,
                  valid_iter=valid_iter,
                  valid_steps=opt.valid_steps)

    if trainer.report_manager.tensorboard_writer is not None:
        trainer.report_manager.tensorboard_writer.close()
Пример #21
0
def main(opt):
    ArgumentParser.validate_train_opts(opt)
    ArgumentParser.update_model_opts(opt)
    ArgumentParser.validate_model_opts(opt)


    init_logger(opt.log_file)
    assert len(opt.accum_count) == len(opt.accum_steps), \
        'Number of accum_count values must match number of accum_steps'
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        # 将模型参数导入到CPU
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)

        model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        ArgumentParser.update_model_opts(model_opt)
        ArgumentParser.validate_model_opts(model_opt)
        logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
    else:
        checkpoint = None
        model_opt = opt
        
    # Build Tokenizaer


################################
    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    # 输出参数个数
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    # 检查模型保存路径,如果不存在则创建
    _check_save_model_path(opt)

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(
        opt, device_id, model, fields, optim, model_saver=model_saver)

    train_iterables = []
    if len(opt.data_ids) > 1:
        for train_id in opt.data_ids:
            shard_base = "train_" + train_id
            iterable = build_dataset_iter(shard_base, fields, opt, multi=True)
            train_iterables.append(iterable)
        train_iter = MultipleDatasetIterator(train_iterables, device_id, opt)
    else:
        train_iter = build_dataset_iter("train", fields, opt)

    valid_iter = build_dataset_iter(
        "valid", fields, opt, is_train=False)

    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0
    trainer.train(
        train_iter,
        train_steps,
        save_checkpoint_steps=opt.save_checkpoint_steps,
        valid_iter=valid_iter,
        valid_steps=opt.valid_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()

########################
    if opt.use_gpu:
        if opt.gpus == []:
            pass
        else:
            pass
    else:
        # use cpu
        pass

        single_main(opt, 0)
    else:   # case only CPU