コード例 #1
0
def main(opt, device_id):
    opt = training_opt_postprocessing(opt, device_id)
    init_logger(opt.log_file)
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        model_opt = checkpoint['opt']
    else:
        checkpoint = None
        model_opt = opt

    # Peek the first dataset to determine the data_type.
    # (All datasets have the same data_type).
    first_dataset = next(lazily_load_dataset("train", opt))
    data_type = first_dataset.data_type

    # Load fields generated from preprocess phase.
    fields = _load_fields(first_dataset, data_type, opt, checkpoint)

    # Report src/tgt features.

    src_features, tgt_features = _collect_report_features(fields)
    for j, feat in enumerate(src_features):
        logger.info(' * src feature %d size = %d'
                    % (j, len(fields[feat].vocab)))
    for j, feat in enumerate(tgt_features):
        logger.info(' * tgt feature %d size = %d'
                    % (j, len(fields[feat].vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = build_optim(model, opt, checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(opt, device_id, model, fields,
                            optim, data_type, model_saver=model_saver)

    def train_iter_fct(): return build_dataset_iter(
        lazily_load_dataset("train", opt), fields, opt)

    def valid_iter_fct(): return build_dataset_iter(
        lazily_load_dataset("valid", opt), fields, opt, is_train=False)

    # Do training.
    trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps,
                  opt.valid_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
コード例 #2
0
def bulid_nmt_model(opt):
    # opt = training_opt_postprocessing(opt)
    # init_logger(opt.log_file)
    # Load checkpoint if we resume from a previous training.
    if opt.model_path_G:
        logger.info('Loading checkpoint from %s' % opt.model_path_G)
        checkpoint = torch.load(opt.model_path_G, map_location=lambda storage, loc: storage)
        model_opt = checkpoint['opt']
    else:
        raise(AssertionError("no nmt model"))

    # Peek the fisrt dataset to determine the data_type.
    # (All datasets have the same data_type).
    first_dataset = next(lazily_load_dataset("train", opt))
    data_type = first_dataset.data_type

    # Load fields generated from preprocess phase.
    fields = _load_fields(first_dataset, data_type, opt, checkpoint)

    # Report src/tgt features.

    src_features, tgt_features = _collect_report_features(fields)
    for j, feat in enumerate(src_features):
        logger.info(' * src feature %d size = %d'
                    % (j, len(fields[feat].vocab)))
    for j, feat in enumerate(tgt_features):
        logger.info(' * tgt feature %d size = %d'
                    % (j, len(fields[feat].vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint, additional_device0)
    # model = build_model(model_opt, opt, fields, checkpoint, "cpu")

    return model
コード例 #3
0
ファイル: train_single.py プロジェクト: vikibytes/OpenNMT-py
def main(opt):
    opt = training_opt_postprocessing(opt)

    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        print('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        model_opt = checkpoint['opt']
        # I don't like reassigning attributes of opt: it's not clear.
        opt.start_epoch = checkpoint['epoch'] + 1
    else:
        checkpoint = None
        model_opt = opt

    # Peek the fisrt dataset to determine the data_type.
    # (All datasets have the same data_type).
    first_dataset = next(lazily_load_dataset("train", opt))
    data_type = first_dataset.data_type

    # Load fields generated from preprocess phase.
    fields = _load_fields(first_dataset, data_type, opt, checkpoint)

    # Report src/tgt features.
    _collect_report_features(fields)

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    _tally_parameters(model)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = build_optim(model, opt, checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(opt,
                            model,
                            fields,
                            optim,
                            data_type,
                            model_saver=model_saver)

    def train_iter_fct():
        return build_dataset_iter(lazily_load_dataset("train", opt), fields,
                                  opt)

    def valid_iter_fct():
        return build_dataset_iter(lazily_load_dataset("valid", opt), fields,
                                  opt)

    # Do training.
    trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps,
                  opt.valid_steps, opt.save_checkpoint_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
コード例 #4
0
 def monitor_iter_fct():
     monitor_data = dict()
     for src, tgt in zip(opt.monitor_src, opt.monitor_tgt):
         fname = src.split("/" if "/" in src else "\\")[-1].split(
             ".")[0].replace("_src", "")
         monitor_data[fname] = build_dataset_iter(lazily_load_dataset(
             "monitor", opt, fname),
                                                  fields,
                                                  opt,
                                                  is_train=False)
     return monitor_data
コード例 #5
0
 def get_speech_iterator(self, name, lang1, lang2, is_train=True): # name: train or valid
     """
     Create a new iterator for a dataset.
     """
     key = ','.join([x for x in ['speech', name, lang1, lang2] if x is not None])
     logger.info("Creating new training %s iterator ..." % key)
     speech_direction = (lang1, lang2)
     iterator = build_dataset_iter(
         lazily_load_dataset(name, self.params.speech_dataset[speech_direction][0]),
                                     self.speech_fields[speech_direction], self.params, is_train)
     iterator = iter(iterator)
     self.iterators[key] = iterator
     return iterator
コード例 #6
0
    def get_speech_iterator(self, data_type, lang1, lang2):
        """
        Create a new iterator for a dataset.
        """
        assert data_type in ['valid']
        speech_direction = (lang1, lang2)
        lang2_id = self.params.lang2id[lang2]

        iterator = build_dataset_iter(
            lazily_load_dataset(data_type, self.params.speech_dataset[speech_direction][0]),
                                        self.params.speech_fields[speech_direction], self.params, False)
        iterator = iter(iterator)
        for batch in iterator:
            bos_index = self.params.bos_index[lang2_id]
            batch.tgt[0] = bos_index
            yield batch
コード例 #7
0
ファイル: train_single.py プロジェクト: fabiocuri/WALS
def main(opt, device_id):
    #TODO delete all these lines related to WALS features
    #begin
    SimulationLanguages = [opt.wals_src, opt.wals_tgt]

    print('Loading WALS features from databases...')

    cwd = os.getcwd()

    db = sqlite3.connect(cwd + '/onmt/WalsValues.db')
    cursor = db.cursor()
    cursor.execute('SELECT * FROM WalsValues')
    WalsValues = cursor.fetchall()

    db = sqlite3.connect(cwd + '/onmt/FeaturesList.db')
    cursor = db.cursor()
    cursor.execute('SELECT * FROM FeaturesList')
    FeaturesList = cursor.fetchall()

    db = sqlite3.connect(cwd + '/onmt/FTInfos.db')
    cursor = db.cursor()
    cursor.execute('SELECT * FROM FTInfos')
    FTInfos = cursor.fetchall()

    db = sqlite3.connect(cwd + '/onmt/FTList.db')
    cursor = db.cursor()
    cursor.execute('SELECT * FROM FTList')
    FTList = cursor.fetchall()

    ListLanguages = []
    for i in WalsValues:
        ListLanguages.append(i[0])

    FeatureTypes = []
    for i in FTList:
        FeatureTypes.append((i[0], i[1].split(',')))

    FeatureNames = []
    for i in FeatureTypes:
        FeatureNames += i[1]

    FeatureTypesNames = []
    for i in FeatureTypes:
        FeatureTypesNames.append(i[0])

    FeatureValues, FeatureTensors = get_feat_values(SimulationLanguages, WalsValues, FeaturesList, ListLanguages, FeatureTypes, FeatureNames) 

    print('WALS databases loaded!')
    #end
    #TODO: load wals features from command-line (wals.npz)

    # FeatureValues: defaultdict with feature values, per language.
    # FeatureTensors: tensor of possible outputs, per feature.

    opt = training_opt_postprocessing(opt, device_id)
    init_logger(opt.log_file)
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        model_opt = checkpoint['opt']
    else:
        checkpoint = None
        model_opt = opt

    # Peek the first dataset to determine the data_type.
    # (All datasets have the same data_type).
    first_dataset = next(lazily_load_dataset("train", opt))
    data_type = first_dataset.data_type

    # Load fields generated from preprocess phase.
    fields = _load_fields(first_dataset, data_type, opt, checkpoint)

    # Report src/tgt features.

    src_features, tgt_features = _collect_report_features(fields)
    for j, feat in enumerate(src_features):
        logger.info(' * src feature %d size = %d'
                    % (j, len(fields[feat].vocab)))
    for j, feat in enumerate(tgt_features):
        logger.info(' * tgt feature %d size = %d'
                    % (j, len(fields[feat].vocab)))

    # Build model.
    #TODO: remove all parameters related to WALS features: FeatureValues, FeatureTensors, FeatureTypes, FeaturesList, FeatureNames, FTInfos, FeatureTypesNames, SimulationLanguages
    #TODO: include four parameter related to WALS features: the four numpy arrays separately
    model = build_model(model_opt, opt, fields, checkpoint, FeatureValues, FeatureTensors, FeatureTypes, FeaturesList, FeatureNames, FTInfos, FeatureTypesNames, SimulationLanguages)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = build_optim(model, opt, checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(opt, device_id, model, fields,
                            optim, data_type, model_saver=model_saver)

    def train_iter_fct(): return build_dataset_iter(
        lazily_load_dataset("train", opt), fields, opt)

    def valid_iter_fct(): return build_dataset_iter(
        lazily_load_dataset("valid", opt), fields, opt, is_train=False)

    # Do training.
    trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps,
                  opt.valid_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
コード例 #8
0
ファイル: run_opennmt.py プロジェクト: jniehues-kit/vivisect
 def train_iter_fct():
     model._vivisect["iteration"] += 1
     model._vivisect["mode"] = "train"
     return build_dataset_iter(lazily_load_dataset("train", opt), fields,
                               opt)
コード例 #9
0
def main(opt, device_id):
    opt = training_opt_postprocessing(opt, device_id)
    init_logger(opt.log_file)
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)

        # Load default opts values then overwrite it with opts from
        # the checkpoint. It's usefull in order to re-train a model
        # after adding a new option (not set in checkpoint)
        dummy_parser = configargparse.ArgumentParser()
        opts.model_opts(dummy_parser)
        default_opt = dummy_parser.parse_known_args([])[0]

        model_opt = default_opt
        model_opt.__dict__.update(checkpoint['opt'].__dict__)
    else:
        checkpoint = None
        model_opt = opt

    # Peek the first dataset to determine the data_type.
    # (All datasets have the same data_type).
    first_dataset = next(lazily_load_dataset("train", opt))
    data_type = first_dataset.data_type

    # Load fields generated from preprocess phase.
    fields = load_fields(first_dataset, opt, checkpoint)

    # Report src/tgt features.

    src_features, tgt_features = _collect_report_features(fields)
    for j, feat in enumerate(src_features):
        logger.info(' * src feature %d size = %d' %
                    (j, len(fields[feat].vocab)))
    for j, feat in enumerate(tgt_features):
        logger.info(' * tgt feature %d size = %d' %
                    (j, len(fields[feat].vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = build_optim(model, opt, checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            data_type,
                            model_saver=model_saver)

    def train_iter_fct():
        return build_dataset_iter(lazily_load_dataset("train", opt), fields,
                                  opt)

    def valid_iter_fct():
        return build_dataset_iter(lazily_load_dataset("valid", opt),
                                  fields,
                                  opt,
                                  is_train=False)

    # Do training.
    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    if opt.no_base == False:
        trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps,
                      opt.valid_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()

    if opt.comparable:
        logger.info('')
        logger.info('Beginning comparable data extraction and training.')

        # 1. Initialize Comparable object
        comp = Comparable(model, trainer, fields, logger, opt)

        # 2. Infer similarity threshold from training data

        for epoch in range(opt.comp_epochs):
            # 3. Update threshold if dynamic
            if opt.threshold_dynamics != 'static' and epoch != 0:
                comp.update_threshold(opt.threshold_dynamics,
                                      opt.infer_threshold)

            # 4. Extract parallel data and train
            #if opt.match_articles:
            #    comparable_data = comp.match_articles(opt.match_articles)
            #    train_stats = comp.extract_and_train(comparable_data)
            #else:
            train_stats = comp.extract_and_train(opt.comparable_data)

            # 5. Validate on validation set
            if opt.no_valid == False:
                valid_iter = build_dataset_iter(
                    lazily_load_dataset("valid", opt), fields, opt)
                valid_stats = comp.validate(valid_iter)

            # 6. Drop a checkpoint if needed
            comp.trainer.model_saver._save(epoch)
コード例 #10
0
    def valid_iter_fct(): return build_dataset_iter(
        lazily_load_dataset("valid", opt), fields, opt, is_train=False)

    # Do training.
    trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps,
コード例 #11
0
ファイル: trainslate.py プロジェクト: lvapeab/OpenNMT-py
def main(opt, device_id):
    # opt = training_opt_postprocessing(opt, device_id)
    init_logger(opt.log_file)
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        model_opt = checkpoint['opt']
    else:
        raise Exception('You need to load a model')

    logger.info('Loading data from %s' % opt.data)
    dataset = next(lazily_load_dataset("train", opt))
    data_type = dataset.data_type
    logger.info('Data type %s' % data_type)

    # Load fields generated from preprocess phase.
    fields = _load_fields(dataset, data_type, opt, checkpoint)
    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = build_optim(model, opt, checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    dataset_iter = build_dataset_iter(lazily_load_dataset("train", opt),
                                      fields, opt)
    out_file = codecs.open(opt.output, 'w+', 'utf-8')
    scorer = onmt.translate.GNMTGlobalScorer(opt.alpha, opt.beta,
                                             opt.coverage_penalty,
                                             opt.length_penalty)

    translation_builder = TranslationBuilder(dataset,
                                             fields,
                                             n_best=opt.n_best,
                                             replace_unk=opt.replace_unk,
                                             has_tgt=False)

    def train_iter_fct():
        return build_dataset_iter(lazily_load_dataset("train", opt), fields,
                                  opt)

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            data_type,
                            model_saver=model_saver)

    translator = Translator(trainer.model,
                            fields,
                            opt.beam_size,
                            global_scorer=scorer,
                            out_file=out_file,
                            report_score=False,
                            copy_attn=model_opt.copy_attn,
                            logger=logger)

    for i, batch in enumerate(dataset_iter):
        unprocessed_translations = translator.translate_batch(batch, dataset)
        translations = translation_builder.from_batch(unprocessed_translations)
        print "Translations: ", ' '.join(translations[0].pred_sents[0])
        trainer.train_from_data(batch, train_steps=1)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
コード例 #12
0
ファイル: train_single.py プロジェクト: sarubi/AIW_MTL
 def meta_valid_iter_fct(task_id, is_log=False):
     return build_dataset_iter(
         lazily_load_dataset("meta_valid", opt, task_id=task_id, is_log=is_log), fields_list[task_id], opt)
コード例 #13
0
    def __init__(self, encoder, decoder, discriminator, lm, data, params):
        """
        Initialize trainer.
        """
        super().__init__(device_ids=tuple(range(params.otf_num_processes)))
        self.encoder = encoder
        self.decoder = decoder
        self.discriminator = discriminator
        self.lm = lm
        self.data = data
        self.params = params
        self.train_iter = None
        self.speech_fields = dict()
        params.speech_fields = self.speech_fields
        # training variables
        self.best_metrics = {metric: -1e12 for metric in self.VALIDATION_METRICS}
        self.epoch = 0
        self.n_total_iter = 0
        self.freeze_enc_emb = self.params.freeze_enc_emb
        self.freeze_dec_emb = self.params.freeze_dec_emb


        self.reload_model()

        # initialization for on-the-fly generation/training
        if len(params.pivo_directions) > 0:
            self.otf_start_multiprocessing()

        # define encoder parameters (the ones shared with the
        # decoder are optimized by the decoder optimizer)
        enc_params = list(encoder.parameters())
        for i in range(params.n_langs):
            if params.share_lang_emb and i >= 0:
                break
            #assert enc_params[i].size() == (params.n_words[i], params.emb_dim)
        if self.params.share_encdec_emb:
            to_ignore = 1 if params.share_lang_emb else params.n_langs
            enc_params = enc_params[to_ignore:]

        # optimizers
        if params.dec_optimizer == 'enc_optimizer':
            params.dec_optimizer = params.enc_optimizer
        self.enc_optimizer = get_optimizer(enc_params, params.enc_optimizer) if len(enc_params) > 0 else None
        self.dec_optimizer = get_optimizer(decoder.parameters(), params.dec_optimizer)
        self.dis_optimizer = get_optimizer(discriminator.parameters(), params.dis_optimizer) if discriminator is not None else None
        self.lm_optimizer = get_optimizer(lm.parameters(), params.enc_optimizer) if lm is not None else None

        # models / optimizers
        self.model_opt = {
            'enc': (self.encoder, self.enc_optimizer),
            'dec': (self.decoder, self.dec_optimizer),
            'dis': (self.discriminator, self.dis_optimizer),
            'lm': (self.lm, self.lm_optimizer),
        }

        # define validation metrics / stopping criterion used for early stopping
        logger.info("Stopping criterion: %s" % params.stopping_criterion)
        if params.stopping_criterion == '':
            for lang1, lang2 in self.data['para'].keys():
                for data_type in ['valid', 'test']:
                    self.VALIDATION_METRICS.append('bleu_%s_%s_%s' % (lang1, lang2, data_type))
                    self.VALIDATION_METRICS.append('speech_bleu_%s_%s_%s' % (lang1, lang2, data_type))
            for lang1, lang2, lang3 in self.params.pivo_directions:
                if lang1 == lang3:
                    continue
                for data_type in ['valid', 'test']:
                    self.VALIDATION_METRICS.append('bleu_%s_%s_%s_%s' % (lang1, lang2, lang3, data_type))
            self.stopping_criterion = None
            self.best_stopping_criterion = None
        else:
            split = params.stopping_criterion.split(',')
            assert len(split) == 2 and split[1].isdigit()
            self.decrease_counts_max = int(split[1])
            self.decrease_counts = 0
            self.stopping_criterion = split[0]
            self.best_stopping_criterion = -1e12
            assert len(self.VALIDATION_METRICS) == 0
            self.VALIDATION_METRICS.append(self.stopping_criterion)

        # training statistics
        self.n_iter = 0
        self.n_sentences = 0
        self.stats = {
            'dis_costs': [],
            'processed_s': 0,
            'processed_w': 0,
        }
        for lang1, lang2 in params.speech_dataset.keys():
            self.stats['xe_costs_sp_%s_%s' % (lang1, lang2)] = []
        for speech_direction in params.speech_dataset.keys():
            # Peek the first dataset to determine the data_type.
            # (All datasets have the same data_type).
            first_dataset = next(lazily_load_dataset("train", params.speech_dataset[speech_direction][0]))
            data_type = first_dataset.data_type
            # Load fields generated from preprocess phase.
            fields = _load_fields_vocab(first_dataset, data_type, params.speech_vocabs[speech_direction[1]], None,
                                        'vocab')
            self.speech_fields[speech_direction] = fields

        for lang in params.mono_directions:
            self.stats['xe_costs_%s_%s' % (lang, lang)] = []
        for lang1, lang2 in params.para_directions:
            self.stats['xe_costs_%s_%s' % (lang1, lang2)] = []
        for lang1, lang2 in params.back_directions:
            self.stats['xe_costs_bt_%s_%s' % (lang1, lang2)] = []
        for lang1, lang2, lang3 in params.pivo_directions:
            self.stats['xe_costs_%s_%s_%s' % (lang1, lang2, lang3)] = []
        for lang in params.langs:
            self.stats['lme_costs_%s' % lang] = []
            self.stats['lmd_costs_%s' % lang] = []
            self.stats['lmer_costs_%s' % lang] = []
            self.stats['enc_norms_%s' % lang] = []
        self.last_time = time.time()
        if len(params.pivo_directions) > 0:
            self.gen_time = 0

        # data iterators
        self.iterators = {}

        # initialize BPE subwords
        self.init_bpe()

        # initialize lambda coefficients and their configurations
        parse_lambda_config(params, 'lambda_xe_mono')
        parse_lambda_config(params, 'lambda_xe_para')
        parse_lambda_config(params, 'lambda_xe_back')
        parse_lambda_config(params, 'lambda_xe_otfd')
        parse_lambda_config(params, 'lambda_xe_otfa')
        parse_lambda_config(params, 'lambda_dis')
        parse_lambda_config(params, 'lambda_lm')
        parse_lambda_config(params, 'lambda_speech')
コード例 #14
0
    def extract_and_train(self, comparable_data_list):
        """ Manages the alternating extraction of parallel sentences and training.
        Args:
            comparable_data_list(str): path to list of mapped documents
        Returns:
            train_stats(:obj:'onmt.Trainer.Statistics'): epoch loss statistics
        """
        # Start first epoch
        self.trainer.next_epoch()
        self.accepted_file = \
                open('{}_accepted-e{}.txt'.format(self.comp_log, self.trainer.cur_epoch), 'w+', encoding='utf8')
        self.status_file = '{}_status-e{}.txt'.format(self.comp_log,
                                                      self.trainer.cur_epoch)
        if self.write_dual:
            self.embed_file = '{}_accepted_embed-e{}.txt'.format(
                self.comp_log, self.trainer.cur_epoch)
            self.hidden_file = '{}_accepted_hidden-e{}.txt'.format(
                self.comp_log, self.trainer.cur_epoch)
        epoch_similarities = []
        epoch_scores = []
        counter = 0
        src_sents = []
        tgt_sents = []
        src_embeds = []
        tgt_embeds = []

        # Go through comparable data
        with open(comparable_data_list, encoding='utf8') as c:
            comp_list = c.read().split('\n')
            num_articles = len(comp_list)
            cur_article = 0

            for article_pair in comp_list:
                cur_article += 1

                # Update status
                with open(self.status_file, 'a', encoding='utf8') as sf:
                    sf.write('{} / {}\n'.format(cur_article, num_articles))

                articles = article_pair.split('\t')
                # Discard malaligned documents
                if len(articles) != 2:
                    continue

                # Prepare iterator objects for current src/tgt document
                src_article = self._get_iterator(articles[0])
                tgt_article = self._get_iterator(articles[1])

                # Get sentence representations
                try:
                    if self.representations == 'embed-only':
                        # C_e
                        src_sents += self.get_article_coves(src_article,
                                                            'embed',
                                                            fast=self.fast)
                        tgt_sents += self.get_article_coves(tgt_article,
                                                            'embed',
                                                            fast=self.fast)
                    else:
                        # C_h
                        src_sents += self.get_article_coves(src_article,
                                                            fast=self.fast)
                        tgt_sents += self.get_article_coves(tgt_article,
                                                            fast=self.fast)
                        # C_e
                        src_embeds += self.get_article_coves(src_article,
                                                             'embed',
                                                             fast=self.fast)
                        tgt_embeds += self.get_article_coves(tgt_article,
                                                             'embed',
                                                             fast=self.fast)
                except:
                    # Skip document pair in case of errors
                    src_sents = []
                    tgt_sents = []
                    src_embeds = []
                    tgt_embeds = []
                    continue

                # Ensure enough sentences are accumulated (otherwise scoring becomes unstable)
                if len(src_sents) < 15 or len(tgt_sents) < 15:
                    continue

                # Score src and tgt sentences
                src2tgt, tgt2src, similarities, scores = self.score_sents(
                    src_sents, tgt_sents)

                # Keep statistics
                epoch_similarities += similarities
                epoch_scores += scores
                src_sents = []
                tgt_sents = []

                # Filter candidates (primary filter)
                try:
                    if self.representations == 'dual':
                        # For dual representation systems, filter C_h...
                        candidates = self.filter_candidates(src2tgt,
                                                            tgt2src,
                                                            second=self.second)
                        # ...and C_e
                        comparison_pool, cand_embed = self.get_comparison_pool(
                            src_embeds, tgt_embeds)
                        src_embeds = []
                        tgt_embeds = []
                        if self.write_dual:
                            self.write_embed_only(candidates, cand_embed)
                    else:
                        # Filter C_e or C_h for single representation system
                        candidates = self.filter_candidates(src2tgt, tgt2src)
                        comparison_pool = None
                except:
                    # Skip document pair in case of errors
                    print('Error occured in: {}\n'.format(article_pair),
                          flush=True)
                    src_embeds = []
                    tgt_embeds = []
                    continue

                # Extract parallel samples (secondary filter)
                self.extract_parallel_sents(candidates, comparison_pool)

                # Check if enough parallel sentences were collected
                while self.similar_pairs.contains_batch():
                    # Get a batch of extracted parrallel sentences and train
                    try:
                        training_batch = self.similar_pairs.yield_batch()
                    except:
                        print('Error creating batch. Continuing...',
                              flush=True)
                        continue

                    # Statistics
                    train_stats = self.trainer.train(training_batch)
                    self.trainstep += 1

                    # Validate
                    if self.trainstep % self.valid_steps == 0:
                        if self.no_valid == False:
                            valid_iter = build_dataset_iter(
                                lazily_load_dataset('valid', self.opt),
                                self.fields, self.opt)
                            valid_stats = self.validate(valid_iter)

                    # Create checkpoint
                    if self.trainstep % 5000 == 0:
                        self.trainer.model_saver._save(self.trainstep)

            # Train on remaining partial batch
            if len((self.similar_pairs.pairs)) > 0:
                train_stats = self.trainer.train(
                    self.similar_pairs.yield_batch())
                self.trainstep += 1

        # Write epoch statistics
        self.write_similarities(epoch_similarities,
                                'e{}_comp'.format(self.trainer.cur_epoch))
        self.write_similarities(
            epoch_scores, 'e{}_comp_scores'.format(self.trainer.cur_epoch))
        self.trainer.report_epoch()
        self.logger.info(
            'Accepted parrallel sentences from comparable data: %d / %d' %
            (self.accepted, self.total))
        self.logger.info(
            'Acceptable parrallel sentences from comparable data (out of limit): %d / %d'
            % (self.accepted_limit, self.total))
        self.logger.info('Declined sentences from comparable data: %d / %d' %
                         (self.declined, self.total))

        # Reset epoch statistics
        self.accepted = 0
        self.accepted_limit = 0
        self.declined = 0
        self.total = 0
        self.accepted_file.close()
        return train_stats
コード例 #15
0
def main(opt, device_id):
    opt = training_opt_postprocessing(opt, device_id)
    init_logger(opt.log_file)
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        model_opt = checkpoint['opt']
    else:
        checkpoint = None
        model_opt = opt

    # Peek the first dataset to determine the data_type.
    # (All datasets have the same data_type).
    first_dataset = next(lazily_load_dataset("train", opt))
    data_type = first_dataset.data_type

    # Load fields generated from preprocess phase.
    fields = _load_fields(first_dataset, data_type, opt, checkpoint)

    # Report src/tgt features.

    src_features, tgt_features = _collect_report_features(fields)
    for j, feat in enumerate(src_features):
        logger.info(' * src feature %d size = %d' %
                    (j, len(fields[feat].vocab)))
    for j, feat in enumerate(tgt_features):
        logger.info(' * tgt feature %d size = %d' %
                    (j, len(fields[feat].vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    if opt.train_from and opt.reset_optim != 'all':
        logger.info('* checkpoint training not considered by me yet')
    else:
        # warmup_steps and rnn_size are parameters for Noam decay (transformer):
        #    https://arxiv.org/pdf/1706.03762.pdf (Section 3)
        decay_method = opt.decay_method if opt.decay_method else "standard"
        logger.info(
            '* Opt: %s (rate %.5f, maxgnorm %.1f, %s decay, '
            'decay_rate %.1f, start_decay_at %d, decay_every %d, '
            'ab1 %.5f, ab2 %.5f, adagradaccum %.1f, '
            'warmupsteps %d, hiddensize %d)' %
            (opt.optim, opt.learning_rate, opt.max_grad_norm, decay_method,
             opt.learning_rate_decay, opt.start_decay_steps, opt.decay_steps,
             opt.adam_beta1, opt.adam_beta2, opt.adagrad_accumulator_init,
             opt.warmup_steps, opt.rnn_size))

    optim = build_optim(model, opt, checkpoint)

    # Build model saver                               v
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    logger.info('* model_saver built, using it to build trainer with ')

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            data_type,
                            model_saver=model_saver)

    #---------------------------------------------------------------------------
    # 1. lazily_load_dataset = for pt in pts: yield torch.load(pt)
    # 2. build_dataset_iter  = return DatasetLazyIter (train_iter_fct)
    # 3. train_iter_fct()    = iterator over torchtext.data.batch.Batches
    #---------------------------------------------------------------------------
    def train_iter_fct():
        return build_dataset_iter(lazily_load_dataset("train", opt), fields,
                                  opt)

    def valid_iter_fct():
        return build_dataset_iter(lazily_load_dataset("valid", opt),
                                  fields,
                                  opt,
                                  is_train=False)

    # Do training.
    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps,
                  opt.valid_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
コード例 #16
0
ファイル: train_single.py プロジェクト: sarubi/AIW_MTL
 def train_iter_fct(task_id):
     return build_dataset_iter(
         lazily_load_dataset("train", opt, task_id=task_id), fields_list[task_id], opt)
コード例 #17
0
ファイル: run_opennmt.py プロジェクト: hltcoe/vivisect
        train_dataset_files = build_save_dataset('train', fields, opt, logger)

        logger.info("Building & saving vocabulary...")
        build_save_vocab(train_dataset_files, fields, opt, logger)

        logger.info("Building & saving validation data...")
        build_save_dataset('valid', fields, opt, logger)

        opt = train_args
        opt = training_opt_postprocessing(opt)

        model_opt = opt
        checkpoint = None
        # Peek the fisrt dataset to determine the data_type.
        # (All datasets have the same data_type).
        first_dataset = next(lazily_load_dataset("train", opt))
        data_type = first_dataset.data_type

        # Load fields generated from preprocess phase.
        fields = _load_fields(first_dataset, data_type, opt, checkpoint)

        # Report src/tgt features.
        _collect_report_features(fields)

        # Build model.
        model = build_model(model_opt, opt, fields, checkpoint)
        remove(args.host, args.port, "OpenNMT")

        probe(
            args.name,
            model,
コード例 #18
0
ファイル: train_single.py プロジェクト: sarubi/AIW_MTL
 def valid_iter_fct(task_id):
     return build_dataset_iter(
         lazily_load_dataset("valid", opt, task_id=task_id), fields_list[task_id], opt)
コード例 #19
0
def main(opt, device_id):
    opt = training_opt_postprocessing(opt, device_id)
    init_logger(opt.log_file)
    out_file = None
    best_test_score, best_ckpt = -10000, None
    dummy_parser = argparse.ArgumentParser(description='all_dev.py')
    opts.model_opts(dummy_parser)
    dummy_opt = dummy_parser.parse_known_args([])[0]
    for i in range(0, opt.train_epochs, 10):
        ckpt_path = '{}_epoch_{}.pt'.format(opt.save_model, i + 1)
        logger.info('Loading checkpoint from %s' % ckpt_path)
        checkpoint = torch.load(ckpt_path,
                                map_location=lambda storage, loc: storage)
        model_opt = checkpoint['opt']
        fields = load_fields_from_vocab(checkpoint['vocab'], data_type="text")

        # Build model.
        model = build_model(model_opt, opt, fields, checkpoint)

        assert opt.train_from == ''  # do not load optimizer state
        optim = build_optim(model, opt, checkpoint)
        # Build model saver, no need to create task dir for dev
        if not os.path.exists('experiments/all_dev'):
            os.mkdir('experiments/all_dev')
            os.mkdir('experiments/all_dev/' + opt.meta_dev_task)
        elif not os.path.exists('experiments/all_dev/' + opt.meta_dev_task):
            os.mkdir('experiments/all_dev/' + opt.meta_dev_task)
        model_saver = build_model_saver(
            model_opt, 'experiments/all_dev/' + opt.meta_dev_task + '/model',
            opt, model, fields, optim)

        trainer = build_trainer(opt,
                                device_id,
                                model,
                                fields,
                                optim,
                                "text",
                                model_saver=model_saver)

        train_iter = list(
            build_dataset_iter(lazily_load_dataset("train", opt), fields, opt))
        # do training on trainset of meta-dev task
        trainer.train(train_iter, opt.inner_iterations)

        # do evaluation on devset of meta-dev task
        best_dev_score, best_model_path = -10000, None
        for model_path in os.listdir('experiments/all_dev/' +
                                     opt.meta_dev_task):
            if model_path.find('.pt') == -1:
                continue
            if out_file is None:
                out_file = codecs.open(opt.output, 'w+', 'utf-8')

            fields, model, model_opt = onmt.model_builder.load_test_model(
                opt,
                dummy_opt.__dict__,
                model_path='experiments/all_dev/' + opt.meta_dev_task + '/' +
                model_path)

            scorer = onmt.translate.GNMTGlobalScorer(opt.alpha, opt.beta,
                                                     opt.coverage_penalty,
                                                     opt.length_penalty)

            kwargs = {
                k: getattr(opt, k)
                for k in [
                    "beam_size", "n_best", "max_length", "min_length",
                    "stepwise_penalty", "block_ngram_repeat",
                    "ignore_when_blocking", "dump_beam", "report_bleu",
                    "replace_unk", "gpu", "verbose", "fast", "mask_from"
                ]
            }
            fields['graph'] = torchtext.data.Field(sequential=False)
            translator = Translator(model,
                                    fields,
                                    global_scorer=scorer,
                                    out_file=out_file,
                                    report_score=False,
                                    copy_attn=model_opt.copy_attn,
                                    logger=logger,
                                    log_probs_out_file=None,
                                    **kwargs)
            # make translation and save result
            all_scores, all_predictions = translator.translate(
                src_path='processed_data/meta-dev/' + opt.meta_dev_task +
                '/src-dev.txt',
                tgt_path=None,
                src_dir=None,
                batch_size=opt.translate_batch_size,
                attn_debug=False)
            # dump predictions
            f = open('experiments/all_dev/' + opt.meta_dev_task +
                     '/dev_predictions.csv',
                     'w',
                     encoding='utf-8')
            f.write('smiles,property\n')
            for n_best_mols in all_predictions:
                for mol in n_best_mols:
                    f.write(mol.replace(' ', '') + ',0\n')
            f.close()
            # call chemprop to get scores
            test_path = '\"' + 'experiments/all_dev/' + opt.meta_dev_task + '/dev_predictions.csv' + '\"'
            checkpoint_path = '\"' + 'scorer_ckpts/' + opt.meta_dev_task + '/model.pt' + '\"'
            preds_path = '\"' + 'experiments/all_dev/' + opt.meta_dev_task + '/dev_scores.csv' + '\"'

            # in case of all mols are invalid (will produce not output file by chemprop)
            # the predictions are copied into score file
            cmd = 'cp {} {}'.format(test_path, preds_path)
            result = os.popen(cmd)
            result.close()

            cmd = 'python chemprop/predict.py --test_path {} --checkpoint_path {} --preds_path {} --num_workers 0'.format(
                test_path, checkpoint_path, preds_path)
            scorer_result = os.popen(cmd)
            scorer_result.close()
            # read score file and get score
            score = read_score_csv('experiments/all_dev/' + opt.meta_dev_task +
                                   '/dev_scores.csv')

            assert len(score) % opt.beam_size == 0
            # dev_scores = []
            # for i in range(0, len(score), opt.beam_size):
            #     dev_scores.append(sum([x[1] for x in score[i:i+opt.beam_size]]) / opt.beam_size)

            # report dev score
            dev_metrics = calculate_metrics(opt.meta_dev_task, 'dev', 'dev',
                                            score)
            logger.info('dev metrics: ' + str(dev_metrics))
            dev_score = dev_metrics['success_rate']
            if dev_score > best_dev_score:
                logger.info('New best dev success rate: {:.4f} by {}'.format(
                    dev_score, model_path))
                best_model_path = model_path
                best_dev_score = dev_score
            else:
                logger.info('dev success rate: {:.4f} by {}'.format(
                    dev_score, model_path))

            del fields
            del model
            del model_opt
            del scorer
            del translator
            gc.collect()

        assert best_model_path != None
        # do testing on testset of meta-dev task
        if out_file is None:
            out_file = codecs.open(opt.output, 'w+', 'utf-8')
        fields, model, model_opt = onmt.model_builder.load_test_model(
            opt,
            dummy_opt.__dict__,
            model_path='experiments/all_dev/' + opt.meta_dev_task + '/' +
            best_model_path)

        scorer = onmt.translate.GNMTGlobalScorer(opt.alpha, opt.beta,
                                                 opt.coverage_penalty,
                                                 opt.length_penalty)

        kwargs = {
            k: getattr(opt, k)
            for k in [
                "beam_size", "n_best", "max_length", "min_length",
                "stepwise_penalty", "block_ngram_repeat",
                "ignore_when_blocking", "dump_beam", "report_bleu",
                "replace_unk", "gpu", "verbose", "fast", "mask_from"
            ]
        }
        fields['graph'] = torchtext.data.Field(sequential=False)
        translator = Translator(model,
                                fields,
                                global_scorer=scorer,
                                out_file=out_file,
                                report_score=False,
                                copy_attn=model_opt.copy_attn,
                                logger=logger,
                                log_probs_out_file=None,
                                **kwargs)
        # make translation and save result
        all_scores, all_predictions = translator.translate(
            src_path='processed_data/meta-dev/' + opt.meta_dev_task +
            '/src-test.txt',
            tgt_path=None,
            src_dir=None,
            batch_size=opt.translate_batch_size,
            attn_debug=False)
        # dump predictions
        f = open('experiments/all_dev/' + opt.meta_dev_task +
                 '/test_predictions.csv',
                 'w',
                 encoding='utf-8')
        f.write('smiles,property\n')
        for n_best_mols in all_predictions:
            for mol in n_best_mols:
                f.write(mol.replace(' ', '') + ',0\n')
        f.close()
        # call chemprop to get scores
        test_path = '\"' + 'experiments/all_dev/' + opt.meta_dev_task + '/test_predictions.csv' + '\"'
        checkpoint_path = '\"' + 'scorer_ckpts/' + opt.meta_dev_task + '/model.pt' + '\"'
        preds_path = '\"' + 'experiments/all_dev/' + opt.meta_dev_task + '/test_scores.csv' + '\"'

        # in case of all mols are invalid (will produce not output file by chemprop)
        # the predictions are copied into score file
        cmd = 'cp {} {}'.format(test_path, preds_path)
        result = os.popen(cmd)
        result.close()

        cmd = 'python chemprop/predict.py --test_path {} --checkpoint_path {} --preds_path {} --num_workers 0'.format(
            test_path, checkpoint_path, preds_path)
        scorer_result = os.popen(cmd)
        # logger.info('{}'.format('\n'.join(scorer_result.readlines())))
        scorer_result.close()
        # read score file and get score

        score = read_score_csv('experiments/all_dev/' + opt.meta_dev_task +
                               '/test_scores.csv')

        assert len(score) % opt.beam_size == 0
        # test_scores = []
        # for i in range(0, len(score), opt.beam_size):
        #     test_scores.append(sum([x[1] for x in score[i:i+opt.beam_size]]) / opt.beam_size)

        # report if it is the best on test
        test_metrics = calculate_metrics(opt.meta_dev_task, 'dev', 'test',
                                         score)
        logger.info('test metrics: ' + str(test_metrics))
        test_score = test_metrics['success_rate']
        if test_score > best_test_score:
            best_ckpt = ckpt_path
            logger.info('New best test success rate: {:.4f} by {}'.format(
                test_score, ckpt_path))
            best_test_score = test_score
        else:
            logger.info('test success rate: {:.4f} by {}'.format(
                test_score, ckpt_path))

        del model_opt
        del fields
        del checkpoint
        del model
        del optim
        del model_saver
        del trainer
        gc.collect()
コード例 #20
0
ファイル: train_single.py プロジェクト: sarubi/AIW_MTL
def main(opt, device_id):
    opt = training_opt_postprocessing(opt, device_id)
    init_logger(opt.log_file)
    # Gather information related to the training script and commit version
    script_path = os.path.abspath(__file__)
    script_dir = os.path.dirname(os.path.dirname(script_path))
    logger.info('Train script dir: %s' % script_dir)
    git_commit = str(subprocess.check_output(['bash', script_dir + '/cluster_scripts/git_version.sh']))
    logger.info("Git Commit: %s" % git_commit[2:-3])
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        # TODO: load MTL model
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)

        # Load default opts values then overwrite it with opts from
        # the checkpoint. It's usefull in order to re-train a model
        # after adding a new option (not set in checkpoint)
        dummy_parser = configargparse.ArgumentParser()
        opts.model_opts(dummy_parser)
        default_opt = dummy_parser.parse_known_args([])[0]

        model_opt = default_opt
        model_opt.__dict__.update(checkpoint['opt'].__dict__)
    else:
        checkpoint = None
        model_opt = opt


    num_tasks = len(opt.data.split(','))
    opt.num_tasks = num_tasks

    checkpoint_list=[]
    if opt.warm_model:
        base_name=opt.warm_model
        for task_id in range(num_tasks):
            chkpt_path=base_name.replace("X",str(task_id))
            if not os.path.isfile(chkpt_path):
                chkpt_path = base_name.replace("X", str(0))
            logger.info('Loading a checkpoint from %s' % chkpt_path)

            checkpoint = torch.load(chkpt_path,
                                    map_location=lambda storage, loc: storage)
            checkpoint_list.append(checkpoint)
    else:
        for task_id in range(num_tasks):
            checkpoint_list.append(None)

    fields_list = []
    data_type=None
    for task_id in range(num_tasks):
        # Peek the first dataset to determine the data_type.
        # (All datasets have the same data_type).
        first_dataset = next(lazily_load_dataset("train", opt, task_id=task_id))
        data_type = first_dataset.data_type

        # Load fields generated from preprocess phase.
        if opt.mtl_shared_vocab and task_id > 0:
            logger.info(' * vocabulary size. Same as the main task!')
            fields = fields_list[0]
        else:
            fields = load_fields(first_dataset, opt, checkpoint_list[task_id], task_id=task_id)

        # Report src/tgt features.

        src_features, tgt_features = _collect_report_features(fields)
        for j, feat in enumerate(src_features):
            logger.info(' * (Task %d) src feature %d size = %d'
                        % (task_id, j, len(fields[feat].vocab)))
        for j, feat in enumerate(tgt_features):
            logger.info(' * (Task %) tgt feature %d size = %d'
                        % (task_id, j, len(fields[feat].vocab)))
        fields_list.append(fields)

    if opt.epochs > -1:
        total_num_batch = 0
        for task_id in range(num_tasks):
            train_iter = build_dataset_iter(lazily_load_dataset("train", opt, task_id=task_id), fields_list[task_id], opt)
            for i, batch in enumerate(train_iter):
                num_batch = i
            total_num_batch+=num_batch
            if opt.mtl_schedule < 10:
                break
        num_batch = total_num_batch
        opt.train_steps = (num_batch * opt.epochs) + 1
        # Do the validation and save after each epoch
        opt.valid_steps = num_batch
        opt.save_checkpoint_steps = 1

    # logger.info(opt_to_string(opt))
    logger.info(opt)

    # Build model(s).
    models_list = []
    for task_id in range(num_tasks):
        if opt.mtl_fully_share and task_id > 0:
            # Since we only have one model, copy the pointer to the model for all
            models_list.append(models_list[0])
        else:

            main_model = models_list[0] if task_id > 0 else None
            model = build_model(model_opt, opt, fields_list[task_id], checkpoint_list[task_id], main_model=main_model, task_id=task_id)
            n_params, enc, dec = _tally_parameters(model)
            logger.info('(Task %d) encoder: %d' % (task_id, enc))
            logger.info('(Task %d) decoder: %d' % (task_id, dec))
            logger.info('* number of parameters: %d' % n_params)
            _check_save_model_path(opt)
            models_list.append(model)

    # combine parameters of different models and consider shared parameters just once.
    def combine_named_parameters(named_params_list):
        observed_params = []
        for model_named_params in named_params_list:
            for name, p in model_named_params:
                is_observed = False
                # Check whether we observed this parameter before
                for param in observed_params:
                    if p is param:
                        is_observed = True
                        break
                if not is_observed:
                    observed_params.append(p)
                    yield name, p

    # Build optimizer.
    optims_list = []
    all_models_params=[]
    for task_id in range(num_tasks):
        if not opt.mtl_shared_optimizer:
            optim = build_optim(models_list[task_id], opt, checkpoint)
            optims_list.append(optim)
        else:
            all_models_params.append(models_list[task_id].named_parameters())

    # Extract the list of shared parameters among the models of all tasks.
    observed_params = []
    shared_params = []
    for task_id in range(num_tasks):
        for name, p in models_list[task_id].named_parameters():
            is_observed = False
            # Check whether we observed this parameter before
            for param in observed_params:
                if p is param:
                    shared_params.append(name)
                    is_observed = True
                    break
            if not is_observed:
                observed_params.append(p)
    opt.shared_params = shared_params

    if opt.mtl_shared_optimizer:
        optim = build_optim_mtl_params(combine_named_parameters(all_models_params), opt, checkpoint)
        optims_list.append(optim)

    # Build model saver
    model_saver = build_mtl_model_saver(model_opt, opt, models_list, fields_list, optims_list)

    trainer = build_trainer(opt, device_id, models_list, fields_list,
                            optims_list, data_type, model_saver=model_saver)

    def train_iter_fct(task_id):
        return build_dataset_iter(
            lazily_load_dataset("train", opt, task_id=task_id), fields_list[task_id], opt)

    def valid_iter_fct(task_id):
        return build_dataset_iter(
            lazily_load_dataset("valid", opt, task_id=task_id), fields_list[task_id], opt)

    def meta_valid_iter_fct(task_id, is_log=False):
        return build_dataset_iter(
            lazily_load_dataset("meta_valid", opt, task_id=task_id, is_log=is_log), fields_list[task_id], opt)

    # Do training.
    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps,
                  opt.valid_steps, meta_valid_iter_fct=meta_valid_iter_fct)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
def main(opt, device_id):
    opt = training_opt_postprocessing(opt, device_id)
    init_logger(opt.log_file)
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        model_opt = checkpoint['opt']
    else:
        checkpoint = None
        model_opt = opt

    # Peek the first dataset to determine the data_type.
    # (All datasets have the same data_type).
    first_dataset = next(lazily_load_dataset("train", opt))
    data_type = first_dataset.data_type

    # Load fields generated from preprocess phase.
    fields = _load_fields(first_dataset, data_type, opt, checkpoint)

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    logger.info('* batch_size: %d' % opt.batch_size)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = build_optim(model, opt, checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            data_type,
                            model_saver=model_saver)

    def data_iter_fct(data_stage):
        """data_stage: train / valid"""
        pt_file = opt.data + '.' + data_stage + '.pt'
        logger.info('Loading {} dataset'.format(data_stage))
        dataset = torch.load(pt_file)
        logger.info('Loaded {} dataset'.format(data_stage))
        dataset.fields = fields
        is_train = True if data_stage == "train" else False
        batch_size = opt.batch_size if is_train else opt.valid_batch_size
        repeat = True if data_stage == "train" else False
        if opt.gpuid != -1:
            device = "cuda"
        else:
            device = "cpu"

        def sort_key(ex):
            """ Sort using length of source sentences. """
            return ex.total_tokens

        return torchtext.data.Iterator(dataset=dataset,
                                       batch_size=batch_size,
                                       device=device,
                                       train=is_train,
                                       sort=False,
                                       sort_key=sort_key,
                                       repeat=repeat)

    # Do training.
    trainer.train(data_iter_fct, opt.train_steps, opt.valid_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
コード例 #22
0
 def valid_iter_fct():
     return build_dataset_iter(lazily_load_dataset("valid", opt, logger),
                               fields, opt)
コード例 #23
0
ファイル: train_single.py プロジェクト: USE-sum/usesum
 def train_iter_fct():
     return build_dataset_iter(lazily_load_dataset("train", opt), fields,
                               opt)
コード例 #24
0
def main(opt, device_id):
    opt = training_opt_postprocessing(opt, device_id)
    init_logger(opt.log_file)
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)

        # Load default opts values then overwrite it with opts from
        # the checkpoint. It's usefull in order to re-train a model
        # after adding a new option (not set in checkpoint)
        dummy_parser = configargparse.ArgumentParser()
        opts.model_opts(dummy_parser)
        default_opt = dummy_parser.parse_known_args([])[0]

        model_opt = default_opt
        model_opt.__dict__.update(checkpoint['opt'].__dict__)
    else:
        checkpoint = None
        model_opt = opt

    # Peek the first dataset to determine the data_type.
    # (All datasets have the same data_type).
    first_dataset = next(lazily_load_dataset("train", opt))
    data_type = first_dataset.data_type

    # Load fields generated from preprocess phase.
    fields = load_fields(first_dataset, opt, checkpoint)

    # Report src/tgt features.

    src_features, tgt_features = _collect_report_features(fields)
    for j, feat in enumerate(src_features):
        logger.info(' * src feature %d size = %d' %
                    (j, len(fields[feat].vocab)))
    for j, feat in enumerate(tgt_features):
        logger.info(' * tgt feature %d size = %d' %
                    (j, len(fields[feat].vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint)
    n_params, enc, dec = _tally_parameters(model)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('* number of parameters: %d' % n_params)
    _check_save_model_path(opt)

    # Build optimizer.
    optim = build_optim(model, opt, checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(opt,
                            device_id,
                            model,
                            fields,
                            optim,
                            data_type,
                            model_saver=model_saver)

    def train_iter_fct():
        return build_dataset_iter(lazily_load_dataset("train", opt), fields,
                                  opt)

    def valid_iter_fct():
        return build_dataset_iter(lazily_load_dataset("valid", opt),
                                  fields,
                                  opt,
                                  is_train=False)

    # Do training.
    if len(opt.gpu_ranks):
        logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
    else:
        logger.info('Starting training on CPU, could be very slow')
    trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps,
                  opt.valid_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
コード例 #25
0
ファイル: run_opennmt.py プロジェクト: jniehues-kit/vivisect
 def valid_iter_fct():
     model._vivisect["mode"] = "dev"
     return build_dataset_iter(lazily_load_dataset("valid", opt), fields,
                               opt)
コード例 #26
0
    def valid_iter_fct(): return build_dataset_iter(
        lazily_load_dataset("valid", opt), fields, opt, is_train=False)

    # Do training.
    if len(opt.gpu_ranks):
コード例 #27
0
def main(opt):
    opt = training_opt_postprocessing(opt)
    init_logger(opt.log_file)
    # Load checkpoint if we resume from a previous training.
    if opt.train_from:
        logger.info('Loading checkpoint from %s' % opt.train_from)
        checkpoint = torch.load(opt.train_from,
                                map_location=lambda storage, loc: storage)
        model_opt = checkpoint['opt']
    else:
        checkpoint = None
        model_opt = opt

    if opt.load_pretrained_selector_from:
        logger.info('Loading selector checkpoint from %s' %
                    opt.load_pretrained_selector_from)
        sel_checkpoint = torch.load(opt.load_pretrained_selector_from,
                                    map_location=lambda storage, loc: storage)
    else:
        sel_checkpoint = None

    if opt.load_pretrained_s2s_generator_from:
        logger.info('Loading s2s generator checkpoint from %s' %
                    opt.load_pretrained_s2s_generator_from)
        s2s_gen_checkpoint = torch.load(
            opt.load_pretrained_s2s_generator_from,
            map_location=lambda storage, loc: storage)
    else:
        s2s_gen_checkpoint = None

    # Peek the fisrt dataset to determine the data_type.
    # (All datasets have the same data_type).
    first_dataset = next(lazily_load_dataset("train", opt))
    data_type = first_dataset.data_type

    # Load fields generated from preprocess phase.
    fields = _load_fields(first_dataset, data_type, opt, checkpoint)

    # Report src/tgt features.
    src_features, tgt_features = _collect_report_features(fields)
    for j, feat in enumerate(src_features):
        logger.info(' * src feature %d size = %d' %
                    (j, len(fields[feat].vocab)))
    for j, feat in enumerate(tgt_features):
        logger.info(' * tgt feature %d size = %d' %
                    (j, len(fields[feat].vocab)))

    # Build model.
    model = build_model(model_opt, opt, fields, checkpoint, sel_checkpoint,
                        s2s_gen_checkpoint)

    # Fix the pretrained selector parameters if needed
    if model_opt.fix_sel_all:
        assert opt.load_pretrained_selector_from
        assert opt.sel_lambda == 0.0
        assert not model_opt.fix_sel_classifier
        for name, param in model.named_parameters():
            if 'selector' in name:
                param.requires_grad = False
    # only fix the classifier of the selector
    if model_opt.fix_sel_classifier:
        assert opt.load_pretrained_selector_from
        assert not model_opt.fix_sel_all
        for name, param in model.named_parameters():
            if 'selector' in name and 'rnn' not in name and 'embeddings' not in name:
                param.requires_grad = False

    n_params, sel, enc, dec, gen = _my_tally_parameters(model)
    logger.info('selector: %d' % sel)
    logger.info('encoder: %d' % enc)
    logger.info('decoder: %d' % dec)
    logger.info('generator: %d' % gen)
    logger.info('* number of parameters: %d' % n_params)

    print_trainable_parameters(model)

    _check_save_model_path(opt)

    # Build optimizer.
    optim = build_optim(model, opt, checkpoint)

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, fields, optim)

    trainer = build_trainer(opt,
                            model,
                            fields,
                            optim,
                            data_type,
                            model_saver=model_saver)

    def train_iter_fct():
        return build_dataset_iter(lazily_load_dataset("train", opt), fields,
                                  opt)

    def valid_iter_fct():
        return build_dataset_iter(lazily_load_dataset("valid", opt), fields,
                                  opt)

    # Do training.
    trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps,
                  opt.valid_steps)

    if opt.tensorboard:
        trainer.report_manager.tensorboard_writer.close()
コード例 #28
0
ファイル: train_single.py プロジェクト: USE-sum/usesum
 def valid_iter_fct():
     return build_dataset_iter(lazily_load_dataset("valid", opt),
                               fields,
                               opt,
                               is_train=False)
コード例 #29
0
ファイル: train.py プロジェクト: goodatlas/OpenNMT-py
    def valid_iter_fct(): return build_dataset_iter(
        lazily_load_dataset("valid", opt), fields, opt)

    # Do training.
    trainer.train(train_iter_fct, valid_iter_fct, opt.start_epoch, opt.epochs)