def __getitem__(self, index):

        cfg = TrainOptions().parse()
        fn = self.imgs[index]  #fn 是记录一个包里所有矩阵的文件名的 数组
        imgdata = np.zeros((cfg.depth, 192, 192))
        labeldata = np.zeros((cfg.depth, 192, 192))
        if len(fn) == 0:
            print("\n\n\n\n\nitem is all zero \n\n\n\n")
        for i in range(len(fn)):
            imi = loadmat(os.path.join(self.img_dir, fn[i]),
                          verify_compressed_data_integrity=False)['data']
            labeli = loadmat(os.path.join(self.mask_dir, fn[i]),
                             verify_compressed_data_integrity=False)['data']
            msk_img = (imi >= 0)

            imgdata[i, msk_img] = imi[msk_img]
            labeldata[i] = labeli
        while i < cfg.depth:
            imgdata[i] = imi
            labeldata[i] = labeli
            i += 1

        img = torch.from_numpy(imgdata)
        label = torch.from_numpy(labeldata)
        img = torch.unsqueeze(img, 0).float()  #add channel weiDu
        label = label.long()
        if DEBUG:
            print(img.size(), label.size())
            #label = torch.squeeze(label).long()
        return img, label
Esempio n. 2
0
    def __init__(self, img_dir, mask_dir, is_train=True):
        print('Dataset')
        cfg = TrainOptions().parse()
        sheet_train = "TrainSliceIndex0607.xlsx"
        sheet_test = "TrainSliceIndex0607.xlsx"
        train_s = 0
        test_s = 2

        sheet_name = sheet_train
        s = train_s
        if is_train == False:
            sheet_name = sheet_test
            s = test_s
        sheet = ReadXlsx(
            '/home/yangtingyang/yty/HeartVessel/Dataset/GroupIndex/' +
            sheet_name, s)  # test set

        row_num = sheet.nrows
        col_num = sheet.ncols

        print("sheet has {} rows".format(row_num))

        self.imgs = sheet.col_values(0)

        self.img_dir = img_dir
        self.mask_dir = mask_dir
Esempio n. 3
0
    def __getitem__(self, index):

        cfg = TrainOptions().parse()
        fn = self.imgs[index]  #fn 是记录一个包里所有矩阵的文件名的 数组
        imi = loadmat(os.path.join(self.img_dir, fn),
                      verify_compressed_data_integrity=False)['data']
        labeli = loadmat(os.path.join(self.mask_dir, fn),
                         verify_compressed_data_integrity=False)['data']
        msk_img = (imi >= 0)

        img = torch.from_numpy(imi)
        label = torch.from_numpy(labeli)
        img = torch.unsqueeze(img, 0).float()  #add channel weiDu
        label = label.long()
        if DEBUG:
            print(img.size(), label.size())
            #label = torch.squeeze(label).long()
        return img, label
Esempio n. 4
0
# -*- coding:utf-8 -*-
import torch
import torch.nn as nn
from torch.nn import init
import functools
from torch.autograd import Variable
from torch.optim import lr_scheduler
from .Sparse_conv import SparseConv
from config import TrainOptions

opt = TrainOptions().parse()


def weights_init_xavier(m):
    class_name = m.__class__.__name__
    if class_name.find('Conv') != -1 and \
       class_name not in ['AllOneConv2d', 'SparseConv']:
        init.xavier_normal(m.weight.data)
    elif class_name.find('Linear') != -1:
        init.xavier_normal(m.weight.data)
    elif class_name.find('BatchNorm2d') != -1:
        init.uniform(m.weight.data, 1.0, 0.02)
        init.constant(m.bias.data, 0.0)


# 权重初始化
def init_weights(net, init_type='normal'):
    if init_type == 'xavier':
        net.apply(weights_init_xavier)
    else:
        raise NotImplementedError(
Esempio n. 5
0
from torch.utils.data.dataloader import DataLoader
import os
from config import TrainOptions
from unet import UNet
from dataset import UNetDataset
from testdataset import UNetTestDataset

from trainer import UNetTrainer
from tester import UNetTester

import sys

os.environ['CUDA_VISIBLE_DEVICES'] = '3'

if __name__ == '__main__':
    cfg = TrainOptions().parse()
    print('lr: ', cfg.base_lr)
    unet_dataset = UNetTestDataset(img_dir=cfg.img_dir,
                                   mask_dir=cfg.mask_dir,
                                   is_train=False)
    unet_test_dataset = UNetTestDataset(img_dir=cfg.img_dir,
                                        mask_dir=cfg.mask_dir,
                                        is_train=False)
    unet_dataloader = DataLoader(dataset=unet_dataset,
                                 batch_size=1,
                                 shuffle=True,
                                 num_workers=cfg.num_workers)
    model = UNet(n_channels=1, n_classes=2)
    tester = UNetTester(dataset=unet_test_dataset,
                        dataloader=unet_dataloader,
                        model=model,
Esempio n. 6
0
import torch.nn as nn
import numpy as np
import random
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import models
import os
from trainer import train_model
from visualization import visualize_model

from losses.cos_face_loss import CosineMarginProduct
from losses.arc_face_loss import ArcMarginProduct
from losses.linear_loss import InnerProduct

from config import TrainOptions
args = TrainOptions().parse()

from datasets import CreateDataloader
dataloaders, dataset_sizes, class_names = CreateDataloader(args)

print('class_names: ', class_names)

save_model = args.save_model
if not os.path.exists(save_model):
    os.mkdir(save_model)


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    def __init__(self, img_dir, mask_dir, is_train=True):
        print('Dataset')
        cfg = TrainOptions().parse()
        sheet_train = ReadXlsx(
            '/home/yangtingyang/HeartVessel/Dataset/HeartVeesel/GroupIndex/' +
            'TrainSliceIndex0607.xlsx', 0)  # train set
        sheet_test = ReadXlsx(
            '/home/yangtingyang/HeartVessel/Dataset/HeartVeesel/GroupIndex/' +
            'SelVolum0604.xlsx', 0)  # test set
        sheet = sheet_train
        split_path = cfg.split_train_Dataset_path

        if is_train == False:
            sheet = sheet_test
            split_path = cfg.split_test_Dataset_path
        record = open(split_path, 'w')

        fh = sheet.col_values(0)
        row_num = sheet.nrows

        imgs = []
        imgbag = []
        labelbag = []
        cnt = 0
        personN = 0
        i = 0
        lastinstance = "1000000"
        if DEBUG:
            print("len img list:", row_num)
        while i < row_num:
            lineimg = fh[i].strip('\n')
            body, fix = lineimg.split(".")
            gro, _, _, code, instance = body.split("_")

            if abs(int(instance) -
                   int(lastinstance)) >= 15:  #如果本人的ct图有明显跳变(缺少10帧以上)
                personN = code
                if len(imgbag) != 0:
                    imgs.append(imgbag)  #将该组打包上交,不与下面同处一组
                    record.write(" " + str(len(imgbag)) + "\n")
                    imgbag = []
                    cnt = 0

            elif cnt >= cfg.depth or code != personN:  #满则打包
                if code == personN and i >= cfg.depth:
                    i -= (cfg.depth - 3)
                personN = code

                if len(imgbag) != 0:
                    imgs.append(imgbag)
                    record.write(" " + str(len(imgbag)) + "\n")
                imgbag = []
                cnt = 0
            else:  #加到该组后面
                imgbag.append(lineimg)
                record.write(lineimg + " ")
                i += 1
                cnt += 1
            lastinstance = instance

        self.imgs = imgs
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        record.close()
Esempio n. 8
0
            labels.append(int(label.detach().cpu().numpy()))
            fwrite.write('{0} {1} {2:0.4f} {3}'.format(
                utt_id_list[0], utt_id_list[-1], result, labels[-1]) + '\n')
    embedding_mean_eer, embedding_mean_thresh = processDataTable2(
        np.array(labels), np.array(embedding_mean_probs))
    fwrite.write('embedding_mean_eer {} embedding_mean_thresh {}'.format(
        embedding_mean_eer, embedding_mean_thresh))
    fwrite.close()
    logging.info("embedding_mean_EER : %0.4f (thres:%0.4f)" %
                 (embedding_mean_eer, embedding_mean_thresh))
    eer = embedding_mean_eer
    return eer


# Prepare the parameters
opt = TrainOptions().parse()

## set seed ##
if opt.manual_seed is None:
    opt.manual_seed = random.randint(1, 10000)
print('manual_seed = %d' % opt.manual_seed)
random.seed(opt.manual_seed)
torch.manual_seed(opt.manual_seed)

# Configure the distributed training
opt.num_gpus = int(
    os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
opt.distributed = opt.num_gpus > 1
if opt.cuda:
    opt.device = torch.device("cuda")
else:
Esempio n. 9
0
import psutil
import time
import numpy as np
from tqdm import tqdm
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable

from model import DeepSpeakerModel, DeepSpeakerSeqModel, DeepSpeakerCnnModel, DeepSpeakerCnnSeqModel
from model import similarity, loss_cal, normalize, penalty_loss_cal, similarity_segment, loss_cal_segment, penalty_seq_loss_cal
from config import TrainOptions
from data_loader import DeepSpeakerDataset, DeepSpeakerDataLoader, DeepSpeakerSeqDataset, DeepSpeakerSeqDataLoader
import utils

opt = TrainOptions().parse()
manualSeed = random.randint(1, 10000)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
torch.cuda.manual_seed(manualSeed)
logging = utils.create_output_dir(opt)

print(opt.gpu_ids)
device = torch.device("cuda:{}".format(opt.gpu_ids[0]) if len(opt.gpu_ids) > 0
                      and torch.cuda.is_available() else "cpu")

# data
logging.info("Building dataset.")
if opt.seq_training == 'true':
    opt.data_type = 'train'
    train_dataset = DeepSpeakerSeqDataset(opt,