Ejemplo n.º 1
0
 def __init__(self, params, use_cuda=False):
     super(MModel, self).__init__()
     self.params = params
     self.src_mask_delta_UNet = UNet(params, 3, [64] * 2 + [128] * 9,
                                     [128] * 4 + [32])
     self.src_mask_delta_Conv = nn.Conv2d(32,
                                          11,
                                          kernel_size=3,
                                          stride=1,
                                          padding=1,
                                          padding_mode='replicate')
     self.fg_UNet = UNet(params, 30, [64] * 2 + [128] * 9, [128] * 4 + [64])
     self.fg_tgt_Conv = nn.Conv2d(64,
                                  3,
                                  kernel_size=3,
                                  stride=1,
                                  padding=1,
                                  padding_mode='replicate')
     self.fg_mask_Conv = nn.Conv2d(64,
                                   1,
                                   kernel_size=3,
                                   stride=1,
                                   padding=1,
                                   padding_mode='replicate')
     self.bg_UNet = UNet(params, 4, [64] * 2 + [128] * 9, [128] * 4 + [64])
     self.bg_tgt_Conv = nn.Conv2d(64,
                                  3,
                                  kernel_size=3,
                                  stride=1,
                                  padding=1,
                                  padding_mode='replicate')
     self.use_cuda = use_cuda
Ejemplo n.º 2
0
    def build_model(self):
        """Build generator and discriminator."""
        if self.model_type == 'UNet':
            self.unet = UNet(n_channels=1, n_classes=1)
        elif self.model_type == 'R2U_Net':
            self.unet = R2U_Net(img_ch=3, output_ch=1, t=self.t)  # TODO: changed for green image channel
        elif self.model_type == 'AttU_Net':
            self.unet = AttU_Net(img_ch=1, output_ch=1)
        elif self.model_type == 'R2AttU_Net':
            self.unet = R2AttU_Net(img_ch=3, output_ch=1, t=self.t)
        elif self.model_type == 'Iternet':
            self.unet = Iternet(n_channels=1, n_classes=1)
        elif self.model_type == 'AttUIternet':
            self.unet = AttUIternet(n_channels=1, n_classes=1)
        elif self.model_type == 'R2UIternet':
            self.unet = R2UIternet(n_channels=3, n_classes=1)
        elif self.model_type == 'NestedUNet':
            self.unet = NestedUNet(in_ch=1, out_ch=1)
        elif self.model_type == "AG_Net":
            self.unet = AG_Net(n_classes=1, bn=True, BatchNorm=False)

        self.optimizer = optim.Adam(list(self.unet.parameters()),
                                    self.lr,
                                    betas=tuple(self.beta_list))
        self.unet.to(self.device)
Ejemplo n.º 3
0
def log_param_and_grad(net: UNet, writer: tensorboardX.SummaryWriter, step):
    for name, param in net.named_parameters():
        writer.add_histogram(f"grad/{name}",
                             param.grad.detach().cpu().numpy(), step)
        writer.add_histogram(f"grad_norm/{name}",
                             np.sqrt((param**2).sum().detach().cpu().numpy()),
                             step)
        writer.add_histogram(f"param/{name}",
                             param.detach().cpu().numpy(), step)
Ejemplo n.º 4
0
def get_net(params):
    if params['network'].lower() == 'pan':
        net = PAN(params)
    elif params['network'].lower() == 'shortres':
        net = ShortRes(params)
    elif params['network'].lower() == 'unet':
        net = UNet(params)

    return net
def my_app():

    batch_size_app = 1
    loader_ap = loaders(batch_size_app, 2)

    output = UNet()
    output.cuda()

    #output.load_state_dict(torch.load('/home/daisylabs/aritra_project/results/output.pth'))

    output.load_state_dict(
        torch.load('/home/daisylabs/aritra_project/results/output_best.pth'))

    output.eval()

    with torch.set_grad_enabled(False):
        for u, (inputs, targets) in enumerate(loader_ap):
            if (u == 0):
                inputs = inputs.reshape((batch_size_app, 3, 256, 256))
                targets = targets.reshape((batch_size_app, 256, 256, 256))

                out_1, out_2 = output(inputs)

                out_1 = out_1.reshape((batch_size_app, 256, 256, 256))

                targets = targets.cpu().numpy()
                out_1 = out_1.cpu().numpy()

                tgt_shp = targets.shape[1]

                for slice_number in range(tgt_shp):
                    targets_1 = targets[0][slice_number].reshape((256, 256))
                    out_1_1 = out_1[0][slice_number].reshape((256, 256))

                    plt.figure()
                    plt.subplot(1, 2, 1)
                    plt.title('Original Slice')
                    plt.imshow(targets_1,
                               cmap=plt.get_cmap('gray'),
                               vmin=0,
                               vmax=1)
                    plt.subplot(1, 2, 2)
                    plt.title('Reconstructed Slice')
                    plt.imshow(out_1_1,
                               cmap=plt.get_cmap('gray'),
                               vmin=0,
                               vmax=1)

                    plt.savefig(
                        '/home/daisylabs/aritra_project/results/slices/%d.png'
                        % (slice_number + 1, ))

            else:
                break
Ejemplo n.º 6
0
def main():

    # filename = 'mixture2.wav'
    filename = 'aimer/1-02 花の唄.wav'
    # filename = 'amazarashi/03 季節は次々死んでいく.wav'
    batch_length = 512
    fs = 44100
    frame_size = 4096
    shift_size = 2048
    modelname = 'model/fs%d_frame%d_shift%d_batch%d.model' % (
        fs, frame_size, shift_size, batch_length)
    statname = 'stat/fs%d_frame%d_shift%d_batch%d.npy' % (
        fs, frame_size, shift_size, batch_length)
    max_norm = float(np.load(statname))

    # load network
    model = UNet()
    model.load_state_dict(torch.load(modelname))
    model.eval()
    torch.backends.cudnn.benchmark = True

    # gpu
    if torch.cuda.is_available():
        model.cuda()
    else:
        print('gpu is not avaiable.')
        sys.exit(1)

    # load wave file
    wave = load(filename, sr=fs)[0]
    spec = stft(wave, frame_size, shift_size)
    soft_vocal, soft_accom, hard_vocal, hard_accom = extract(
        spec, model, max_norm, fs, frame_size, shift_size)
    write_wav(os.path.splitext(
        os.path.basename(filename))[0] + '_original.wav', wave, fs)
    write_wav(os.path.splitext(
        os.path.basename(filename))[0] + '_soft_vocal.wav', soft_vocal, fs)
    write_wav(os.path.splitext(
        os.path.basename(filename))[0] + '_soft_accom.wav', soft_accom, fs)
    write_wav(os.path.splitext(
        os.path.basename(filename))[0] + '_hard_vocal.wav', hard_vocal, fs)
    write_wav(os.path.splitext(
        os.path.basename(filename))[0] + '_hard_accom.wav', hard_accom, fs)
Ejemplo n.º 7
0
def TrainUNet(X,
              Y,
              model_=None,
              optimizer_=None,
              epoch=40,
              alpha=0.001,
              gpu_id=0,
              loop=1,
              earlystop=True):
    assert (len(X) == len(Y))
    d_time = datetime.datetime.now().strftime("%m-%d-%H-%M-%S")

    # 1. Model load.

    # print(sum(p.data.size for p in model.unet.params()))
    if model_ is not None:
        model = Regressor(model_)
        print("## model loaded.")
    else:
        model = Regressor(UNet())

    model.compute_accuracy = False

    if gpu_id >= 0:
        model.to_gpu(gpu_id)

    # 2. optimizer load.

    if optimizer_ is not None:
        opt = optimizer_
        print("## optimizer loaded.")
    else:
        opt = optimizers.Adam(alpha=alpha)
        opt.setup(model)

    # 3. Data Split.
    dataset = Unet_DataSet(X, Y)
    print("# number of patterns", len(dataset))

    train, valid = \
        split_dataset_random(dataset, int(len(dataset) * 0.8), seed=0)

    # 4. Iterator
    train_iter = SerialIterator(train, batch_size=C.BATCH_SIZE)
    test_iter = SerialIterator(valid,
                               batch_size=C.BATCH_SIZE,
                               repeat=False,
                               shuffle=False)

    # 5. config train, enable backprop
    chainer.config.train = True
    chainer.config.enable_backprop = True

    # 6. UnetUpdater
    updater = UnetUpdater(train_iter, opt, model, device=gpu_id)

    # 7. EarlyStopping
    if earlystop:
        stop_trigger = triggers.EarlyStoppingTrigger(
            monitor='validation/main/loss',
            max_trigger=(epoch, 'epoch'),
            patients=5)
    else:
        stop_trigger = (epoch, 'epoch')

    # 8. Trainer
    trainer = training.Trainer(updater, stop_trigger, out=C.PATH_TRAINRESULT)

    # 8.1. UnetEvaluator
    trainer.extend(UnetEvaluator(test_iter, model, device=gpu_id))

    trainer.extend(SaveRestore(),
                   trigger=triggers.MinValueTrigger('validation/main/loss'))

    # 8.2. Extensions LogReport
    trainer.extend(extensions.LogReport())

    # 8.3. Extension Snapshot
    # trainer.extend(extensions.snapshot(filename='snapshot_epoch-{.updater.epoch}'))
    # trainer.extend(extensions.snapshot_object(model.unet, filename='loop' + str(loop) + '.model'))

    # 8.4. Print Report
    trainer.extend(extensions.observe_lr())
    trainer.extend(
        extensions.PrintReport([
            'epoch', 'main/loss', 'validation/main/loss', 'elapsed_time', 'lr'
        ]))

    # 8.5. Extension Graph
    trainer.extend(
        extensions.PlotReport(['main/loss', 'validation/main/loss'],
                              x_key='epoch',
                              file_name='loop-' + str(loop) + '-loss' +
                              d_time + '.png'))
    # trainer.extend(extensions.dump_graph('main/loss'))

    # 8.6. Progree Bar
    trainer.extend(extensions.ProgressBar())

    # 9. Trainer run
    trainer.run()

    chainer.serializers.save_npz(C.PATH_TRAINRESULT / ('loop' + str(loop)),
                                 model.unet)
    return model.unet, opt
Ejemplo n.º 8
0
                        required=True,
                        help='Image Directory')
    parser.add_argument('-g',
                        '--gpu',
                        type=int,
                        default=0,
                        help='GPU selection')
    parser.add_argument('-r',
                        '--resolution',
                        type=int,
                        required=True,
                        help='Resolution for Square Image')
    args = parser.parse_args()

    # Height
    model_h = HUNet(128)
    pretrained_model_h = torch.load(
        '/content/drive/My Drive/Colab Notebooks/AI_Australia/Models/model_ep_48.pth.tar'
    )

    # Weight
    model_w = UNet(128, 32, 32)
    pretrained_model_w = torch.load(
        '/content/drive/My Drive/Colab Notebooks/AI_Australia/Models/model_ep_37.pth.tar'
    )

    model_h.load_state_dict(pretrained_model_h["state_dict"])
    model_w.load_state_dict(pretrained_model_w["state_dict"])

    if torch.cuda.is_available():
        model = model_w.cuda(args.gpu)
Ejemplo n.º 9
0
def main(args):

    modelname = os.path.join(args.dst_dir, os.path.splitext(args.src_file)[0])

    if not os.path.exists(args.dst_dir):
        os.makedirs(args.dst_dir)

    # define transforms
    max_norm = float(np.load(args.stats_file))
    transform = transforms.Compose([
        lambda x: x / max_norm])

    # load data
    with open(args.src_file, 'r') as f:
        files = f.readlines()
    filelist = [file.replace('\n', '') for file in files]

    # define sampler
    index = list(range(len(filelist)))
    train_index = sample(index, round(len(index) * args.ratio))
    valid_index = list(set(index) - set(train_index))
    train_sampler = SubsetRandomSampler(train_index)
    valid_sampler = SubsetRandomSampler(valid_index)

    # define dataloader
    trainset = MagSpecDataset(filelist, transform=transform)
    train_loader = torch.utils.data.DataLoader(
        dataset=trainset,
        batch_size=args.batch_size, shuffle=False, sampler=train_sampler,
        num_workers=args.num_worker)
    valid_loader = torch.utils.data.DataLoader(
        dataset=trainset,
        batch_size=1, shuffle=False, sampler=valid_sampler,
        num_workers=args.num_worker)

    # fix seed
    torch.manual_seed(args.seed)

    # define network
    model = UNet()

    # define optimizer
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    # define loss
    criterion = nn.L1Loss(size_average=False)

    # gpu
    if torch.cuda.is_available():
        model.cuda()
        criterion.cuda()
    else:
        print('gpu is not avaiable.')
        sys.exit(1)

    # training
    for epoch in range(args.num_epoch):

        model.train()
        train_loss = 0.0
        for i, data in enumerate(train_loader):
            inputs, targets, silence = data

            # wrap them in Variable
            inputs = Variable(inputs[:, None, ...]).cuda()
            targets = Variable(targets[:, None, ...]).cuda()
            silence = Variable(silence[:, None, ...]).cuda()

            # zero the parameter gradients
            optimizer.zero_grad()

            outputs = model(inputs)
            batch_train_loss = criterion(
                inputs * outputs * silence, targets * silence)
            batch_train_loss.backward()
            optimizer.step()

            # print statistics
            train_loss += batch_train_loss.item()

        model.eval()
        valid_loss = 0.0
        for i, data in enumerate(valid_loader):
            inputs, targets, silence = data

            # wrap them in Variable
            inputs = Variable(inputs[:, None, ...]).cuda()
            targets = Variable(targets[:, None, ...]).cuda()
            silence = Variable(silence[:, None, ...]).cuda()

            outputs = model(inputs)
            batch_valid_loss = criterion(
                inputs * outputs * silence, targets * silence)

            # print statistics
            valid_loss += batch_valid_loss.item()
        print('[{}/{}] training loss: {:.3f}; validation loss: {:.3f}'.format(
              epoch + 1, args.num_epoch, train_loss, valid_loss))

        # save model
        if epoch % args.num_interval == args.num_interval - 1:
            torch.save(
                model.state_dict(),
                modelname + '_batch{}_ep{}.model'.format(
                    args.batch_length, epoch + 1))

    torch.save(model.state_dict(), modelname + '.model')
Ejemplo n.º 10
0
    elif args.loss == 'mae':
        height_loss = nn.L1Loss()
    elif args.loss == 'huber':
        height_loss = nn.SmoothL1Loss()
    
    train = DataLoader(Images(args.dataset, 'TRAINING.csv', True), 
                       batch_size=args.batch_size, num_workers=8, shuffle=True)
    
    valid = DataLoader(Images(args.dataset, 'VAL.csv', True), 
                       batch_size=1, num_workers=8, shuffle=False)
    
    
    print("Training on " + str(len(train)*args.batch_size) + " images.")
    print("Validating on " + str(len(valid)) + " images.")

    net = UNet(args.min_neuron)
    start_epoch = 0
    
    #pretrained_model = torch.load(glob('models/IMDB_MODEL_06102019_121502/*')[0])
    #state_dict = pretrained_model["state_dict"]
    
    #own_state = net.state_dict()
    
    #for name, param in state_dict.items():
    #    if name not in own_state:
    #         continue
    #    if isinstance(param, Parameter):
            # backwards compatibility for serialized parameters
    #        param = param.data
            
    #    if not (("height_1" in name) or ("height_2" in name)):
Ejemplo n.º 11
0
def main(args):
    torch.backends.cudnn.benchmark = True
    seed_all(args.seed)

    num_classes = 1

    d = Dataset(train_set_size=args.train_set_sz, num_cls=num_classes)
    train = d.train_set
    valid = d.test_set

    net = UNet(in_dim=1, out_dim=4).cuda()
    snake_approx_net = UNet(in_dim=1,
                            out_dim=1,
                            wf=3,
                            padding=True,
                            first_layer_pad=None,
                            depth=4,
                            last_layer_resize=True).cuda()
    best_val_dice = -np.inf

    optimizer = torch.optim.Adam(params=net.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
    snake_approx_optimizer = torch.optim.Adam(
        params=snake_approx_net.parameters(),
        lr=args.lr,
        weight_decay=args.weight_decay)
    scheduler_warmup = GradualWarmupScheduler(optimizer,
                                              multiplier=10,
                                              total_epoch=50,
                                              after_scheduler=None)

    # load model
    if args.ckpt:
        loaded = _pickle.load(open(args.ckpt, 'rb'))
        net.load_state_dict(loaded[0])
        optimizer.load_state_dict(loaded[1])
        snake_approx_net.load_state_dict(loaded[2])
        snake_approx_optimizer.load_state_dict(loaded[3])

    if not os.path.exists(args.log_dir):
        os.makedirs(args.log_dir, exist_ok=True)

    writer = tensorboardX.SummaryWriter(log_dir=args.log_dir)
    snake = SnakePytorch(args.delta, args.batch_sz * args.num_samples,
                         args.num_lines, args.radius)
    snake_eval = SnakePytorch(args.delta, args.batch_sz, args.num_lines,
                              args.radius)
    noises = torch.zeros(
        (args.batch_sz, args.num_samples, args.num_lines, args.radius)).cuda()

    step = 1
    start = timeit.default_timer()
    for epoch in range(1, args.n_epochs + 1):
        for iteration in range(
                1,
                int(np.ceil(train.dataset_sz() / args.batch_sz)) + 1):

            scheduler_warmup.step()

            imgs, masks, onehot_masks, centers, dts_modified, dts_original, jitter_radius, bboxes = \
                train.next_batch(args.batch_sz)

            xs = make_batch_input(imgs)
            xs = torch.cuda.FloatTensor(xs)

            net.train()
            unet_logits = net(xs)

            center_jitters, angle_jitters = [], []
            for img, mask, center in zip(imgs, masks, centers):
                c_j, a_j = get_random_jitter_by_mask(mask, center, [1],
                                                     args.theta_jitter)
                if not args.use_center_jitter:
                    c_j = np.zeros_like(c_j)
                center_jitters.append(c_j)
                angle_jitters.append(a_j)

            center_jitters = np.asarray(center_jitters)
            angle_jitters = np.asarray(angle_jitters)

            # args.radius + 1 because we need additional outermost points for the gradient
            gs_logits_whole_img = unet_logits[:, 3, ...]
            gs_logits, coords_r, coords_c = get_star_pattern_values(
                gs_logits_whole_img,
                None,
                centers,
                args.num_lines,
                args.radius + 1,
                center_jitters=center_jitters,
                angle_jitters=angle_jitters)

            # currently only class 1 is foreground
            # if there's multiple foreground classes use a for loop
            gs = gs_logits[:, :,
                           1:] - gs_logits[:, :, :-1]  # compute the gradient

            noises.normal_(
                0, 1
            )  # noises here is only used for random exploration so no need mirrored sampling
            gs_noisy = torch.unsqueeze(gs, 1) + noises

            def batch_eval_snake(snake, inputs, batch_sz):
                n_inputs = len(inputs)
                assert n_inputs % batch_sz == 0
                n_batches = int(np.ceil(n_inputs / batch_sz))
                ind_sets = []
                for j in range(n_batches):
                    inps = inputs[j * batch_sz:(j + 1) * batch_sz]
                    batch_ind_sets = snake(inps).data.cpu().numpy()
                    ind_sets.append(batch_ind_sets)
                ind_sets = np.concatenate(ind_sets, 0)
                return ind_sets

            gs_noisy = gs_noisy.reshape((args.batch_sz * args.num_samples,
                                         args.num_lines, args.radius))
            ind_sets = batch_eval_snake(snake, gs_noisy,
                                        args.batch_sz * args.num_samples)
            ind_sets = ind_sets.reshape(
                (args.batch_sz * args.num_samples, args.num_lines))
            ind_sets = np.expand_dims(
                smooth_ind(ind_sets, args.smoothing_window), -1)

            # loss layers
            m = torch.nn.LogSoftmax(dim=1)
            loss = torch.nn.NLLLoss()

            # ===========================================================================
            # Inner loop: Train dice loss prediction network
            snake_approx_net.train()
            for _ in range(args.dice_approx_train_steps):

                snake_approx_logits = snake_approx_net(
                    gs_noisy.reshape(args.batch_sz * args.num_samples, 1,
                                     args.num_lines, args.radius).detach())
                snake_approx_train_loss = loss(
                    m(snake_approx_logits.squeeze().transpose(2, 1)),
                    torch.cuda.LongTensor(ind_sets.squeeze()))
                snake_approx_optimizer.zero_grad()
                snake_approx_train_loss.backward()
                snake_approx_optimizer.step()
            # ===========================================================================

            # ===========================================================================
            # Now, minimize the approximate dice loss
            snake_approx_net.eval()

            gt_indices = []
            for mask, center, cj, aj in zip(masks, centers, center_jitters,
                                            angle_jitters):
                gt_ind = mask_to_indices(mask, center, args.radius,
                                         args.num_lines, cj, aj)
                gt_indices.append(gt_ind)
            gt_indices = np.asarray(gt_indices).astype(int)

            gt_indices = gt_indices.reshape((args.batch_sz, args.num_lines))
            gt_indices = torch.cuda.LongTensor(gt_indices)

            snake_approx_logits = snake_approx_net(
                gs.reshape((args.batch_sz, 1, args.num_lines, args.radius)))
            nll_approx_loss = loss(
                m(snake_approx_logits.squeeze().transpose(2, 1)), gt_indices)

            total_loss = nll_approx_loss
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()
            # ===========================================================================

            snake_approx_train_loss = snake_approx_train_loss.data.cpu().numpy(
            )
            nll_approx_loss = nll_approx_loss.data.cpu().numpy()
            total_loss = snake_approx_train_loss + nll_approx_loss

            if step % args.log_freq == 0:
                stop = timeit.default_timer()
                print(f"step={step}\tepoch={epoch}\titer={iteration}"
                      f"\tloss={total_loss}"
                      f"\tsnake_approx_train_loss={snake_approx_train_loss}"
                      f"\tnll_approx_loss={nll_approx_loss}"
                      f"\tlr={optimizer.param_groups[0]['lr']}"
                      f"\ttime={stop-start}")
                start = stop
                writer.add_scalar("total_loss", total_loss, step)
                writer.add_scalar("nll_approx_loss", nll_approx_loss, step)
                writer.add_scalar("lr", optimizer.param_groups[0]["lr"], step)

            if step % args.train_eval_freq == 0:
                train_dice = do_eval(
                    net,
                    snake_eval,
                    train.images,
                    train.masks,
                    train.centers,
                    args.batch_sz,
                    args.num_lines,
                    args.radius,
                    smoothing_window=args.smoothing_window).data.cpu().numpy()
                writer.add_scalar("train_dice", train_dice, step)
                print(
                    f"step={step}\tepoch={epoch}\titer={iteration}\ttrain_eval: train_dice={train_dice}"
                )

            if step % args.val_eval_freq == 0:
                val_dice = do_eval(
                    net,
                    snake_eval,
                    valid.images,
                    valid.masks,
                    valid.centers,
                    args.batch_sz,
                    args.num_lines,
                    args.radius,
                    smoothing_window=args.smoothing_window).data.cpu().numpy()
                writer.add_scalar("val_dice", val_dice, step)
                print(
                    f"step={step}\tepoch={epoch}\titer={iteration}\tvalid_dice={val_dice}"
                )
                if val_dice > best_val_dice:
                    best_val_dice = val_dice
                    _pickle.dump([
                        net.state_dict(),
                        optimizer.state_dict(),
                        snake_approx_net.state_dict(),
                        snake_approx_optimizer.state_dict()
                    ],
                                 open(
                                     os.path.join(args.log_dir,
                                                  'best_model.pth.tar'), 'wb'))
                    f = open(
                        os.path.join(args.log_dir, f"best_val_dice{step}.txt"),
                        'w')
                    f.write(str(best_val_dice))
                    f.close()
                    print(f"better val dice detected.")

            step += 1

    return best_val_dice
Ejemplo n.º 12
0
    ####### CREATE/LOAD TRAINING INFO JSON
    if trainMode.lower() == 'start':
        print('\nSTARTING TRAINING...')
        trainCFG = dict()
        init_epoch = 0
        trainCFG['trainTFRdir'] = trainTFRdir
        trainCFG['valTFRdir'] = valTFRdir
        trainCFG[
            'input_shape'] = input_shape  #[int(x) for x in args.inputShape.split(',')]
        trainCFG['batch_size'] = batch_size
        trainCFG['train_epochs'] = epochs
        trainCFG['initLR'] = initLR
        trainCFG['best_val_loss'] = np.inf

        ##### INITIALIZE MODEL
        model = UNet(input_shape=input_shape)
        model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=initLR),
                      loss=tf.losses.CategoricalCrossentropy())

    elif trainMode.lower() == 'resume':
        print('RESUMING TRAINING FROM: ' + trainInfoPath)
        trainCFG = load_json(trainInfoPath)
        init_epoch = trainCFG['last_epoch'] + 1
        if init_epoch >= trainCFG['train_epochs']:
            raise Exception(
                '\nInitial training epoch value is higher than the max. number of training epochs specified'
            )

        model = tf.keras.models.load_model(lastModelPath)

    dataset_info = load_json(os.path.join(dataDir, 'data_info.json'))
Ejemplo n.º 13
0
                new_file = pad(new_file)
                #new_file = new_file/torch.max(new_file)

                dataset = torch.stack((dataset,new_file))
            else:
                new_file,fs = torchaudio.load(filepath)
                new_file = pad(new_file)
                #new_file = new_file/torch.max(new_file)

                dataset = torch.cat((dataset,new_file.unsqueeze(0)))

            n_files = n_files + 1

print("finished loading: {} files loaded, Total Time: {}".format(n_files, time.time()-start_time))

G = UNet(1,2)
G.cuda()

#load the model
G.load_state_dict(torch.load("g_param.pth"))

G.eval()

results_path = "val_out"

for j in range(dataset.size()[0]):
    input_stereo = dataset[j,:,:].cuda()
    input_wav = torch.mean(input_stereo, dim=0).unsqueeze(0)
    output_wav = G(input_wav.unsqueeze(0)).cpu().detach()
    torchaudio.save(results_path + os.sep + "test_output_" + str(j) + ".wav", output_wav.squeeze(),fs)
Ejemplo n.º 14
0
                dataset = torch.stack((dataset,new_file))
            else:
                new_file,fs = torchaudio.load(filepath)
                new_file = pad(new_file)
                new_file = new_file/torch.max(new_file)

                dataset = torch.cat((dataset,new_file.unsqueeze(0)))

            n_files = n_files + 1

print("finished loading: {} files loaded".format(n_files))

#setup the network (refer to the github page)
#network input/output (N,C,H,W) = (1,1,256,256) => (1,2,256,256)
model = UNet(1,2)
model.cuda()
model.train()

criterion = torch.optim.Adam(model.parameters(), lr = .0001, betas = (.5,.999))

#for each epoch:
keep_training = True
training_losses = []
counter = 1

results_path = "output"

print("training start!")
while keep_training:
    epoch_losses = []
Ejemplo n.º 15
0
def main():
    global args, best_result, output_directory

    # set random seed
    torch.manual_seed(args.manual_seed)
    torch.cuda.manual_seed(args.manual_seed)
    np.random.seed(args.manual_seed)
    random.seed(args.manual_seed)

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        args.batch_size = args.batch_size * torch.cuda.device_count()
    else:
        print("Let's use GPU ", torch.cuda.current_device())

    train_loader, val_loader = create_loader(args)

    if args.resume:
        assert os.path.isfile(args.resume), \
            "=> no checkpoint found at '{}'".format(args.resume)
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
        start_epoch = 0
        # start_epoch = checkpoint['epoch'] + 1
        # best_result = checkpoint['best_result']
        # optimizer = checkpoint['optimizer']

        # solve 'out of memory'
        model = checkpoint['model']
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.lr,
                                     betas=(0.9, 0.999))
        # optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
        print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))

        # clear memory
        del checkpoint
        # del model_dict
        torch.cuda.empty_cache()
    else:
        print("=> creating Model")
        # input_shape = [args.batch_size,3,256,512]
        model = UNet(3, 1)
        print("=> model created.")
        start_epoch = 0

        print('Number of model parameters: {}'.format(
            sum([p.data.nelement() for p in model.parameters()])))

        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.lr,
                                     betas=(0.9, 0.999))
        # optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

        # You can use DataParallel() whether you use Multi-GPUs or not
        model = nn.DataParallel(model).cuda()

    # when training, use reduceLROnPlateau to reduce learning rate
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                               'min',
                                               patience=args.lr_patience)

    # loss function
    criterion = criteria.myL1Loss()
    # criterion = nn.SmoothL1Loss()
    # create directory path
    output_directory = utils.get_output_directory(args)
    if not os.path.exists(output_directory):
        os.makedirs(output_directory)
    best_txt = os.path.join(output_directory, 'best.txt')
    config_txt = os.path.join(output_directory, 'config.txt')

    # write training parameters to config file
    if not os.path.exists(config_txt):
        with open(config_txt, 'w') as txtfile:
            args_ = vars(args)
            args_str = ''
            for k, v in args_.items():
                args_str = args_str + str(k) + ':' + str(v) + ',\t\n'
            txtfile.write(args_str)

    for epoch in range(start_epoch, args.epochs):

        # remember change of the learning rate
        old_lr = 0.0
        # adjust_learning_rate(optimizer,epoch)
        for i, param_group in enumerate(optimizer.param_groups):
            old_lr = float(param_group['lr'])
        print("lr: %f" % old_lr)

        train(train_loader, model, criterion, optimizer,
              epoch)  # train for one epoch
        result, img_merge = validate(val_loader, model,
                                     epoch)  # evaluate on validation set

        # remember best mae and save checkpoint
        is_best = result.mae < best_result.mae
        if is_best:
            best_result = result
            with open(best_txt, 'w') as txtfile:
                txtfile.write("epoch={}, mae={:.3f}, "
                              "t_gpu={:.4f}".format(epoch, result.mae,
                                                    result.gpu_time))
            if img_merge is not None:
                img_filename = output_directory + '/comparison_best.png'
                utils.save_image(img_merge, img_filename)

        # save checkpoint for each epoch
        utils.save_checkpoint(
            {
                'args': args,
                'epoch': epoch,
                'model': model,
                'best_result': best_result,
                'optimizer': optimizer,
            }, is_best, epoch, output_directory)

        # when mae doesn't fall, reduce learning rate
        scheduler.step(result.mae)
Ejemplo n.º 16
0
def main(args):
    torch.backends.cudnn.benchmark = True
    seed_all(args.seed)

    d = Dataset(train_set_size=args.train_set_sz,
                num_cls=args.num_cls,
                remove_nan_center=False)
    train = d.train_set
    valid = d.test_set

    num_cls = args.num_cls + 1  # +1 for background
    net = UNet(in_dim=1, out_dim=num_cls).cuda()
    best_net = UNet(in_dim=1, out_dim=num_cls)
    best_val_dice = -np.inf
    best_cls_val_dices = None

    optimizer = torch.optim.Adam(params=net.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
    scheduler_warmup = GradualWarmupScheduler(optimizer,
                                              multiplier=10,
                                              total_epoch=50,
                                              after_scheduler=None)

    if not os.path.exists(args.log_dir):
        os.makedirs(args.log_dir, exist_ok=True)

    writer = tensorboardX.SummaryWriter(log_dir=args.log_dir)

    step = 1
    for epoch in range(1, args.n_epochs + 1):
        for iteration in range(
                1,
                int(np.ceil(train.dataset_sz() / args.batch_sz)) + 1):

            net.train()

            imgs, masks, one_hot_masks, centers, _, _, _, _ = train.next_batch(
                args.batch_sz)
            imgs = make_batch_input(imgs)
            imgs = torch.cuda.FloatTensor(imgs)
            one_hot_masks = torch.cuda.FloatTensor(one_hot_masks)

            pred_logit = net(imgs)
            pred_softmax = F.softmax(pred_logit, dim=1)

            if args.use_ce:
                ce = torch.nn.CrossEntropyLoss()
                loss = ce(pred_logit, torch.cuda.LongTensor(masks))
            else:
                loss = dice_loss(pred_softmax,
                                 one_hot_masks,
                                 keep_background=False).mean()

            scheduler_warmup.step()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if step % args.log_freq == 0:
                print(
                    f"step={step}\tepoch={epoch}\titer={iteration}\tloss={loss.data.cpu().numpy()}"
                )
                writer.add_scalar("cnn_dice_loss",
                                  loss.data.cpu().numpy(), step)
                writer.add_scalar("lr", optimizer.param_groups[0]["lr"], step)

            if step % args.train_eval_freq == 0:
                train_dice, cls_train_dices = do_eval(net, train.images,
                                                      train.onehot_masks,
                                                      args.batch_sz, num_cls)
                train_dice = train_dice.cpu().numpy()
                cls_train_dices = cls_train_dices.cpu().numpy()
                writer.add_scalar("train_dice", train_dice, step)
                # lr_sched.step(1-train_dice)
                for j, cls_train_dice in enumerate(cls_train_dices):
                    writer.add_scalar(f"train_dice/{j}", cls_train_dice, step)
                print(
                    f"step={step}\tepoch={epoch}\titer={iteration}\ttrain_eval: train_dice={train_dice}"
                )

            if step % args.val_eval_freq == 0:
                _pickle.dump(
                    net.state_dict(),
                    open(os.path.join(args.log_dir, 'model.pth.tar'), 'wb'))
                val_dice, cls_val_dices = do_eval(net, valid.images,
                                                  valid.onehot_masks,
                                                  args.batch_sz, num_cls)
                val_dice = val_dice.cpu().numpy()
                cls_val_dices = cls_val_dices.cpu().numpy()
                writer.add_scalar("val_dice", val_dice, step)
                for j, cls_val_dice in enumerate(cls_val_dices):
                    writer.add_scalar(f"val_dice/{j}", cls_val_dice, step)
                print(
                    f"step={step}\tepoch={epoch}\titer={iteration}\tvalid_dice={val_dice}"
                )
                if val_dice > best_val_dice:
                    best_val_dice = val_dice
                    best_cls_val_dices = cls_val_dices
                    best_net.load_state_dict(net.state_dict().copy())
                    _pickle.dump(
                        best_net.state_dict(),
                        open(os.path.join(args.log_dir, 'best_model.pth.tar'),
                             'wb'))
                    f = open(
                        os.path.join(args.log_dir, f"best_val_dice{step}.txt"),
                        'w')
                    f.write(str(best_val_dice) + "\n")
                    f.write(" ".join([
                        str(dice_score) for dice_score in best_cls_val_dices
                    ]))
                    f.close()
                    print(f"better val dice detected.")
                # if step % 5000 == 0:
                #     _pickle.dump(net.state_dict(), open(os.path.join(args.log_dir, '{}.pth.tar'.format(step)),
                #                                         'wb'))

            step += 1

    return best_val_dice, best_cls_val_dices
Ejemplo n.º 17
0
        ax2.title.set_text('probability prediction')

        plt.show(block=False)
        plt.pause(0.01)

    def show_gamma(self):
        plt.figure(3)
        plt.subplot(1, 1, 1)
        plt.imshow(self.gamma[0])
        plt.title('Gamma')
        plt.show(block=False)
        plt.pause(0.01)

    def show_s(self):
        plt.figure(4)
        plt.subplot(1, 1, 1)
        plt.imshow(self.s[0])
        plt.show(block=False)
        plt.pause(0.01)


if __name__ == "__main__":
    net = UNet(num_classes=2)
    net_ = networks(net, 10, 100)
    for i in xrange(10):
        # print(net_)
        limage = torch.randn(1, 1, 256, 256)
        uimage = torch.randn(1, 1, 256, 256)
        lmask = torch.randint(0, 2, (1, 256, 256), dtype=torch.long)
        net_.update((limage, lmask), uimage)
Ejemplo n.º 18
0
Archivo: main.py Proyecto: ted-17/unet
util.mix_voice_noise(voicedir, noisedir, mixeddir, num_data, fs=16000)

# get list each of which the path name is written
voicepath_list = util.get_wavlist(voicedir)
mixedpath_list = util.get_wavlist(mixeddir)

# make spectrogram (n x F x T x 1)
V = util.make_dataset(voicepath_list, fftsize, hopsize, nbit)
X = util.make_dataset(mixedpath_list, fftsize, hopsize, nbit)

#%% model training
height, width = fftsize // 2, fftsize // 2  #CNN height x width
X_train = X[:, :height, :width, ...]  #voice + noise
Y_train = V[:, :height, :width, ...]  #noise only
num_filt_first = 16
unet = UNet(height, width, num_filt_first)
model = unet.get_model()
model.compile(optimizer='adam', loss='mean_squared_error')
history = model.fit(X_train, Y_train, epochs=5, batch_size=32)

#%% model testing
absY, phsY, max_Y, min_Y = util.make_spectrogram(mixpath_list[0],
                                                 fftsize,
                                                 hopsize,
                                                 nbit,
                                                 istest=True)
P = np.squeeze(model.predict(absY[np.newaxis, :height, :width, ...]))
P = np.hstack((P, absY[:height, width:]))  #t-axis
P = np.vstack((P, absY[height, :]))  #f-axis
Y = (absY * (max_Y - min_Y) + min_Y) * phsY
y = librosa.core.istft(absY * phsY, hop_length=hopsize, win_length=fftsize)
Ejemplo n.º 19
0
print("debug:", DEBUG)
if DEBUG:
    task = 'DEBUG' + task
    num_train = 10
    num_val = 2
    save_model_freq = 1 

# set up the model and define the graph
with tf.variable_scope(tf.get_variable_scope()):
    input=tf.placeholder(tf.float32,shape=[None,None,None,5])
    reflection=tf.placeholder(tf.float32,shape=[None,None,None,5])
    target=tf.placeholder(tf.float32,shape=[None,None,None,5])
    overexp_mask = utils.tf_overexp_mask(input)
    tf_input, tf_reflection, tf_target, real_input = utils.prepare_real_input(input, target, reflection, overexp_mask, ARGS)
    reflection_layer=UNet(real_input, ext='Ref_') #real_reflect = build_one_hyper(reflection_layer[...,4:5])
    transmission_layer=UNet(tf.concat([real_input, reflection_layer],axis=3),ext='Tran_') 
    lossDict = {}

    lossDict["percep_t"]=0.2*loss.compute_percep_loss(0.5 * tf_target[...,4:5],  0.5*transmission_layer[...,4:5], overexp_mask, reuse=False )
    lossDict["percep_r"]=0.2*loss.compute_percep_loss(0.5 * tf_reflection[...,4:5], 0.5*reflection_layer[...,4:5], overexp_mask, reuse=True)

    lossDict["pncc"] = 6*loss.compute_percep_ncc_loss(tf.multiply(0.5*transmission_layer[...,4:5],overexp_mask), 
        tf.multiply(0.5*reflection_layer[...,4:5],overexp_mask))

    lossDict["reconstruct"]= loss.mask_reconstruct_loss(tf_input[...,4:5], transmission_layer[...,4:5], reflection_layer[...,4:5], overexp_mask)
    
    lossDict["reflection"] = lossDict["percep_r"]
    lossDict["transmission"]=lossDict["percep_t"]
    lossDict["all_loss"] = lossDict["reflection"] + lossDict["transmission"] + lossDict["pncc"]
Ejemplo n.º 20
0
def score_data(input_folder,
               output_folder,
               model_path,
               args,
               do_postprocessing=False,
               gt_exists=True,
               evaluate_all=False,
               random_center_ratio=None):
    num_classes = args.num_cls
    nx, ny = exp_config.image_size[:2]
    batch_size = 1
    num_channels = num_classes + 1

    net = UNet(in_dim=1, out_dim=4).cuda()
    ckpt_path = os.path.join(model_path, 'best_model.pth.tar')
    net.load_state_dict(_pickle.load(open(ckpt_path, 'rb'))[0])
    if args.unet_ckpt:
        pretrained_unet = UNet(in_dim=1, out_dim=4).cuda()
        pretrained_unet.load_state_dict(
            _pickle.load(open(args.unet_ckpt, 'rb')))

    snake = SnakePytorch(args.delta, 1, args.num_lines, args.radius)

    evaluate_test_set = not gt_exists

    total_time = 0
    total_volumes = 0

    for folder in os.listdir(input_folder):

        folder_path = os.path.join(input_folder, folder)

        if os.path.isdir(folder_path):

            if evaluate_test_set or evaluate_all:
                train_test = 'test'  # always test
            else:
                train_test = 'test' if (int(folder[-3:]) % 5 == 0) else 'train'

            if train_test == 'test':

                infos = {}
                for line in open(os.path.join(folder_path, 'Info.cfg')):
                    label, value = line.split(':')
                    infos[label] = value.rstrip('\n').lstrip(' ')

                patient_id = folder.lstrip('patient')
                ED_frame = int(infos['ED'])
                ES_frame = int(infos['ES'])

                for file in glob.glob(
                        os.path.join(folder_path,
                                     'patient???_frame??.nii.gz')):

                    logging.info(
                        ' ----- Doing image: -------------------------')
                    logging.info('Doing: %s' % file)
                    logging.info(
                        ' --------------------------------------------')

                    file_base = file.split('.nii.gz')[0]

                    frame = int(file_base.split('frame')[-1])
                    img_dat = utils.load_nii(file)
                    img = img_dat[0].copy()
                    img = image_utils.normalise_image(img)

                    if gt_exists:
                        file_mask = file_base + '_gt.nii.gz'
                        mask_dat = utils.load_nii(file_mask)
                        mask = mask_dat[0]

                    start_time = time.time()

                    pixel_size = (img_dat[2].structarr['pixdim'][1],
                                  img_dat[2].structarr['pixdim'][2])
                    scale_vector = (pixel_size[0] /
                                    exp_config.target_resolution[0],
                                    pixel_size[1] /
                                    exp_config.target_resolution[1])

                    predictions = []

                    for zz in range(img.shape[2]):

                        slice_img = np.squeeze(img[:, :, zz])
                        slice_rescaled = transform.rescale(slice_img,
                                                           scale_vector,
                                                           order=1,
                                                           preserve_range=True,
                                                           multichannel=False,
                                                           mode='constant')

                        x, y = slice_rescaled.shape
                        slice_cropped, x_s, y_s, x_c, y_c = get_slice(
                            slice_rescaled, nx, ny)
                        # GET PREDICTION
                        network_input = np.float32(
                            np.tile(np.reshape(slice_cropped, (nx, ny, 1)),
                                    (batch_size, 1, 1, 1)))
                        network_input = np.transpose(network_input,
                                                     [0, 3, 1, 2])
                        network_input = torch.cuda.FloatTensor(network_input)
                        with torch.no_grad():
                            net.eval()
                            logit = net(network_input)

                        # get the center
                        if args.unet_ckpt != '':
                            unet_mask = torch.argmax(
                                pretrained_unet(network_input),
                                dim=1).data.cpu().numpy()[0]
                        else:
                            assert gt_exists
                            mask_copy = mask[:, :, zz].copy()
                            unet_mask = get_slice(mask_copy, nx, ny)[0]
                        unet_mask = image_utils.keep_largest_connected_components(
                            unet_mask)
                        from data_iterator import get_center_of_mass
                        if num_classes == 2:
                            lv_center = get_center_of_mass(unet_mask, [3])
                            mo_center = get_center_of_mass(unet_mask, [2])
                        else:
                            lv_center = get_center_of_mass(unet_mask, [3])
                            mo_center = np.asarray([[np.nan, np.nan]])
                        lv_center = np.asarray(lv_center)
                        mo_center = np.asarray(mo_center)

                        lv_mask = np.zeros((nx, ny))
                        if not np.isnan(lv_center[0, 0]):
                            if random_center_ratio:
                                dt, _ = get_distance_transform(
                                    unet_mask == 3, None)
                                max_radius = dt[0,
                                                int(lv_center[0][0]),
                                                int(lv_center[0][1])]
                                radius = int(max_radius * random_center_ratio)
                                c_j, _ = get_random_jitter(radius, 0)
                            else:
                                c_j = None

                            lv_logit, _, _ = get_star_pattern_values(
                                logit[:, 3, ...],
                                None,
                                lv_center,
                                args.num_lines,
                                args.radius + 1,
                                center_jitters=c_j)
                            lv_gs = lv_logit[:, :,
                                             1:] - lv_logit[:, :, :
                                                            -1]  # compute the gradient
                            # run DP algo
                            # can only put batch with fixed shape into the snake algorithm
                            lv_ind = snake(lv_gs).data.cpu().numpy()
                            lv_ind = np.expand_dims(
                                smooth_ind(lv_ind.squeeze(-1),
                                           args.smoothing_window), -1)
                            lv_mask = star_pattern_ind_to_mask(
                                lv_ind, lv_center, nx, ny, args.num_lines,
                                args.radius)

                        if num_classes == 1:
                            pred_mask = lv_mask * 3
                        else:
                            mo_mask = np.zeros((nx, ny))
                            if not np.isnan(mo_center[0]):
                                c_j = None
                                mo_logit, _, _ = get_star_pattern_values(
                                    logit[:, 2, ...],
                                    None,
                                    lv_center,
                                    args.num_lines,
                                    args.radius + 1,
                                    center_jitters=c_j)
                                mo_gs = mo_logit[:, :,
                                                 1:] - mo_logit[:, :, :
                                                                -1]  # compute the gradient
                                mo_ind = snake(mo_gs).data.cpu().numpy()
                                mo_ind = mo_ind[:len(mo_gs), ...]
                                mo_ind = np.expand_dims(
                                    smooth_ind(mo_ind.squeeze(-1),
                                               args.smoothing_window), -1)
                                mo_mask = star_pattern_ind_to_mask(
                                    mo_ind, lv_center, nx, ny, args.num_lines,
                                    args.radius)
                            pred_mask = lv_mask * 3 + (
                                1 - lv_mask
                            ) * mo_mask * 2  # 3 is lv class, 2 is mo class

                        prediction_cropped = pred_mask.squeeze()
                        # ASSEMBLE BACK THE SLICES
                        prediction = np.zeros((x, y))
                        # insert cropped region into original image again
                        if x > nx and y > ny:
                            prediction[x_s:x_s + nx,
                                       y_s:y_s + ny] = prediction_cropped
                        else:
                            if x <= nx and y > ny:
                                prediction[:, y_s:y_s +
                                           ny] = prediction_cropped[x_c:x_c +
                                                                    x, :]
                            elif x > nx and y <= ny:
                                prediction[
                                    x_s:x_s +
                                    nx, :] = prediction_cropped[:, y_c:y_c + y]
                            else:
                                prediction[:, :] = prediction_cropped[x_c:x_c +
                                                                      x,
                                                                      y_c:y_c +
                                                                      y]

                        # RESCALING ON THE LOGITS
                        if gt_exists:
                            prediction = transform.resize(
                                prediction, (mask.shape[0], mask.shape[1]),
                                order=0,
                                preserve_range=True,
                                mode='constant')
                        else:  # This can occasionally lead to wrong volume size, therefore if gt_exists
                            # we use the gt mask size for resizing.
                            prediction = transform.rescale(
                                prediction,
                                (1.0 / scale_vector[0], 1.0 / scale_vector[1]),
                                order=0,
                                preserve_range=True,
                                multichannel=False,
                                mode='constant')

                        # prediction = np.uint8(np.argmax(prediction, axis=-1))
                        prediction = np.uint8(prediction)
                        predictions.append(prediction)

                        gt_binary = (mask[..., zz] == 3) * 1
                        pred_binary = (prediction == 3) * 1
                        from medpy.metric.binary import hd, dc, assd
                        lv_center = lv_center[0]
                        # i=0;  plt.imshow(network_input[0, 0]); plt.plot(lv_center[1], lv_center[0], 'ro'); plt.show(); plt.imshow(unet_mask); plt.plot(lv_center[1], lv_center[0], 'ro'); plt.show(); plt.imshow(logit[0, 0]); plt.plot(lv_center[1], lv_center[0], 'ro'); plt.show(); plt.imshow(lv_logit[0]); plt.show();  plt.imshow(lv_gs[0]); plt.show(); plt.imshow(prediction_cropped); plt.plot(lv_center[1], lv_center[0], 'r.'); plt.show();

                    prediction_arr = np.transpose(
                        np.asarray(predictions, dtype=np.uint8), (1, 2, 0))

                    # This is the same for 2D and 3D again
                    if do_postprocessing:
                        prediction_arr = image_utils.keep_largest_connected_components(
                            prediction_arr)

                    elapsed_time = time.time() - start_time
                    total_time += elapsed_time
                    total_volumes += 1

                    logging.info('Evaluation of volume took %f secs.' %
                                 elapsed_time)

                    if frame == ED_frame:
                        frame_suffix = '_ED'
                    elif frame == ES_frame:
                        frame_suffix = '_ES'
                    else:
                        raise ValueError(
                            'Frame doesnt correspond to ED or ES. frame = %d, ED = %d, ES = %d'
                            % (frame, ED_frame, ES_frame))

                    # Save prediced mask
                    out_file_name = os.path.join(
                        output_folder, 'prediction',
                        'patient' + patient_id + frame_suffix + '.nii.gz')
                    if gt_exists:
                        out_affine = mask_dat[1]
                        out_header = mask_dat[2]
                    else:
                        out_affine = img_dat[1]
                        out_header = img_dat[2]

                    logging.info('saving to: %s' % out_file_name)
                    utils.save_nii(out_file_name, prediction_arr, out_affine,
                                   out_header)

                    # Save image data to the same folder for convenience
                    image_file_name = os.path.join(
                        output_folder, 'image',
                        'patient' + patient_id + frame_suffix + '.nii.gz')
                    logging.info('saving to: %s' % image_file_name)
                    utils.save_nii(image_file_name, img_dat[0], out_affine,
                                   out_header)

                    if gt_exists:
                        # Save GT image
                        gt_file_name = os.path.join(
                            output_folder, 'ground_truth',
                            'patient' + patient_id + frame_suffix + '.nii.gz')
                        logging.info('saving to: %s' % gt_file_name)
                        utils.save_nii(gt_file_name, mask, out_affine,
                                       out_header)

                        # Save difference mask between predictions and ground truth
                        difference_mask = np.where(
                            np.abs(prediction_arr - mask) > 0, [1], [0])
                        difference_mask = np.asarray(difference_mask,
                                                     dtype=np.uint8)
                        diff_file_name = os.path.join(
                            output_folder, 'difference',
                            'patient' + patient_id + frame_suffix + '.nii.gz')
                        logging.info('saving to: %s' % diff_file_name)
                        utils.save_nii(diff_file_name, difference_mask,
                                       out_affine, out_header)

    logging.info('Average time per volume: %f' % (total_time / total_volumes))

    return None
Ejemplo n.º 21
0
from network import UNet
import numpy as np
from numpy import *
import random
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.cluster import KMeans
from sklearn.externals import joblib
from sklearn import cluster
import cv2
cmap = plt.cm.jet
#os.environ["CUDA_VISIBLE_DEVICES"] = "1"  # use single GPU

args = utils.parse_command()
print(args)
model = UNet(3, 1)
model = nn.DataParallel(model, device_ids=[0])
model.cuda()
# if setting gpu id, the using single GPU
print('Single GPU Mode.')


def create_loader(args):
    root_dir = ''
    test_set = KittiFolder(root_dir, mode='test', size=(256, 512))
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=1,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=True)
    return test_loader
Ejemplo n.º 22
0
valid_loader = data.DataLoader(
    dataset=DataFolder('dataset/valid_images_256/', 'dataset/valid_masks_256/', 'validation'),
    batch_size=args.eval_batch_size,
    shuffle=False,
    num_workers=2
)

eval_loader = data.DataLoader(
    dataset=DataFolder('dataset/eval_images_256/', 'dataset/eval_masks_256/', 'evaluate'),
    batch_size=args.eval_batch_size,
    shuffle=False,
    num_workers=2
)

model = UNet(1, shrink=1).cuda()
nets = [model]
params = [{'params': net.parameters()} for net in nets]
solver = optim.Adam(params, lr=args.lr)

criterion = nn.CrossEntropyLoss()
es = EarlyStopping(min_delta=args.min_delta, patience=args.patience)

for epoch in range(1, args.epochs+1):

    train_loss = []
    valid_loss = []

    for batch_idx, (img, mask, _) in enumerate(train_loader):

        solver.zero_grad()
Ejemplo n.º 23
0
def score_data(input_folder,
               output_folder,
               model_path,
               num_classes=3,
               do_postprocessing=False,
               gt_exists=True,
               evaluate_all=False):

    nx, ny = exp_config.image_size[:2]
    batch_size = 1
    num_channels = num_classes + 1

    net = UNet(in_dim=1, out_dim=num_classes + 1).cuda()
    ckpt_path = os.path.join(model_path, 'best_model.pth.tar')
    net.load_state_dict(_pickle.load(open(ckpt_path, 'rb')))

    evaluate_test_set = not gt_exists

    total_time = 0
    total_volumes = 0

    for folder in os.listdir(input_folder):

        folder_path = os.path.join(input_folder, folder)

        if os.path.isdir(folder_path):

            if evaluate_test_set or evaluate_all:
                train_test = 'test'  # always test
            else:
                train_test = 'test' if (int(folder[-3:]) % 5 == 0) else 'train'

            if train_test == 'test':

                infos = {}
                for line in open(os.path.join(folder_path, 'Info.cfg')):
                    label, value = line.split(':')
                    infos[label] = value.rstrip('\n').lstrip(' ')

                patient_id = folder.lstrip('patient')
                ED_frame = int(infos['ED'])
                ES_frame = int(infos['ES'])

                for file in glob.glob(
                        os.path.join(folder_path,
                                     'patient???_frame??.nii.gz')):

                    logging.info(
                        ' ----- Doing image: -------------------------')
                    logging.info('Doing: %s' % file)
                    logging.info(
                        ' --------------------------------------------')

                    file_base = file.split('.nii.gz')[0]

                    frame = int(file_base.split('frame')[-1])
                    img_dat = utils.load_nii(file)
                    img = img_dat[0].copy()
                    img = image_utils.normalise_image(img)

                    if gt_exists:
                        file_mask = file_base + '_gt.nii.gz'
                        mask_dat = utils.load_nii(file_mask)
                        mask = mask_dat[0]

                    start_time = time.time()

                    pixel_size = (img_dat[2].structarr['pixdim'][1],
                                  img_dat[2].structarr['pixdim'][2])
                    scale_vector = (pixel_size[0] /
                                    exp_config.target_resolution[0],
                                    pixel_size[1] /
                                    exp_config.target_resolution[1])

                    predictions = []

                    for zz in range(img.shape[2]):

                        slice_img = np.squeeze(img[:, :, zz])
                        slice_rescaled = transform.rescale(slice_img,
                                                           scale_vector,
                                                           order=1,
                                                           preserve_range=True,
                                                           multichannel=False,
                                                           mode='constant')

                        x, y = slice_rescaled.shape

                        x_s = (x - nx) // 2
                        y_s = (y - ny) // 2
                        x_c = (nx - x) // 2
                        y_c = (ny - y) // 2

                        # Crop section of image for prediction
                        if x > nx and y > ny:
                            slice_cropped = slice_rescaled[x_s:x_s + nx,
                                                           y_s:y_s + ny]
                        else:
                            slice_cropped = np.zeros((nx, ny))
                            if x <= nx and y > ny:
                                slice_cropped[x_c:x_c +
                                              x, :] = slice_rescaled[:,
                                                                     y_s:y_s +
                                                                     ny]
                            elif x > nx and y <= ny:
                                slice_cropped[:, y_c:y_c +
                                              y] = slice_rescaled[x_s:x_s +
                                                                  nx, :]
                            else:
                                slice_cropped[x_c:x_c + x, y_c:y_c +
                                              y] = slice_rescaled[:, :]

                        # GET PREDICTION
                        network_input = np.float32(
                            np.tile(np.reshape(slice_cropped, (nx, ny, 1)),
                                    (batch_size, 1, 1, 1)))
                        network_input = np.transpose(network_input,
                                                     [0, 3, 1, 2])
                        network_input = torch.cuda.FloatTensor(network_input)
                        with torch.no_grad():
                            net.eval()
                            logits_out = net(network_input)
                            softmax_out = F.softmax(logits_out, dim=1)
                            # mask_out = torch.argmax(logits_out, dim=1)
                            softmax_out = softmax_out.data.cpu().numpy()
                            softmax_out = np.transpose(softmax_out,
                                                       [0, 2, 3, 1])
                        # prediction_cropped = np.squeeze(softmax_out[0,...])
                        prediction_cropped = np.squeeze(softmax_out)

                        # ASSEMBLE BACK THE SLICES
                        slice_predictions = np.zeros((x, y, num_channels))
                        # insert cropped region into original image again
                        if x > nx and y > ny:
                            slice_predictions[x_s:x_s + nx, y_s:y_s +
                                              ny, :] = prediction_cropped
                        else:
                            if x <= nx and y > ny:
                                slice_predictions[:, y_s:y_s +
                                                  ny, :] = prediction_cropped[
                                                      x_c:x_c + x, :, :]
                            elif x > nx and y <= ny:
                                slice_predictions[
                                    x_s:x_s +
                                    nx, :, :] = prediction_cropped[:, y_c:y_c +
                                                                   y, :]
                            else:
                                slice_predictions[:, :, :] = prediction_cropped[
                                    x_c:x_c + x, y_c:y_c + y, :]

                        # RESCALING ON THE LOGITS
                        if gt_exists:
                            prediction = transform.resize(
                                slice_predictions,
                                (mask.shape[0], mask.shape[1], num_channels),
                                order=1,
                                preserve_range=True,
                                mode='constant')
                        else:  # This can occasionally lead to wrong volume size, therefore if gt_exists
                            # we use the gt mask size for resizing.
                            prediction = transform.rescale(
                                slice_predictions, (1.0 / scale_vector[0],
                                                    1.0 / scale_vector[1], 1),
                                order=1,
                                preserve_range=True,
                                multichannel=False,
                                mode='constant')

                        # prediction = transform.resize(slice_predictions,
                        #                               (mask.shape[0], mask.shape[1], num_channels),
                        #                               order=1,
                        #                               preserve_range=True,
                        #                               mode='constant')

                        prediction = np.uint8(np.argmax(prediction, axis=-1))
                        if num_classes == 1:
                            prediction[prediction == 1] = 3
                        elif num_classes == 2:
                            prediction[prediction == 2] = 3
                            prediction[prediction == 1] = 2
                        predictions.append(prediction)

                    prediction_arr = np.transpose(
                        np.asarray(predictions, dtype=np.uint8), (1, 2, 0))

                    # This is the same for 2D and 3D again
                    if do_postprocessing:
                        assert num_classes == 1
                        from skimage.measure import regionprops
                        lv_obj = (mask_dat[0] == 3).astype(np.uint8)
                        prop = regionprops(lv_obj)
                        assert len(prop) == 1
                        prop = prop[0]
                        centroid = prop.centroid
                        centroid = (int(centroid[0]), int(centroid[1]),
                                    int(centroid[2]))
                        prediction_arr = image_utils.keep_largest_connected_components(
                            prediction_arr, centroid)

                    elapsed_time = time.time() - start_time
                    total_time += elapsed_time
                    total_volumes += 1

                    logging.info('Evaluation of volume took %f secs.' %
                                 elapsed_time)

                    if frame == ED_frame:
                        frame_suffix = '_ED'
                    elif frame == ES_frame:
                        frame_suffix = '_ES'
                    else:
                        raise ValueError(
                            'Frame doesnt correspond to ED or ES. frame = %d, ED = %d, ES = %d'
                            % (frame, ED_frame, ES_frame))

                    # Save prediced mask
                    out_file_name = os.path.join(
                        output_folder, 'prediction',
                        'patient' + patient_id + frame_suffix + '.nii.gz')
                    if gt_exists:
                        out_affine = mask_dat[1]
                        out_header = mask_dat[2]
                    else:
                        out_affine = img_dat[1]
                        out_header = img_dat[2]

                    logging.info('saving to: %s' % out_file_name)
                    utils.save_nii(out_file_name, prediction_arr, out_affine,
                                   out_header)

                    # Save image data to the same folder for convenience
                    image_file_name = os.path.join(
                        output_folder, 'image',
                        'patient' + patient_id + frame_suffix + '.nii.gz')
                    logging.info('saving to: %s' % image_file_name)
                    utils.save_nii(image_file_name, img_dat[0], out_affine,
                                   out_header)

                    if gt_exists:

                        # Save GT image
                        gt_file_name = os.path.join(
                            output_folder, 'ground_truth',
                            'patient' + patient_id + frame_suffix + '.nii.gz')
                        logging.info('saving to: %s' % gt_file_name)
                        utils.save_nii(gt_file_name, mask, out_affine,
                                       out_header)

                        # Save difference mask between predictions and ground truth
                        difference_mask = np.where(
                            np.abs(prediction_arr - mask) > 0, [1], [0])
                        difference_mask = np.asarray(difference_mask,
                                                     dtype=np.uint8)
                        diff_file_name = os.path.join(
                            output_folder, 'difference',
                            'patient' + patient_id + frame_suffix + '.nii.gz')
                        logging.info('saving to: %s' % diff_file_name)
                        utils.save_nii(diff_file_name, difference_mask,
                                       out_affine, out_header)

    logging.info('Average time per volume: %f' % (total_time / total_volumes))

    return None
    os.environ["CUDA_VISIBLE_DEVICES"]=str(np.argmax( [int(x.split()[2]) for x in subprocess.Popen("nvidia-smi -q -d Memory | grep -A4 GPU | grep Free", shell=True, stdout=subprocess.PIPE).stdout.readlines()]))
else:    
    os.environ["CUDA_VISIBLE_DEVICES"]=''


test_names= sorted(glob(ARGS.test_dir + "/*png"))
print('Data load succeed!')

# set up the model and define the graph
with tf.variable_scope(tf.get_variable_scope()):
    input=tf.placeholder(tf.float32,shape=[None,None,None,5])
    reflection=tf.placeholder(tf.float32,shape=[None,None,None,5])
    target=tf.placeholder(tf.float32,shape=[None,None,None,5])
    overexp_mask = utils.tf_overexp_mask(input)
    tf_input, tf_reflection, tf_target, real_input = utils.prepare_real_input(input, target, reflection, overexp_mask, ARGS)
    reflection_layer=UNet(real_input, ext='Ref_')
    transmission_layer = UNet(tf.concat([real_input, reflection_layer],axis=3),ext='Tran_') 
 

######### Session #########
saver=tf.train.Saver(max_to_keep=10)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess=tf.Session(config=config)
sess.run(tf.global_variables_initializer())
var_restore = [v for v in tf.trainable_variables()]
saver_restore=tf.train.Saver(var_restore)
for var in tf.trainable_variables():
    print("Listing trainable variables ... ")
    print(var)
Ejemplo n.º 25
0
        print(f'Dice loss in step {step} is {dice_loss}')
        for i in range(len(label)):
            mask = pred[i, 0, :, :].data.cpu().numpy()
            mask = np.where(mask > 0.4, 1, 0)
            name = batch['name'][i]
            img = sitk.ReadImage(os.path.join(img_path, train_phrase, name))
            img = sitk.GetArrayFromImage(img)
            display(name.split('.')[0], mask, mask)
            # display(name.split('.')[0], mask, img)
            pass
        pass
    aver_dice = aver_dice / len(dataloader)
    print(f'average dice is {aver_dice}.')
    pass


if __name__ == "__main__":
    dataset = EmbDataset(train_phrase='train')
    channels_in = len(dataset.model_set) + 1
    dataloader = DataLoader(dataset, batch_size=3, shuffle=False)
    state = torch.load(
        '/home/zhangqianru/data/ly/ckpt_folder/retrain_2/epoch4.pth')
    epoch = state['epoch']
    print(f'Load epoch {epoch}.')
    net = UNet(channels_in, 1)
    net.load_state_dict(state['net'])
    eval_net(net,
             dataloader,
             dataset.train_phrase,
             save_path='/home/zhangqianru/')
    pass
from data_loader import loaders
from train import my_train
from eval import my_eval
from visualize import my_vis
from app import my_app
import numpy as np
import ray

#data loading
batch_size = 2
loader_tr = loaders(batch_size, 0)
loader_vl = loaders(batch_size, 1)

#networks

output = UNet()

output.cuda()

#optimizer

optimizer = optim.Adam(output.parameters(), lr=.00003, weight_decay=1e-4)

#training

metric_values, metric1_values, val_metric_values, val_metric1_values, epoch_values, loss_values, val_loss_values = (
    [] for i in range(7))

no_of_epochs = 1000
no_of_batches = len(loader_tr)
no_of_batches_1 = len(loader_vl)