Exemplo n.º 1
0
def main():
    json_path = os.path.join(args.model_dir)
    params = utils.Params(json_path)

    net = Unet(params.model).cuda()
    # TODO - check exists
    #checkpoint = torch.load('./final.pth.tar')
    #net.load_state_dict(checkpoint)

    train_dataset = AudioDataset(data_type='train')
    test_dataset = AudioDataset(data_type='val')
    train_data_loader = DataLoader(dataset=train_dataset,
                                   batch_size=args.batch_size,
                                   collate_fn=train_dataset.collate,
                                   shuffle=True,
                                   num_workers=4)
    test_data_loader = DataLoader(dataset=test_dataset,
                                  batch_size=args.batch_size,
                                  collate_fn=test_dataset.collate,
                                  shuffle=False,
                                  num_workers=4)

    torch.set_printoptions(precision=10, profile="full")

    # Optimizer
    optimizer = optim.Adam(net.parameters(), lr=1e-3)
    # Learning rate scheduler
    scheduler = ExponentialLR(optimizer, 0.95)

    for epoch in range(args.num_epochs):
        train_bar = tqdm(train_data_loader)
        for input in train_bar:
            train_mixed, train_clean, seq_len = map(lambda x: x.cuda(), input)
            mixed = stft(train_mixed).unsqueeze(dim=1)
            real, imag = mixed[..., 0], mixed[..., 1]
            out_real, out_imag = net(real, imag)
            out_real, out_imag = torch.squeeze(out_real,
                                               1), torch.squeeze(out_imag, 1)
            out_audio = istft(out_real, out_imag, train_mixed.size(1))
            out_audio = torch.squeeze(out_audio, dim=1)
            for i, l in enumerate(seq_len):
                out_audio[i, l:] = 0
            librosa.output.write_wav(
                'mixed.wav', train_mixed[0].cpu().data.numpy()
                [:seq_len[0].cpu().data.numpy()], 16000)
            librosa.output.write_wav(
                'clean.wav', train_clean[0].cpu().data.numpy()
                [:seq_len[0].cpu().data.numpy()], 16000)
            librosa.output.write_wav(
                'out.wav', out_audio[0].cpu().data.numpy()
                [:seq_len[0].cpu().data.numpy()], 16000)
            loss = wSDRLoss(train_mixed, train_clean, out_audio)
            print(epoch, loss)
            optimizer.zero_grad()
            loss.backward()

            optimizer.step()
        scheduler.step()
    torch.save(net.state_dict(), './final.pth.tar')
Exemplo n.º 2
0
def main():
    json_path = os.path.join(args.conf)
    params = utils.Params(json_path)

    net = Unet(params.model).cuda()
    # TODO - check exists
    # if os.path.exists('./ckpt/final.pth.tar'):
    #     checkpoint = torch.load('./ckpt/final.pth.tar')
    #     net.load_state_dict(checkpoint)

    train_dataset = AudioDataset(data_type='train')
    # test_dataset = AudioDataset(data_type='val')
    train_data_loader = DataLoader(dataset=train_dataset, batch_size=args.batch_size,
                                   collate_fn=train_dataset.collate, shuffle=True, num_workers=0)
    # test_data_loader = DataLoader(dataset=test_dataset, batch_size=args.batch_size,
    #                               collate_fn=test_dataset.collate, shuffle=False, num_workers=4)

    # torch.set_printoptions(precision=10, profile="full")

    # Optimizer
    optimizer = optim.Adam(net.parameters(), lr=1e-2)
    # Learning rate scheduler
    scheduler = ExponentialLR(optimizer, 0.996)

    if not os.path.exists('ckpt'): # model save dir
        os.mkdir('ckpt')

    for epoch in range(1, args.num_epochs+1):
        train_bar = tqdm(train_data_loader, ncols=60)
        loss_sum = 0.0
        step_cnt = 0
        for input_ in train_bar:
            train_mixed, train_clean, seq_len = map(lambda x: x.cuda(), input_)
            mixed = stft(train_mixed).unsqueeze(dim=1)
            real, imag = mixed[..., 0], mixed[..., 1]
            out_real, out_imag = net(real, imag)
            out_real, out_imag = torch.squeeze(out_real, 1), torch.squeeze(out_imag, 1)
            out_audio = istft(out_real, out_imag, train_mixed.size(1))
            out_audio = torch.squeeze(out_audio, dim=1)
            for i, l in enumerate(seq_len):
                out_audio[i, l:] = 0
            # librosa.output.write_wav('mixed.wav', train_mixed[0].cpu().data.numpy()[:seq_len[0].cpu().data.numpy()], 16000)
            # librosa.output.write_wav('clean.wav', train_clean[0].cpu().data.numpy()[:seq_len[0].cpu().data.numpy()], 16000)
            # librosa.output.write_wav('out.wav', out_audio[0].cpu().data.numpy()[:seq_len[0].cpu().data.numpy()], 16000)
            loss = wSDRLoss(train_mixed, train_clean, out_audio)
            # print(epoch, loss.item(), end='', flush=True)
            loss_sum += loss.item()
            step_cnt += 1
            optimizer.zero_grad()
            loss.backward()

            optimizer.step()

        avg_loss = loss_sum / step_cnt
        print('epoch %d> Avg_loss: %.6f.\n' % (epoch, avg_loss))
        scheduler.step()
        if epoch %20 == 0:
            torch.save(net.state_dict(), './ckpt/step%05d.pth.tar' % epoch)
Exemplo n.º 3
0
    def __init__(self,
                 input_channels=1,
                 num_classes=1,
                 num_filters=None,
                 latent_levels=1,
                 latent_dim=2,
                 initializers=None,
                 no_convs_fcomb=4,
                 image_size=(1, 128, 128),
                 beta=10.0,
                 reversible=False):
        super(ProbabilisticUnet, self).__init__()
        self.input_channels = input_channels
        self.num_classes = num_classes
        self.num_filters = num_filters
        self.latent_dim = latent_dim
        self.no_convs_per_block = 3
        self.no_convs_fcomb = no_convs_fcomb
        self.initializers = {'w': 'he_normal', 'b': 'normal'}
        self.z_prior_sample = 0

        self.unet = Unet(self.input_channels,
                         self.num_classes,
                         self.num_filters,
                         initializers=self.initializers,
                         apply_last_layer=False,
                         padding=True,
                         reversible=reversible).to(device)
        self.prior = AxisAlignedConvGaussian(
            self.input_channels,
            self.num_filters,
            self.no_convs_per_block,
            self.latent_dim,
            initializers=self.initializers).to(device)
        self.posterior = AxisAlignedConvGaussian(
            self.input_channels,
            self.num_filters,
            self.no_convs_per_block,
            self.latent_dim,
            initializers=self.initializers,
            posterior=True).to(device)
        self.fcomb = Fcomb(self.num_filters,
                           self.latent_dim,
                           self.input_channels,
                           self.num_classes,
                           self.no_convs_fcomb,
                           initializers={
                               'w': 'orthogonal',
                               'b': 'normal'
                           },
                           use_tile=True).to(device)

        self.last_conv = Conv2D(32,
                                num_classes,
                                kernel_size=1,
                                activation=torch.nn.Identity,
                                norm=torch.nn.Identity)
Exemplo n.º 4
0
def unet_tf_pth(checkpoint_path, pth_output_path):
    model = Unet().eval()
    state_dict = model.state_dict()

    reader = tf.train.NewCheckpointReader(checkpoint_path)

    pth_keys = state_dict.keys()
    keys = sorted(reader.get_variable_to_shape_map().keys())
    print(keys)
    print(pth_keys)
Exemplo n.º 5
0
def main(unused_argv):
    tf.logging.set_verbosity(3)

    try:
        config = process_config()

    except:
        exit(0)

    create_dirs(config, [config.checkpoint_dir, config.evaluate_dir, config.presentation_dir, config.summary_dir])

    session = tf.Session()
    K.set_session(session)

    if config.mode == "evaluate":
        model = Unet(config, is_evaluating=True)
        trainer = Trainer(config, None, None, model, session)

        sat_data = [Aracati.load_data(file, is_grayscale=False) for file in
                    sorted(glob.glob("./datasets/aracati/test/input/*.png"))]
        sat_data = [sat_data[i:i+1] for i in range(len(sat_data))]

        model.load(session)
        trainer.evaluate_data(sat_data, Aracati.save_data)

    else:
        data = Aracati(config)
        model = Unet(config)
        logger = Logger(config, session)
        trainer = Trainer(config, data, logger, model, session)

        if config.mode == "restore":
            model.load(session)

        trainer.train()
def get_model(name, in_channels, out_channels, **kwargs):
    if name == "unet":
        return Unet(in_channels, out_channels)
    if name == "baby-unet":
        return BabyUnet(in_channels, out_channels)
    if name == "dncnn":
        return DnCNN(in_channels, out_channels)
    if name == "convolution":
        return SingleConvolution(in_channels, out_channels, kwargs["width"])
Exemplo n.º 7
0
def get_model(name, in_channels, out_channels, **kwargs):
    if name == 'unet':
        return Unet(in_channels, out_channels)
    if name == 'baby-unet':
        return BabyUnet(in_channels, out_channels)
    if name == 'dncnn':
        return DnCNN(in_channels, out_channels)
    if name == 'convolution':
        return SingleConvolution(in_channels, out_channels, kwargs['width'])
Exemplo n.º 8
0
def test_unet_gradients(basic_image):
    labels = basic_image
    model = Unet(num_classes=1, depth=1)
    loss_fn = tf.keras.losses.CategoricalCrossentropy()
    with tf.GradientTape() as tape:
        predictions = model(basic_image)
        loss = loss_fn(labels, predictions)

    gradients = tape.gradient(loss, model.trainable_variables)
    assert all([g is not None for g in gradients])
Exemplo n.º 9
0
def main():

    videos, audios = get_data("v1", 1)
    print("Data Loaded")
    # if os.path.exists('./Unet.pt'):
    #    unet = torch.load('./Unet.pt')
      #  frame_discriminator = torch.load()
    unet = Unet(debug=False)
    frame_discriminator = FrameDiscriminator()
    sequence_discriminator = SequenceDiscriminator()
    if cuda:
        print('SAHI JA RHA.......')
        unet = unet.cuda()
        frame_discriminator = frame_discriminator.cuda()
        sequence_discriminator = sequence_discriminator.cuda()
    # if torch.cuda.device_count() > 1:
        # print("Using ", torch.cuda.device_count(), " GPUs!")
        # unet = nn.DataParallel(unet)
        # frame_discriminator = nn.DataParallel(frame_discriminator)
        # sequence_discriminator = nn.DataParallel(sequence_discriminator)
    train(audios, videos, unet, frame_discriminator, sequence_discriminator)
Exemplo n.º 10
0
def single_model_factory(model_name, C):
    name = model_name.strip().upper()
    if name == 'SIAMUNET_CONC':
        from models.siamunet_conc import SiamUnet_conc
        return SiamUnet_conc(C.num_feats_in, 2)
    elif name == 'SIAMUNET_DIFF':
        from models.siamunet_diff import SiamUnet_diff
        return SiamUnet_diff(C.num_feats_in, 2)
    elif name == 'EF':
        from models.unet import Unet
        return Unet(C.num_feats_in, 2)
    else:
        raise NotImplementedError("{} is not a supported architecture".format(model_name))
Exemplo n.º 11
0
 def test_init(self):
     unet_config = UnetConfig(input_size=(16, 16, 3),
                              filters=10,
                              dropout=0.6,
                              batchnorm=False)
     unet = Unet(config=unet_config)
     unet.compile(loss="binary_crossentropy", metrics=["accuracy"])
     unet.summary()
Exemplo n.º 12
0
def main():
    json_path = os.path.join(args.model_dir)
    params = utils.Params(json_path)

    net = Unet(params.model).cuda()
    # TODO - check exists
    checkpoint = torch.load(args.ckpt)
    net.load_state_dict(checkpoint)

    # train_dataset = AudioDataset(data_type='train')
    test_dataset = AudioDataset(data_type='val')
    # train_data_loader = DataLoader(dataset=train_dataset, batch_size=args.batch_size,
    #                                collate_fn=train_dataset.collate, shuffle=True, num_workers=0)
    test_data_loader = DataLoader(dataset=test_dataset, batch_size=1,
                                  collate_fn=test_dataset.collate, shuffle=False, num_workers=0)

    torch.set_printoptions(precision=10, profile="full")


    train_bar = tqdm(test_data_loader, ncols=60)
    cnt = 1
    with torch.no_grad():
        for input_ in train_bar:
            train_mixed, train_clean, seq_len = map(lambda x: x.cuda(), input_)
            mixed = stft(train_mixed).unsqueeze(dim=1)
            real, imag = mixed[..., 0], mixed[..., 1]
            out_real, out_imag = net(real, imag)
            out_real, out_imag = torch.squeeze(out_real, 1), torch.squeeze(out_imag, 1)
            out_audio = istft(out_real, out_imag, train_mixed.size(1))
            out_audio = torch.squeeze(out_audio, dim=1)
            for i, l in enumerate(seq_len):
                out_audio[i, l:] = 0
            # librosa.output.write_wav('mixed.wav', train_mixed[0].cpu().data.numpy()[:seq_len[0].cpu().data.numpy()], 16000)
            # librosa.output.write_wav('clean.wav', train_clean[0].cpu().data.numpy()[:seq_len[0].cpu().data.numpy()], 16000)
            sf.write('enhanced_testset/enhanced_%03d.wav' % cnt, np.array(out_audio[0].cpu().data.numpy()[:seq_len[0].cpu().data.numpy()], dtype=np.float32), 16000,)
            cnt += 1
Exemplo n.º 13
0
 def _build_model(self, model_config: Bunch) -> NoReturn:
     assert model_config.type in (
         "care", "n2v",
         "pn2v"), "Model type must be either care, n2v, or pn2v"
     self.model_type = model_config.type
     noise_histogram = np.load(model_config.noise_model_path)
     self.noise_model = NoiseModel(noise_histogram)
     self.unet = Unet(model_config.num_classes, model_config.depth,
                      model_config.initial_filters)
     self.optimizer = Adam()
     self.loss_fn = self.loss_pn2v
     self.model_path = model_config.model_path
     os.makedirs(self.model_path)
     if self.model_type in ("care", "n2v"):
         self.noise_fn = self.loss_n2v
Exemplo n.º 14
0
def test(args):

    # Setup Dataloader
    data_json = json.load(open('config.json'))
    data_path = data_json[args.dataset]['data_path']

    t_loader = SaltLoader(data_path, split="test")
    test_df=t_loader.test_df
    test_loader = data.DataLoader(t_loader, batch_size=args.batch_size, num_workers=8)

    # load Model
    if args.arch=='unet':
        model = Unet(start_fm=16)
    else:
        model=Unet_upsample(start_fm=16)
    model_path = data_json[args.model]['model_path']
    model.load_state_dict(torch.load(model_path)['model_state'])
    model.cuda()
    total = sum([param.nelement() for param in model.parameters()])
    print('Number of params: %.2fM' % (total / 1e6))

    #test
    pred_list=[]
    for images in test_loader:
        images = Variable(images.cuda())
        y_preds = model(images)
        y_preds_shaped = y_preds.reshape(-1,  args.img_size_target, args.img_size_target)
        for idx in range(args.batch_size):
            y_pred = y_preds_shaped[idx]
            pred = torch.sigmoid(y_pred)
            pred = pred.cpu().data.numpy()
            pred_ori = resize(pred, (args.img_size_ori, args.img_size_ori), mode='constant', preserve_range=True)
            pred_list.append(pred_ori)

    #submit the test image predictions.
    threshold_best=args.threshold
    pred_dict = {idx: RLenc(np.round(pred_list[i] > threshold_best)) for i, idx in
                 enumerate(tqdm_notebook(test_df.index.values))}
    sub = pd.DataFrame.from_dict(pred_dict, orient='index')
    sub.index.names = ['id']
    sub.columns = ['rle_mask']
    sub.to_csv('./results/{}_submission.csv'.format(args.model))
    print("The submission.csv saved in ./results")
Exemplo n.º 15
0
def test(args):

    # Setup Data
    data_json = json.load(open('config.json'))
    x = Variable(torch.randn(32, 1, 128, 128))
    x = x.cuda()

    # load Model
    if args.arch == 'unet':
        model = Unet(start_fm=16)
    else:
        model = Unet_upsample(start_fm=16)
    model_path = data_json[args.model]['model_path']
    model.load_state_dict(torch.load(model_path)['model_state'])
    model.cuda()
    total = sum([param.nelement() for param in model.parameters()])
    print('Number of params: %.2fM' % (total / 1e6))

    #visualize
    y = model(x)
    g = make_dot(y)
    g.render('k')
Exemplo n.º 16
0
    parser.add_argument('--case_test', type=int, default=7)
    parser.add_argument('--linear_factor', type=int, default=1)
    parser.add_argument('--weight_dir', type=str, default='weight_2nets')
    opt = {**vars(parser.parse_args())}

    os.environ['CUDA_VISIBLE_DEVICES'] = opt['gpu_id']
    t0 = time.time()

    device0 = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    device1 = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
    rootDir = '/data/Jinwei/Bayesian_QSM/' + opt['weight_dir']

    # network
    unet3d = Unet(
        input_channels=1,
        output_channels=1,
        num_filters=[2**i for i in range(5, 10)],  # or range(3, 8)
        use_deconv=1,
        flag_rsa=0)
    unet3d.to(device0)
    weights_dict = torch.load(rootDir +
                              '/linear_factor=1_validation=6_test=7_unet3d.pt')
    unet3d.load_state_dict(weights_dict)

    resnet = ResBlock(
        input_dim=2,
        filter_dim=32,
        output_dim=1,
    )
    resnet.to(device1)
    weights_dict = torch.load(rootDir +
                              '/linear_factor=1_validation=6_test=7_resnet.pt')
Exemplo n.º 17
0
        print('Using simulated RDF')
    # parameters
    lr = 7e-4  # 7e-4 for HOBIT
    batch_size = 1
    B0_dir = (0, 0, 1)
    voxel_size = dataLoader_train.voxel_size
    volume_size = dataLoader_train.volume_size

    trainLoader = data.DataLoader(dataLoader_train,
                                  batch_size=batch_size,
                                  shuffle=True)

    # network of HOBIT
    unet3d = Unet(input_channels=1,
                  output_channels=1,
                  num_filters=[2**i for i in range(5, 10)],
                  use_deconv=1,
                  flag_rsa=0)
    unet3d.to(device0)
    weights_dict = torch.load(rootDir + '/weight_2nets/unet3d_fine.pt')
    # weights_dict = torch.load(rootDir+'/weight_2nets/linear_factor=1_validation=6_test=7_unet3d.pt')
    unet3d.load_state_dict(weights_dict)

    # QSMnet
    unet3d_ = Unet(input_channels=1,
                   output_channels=1,
                   num_filters=[2**i for i in range(5, 10)],
                   use_deconv=1,
                   flag_rsa=0)
    unet3d_.to(device0)
    weights_dict = torch.load(rootDir +
Exemplo n.º 18
0
    voxel_size = dataLoader_train.voxel_size
    volume_size = dataLoader_train.volume_size
    S = SMV_kernel(volume_size, voxel_size, radius=5)
    D = dipole_kernel(volume_size, voxel_size, B0_dir)
    D = np.real(S * D)

    trainLoader = data.DataLoader(dataLoader_train,
                                  batch_size=batch_size,
                                  shuffle=True)

    # # network
    unet3d = Unet(input_channels=1,
                  output_channels=2,
                  num_filters=[2**i for i in range(3, 8)],
                  bilateral=1,
                  use_deconv=1,
                  use_deconv2=1,
                  renorm=1,
                  flag_r_train=flag_r_train)
    # unet3d = Unet(
    #     input_channels=1,
    #     output_channels=2,
    #     num_filters=[2**i for i in range(3, 8)],
    #     flag_rsa=2
    # )
    unet3d.to(device)
    if flag_init == 0:
        weights_dict = torch.load(
            rootDir +
            '/weight/weights_sigma={0}_smv={1}_mv8'.format(sigma, 1) + '.pt')
    else:
Exemplo n.º 19
0
print(data.shape)

####################################################
#           PREPARE Noise Model
####################################################

histogram = np.load(path + args.histogram)

# Create a NoiseModel object from the histogram.
noiseModel = hist_noise_model.NoiseModel(histogram)

####################################################
#           CREATE AND TRAIN NETWORK
####################################################

net = Unet(800, depth=args.netDepth)

# Split training and validation image.
my_train_data = data[:-5].copy()
np.random.shuffle(my_train_data)
my_val_data = data[-5:].copy()
np.random.shuffle(my_val_data)

# Start training.
train_hist, val_hist = training.train_network(
    net=net,
    train_data=my_train_data,
    val_data=my_val_data,
    postfix=args.name,
    directory=path,
    noise_model=noiseModel,
Exemplo n.º 20
0
        D = dipole_kernel(patchSize_padding, voxel_size, B0_dir)
        # S = SMV_kernel(patchSize, voxel_size, radius=5)
        # D = np.real(S * D)
    else:
        B0_dir = (0, 0, 1)
        patchSize = (64, 64, 64)
        # patchSize_padding = (64, 64, 128)
        patchSize_padding = patchSize
        extraction_step = (21, 21, 21)
        voxel_size = (1, 1, 1)
        D = dipole_kernel(patchSize_padding, voxel_size, B0_dir)

    # network
    unet3d = Unet(
        input_channels=1,
        output_channels=1,
        num_filters=[2**i for i in range(5, 10)],  # or range(3, 8)
        use_deconv=1,
        flag_rsa=0)

    resnet = ResBlock(
        input_dim=2,
        filter_dim=32,
        output_dim=1,
    )

    unet3d.to(device)
    resnet.to(device)

    # optimizer
    optimizer = optim.Adam(list(unet3d.parameters()) +
                           list(resnet.parameters()),
Exemplo n.º 21
0
    fName = fName + '_probUnet'
    net = cFlowNet(input_channels=1,
                   num_classes=1,
                   num_filters=[32, 64, 128, 256],
                   latent_dim=6,
                   no_convs_fcomb=4,
                   norm=True,
                   flow=args.flow)
elif args.unet:
    print("Using Det. Unet")
    fName = fName + '_Unet'
    net = Unet(input_channels=1,
               num_classes=1,
               num_filters=[32, 64, 128, 256],
               apply_last_layer=True,
               padding=True,
               norm=True,
               initializers={
                   'w': 'he_normal',
                   'b': 'normal'
               })
    criterion = nn.BCELoss(size_average=False)
else:
    print("Choose a model.\nAborting....")
    sys.exit()

if not os.path.exists('logs'):
    os.mkdir('logs')

logFile = 'logs/' + fName + '.txt'
makeLogFile(logFile)
Exemplo n.º 22
0
def build_Unet_model(C):
    from models.unet import Unet
    return Unet(6, 2)
Exemplo n.º 23
0
import tensorflow as tf
import cv2
from models.unet import Unet
from data_augmentation.data_augmentation import DataAugmentation
import numpy as np

gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(gpus[0], True)

# Initialize
IMAGE_PATH = "dataset/Original/Testing/"
MASK_PATH = "dataset/MASKS/Testing/"
IMAGE_FILE = "Frame00314-org"

model = Unet(input_shape=(224, 224, 1)).build()
model.load_weights("models/model_weight.h5")
model.summary()
print("yeah")


def convert_to_tensor(numpy_image):
    numpy_image = np.expand_dims(numpy_image, axis=2)
    numpy_image = np.expand_dims(numpy_image, axis=0)
    tensor_image = tf.convert_to_tensor(numpy_image)
    return tensor_image


def predict(image):
    process_obj = DataAugmentation(input_size=224, output_size=224)
    image_processed = process_obj.data_process_test(image)
    tensor_image = convert_to_tensor(image_processed)
Exemplo n.º 24
0
    # x = torch.rand((256, 1024, block_mem)).cuda()
    # x = torch.rand((2, 2)).cuda()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    rootDir = '/data/Jinwei/Bayesian_QSM/'

    cfl = opt['flag_cfl']
    val = opt['case_validation']
    test = opt['case_test']

    # network
    if opt['flag_cfl'] == 0:
        unet3d = Unet(
            input_channels=1,
            output_channels=1,
            num_filters=[2**i for i in range(5, 10)],  # or range(3, 8)
            use_deconv=1)

    elif opt['flag_cfl'] == 1:
        unet3d = unetVggBNNAR1CLF(
            input_channels=1,
            output_channels=1,
            num_filters=[2**i for i in range(5, 10)],  # or range(3, 8)
            use_deconv=1)
    elif opt['flag_cfl'] == 2:
        unet3d = unetVggBNNAR1CLFRes(
            input_channels=1,
            output_channels=1,
            num_filters=[2**i for i in range(5, 10)],  # or range(3, 8)
            use_deconv=1)
        D = np.real(S * D)
    else:
        B0_dir = (0, 0, 1)
        patchSize = (64, 64, 64)
        # patchSize_padding = (64, 64, 128)
        patchSize_padding = patchSize
        extraction_step = (21, 21, 21)
        voxel_size = (1, 1, 1)
        D = dipole_kernel(patchSize_padding, voxel_size, B0_dir)

    # network
    unet3d = Unet(
        input_channels=1,
        output_channels=2,
        num_filters=[2**i for i in range(5, 10)],  # or range(3, 8)
        bilateral=1,
        use_deconv=1,
        use_deconv2=1,
        renorm=1,
        flag_r_train=0)
    unet3d.to(device)

    # optimizer
    optimizer = optim.Adam(unet3d.parameters(), lr=lr, betas=(0.5, 0.999))
    ms = [0.3, 0.5, 0.7, 0.9]
    ms = [np.floor(m * niter).astype(int) for m in ms]
    scheduler = MultiStepLR(optimizer, milestones=ms, gamma=0.2)

    # logger
    logger = Logger('logs', rootDir, opt['flag_rsa'], opt['case_validation'],
                    opt['case_test'])
Exemplo n.º 26
0
def get_model(model_name, in_channel, n_classes):
    return {
        'unetresnet': UNetResNet(in_channel, n_classes),
        'unet': Unet(in_channel, n_classes),
        'unetplus': UnetPlus(in_channel, n_classes)
    }[model_name]
Exemplo n.º 27
0
def test_unet_various_depths(depth, basic_image):
    model = Unet(1, depth)
    prediction = model(basic_image)
    assert prediction.get_shape() == (2, 128, 128, 1)
    image = cv2.resize(image, (cfg.input_shape[0], cfg.input_shape[1]))
    image = np.expand_dims(image, axis=0)
    # 3、GPU设置
    session_config = tf.ConfigProto(
        device_count={'GPU': 0},
        gpu_options={
            'allow_growth': 1,
            # 'per_process_gpu_memory_fraction': 0.1,
            'visible_device_list': '0'
        },
        allow_soft_placement=True)  ##这个设置必须有,否则无论如何都会报cudnn不匹配的错误,BUG十分隐蔽,真是智障

    with tf.Session(config=session_config) as sess:
        # 1、定义model
        model = Unet(sess, cfg, is_train=is_train)

        # 2、恢复模型
        model.restore(model_restore_name)

        # 3、预测
        since = time.time()
        pre = model.predict(image)
        seconds = time.time() - since

        # 4、调整维度
        pre_list = np.split(pre[0], batch_size, axis=0)
        image = np.squeeze(image, axis=0)
        pres = np.squeeze(pre_list, axis=0)
        pres = np.expand_dims(pres, axis=-1)
        result = np.multiply(pres, image)
Exemplo n.º 29
0
def unet():
    model = Unet(num_classes=2, depth=3)
    return model
Exemplo n.º 30
0
    radius = 5
    sigma_sq = (3*10**(-5))**2
    Lambda_tv = 10
    D = dipole_kernel(patchSize, voxel_size, B0_dir)
    S = SMV_kernel(patchSize, voxel_size, radius)
    D = S*D

    GPU = 2

    if GPU == 1:
        rootDir = '/home/sdc/Jinwei/BayesianQSM'
    elif GPU == 2:
        rootDir = '/data/Jinwei/Bayesian_QSM/weight'

    # network
    unet3d = Unet(input_channels=1, output_channels=2, num_filters=[2**i for i in range(5, 10)])
    unet3d.to(device)
    # optimizer
    optimizer = optim.Adam(unet3d.parameters(), lr = lr, betas=(0.5, 0.999))
    # dataloader
    dataLoader = QSM_data_loader2(GPU=GPU, patchSize=patchSize)
    trainLoader = data.DataLoader(dataLoader, batch_size=1, shuffle=False)

    epoch = 0
    while epoch < niter:

        epoch += 1
        for idx, (input_RDFs, in_loss_RDFs, QSMs, Masks, \
            fidelity_Ws, gradient_Ws, flag_COSMOS) in enumerate(trainLoader):

            input_RDFs = input_RDFs[0, ...].to(device, dtype=torch.float)