예제 #1
0
    def __init__(self,
                 db: FeverDocDB,
                 sentence_level=False,
                 wiki_tokenizer: Tokenizer = None,
                 claim_tokenizer: Tokenizer = None,
                 token_indexers: Dict[str, TokenIndexer] = None,
                 filtering: str = None) -> None:
        self._sentence_level = sentence_level
        self._wiki_tokenizer = wiki_tokenizer or WordTokenizer()
        self._claim_tokenizer = claim_tokenizer or WordTokenizer()
        self._token_indexers = token_indexers or {
            'tokens': SingleIdTokenIndexer()
        }

        self.db = db

        self.formatter = FEVERGoldFormatter(set(self.db.get_doc_ids()),
                                            FEVERLabelSchema(),
                                            filtering=filtering)
        self.reader = JSONLineReader()
예제 #2
0
    mname = args.model
    logger.info("Model name is {0}".format(mname))

    ffns = []

    if args.sentence:
        logger.info("Model is Sentence level")
        ffns.append(SentenceLevelTermFrequencyFeatureFunction(db, naming=mname))
    else:
        logger.info("Model is Document level")
        ffns.append(TermFrequencyFeatureFunction(db,naming=mname))

    f = Features(mname,ffns)
    jlr = JSONLineReader()

    formatter = FEVERGoldFormatter(None, FEVERLabelSchema(),filtering=args.filtering)

    train_ds = DataSet(file=args.train, reader=jlr, formatter=formatter)
    dev_ds = DataSet(file=args.dev, reader=jlr, formatter=formatter)

    train_ds.read()
    dev_ds.read()

    test_ds = None
    if args.test is not None:
        test_ds = DataSet(file=args.test, reader=jlr, formatter=formatter)
        test_ds.read()

    train_feats, dev_feats, test_feats = f.load(train_ds, dev_ds, test_ds)
    f.save_vocab(mname)
예제 #3
0
 def __init__(self,
              db: FeverDocDB) -> None:
     self.db = db
     self.formatter = FEVERGoldFormatter(set(self.db.get_doc_ids()), FEVERLabelSchema())
     self.reader = JSONLineReader()
예제 #4
0
                        help=("String option specifying tokenizer type to use "
                              "(e.g. 'corenlp')"))

    parser.add_argument('--num-workers',
                        type=int,
                        default=None,
                        help='Number of CPU processes (for tokenizing, etc)')
    args = parser.parse_args()
    doc_freqs = None
    if args.use_precomputed:
        _, metadata = utils.load_sparse_csr(args.model)
        doc_freqs = metadata['doc_freqs'].squeeze()

    db = FeverDocDB("data/fever/fever.db")
    jlr = JSONLineReader()
    formatter = FEVERGoldFormatter(set(), FEVERLabelSchema())

    jlr = JSONLineReader()

    with open(args.in_file, "r") as f, open(
            "data/fever/{0}.sentences.{3}.p{1}.s{2}.jsonl".format(
                args.split, args.max_page, args.max_sent,
                "precomputed" if args.use_precomputed else "not_precomputed"),
            "w+") as out_file:
        lines = jlr.process(f)
        #lines = tf_idf_claims_batch(lines)

        for line in tqdm(lines):
            line = tf_idf_claim(line)
            out_file.write(json.dumps(line) + "\n")
예제 #5
0
    logger.info("Model name is {0}".format(mname))

    ffns = []

    if args.sentence:
        logger.info("Model is Sentence level")
        ffns.append(SentenceLevelTermFrequencyFeatureFunction(db,
                                                              naming=mname))
    else:
        logger.info("Model is Document level")
        ffns.append(TermFrequencyFeatureFunction(db, naming=mname))

    f = Features(mname, ffns)
    f.load_vocab(mname)

    jlr = JSONLineReader()
    formatter = FEVERGoldFormatter(None, FEVERLabelSchema())

    test_ds = DataSet(file=args.test, reader=jlr, formatter=formatter)
    test_ds.read()
    feats = f.lookup(test_ds)

    input_shape = feats[0].shape[1]
    model = SimpleMLP(input_shape, 100, 3)

    if gpu():
        model.cuda()

    model.load_state_dict(torch.load("models/{0}.model".format(mname)))
    print_evaluation(model, feats, FEVERLabelSchema(), args.log)