示例#1
0
def predict(sentence1, sentence2):
    ptr_dir = "C:/Users/aaaaa/workspace/fact-check/BERT_pairwise_text_classification/pretrained"
    data_dir = "C:/Users/aaaaa/workspace/fact-check/BERT_pairwise_text_classification/data"
    caseType = "skt"
    model_dir = "C:/Users/aaaaa/workspace/fact-check/BERT_pairwise_text_classification/experiments/base_model"
    checkpoint_model_file = "best_skt.tar"
    
    # ptr_dir = "BERT_pairwise_text_classification/pretrained"
    # data_dir = "BERT_pairwise_text_classification/data"
    # caseType = "skt"
    # model_dir = "BERT_pairwise_text_classification/experiments/base_model"
    # checkpoint_model_file = "best_skt.tar"
    
    # ptr_dir = "pretrained"
    # data_dir = "data"
    # caseType = "skt"
    # model_dir = "experiments/base_model"
    # checkpoint_model_file = "best_skt.tar"
    
    ptr_dir = Path(ptr_dir)
    data_dir = Path(data_dir)
    model_dir = Path(model_dir)
    checkpoint_model_file = Path(checkpoint_model_file)
    
    ptr_config = Config(ptr_dir / 'config_skt.json')
    data_config = Config(data_dir / 'config.json')
    model_config = Config(model_dir / 'config.json')
    
    # vocab
    with open(os.path.join(ptr_dir, ptr_config.vocab), mode='rb') as io:
        vocab = pickle.load(io)
    
    
    ptr_tokenizer = SentencepieceTokenizer(os.path.join(ptr_dir, ptr_config.tokenizer))
    pad_sequence = PadSequence(length=model_config.length, pad_val=vocab.to_indices(vocab.padding_token))
    preprocessor = PreProcessor(vocab=vocab, split_fn=ptr_tokenizer, pad_fn=pad_sequence)
    
    # model (restore)
    checkpoint_manager = CheckpointManager(model_dir)
    checkpoint = checkpoint_manager.load_checkpoint(checkpoint_model_file)
    config = BertConfig(os.path.join(ptr_dir, ptr_config.config))
    model = PairwiseClassifier(config, num_classes=model_config.num_classes, vocab=preprocessor.vocab)
    model.load_state_dict(checkpoint['model_state_dict'])
    
    device = torch.device('cpu')
    model.to(device)
    
    transform = preprocessor.preprocess
    if model.training:
        model.eval()
        
    indices, token_types = [torch.tensor([elm]) for elm in transform(sentence1, sentence2)]

    with torch.no_grad():
        label = model(indices, token_types)
    label = label.max(dim=1)[1]
    label = label.numpy()[0]

    return label
示例#2
0
def main(args):
    dataset_config = Config(args.dataset_config)
    model_config = Config(args.model_config)
    ptr_config_info = Config(f"conf/pretrained/{model_config.type}.json")

    exp_dir = Path("experiments") / model_config.type
    exp_dir = exp_dir.joinpath(
        f"epochs_{args.epochs}_batch_size_{args.batch_size}_learning_rate_{args.learning_rate}"
        f"_weight_decay_{args.weight_decay}"
    )

    preprocessor = get_preprocessor(ptr_config_info, model_config)

    with open(ptr_config_info.config, mode="r") as io:
        ptr_config = json.load(io)

    # model (restore)
    checkpoint_manager = CheckpointManager(exp_dir)
    checkpoint = checkpoint_manager.load_checkpoint('best.tar')
    config = BertConfig()
    config.update(ptr_config)
    model = PairwiseClassifier(config, num_classes=model_config.num_classes, vocab=preprocessor.vocab)
    model.load_state_dict(checkpoint['model_state_dict'])

    # evaluation
    filepath = getattr(dataset_config, args.data)
    ds = Corpus(filepath, preprocessor.preprocess)
    dl = DataLoader(ds, batch_size=args.batch_size, num_workers=4)
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    model.to(device)

    summary_manager = SummaryManager(exp_dir)
    summary = evaluate(model, dl, {'loss': nn.CrossEntropyLoss(), 'acc': acc}, device)

    summary_manager.load('summary.json')
    summary_manager.update({'{}'.format(args.data): summary})
    summary_manager.save('summary.json')

    print('loss: {:.3f}, acc: {:.2%}'.format(summary['loss'], summary['acc']))
    elif args.type == 'skt':
        ptr_tokenizer = SentencepieceTokenizer(ptr_config.tokenizer)
        pad_sequence = PadSequence(length=model_config.length,
                                   pad_val=vocab.to_indices(
                                       vocab.padding_token))
        preprocessor = PreProcessor(vocab=vocab,
                                    split_fn=ptr_tokenizer,
                                    pad_fn=pad_sequence)

    # model (restore)
    checkpoint_manager = CheckpointManager(model_dir)
    checkpoint = checkpoint_manager.load_checkpoint('best_{}.tar'.format(
        args.type))
    config = BertConfig(ptr_config.config)
    model = PairwiseClassifier(config,
                               num_classes=model_config.num_classes,
                               vocab=preprocessor.vocab)
    model.load_state_dict(checkpoint['model_state_dict'])

    # evaluation
    filepath = getattr(data_config, args.dataset)
    ds = Corpus(filepath, preprocessor.preprocess)
    dl = DataLoader(ds, batch_size=model_config.batch_size, num_workers=4)

    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')
    model.to(device)

    summary_manager = SummaryManager(model_dir)
    summary = evaluate(model, dl, {
        'loss': nn.CrossEntropyLoss(),
示例#4
0
        vocab = pickle.load(io)

    # tokenizer
    # tokenizer
    if args.type == 'etri':
        ptr_tokenizer = ETRITokenizer.from_pretrained(ptr_config.tokenizer, do_lower_case=False)
        pad_sequence = PadSequence(length=model_config.length, pad_val=vocab.to_indices(vocab.padding_token))
        preprocessor = PreProcessor(vocab=vocab, split_fn=ptr_tokenizer.tokenize, pad_fn=pad_sequence)
    elif args.type == 'skt':
        ptr_tokenizer = SentencepieceTokenizer(ptr_config.tokenizer)
        pad_sequence = PadSequence(length=model_config.length, pad_val=vocab.to_indices(vocab.padding_token))
        preprocessor = PreProcessor(vocab=vocab, split_fn=ptr_tokenizer, pad_fn=pad_sequence)

    # model
    config = BertConfig(ptr_config.config)
    model = PairwiseClassifier(config, num_classes=model_config.num_classes, vocab=preprocessor.vocab)
    bert_pretrained = torch.load(ptr_config.bert)
    model.load_state_dict(bert_pretrained, strict=False)

    # training
    tr_ds = Corpus(data_config.train, preprocessor.preprocess)
    tr_dl = DataLoader(tr_ds, batch_size=model_config.batch_size, shuffle=True, num_workers=4, drop_last=True)
    val_ds = Corpus(data_config.validation, preprocessor.preprocess)
    val_dl = DataLoader(val_ds, batch_size=model_config.batch_size, num_workers=4)

    loss_fn = nn.CrossEntropyLoss()
    opt = optim.Adam(
        [
            {"params": model.bert.parameters(), "lr": model_config.learning_rate / 100},
            {"params": model.classifier.parameters(), "lr": model_config.learning_rate},
示例#5
0
def main(args):
    dataset_config = Config(args.dataset_config)
    model_config = Config(args.model_config)
    ptr_config_info = Config(f"conf/pretrained/{model_config.type}.json")

    exp_dir = Path("experiments") / model_config.type
    exp_dir = exp_dir.joinpath(
        f"epochs_{args.epochs}_batch_size_{args.batch_size}_learning_rate_{args.learning_rate}"
        f"_weight_decay_{args.weight_decay}")

    if not exp_dir.exists():
        exp_dir.mkdir(parents=True)

    if args.fix_seed:
        torch.manual_seed(777)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    preprocessor = get_preprocessor(ptr_config_info, model_config)

    with open(ptr_config_info.config, mode="r") as io:
        ptr_config = json.load(io)

    # model
    config = BertConfig()
    config.update(ptr_config)
    model = PairwiseClassifier(config,
                               num_classes=model_config.num_classes,
                               vocab=preprocessor.vocab)
    bert_pretrained = torch.load(ptr_config_info.bert)
    model.load_state_dict(bert_pretrained, strict=False)

    tr_dl, val_dl = get_data_loaders(dataset_config, preprocessor,
                                     args.batch_size)

    loss_fn = nn.CrossEntropyLoss()
    opt = optim.Adam([
        {
            "params": model.bert.parameters(),
            "lr": args.learning_rate / 100
        },
        {
            "params": model.classifier.parameters(),
            "lr": args.learning_rate
        },
    ],
                     weight_decay=args.weight_decay)

    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')
    model.to(device)

    writer = SummaryWriter(f'{exp_dir}/runs')
    checkpoint_manager = CheckpointManager(exp_dir)
    summary_manager = SummaryManager(exp_dir)
    best_val_loss = 1e+10

    for epoch in tqdm(range(args.epochs), desc='epochs'):

        tr_loss = 0
        tr_acc = 0

        model.train()
        for step, mb in tqdm(enumerate(tr_dl), desc='steps', total=len(tr_dl)):
            x_mb, x_types_mb, y_mb = map(lambda elm: elm.to(device), mb)
            opt.zero_grad()
            y_hat_mb = model(x_mb, x_types_mb)
            mb_loss = loss_fn(y_hat_mb, y_mb)
            mb_loss.backward()
            opt.step()

            with torch.no_grad():
                mb_acc = acc(y_hat_mb, y_mb)

            tr_loss += mb_loss.item()
            tr_acc += mb_acc.item()

            if (epoch * len(tr_dl) + step) % args.summary_step == 0:
                val_loss = evaluate(model, val_dl, {'loss': loss_fn},
                                    device)['loss']
                writer.add_scalars('loss', {
                    'train': tr_loss / (step + 1),
                    'val': val_loss
                },
                                   epoch * len(tr_dl) + step)
                model.train()
        else:
            tr_loss /= (step + 1)
            tr_acc /= (step + 1)

            tr_summary = {'loss': tr_loss, 'acc': tr_acc}
            val_summary = evaluate(model, val_dl, {
                'loss': loss_fn,
                'acc': acc
            }, device)
            tqdm.write(
                f"epoch: {epoch+1}\n"
                f"tr_loss: {tr_summary['loss']:.3f}, val_loss: {val_summary['loss']:.3f}\n"
                f"tr_acc: {tr_summary['acc']:.2%}, val_acc: {val_summary['acc']:.2%}"
            )

            val_loss = val_summary['loss']
            is_best = val_loss < best_val_loss

            if is_best:
                state = {
                    'epoch': epoch + 1,
                    'model_state_dict': model.state_dict(),
                    'opt_state_dict': opt.state_dict()
                }
                summary = {'train': tr_summary, 'validation': val_summary}

                summary_manager.update(summary)
                summary_manager.save('summary.json')
                checkpoint_manager.save_checkpoint(state, 'best.tar')

                best_val_loss = val_loss