Esempio n. 1
0
def process_args():
    #parser = argparse.ArgumentParser()
    parser = config.get_argparse()
    #parser.add_argument('-b', '--bypass-scheduler',
    #                    action='store_true', default=False, required=False,
    #                    help='Bypass scheduler and sync/vote immediately',
    #                    dest='bypass')
    args = parser.parse_args()
    #option, args = parser.parse_known_args()

    return args
Esempio n. 2
0
 def _load_predict(self) -> Predict:
     self.args = get_argparse().parse_args()
     self.args.device = device
     if self.use_cut:
         model = ModelInfo(
             'bert_base_chinese',
             'bert_base_chinese-lstm-crf-cut-redundant-epoch_9.bin',
             '/home/data_ti5_d/maoyl/GeoQA/save_model')
     else:
         model = ModelInfo(
             'bert_base_chinese',
             'bert_base_chinese-lstm-crf-no_cut-redundant-epoch_9.bin',
             '/home/data_ti5_d/maoyl/GeoQA/save_model')
     self.args.bert_path = os.path.join('./bert', model.bert_type)
     model_path = os.path.join(model.path, model.name)
     print('model_path: ' + model_path)
     return Predict(self.args, model_path)
Esempio n. 3
0
                                       real_labels=out_label_ids,
                                       predict_labels=tags)

    # 输出指标
    print("***** %s results ***** " % data_type)
    eval_loss = eval_loss / nb_eval_steps
    performance.res_dict['eval_loss'] = eval_loss
    return performance


if __name__ == "__main__":
    processor = CnerProcessor()
    label_list = processor.get_labels()
    num_labels = len(label_list)
    # 将标签进行id映射
    args = get_argparse().parse_args()
    args.id2label = {i: label for i, label in enumerate(label_list)}
    args.label2id = {label: i for i, label in enumerate(label_list)}

    bert_names = [
        'bert_base_chinese', 'chinese_bert_wwm_ext',
        'chinese_roberta_wwm_ext_large'
    ]
    data_process_types = ['data_no_graph']
    cuts = ['cut', 'no_cut']
    redundants = ['redundant', 'no_redundant']

    args.use_lstm = True
    args.max_seq_length = 256
    for name in bert_names:
        args.bert_path = './bert/' + name
Esempio n. 4
0
        train_loss_history.append(training_loss)

        train_acc = training_corrects / len(train_dataset)

        train_acc_history.append(train_acc)

        print(
            f'Training loss: {training_loss:.4f}\taccuracy: {train_acc:.4f}\n')

        if train_acc > best_acc:
            best_acc = train_acc
            best_model_params = copy.deepcopy(model.state_dict())

    model.load_state_dict(best_model_params)
    torch.save(model, args.model_path)

    with open(args.info_path, 'wb') as f:
        pickle.dump(
            {
                'train_loss_history': train_loss_history,
                'train_acc_history': train_acc_history,
            }, f)


if __name__ == "__main__":
    from config import get_argparse
    parser = get_argparse()
    args = parser.parse_args()

    train(args)
Esempio n. 5
0
    total_correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in tqdm(data_loader,
                                   total=len(data_loader),
                                   ascii=True):
            inputs, labels = inputs.to(device), labels.to(device)

            output = model(inputs)

            _, predicted = torch.max(output.detach(), 1)
            total += labels.size(0)

            total_correct += (predicted == labels).sum().item()

        print(total)
        acc = total_correct / total

        print("Accuracy: {:.4f}".format(acc))

    return acc


if __name__ == "__main__":
    from config import get_argparse
    parse = get_argparse()
    args = parse.parse_args()
    test(args)
Esempio n. 6
0
        valid_loss = sum(valid_loss) / float(len(valid_corrects))
        valid_real_acc = sum(valid_real_corrects) / float(
            len(valid_real_corrects))
        valid_fake_acc = sum(valid_fake_corrects) / float(
            len(valid_fake_corrects))

    return valid_loss, valid_acc, valid_real_acc, valid_fake_acc, valid_real_corrects, valid_fake_corrects


if __name__ == '__main__':
    random_seed = 219373241
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    np.random.seed(random_seed)

    opt = get_argparse()

    bert = BertModel.from_pretrained(opt.bert_dir)
    bert_tokenizer = BertTokenizer.from_pretrained(opt.bert_dir)
    bert.to(opt.device)
    bert.eval()

    DSET = eval(opt.data)

    if not opt.test:
        os.makedirs(opt.results_dir)
        trn_dset = DSET(opt, bert_tokenizer, 'train')
        val_dset = DSET(opt, bert_tokenizer, 'validate')
        tst_dset = DSET(opt, bert_tokenizer, 'test')
    else:
        tst_dset = DSET(opt, bert_tokenizer, 'test')