コード例 #1
0
ファイル: main.py プロジェクト: hsack6/AGATE
def main(opt):
    train_dataset = BADataset(opt.dataroot, opt.L, True, False, False)
    train_dataloader = BADataloader(train_dataset, batch_size=opt.batchSize, \
                                      shuffle=True, num_workers=opt.workers, drop_last=True)

    valid_dataset = BADataset(opt.dataroot, opt.L, False, True, False)
    valid_dataloader = BADataloader(valid_dataset, batch_size=opt.batchSize, \
                                     shuffle=True, num_workers=opt.workers, drop_last=True)

    test_dataset = BADataset(opt.dataroot, opt.L, False, False, True)
    test_dataloader = BADataloader(test_dataset, batch_size=opt.batchSize, \
                                     shuffle=True, num_workers=opt.workers, drop_last=True)

    all_dataset = BADataset(opt.dataroot, opt.L, False, False, False)
    all_dataloader = BADataloader(all_dataset, batch_size=opt.batchSize, \
                                     shuffle=False, num_workers=opt.workers, drop_last=False)

    opt.n_edge_types = train_dataset.n_edge_types
    opt.n_node = train_dataset.n_node
    opt.n_existing_node = all_node_num

    net = GCN(opt, kernel_size=2, n_blocks=1, state_dim_bottleneck=opt.state_dim, annotation_dim_bottleneck=opt.annotation_dim)
    net.double()
    print(net)

    criterion = nn.CosineSimilarity(dim=1, eps=1e-6)

    if opt.cuda:
        net.cuda()
        criterion.cuda()

    optimizer = optim.Adam(net.parameters(), lr=opt.lr)
    early_stopping = EarlyStopping(patience=opt.patience, verbose=True)

    os.makedirs(OutputDir, exist_ok=True)
    train_loss_ls = []
    valid_loss_ls = []
    test_loss_ls = []

    for epoch in range(0, opt.niter):
        train_loss = train(epoch, train_dataloader, net, criterion, optimizer, opt)
        valid_loss = valid(valid_dataloader, net, criterion, opt)
        test_loss = test(test_dataloader, net, criterion, opt)

        train_loss_ls.append(train_loss)
        valid_loss_ls.append(valid_loss)
        test_loss_ls.append(test_loss)

        early_stopping(valid_loss, net, OutputDir)
        if early_stopping.early_stop:
            print("Early stopping")
            break

    df = pd.DataFrame({'epoch':[i for i in range(1, len(train_loss_ls)+1)], 'train_loss': train_loss_ls, 'valid_loss': valid_loss_ls, 'test_loss': test_loss_ls})
    df.to_csv(OutputDir + '/loss.csv', index=False)

    net.load_state_dict(torch.load(OutputDir + '/checkpoint.pt'))
    inference(all_dataloader, net, criterion, opt, OutputDir)
コード例 #2
0
ファイル: test.py プロジェクト: ieee820/IDRiD
def save_output(root_dir, output_dir):
    # dataset
    dataset = save_predict_dataset(root_dir)

    #model
    model = GCN(4, 512)
    if use_gpu:
        model = model.cuda()
        #model = torch.nn.DataParallel(model).cuda()
    model.load_state_dict(torch.load(os.path.join(save_dir, model_name)))
    model.train(False)
    for n in range(int(len(dataset) / (6 * 9))):
        #test
        full_output = np.zeros((4, 2848, 4288), dtype='float32')  #(C, H, W)
        title = ''
        for idx in range(6 * 9 * n, 6 * 9 * (n + 1)):
            image, name = dataset[idx]
            r = int((idx % (6 * 9)) / 9)  #row
            c = (idx % (6 * 9)) % 9  #column
            title = name

            if use_gpu:
                image = image.cuda()
            image = Variable(image, volatile=True)

            #forward
            output = model(image.unsqueeze(0))
            output = F.sigmoid(output)
            output = output[0]

            if c < 8:
                if r == 5:
                    full_output[:, r * 512:r * 512 + 512 - 224,
                                c * 512:c * 512 +
                                512] = output.cpu().data.numpy()[:, :-224, :]
                else:
                    full_output[:, r * 512:r * 512 + 512, c * 512:c * 512 +
                                512] = output.cpu().data.numpy()

        for i, d in enumerate(['MA', 'EX', 'HE', 'SE']):
            if not os.path.exists(os.path.join(output_dir, d)):
                os.makedirs(os.path.join(output_dir, d))
            im = np.expand_dims(full_output[i], axis=0).transpose(1, 2, 0)
            im = full_output[i] * 255
            im = np.uint8(im)
            im = Image.fromarray(im)
            im.save(os.path.join(output_dir, d, title + '.jpg'))
コード例 #3
0
ファイル: recon_cossm.py プロジェクト: hzl1216/gdapgcn
def load_trained_vector(epoch, number, n2i_f, file_homes):
    global node2index
    node2index = cPickle.load(n2i_f)
    node_count = len(node2index)
    node_dim = 128
    n_repr = 128
    gcn = GCN(node_count, node_dim, n_repr)
    gcn.load_state_dict(
        torch.load(file_homes + '/networks/GCN_%d_%d.pth' % (number, epoch),
                   map_location='cpu'))
    f = open(files_home + '/networks/adj_matrix_%d_full' % (number), 'rb')
    full_adj_matrix = cPickle.load(f)
    full_adj_matrix = sparse_mx_to_torch_sparse_tensor(full_adj_matrix)
    init_input = torch.LongTensor([j for j in range(0, node_count)])
    gcn.eval()

    rp_matrix = gcn(init_input, full_adj_matrix)
    #gcn.to(device)
    return rp_matrix.double()
コード例 #4
0
ファイル: train.py プロジェクト: ieee820/IDRiD
        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))

    print('Best F1 score: {:.4f}'.format(best_f1))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model


if __name__ == '__main__':
    # dataset
    dataloaders = make_dataloaders(batch_size=batch_size)

    #model
    model = GCN(4, 512)
    if use_gpu:
        model = model.cuda()
        #model = torch.nn.DataParallel(model).cuda()
    model.load_state_dict(torch.load(os.path.join(save_dir, 'gcn_v5.pth')))
    #training
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = ReduceLROnPlateau(optimizer, 'min', verbose=True)
    model = train_model(model, num_epochs, dataloaders, optimizer, scheduler)

    #save
    save_model(model, save_dir, model_name)
コード例 #5
0
def test():
    device = torch.device('cuda:0')
    print('using device:', device)
    config = Config()
    test_loader = get_loader(config, config.TEST_LIST, 2)
    model = GCN().to(device)
    criterion0 = nn.CrossEntropyLoss()
    criterion1 = nn.CrossEntropyLoss()
    criterion2 = nn.CrossEntropyLoss()
    criterion3 = nn.CrossEntropyLoss()
    criterion4 = nn.CrossEntropyLoss()
    criterion5 = nn.CrossEntropyLoss()
    criterion6 = nn.CrossEntropyLoss()
    criterion7 = nn.CrossEntropyLoss()
    criterion8 = nn.CrossEntropyLoss()
    criterion9 = nn.CrossEntropyLoss()
    criterion10 = nn.CrossEntropyLoss()
    '''
    RES152 'model--epoch:10-D-acc:0.9812--L-acc:0.9932.ckpt'
    res50  'model--epoch:11-D-acc:0.9801--L-acc:0.9937.ckpt'

    '''

    for i in range(1, 199):
        ckpt_path = os.path.join(config.MODEL_PATH,
                                 'model_epoch_' + str(i) + '.ckpt')
        print('loading checkpoint from', ckpt_path)
        checkpoint = torch.load(ckpt_path)
        model.load_state_dict(checkpoint)
        params = list(model.named_parameters())

        # print(params[0][1])
        data = params[0][1].detach().cpu().numpy()
        print(data.shape)
        plt.matshow(data, cmap=plt.cm.hot)
        plt.colorbar()
        plt.savefig('./Adjacent/epoch_' + str(i) + '.jpg')
        plt.close()
        del params
        del data

    exit(0)

    ckpt_path = os.path.join(config.MODEL_PATH, 'model_epoch_63.ckpt')
    print('loading checkpoint from', ckpt_path)
    checkpoint = torch.load(ckpt_path)
    model.load_state_dict(checkpoint)
    params = list(model.named_parameters())
    test_accs = []
    test_loss = []

    disease_nums = 11
    disease_names = [
        '主疾病', '结节类型--高/等/低回声', '结节类型--单/多发', '边界', '形态', '纵横比', '皮质类型',
        '淋巴门结构', '钙化', '囊性区', '血流'
    ]

    ################
    ###测试
    ################
    label_preds = [[] for _ in range(disease_nums)]
    label_gts = [[] for _ in range(disease_nums)]

    print('begin to predict')
    with torch.no_grad():
        for ii, (inputs, gt_labels) in enumerate(test_loader):
            print(str(ii))
            for i in range(len(inputs)):
                # _input = _input.to(torch.float)
                inputs[i] = inputs[i].type(torch.FloatTensor)
                inputs[i] = inputs[i].to(device)

            # for _input in inputs:
            #     print(_input.device)

            for i in range(len(gt_labels)):
                gt_labels[i] = gt_labels[i].to(torch.long)
                gt_labels[i] = gt_labels[i].to(device)

            pred_outputs = model(inputs[0], inputs[1])

            # loss0 = criterion0(pred_outputs[0],gt_label[0])
            # loss1 = criterion1(pred_outputs[1],gt_label[1])
            # loss2 = criterion2(pred_outputs[2],gt_label[2])
            # loss3 = criterion3(pred_outputs[3],gt_label[3])
            # loss4 = criterion4(pred_outputs[4],gt_label[4])
            # loss5 = criterion5(pred_outputs[5],gt_label[5])
            # loss6 = criterion6(pred_outputs[6],gt_label[6])
            # loss7 = criterion7(pred_outputs[7],gt_label[7])
            # loss8 = criterion8(pred_outputs[8],gt_label[8])
            # loss9 = criterion9(pred_outputs[9],gt_label[9])
            # loss10 = criterion10(pred_outputs[10],gt_label[10])

            # loss = 10*loss0+loss1+loss2+loss3+loss4+loss5+loss6+loss7+loss8+loss9+loss10

            # 记录当前的lost以及batchSize数据对应的分类准确数量
            for i in range(disease_nums):
                print(pred_outputs[i])
                _, predict = torch.max(pred_outputs[i], 1)
                print(predict)
                predict = list(predict.cpu().numpy())
                label = list(gt_labels[i].cpu().numpy())
                #存储
                label_preds[i] += predict
                label_gts[i] += label


########################
### 评价矩阵
########################
    for i in range(disease_nums):
        print('-' * 40 + '↓disease↓' + '-' * 40)
        accuracy = accuracy_score(label_preds[i], label_gts[i])
        precision = precision_score(label_preds[i],
                                    label_gts[i],
                                    average='macro')
        recall = recall_score(label_preds[i], label_gts[i], average='macro')
        f1 = f1_score(label_preds[i], label_gts[i], average='macro')

        print(
            ' {} : accuracy:{:.4f}, precision:{:.4f}, recall:{:.4f},f1:{:.4f}'.
            format(disease_names[i], accuracy, precision, recall, f1))

    #主任务的混淆矩阵
    from sklearn.metrics import confusion_matrix
    print('confusion_matrix: ')
    for i, row in enumerate(confusion_matrix(label_preds[0], label_gts[0])):
        print(str(i) + '\t' + str(row))

    con_matrix = confusion_matrix(label_preds[0], label_gts[0])
    class_num = len(con_matrix)

    print('classnum', class_num)
    for i in range(class_num):
        print(config.DISEASE_LABELS[i],
              ' acc :{:.4f}'.format(con_matrix[i][i] / np.sum(con_matrix[i])))
    print('-' * 80)
コード例 #6
0
ファイル: test.py プロジェクト: ieee820/IDRiD
def run_statistic(threshold):
    '''
        evaluate on small images result
    '''
    # dataset
    dataset = IDRiD_sub1_dataset(data_dir)
    dataloader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=4)
    #print('Data: %d'%(len(dataset)))

    #model
    model = GCN(4, 512)
    if use_gpu:
        model = model.cuda()
        #model = torch.nn.DataParallel(model).cuda()
    model.load_state_dict(torch.load(os.path.join(save_dir, model_name)))
    model.train(False)
    for i in range(4):
        y_pred_list = []
        y_true_list = []
        for idx, data in enumerate(dataloader):
            images, masks, names = data

            if use_gpu:
                images = images.cuda()
                masks = masks.cuda()
            images, masks = Variable(images,
                                     volatile=True), Variable(masks,
                                                              volatile=True)

            #forward
            outputs = model(images)

            # statistics
            outputs = F.sigmoid(
                outputs).cpu().data  #remenber to apply sigmoid befor usage
            masks = masks.cpu().data
            #for i in range(len(outputs)):
            y_pred = outputs[i]
            y_true = masks[i]
            y_pred = y_pred.numpy().flatten()
            y_pred = np.where(y_pred > threshold, 1, 0)
            y_true = y_true.numpy().flatten()
            y_pred_list.append(y_pred)
            y_true_list.append(y_true)

            #verbose
            if idx % 5 == 0 and idx != 0:
                print('\r{:.2f}%'.format(100 * idx / len(dataloader)),
                      end='\r')
        #print()
        type_list = ['MA', 'EX', 'HE', 'SE']
        precision, recall, f1, _ = precision_recall_fscore_support(
            np.array(y_true_list).flatten(),
            np.array(y_pred_list).flatten(),
            average='binary')
        print(
            '{}    \nThreshold: {:.2f}\nPrecision: {:.4f}\nRecall: {:.4f}\nF1: {:.4f}'
            .format(type_list[i], threshold, precision, recall, f1))
コード例 #7
0
ファイル: test.py プロジェクト: ieee820/IDRiD
def show_image_sample():
    # dataset
    dataset = IDRiD_sub1_dataset(data_dir)

    #model
    model = GCN(4, 512)
    if use_gpu:
        model = model.cuda()
        #model = torch.nn.DataParallel(model).cuda()
    model.load_state_dict(torch.load(os.path.join(save_dir, model_name)))
    model.train(False)
    for n in range(12):
        #test
        full_image = np.zeros((3, 2848, 4288), dtype='float32')
        full_mask = np.zeros((4, 2848, 4288), dtype='float32')
        full_output = np.zeros((4, 2848, 4288), dtype='float32')  #(C, H, W)
        title = ''
        for idx in range(9 * 6 * n, 9 * 6 * (n + 1)):
            image, mask, name = dataset[idx]
            n = int(idx / (6 * 9))  #image index
            r = int((idx % (6 * 9)) / 9)  #row
            c = (idx % (6 * 9)) % 9  #column
            title = name[:-8]

            if use_gpu:
                image = image.cuda()
                mask = mask.cuda()
            image, mask = Variable(image,
                                   volatile=True), Variable(mask,
                                                            volatile=True)

            #forward
            output = model(image.unsqueeze(0))
            output = F.sigmoid(output)
            output = output[0]
            if c < 8:
                if r == 5:
                    full_output[:, r * 512:r * 512 + 512 - 224,
                                c * 512:c * 512 +
                                512] = output.cpu().data.numpy()[:, :-224, :]
                    full_mask[:, r * 512:r * 512 + 512 - 224, c * 512:c * 512 +
                              512] = mask.cpu().data.numpy()[:, :-224, :]
                    full_image[:, r * 512:r * 512 + 512 - 224,
                               c * 512:c * 512 +
                               512] = image.cpu().data.numpy()[:, :-224, :]

                else:
                    full_output[:, r * 512:r * 512 + 512, c * 512:c * 512 +
                                512] = output.cpu().data.numpy()
                    full_mask[:, r * 512:r * 512 + 512,
                              c * 512:c * 512 + 512] = mask.cpu().data.numpy()
                    full_image[:, r * 512:r * 512 + 512, c * 512:c * 512 +
                               512] = image.cpu().data.numpy()

        full_image = full_image.transpose(1, 2, 0)
        MA = full_output[0]
        EX = full_output[1]
        HE = full_output[2]
        SE = full_output[3]

        plt.figure()
        plt.axis('off')
        plt.suptitle(title)
        plt.subplot(331)
        plt.title('image')
        fig = plt.imshow(full_image)
        fig.axes.get_xaxis().set_visible(False)
        fig.axes.get_yaxis().set_visible(False)
        plt.subplot(332)
        plt.title('ground truth MA')
        fig = plt.imshow(full_mask[0])
        fig.axes.get_xaxis().set_visible(False)
        fig.axes.get_yaxis().set_visible(False)
        plt.subplot(333)
        plt.title('ground truth EX')
        fig = plt.imshow(full_mask[1])
        fig.axes.get_xaxis().set_visible(False)
        fig.axes.get_yaxis().set_visible(False)
        plt.subplot(334)
        plt.title('ground truth HE')
        fig = plt.imshow(full_mask[2])
        fig.axes.get_xaxis().set_visible(False)
        fig.axes.get_yaxis().set_visible(False)
        plt.subplot(335)
        plt.title('ground truth SE')
        fig = plt.imshow(full_mask[3])
        fig.axes.get_xaxis().set_visible(False)
        fig.axes.get_yaxis().set_visible(False)
        plt.subplot(336)
        plt.title('predict MA')
        fig = plt.imshow(MA)
        fig.axes.get_xaxis().set_visible(False)
        fig.axes.get_yaxis().set_visible(False)
        plt.subplot(337)
        plt.title('predict EX')
        fig = plt.imshow(EX)
        fig.axes.get_xaxis().set_visible(False)
        fig.axes.get_yaxis().set_visible(False)
        plt.subplot(338)
        plt.title('predict HE')
        fig = plt.imshow(HE)
        fig.axes.get_xaxis().set_visible(False)
        fig.axes.get_yaxis().set_visible(False)
        plt.subplot(339)
        plt.title('predict SE')
        fig = plt.imshow(SE)
        fig.axes.get_xaxis().set_visible(False)
        fig.axes.get_yaxis().set_visible(False)

        plt.show()
コード例 #8
0
ファイル: testset_evaluation.py プロジェクト: hzl1216/gdapgcn
def main(files_home):
    starttime = datetime.now()
    print('start test model ', starttime)

    number = args.number

    f = open(os.path.join(files_home, files_name['node_index']), 'rb')
    node2index = cPickle.load(f)

    f_train = os.path.join(files_home, files_name['train_file'])
    f_test = os.path.join(files_home, files_name['test_file'])  ###

    #    testset = Sample_Set_Test(f_test,f_train, node2index)
    #   testloader = DataLoader(testset, batch_size=32, shuffle=False)

    node_count = len(node2index)
    node_dim = 128
    n_repr = 128
    gcn = GCN(node_count, node_dim, n_repr, dropout=args.dropout)
    lp = Link_Prediction(n_repr, dropout=args.dropout)
    if args.cuda == True:
        gcn.cuda()
        lp = nn.DataParallel(lp)
        lp.cuda()

    init_input = torch.LongTensor([j for j in range(0, node_count)]).cuda()

    dis2gene_test_true, dis2gene_test_all = get_rank_test_samples(
        f_train, f_test)

    f = open(files_home + '/networks/adj_matrix_%d_full' % (number), 'rb')
    full_adj_matrix = cPickle.load(f)
    full_adj_matrix = sparse_mx_to_torch_sparse_tensor(full_adj_matrix).cuda()

    for epoch in tqdm(range(0, args.epochs)):
        if epoch % 9 != 0 and epoch < args.epochs - 5:
            continue
        gcn.load_state_dict(
            torch.load(files_home + '/networks/GCN_%d_%d.pth' %
                       (number, epoch)))
        lp.load_state_dict(
            torch.load(files_home + '/networks/Link_Prediction_%d_%d.pth' %
                       (number, epoch)))
        gcn.eval()

        feature_matrix = gcn(init_input, full_adj_matrix)
        #       test_link(testloader, feature_matrix, lp)

        if 0:
            print('use gcn to prediction')
            ap, prec, recall, f1score = test_priorization_gcn(
                dis2gene_test_true, dis2gene_test_all, feature_matrix, lp)
        else:
            print('use gcn and word2ver to prediction')
            ap, prec, recall, f1score = test_priorization_word_gcn(
                dis2gene_test_true, dis2gene_test_all, feature_matrix, lp)
        print('Performance for number=%d epoch=%d' % (number, epoch))
        print('AP: ', ap)
        print('Prec: ', prec)
        print('Recall: ', recall)
        print('F1score: ', f1score)

    endtime = datetime.now()
    print('finish test model! run spend ', endtime - starttime)
コード例 #9
0
ファイル: train.py プロジェクト: furljq/CUB200
from graph import build_graph_skeleton as build_graph
from dataLoader import cropCUB as CUB

CD = True

G = build_graph()

net = GCN(2048, [1024, 512], 200)
if CD:
    net = net.cuda(0)

batch_size = 8
start_epoch = 0
if True:
    ckpt = torch.load('/Disk5/junqi/CUB/early_skeleton_93.ckpt')
    net.load_state_dict(ckpt['net_state_dict'])
    start_epoch = ckpt['epoch'] + 1

datas = CUB()
dataLoader = torch.utils.data.DataLoader(datas,
                                         batch_size=batch_size,
                                         shuffle=True,
                                         num_workers=4)
optimizer = torch.optim.SGD(net.parameters(),
                            lr=1e-2,
                            momentum=0.9,
                            weight_decay=1e-4)

net.train()
for epoch in range(start_epoch, 201):
    epoch_loss = 0
コード例 #10
0
def train_gcn(training_features, training_adjs, training_labels, eval_features,
              eval_adjs, eval_labels, params, class_weights, activations,
              unary, coeffs, graph_params):
    device = torch.device(
        'cuda:1') if torch.cuda.is_available() else torch.device('cpu')
    gcn = GCN(n_features=training_features[0].shape[1],
              hidden_layers=params["hidden_layers"],
              dropout=params["dropout"],
              activations=activations,
              p=graph_params["probability"],
              normalization=params["edge_normalization"])
    gcn.to(device)
    opt = params["optimizer"](gcn.parameters(),
                              lr=params["lr"],
                              weight_decay=params["regularization"])

    n_training_graphs = len(training_labels)
    graph_size = graph_params["vertices"]
    n_eval_graphs = len(eval_labels)

    counter = 0  # For early stopping
    min_loss = None
    for epoch in range(params["epochs"]):
        # -------------------------- TRAINING --------------------------
        training_graphs_order = np.arange(n_training_graphs)
        np.random.shuffle(training_graphs_order)
        for i, idx in enumerate(training_graphs_order):
            training_mat = torch.tensor(training_features[idx], device=device)
            training_adj, training_lbs = map(
                lambda x: torch.tensor(
                    data=x[idx], dtype=torch.double, device=device),
                [training_adjs, training_labels])
            gcn.train()
            opt.zero_grad()
            output_train = gcn(training_mat, training_adj)
            output_matrix_flat = (
                torch.mm(output_train, output_train.transpose(0, 1)) +
                1 / 2).flatten()
            training_criterion = gcn_build_weighted_loss(
                unary, class_weights, training_lbs)
            loss_train = coeffs[0] * training_criterion(output_train.view(output_train.shape[0]), training_lbs) + \
                coeffs[1] * gcn_pairwise_loss(output_matrix_flat, training_adj.flatten()) + \
                coeffs[2] * gcn_binomial_reg(output_train, graph_params)
            loss_train.backward()
            opt.step()

        # -------------------------- EVALUATION --------------------------
        graphs_order = np.arange(n_eval_graphs)
        np.random.shuffle(graphs_order)
        outputs = torch.zeros(graph_size * n_eval_graphs, dtype=torch.double)
        output_xs = torch.zeros(graph_size**2 * n_eval_graphs,
                                dtype=torch.double)
        adj_flattened = torch.tensor(
            np.hstack([eval_adjs[idx].flatten() for idx in graphs_order]))
        for i, idx in enumerate(graphs_order):
            eval_mat = torch.tensor(eval_features[idx], device=device)
            eval_adj, eval_lbs = map(
                lambda x: torch.tensor(
                    data=x[idx], dtype=torch.double, device=device),
                [eval_adjs, eval_labels])
            gcn.eval()
            output_eval = gcn(eval_mat, eval_adj)
            output_matrix_flat = (
                torch.mm(output_eval, output_eval.transpose(0, 1)) +
                1 / 2).flatten()
            output_xs[i * graph_size**2:(i + 1) *
                      graph_size**2] = output_matrix_flat.cpu()
            outputs[i * graph_size:(i + 1) * graph_size] = output_eval.view(
                output_eval.shape[0]).cpu()
        all_eval_labels = torch.tensor(np.hstack(
            [eval_labels[idx] for idx in graphs_order]),
                                       dtype=torch.double)
        eval_criterion = gcn_build_weighted_loss(unary, class_weights,
                                                 all_eval_labels)
        loss_eval = (
            coeffs[0] * eval_criterion(outputs, all_eval_labels) +
            coeffs[1] * gcn_pairwise_loss(output_xs, adj_flattened) +
            coeffs[2] * gcn_binomial_reg(outputs, graph_params)).item()

        if min_loss is None:
            current_min_loss = loss_eval
        else:
            current_min_loss = min(min_loss, loss_eval)

        if epoch >= 10 and params[
                "early_stop"]:  # Check for early stopping during training.
            if min_loss is None:
                min_loss = current_min_loss
                torch.save(gcn.state_dict(),
                           "tmp_time.pt")  # Save the best state.
            elif loss_eval < min_loss:
                min_loss = current_min_loss
                torch.save(gcn.state_dict(),
                           "tmp_time.pt")  # Save the best state.
                counter = 0
            else:
                counter += 1
                if counter >= 40:  # Patience for learning
                    break
    # After stopping early, our model is the one with the best eval loss.
    gcn.load_state_dict(torch.load("tmp_time.pt"))
    os.remove("tmp_time.pt")
    return gcn
コード例 #11
0
# Data
if args.prepro:
    tmp = torch.load(args.prepro)
    adj_i, adj_v, adj_s, feats, labels, idx_train, idx_val, idx_test = tmp
    adj = torch.sparse.FloatTensor(adj_i, adj_v, adj_s)
else:
    adj, feats, labels, idx_train, idx_val, idx_test = load_cora()

# Model
model = GCN(num_layers=args.layers,
            in_size=feats.shape[1],
            h_size=args.h_size,
            out_size=labels.max().item() + 1,
            dropout=args.dropout)
if args.model:
    model.load_state_dict(torch.load(args.model))

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)

# GPU
if args.gpu:
    tmp = model, adj, feats, labels, idx_train, idx_val, idx_test
    tmp = [x.cuda() for x in tmp]
    model, adj, feats, labels, idx_train, idx_val, idx_test = tmp

# Train/Validate
print('Loaded data in {:.2f}s'.format(time.time() - start))
if args.test:
    assert args.model is not None, 'No model to evaluate'
    loss, acc = validate(model, adj, feats, labels, idx_test)
コード例 #12
0
def main(args):

    # 0. initial setting

    # set environmet
    cudnn.benchmark = True

    if not os.path.isdir(os.path.join(args.path, './ckpt')):
        os.mkdir(os.path.join(args.path, './ckpt'))
    if not os.path.isdir(os.path.join(args.path, './results')):
        os.mkdir(os.path.join(args.path, './results'))
    if not os.path.isdir(os.path.join(args.path, './ckpt', args.name)):
        os.mkdir(os.path.join(args.path, './ckpt', args.name))
    if not os.path.isdir(os.path.join(args.path, './results', args.name)):
        os.mkdir(os.path.join(args.path, './results', args.name))
    if not os.path.isdir(os.path.join(args.path, './results', args.name,
                                      "log")):
        os.mkdir(os.path.join(args.path, './results', args.name, "log"))

    # set logger
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(message)s')
    handler = logging.FileHandler(
        os.path.join(
            args.path, "results/{}/log/{}.log".format(
                args.name, time.strftime('%c', time.localtime(time.time())))))
    handler.setFormatter(formatter)
    logger.addHandler(handler)
    logger.addHandler(logging.StreamHandler())
    args.logger = logger

    # set cuda
    if torch.cuda.is_available():
        args.logger.info("running on cuda")
        args.device = torch.device("cuda")
        args.use_cuda = True
    else:
        args.logger.info("running on cpu")
        args.device = torch.device("cpu")
        args.use_cuda = False

    args.logger.info("[{}] starts".format(args.name))

    # 1. load data

    adj, features, labels, idx_train, idx_val, idx_test = load_data()

    # 2. setup
    CORA_NODES = 2708
    CORA_FEATURES = 1433
    CORA_CLASSES = 7
    CITESEER_NODES = 3327
    CITESEER_FEATURES = 3703
    CITESEER_CLASSES = 6

    (num_nodes, feature_dim,
     classes) = (CORA_NODES, CORA_FEATURES,
                 CORA_CLASSES) if args.dataset == 'cora' else (
                     CITESEER_NODES, CITESEER_FEATURES, CITESEER_CLASSES)
    args.logger.info("setting up...")
    model = GCN(args, feature_dim, args.hidden, classes,
                args.dropout) if args.model == 'gcn' else SpGAT(
                    args, feature_dim, args.hidden, classes, args.dropout,
                    args.alpha, args.n_heads)
    model.to(args.device)
    loss_fn = nn.NLLLoss()
    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           weight_decay=args.weight_decay)

    if args.load:
        loaded_data = load(args, args.ckpt)
        model.load_state_dict(loaded_data['model'])
        optimizer.load_state_dict(loaded_data['optimizer'])

    # 3. train / test

    if not args.test:
        # train
        args.logger.info("starting training")
        train_loss_meter = AverageMeter(args,
                                        name="Loss",
                                        save_all=True,
                                        x_label="epoch")
        val_acc_meter = AverageMeter(args,
                                     name="Val Acc",
                                     save_all=True,
                                     x_label="epoch")
        earlystop_listener = val_acc_meter.attach_combo_listener(
            (lambda prev, new: prev.max >= new.max), threshold=args.patience)
        steps = 1
        for epoch in range(1, 1 + args.epochs):
            spent_time = time.time()
            model.train()
            train_loss_tmp_meter = AverageMeter(args)

            if args.start_from_step is not None:
                if steps < args.start_from_step:
                    steps += 1
                    continue
            optimizer.zero_grad()
            batch = len(idx_train)
            output = model(features.to(args.device), adj.to(args.device))
            loss = loss_fn(output[idx_train],
                           labels[idx_train].to(args.device))
            loss.backward()
            optimizer.step()
            train_loss_tmp_meter.update(loss, weight=batch)
            steps += 1

            train_loss_meter.update(train_loss_tmp_meter.avg)
            spent_time = time.time() - spent_time
            args.logger.info(
                "[{}] train loss: {:.3f} took {:.1f} seconds".format(
                    epoch, train_loss_tmp_meter.avg, spent_time))

            model.eval()
            spent_time = time.time()
            if not args.fastmode:
                with torch.no_grad():
                    output = model(features.to(args.device),
                                   adj.to(args.device))
            acc = accuracy(output[idx_val], labels[idx_val]) * 100.0
            val_acc_meter.update(acc)
            earlystop = earlystop_listener.listen()
            spent_time = time.time() - spent_time
            args.logger.info(
                "[{}] val acc: {:2.1f} % took {:.1f} seconds".format(
                    epoch, acc, spent_time))
            if steps % args.save_period == 0:
                save(args, "epoch{}".format(epoch),
                     {'model': model.state_dict()})
                train_loss_meter.plot(scatter=False)
                val_acc_meter.plot(scatter=False)
                val_acc_meter.save()

            if earlystop:
                break

    else:
        # test
        args.logger.info("starting test")
        model.eval()
        with torch.no_grad():
            model(features.to(args.device), adj.to(args.device))
        acc = accuracy(output[idx_test], labels[idx_test]) * 100
        logger.d("test acc: {:2.1f} % took {:.1f} seconds".format(
            acc, spent_time))