Ejemplo n.º 1
0
    gc.collect()

    # print(model.parameters)
    print_params(model)

    start_epoch = 1
    pretrain_model = opt.pretrain_model
    lr = opt.lr
    model_name = opt.model_name

    if pretrain_model != '':
        chkpt = torch.load(pretrain_model, map_location=torch.device('cpu'))
        model.load_state_dict(chkpt['checkpoint'])
        logging('load checkpoint from {}'.format(pretrain_model))
    else:
        assert 1 == 2, 'please provide checkpoint to evaluate.'

    model = get_cuda(model)
    model.eval()

    f1, auc, pr_x, pr_y = test(model,
                               test_loader,
                               model_name,
                               id2rel=id2rel,
                               input_theta=opt.input_theta,
                               output=True,
                               test_prefix='test',
                               is_test=True,
                               ours=False)
    print('finished')
Ejemplo n.º 2
0
def train(opt, isbody=False):
    train_ds = MedicalExtractionDataset(opt.train_data)
    dev_ds = MedicalExtractionDataset(opt.dev_data)
    test_ds = MedicalExtractionDataset(opt.test_data)

    dev_dl = DataLoader(dev_ds,
                        batch_size=opt.dev_batch_size,
                        shuffle=False,
                        num_workers=opt.num_worker)
    test_dl = DataLoader(test_ds,
                         batch_size=opt.dev_batch_size,
                         shuffle=False,
                         num_workers=opt.num_worker)

    if isbody:
        logging('training for body')
        model = MedicalExtractionModelForBody(opt)
    else:
        logging('training for subject, decorate and body')
        model = MedicalExtractionModel(opt)
    # print(model.parameters)
    print_params(model)

    start_epoch = 1
    learning_rate = opt.lr
    total_epochs = opt.epochs
    pretrain_model = opt.pretrain_model
    model_name = opt.model_name  # 要保存的模型名字

    # load pretrained model
    if pretrain_model != '' and not isbody:
        chkpt = torch.load(pretrain_model, map_location=torch.device('cpu'))
        model.load_state_dict(chkpt['checkpoints'])
        logging('load model from {}'.format(pretrain_model))
        start_epoch = chkpt['epoch'] + 1
        learning_rate = chkpt['learning_rate']
        logging('resume from epoch {} with learning_rate {}'.format(
            start_epoch, learning_rate))
    else:
        logging('training from scratch with learning_rate {}'.format(
            learning_rate))

    model = get_cuda(model)

    num_train_steps = int(len(train_ds) / opt.batch_size * opt.epochs)
    param_optimizer = list(model.named_parameters())
    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    optimizer_parameters = [
        {
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.001
        },
        {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        },
    ]

    # optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    optimizer = optim.AdamW(optimizer_parameters, lr=learning_rate)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=opt.num_warmup_steps,
        num_training_steps=num_train_steps)
    threshold = opt.threshold
    criterion = nn.BCEWithLogitsLoss(reduction='none')

    checkpoint_dir = opt.checkpoint_dir
    if not os.path.exists(checkpoint_dir):
        os.mkdir(checkpoint_dir)

    es = EarlyStopping(patience=opt.patience, mode="min", criterion='val loss')
    for epoch in range(start_epoch, total_epochs + 1):
        train_loss = 0.0
        model.train()
        train_dl = DataLoader(train_ds,
                              batch_size=opt.batch_size,
                              shuffle=True,
                              num_workers=opt.num_worker)
        tk_train = tqdm(train_dl, total=len(train_dl))
        for batch in tk_train:
            optimizer.zero_grad()
            subject_target_ids = batch['subject_target_ids']
            decorate_target_ids = batch['decorate_target_ids']
            freq_target_ids = batch['freq_target_ids']
            body_target_ids = batch['body_target_ids']
            mask = batch['mask'].float().unsqueeze(-1)
            body_mask = batch['body_mask'].unsqueeze(-1)
            loss = None
            if isbody:
                body_logits = model(
                    input_ids=batch['body_input_ids'],
                    attention_mask=batch['body_mask'],
                    token_type_ids=batch['body_token_type_ids'])
                loss = torch.sum(
                    criterion(body_logits, body_target_ids) *
                    body_mask) / torch.sum(body_mask)
            else:
                subject_logits, decorate_logits, freq_logits = model(
                    input_ids=batch['input_ids'],
                    attention_mask=batch['mask'],
                    token_type_ids=batch['token_type_ids'])
                loss = torch.sum(
                    (criterion(subject_logits, subject_target_ids) +
                     criterion(decorate_logits, decorate_target_ids) +
                     criterion(freq_logits, freq_target_ids)) *
                    mask) / torch.sum(mask)

            loss.backward()
            optimizer.step()
            scheduler.step()

            tk_train.set_postfix(train_loss='{:5.3f} / 1000'.format(
                1000 * loss.item()),
                                 epoch='{:2d}'.format(epoch))
            train_loss += loss.item() * subject_target_ids.shape[0]

        avg_train_loss = train_loss * 1000 / len(train_ds)
        print('train loss per example: {:5.3f} / 1000'.format(avg_train_loss))

        avg_val_loss = test(model,
                            dev_ds,
                            dev_dl,
                            criterion,
                            threshold,
                            'val',
                            isbody=isbody)

        # 保留最佳模型方便evaluation
        if isbody:
            save_model_path = os.path.join(checkpoint_dir,
                                           model_name + '_body_best.pt')
        else:
            save_model_path = os.path.join(checkpoint_dir,
                                           model_name + '_best.pt')

        es(avg_val_loss,
           model,
           model_path=save_model_path,
           epoch=epoch,
           learning_rate=learning_rate)
        if es.early_stop:
            print("Early stopping")
            break

        # 保存epoch的模型方便断点续训
        if epoch % opt.save_model_freq == 0:
            if isbody:
                save_model_path = os.path.join(
                    checkpoint_dir, model_name + '_body_{}.pt'.format(epoch))
            else:
                save_model_path = os.path.join(
                    checkpoint_dir, model_name + '_{}.pt'.format(epoch))
            torch.save(
                {
                    'epoch': epoch,
                    'learning_rate': learning_rate,
                    'checkpoints': model.state_dict()
                }, save_model_path)

    # load best model and test
    if isbody:
        best_model_path = os.path.join(checkpoint_dir,
                                       model_name + '_body_best.pt')
    else:
        best_model_path = os.path.join(checkpoint_dir, model_name + '_best.pt')
    chkpt = torch.load(best_model_path, map_location=torch.device('cpu'))
    model.load_state_dict(chkpt['checkpoints'])
    if isbody:
        logging('load best body model from {} and test ...'.format(
            best_model_path))
    else:
        logging('load best model from {} and test ...'.format(best_model_path))
    test(model, test_ds, test_dl, criterion, threshold, 'test', isbody)
    model.cpu()
Ejemplo n.º 3
0
def train(opt):
    if opt.use_model == 'bert':
        # datasets
        train_set = BERTDGLREDataset(opt.train_set, opt.train_set_save, word2id, ner2id, rel2id, dataset_type='train',
                                     opt=opt)
        # dev_set = BERTDGLREDataset(opt.dev_set, opt.dev_set_save, word2id, ner2id, rel2id, dataset_type='dev',
        #                            instance_in_train=train_set.instance_in_train, opt=opt)

        # dataloaders
        train_loader = DGLREDataloader(train_set, batch_size=opt.batch_size, shuffle=True,
                                       negativa_alpha=opt.negativa_alpha)
        # dev_loader = DGLREDataloader(dev_set, batch_size=opt.test_batch_size, dataset_type='dev')

        model = GAIN_BERT(opt)

    elif opt.use_model == 'bilstm':
        # datasets
        train_set = DGLREDataset(opt.train_set, opt.train_set_save, word2id, ner2id, rel2id, dataset_type='train',
                                 opt=opt)
        # dev_set = DGLREDataset(opt.dev_set, opt.dev_set_save, word2id, ner2id, rel2id, dataset_type='dev',
        #                        instance_in_train=train_set.instance_in_train, opt=opt)

        # dataloaders
        train_loader = DGLREDataloader(train_set, batch_size=opt.batch_size, shuffle=True,
                                       negativa_alpha=opt.negativa_alpha)
        # dev_loader = DGLREDataloader(dev_set, batch_size=opt.test_batch_size, dataset_type='dev')

        model = GAIN_GloVe(opt)
    else:
        assert 1 == 2, 'please choose a model from [bert, bilstm].'

    print(model.parameters)
    print_params(model)

    start_epoch = 1
    pretrain_model = opt.pretrain_model
    lr = opt.lr
    model_name = opt.model_name

    if pretrain_model != '':
        chkpt = torch.load(pretrain_model, map_location=torch.device('cpu'))
        model.load_state_dict(chkpt['checkpoint'])
        logging('load model from {}'.format(pretrain_model))
        start_epoch = chkpt['epoch'] + 1
        lr = chkpt['lr']
        logging('resume from epoch {} with lr {}'.format(start_epoch, lr))
    else:
        logging('training from scratch with lr {}'.format(lr))

    model = get_cuda(model)

    if opt.use_model == 'bert':
        bert_param_ids = list(map(id, model.bert.parameters()))
        base_params = filter(lambda p: p.requires_grad and id(p) not in bert_param_ids, model.parameters())

        optimizer = optim.AdamW([
            {'params': model.bert.parameters(), 'lr': lr * 0.01},
            {'params': base_params, 'weight_decay': opt.weight_decay}
        ], lr=lr)
    else:
        optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr,
                                weight_decay=opt.weight_decay)

    BCE = nn.BCEWithLogitsLoss(reduction='none')

    if opt.coslr:
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(opt.epoch // 4) + 1)

    checkpoint_dir = opt.checkpoint_dir
    if not os.path.exists(checkpoint_dir):
        os.mkdir(checkpoint_dir)
    fig_result_dir = opt.fig_result_dir
    if not os.path.exists(fig_result_dir):
        os.mkdir(fig_result_dir)

    best_ign_auc = 0.0
    best_ign_f1 = 0.0
    best_epoch = 0

    model.train()

    global_step = 0
    total_loss = 0

    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.ylim(0.0, 1.0)
    plt.xlim(0.0, 1.0)
    plt.title('Precision-Recall')
    plt.grid(True)

    acc_NA, acc_not_NA, acc_total = Accuracy(), Accuracy(), Accuracy()
    logging('begin..')

    for epoch in range(start_epoch, opt.epoch + 1):
        start_time = time.time()
        for acc in [acc_NA, acc_not_NA, acc_total]:
            acc.clear()

        for ii, d in enumerate(train_loader):
            relation_multi_label = d['relation_multi_label']
            relation_mask = d['relation_mask']
            relation_label = d['relation_label']

            predictions = model(words=d['context_idxs'],
                                src_lengths=d['context_word_length'],
                                mask=d['context_word_mask'],
                                entity_type=d['context_ner'],
                                entity_id=d['context_pos'],
                                mention_id=d['context_mention'],
                                distance=None,
                                entity2mention_table=d['entity2mention_table'],
                                graphs=d['graphs'],
                                h_t_pairs=d['h_t_pairs'],
                                relation_mask=relation_mask,
                                path_table=d['path_table'],
                                entity_graphs=d['entity_graphs'],
                                ht_pair_distance=d['ht_pair_distance']
                                )
            loss = torch.sum(BCE(predictions, relation_multi_label) * relation_mask.unsqueeze(2)) / (
                    opt.relation_nums * torch.sum(relation_mask))

            optimizer.zero_grad()
            loss.backward()

            if opt.clip != -1:
                nn.utils.clip_grad_value_(model.parameters(), opt.clip)
            optimizer.step()
            if opt.coslr:
                scheduler.step(epoch)

            output = torch.argmax(predictions, dim=-1)
            output = output.data.cpu().numpy()
            relation_label = relation_label.data.cpu().numpy()

            for i in range(output.shape[0]):
                for j in range(output.shape[1]):
                    label = relation_label[i][j]
                    if label < 0:
                        break

                    is_correct = (output[i][j] == label)
                    if label == 0:
                        acc_NA.add(is_correct)
                    else:
                        acc_not_NA.add(is_correct)

                    acc_total.add(is_correct)

            global_step += 1
            total_loss += loss.item()

            log_step = opt.log_step
            if global_step % log_step == 0:
                cur_loss = total_loss / log_step
                elapsed = time.time() - start_time
                logging(
                    '| epoch {:2d} | step {:4d} |  ms/b {:5.2f} | train loss {:5.3f} | NA acc: {:4.2f} | not NA acc: {:4.2f}  | tot acc: {:4.2f} '.format(
                        epoch, global_step, elapsed * 1000 / log_step, cur_loss * 1000, acc_NA.get(), acc_not_NA.get(),
                        acc_total.get()))
                total_loss = 0
                start_time = time.time()

        if epoch % opt.test_epoch == 0:
            logging('-' * 89)
            eval_start_time = time.time()
            model.eval()
            ign_f1, ign_auc, pr_x, pr_y = test(model, dev_loader, model_name, id2rel=id2rel)
            model.train()
            logging('| epoch {:3d} | time: {:5.2f}s'.format(epoch, time.time() - eval_start_time))
            logging('-' * 89)

            if ign_f1 > best_ign_f1:
                best_ign_f1 = ign_f1
                best_ign_auc = ign_auc
                best_epoch = epoch
                path = os.path.join(checkpoint_dir, model_name + '_best.pt')
                torch.save({
                    'epoch': epoch,
                    'checkpoint': model.state_dict(),
                    'lr': lr,
                    'best_ign_f1': ign_f1,
                    'best_ign_auc': ign_auc,
                    'best_epoch': epoch
                }, path)

                plt.plot(pr_x, pr_y, lw=2, label=str(epoch))
                plt.legend(loc="upper right")
                plt.savefig(os.path.join(fig_result_dir, model_name))

        if epoch % opt.save_model_freq == 0:
            path = os.path.join(checkpoint_dir, model_name + '_{}.pt'.format(epoch))
            torch.save({
                'epoch': epoch,
                'lr': lr,
                'checkpoint': model.state_dict()
            }, path)

    print("Finish training")
    print("Best epoch = %d | Best Ign F1 = %f" % (best_epoch, best_ign_f1))
    print("Storing best result...")
    print("Finish storing")
Ejemplo n.º 4
0
def train(opt):
    train_ds = MedicalExtractionDataset(opt.train_data)
    dev_ds = MedicalExtractionDataset(opt.dev_data)

    dev_dl = DataLoader(dev_ds,
                        batch_size=opt.dev_batch_size,
                        shuffle=False,
                        num_workers=1
                        )

    model = MedicalExtractionModel(opt)
    print(model.parameters)
    print_params(model)

    start_epoch = 1
    learning_rate = opt.lr
    total_epochs = opt.epochs
    log_step = opt.log_step
    pretrain_model = opt.pretrain_model
    model_name = opt.model_name  # 要保存的模型名字

    # load pretrained model
    if pretrain_model != '':
        chkpt = torch.load(pretrain_model, map_location=torch.device('cpu'))
        model.load_state_dict(chkpt['checkpoints'])
        logging('load model from {}'.format(pretrain_model))
        start_epoch = chkpt['epoch'] + 1
        learning_rate = chkpt['learning_rate']
        logging('resume from epoch {} with learning_rate {}'.format(start_epoch, learning_rate))
    else:
        logging('training from scratch with learning_rate {}'.format(learning_rate))

    model = get_cuda(model)

    # TODO 如果用Bert可以改成AdamW
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # TODO loss function
    # criterion =

    checkpoint_dir = opt.checkpoint_dir
    if not os.path.exists(checkpoint_dir):
        os.mkdir(checkpoint_dir)

    # training process
    # 1.
    global_step = 0
    total_loss = 0
    for epoch in range(1, total_epochs + 1):
        start_time = time.time()
        train_dl = DataLoader(train_ds,
                              batch_size=opt.batch_size,
                              shuffle=True,
                              num_workers=8
                              )
        model.train()
        for batch in train_dl:
            optimizer.zero_grad()
            # TODO 喂数据

            # TODO loss计算
            loss = None

            loss.backward()
            optimizer.step()

            global_step += 1
            total_loss += loss.item()
            if global_step % log_step == 0:
                cur_loss = total_loss / log_step
                elapsed = time.time() - start_time
                logging(
                    '| epoch {:2d} | step {:4d} |  ms/b {:5.2f} | train loss {:5.3f} '.format(
                        epoch, global_step, elapsed * 1000 / log_step, cur_loss * 1000))
                total_loss = 0
                start_time = time.time()

        if epoch % opt.test_epoch == 0:
            model.eval()
            with torch.no_grad():
                for batch in dev_dl:
                    # TODO 在验证集上测试
                    pass

        # save model
        # TODO 可以改成只save在dev上最佳的模型
        if epoch % opt.save_model_freq == 0:
            path = os.path.join(checkpoint_dir, model_name + '_{}.pt'.format(epoch))
            torch.save({
                'epoch': epoch,
                'learning_rate': learning_rate,
                'checkpoint': model.state_dict()
            }, path)
Ejemplo n.º 5
0
                         num_workers=opt.num_worker)

    body_model = MedicalExtractionModelForBody(opt)
    model = MedicalExtractionModel(opt)

    # load best model and test
    checkpoint_dir = opt.checkpoint_dir
    model_name = opt.model_name
    criterion = nn.BCEWithLogitsLoss(reduction='none')
    threshold = opt.threshold

    body_best_model_path = os.path.join(checkpoint_dir,
                                        model_name + '_body_best.pt')
    chkpt = torch.load(body_best_model_path, map_location=torch.device('cpu'))
    body_model.load_state_dict(chkpt['checkpoints'])

    best_model_path = os.path.join(checkpoint_dir, model_name + '_best.pt')
    chkpt = torch.load(best_model_path, map_location=torch.device('cpu'))
    model.load_state_dict(chkpt['checkpoints'])

    body_model = get_cuda(body_model)
    print('test body ...')
    test(body_model, test_ds, test_dl, criterion, threshold, 'test', True,
         True)
    body_model.cpu()

    model = get_cuda(model)
    print('test others ...')
    test(model, test_ds, test_dl, criterion, threshold, 'test', False, True)
    model.cpu()
Ejemplo n.º 6
0
    def forward(self, **params):
        '''
            words: [batch_size, max_length]
            src_lengths: [batchs_size]
            mask: [batch_size, max_length]
            entity_type: [batch_size, max_length]
            entity_id: [batch_size, max_length]
            mention_id: [batch_size, max_length]
            distance: [batch_size, max_length]
            entity2mention_table: list of [local_entity_num, local_mention_num]
            graphs: list of DGLHeteroGraph
            h_t_pairs: [batch_size, h_t_limit, 2]
        '''
        src = self.word_emb(params['words'])
        mask = params['mask']
        bsz, slen, _ = src.size()

        if self.config.use_entity_type:
            src = torch.cat(
                [src, self.entity_type_emb(params['entity_type'])], dim=-1)

        if self.config.use_entity_id:
            src = torch.cat([src, self.entity_id_emb(params['entity_id'])],
                            dim=-1)

        # src: [batch_size, slen, encoder_input_size]
        # src_lengths: [batchs_size]

        encoder_outputs, (output_h_t,
                          _) = self.encoder(src, params['src_lengths'])
        encoder_outputs[mask == 0] = 0
        # encoder_outputs: [batch_size, slen, 2*encoder_hid_size]
        # output_h_t: [batch_size, 2*encoder_hid_size]

        graphs = params['graphs']

        mention_id = params['mention_id']
        features = None

        for i in range(len(graphs)):
            encoder_output = encoder_outputs[i]  # [slen, 2*encoder_hid_size]
            mention_num = torch.max(mention_id[i])
            mention_index = get_cuda(
                (torch.arange(mention_num) + 1).unsqueeze(1).expand(
                    -1, slen))  # [mention_num, slen]
            mentions = mention_id[i].unsqueeze(0).expand(
                mention_num, -1)  # [mention_num, slen]
            select_metrix = (
                mention_index == mentions).float()  # [mention_num, slen]
            # average word -> mention
            word_total_numbers = torch.sum(select_metrix,
                                           dim=-1).unsqueeze(-1).expand(
                                               -1, slen)  # [mention_num, slen]
            select_metrix = torch.where(word_total_numbers > 0,
                                        select_metrix / word_total_numbers,
                                        select_metrix)
            x = torch.mm(select_metrix,
                         encoder_output)  # [mention_num, 2*encoder_hid_size]

            x = torch.cat((output_h_t[i].unsqueeze(0), x), dim=0)

            if features is None:
                features = x
            else:
                features = torch.cat((features, x), dim=0)

        graph_big = dgl.batch_hetero(graphs)
        output_features = [features]

        for GCN_layer in self.GCN_layers:
            features = GCN_layer(
                graph_big,
                {"node": features})["node"]  # [total_mention_nums, gcn_dim]
            output_features.append(features)

        output_feature = torch.cat(output_features, dim=-1)

        graphs = dgl.unbatch_hetero(graph_big)

        # mention -> entity
        entity2mention_table = params[
            'entity2mention_table']  # list of [entity_num, mention_num]
        entity_num = torch.max(params['entity_id'])
        entity_bank = get_cuda(torch.Tensor(bsz, entity_num, self.bank_size))
        global_info = get_cuda(torch.Tensor(bsz, self.bank_size))

        cur_idx = 0
        entity_graph_feature = None
        for i in range(len(graphs)):
            # average mention -> entity
            select_metrix = entity2mention_table[i].float(
            )  # [local_entity_num, mention_num]
            select_metrix[0][0] = 1
            mention_nums = torch.sum(select_metrix,
                                     dim=-1).unsqueeze(-1).expand(
                                         -1, select_metrix.size(1))
            select_metrix = torch.where(mention_nums > 0,
                                        select_metrix / mention_nums,
                                        select_metrix)
            node_num = graphs[i].number_of_nodes('node')
            entity_representation = torch.mm(
                select_metrix, output_feature[cur_idx:cur_idx + node_num])
            entity_bank[i, :select_metrix.size(0) -
                        1] = entity_representation[1:]
            global_info[i] = output_feature[cur_idx]
            cur_idx += node_num

            if entity_graph_feature is None:
                entity_graph_feature = entity_representation[
                    1:, -self.config.gcn_dim:]
            else:
                entity_graph_feature = torch.cat(
                    (entity_graph_feature,
                     entity_representation[1:, -self.config.gcn_dim:]),
                    dim=0)

        h_t_pairs = params['h_t_pairs']
        h_t_pairs = h_t_pairs + (h_t_pairs
                                 == 0).long() - 1  # [batch_size, h_t_limit, 2]
        h_t_limit = h_t_pairs.size(1)

        # [batch_size, h_t_limit, bank_size]
        h_entity_index = h_t_pairs[:, :, 0].unsqueeze(-1).expand(
            -1, -1, self.bank_size)
        t_entity_index = h_t_pairs[:, :, 1].unsqueeze(-1).expand(
            -1, -1, self.bank_size)

        # [batch_size, h_t_limit, bank_size]
        h_entity = torch.gather(input=entity_bank, dim=1, index=h_entity_index)
        t_entity = torch.gather(input=entity_bank, dim=1, index=t_entity_index)

        global_info = global_info.unsqueeze(1).expand(-1, h_t_limit, -1)

        entity_graphs = params['entity_graphs']
        entity_graph_big = dgl.batch(entity_graphs)
        self.edge_layer(entity_graph_big, entity_graph_feature)
        entity_graphs = dgl.unbatch(entity_graph_big)
        path_info = get_cuda(torch.zeros((bsz, h_t_limit, self.gcn_dim * 4)))
        relation_mask = params['relation_mask']
        path_table = params['path_table']
        for i in range(len(entity_graphs)):
            path_t = path_table[i]
            for j in range(h_t_limit):
                if relation_mask is not None and relation_mask[i,
                                                               j].item() == 0:
                    break

                h = h_t_pairs[i, j, 0].item()
                t = h_t_pairs[i, j, 1].item()
                # for evaluate
                if relation_mask is None and h == 0 and t == 0:
                    continue

                if (h + 1, t + 1) in path_t:
                    v = [val - 1 for val in path_t[(h + 1, t + 1)]]
                elif (t + 1, h + 1) in path_t:
                    v = [val - 1 for val in path_t[(t + 1, h + 1)]]
                else:
                    print(h, t, v)
                    print(entity_graphs[i].all_edges())
                    print(h_t_pairs)
                    print(relation_mask)
                    assert 1 == 2

                middle_node_num = len(v)

                if middle_node_num == 0:
                    continue

                # forward
                edge_ids = get_cuda(entity_graphs[i].edge_ids(
                    [h for _ in range(middle_node_num)], v))
                forward_first = torch.index_select(entity_graphs[i].edata['h'],
                                                   dim=0,
                                                   index=edge_ids)
                edge_ids = get_cuda(entity_graphs[i].edge_ids(
                    v, [t for _ in range(middle_node_num)]))
                forward_second = torch.index_select(
                    entity_graphs[i].edata['h'], dim=0, index=edge_ids)

                # backward
                edge_ids = get_cuda(entity_graphs[i].edge_ids(
                    [t for _ in range(middle_node_num)], v))
                backward_first = torch.index_select(
                    entity_graphs[i].edata['h'], dim=0, index=edge_ids)
                edge_ids = get_cuda(entity_graphs[i].edge_ids(
                    v, [h for _ in range(middle_node_num)]))
                backward_second = torch.index_select(
                    entity_graphs[i].edata['h'], dim=0, index=edge_ids)

                tmp_path_info = torch.cat((forward_first, forward_second,
                                           backward_first, backward_second),
                                          dim=-1)
                _, attn_value = self.attention(
                    torch.cat((h_entity[i, j], t_entity[i, j]), dim=-1),
                    tmp_path_info)
                path_info[i, j] = attn_value

            entity_graphs[i].edata.pop('h')

        path_info = self.dropout(
            self.activation(self.path_info_mapping(path_info)))

        predictions = self.predict(
            torch.cat((h_entity, t_entity, torch.abs(h_entity - t_entity),
                       torch.mul(h_entity, t_entity), global_info, path_info),
                      dim=-1))
        return predictions
Ejemplo n.º 7
0
                        default='train',
                        choices=['train', 'val', 'trainval', 'demo'])
    parser.add_argument("--gpu_id", type=int, default=-1)
    parser.add_argument("--backbone", type=str, default='vgg')
    parser.add_argument("--root_dataset",
                        type=str,
                        default='./data/Pascal_VOC')
    parser.add_argument("--resume", type=str, default='')
    parser.add_argument("--fcn",
                        type=str,
                        default='32s',
                        choices=['32s', '16s', '8s', '50', '101'])
    opts = parser.parse_args()

    # os.environ['CUDA_VISIBLE_DEVICES'] = str(opts.gpu_id)
    opts.cuda = get_cuda(torch.cuda.is_available() and opts.gpu_id != -1,
                         opts.gpu_id)
    print('Cuda', opts.cuda)
    cfg = get_config()[1]
    opts.cfg = cfg

    if opts.mode in ['train', 'trainval']:
        opts.out = get_log_dir('fcn' + opts.fcn, 1, cfg)
        print('Output logs: ', opts.out)

    data = get_loader(opts)

    trainer = Trainer(data, opts)
    if opts.mode == 'val':
        trainer.Test()
    elif opts.mode == 'demo':
        trainer.Demo()
                        default='vgg',
                        choices=['vgg', 'resnet'])
    parser.add_argument("--root_dataset",
                        type=str,
                        default='./data/cloud/cce/swimseg')
    parser.add_argument("--resume", type=str, default='')
    parser.add_argument("--fcn",
                        type=str,
                        default='32s',
                        choices=['32s', '16s', '8s', '50', '101'])
    opts = parser.parse_args()

    # os.environ['CUDA_VISIBLE_DEVICES'] = str(opts.gpu_id)
    cuda = torch.cuda.is_available()
    gpu_id = torch.cuda.current_device() if cuda else 99
    opts.cuda = get_cuda(cuda, gpu_id)

    print('CUDA', opts.cuda)

    cfg = get_config()[1]
    opts.cfg = cfg

    if opts.mode in ['train', 'trainval']:
        opts.out = get_log_dir('fcn' + opts.fcn, 1, cfg)
        print('Output logs: ', opts.out)

    data = get_loader(opts)

    trainer = Trainer(data, opts)
    if opts.mode == 'val':
        trainer.Test()
Ejemplo n.º 9
0
    def __getitem__(self, item):
        example = self.data[item]
        raw_text = example['raw_text']
        symptom_name = example['symptom_name']
        attr_dict = example['attr_dict']
        symptom_pos = example['symptom_pos']

        symptom_tokens = PLMConfig.tokenizer.encode(symptom_name)
        symptom_ids = symptom_tokens.ids[1:-1]
        symptom_offsets = symptom_tokens.offsets[1:-1]

        subject = attr_dict['subject']
        body = attr_dict['body']
        decorate = attr_dict['decorate']
        freq = attr_dict['frequency']

        # for body
        body_span_bound = max(0, symptom_pos[0] - self.left_side_span), min(
            symptom_pos[1] + self.right_side_span + 1, len(raw_text))
        body_text_span = raw_text[body_span_bound[0]:body_span_bound[1]]
        body_text_token = PLMConfig.tokenizer.encode(body_text_span)
        body_text_ids = body_text_token.ids[1:-1]
        body_text_offsets = body_text_token.offsets[1:-1]
        body_target_span = body[body_span_bound[0]:body_span_bound[1]]

        body_target_ids = []
        for idx, offset in enumerate(body_text_offsets):
            body_total = sum(body_target_span[offset[0]:offset[1]])
            if body_total > 0:
                body_target_ids.append(1)
            else:
                body_target_ids.append(0)

        body_input_ids = [101] + symptom_ids + [102] + body_text_ids + [102]
        body_token_type_ids = [0] * (len(symptom_ids) +
                                     2) + [1] * (len(body_text_ids) + 1)
        body_mask = [1] * len(body_token_type_ids)
        body_text_offsets = [
            (0, 0)
        ] * (len(symptom_ids) + 2) + body_text_offsets + [(0, 0)]
        body_target_ids = [0] * (len(symptom_ids) + 2) + body_target_ids + [0]

        body_padding_length = self.max_len - len(body_input_ids)
        if body_padding_length > 0:
            body_input_ids = body_input_ids + ([0] * body_padding_length)
            body_token_type_ids = body_token_type_ids + ([0] *
                                                         body_padding_length)
            body_mask = body_mask + ([0] * body_padding_length)
            body_text_offsets = body_text_offsets + ([(0, 0)] *
                                                     body_padding_length)
            body_target_ids = body_target_ids + ([0] * body_padding_length)
        body_input_ids = torch.tensor(body_input_ids, dtype=torch.long)
        body_token_type_ids = torch.tensor(body_token_type_ids,
                                           dtype=torch.long)
        body_mask = torch.tensor(body_mask, dtype=torch.long)
        body_target_ids = torch.tensor(body_target_ids,
                                       dtype=torch.float).unsqueeze(-1)

        data = {
            'body_input_ids': get_cuda(body_input_ids),
            'body_token_type_ids': get_cuda(body_token_type_ids),
            'body_mask': get_cuda(body_mask),
            'body_target_ids': get_cuda(body_target_ids),
            'body_text_offsets': torch.tensor(body_text_offsets),
            'raw_text': raw_text,
            'symptom_name': symptom_name
        }

        # for subject, decorate and freq
        subject_target_span = subject[symptom_pos[0]:symptom_pos[1] + 1]
        decorate_target_span = decorate[symptom_pos[0]:symptom_pos[1] + 1]
        freq_target_span = freq[symptom_pos[0]:symptom_pos[1] + 1]

        subject_target_ids = []
        decorate_target_ids = []
        freq_target_ids = []
        for idx, offset in enumerate(symptom_offsets):
            subject_total = sum(subject_target_span[offset[0]:offset[1]])
            if subject_total > 0:
                subject_target_ids.append(1)
            else:
                subject_target_ids.append(0)

            decorate_total = sum(decorate_target_span[offset[0]:offset[1]])
            if decorate_total > 0:
                decorate_target_ids.append(1)
            else:
                decorate_target_ids.append(0)

            freq_total = sum(freq_target_span[offset[0]:offset[1]])
            if freq_total > 0:
                freq_target_ids.append(1)
            else:
                freq_target_ids.append(0)

        input_ids = [101] + symptom_ids + [102]
        token_type_ids = [0] * (len(symptom_ids) + 2)
        mask = [1] * len(token_type_ids)
        offsets = [(0, 0)] + symptom_offsets + [(0, 0)]
        subject_target_ids = [0] + subject_target_ids + [0]
        decorate_target_ids = [0] + decorate_target_ids + [0]
        freq_target_ids = [0] + freq_target_ids + [0]

        padding_length = self.max_len - len(input_ids)
        if padding_length > 0:
            input_ids = input_ids + ([0] * padding_length)
            token_type_ids = token_type_ids + ([0] * padding_length)
            mask = mask + ([0] * padding_length)
            offsets = offsets + ([(0, 0)] * padding_length)
            subject_target_ids = subject_target_ids + ([0] * padding_length)
            decorate_target_ids = decorate_target_ids + ([0] * padding_length)
            freq_target_ids = freq_target_ids + ([0] * padding_length)

        input_ids = torch.tensor(input_ids, dtype=torch.long)
        token_type_ids = torch.tensor(token_type_ids, dtype=torch.long)
        mask = torch.tensor(mask, dtype=torch.long)
        subject_target_ids = torch.tensor(subject_target_ids,
                                          dtype=torch.float).unsqueeze(-1)
        decorate_target_ids = torch.tensor(decorate_target_ids,
                                           dtype=torch.float).unsqueeze(-1)
        freq_target_ids = torch.tensor(freq_target_ids,
                                       dtype=torch.float).unsqueeze(-1)
        data.update({
            'input_ids': get_cuda(input_ids),
            'token_type_ids': get_cuda(token_type_ids),
            'mask': get_cuda(mask),
            'subject_target_ids': get_cuda(subject_target_ids),
            'decorate_target_ids': get_cuda(decorate_target_ids),
            'freq_target_ids': get_cuda(freq_target_ids),
            'offsets': torch.tensor(offsets)
        })

        return data
Ejemplo n.º 10
0
    def __iter__(self):
        # shuffle
        if self.shuffle:
            random.shuffle(self.order)
            self.data = [self.dataset[idx] for idx in self.order]
        else:
            self.data = self.dataset
        batch_num = math.ceil(self.length / self.batch_size)
        self.batches = [
            self.data[idx * self.batch_size:min(self.length, (idx + 1) *
                                                self.batch_size)]
            for idx in range(0, batch_num)
        ]
        self.batches_order = [
            self.order[idx * self.batch_size:min(self.length, (idx + 1) *
                                                 self.batch_size)]
            for idx in range(0, batch_num)
        ]

        # begin
        context_word_ids = torch.LongTensor(self.batch_size,
                                            self.max_length).cpu()
        context_pos_ids = torch.LongTensor(self.batch_size,
                                           self.max_length).cpu()
        context_ner_ids = torch.LongTensor(self.batch_size,
                                           self.max_length).cpu()
        context_mention_ids = torch.LongTensor(self.batch_size,
                                               self.max_length).cpu()
        context_word_mask = torch.LongTensor(self.batch_size,
                                             self.max_length).cpu()
        context_word_length = torch.LongTensor(self.batch_size).cpu()
        ht_pairs = torch.LongTensor(self.batch_size, self.h_t_limit, 2).cpu()
        relation_multi_label = torch.Tensor(self.batch_size, self.h_t_limit,
                                            self.relation_num).cpu()
        relation_mask = torch.Tensor(self.batch_size, self.h_t_limit).cpu()
        relation_label = torch.LongTensor(self.batch_size,
                                          self.h_t_limit).cpu()
        ht_pair_distance = torch.LongTensor(self.batch_size,
                                            self.h_t_limit).cpu()

        for idx, minibatch in enumerate(self.batches):
            cur_bsz = len(minibatch)

            for mapping in [
                    context_word_ids, context_pos_ids, context_ner_ids,
                    context_mention_ids, context_word_mask,
                    context_word_length, ht_pairs, ht_pair_distance,
                    relation_multi_label, relation_mask, relation_label
            ]:
                if mapping is not None:
                    mapping.zero_()

            relation_label.fill_(IGNORE_INDEX)

            max_h_t_cnt = 0

            label_list = []
            L_vertex = []
            titles = []
            indexes = []
            graph_list = []
            entity_graph_list = []
            entity2mention_table = []
            path_table = []
            overlaps = []

            for i, example in enumerate(minibatch):
                title, entities, labels, na_triple, word_id, pos_id, ner_id, mention_id, entity2mention, graph, entity_graph, path = \
                    example['title'], example['entities'], example['labels'], example['na_triple'], \
                    example['word_id'], example['pos_id'], example['ner_id'], example['mention_id'], example[
                        'entity2mention'], example['graph'], example['entity_graph'], example['path']
                graph_list.append(graph)
                entity_graph_list.append(entity_graph)
                path_table.append(path)
                overlaps.append(example['overlap'])

                entity2mention_t = get_cuda(
                    torch.zeros((pos_id.max() + 1, mention_id.max() + 1)))
                for e, ms in entity2mention.items():
                    for m in ms:
                        entity2mention_t[e, m] = 1
                entity2mention_table.append(entity2mention_t)

                L = len(entities)
                word_num = word_id.shape[0]

                context_word_ids[i, :word_num].copy_(torch.from_numpy(word_id))
                context_pos_ids[i, :word_num].copy_(torch.from_numpy(pos_id))
                context_ner_ids[i, :word_num].copy_(torch.from_numpy(ner_id))
                context_mention_ids[i, :word_num].copy_(
                    torch.from_numpy(mention_id))

                idx2label = defaultdict(list)
                label_set = {}
                for label in labels:
                    head, tail, relation, intrain, evidence = \
                        label['h'], label['t'], label['r'], label['in_train'], label['evidence']
                    idx2label[(head, tail)].append(relation)
                    label_set[(head, tail, relation)] = intrain

                label_list.append(label_set)

                if self.dataset_type == 'train':
                    train_tripe = list(idx2label.keys())
                    for j, (h_idx, t_idx) in enumerate(train_tripe):
                        hlist, tlist = entities[h_idx], entities[t_idx]
                        ht_pairs[i,
                                 j, :] = torch.Tensor([h_idx + 1, t_idx + 1])
                        label = idx2label[(h_idx, t_idx)]

                        delta_dis = hlist[0]['global_pos'][0] - tlist[0][
                            'global_pos'][0]
                        if delta_dis < 0:
                            ht_pair_distance[i, j] = -int(
                                self.dis2idx[-delta_dis]) + self.dis_size // 2
                        else:
                            ht_pair_distance[i, j] = int(
                                self.dis2idx[delta_dis]) + self.dis_size // 2

                        for r in label:
                            relation_multi_label[i, j, r] = 1

                        relation_mask[i, j] = 1
                        rt = np.random.randint(len(label))
                        relation_label[i, j] = label[rt]

                    lower_bound = len(na_triple)
                    if self.negativa_alpha > 0.0:
                        random.shuffle(na_triple)
                        lower_bound = int(
                            max(20,
                                len(train_tripe) * self.negativa_alpha))

                    for j, (h_idx, t_idx) in enumerate(na_triple[:lower_bound],
                                                       len(train_tripe)):
                        hlist, tlist = entities[h_idx], entities[t_idx]
                        ht_pairs[i,
                                 j, :] = torch.Tensor([h_idx + 1, t_idx + 1])

                        delta_dis = hlist[0]['global_pos'][0] - tlist[0][
                            'global_pos'][0]
                        if delta_dis < 0:
                            ht_pair_distance[i, j] = -int(
                                self.dis2idx[-delta_dis]) + self.dis_size // 2
                        else:
                            ht_pair_distance[i, j] = int(
                                self.dis2idx[delta_dis]) + self.dis_size // 2

                        relation_multi_label[i, j, 0] = 1
                        relation_label[i, j] = 0
                        relation_mask[i, j] = 1

                        max_h_t_cnt = max(max_h_t_cnt,
                                          len(train_tripe) + lower_bound)
                else:
                    j = 0
                    for h_idx in range(L):
                        for t_idx in range(L):
                            if h_idx != t_idx:
                                hlist, tlist = entities[h_idx], entities[t_idx]
                                ht_pairs[i, j, :] = torch.Tensor(
                                    [h_idx + 1, t_idx + 1])

                                relation_mask[i, j] = 1

                                delta_dis = hlist[0]['global_pos'][0] - tlist[
                                    0]['global_pos'][0]
                                if delta_dis < 0:
                                    ht_pair_distance[i, j] = -int(self.dis2idx[
                                        -delta_dis]) + self.dis_size // 2
                                else:
                                    ht_pair_distance[i, j] = int(
                                        self.dis2idx[delta_dis]
                                    ) + self.dis_size // 2

                                j += 1

                    max_h_t_cnt = max(max_h_t_cnt, j)
                    L_vertex.append(L)
                    titles.append(title)
                    indexes.append(self.batches_order[idx][i])

            context_word_mask = context_word_ids > 0
            context_word_length = context_word_mask.sum(1)
            batch_max_length = context_word_length.max()

            yield {
                'context_idxs':
                get_cuda(context_word_ids[:cur_bsz, :batch_max_length].
                         contiguous()),
                'context_pos':
                get_cuda(
                    context_pos_ids[:cur_bsz, :batch_max_length].contiguous()),
                'context_ner':
                get_cuda(
                    context_ner_ids[:cur_bsz, :batch_max_length].contiguous()),
                'context_mention':
                get_cuda(context_mention_ids[:cur_bsz, :batch_max_length].
                         contiguous()),
                'context_word_mask':
                get_cuda(context_word_mask[:cur_bsz, :batch_max_length].
                         contiguous()),
                'context_word_length':
                get_cuda(context_word_length[:cur_bsz].contiguous()),
                'h_t_pairs':
                get_cuda(ht_pairs[:cur_bsz, :max_h_t_cnt, :2]),
                'relation_label':
                get_cuda(relation_label[:cur_bsz, :max_h_t_cnt]).contiguous(),
                'relation_multi_label':
                get_cuda(relation_multi_label[:cur_bsz, :max_h_t_cnt]),
                'relation_mask':
                get_cuda(relation_mask[:cur_bsz, :max_h_t_cnt]),
                'ht_pair_distance':
                get_cuda(ht_pair_distance[:cur_bsz, :max_h_t_cnt]),
                'labels':
                label_list,
                'L_vertex':
                L_vertex,
                'titles':
                titles,
                'indexes':
                indexes,
                'graphs':
                graph_list,
                'entity2mention_table':
                entity2mention_table,
                'entity_graphs':
                entity_graph_list,
                'path_table':
                path_table,
                'overlaps':
                overlaps
            }