コード例 #1
0
    def training_start(self, model, data):

        training_start_time = time.time()
        logger.info("Start training")

        # Print a model summary to make sure everything is ok with it
        model_summary = torch_summarize(model)
        logger.debug(model_summary)

        evaluator = BaseEvaluator(self.config)
        logger.debug("Preparing training data")

        train_batches = data.prepare_training_data(data.train, self.batch_size)
        dev_batches = data.prepare_training_data(data.dev, self.batch_size)

        id2word = data.vocab.id2tok
        dev_lexicalizations = data.lexicalizations['dev']
        dev_multi_ref_fn = '%s.multi-ref' % data.fnames['dev']

        self.set_optimizer(model, self.config['optimizer'])
        self.set_train_criterion(len(id2word), PAD_ID)
        #print(data.dev[0])
        #exit()
        # Moving the model to GPU, if available
        if self.use_cuda:
            model = model.cuda()

        for epoch_idx in range(1, self.n_epochs + 1):

            epoch_start = time.time()
            pred_fn = os.path.join(self.model_dir, 'predictions.epoch%d' % epoch_idx)

            train_loss = self.train_epoch(epoch_idx, model, train_batches)
            dev_loss = self.compute_val_loss(model, dev_batches)

            predicted_ids, attention_weights = evaluator.evaluate_model(model, data.dev[0], data.uni_mr['dev'])
            predicted_tokens = evaluator.lexicalize_predictions(predicted_ids,
                                                                dev_lexicalizations,
                                                                id2word)

            save_predictions_txt(predicted_tokens, pred_fn)
            self.record_loss(train_loss, dev_loss)

            if self.evaluate_prediction:
                self.run_external_eval(dev_multi_ref_fn, pred_fn)

            if self.save_model:
                save_model(model, os.path.join(self.model_dir, 'weights.epoch%d' % epoch_idx))

            logger.info('Epoch %d/%d: time=%s' % (epoch_idx, self.n_epochs, asMinutes(time.time() - epoch_start)))

        self.plot_lcurve()

        if self.evaluate_prediction:
            score_fname = os.path.join(self.model_dir, 'scores.csv')
            scores = self.get_scores_to_save()
            save_scores(scores, self.score_file_header, score_fname)
            self.plot_training_results()

        logger.info('End training time=%s' % (asMinutes(time.time() - training_start_time)))
コード例 #2
0
def run(config_dict):

    # Fetch all relevant modules.
    data_module = config_dict['data-module']
    model_module = config_dict['model-module']
    training_module = config_dict['training-module']
    evaluation_module = config_dict.get('evaluation-module', None)
    mode = config_dict['mode']

    # Load the modules
    DataClass = importlib.import_module(data_module).component
    ModelClass = importlib.import_module(model_module).component
    TrainingClass = importlib.import_module(training_module).component
    EvaluationClass = importlib.import_module(
        evaluation_module).component if evaluation_module else None

    model_dirname = make_model_dir(config_dict)
    logger = set_logger(config_dict["log_level"],
                        os.path.join(model_dirname, "log.txt"))

    # Setup the data
    data = DataClass(config_dict["data_params"])
    data.setup()

    # Setup the model
    fix_seed(config_d['random_seed'])  # fix seed generators
    model = ModelClass(config_dict["model_params"])
    model.setup(
        data)  # there are some data-specific params => pass data as arg

    if mode == "train":
        training_params = config_dict['training_params']
        trainer = TrainingClass(training_params)
        trainer.training_start(model, data)
        save_config(config_dict, os.path.join(model_dirname, 'config.json'))
    elif mode == "predict":
        assert evaluation_module is not None, "No evaluation module -- check config file!"
        evaluator = EvaluationClass(config_dict)
        model_fname = config_dict["model_fn"]
        load_model(model, model_fname)
        id2word = data.vocab.id2tok
        # predict on dev set
        if 'dev' in data.fnames:
            logger.info("Predicting on dev data")
            predicted_ids, attention_weights = evaluator.evaluate_model(
                model, data.dev[0])
            data_lexicalizations = data.lexicalizations['dev']
            predicted_snts = evaluator.lexicalize_predictions(
                predicted_ids, data_lexicalizations, id2word)
            save_predictions_txt(predicted_snts,
                                 '%s.devset.predictions.txt' % model_fname)
        # predict on test set
        if 'test' in data.fnames:
            logger.info("Predicting on test data")
            predicted_ids, attention_weights = evaluator.evaluate_model(
                model, data.test[0])
            data_lexicalizations = data.lexicalizations['test']
            predicted_snts = evaluator.lexicalize_predictions(
                predicted_ids, data_lexicalizations, id2word)
            save_predictions_txt(predicted_snts,
                                 '%s.testset.predictions.txt' % model_fname)
    else:
        logger.warning("Check the 'mode' field in the config file: %s" % mode)

    logger.info('DONE')
コード例 #3
0
ファイル: run_gg.py プロジェクト: templeblock/prag_generation
def run(config_dict):

    # Fetch all relevant modules.
    data_module = config_dict['data-module']
    model_module = config_dict['model-module']
    training_module = config_dict['training-module']
    evaluation_module = config_dict.get('evaluation-module', None)
    mode = config_dict['mode']

    # Load the modules
    DataClass = importlib.import_module(data_module).component
    ModelClass = importlib.import_module(model_module).component
    TrainingClass = importlib.import_module(training_module).component
    EvaluationClass = importlib.import_module(
        evaluation_module).component if evaluation_module else None

    model_dirname = make_model_dir(config_dict)
    logger = set_logger(config_dict["log_level"],
                        os.path.join(model_dirname, "log.txt"))

    # Setup the data
    data = DataClass(config_dict["data_params"])
    data.setup()

    # Setup the model
    fix_seed(config_d['random_seed'])  # fix seed generators
    model = ModelClass(config_dict["model_params"])
    print("build model done")
    model.setup(
        data)  # there are some data-specific params => pass data as arg
    print("setup data done")
    #print(len(data.lexicalizations['test']))
    if mode == "train":
        training_params = config_dict['training_params']
        trainer = TrainingClass(training_params)
        trainer.training_start(model, data)
        save_config(config_dict, os.path.join(model_dirname, 'config.json'))

    elif mode == "predict":
        assert evaluation_module is not None, "No evaluation module -- check config file!"
        evaluator = EvaluationClass(config_dict)
        model_fname = config_dict["model_fn"]
        load_model(model, model_fname)
        #print(model.state_dict())
        model = model.to('cuda')
        id2word = data.vocab.id2tok
        beam_size = None  #10
        alpha = 0.3
        #"""
        if 'dev' in data.fnames:
            logger.info("Predicting on dev data")
            print(len(data.uni_mr['dev']), len(data.dev[0]),
                  len(data.lexicalizations['dev']))
            dec_snt_beam, fw_beam = [], []
            predicted_ids, fw_beam = evaluator.evaluate_model(
                model,
                data.dev[0],
                data.uni_mr['dev'],
                beam_size=beam_size,
                alpha=alpha)
            data_lexicalizations = data.lexicalizations['dev']
            #dec_snt_beam = []
            #for _ in range(beam_size):
            #    predicted_snt = evaluator.lexicalize_predictions(predicted_ids[_], data_lexicalizations, id2word)
            #    dec_snt_beam.append( predicted_snt )
            print(len(predicted_ids), len(data_lexicalizations))
            predicted_snts = evaluator.lexicalize_predictions(
                predicted_ids, data_lexicalizations, id2word)
            #save_predictions_txt(predicted_snts, '%s.devset.predictions.txt_incre_0.7' % model_fname)
            save_predictions_txt(
                predicted_snts,
                '%s.devset.predictions.txt_incre_%.1f' % (model_fname, alpha))
            # for beam_idx in range(beam_size):
            #     predicted_ids, attention_weights = evaluator.evaluate_model(model, data.dev[0], data.uni_mr['dev'])
            #     data_lexicalizations = data.lexicalizations['dev']
            #     #print(len(predicted_ids), len(data_lexicalizations))
            #     predicted_snts = evaluator.lexicalize_predictions(predicted_ids,
            #                                                   data_lexicalizations,
            #                                                   id2word)
            #     fw_prob = [[x.item() for x in x_list ] for x_list in attention_weights]
            #     #print(len(fw_prob), len(fw_prob[0]))#, fw_prob)
            #     dec_snt_beam.append( predicted_snts  ), fw_beam.append( fw_prob  )
            #     save_predictions_txt(predicted_snts, '%s.devset.predictions.txt' % model_fname)
            #save_beam_fw(fw_beam, dec_snt_beam, beam_size, '%s.devset.recs.txt' % model_fname)
            exit()
        #"""
        if 'test' in data.fnames:
            logger.info("Predicting on test data")
            print(len(data.test[0]))
            predicted_ids, attention_weights = evaluator.evaluate_model(
                model,
                data.test[0],
                data.uni_mr['test'],
                beam_size=beam_size,
                alpha=alpha)
            data_lexicalizations = data.lexicalizations['test']
            print(len(predicted_ids), len(data_lexicalizations))
            predicted_snts = evaluator.lexicalize_predictions(
                predicted_ids, data_lexicalizations, id2word)

            save_predictions_txt(
                predicted_snts,
                '%s.testset.predictions.txt_inre_%.1f' % (model_fname, alpha))

    else:
        logger.warning("Check the 'mode' field in the config file: %s" % mode)

    logger.info('DONE')