示例#1
0
def evaluate_model(lc_quad, args, assignments):
    print(assignments)
    args.gamma = assignments['gamma']
    # args.positive_reward = assignments['positive_reward']
    args.negetive_reward = assignments['negetive_reward']
    args.lr = assignments['lr']
    args.dropout = assignments['dropout']

    runner = Runner(lc_quad, args)
    runner.train(lc_quad, args)
    final_value = runner.test(lc_quad, args)

    return final_value
示例#2
0
    except RuntimeError as expt:
        logger.error(expt)
        return flask.jsonify({'error': str(expt)}), 408
    except Exception as expt:
        logger.error(expt)
        return flask.jsonify({'error': str(expt)}), 422


@app.errorhandler(404)
def not_found(error):
    return flask.make_response(flask.jsonify({'error': 'Command Not found'}), 404)


if __name__ == '__main__':
    logger = logging.getLogger(__name__)
    Utils.setup_logging()
    args = parse_args()

    dataset = LC_QuAD(config['lc_quad']['train'], config['lc_quad']['test'], config['lc_quad']['vocab'],
                      False, args.remove_stop_words)

    runner = Runner(dataset, args)
    runner.load_checkpoint()
    runner.environment.entity_linker = None
    runner.environment.relation_linker = None

    print(runner.link("Who has been married to both Penny Lancaster and Alana Stewart?", k=10, e=0.1))
    logger.info("Starting the HTTP server")
    http_server = WSGIServer(('', args.port), app)
    http_server.serve_forever()
示例#3
0
    start = time.time()
    args = parse_args()
    logger = logging.getLogger('main')
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter("%(levelname)s:%(name)s:%(message)s")
    ch = logging.StreamHandler()
    ch.setFormatter(formatter)
    logger.addHandler(ch)
    logger.info(args)

    dataset = None
    if args.dataset == 'lcquad':
        dataset = LC_QuAD(config['lc_quad']['train'],
                          config['lc_quad']['test'],
                          config['lc_quad']['vocab'], False,
                          args.remove_stop_words)
    elif args.dataset == 'qald_7_ml':
        dataset = Qald_7_ml(config['qald_7_ml']['train'],
                            config['qald_7_ml']['test'],
                            config['qald_7_ml']['vocab'], False, False)
    runner = Runner(dataset, args)

    if args.mode == 'test':
        runner.load_checkpoint()
    else:
        runner.train(dataset, args)
    logger.setLevel(logging.DEBUG)
    runner.test(dataset, args, use_elastic=True)
    finish = time.time()
    print('total runtime:', finish - start)
示例#4
0
    try:
        with open('eval-{}.json'.format(args.dataset), 'rt') as json_file:
            eval_results = json.load(json_file)
    except:
        eval_results = {}

    if args.mode == 'test':
        for file_name in os.listdir(config['chk_path']):
            if file_name.startswith(args.dataset) and 'bilstm' in file_name:
                try:
                    l = len(args.dataset)
                    args.policy = file_name[l + 1:-6]
                    args.b = int(file_name[-4:-3])
                    args.k = 1
                    args.checkpoint = file_name
                    runner = Runner(dataset, args)
                    runner.load_checkpoint(
                        os.path.join(config['chk_path'], args.checkpoint))
                    print(args)
                    results = runner.test(dataset, args,
                                          use_elastic=True)  # use_EARL=True)
                    eval_results[file_name] = results
                    finish = time.time()
                    print('total runtime:', finish - start)
                except Exception as e:
                    print(e)
                    print(file_name)
                    eval_results[file_name] = [0, 0]

        with open('eval-{}.json'.format(args.dataset), 'wt') as json_file:
            json.dump(eval_results, json_file)
示例#5
0
from common.dataset.qald_7_ml import Qald_7_ml
from scripts.config_args import parse_args

if __name__ == '__main__':

    args = parse_args()

    dataset = LC_QuAD(config['lc_quad']['train'], config['lc_quad']['test'],
                      config['lc_quad']['vocab'], False,
                      args.remove_stop_words)
    # dataset = Qald_6_ml(config['qald_6_ml']['train'], config['qald_6_ml']['test'], config['qald_6_ml']['vocab'],
    #                     False, False)
    # dataset = Qald_7_ml(config['qald_7_ml']['train'], config['qald_7_ml']['test'], config['qald_7_ml']['vocab'],
    #                           False, False)

    runner = Runner(dataset, args)
    runner.load_checkpoint(
        checkpoint_filename=
        '/Users/hamid/workspace/DeepShallowParsingQA/data/checkpoints/lctmp')
    runner.environment.entity_linker = None
    runner.environment.relation_linker = None

    connecting_relations = False
    free_relation_match = False
    connecting_relation = False
    k = 10
    results = {}
    for idx, qarow in tqdm(enumerate(dataset.test_set)):
        result = runner.link(qarow.question, 0.1, k, connecting_relations,
                             free_relation_match, connecting_relation, True)
        results[qarow.question] = result