Beispiel #1
0
        def test(task=0):
            # Load a trained model that you have fine-tuned
            model_state_dict = torch.load(rx_output_model_file)
            model.load_state_dict(model_state_dict)
            model.to(device)

            model.eval()
            y_preds = []
            y_trues = []
            for test_input in tqdm(test_dataloader, desc="Testing"):
                test_input = tuple(t.to(device) for t in test_input)
                input_ids, dx_labels, rx_labels = test_input
                input_ids, dx_labels, rx_labels = input_ids.squeeze(
                ), dx_labels.squeeze(), rx_labels.squeeze(dim=0)
                with torch.no_grad():
                    loss, rx_logits = model(input_ids,
                                            dx_labels=dx_labels,
                                            rx_labels=rx_labels)
                    y_preds.append(t2n(torch.sigmoid(rx_logits)))
                    y_trues.append(t2n(rx_labels))

            print('')
            acc_container = metric_report(np.concatenate(y_preds, axis=0),
                                          np.concatenate(y_trues, axis=0),
                                          args.therhold)

            # save report
            if args.do_train:
                for k, v in acc_container.items():
                    writer.add_scalar('test/{}'.format(k), v, 0)

            return acc_container
Beispiel #2
0
    def eval(self, eval_type='test'):

        self.model.eval()

        if eval_type == 'valid':
            test_task_dict = self.meta_valid_task_entity_to_triplets
            test_task_pool = list(
                self.meta_valid_task_entity_to_triplets.keys())
        elif eval_type == 'test':
            test_task_dict = self.meta_test_task_entity_to_triplets
            test_task_pool = list(
                self.meta_test_task_entity_to_triplets.keys())
        else:
            raise ValueError("Eval Type <{}> is Wrong".format(eval_type))

        y_probs = []
        ys = []

        for unseen_entity in test_task_pool:

            triplets = test_task_dict[unseen_entity]
            triplets = np.array(triplets)
            heads, relations, tails = triplets.transpose()

            train_triplets = triplets[:self.args.few]
            test_triplets = triplets[self.args.few:]

            if (len(triplets)) - self.args.few < 1:
                continue

            samples = test_triplets
            samples = torch.LongTensor(samples)

            if self.use_cuda:
                samples = samples.cuda()

            unseen_entity_embedding = self.model(unseen_entity, train_triplets,
                                                 self.use_cuda)
            y_prob, y = self.model.predict(unseen_entity,
                                           unseen_entity_embedding,
                                           samples,
                                           target=None,
                                           use_cuda=self.use_cuda)
            y_probs.append(y_prob.detach().cpu())
            ys.append(y.detach().cpu())

        y = torch.cat(ys, dim=0).detach().cpu().numpy()
        y_prob = torch.cat(y_probs, dim=0).detach().cpu().numpy()

        results = utils.metric_report(y, y_prob)

        return results
Beispiel #3
0
    def mc_score_inference(self, eval_type='test'):

        self.model.eval()

        if eval_type == 'valid':
            test_task_dict = self.meta_valid_task_entity_to_triplets
            test_task_pool = list(
                self.meta_valid_task_entity_to_triplets.keys())

        elif eval_type == 'test':
            test_task_dict = self.meta_test_task_entity_to_triplets
            test_task_pool = list(
                self.meta_test_task_entity_to_triplets.keys())

        else:
            raise ValueError("Eval Type <{}> is Wrong".format(eval_type))

        total_task_entity = []
        total_task_entity_embeddings = []
        total_train_task_triplets = []
        total_test_task_triplets = []
        total_test_task_triplets_dict = dict()

        for task_entity in tqdm(test_task_pool):

            task_triplets = test_task_dict[task_entity]
            task_triplets = np.array(task_triplets)
            task_heads, task_relations, task_tails = task_triplets.transpose()

            train_task_triplets = task_triplets[:self.args.few]
            test_task_triplets = task_triplets[self.args.few:]

            if (len(task_triplets)) - self.args.few < 1:
                continue

            # Train (Inductive)
            task_entity_embedding = torch.cat([
                self.model(task_entity,
                           train_task_triplets,
                           use_cuda=self.use_cuda,
                           is_trans=False) for _ in range(self.args.mc_times)
            ]).view(-1, self.embedding_size)

            total_task_entity.append(task_entity)
            total_task_entity_embeddings.append(task_entity_embedding)
            total_train_task_triplets.extend(train_task_triplets)
            total_test_task_triplets.extend(test_task_triplets)
            total_test_task_triplets_dict[task_entity] = torch.LongTensor(
                test_task_triplets)

        # Train (Transductive)
        total_task_entity = np.array(total_task_entity)
        total_task_entity_embeddings = torch.cat(
            total_task_entity_embeddings).view(-1, self.args.mc_times,
                                               self.embedding_size)
        total_train_task_triplets = np.array(total_train_task_triplets)
        total_test_task_triplets = torch.LongTensor(total_test_task_triplets)

        self.model.train()

        task_entity_embeddings = torch.cat([
            self.model(
                total_task_entity,
                total_train_task_triplets,
                use_cuda=self.use_cuda,
                is_trans=True,
                total_unseen_entity_embedding=total_task_entity_embeddings[:,
                                                                           i])
            [0] for i in range(self.args.mc_times)
        ]).view(self.args.mc_times, -1, self.embedding_size)

        # Test
        total_task_entity = torch.from_numpy(total_task_entity)

        if self.use_cuda:
            total_task_entity = total_task_entity.cuda()

        my_total_triplets = []
        my_induc_triplets = []
        my_trans_triplets = []

        for task_entity, test_triplets in total_test_task_triplets_dict.items(
        ):

            if self.use_cuda:
                device = torch.device('cuda')
                test_triplets = test_triplets.cuda()

            for test_triplet in test_triplets:

                is_trans = self.is_trans(total_task_entity, test_triplet)

                my_total_triplets.append(test_triplet)

                if is_trans:
                    my_trans_triplets.append(test_triplet)
                else:
                    my_induc_triplets.append(test_triplet)

        my_total_triplets = torch.stack(my_total_triplets, dim=0)

        for mc_index in range(self.args.mc_times):

            y_prob, y = self.model.predict(total_task_entity,
                                           task_entity_embeddings[mc_index],
                                           my_total_triplets,
                                           target=None,
                                           use_cuda=self.use_cuda)

            if mc_index == 0:
                y_prob_mean = y_prob
            else:
                y_prob_mean += y_prob

        y_prob_mean = y_prob_mean / self.args.mc_times

        y_prob = y_prob_mean.detach().cpu().numpy()
        y = y.detach().cpu().numpy()
        total_results = utils.metric_report(y, y_prob)

        my_induc_triplets = torch.stack(my_induc_triplets, dim=0)

        for mc_index in range(self.args.mc_times):

            y_prob, y = self.model.predict(total_task_entity,
                                           task_entity_embeddings[mc_index],
                                           my_induc_triplets,
                                           target=None,
                                           use_cuda=self.use_cuda)

            if mc_index == 0:
                y_prob_mean = y_prob
            else:
                y_prob_mean += y_prob

        y_prob_mean = y_prob_mean / self.args.mc_times

        y_prob = y_prob_mean.detach().cpu().numpy()
        y = y.detach().cpu().numpy()
        total_induc_results = utils.metric_report(y, y_prob)

        my_trans_triplets = torch.stack(my_trans_triplets, dim=0)

        for mc_index in range(self.args.mc_times):

            y_prob, y = self.model.predict(total_task_entity,
                                           task_entity_embeddings[mc_index],
                                           my_trans_triplets,
                                           target=None,
                                           use_cuda=self.use_cuda)

            if mc_index == 0:
                y_prob_mean = y_prob
            else:
                y_prob_mean += y_prob

        y_prob_mean = y_prob_mean / self.args.mc_times

        y_prob = y_prob_mean.detach().cpu().numpy()
        y = y.detach().cpu().numpy()
        total_trans_results = utils.metric_report(y, y_prob)

        return total_results, total_induc_results, total_trans_results
Beispiel #4
0
def train_predict(batch_size=100, epochs=10, topk=30, L2=1e-8):
    patients = getTrainData(4000000)  # patients × visits × medical_code

    patients_num = len(patients)
    train_patient_num = int(patients_num * 0.8)
    patients_train = patients[0:train_patient_num]
    test_patient_num = patients_num - train_patient_num
    patients_test = patients[train_patient_num:]

    train_batch_num = int(np.ceil(float(train_patient_num) / batch_size))
    test_batch_num = int(np.ceil(float(test_patient_num) / batch_size))

    model = Dipole(input_dim=3393,
                   day_dim=200,
                   rnn_hiddendim=300,
                   output_dim=283)

    params = list(model.parameters())
    k = 0
    for i in params:
        l = 1
        print("该层的结构:" + str(list(i.size())))
        for j in i.size():
            l *= j
        print("该层参数和:" + str(l))
        k = k + l
    print("总参数数量和:" + str(k))

    optimizer = Adadelta(model.parameters(), lr=1, weight_decay=L2)
    loss_mce = nn.BCELoss(reduction='sum')
    model = model.cuda(device=1)

    for epoch in range(epochs):
        starttime = time.time()
        # 训练
        model.train()
        all_loss = 0.0
        for batch_index in range(train_batch_num):
            patients_batch = patients_train[batch_index *
                                            batch_size:(batch_index + 1) *
                                            batch_size]
            patients_batch_reshape, patients_lengths = model.padTrainMatrix(
                patients_batch)  # maxlen × n_samples × inputDimSize
            batch_x = patients_batch_reshape[0:-1]  # 获取前n-1个作为x,来预测后n-1天的值
            # batch_y = patients_batch_reshape[1:]
            batch_y = patients_batch_reshape[1:, :, :283]  # 取出药物作为y
            optimizer.zero_grad()
            # h0 = model.initHidden(batch_x.shape[1])
            batch_x = torch.tensor(batch_x, device=torch.device('cuda:1'))
            batch_y = torch.tensor(batch_y, device=torch.device('cuda:1'))
            y_hat = model(batch_x)
            mask = out_mask2(y_hat,
                             patients_lengths)  # 生成mask,用于将padding的部分输出置0
            # 通过mask,将对应序列长度外的网络输出置0
            y_hat = y_hat.mul(mask)
            batch_y = batch_y.mul(mask)
            # (seq_len, batch_size, out_dim)->(seq_len*batch_size*out_dim, 1)->(seq_len*batch_size*out_dim, )
            y_hat = y_hat.view(-1, 1).squeeze()
            batch_y = batch_y.view(-1, 1).squeeze()

            loss = loss_mce(y_hat, batch_y)
            loss.backward()
            optimizer.step()
            all_loss += loss.item()
        print("Train:Epoch-" + str(epoch) + ":" + str(all_loss) +
              " Train Time:" + str(time.time() - starttime))

        # 测试
        model.eval()
        NDCG = 0.0
        RECALL = 0.0
        DAYNUM = 0.0
        all_loss = 0.0
        gbert_pred = []
        gbert_true = []
        gbert_len = []

        for batch_index in range(test_batch_num):
            patients_batch = patients_test[batch_index *
                                           batch_size:(batch_index + 1) *
                                           batch_size]
            patients_batch_reshape, patients_lengths = model.padTrainMatrix(
                patients_batch)
            batch_x = patients_batch_reshape[0:-1]
            batch_y = patients_batch_reshape[1:, :, :283]
            batch_x = torch.tensor(batch_x, device=torch.device('cuda:1'))
            batch_y = torch.tensor(batch_y, device=torch.device('cuda:1'))
            y_hat = model(batch_x)
            mask = out_mask2(y_hat, patients_lengths)
            loss = loss_mce(y_hat.mul(mask), batch_y.mul(mask))

            all_loss += loss.item()
            y_hat = y_hat.detach().cpu().numpy()
            ndcg, recall, daynum = validation(y_hat, patients_batch,
                                              patients_lengths, topk)
            NDCG += ndcg
            RECALL += recall
            DAYNUM += daynum
            gbert_pred.append(y_hat)
            gbert_true.append(batch_y.cpu())
            gbert_len.append(patients_lengths)

        avg_NDCG = NDCG / DAYNUM
        avg_RECALL = RECALL / DAYNUM
        y_pred_all, y_true_all = batch_squeeze(gbert_pred, gbert_true,
                                               gbert_len)
        acc_container = metric_report(y_pred_all, y_true_all, 0.2)
        print("Test:Epoch-" + str(epoch) + " Loss:" + str(all_loss) +
              " Test Time:" + str(time.time() - starttime))
        print("Test:Epoch-" + str(epoch) + " NDCG:" + str(avg_NDCG) +
              " RECALL:" + str(avg_RECALL))
        print("Test:Epoch-" + str(epoch) + " Jaccard:" +
              str(acc_container['jaccard']) + " f1:" +
              str(acc_container['f1']) + " prauc:" +
              str(acc_container['prauc']) + " roauc:" +
              str(acc_container['auc']))

        print("")
Beispiel #5
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument("--model_name",
                        default='GBert-predict',
                        type=str,
                        required=False,
                        help="model name")
    parser.add_argument("--data_dir",
                        default='../data',
                        type=str,
                        required=False,
                        help="The input data dir.")
    parser.add_argument("--pretrain_dir",
                        default='../saved/GBert-pretraining',
                        type=str,
                        required=False,
                        help="pretraining model")
    parser.add_argument("--train_file",
                        default='data-multi-visit.pkl',
                        type=str,
                        required=False,
                        help="training data file.")
    parser.add_argument(
        "--output_dir",
        default='../saved/',
        type=str,
        required=False,
        help="The output directory where the model checkpoints will be written."
    )

    # Other parameters
    parser.add_argument("--use_pretrain",
                        default=False,
                        action='store_true',
                        help="is use pretrain")
    parser.add_argument("--graph",
                        default=False,
                        action='store_true',
                        help="if use ontology embedding")
    parser.add_argument("--therhold",
                        default=0.3,
                        type=float,
                        help="therhold.")
    parser.add_argument(
        "--max_seq_length",
        default=55,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        default=False,
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        default=True,
                        action='store_true',
                        help="Whether to run on the dev set.")
    parser.add_argument("--do_test",
                        default=True,
                        action='store_true',
                        help="Whether to run on the test set.")
    parser.add_argument("--train_batch_size",
                        default=1,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--learning_rate",
                        default=5e-4,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=20.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument('--seed',
                        type=int,
                        default=1203,
                        help="random seed for initialization")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")

    args = parser.parse_args()
    args.output_dir = os.path.join(args.output_dir, args.model_name)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    device = torch.device(
        "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    # if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train:
    #     raise ValueError(
    #         "Output directory ({}) already exists and is not empty.".format(args.output_dir))
    os.makedirs(args.output_dir, exist_ok=True)

    print("Loading Dataset")
    tokenizer, (train_dataset, eval_dataset, test_dataset) = load_dataset(args)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=RandomSampler(train_dataset),
                                  batch_size=1)
    eval_dataloader = DataLoader(eval_dataset,
                                 sampler=SequentialSampler(eval_dataset),
                                 batch_size=1)
    test_dataloader = DataLoader(test_dataset,
                                 sampler=SequentialSampler(test_dataset),
                                 batch_size=1)

    print('Loading Model: ' + args.model_name)
    # config = BertConfig(vocab_size_or_config_json_file=len(tokenizer.vocab.word2idx), side_len=train_dataset.side_len)
    # config.graph = args.graph
    # model = SeperateBertTransModel(config, tokenizer.dx_voc, tokenizer.rx_voc)
    if args.use_pretrain:
        logger.info("Use Pretraining model")
        model = GBERT_Predict.from_pretrained(args.pretrain_dir,
                                              tokenizer=tokenizer)
    else:
        config = BertConfig(
            vocab_size_or_config_json_file=len(tokenizer.vocab.word2idx))
        config.graph = args.graph
        model = GBERT_Predict(config, tokenizer)
    logger.info('# of model parameters: ' + str(get_n_params(model)))

    model.to(device)

    model_to_save = model.module if hasattr(
        model, 'module') else model  # Only save the model it-self
    rx_output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")

    # Prepare optimizer
    # num_train_optimization_steps = int(
    #     len(train_dataset) / args.train_batch_size) * args.num_train_epochs
    # param_optimizer = list(model.named_parameters())
    # no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    # optimizer_grouped_parameters = [
    #     {'params': [p for n, p in param_optimizer if not any(
    #         nd in n for nd in no_decay)], 'weight_decay': 0.01},
    #     {'params': [p for n, p in param_optimizer if any(
    #         nd in n for nd in no_decay)], 'weight_decay': 0.0}
    # ]

    # optimizer = BertAdam(optimizer_grouped_parameters,
    #                      lr=args.learning_rate,
    #                      warmup=args.warmup_proportion,
    #                      t_total=num_train_optimization_steps)
    optimizer = Adam(model.parameters(), lr=args.learning_rate)

    global_step = 0
    if args.do_train:
        writer = SummaryWriter(args.output_dir)

        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_dataset))
        logger.info("  Batch size = %d", 1)

        dx_acc_best, rx_acc_best = 0, 0
        acc_name = 'prauc'
        dx_history = {'prauc': []}
        rx_history = {'prauc': []}

        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            print('')
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            prog_iter = tqdm(train_dataloader, leave=False, desc='Training')
            model.train()
            for _, batch in enumerate(prog_iter):
                batch = tuple(t.to(device) for t in batch)
                input_ids, dx_labels, rx_labels = batch
                input_ids, dx_labels, rx_labels = input_ids.squeeze(
                    dim=0), dx_labels.squeeze(dim=0), rx_labels.squeeze(dim=0)
                loss, rx_logits = model(input_ids,
                                        dx_labels=dx_labels,
                                        rx_labels=rx_labels,
                                        epoch=global_step)
                loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += 1
                nb_tr_steps += 1

                # Display loss
                prog_iter.set_postfix(loss='%.4f' % (tr_loss / nb_tr_steps))

                optimizer.step()
                optimizer.zero_grad()

            writer.add_scalar('train/loss', tr_loss / nb_tr_steps, global_step)
            global_step += 1

            if args.do_eval:
                print('')
                logger.info("***** Running eval *****")
                model.eval()
                dx_y_preds = []
                dx_y_trues = []
                rx_y_preds = []
                rx_y_trues = []
                for eval_input in tqdm(eval_dataloader, desc="Evaluating"):
                    eval_input = tuple(t.to(device) for t in eval_input)
                    input_ids, dx_labels, rx_labels = eval_input
                    input_ids, dx_labels, rx_labels = input_ids.squeeze(
                    ), dx_labels.squeeze(), rx_labels.squeeze(dim=0)
                    with torch.no_grad():
                        loss, rx_logits = model(input_ids,
                                                dx_labels=dx_labels,
                                                rx_labels=rx_labels)
                        rx_y_preds.append(t2n(torch.sigmoid(rx_logits)))
                        rx_y_trues.append(t2n(rx_labels))
                        # dx_y_preds.append(t2n(torch.sigmoid(dx_logits)))
                        # dx_y_trues.append(
                        #     t2n(dx_labels.view(-1, len(tokenizer.dx_voc.word2idx))))
                        # rx_y_preds.append(t2n(torch.sigmoid(rx_logits))[
                        #                   :, tokenizer.rx_singe2multi])
                        # rx_y_trues.append(
                        #     t2n(rx_labels)[:, tokenizer.rx_singe2multi])

                print('')
                # dx_acc_container = metric_report(np.concatenate(dx_y_preds, axis=0), np.concatenate(dx_y_trues, axis=0),
                #                                  args.therhold)
                rx_acc_container = metric_report(
                    np.concatenate(rx_y_preds, axis=0),
                    np.concatenate(rx_y_trues, axis=0), args.therhold)
                for k, v in rx_acc_container.items():
                    writer.add_scalar('eval/{}'.format(k), v, global_step)

                if rx_acc_container[acc_name] > rx_acc_best:
                    rx_acc_best = rx_acc_container[acc_name]
                    # save model
                    torch.save(model_to_save.state_dict(),
                               rx_output_model_file)

        with open(os.path.join(args.output_dir, 'bert_config.json'),
                  'w',
                  encoding='utf-8') as fout:
            fout.write(model.config.to_json_string())

    if args.do_test:
        logger.info("***** Running test *****")
        logger.info("  Num examples = %d", len(test_dataset))
        logger.info("  Batch size = %d", 1)

        def test(task=0):
            # Load a trained model that you have fine-tuned
            model_state_dict = torch.load(rx_output_model_file)
            model.load_state_dict(model_state_dict)
            model.to(device)

            model.eval()
            y_preds = []
            y_trues = []
            for test_input in tqdm(test_dataloader, desc="Testing"):
                test_input = tuple(t.to(device) for t in test_input)
                input_ids, dx_labels, rx_labels = test_input
                input_ids, dx_labels, rx_labels = input_ids.squeeze(
                ), dx_labels.squeeze(), rx_labels.squeeze(dim=0)
                with torch.no_grad():
                    loss, rx_logits = model(input_ids,
                                            dx_labels=dx_labels,
                                            rx_labels=rx_labels)
                    y_preds.append(t2n(torch.sigmoid(rx_logits)))
                    y_trues.append(t2n(rx_labels))

            print('')
            acc_container = metric_report(np.concatenate(y_preds, axis=0),
                                          np.concatenate(y_trues, axis=0),
                                          args.therhold)

            # save report
            if args.do_train:
                for k, v in acc_container.items():
                    writer.add_scalar('test/{}'.format(k), v, 0)

            return acc_container

        test(task=0)
Beispiel #6
0
def train_predict(batch_size=100, epochs=10, topk=30):
    patients = getTrainData(4000000)  # patients × visits × medical_code

    patients_num = len(patients)
    train_patient_num = int(patients_num * 0.8)
    patients_train = patients[0:train_patient_num]
    test_patient_num = patients_num - train_patient_num
    patients_test = patients[train_patient_num:]

    train_batch_num = int(np.ceil(float(train_patient_num) / batch_size))
    test_batch_num = int(np.ceil(float(test_patient_num) / batch_size))

    retain = Retain(inputDimSize=3393, embDimSize=300, alphaHiddenDimSize=200, betaHiddenDimSize=200, outputDimSize=283)

    for epoch in range(epochs):
        starttime = time.time()
        # 训练
        loss = 0.0
        for batch_index in range(train_batch_num):
            patients_batch = patients_train[batch_index * batch_size:(batch_index + 1) * batch_size]
            patients_batch_reshape, patients_lengths = retain.padTrainMatrix(
                patients_batch)  # maxlen × n_samples × inputDimSize
            batch_x = patients_batch_reshape[0:-1]  # 获取前n-1个作为x,来预测后n-1天的值
            # batch_y = patients_batch_reshape[1:]
            batch_y = patients_batch_reshape[1:, :, :283]

            loss += retain.startTrain(batch_x, batch_y, patients_lengths)
        print("Train:Epoch-" + str(epoch) + ":" + str(loss) + " Train Time:" + str(time.time() - starttime))

        # 测试
        NDCG = 0.0
        RECALL = 0.0
        DAYNUM = 0.0
        all_loss = 0.0
        gbert_pred = []
        gbert_true = []
        gbert_len = []

        for batch_index in range(test_batch_num):
            patients_batch = patients_test[batch_index * batch_size:(batch_index + 1) * batch_size]
            patients_batch_reshape, patients_lengths = retain.padTrainMatrix(patients_batch)
            batch_x = patients_batch_reshape[0:-1]
            # batch_y = patients_batch_reshape[1:]
            batch_y = patients_batch_reshape[1:, :, :283]

            loss, y_hat = retain.get_reslut(batch_x, batch_y, patients_lengths)
            all_loss += loss
            ndcg, recall, daynum = validation(y_hat, patients_batch, patients_lengths, topk)
            # acc_container = metric_report(y_hat, patients_batch)
            NDCG += ndcg
            RECALL += recall
            DAYNUM += daynum
            gbert_pred.append(y_hat)
            gbert_true.append(batch_y)
            gbert_len.append(patients_lengths)

        avg_NDCG = NDCG / DAYNUM
        avg_RECALL = RECALL / DAYNUM
        y_pred_all, y_true_all = batch_squeeze(gbert_pred, gbert_true, gbert_len)
        acc_container = metric_report(y_pred_all, y_true_all, 0.2)
        print("Test:Epoch-" + str(epoch) + " Loss:" + str(all_loss) + " Test Time:" + str(time.time() - starttime))
        print("Test:Epoch-" + str(epoch) + " NDCG:" + str(avg_NDCG) + " RECALL:" + str(avg_RECALL))
        print("Test:Epoch-" + str(epoch) + " Jaccard:" + str(acc_container['jaccard']) +
              " f1:" + str(acc_container['f1']) + " prauc:" + str(acc_container['prauc']) + " auc:" + str(
            acc_container['auc']))

        print("")
Beispiel #7
0
    def eval(self, eval_type='test'):

        self.model.eval()

        if eval_type == 'valid':
            test_task_dict = self.meta_valid_task_entity_to_triplets
            test_task_pool = list(
                self.meta_valid_task_entity_to_triplets.keys())
        elif eval_type == 'test':
            test_task_dict = self.meta_test_task_entity_to_triplets
            test_task_pool = list(
                self.meta_test_task_entity_to_triplets.keys())
        else:
            raise ValueError("Eval Type <{}> is Wrong".format(eval_type))

        total_unseen_entity = []
        total_unseen_entity_embedding = []
        total_train_triplets = []
        total_test_triplets = []
        total_test_triplets_dict = dict()

        for unseen_entity in test_task_pool:

            triplets = test_task_dict[unseen_entity]
            triplets = np.array(triplets)
            heads, relations, tails = triplets.transpose()

            train_triplets = triplets[:self.args.few]
            test_triplets = triplets[self.args.few:]

            if (len(triplets)) - self.args.few < 1:
                continue

            # Train (Inductive)
            unseen_entity_embedding = self.model(unseen_entity,
                                                 train_triplets,
                                                 use_cuda=self.use_cuda,
                                                 is_trans=False)
            total_unseen_entity.append(unseen_entity)
            total_unseen_entity_embedding.append(unseen_entity_embedding)
            total_train_triplets.extend(train_triplets)
            total_test_triplets.extend(test_triplets)
            total_test_triplets_dict[unseen_entity] = torch.LongTensor(
                test_triplets)

        # Train (Transductive)
        total_unseen_entity = np.array(total_unseen_entity)
        total_unseen_entity_embedding = torch.cat(
            total_unseen_entity_embedding).view(-1, self.embedding_size)
        total_train_triplets = np.array(total_train_triplets)

        samples = total_test_triplets
        samples = torch.LongTensor(samples)

        if self.use_cuda:
            samples = samples.cuda()

        unseen_entity_embeddings, _, _ = self.model(
            total_unseen_entity,
            total_train_triplets,
            use_cuda=self.use_cuda,
            is_trans=True,
            total_unseen_entity_embedding=total_unseen_entity_embedding)
        y_prob, y = self.model.predict(total_unseen_entity,
                                       unseen_entity_embeddings,
                                       samples,
                                       target=None,
                                       use_cuda=self.use_cuda)

        y_prob = y_prob.detach().cpu().numpy()
        y = y.detach().cpu().numpy()

        results = utils.metric_report(y, y_prob)

        return results