Beispiel #1
0
def main():

    data_set_names = {
        'WikipediaMedical': 'WM',
    }

    if args.data_sets == ['EHR']:
        args.textbook_data_sets = []
    else:
        args.textbook_data_sets = args.data_sets

    # if len(args.textbook_data_sets) == 0:
    #     base_dir = args.global_dir + '/' + '_'.join(args.data_sets)
    # else:
    #     base_dir = args.global_dir + '/' + '_'.join([data_set_names[ds] for ds in args.textbook_data_sets])

    data = DataUtil(data_dir=args.data_dir, vocab_dir=args.vocab_dir,
                    split_by_sentence=not args.split_by_section, skip_list=args.skip_list)
    if args.reload_data:
        # if self.config.textbook_data_ratio > 0:
        for ds in args.textbook_data_sets:
            data.load_textbook_train_dev_data(
                args.ref_data_dir + 'medlit/' + args.data_type + '/train/' + ds,
                args.ref_data_dir + 'medlit/' + args.data_type + '/dev/' + ds)
        # train
        data.load_i2b2_train_data(train_base_dir=args.ref_data_dir + 'i2b2_ehr/' + args.data_type)
        # test
        data.load_test_data(ref_base_dir=args.ref_data_dir + 'i2b2_ehr/' + args.data_type, i2b2=True)
        # dev
        data.load_test_data(ref_base_dir=args.ref_data_dir + 'i2b2_ehr/' + args.data_type, i2b2=True,
                                 type='dev')

    else:
        data.load_split_data()

    logger.info("MedLit Training data: " + str(len(data.textbook_train_data)))
    logger.info("MedLit Dev data: " + str(len(data.textbook_dev_data)))
    logger.info("i2b2 Training data: " + str(len(data.i2b2_train_data)))
    logger.info("i2b2 Dev data: " + str(len(data.i2b2_dev_data)))
    logger.info("i2b2 Test data: " + str(len(data.i2b2_test_data)))

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
        device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
            args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError("At least one of `do_train` or `do_eval` must be True.")

    # task_name = args.task_name.lower()
    #
    # if task_name not in processors:
    #     raise ValueError("Task not found: %s" % (task_name))

    #processor = processors[task_name]()
    num_labels = 11
    label_list = ['Allergies', 'Assessment and Plan', 'Chief Complaint', 'Examination', 'Family History', 'Findings',
                  'Medications', 'Past Medical History', 'Personal and Social history', 'Procedures',
                  'Review of Systems']

    logger.info("Num Labels: " + str(num_labels))
    logger.info("Labels: " + str(label_list))

    tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)

    train_examples = []
    num_train_steps = None
    if args.do_train:
        for data_name in args.train_data:
            if data_name == "i2b2" or data_name == "ALL":
                if args.i2b2_data_ratio != 1:
                    train_examples.extend(data.get_data_subset(data.i2b2_train_data, args.i2b2_data_ratio))
                else:
                    train_examples.extend(data.i2b2_train_data)
            if data_name == "MedLit" or data_name == "ALL":
                train_examples.extend(data.textbook_train_data)
        num_train_steps = int(
            len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)

    logger.info("Combined Train data: " + str(len(train_examples)))

    # Prepare model
    model = BertForSequenceClassification.from_pretrained(args.bert_model,
                                                          cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(
                                                              args.local_rank),
                                                          num_labels=num_labels)
    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
    t_total = num_train_steps
    if args.local_rank != -1:
        t_total = t_total // torch.distributed.get_world_size()
    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)

    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=t_total)

    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    best_f1 = 0
    best_model = model
    if args.do_train:
        train_features = convert_examples_to_features(
            train_examples, label_list, args.max_seq_length, tokenizer)
        logger.info("***** Running training *****")
        logger.info("  Num train examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_steps)
        all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
        train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)

        dev_examples = []
        name = ""
        for data_name in args.tuning_set:
            name += data_name + "_"
            if data_name == "i2b2" or data_name == "ALL":
                random.shuffle(data.i2b2_dev_data)
                dev_examples.extend(data.i2b2_dev_data[:500])
            if data_name == "MedLit" or data_name == "ALL":
                random.shuffle(data.textbook_dev_data)
                dev_examples.extend(data.textbook_dev_data[:500])
        dev_features = convert_examples_to_features(
            dev_examples, label_list, args.max_seq_length, tokenizer)

        logger.info(" Num dev examples: " + str(len(dev_examples)))

        logger.info("EVAL on Pretrained model only: " + args.bert_model)
        run_eval(args, model, device, dev_examples, dev_features, 0, global_step,
                 name, label_list, save_results=False)

        model.train()
        for epoch in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch
                loss = model(input_ids, segment_ids, input_mask, label_ids)
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    # modify learning rate with special warm up BERT uses
                    lr_this_step = args.learning_rate * warmup_linear(global_step / t_total, args.warmup_proportion)
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

            loss = tr_loss / nb_tr_steps
            f1 = run_eval(args, model, device, dev_examples, dev_features, loss, global_step,
                          name, label_list, save_results=False)
            logger.info(str(epoch) + "/" + str(args.num_train_epochs) + ". loss: " + str(loss) + ", F1: " + str(f1))
            if f1 > best_f1:
                best_f1 = f1
                best_model = model
                output_model_file = os.path.join(args.output_dir, "pytorch_model" + str(epoch) + ".bin")
                logger.info("Saving best model with F1: " + str(best_f1))
                model_to_save = best_model.module if hasattr(best_model,
                                                        'module') else best_model  # Only save the model it-self
                torch.save(model_to_save.state_dict(), output_model_file)

    # Save a trained model
    output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")

    if args.do_train:
        logger.info("Saving best model with F1: " + str(best_f1))
        model_to_save = best_model.module if hasattr(best_model, 'module') else best_model  # Only save the model it-self
        torch.save(model_to_save.state_dict(), output_model_file)
    else:
        model_state_dict = torch.load(os.path.join(args.bert_model, "pytorch_model.bin"))
        best_model = BertForSequenceClassification.from_pretrained(args.bert_model, state_dict=model_state_dict,
                                                              num_labels=num_labels)
    best_model.to(device)

    if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
        loss = tr_loss / nb_tr_steps if args.do_train else None

        if "ALL" in args.test_set or "MedLit" in args.test_set:
            eval_examples = data.textbook_dev_data
            eval_features = convert_examples_to_features(
                eval_examples, label_list, args.max_seq_length, tokenizer)

            run_eval(args, best_model, device, eval_examples, eval_features, loss, global_step, "medlit_dev",
                     label_list, print_examples=True)

        if "ALL" in args.test_set or "i2b2" in args.test_set:
            eval_examples = data.i2b2_test_data
            eval_features = convert_examples_to_features(
               eval_examples, label_list, args.max_seq_length, tokenizer)

            run_eval(args, best_model, device, eval_examples, eval_features, loss, global_step, "i2b2_test",
                    label_list, print_examples=True)
Beispiel #2
0
class Experiment:
    def __init__(self, config, sequence_length=20, reload_data=True):
        # Hyper Parameters
        self.sequence_length = sequence_length
        self.hidden_size = 128
        self.num_layers = 1

        self.config = config
        self.data = DataUtil(data_dir=config.data_dir,
                             vocab_dir=config.vocab_dir,
                             split_by_sentence=not config.split_by_section,
                             skip_list=config.skip_list)

        if not self.config.filtered:
            self.data.make_dir(self.config.output_dir + "/models/")

        if reload_data:
            for ds in self.config.textbook_data_sets:
                self.data.load_textbook_train_dev_data(
                    config.data_dir + 'medlit/train/' + ds,
                    config.data_dir + 'medlit/dev/' + ds)
            # train
            self.data.load_i2b2_train_data(train_base_dir=config.data_dir +
                                           '/i2b2_ehr/')
            # test
            self.data.load_test_data(ref_base_dir=config.data_dir +
                                     '/i2b2_ehr/')
            # dev
            self.data.load_test_data(ref_base_dir=config.data_dir +
                                     '/i2b2_ehr/',
                                     type='dev')

        else:
            self.data.load_split_data()

        self.data.make_dir(self.config.output_dir)

        log_file_name = strftime("log_%Y_%m_%d_%H_%M_%S", localtime())
        self.logger = self.setup_logger(self.config.output_dir +
                                        '/%s.txt' % log_file_name)

        if exists(config.vocab_dir + "/NaturalLang.pkl") and not reload_data:
            print("Loading vocab")
            self.data.load_vocab()
        else:
            print("Building vocab")
            self.data.build_vocab(self.data.textbook_train_data,
                                  pretrain=False)

        self.model = None
        self.use_cuda = torch.cuda.is_available()

        if not self.config.filtered:
            if self.config.model_type == 'gru_rnn':
                self.model = GRURNN(
                    self.config.embedding_size, self.hidden_size,
                    self.data.input_lang, self.data.pretrained_embeddings,
                    self.num_layers, self.data.input_lang.n_words,
                    self.data.output_lang.n_words, self.config.dropout)
            elif self.config.model_type == 'attn_gru_rnn':
                self.model = AttentionGRURNN(
                    self.config.embedding_size, self.hidden_size,
                    self.data.input_lang, self.data.pretrained_embeddings,
                    self.num_layers, self.data.input_lang.n_words,
                    self.data.output_lang.n_words, self.config.dropout)
            elif self.config.model_type == 'cnn':
                self.model = CNN(self.data.input_lang.n_words,
                                 self.data.output_lang.n_words,
                                 self.config.embedding_size,
                                 self.data.input_lang,
                                 self.data.pretrained_embeddings,
                                 self.config.dropout)

            self.epoch_start = 1

            if self.use_cuda:
                self.model = self.model.cuda()

    def setup_logger(self, log_file, level=logging.INFO):
        logger = logging.getLogger()
        logger.setLevel(level)
        handler = logging.FileHandler(log_file)
        formatter = logging.Formatter('%(asctime)s %(message)s')
        handler.setFormatter(formatter)
        logger.addHandler(handler)
        return logger

    def log(self, info):
        print(info)
        if self.logger is not None:
            self.logger.info(info)

    def as_minutes(self, s):
        m = math.floor(s / 60)
        s -= m * 60
        return '%dm %ds' % (m, s)

    def time_since(self, since, percent):
        now = time.time()
        s = now - since
        es = s / percent
        rs = es - s
        return '%s (- %s)' % (self.as_minutes(s), self.as_minutes(rs))

    def train(self,
              data_setup,
              save_model_dir,
              print_every=20,
              plot_every=100,
              learning_rate=0.001):
        start = time.time()

        plot_losses = []
        print_loss_total = 0
        plot_loss_total = 0

        if self.config.model_type == 'cnn' and self.config.transfer_learning:
            self.model.output_size = self.data.output_lang.n_words
            if self.config.reuse_embedding_layer_only:
                self.model.init_conv1_layer()
                self.model.init_conv2_layer()
                self.model.init_fc_layers()
            if self.config.reuse_embedding_conv1_layers:
                self.model.init_conv2_layer()
                self.model.init_fc_layers()
            if self.use_cuda:
                self.model = self.model.cuda()
        elif self.config.transfer_learning:
            self.model.freeze_layer("fc1")

        if self.config.optimizer == 'sgd':
            optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                         self.model.parameters()),
                                  lr=learning_rate,
                                  momentum=0.9)
        elif self.config.optimizer == 'adam':
            optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                          self.model.parameters()),
                                   lr=learning_rate)

        self.log('data_setup:' + str(data_setup))

        train_data = []

        for data_set in data_setup:
            data_ratio = data_setup[data_set]
            data = self.data.get_dataset(data_set)
            train_data += self.data.get_data_subset(data, data_ratio)
        print('len train_data:', len(train_data))
        print('training data examples:', train_data[:5])

        if self.config.downsampling:
            train_data = self.data.downsampling(
                train_data, number_samples=self.config.downsampling_size)

        num_train_data = len(train_data)
        print('num_train_data:', num_train_data)
        print('train_data:', train_data[:10])
        num_batches = int(
            np.ceil(num_train_data / float(self.config.batch_size)))
        self.log('num_batches: ' + str(num_batches))

        if self.config.weighted_loss:
            loss_weight = self.data.get_label_weight(train_data)
            if self.use_cuda:
                loss_weight = loss_weight.cuda()
        else:
            loss_weight = None

        max_dev_acc = 0

        for epoch in range(self.epoch_start, self.config.num_train_epochs + 1):
            batch_start = time.time()
            correct = 0
            total = 0

            random.shuffle(train_data)

            self.model.train()

            for cnt, i in enumerate(random.sample(range(num_batches),
                                                  num_batches),
                                    start=1):
                inputs, seq_lengths, targets, batch = self.data.construct_batch(
                    self.config.batch_size * i,
                    self.config.batch_size * (i + 1),
                    train_data,
                    fixed_length=True
                    if self.config.model_type == 'cnn' else False)

                if self.use_cuda:
                    inputs = inputs.cuda()
                    targets = targets.cuda()

                optimizer.zero_grad()

                if self.config.model_type == 'cnn':
                    outputs = self.model(inputs)  # for CNN
                elif self.config.model_type == 'attn_gru_rnn':
                    outputs = self.model(inputs, self.data.input_lang,
                                         seq_lengths)
                else:
                    outputs = self.model(inputs, seq_lengths)

                _, predicted = torch.max(outputs.data, dim=1)

                total += targets.data.size(0)
                correct += (predicted == targets.data).sum()
                batch_train_acc = 100.0 * (
                    predicted == targets.data).sum() / targets.data.size(0)

                loss = F.cross_entropy(outputs, targets, weight=loss_weight)
                loss.backward()
                optimizer.step()
                self.log(
                    "Epoch %d, batch %d / %d: train loss = %f, train accuracy = %f %%"
                    % (epoch, cnt, num_batches, loss.data.item(),
                       batch_train_acc))

                print_loss_total += loss.data.item()
                plot_loss_total += loss.data.item()

                if cnt % print_every == 0:
                    print_loss_avg = print_loss_total / print_every
                    print_loss_total = 0
                    self.log('Average batch loss: %s' % str(print_loss_avg))
                    self.log(
                        self.time_since(batch_start, cnt * 1.0 / num_batches))

                if cnt % plot_every == 0:
                    plot_loss_avg = plot_loss_total / plot_every
                    plot_losses.append(plot_loss_avg)
                    plot_loss_total = 0
            self.log('Epoch %d is done' % epoch)
            self.log('Epoch %d Train Accuracy: %f %%' %
                     (epoch, 100.0 * correct / total))
            self.log(
                self.time_since(start,
                                epoch * 1.0 / self.config.num_train_epochs))

            datasets = []
            print("TUNING SET IS: " + str(self.config.tuning_set))
            if 'ALL' in self.config.tuning_set or 'MedLit' in self.config.tuning_set:
                self.log("Test on MedLit Dev: ")
                datasets.append(self.data.TEXTBOOK_DEV)
            if 'ALL' in self.config.tuning_set or 'i2b2' in self.config.tuning_set:
                self.log("Test on i2b2 EHR Dev: ")
                datasets.append(self.data.i2b2_DEV)

            self.log("Tuning on:")
            self.log(datasets)
            dev_acc = self.test(datasets=datasets,
                                epoch=epoch,
                                calc_confusion_matrix=True)

            # save intermediate training results
            if dev_acc > max_dev_acc:
                save_path = save_model_dir + "/models/best_model.pt"
                torch.save(self.model, save_path)
                self.log('Best Model saved in file: %s' % save_path)
                max_dev_acc = dev_acc

                if 'i2b2' in self.config.test_set:
                    self.log("Test on i2b2 Test:")
                    self.test(datasets=[self.data.i2b2_TEST],
                              epoch=epoch,
                              print_test_results=True)

            save_path = save_model_dir + "/models/epoch_" + str(epoch) + ".pt"
            torch.save(self.model, save_path)
            self.log('Model saved in file: %s' % save_path)

    def test(self,
             datasets,
             epoch=-1,
             calc_confusion_matrix=True,
             generate_reports=True,
             print_test_results=False,
             print_examples=False):
        if self.model is None:
            self.log('Restoring model from ' + self.config.reload_model_file)

            if torch.cuda.is_available():
                self.model = torch.load(self.config.reload_model_file)
            else:
                self.model = torch.load(self.config.reload_model_file,
                                        map_location='cpu')

            self.log('Model is restored')

        self.model.eval()

        start = time.time()

        data = []

        dataset_name = '_'.join(datasets)

        for dataset in datasets:
            data.extend(self.data.get_dataset(dataset))

        if self.config.downsampling:
            data = []
            for dataset in datasets:
                data.extend(self.data.get_dataset(dataset))
                data = self.data.downsampling(data, number_samples=500)

        num_test_data = len(data)
        self.log("num_test_data: " + str(num_test_data))
        num_batches = int(
            np.ceil(num_test_data / float(self.config.batch_size)))
        self.log('num_batches: ' + str(num_batches))

        correct = 0
        total = 0
        loss = 0.0
        labels = []
        predictions = []
        examples = []

        for i in range(num_batches):

            inputs, seq_lengths, targets, batch = self.data.construct_batch(
                self.config.batch_size * i,
                self.config.batch_size * (i + 1),
                data,
                fixed_length=True
                if self.config.model_type == 'cnn' else False)

            if self.use_cuda:
                inputs = inputs.cuda()
                targets = targets.cuda()

            if self.config.model_type == 'cnn':
                outputs = self.model(inputs)  # for CNN
            elif self.config.model_type == 'attn_gru_rnn':
                outputs = self.model(inputs, self.data.input_lang, seq_lengths)
            else:
                outputs = self.model(inputs, seq_lengths)

            _, predicted = torch.max(outputs.data, dim=1)

            ordered = torch.sort(outputs.data)

            total += targets.data.size(0)
            correct += (predicted == targets.data).sum()
            labels.extend(targets.cpu().data.numpy().tolist())
            predictions.extend(predicted.cpu().numpy().tolist())

            loss += F.cross_entropy(outputs, targets).data.item()

            if print_examples or print_test_results:
                for k, d in enumerate(batch):
                    examples.append([
                        d[0],
                        d[1].replace('\r',
                                     ' ').replace('\n',
                                                  ' ').replace('\t', ' '),
                        d[2], d[3],
                        str(d[4]),
                        str(d[5]),
                        self.data.output_lang.get_word(
                            predicted[k].cpu().data.item()),
                        self.data.output_lang.get_word(
                            int(ordered[1][k][outputs.data.shape[1] - 2])),
                        self.data.output_lang.get_word(
                            int(ordered[1][k][outputs.data.shape[1] - 3]))
                    ])

        if print_examples:
            self.data.make_dir(self.config.output_dir + '/test_saved')
            self.log("Save examples to: " + self.config.output_dir +
                     '/test_saved')
            with open(
                    self.config.output_dir + '/test_saved/' + dataset_name +
                    'epoch_%d.txt' % epoch, 'w') as f:
                f.write(
                    "#\tSentence\tTrue\tHeader String\tLocation\tLine\tPrediction 1\tPrediction 2\tPrediction 3\n"
                )
                for e in examples:
                    f.write('\t'.join(e) + '\n')

        self.log('Epoch %d ' % epoch + 'Time used: ' +
                 str(time.time() - start))
        self.log('Epoch %d ' % epoch + 'Test loss: %f' % loss)
        self.log('Epoch %d ' % epoch + 'Test Accuracy: %f %%' %
                 (100.0 * correct / total))
        self.log(
            'Epoch %d ' % epoch + 'Test Precision: %f %%' %
            (100.0 * precision_score(labels, predictions, average='micro')))
        self.log('Epoch %d ' % epoch + 'Test Recall: %f %%' %
                 (100.0 * recall_score(labels, predictions, average='micro')))
        self.log('Epoch %d ' % epoch + 'Test F1 Score: %f %%' %
                 (100.0 * f1_score(labels, predictions, average='micro')))

        text_labels = [self.data.output_lang.get_word(l) for l in labels]
        text_preds = [self.data.output_lang.get_word(l) for l in predictions]
        label_set = sorted(list(set(text_labels)))
        if calc_confusion_matrix:
            cm = confusion_matrix(text_labels, text_preds, labels=label_set)
            self.log('confusion_matrix for epoch %d: ' % epoch)
            header = '\t'.join(label_set)
            self.log(header)
            for i, row in enumerate(list(cm)):
                row = [str(num) for num in row]
                self.log('\t'.join([label_set[i]] + row))
            np.savetxt(self.config.output_dir + '/' + dataset_name +
                       '_confusion_matrix_epoch_%d.csv' % epoch,
                       cm,
                       fmt='%d',
                       header=header,
                       delimiter=',')
            self.log('Saved confusion matrix!')

        if generate_reports:
            reports = classification_report(text_labels,
                                            text_preds,
                                            labels=label_set,
                                            target_names=label_set,
                                            digits=4)
            self.log(reports)
            with open(
                    self.config.output_dir + '/' + dataset_name +
                    '_report_epoch_%d.txt' % epoch, 'w') as f:
                f.write(reports)
            self.log('Saved report!')

        if print_test_results:
            with open(
                    self.config.output_dir + '/' + dataset_name +
                    '_predictions_epoch_%d.json' % epoch, 'w') as f:
                json.dump(examples, f, indent=4, sort_keys=True)
        return 100.0 * correct / total

    def test_one(self, header, text):
        if self.model is None:
            self.log('Restoring model from ' + self.config.reload_model_file)

            if torch.cuda.is_available():
                self.model = torch.load(self.config.reload_model_file)
            else:
                self.model = torch.load(self.config.reload_model_file,
                                        map_location='cpu')
            self.log('Model is restored')
            self.model.eval()
            if self.use_cuda:
                self.model = self.model.cuda()

        inputs, seq_lengths, targets = self.data.construct_one(
            header,
            text,
            fixed_length=True if self.config.model_type == 'cnn' else False)

        if self.use_cuda:
            inputs = inputs.cuda()
            targets = targets.cuda()

        if self.config.model_type == 'cnn':
            outputs = self.model(inputs)  # for CNN
        elif self.config.model_type == 'attn_gru_rnn':
            outputs = self.model(inputs, self.data.input_lang, seq_lengths)
        else:
            outputs = self.model(inputs, seq_lengths)

        _, predicted = torch.max(outputs.data, dim=1)

        return predicted.cpu().numpy().tolist() == targets.cpu().data.numpy(
        ).tolist()