Пример #1
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--mode', type=str, default='train', help='train or eval')
    parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
    parser.add_argument('--epochs', type=int, default=10, help='number of training epochs')
    parser.add_argument('--batch_size', type=int, default=64, help='number of examples to process in a batch')
    parser.add_argument('--num_classes', type=int, default=6, help='number of target classes')
    parser.add_argument('--max_norm', type=float, default=5.0, help='max norm of gradient')
    parser.add_argument('--embed_trainable', type=bool, default=True, help='finetune pre-trained embeddings')
    parser.add_argument('--kernel_sizes', nargs='+', type=int, default=[2, 3, 4], help='kernel sizes for the convolution layer')
    parser.add_argument('--device', type=str, default=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'))
    parser.add_argument('--p', type=float, default=0.5, help='dropout rate')
    parser.add_argument('--c_out', type=int, default=32, help='output channel size of the convolution layer')
    args = parser.parse_args()

    if args.mode == 'train':
        sys.stdout = Logger(TRAIN_LOG_LOC)
        print_statement('HYPERPARAMETER SETTING')
        print_flags(args)
        train(args, MODEL_LOC)
    else:
        sys.stdout = Logger(TEST_LOG_LOC)
        print_statement('HYPERPARAMETER SETTING')
        print_flags(args)
        test(args, MODEL_LOC, LABEL_JSON_LOC)
Пример #2
0
def test(args, MODEL_LOC, LABEL_JSON_LOC):
    print_statement('LOAD EMBEDDINGS')
    label_map = load_json(LABEL_JSON_LOC, reverse=True, name='Label Mapping')
    with open('dataset/ind2token', 'rb') as f:
        ind2token = pickle.load(f)
    with open('dataset/token2ind', 'rb') as f:
        token2ind = pickle.load(f)
    with open('dataset/embeddings_vector', 'rb') as f:
        embeddings_vector = pickle.load(f)
    print_value('Embed shape', embeddings_vector.shape)
    print_value('Vocab size', len(ind2token))
    batch_size = args.batch_size
    embedding_size = embeddings_vector.shape[1]
    model = TextCNN(batch_size=batch_size,
                    c_out=args.c_out,
                    output_size=args.num_classes,
                    vocab_size=len(ind2token),
                    embedding_size=embedding_size,
                    embeddings_vector=torch.from_numpy(embeddings_vector),
                    kernel_sizes=args.kernel_sizes,
                    trainable=args.embed_trainable,
                    p=args.p)
    model.to(args.device)
    ckpt = torch.load(MODEL_LOC, map_location=args.device)
    model.load_state_dict(ckpt["state_dict"])
    model.eval()
    print_statement('MODEL TESTING')
    qcdataset = QCDataset(token2ind, ind2token, split='test', batch_first=True)
    dataloader_test = DataLoader(qcdataset,
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 collate_fn=qcdataset.collate_fn)
    ct = ClassificationTool(len(label_map))
    accs = []
    length = []
    for batch_inputs, batch_targets in dataloader_test:
        batch_inputs = batch_inputs.to(args.device)
        batch_targets = batch_targets.to(args.device)
        with torch.no_grad():
            output = model(batch_inputs)
        acc = torch.sum(output.argmax(dim=1) == batch_targets)
        accs.append(acc)
        length.append(len(batch_targets))
        ct.update(output, batch_targets)
    test_acc = float(np.sum(accs)) / sum(length)
    print('Testing on {} data:'.format(sum(length)))
    print('+ Overall ACC: {:.3f}'.format(test_acc))
    PREC, REC, F1 = ct.get_result()
    for i, classname in enumerate(label_map.values()):
        print('* {} PREC: {:.3f}, {} REC: {:.3f}, {} F1: {:.3f}'.format(
            classname[:3], PREC[i], classname[:3], REC[i], classname[:3],
            F1[i]))
Пример #3
0
def main():
    # Load parameters.
    parser = argparse.ArgumentParser()
    parser.add_argument('--classifier',
                        type=str,
                        default='TextCNN',
                        help='classifier to use "LSTM/TextCNN"')
    parser.add_argument('--pretrained',
                        type=bool,
                        default=False,
                        help='finetune pre-trained classifier')
    parser.add_argument('--mode',
                        type=str,
                        default='train',
                        help='train or eval')
    parser.add_argument('--epochs',
                        type=int,
                        default=50,
                        help='number of training epochs')
    parser.add_argument('--batch_size',
                        type=int,
                        default=64,
                        help='number of examples to process in a batch')
    parser.add_argument('--max_norm',
                        type=float,
                        default=5.0,
                        help='max norm of gradient')
    parser.add_argument('--embed_trainable',
                        type=bool,
                        default=True,
                        help='finetune pre-trained embeddings')
    parser.add_argument(
        '--device',
        type=str,
        default=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'))
    # rationale specific parameters.
    parser.add_argument('--lr_enc',
                        type=float,
                        default=1e-3,
                        help='learning rate for the encoder')
    parser.add_argument('--lr_gen',
                        type=float,
                        default=1e-3,
                        help='learning rate for the generator')
    parser.add_argument(
        '--num_hidden_rationale',
        type=int,
        default=64,
        help='number of hidden units for the PreGenerator LSTM for rationale')
    parser.add_argument(
        '--lstm_layer_rationale',
        type=int,
        default=2,
        help='number of layers for the PreGenerator LSTM for rationale')
    parser.add_argument(
        '--lstm_bidirectional_rationale',
        type=bool,
        default=True,
        help='bi-direction for the PreGenerator LSTM for rationale')
    parser.add_argument('--lambda_1',
                        type=float,
                        default=1e-2,
                        help='regularizer of the length of selected words')
    parser.add_argument('--lambda_2',
                        type=float,
                        default=1e-3,
                        help='regularizer of the local coherency of words')
    parser.add_argument(
        '--agg_mode',
        type=str,
        default='fc',
        help='aggregation mode chosen after the pregenerator LSTM layer')
    # LSTM specific parameters.
    parser.add_argument('--num_hidden',
                        type=int,
                        default=256,
                        help='number of hidden units in the LSTM classifier')
    parser.add_argument('--lstm_layer',
                        type=int,
                        default=2,
                        help='number of layers of lstm')
    parser.add_argument('--lstm_bidirectional',
                        type=bool,
                        default=True,
                        help='bi-direction of lstm')
    # TextCNN specific parameters.
    parser.add_argument('--num_classes',
                        type=int,
                        default=6,
                        help='number of target classes')
    parser.add_argument('--kernel_sizes',
                        nargs='+',
                        type=int,
                        default=[2, 3, 4],
                        help='kernel sizes for the convolution layer')
    parser.add_argument('--p', type=float, default=0.5, help='dropout rate')
    parser.add_argument('--c_out',
                        type=int,
                        default=32,
                        help='output channel size of the convolution layer')

    args = parser.parse_args()

    # Create log object.
    if args.mode == 'train':
        sys.stdout = Logger(TRAIN_LOG_LOC)
        print_statement('HYPERPARAMETER SETTING')
        print_flags(args)
        train(args, GEN_MODEL_LOC, LSTM_MODEL_LOC, TCN_MODEL_LOC,
              LABEL_JSON_LOC)

    else:
        sys.stdout = Logger(TEST_LOG_LOC)
        print_statement('HYPERPARAMETER SETTING')
        print_flags(args)
        test(args, GEN_MODEL_LOC, LSTM_MODEL_LOC, TCN_MODEL_LOC,
             LABEL_JSON_LOC)
Пример #4
0
def train(args, MODEL_LOC):
    print_statement('LOAD EMBEDDINGS')
    with open('dataset/ind2token', 'rb') as f:
        ind2token = pickle.load(f)
    with open('dataset/token2ind', 'rb') as f:
        token2ind = pickle.load(f)
    with open('dataset/embeddings_vector', 'rb') as f:
        embeddings_vector = pickle.load(f)
    print_value('Embed shape', embeddings_vector.shape)
    print_value('Vocab size', len(ind2token))
    print_statement('MODEL TRAINING')
    batch_size = args.batch_size
    embedding_size = embeddings_vector.shape[1]
    qcdataset = QCDataset(token2ind, ind2token, batch_first=True)
    dataloader_train = DataLoader(qcdataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  collate_fn=qcdataset.collate_fn)
    qcdataset = QCDataset(token2ind, ind2token, split='val', batch_first=True)
    dataloader_validate = DataLoader(qcdataset,
                                     batch_size=batch_size,
                                     shuffle=True,
                                     collate_fn=qcdataset.collate_fn)
    model = TextCNN(batch_size=batch_size,
                    c_out=args.c_out,
                    output_size=args.num_classes,
                    vocab_size=len(ind2token),
                    embedding_size=embedding_size,
                    embeddings_vector=torch.from_numpy(embeddings_vector),
                    kernel_sizes=args.kernel_sizes,
                    trainable=args.embed_trainable,
                    p=args.p)
    model.to(args.device)
    criterion = nn.CrossEntropyLoss()
    optim = torch.optim.Adam(model.parameters(), lr=args.lr)
    best_eval = 0
    iteration = 0
    max_iterations = args.epochs * len(dataloader_train)
    for i in range(args.epochs):
        for batch_inputs, batch_targets in dataloader_train:
            iteration += 1
            batch_inputs = batch_inputs.to(args.device)
            batch_targets = batch_targets.to(args.device)
            model.train()
            optim.zero_grad()
            output = model(batch_inputs)
            loss = criterion(output, batch_targets)
            accuracy = float(torch.sum(output.argmax(
                dim=1) == batch_targets)) / len(batch_targets)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           max_norm=args.max_norm)
            optim.step()
            if iteration % 10 == 0:
                print(
                    'Train step: {:d}/{:d}, Train loss: {:3f}, Train accuracy: {:.3f}'
                    .format(iteration, max_iterations, loss, accuracy))
            if iteration % 100 == 0 and iteration > 0:
                print_statement('MODEL VALIDATING')
                model.eval()
                accs = []
                length = []
                for batch_inputs, batch_targets in dataloader_validate:
                    batch_inputs = batch_inputs.to(args.device)
                    batch_targets = batch_targets.to(args.device)
                    with torch.no_grad():
                        output = model(batch_inputs)
                    acc = torch.sum(output.argmax(dim=1) == batch_targets)
                    length.append(len(batch_targets))
                    accs.append(acc)
                validate_acc = float(np.sum(accs)) / sum(length)
                print('Testing on {} data:'.format(sum(length)))
                print('+ Validation accuracy: {:.3f}'.format(validate_acc))
                # save best model parameters
                if validate_acc > best_eval:
                    print("New highscore! Saving model...")
                    best_eval = validate_acc
                    ckpt = {
                        "state_dict": model.state_dict(),
                        "optimizer_state_dict": optim.state_dict(),
                        "best_eval": best_eval
                    }
                    torch.save(ckpt, MODEL_LOC)
Пример #5
0
                        type=str,
                        default='softmax',
                        help='loss function (softmax, ns, hs)')
    parser.add_argument('--verbose',
                        type=int,
                        default=2,
                        help='silent: 0, progress bar: 1, detailed: 2')
    args = parser.parse_args()

    # Create log object.
    if args.mode == 'train':
        sys.stdout = Logger(TRAIN_LOG_LOC)
    else:
        sys.stdout = Logger(TEST_LOG_LOC)

    print_statement('HYPERPARAMETER SETTING', verbose=args.verbose)
    print_flags(args, verbose=args.verbose)

    # Load data.
    print_statement('DATA PROCESSING', verbose=args.verbose)
    label_map = load_json(LABEL_JSON_LOC,
                          reverse=True,
                          name='Label Mapping',
                          verbose=args.verbose)
    train_data = load_json(TRAIN_JSON_LOC,
                           label_map,
                           name='Training Set',
                           verbose=args.verbose)
    val_data = load_json(VAL_JSON_LOC,
                         label_map,
                         name='Validation Set',
Пример #6
0
                        default=True,
                        help='bi-direction of lstm')
    parser.add_argument('--embed_trainable',
                        type=bool,
                        default=True,
                        help='finetune pre-trained embeddings')

    args = parser.parse_args()

    # Create log object.
    if args.mode == 'train':
        sys.stdout = Logger(TRAIN_LOG_LOC)
    else:
        sys.stdout = Logger(TEST_LOG_LOC)

    print_statement('HYPERPARAMETER SETTING')
    print_flags(args)

    # Load data.
    print_statement('DATA PROCESSING')
    label_map = load_json(LABEL_JSON_LOC, reverse=True, name='Label Mapping')
    train_data = load_json(TRAIN_JSON_LOC, label_map, name='Training Set')
    val_data = load_json(VAL_JSON_LOC, label_map, name='Validation Set')
    test_data = load_json(TEST_JSON_LOC, label_map, name='Test Set')
    print_statement('LOAD EMBEDDINGS')
    with open('dataset/ind2token', 'rb') as f:
        ind2token = pickle.load(f)
        f.close()
    with open('dataset/token2ind', 'rb') as f:
        token2ind = pickle.load(f)
        f.close()
Пример #7
0
def train(args, GEN_MODEL_LOC, LSTM_MODEL_LOC, TCN_MODEL_LOC, LABEL_JSON_LOC):
    print_statement('LOAD EMBEDDINGS')
    label_map = load_json(LABEL_JSON_LOC, reverse=True, name='Label Mapping')
    with open('dataset/ind2token', 'rb') as f:
        ind2token = pickle.load(f)
    with open('dataset/token2ind', 'rb') as f:
        token2ind = pickle.load(f)
    with open('dataset/embeddings_vector', 'rb') as f:
        embeddings_vector = pickle.load(f)
    print_value('Embed shape', embeddings_vector.shape)
    print_value('Vocab size', len(ind2token))

    batch_size = args.batch_size
    embedding_size = embeddings_vector.shape[1]
    qcdataset = QCDataset(token2ind, ind2token, batch_first=True)
    dataloader_train = DataLoader(qcdataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  collate_fn=qcdataset.collate_fn)
    qcdataset = QCDataset(token2ind, ind2token, split='val', batch_first=True)
    dataloader_validate = DataLoader(qcdataset,
                                     batch_size=batch_size,
                                     shuffle=True,
                                     collate_fn=qcdataset.collate_fn)

    if args.classifier == 'LSTM':
        classifier = LSTMClassifier(
            output_size=args.num_classes,
            hidden_size=args.num_hidden,
            embedding_length=embedding_size,
            embeddings_vector=torch.from_numpy(embeddings_vector),
            lstm_layer=args.lstm_layer,
            lstm_dirc=args.lstm_bidirectional,
            trainable=args.embed_trainable,
            device=args.device)
        ckpt_path = 'LSTM/model/best_model.pt'
        ENC_MODEL_LOC = LSTM_MODEL_LOC

    elif args.classifier == 'TextCNN':
        classifier = TextCNN(
            batch_size=batch_size,
            c_out=args.c_out,
            output_size=args.num_classes,
            vocab_size=len(ind2token),
            embedding_size=embedding_size,
            embeddings_vector=torch.from_numpy(embeddings_vector),
            kernel_sizes=args.kernel_sizes,
            trainable=args.embed_trainable,
            p=args.p)
        ckpt_path = 'TextCNN/model/best_model.pt'
        ENC_MODEL_LOC = TCN_MODEL_LOC
    if args.pretrained:
        ckpt = torch.load(ckpt_path, map_location=args.device)
        classifier.load_state_dict(ckpt['state_dict'])
    # for parameter in classifier.parameters():
    # parameter.requires_grad = False
    classifier.to(args.device)
    # classifier.eval()

    pregen = PreGenerator(
        hidden_size=args.num_hidden_rationale,
        embedding_size=embedding_size,
        lstm_layer=args.lstm_layer_rationale,
        lstm_dirc=args.lstm_bidirectional_rationale,
        embeddings_vector=torch.from_numpy(embeddings_vector),
        trainable=args.embed_trainable,
        agg_mode=args.agg_mode)
    pregen.to(args.device)

    # for name, parameter in pregen.named_parameters():
    #     print(name)
    #     print(parameter)

    print_statement('MODEL TRAINING')
    criterion = torch.nn.CrossEntropyLoss(reduction='none')
    gen_optimizer = torch.optim.Adam(pregen.parameters(), lr=args.lr_gen)
    enc_optimizer = torch.optim.Adam(classifier.parameters(), lr=args.lr_enc)
    best_eval = 0
    iteration = 0
    max_iterations = args.epochs * len(dataloader_train)
    for i in range(args.epochs):
        for batch_inputs, batch_targets in dataloader_train:
            iteration += 1
            batch_inputs = batch_inputs.to(args.device)
            batch_targets = batch_targets.to(args.device)
            pregen.train()
            classifier.train()
            gen_optimizer.zero_grad()
            enc_optimizer.zero_grad()
            p_z_x = pregen(batch_inputs)
            dist = D.Bernoulli(probs=p_z_x)
            pregen_output = dist.sample()
            batch_inputs_masked = batch_inputs.clone()
            batch_inputs_masked[torch.eq(pregen_output, 0.)] = 1
            classifier_output = classifier(batch_inputs_masked)
            selection_loss = args.lambda_1 * pregen_output.sum(dim=-1)
            transition_loss = args.lambda_2 * (
                pregen_output[:, 1:] - pregen_output[:, :-1]).abs().sum(dim=-1)
            classify_loss = criterion(classifier_output, batch_targets)
            cost = selection_loss + transition_loss + classify_loss
            enc_loss = (selection_loss + transition_loss +
                        classify_loss).mean()
            enc_loss.backward()
            torch.nn.utils.clip_grad_norm_(classifier.parameters(),
                                           max_norm=args.max_norm)
            enc_optimizer.step()
            gen_loss = (cost.detach() *
                        -dist.log_prob(p_z_x).sum(dim=-1)).mean()
            gen_loss.backward()
            torch.nn.utils.clip_grad_norm_(pregen.parameters(),
                                           max_norm=args.max_norm)
            gen_optimizer.step()
            accuracy = float(
                torch.sum(classifier_output.argmax(
                    dim=1) == batch_targets)) / len(batch_targets)
            keep = compute_keep_rate(batch_inputs, pregen_output)
            if iteration % 10 == 0:
                print(
                    'Train step: {:d}/{:d}, GEN Train loss: {:.3f}, ENC Train loss: {:.3f}, '
                    'Train accuracy: {:.3f}, Keep percentage: {:.2f}'.format(
                        iteration, max_iterations, gen_loss, enc_loss,
                        accuracy, keep))
            if iteration % 100 == 0 and iteration > 0:
                print_statement('MODEL VALIDATING')
                pregen.eval()
                accs = []
                keeps = []
                length = []
                elements = []
                org_pads = []
                pads_kept = []
                for batch_inputs, batch_targets in dataloader_validate:
                    batch_inputs = batch_inputs.to(args.device)
                    batch_targets = batch_targets.to(args.device)
                    with torch.no_grad():
                        p_z_x = pregen(batch_inputs)
                        dist = D.Bernoulli(probs=p_z_x)
                        pregen_output = dist.sample()
                        batch_inputs_masked = batch_inputs.clone()
                        batch_inputs_masked[torch.eq(pregen_output, 0.)] = 1
                        classifier_output = classifier(batch_inputs_masked)
                    acc = torch.sum(
                        classifier_output.argmax(dim=1) == batch_targets)
                    keep = torch.sum(pregen_output)
                    org_pad = torch.eq(batch_inputs, 1).sum()
                    num_pads_kept = (torch.eq(batch_inputs, 1) *
                                     torch.eq(pregen_output, 1.)).sum()
                    length.append(len(batch_targets))
                    accs.append(acc)
                    keeps.append(keep)
                    org_pads.append(org_pad)
                    pads_kept.append(num_pads_kept)
                    elements.append(pregen_output.numel())
                validate_acc = float(sum(accs)) / sum(length)
                validate_keep = float(sum(keeps) - sum(pads_kept)) / float(
                    sum(elements) - sum(org_pads))
                extract_rationale(batch_inputs, batch_inputs_masked, ind2token,
                                  validate_acc, validate_keep, args.classifier,
                                  batch_targets, label_map)
                print('Testing on {} data:'.format(sum(length)))
                print('+ Validation accuracy: {:.3f}'.format(validate_acc))
                print('+ Keep percentage: {:.2f}'.format(validate_keep))
                # save best model parameters
                if validate_acc > best_eval:
                    print("New highscore! Saving model...")
                    best_eval = validate_acc
                    gen_ckpt = {
                        "state_dict": pregen.state_dict(),
                        "optimizer_state_dict": gen_optimizer.state_dict(),
                        "best_eval": best_eval
                    }
                    torch.save(gen_ckpt, GEN_MODEL_LOC)
                    enc_ckpt = {
                        "state_dict": classifier.state_dict(),
                        "optimizer_state_dict": enc_optimizer.state_dict(),
                        "best_eval": validate_keep
                    }
                    torch.save(enc_ckpt, ENC_MODEL_LOC)
Пример #8
0
def test(args, GEN_MODEL_LOC, LSTM_MODEL_LOC, TCN_MODEL_LOC, LABEL_JSON_LOC):
    print_statement('LOAD EMBEDDINGS')
    label_map = load_json(LABEL_JSON_LOC, reverse=True, name='Label Mapping')
    with open('dataset/ind2token', 'rb') as f:
        ind2token = pickle.load(f)
    with open('dataset/token2ind', 'rb') as f:
        token2ind = pickle.load(f)
    with open('dataset/embeddings_vector', 'rb') as f:
        embeddings_vector = pickle.load(f)
    print_value('Embed shape', embeddings_vector.shape)
    print_value('Vocab size', len(ind2token))

    batch_size = args.batch_size
    embedding_size = embeddings_vector.shape[1]

    if args.classifier == 'LSTM':
        classifier = LSTMClassifier(
            output_size=args.num_classes,
            hidden_size=args.num_hidden,
            embedding_length=embedding_size,
            embeddings_vector=torch.from_numpy(embeddings_vector),
            lstm_layer=args.lstm_layer,
            lstm_dirc=args.lstm_bidirectional,
            trainable=args.embed_trainable,
            device=args.device)
        ckpt_path = LSTM_MODEL_LOC

    elif args.classifier == 'TextCNN':
        classifier = TextCNN(
            batch_size=batch_size,
            c_out=args.c_out,
            output_size=args.num_classes,
            vocab_size=len(ind2token),
            embedding_size=embedding_size,
            embeddings_vector=torch.from_numpy(embeddings_vector),
            kernel_sizes=args.kernel_sizes,
            trainable=args.embed_trainable,
            p=args.p)
        ckpt_path = TCN_MODEL_LOC

    ckpt = torch.load(ckpt_path, map_location=args.device)
    classifier.load_state_dict(ckpt['state_dict'])
    classifier.to(args.device)
    classifier.eval()

    pregen = PreGenerator(
        hidden_size=args.num_hidden_rationale,
        embedding_size=embedding_size,
        lstm_layer=args.lstm_layer_rationale,
        lstm_dirc=args.lstm_bidirectional_rationale,
        embeddings_vector=torch.from_numpy(embeddings_vector),
        trainable=args.embed_trainable,
        agg_mode=args.agg_mode)
    ckpt = torch.load(GEN_MODEL_LOC, map_location=args.device)
    pregen.load_state_dict(ckpt['state_dict'])
    pregen.to(args.device)
    pregen.eval()

    print_statement('MODEL TESTING')

    qcdataset = QCDataset(token2ind, ind2token, split='test', batch_first=True)
    dataloader_test = DataLoader(qcdataset,
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 collate_fn=qcdataset.collate_fn)

    ct = ClassificationTool(len(label_map))
    accs = []
    keeps = []
    length = []
    elements = []
    org_pads = []
    pads_kept = []

    for batch_inputs, batch_targets in dataloader_test:
        batch_inputs = batch_inputs.to(args.device)
        batch_targets = batch_targets.to(args.device)
        with torch.no_grad():
            p_z_x = pregen(batch_inputs)
            dist = D.Bernoulli(probs=p_z_x)
            pregen_output = dist.sample()
            batch_inputs_masked = batch_inputs.clone()
            batch_inputs_masked[torch.eq(pregen_output, 0.)] = 1
            classifier_output = classifier(batch_inputs_masked)
        acc = torch.sum(classifier_output.argmax(dim=1) == batch_targets)
        keep = torch.sum(pregen_output)
        org_pad = torch.eq(batch_inputs, 1).sum()
        num_pads_kept = (torch.eq(batch_inputs, 1) *
                         torch.eq(pregen_output, 1.)).sum()
        accs.append(acc)
        keeps.append(keep)
        org_pads.append(org_pad)
        pads_kept.append(num_pads_kept)
        elements.append(pregen_output.numel())
        length.append(len(batch_targets))
        ct.update(classifier_output, batch_targets)
    test_acc = float(np.sum(accs)) / sum(length)
    test_keep = float(np.sum(keeps) - np.sum(pads_kept)) / float(
        sum(elements) - np.sum(org_pads))
    extract_rationale(batch_inputs, batch_inputs_masked, ind2token, test_acc,
                      test_keep, args.classifier, batch_targets, label_map)
    print('Testing on {} data:'.format(sum(length)))
    print('+ Overall ACC: {:.3f}'.format(test_acc))
    print('+ Overall KEEP: {:.3f}'.format(test_keep))
    PREC, REC, F1 = ct.get_result()
    for i, classname in enumerate(label_map.values()):
        print('* {} PREC: {:.3f}, {} REC: {:.3f}, {} F1: {:.3f}'.format(
            classname[:3], PREC[i], classname[:3], REC[i], classname[:3],
            F1[i]))