コード例 #1
0
def run(config_dict):

    # Fetch all relevant modules.
    data_module = config_dict['data-module']
    model_module = config_dict['model-module']
    training_module = config_dict['training-module']
    evaluation_module = config_dict.get('evaluation-module', None)
    mode = config_dict['mode']

    # Load the modules
    DataClass = importlib.import_module(data_module).component
    ModelClass = importlib.import_module(model_module).component
    TrainingClass = importlib.import_module(training_module).component
    EvaluationClass = importlib.import_module(
        evaluation_module).component if evaluation_module else None

    model_dirname = make_model_dir(config_dict)
    logger = set_logger(config_dict["log_level"],
                        os.path.join(model_dirname, "log.txt"))

    # Setup the data
    data = DataClass(config_dict["data_params"])
    data.setup()

    # Setup the model
    fix_seed(config_d['random_seed'])  # fix seed generators
    model = ModelClass(config_dict["model_params"])
    model.setup(
        data)  # there are some data-specific params => pass data as arg

    if mode == "train":
        training_params = config_dict['training_params']
        trainer = TrainingClass(training_params)
        trainer.training_start(model, data)
        save_config(config_dict, os.path.join(model_dirname, 'config.json'))
    elif mode == "predict":
        assert evaluation_module is not None, "No evaluation module -- check config file!"
        evaluator = EvaluationClass(config_dict)
        model_fname = config_dict["model_fn"]
        load_model(model, model_fname)
        id2word = data.vocab.id2tok
        # predict on dev set
        if 'dev' in data.fnames:
            logger.info("Predicting on dev data")
            predicted_ids, attention_weights = evaluator.evaluate_model(
                model, data.dev[0])
            data_lexicalizations = data.lexicalizations['dev']
            predicted_snts = evaluator.lexicalize_predictions(
                predicted_ids, data_lexicalizations, id2word)
            save_predictions_txt(predicted_snts,
                                 '%s.devset.predictions.txt' % model_fname)
        # predict on test set
        if 'test' in data.fnames:
            logger.info("Predicting on test data")
            predicted_ids, attention_weights = evaluator.evaluate_model(
                model, data.test[0])
            data_lexicalizations = data.lexicalizations['test']
            predicted_snts = evaluator.lexicalize_predictions(
                predicted_ids, data_lexicalizations, id2word)
            save_predictions_txt(predicted_snts,
                                 '%s.testset.predictions.txt' % model_fname)
    else:
        logger.warning("Check the 'mode' field in the config file: %s" % mode)

    logger.info('DONE')
コード例 #2
0
ファイル: run_gg.py プロジェクト: templeblock/prag_generation
def run(config_dict):

    # Fetch all relevant modules.
    data_module = config_dict['data-module']
    model_module = config_dict['model-module']
    training_module = config_dict['training-module']
    evaluation_module = config_dict.get('evaluation-module', None)
    mode = config_dict['mode']

    # Load the modules
    DataClass = importlib.import_module(data_module).component
    ModelClass = importlib.import_module(model_module).component
    TrainingClass = importlib.import_module(training_module).component
    EvaluationClass = importlib.import_module(
        evaluation_module).component if evaluation_module else None

    model_dirname = make_model_dir(config_dict)
    logger = set_logger(config_dict["log_level"],
                        os.path.join(model_dirname, "log.txt"))

    # Setup the data
    data = DataClass(config_dict["data_params"])
    data.setup()

    # Setup the model
    fix_seed(config_d['random_seed'])  # fix seed generators
    model = ModelClass(config_dict["model_params"])
    print("build model done")
    model.setup(
        data)  # there are some data-specific params => pass data as arg
    print("setup data done")
    #print(len(data.lexicalizations['test']))
    if mode == "train":
        training_params = config_dict['training_params']
        trainer = TrainingClass(training_params)
        trainer.training_start(model, data)
        save_config(config_dict, os.path.join(model_dirname, 'config.json'))

    elif mode == "predict":
        assert evaluation_module is not None, "No evaluation module -- check config file!"
        evaluator = EvaluationClass(config_dict)
        model_fname = config_dict["model_fn"]
        load_model(model, model_fname)
        #print(model.state_dict())
        model = model.to('cuda')
        id2word = data.vocab.id2tok
        beam_size = None  #10
        alpha = 0.3
        #"""
        if 'dev' in data.fnames:
            logger.info("Predicting on dev data")
            print(len(data.uni_mr['dev']), len(data.dev[0]),
                  len(data.lexicalizations['dev']))
            dec_snt_beam, fw_beam = [], []
            predicted_ids, fw_beam = evaluator.evaluate_model(
                model,
                data.dev[0],
                data.uni_mr['dev'],
                beam_size=beam_size,
                alpha=alpha)
            data_lexicalizations = data.lexicalizations['dev']
            #dec_snt_beam = []
            #for _ in range(beam_size):
            #    predicted_snt = evaluator.lexicalize_predictions(predicted_ids[_], data_lexicalizations, id2word)
            #    dec_snt_beam.append( predicted_snt )
            print(len(predicted_ids), len(data_lexicalizations))
            predicted_snts = evaluator.lexicalize_predictions(
                predicted_ids, data_lexicalizations, id2word)
            #save_predictions_txt(predicted_snts, '%s.devset.predictions.txt_incre_0.7' % model_fname)
            save_predictions_txt(
                predicted_snts,
                '%s.devset.predictions.txt_incre_%.1f' % (model_fname, alpha))
            # for beam_idx in range(beam_size):
            #     predicted_ids, attention_weights = evaluator.evaluate_model(model, data.dev[0], data.uni_mr['dev'])
            #     data_lexicalizations = data.lexicalizations['dev']
            #     #print(len(predicted_ids), len(data_lexicalizations))
            #     predicted_snts = evaluator.lexicalize_predictions(predicted_ids,
            #                                                   data_lexicalizations,
            #                                                   id2word)
            #     fw_prob = [[x.item() for x in x_list ] for x_list in attention_weights]
            #     #print(len(fw_prob), len(fw_prob[0]))#, fw_prob)
            #     dec_snt_beam.append( predicted_snts  ), fw_beam.append( fw_prob  )
            #     save_predictions_txt(predicted_snts, '%s.devset.predictions.txt' % model_fname)
            #save_beam_fw(fw_beam, dec_snt_beam, beam_size, '%s.devset.recs.txt' % model_fname)
            exit()
        #"""
        if 'test' in data.fnames:
            logger.info("Predicting on test data")
            print(len(data.test[0]))
            predicted_ids, attention_weights = evaluator.evaluate_model(
                model,
                data.test[0],
                data.uni_mr['test'],
                beam_size=beam_size,
                alpha=alpha)
            data_lexicalizations = data.lexicalizations['test']
            print(len(predicted_ids), len(data_lexicalizations))
            predicted_snts = evaluator.lexicalize_predictions(
                predicted_ids, data_lexicalizations, id2word)

            save_predictions_txt(
                predicted_snts,
                '%s.testset.predictions.txt_inre_%.1f' % (model_fname, alpha))

    else:
        logger.warning("Check the 'mode' field in the config file: %s" % mode)

    logger.info('DONE')