예제 #1
0
def read_combine_data(corpus_files,
                      dev_files,
                      rebuild_maps=False,
                      mini_count=0):
    assert len(corpus_files) == len(dev_files)
    corpus_data = []
    for i, corpus_f in enumerate(corpus_files):
        curr_data = []
        with codecs.open(corpus_f, 'r', 'utf-8') as f:
            curr_data += f.readlines()
        curr_data += ['\n']
        with codecs.open(dev_files[i], 'r', 'utf-8') as df:
            curr_data += df.readlines()
        corpus_data.append(curr_data)

    tokens = []
    labels = []

    token2idx = dict()
    tag2idx = dict()
    chr_cnt = dict()
    chr2idx = dict()

    for data in corpus_data:
        if rebuild_maps:
            print('constructing coding table')
            # here token2idx, tag2idx and chr_cnt are doing argmentaion
            curr_tokens, curr_labels, token2idx, tag2idx, chr_cnt = utils.generate_corpus_char(
                data,
                token2idx,
                tag2idx,
                chr_cnt,
                c_thresholds=mini_count,
                if_shrink_w_feature=False)
        else:
            curr_tokens, curr_labels = utils.read_corpus(data)
        tokens.append(curr_tokens)
        labels.append(curr_labels)

    shrink_char_count = [
        k for (k, v) in iter(chr_cnt.items()) if v >= mini_count
    ]
    chr2idx = {
        shrink_char_count[ind]: ind
        for ind in range(0, len(shrink_char_count))
    }

    chr2idx['<u>'] = len(chr2idx)  # unk for char
    chr2idx[' '] = len(chr2idx)  # concat for char
    chr2idx['\n'] = len(chr2idx)  # eof for char

    if rebuild_maps:
        return tokens, labels, token2idx, tag2idx, chr2idx
    else:
        return tokens, labels
예제 #2
0
    checkpoint_file = torch.load(args.load_check_point)
    f_map = checkpoint_file['f_map']
    CRF_l_map = checkpoint_file['CRF_l_map']
    c_map = checkpoint_file['c_map']
    in_doc_words = checkpoint_file['in_doc_words']
    SCRF_l_map = checkpoint_file['SCRF_l_map']
    ALLOW_SPANLEN = checkpoint_file['ALLOW_SPANLEN']

    with codecs.open(args.dev_file, 'r', 'utf-8') as f:
        dev_lines = f.readlines()

    with codecs.open(args.test_file, 'r', 'utf-8') as f:
        test_lines = f.readlines()

    dev_features, dev_labels = utils.read_corpus(dev_lines)
    test_features, test_labels = utils.read_corpus(test_lines)

    dev_dataset = utils.construct_bucket_mean_vb_wc(
        dev_features,
        dev_labels,
        CRF_l_map,
        SCRF_l_map,
        c_map,
        f_map,
        SCRF_stop_tag=SCRF_l_map['<STOP>'],
        train_set=False)
    test_dataset = utils.construct_bucket_mean_vb_wc(
        test_features,
        test_labels,
        CRF_l_map,
예제 #3
0
    args = parser.parse_args()

    CRF_l_map, SCRF_l_map = utils.get_crf_scrf_label()

    print('setting:')
    print(args)

    print('loading corpus')
    with codecs.open(args.train_file, 'r', 'utf-8') as f:
        lines = f.readlines()
    with codecs.open(args.dev_file, 'r', 'utf-8') as f:
        dev_lines = f.readlines()
    with codecs.open(args.test_file, 'r', 'utf-8') as f:
        test_lines = f.readlines()

    dev_features, dev_labels = utils.read_corpus(dev_lines)
    test_features, test_labels = utils.read_corpus(test_lines)

    if args.load_check_point:
        if os.path.isfile(args.load_check_point):
            print("loading checkpoint: '{}'".format(args.load_check_point))
            checkpoint_file = torch.load(args.load_check_point)
            args.start_epoch = checkpoint_file['epoch']
            f_map = checkpoint_file['f_map']
            c_map = checkpoint_file['c_map']
            in_doc_words = checkpoint_file['in_doc_words']
            train_features, train_labels = utils.read_corpus(lines)
        else:
            print("no checkpoint found at: '{}'".format(args.load_check_point))
            sys.exit()
    else:
예제 #4
0
    dev_dataset_loader = []
    test_dataset_loader = []
    f_map = dict()
    l_map = dict()
    char_count = dict()
    train_features = []
    dev_features = []
    test_features = []
    train_labels = []
    dev_labels = []
    test_labels = []
    train_features_tot = []
    test_word = []

    for i in range(file_num):
        dev_features0, dev_labels0 = utils.read_corpus(dev_lines[i])
        test_features0, test_labels0 = utils.read_corpus(test_lines[i])

        dev_features.append(dev_features0)
        test_features.append(test_features0)
        dev_labels.append(dev_labels0)
        test_labels.append(test_labels0)

        if args.output_annotation:  #NEW
            test_word0 = utils.read_features(test_lines[i])
            test_word.append(test_word0)

        if args.load_check_point:
            if os.path.isfile(args.load_check_point):
                print("loading checkpoint: '{}'".format(args.load_check_point))
                checkpoint_file = torch.load(args.load_check_point)
예제 #5
0
    else:
        if_cuda = False
        packer = CRFRepack_WC(len(l_map), False)

    decode_label = (args.decode_type == 'label')
    predictor = predict_wc(if_cuda, f_map, c_map, l_map, f_map['<eof>'], c_map['\n'], l_map['<pad>'], l_map['<start>'],
                           decode_label, args.batch_size, jd['caseless'])

    # loading corpus
    lines = []
    features = []
    labels = []
    with codecs.open(args.input_file, 'r', 'utf-8') as f:
        for line in f:
            if line == '\n':
                f, l = utils.read_corpus(lines)
                features.append([f])
                labels.append(l)
                lines = []
                continue
            lines.append(line)

    # for idx in range(args.dataset_no):
    #     print('annotating the entity type', idx)
    #     fout = open(args.output_file+str(idx)+'.txt', 'w')
    #     for feature in features:
    #         predictor.output_batch(ner_model, feature, fout, idx)
    #         # predictor.combined_output_batch(ner_model, feature, fout)
    #         fout.write('\n')
    #     fout.close()
예제 #6
0
파일: eval.py 프로젝트: xhuang28/NewBioNer
        # assert len(train_args['dev_file']) == len(train_args['test_file'])
        num_corpus = len(train_args['dev_file'])


        # construct the pred and eval dataloader
        dev_tokens = []
        dev_labels = []

        test_tokens = []
        test_labels = []

        for i in range(num_corpus):
            dev_lines = []
            with codecs.open(train_args['dev_file'][i], 'r', 'utf-8') as f:
                dev_lines = f.readlines()
            dev_features, dev_l = utils.read_corpus(dev_lines)
            dev_tokens.append(dev_features)
            dev_labels.append(dev_l)

            test_lines = []
            with codecs.open(train_args['test_file'][i], 'r', 'utf-8') as f:
                test_lines = f.readlines()
            test_features, test_l = utils.read_corpus(test_lines)
            test_tokens.append(test_features)
            test_labels.append(test_l)

        """
        try:
            print("Load from PICKLE")
            single_devset = pickle.load(open(args.pickle + "/single_dev.p", "rb" ))
            dev_dataset_loader = []
예제 #7
0
        torch.cuda.set_device(args.gpu)

    print('setting:')
    print(args)

    '''
    load corpus
    '''
    print('loading corpus')
    with codecs.open(args.train_file, 'r', 'utf-8') as f:
        train_lines = f.readlines()

    with codecs.open(args.test_file, 'r', 'utf-8') as f:
        test_lines = f.readlines()

    raw_train_words, raw_train_labels, raw_train_seg_labels, raw_train_ent_labels = utils.read_corpus(train_lines)

    len_train = int(len(raw_train_words) * 0.9)

    train_words, train_labels, train_seg_labels, train_ent_labels = raw_train_words[:len_train], raw_train_labels[:len_train], raw_train_seg_labels[:len_train], raw_train_ent_labels[:len_train]
    dev_words, dev_labels, dev_seg_labels, dev_ent_labels = raw_train_words[len_train:], raw_train_labels[len_train:], raw_train_seg_labels[len_train:], raw_train_ent_labels[len_train:]
    test_words, test_labels, test_seg_labels, test_ent_labels = utils.read_corpus(test_lines)

    '''
    generating str2id mapping
    '''
    print('constructing str to id maps')
    w_map, l_map, seg_l_map, ent_l_map = utils.generate_mappings(train_lines+test_lines, caseless = args.caseless, thresholds = args.mini_count, unknown = args.unk)

    c_map = utils.generate_charmapping(train_words+dev_words+test_words)
예제 #8
0
    bichar_f_map = checkpoint_file['bichar_f_map']
    is_bichar = checkpoint_file['bichar']

    if args.gpu >= 0:
        torch.cuda.set_device(args.gpu)

    if args.test_file:
        with codecs.open(args.test_file, 'r', 'utf-8') as f:
            test_lines = f.readlines()
    else:
        with codecs.open(jd['test_file'], 'r', 'utf-8') as f:
            test_lines = f.readlines()


    # converting format
    test_features, test_labels, test_bichar_features = utils.read_corpus(test_lines)

    with codecs.open(args.lexicon_test_file, 'r', 'utf-8') as f:
        lexicon_test_lines = f.readlines()
    lexicon_test_features, lexicon_feature_map = utils.read_corpus_lexicon(lexicon_test_lines, test_features,
                                                                           lexicon_f_map)
    lexicon_test_dataset = utils.padding_lexicon_bucket(lexicon_test_features, lexicon_f_map, args.gpu)

    # construct dataset
    test_dataset = utils.construct_bucket_mean_vb(test_features, test_labels, lexicon_test_dataset, f_map, l_map,
                                                  test_bichar_features, bichar_f_map, jd['caseless'])

    # build model
    ner_model = LSTM_CRF(len(f_map), len(bichar_f_map), len(lexicon_f_map), len(l_map), jd['embedding_dim'], jd['hidden'], jd['layers'], jd['drop_out'], args.gpu, is_bichar, large_CRF=jd['small_crf'])

    ner_model.load_state_dict(checkpoint_file['state_dict'])
예제 #9
0
    # load corpus
    print('loading corpus')
    with codecs.open(args.train_file, 'r', 'utf-8') as f:
        train_lines = f.readlines()
    with codecs.open(args.dev_file, 'r', 'utf-8') as f:
        dev_lines = f.readlines()
    with codecs.open(args.test_file, 'r', 'utf-8') as f:
        test_lines = f.readlines()
    with codecs.open(args.cotrain_file_1, 'r', 'utf-8') as f:
        cotrain_lines=f.readlines()
    
    #load prior knowledge
    with open('../Data/source_full.json') as f:
        knowledge_dict = json.load(f)

    train_features, train_labels=utils.read_corpus(train_lines)
    dev_features, dev_labels = utils.read_corpus(dev_lines)
    test_features, test_labels = utils.read_corpus(test_lines)
    co_features,co_labels=utils.read_corpus(cotrain_lines+dev_lines+test_lines)
    
    if args.load_check_point:
        if os.path.isfile(args.load_check_point):
            print("loading checkpoint: '{}'".format(args.load_check_point))
            checkpoint_file = torch.load(args.load_check_point)
            args.start_epoch = checkpoint_file['epoch']
            f_map = checkpoint_file['f_map']
            l_map = checkpoint_file['l_map']
            c_map = checkpoint_file['c_map']
            in_doc_words = checkpoint_file['in_doc_words']
            train_features, train_labels = utils.read_corpus(train_lines)
        else:
예제 #10
0
    if args.gpu >= 0:
        torch.cuda.set_device(args.gpu)

    print('setting:')
    print(args)

    # load corpus
    print('loading corpus')
    with codecs.open(args.train_file, 'r', 'utf-8') as f:
        lines = f.readlines()
    with codecs.open(args.dev_file, 'r', 'utf-8') as f:
        dev_lines = f.readlines()

    # converting format
    dev_features, dev_labels, dev_bichar_features = utils.read_corpus(
        dev_lines)
    with codecs.open(args.test_file, 'r', 'utf-8') as f:
        test_lines = f.readlines()
    test_features, test_labels, test_bichar_features = utils.read_corpus(
        test_lines)

    with codecs.open(args.lexicon_train_dir, 'r', 'utf-8') as f:
        lexicon_train_lines = f.readlines()
    with codecs.open(args.lexicon_dev_dir, 'r', 'utf-8') as f:
        lexicon_dev_lines = f.readlines()

    # converting format
    lexicon_f_map = dict()
    lexicon_dev_features, lexicon_feature_map = utils.read_corpus_lexicon(
        lexicon_dev_lines, dev_features, lexicon_f_map)
    with codecs.open(args.lexicon_test_dir, 'r', 'utf-8') as f:
예제 #11
0
def train(args):
    """
    Perform training
    """
    train_data = read_corpus(args.data_dir)
    dev_data = read_corpus(args.valid_dir)
    train_batch_size = args.batch_size
    clip_grad = args.clip_grad
    valid_freq = args.valid_freq
    save_path = args.model_save_path
    vocab = Vocab.load('vocab.json')
    device = torch.device("cuda:0" if args.cuda else "cpu")
    max_patience = 5
    max_num_trial = 5
    learning_rate_decay = 0.5
    max_epoch = 5000
    model_save_path = args.save_path

    model = RNNLM(embed_size=args.embed_size,
                  hidden_size=args.hidden_size,
                  vocab=vocab,
                  dropout_rate=args.dropout,
                  device=device,
                  tie_embed=args.tie_embed)

    model.train()

    #Xavier initialization
    for p in model.parameters():
        p.data.uniform_(-0.1, 0.1)

    model.to(device)

    #TODO Tunable learning rate
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    num_trial = 0
    train_iter = patience = cum_loss = report_loss = total_word = report_total_word = 0

    cum_examples = report_examples = epoch = valid_num = 0

    hist_valid_scores = []
    train_time = begin_time = time.time()

    print("Begin training")

    while True:
        epoch += 1

        for sent_batch in batch_iter(train_data,
                                     batch_size=train_batch_size,
                                     shuffle=True):
            train_iter += 1

            optimizer.zero_grad()

            batch_size = len(sent_batch)

            example_losses = -model(sent_batch)

            batch_loss = example_losses.sum()
            loss = batch_loss / batch_size

            loss.backward()

            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       clip_grad)

            optimizer.step()

            batch_losses_val = batch_loss.item()
            report_loss += batch_losses_val
            cum_loss += batch_losses_val

            tgt_word_num_to_predict = sum(len(s[1:]) for s in sent_batch)
            total_word += tgt_word_num_to_predict

            report_total_word += tgt_word_num_to_predict
            report_examples += batch_size
            cum_examples += batch_size

            if train_iter % 10 == 0:
                print('epoch %d, iter %d, avg.loss %.2f, avg. ppl %.2f' \
                    'cum. examples %d, speed %.2f words/sec, time elapsed %.2f sec' %(epoch, train_iter, report_loss/ report_examples, math.exp(report_loss / report_total_word),cum_examples, report_total_word / (time.time() - train_time), time.time() - begin_time), file = sys.stderr)

                train_time = time.time()
                report_loss = report_total_word = report_examples = 0.

            #VALIDATION
            if train_iter % valid_freq == 0:
                print(
                    "epoch %d, iter %d, cum. loss %.2f, cum. ppl %.2f, cum. examples %d"
                    % (epoch, train_iter, cum_loss / cum_examples,
                       np.exp(cum_loss / total_word), cum_examples),
                    file=sys.stderr)

                cum_loss = cum_examples = total_word = 0

                valid_num += 1

                print("Begin validation", file=sys.stderr)

                dev_ppl = evaluate_ppl(model, dev_data, batch_size=128)
                valid_metric = -dev_ppl

                print("validation: iter %d, dev. ppl %f" %
                      (train_iter, dev_ppl),
                      file=sys.stderr)

                is_better = len(hist_valid_scores
                                ) == 0 or valid_metric > max(hist_valid_scores)

                if is_better:
                    patience = 0
                    print("Save currently best model")
                    model.save(model_save_path)

                    torch.save(optimizer.state_dict(),
                               model_save_path + '.optim')

                elif patience < max_patience:
                    patience += 1

                    print("hit patience %d" % patience, file=sys.stderr)
                    if patience == max_patience:
                        num_trial += 1

                        if num_trial == max_num_trial:
                            print("early stop!", file=sys.stderr)
                            exit(0)

                        #Learning rate decay
                        lr = optimizer.param_groups[0][
                            'lr'] * learning_rate_decay

                        #load previous best model
                        params = torch.load(
                            model_save_path,
                            map_location=lambda storage, loc: storage)

                        model.load_state_dict(params['state_dict'])

                        model = model.to(device)

                        optimizer.load_state_dict(
                            torch.load(model_save_path + '.optim'))

                        #load learning rate
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr

                        patience = 0

                if epoch == max_epoch:
                    print("maximum epoch reached!", file=sys.stderr)
                    exit(0)
예제 #12
0
파일: eval_w.py 프로젝트: lzbgt/LM-LSTM-CRF
    if args.gpu >= 0:
        torch.cuda.set_device(args.gpu)


    # load corpus

    if args.test_file:
        with codecs.open(args.test_file, 'r', 'utf-8') as f:
            test_lines = f.readlines()
    else:
        with codecs.open(jd['test_file'], 'r', 'utf-8') as f:
            test_lines = f.readlines()

    # converting format

    test_features, test_labels = utils.read_corpus(test_lines)

    # construct dataset
    test_dataset = utils.construct_bucket_mean_vb(test_features, test_labels, f_map, l_map, jd['caseless'])
    
    test_dataset_loader = [torch.utils.data.DataLoader(tup, 50, shuffle=False, drop_last=False) for tup in test_dataset]

    # build model
    ner_model = LSTM_CRF(len(f_map), len(l_map), jd['embedding_dim'], jd['hidden'], jd['layers'], jd['drop_out'], large_CRF=jd['small_crf'])

    ner_model.load_state_dict(checkpoint_file['state_dict'])

    if args.gpu >= 0:
        if_cuda = True
        torch.cuda.set_device(args.gpu)
        ner_model.cuda()
예제 #13
0
        return vocab


class Vocab(object):
    def __init__(self,vocab):
        self.vocab = vocab

    def __len__(self):
        return len(self.vocab)


    @staticmethod
    def build(sents):
        vocab = Dictionary.from_corpus(sents)
        return Vocab(vocab)

    def save(self, file_path):
        json.dump(dict(src_word2id = self.vocab.word2id), open(file_path, 'w'), indent = 2)

    @staticmethod
    def load(file_path):
        entry = json.load(open(file_path, 'r'))
        word2id = entry['src_word2id']
        return Dictionary(word2id) 

if __name__ == '__main__':
    args = get_args()
    sents = read_corpus(args.data_dir)
    vocab = Vocab.build(sents)
    vocab.save('vocab.json')
예제 #14
0
    def read_dataset(self, file_dict, dataset_name, *args, **kwargs):
        print('loading corpus')
        self.file_num = len(self.args.train_file)
        for i in range(self.file_num):
            with codecs.open(self.args.train_file[i], 'r', 'utf-8') as f:
                lines0 = f.readlines()
                lines0 = lines0[0:2000]
                # print (len(lines0))
            self.lines.append(lines0)
        for i in range(self.file_num):
            with codecs.open(self.args.dev_file[i], 'r', 'utf-8') as f:
                dev_lines0 = f.readlines()
                dev_lines0 = dev_lines0[0:2000]
            self.dev_lines.append(dev_lines0)
        for i in range(self.file_num):
            with codecs.open(self.args.test_file[i], 'r', 'utf-8') as f:
                test_lines0 = f.readlines()
                test_lines0 = test_lines0[0:2000]
            self.test_lines.append(test_lines0)

        for i in range(self.file_num):
            dev_features0, dev_labels0 = utils.read_corpus(self.dev_lines[i])
            test_features0, test_labels0 = utils.read_corpus(
                self.test_lines[i])

            self.dev_features.append(dev_features0)
            self.test_features.append(test_features0)
            self.dev_labels.append(dev_labels0)
            self.test_labels.append(test_labels0)

            if self.args.output_annotation:  # NEW
                test_word0, test_word_tag0 = utils.read_features(
                    self.test_lines[i])
                self.test_word.append(test_word0)
                self.test_word_tag.append(test_word_tag0)
            #print (len(self.test_word), len(self.test_labels))
            if self.args.load_check_point:
                if os.path.isfile(self.args.load_check_point):
                    print("loading checkpoint: '{}'".format(
                        self.args.load_check_point))
                    self.checkpoint_file = torch.load(
                        self.args.load_check_point)
                    self.args.start_epoch = self.checkpoint_file['epoch']
                    self.f_map = self.checkpoint_file['f_map']
                    self.l_map = self.checkpoint_file['l_map']
                    c_map = self.checkpoint_file['c_map']
                    self.in_doc_words = self.checkpoint_file['in_doc_words']
                    self.train_features, self.train_labels = utils.read_corpus(
                        self.lines[i])
                else:
                    print("no checkpoint found at: '{}'".format(
                        self.args.load_check_point))
            else:
                print('constructing coding table')
                train_features0, train_labels0, self.f_map, self.l_map, self.char_count = utils.generate_corpus_char(
                    self.lines[i],
                    self.f_map,
                    self.l_map,
                    self.char_count,
                    c_thresholds=self.args.mini_count,
                    if_shrink_w_feature=False)
            self.train_features.append(train_features0)
            self.train_labels.append(train_labels0)

            self.train_features_tot += train_features0

        shrink_char_count = [
            k for (k, v) in iter(self.char_count.items())
            if v >= self.args.mini_count
        ]
        self.char_map = {
            shrink_char_count[ind]: ind
            for ind in range(0, len(shrink_char_count))
        }

        self.char_map['<u>'] = len(self.char_map)  # unk for char
        self.char_map[' '] = len(self.char_map)  # concat for char
        self.char_map['\n'] = len(self.char_map)  # eof for char

        f_set = {v for v in self.f_map}
        dt_f_set = f_set
        self.f_map = utils.shrink_features(self.f_map, self.train_features_tot,
                                           self.args.mini_count)
        l_set = set()

        for i in range(self.file_num):
            dt_f_set = functools.reduce(
                lambda x, y: x | y, map(lambda t: set(t),
                                        self.dev_features[i]), dt_f_set)
            dt_f_set = functools.reduce(
                lambda x, y: x | y, map(lambda t: set(t),
                                        self.test_features[i]), dt_f_set)

            l_set = functools.reduce(lambda x, y: x | y,
                                     map(lambda t: set(t), self.dev_labels[i]),
                                     l_set)
            l_set = functools.reduce(
                lambda x, y: x | y, map(lambda t: set(t), self.test_labels[i]),
                l_set)

        if not self.args.rand_embedding:
            print("feature size: '{}'".format(len(self.f_map)))
            print('loading embedding')
            if self.args.fine_tune:  # which means does not do fine-tune
                self.f_map = {'<eof>': 0}
            self.f_map, self.embedding_tensor, self.in_doc_words = utils.load_embedding_wlm(
                self.args.emb_file,
                ' ',
                self.f_map,
                dt_f_set,
                self.args.caseless,
                self.args.unk,
                self.args.word_dim,
                shrink_to_corpus=self.args.shrink_embedding)
            print("embedding size: '{}'".format(len(self.f_map)))

        for label in l_set:

            if label not in self.l_map:
                self.l_map[label] = len(self.l_map)

        print('constructing dataset')
        for i in range(self.file_num):
            # construct dataset
            dataset, forw_corp, back_corp = utils.construct_bucket_mean_vb_wc(
                self.train_features[i], self.train_labels[i], self.l_map,
                self.char_map, self.f_map, self.args.caseless)
            dev_dataset, forw_dev, back_dev = utils.construct_bucket_mean_vb_wc(
                self.dev_features[i], self.dev_labels[i], self.l_map,
                self.char_map, self.f_map, self.args.caseless)
            test_dataset, forw_test, back_test = utils.construct_bucket_mean_vb_wc(
                self.test_features[i], self.test_labels[i], self.l_map,
                self.char_map, self.f_map, self.args.caseless)
            self.dataset_loader.append([
                torch.utils.data.DataLoader(tup,
                                            self.args.batch_size,
                                            shuffle=True,
                                            drop_last=False) for tup in dataset
            ])
            self.dev_dataset_loader.append([
                torch.utils.data.DataLoader(tup,
                                            50,
                                            shuffle=False,
                                            drop_last=False)
                for tup in dev_dataset
            ])
            self.test_dataset_loader.append([
                torch.utils.data.DataLoader(tup,
                                            50,
                                            shuffle=False,
                                            drop_last=False)
                for tup in test_dataset
            ])
예제 #15
0
    if args.gpu >= 0:
        torch.cuda.set_device(args.gpu)

    print('setting:')
    print(args)

    # load corpus
    print('loading corpus')
    with codecs.open(args.train_file, 'r', 'utf-8') as f:
        lines = f.readlines()

    with codecs.open(args.test_file, 'r', 'utf-8') as f:
        test_lines = f.readlines()

    test_features, test_labels = utils.read_corpus(test_lines)  #测试集

    if args.load_check_point:
        if os.path.isfile(args.load_check_point):
            print("loading checkpoint: '{}'".format(args.load_check_point))
            checkpoint_file = torch.load(args.load_check_point)
            args.start_epoch = checkpoint_file['epoch']
            f_map = checkpoint_file['f_map']
            l_map = checkpoint_file['l_map']
            train_features, train_labels = utils.read_corpus(lines)
        else:
            print("no checkpoint found at: '{}'".format(args.load_check_point))
    else:
        print('constructing coding table')

        # converting format