예제 #1
0
파일: denoise.py 프로젝트: tbepler/topaz
def main(args):

    # set the number of threads
    num_threads = args.num_threads
    from topaz.torch import set_num_threads
    set_num_threads(num_threads)

    ## set the device
    use_cuda = topaz.cuda.set_device(args.device)
    print('# using device={} with cuda={}'.format(args.device, use_cuda),
          file=sys.stderr)

    cutoff = args.pixel_cutoff  # pixel truncation limit

    do_train = (args.dir_a is not None
                and args.dir_b is not None) or (args.hdf is not None)
    if do_train:

        method = args.method
        paired = (method == 'noise2noise')
        preload = args.preload
        holdout = args.holdout  # fraction of image pairs to holdout for validation

        if args.hdf is None:  #use dirA/dirB
            crop = args.crop
            dir_as = args.dir_a
            dir_bs = args.dir_b

            dset_train = []
            dset_val = []

            for dir_a, dir_b in zip(dir_as, dir_bs):
                random = np.random.RandomState(44444)
                if paired:
                    dataset_train, dataset_val = make_paired_images_datasets(
                        dir_a,
                        dir_b,
                        crop,
                        random=random,
                        holdout=holdout,
                        preload=preload,
                        cutoff=cutoff)
                else:
                    dataset_train, dataset_val = make_images_datasets(
                        dir_a,
                        dir_b,
                        crop,
                        cutoff=cutoff,
                        random=random,
                        holdout=holdout)
                dset_train.append(dataset_train)
                dset_val.append(dataset_val)

            dataset_train = dset_train[0]
            for i in range(1, len(dset_train)):
                dataset_train.x += dset_train[i].x
                if paired:
                    dataset_train.y += dset_train[i].y

            dataset_val = dset_val[0]
            for i in range(1, len(dset_val)):
                dataset_val.x += dset_val[i].x
                if paired:
                    dataset_val.y += dset_val[i].y

            shuffle = True
        else:  # make HDF datasets
            dataset_train, dataset_val = make_hdf5_datasets(args.hdf,
                                                            paired=paired,
                                                            cutoff=cutoff,
                                                            holdout=holdout,
                                                            preload=preload)
            shuffle = preload

        # initialize the model
        arch = args.arch
        if arch == 'unet':
            model = dn.UDenoiseNet()
        elif arch == 'unet-small':
            model = dn.UDenoiseNetSmall()
        elif arch == 'unet2':
            model = dn.UDenoiseNet2()
        elif arch == 'unet3':
            model = dn.UDenoiseNet3()
        elif arch == 'fcnet':
            model = dn.DenoiseNet(32)
        elif arch == 'fcnet2':
            model = dn.DenoiseNet2(64)
        elif arch == 'affine':
            model = dn.AffineDenoise()
        else:
            raise Exception('Unknown architecture: ' + arch)

        if use_cuda:
            model = model.cuda()

        # train
        optim = args.optim
        lr = args.lr
        batch_size = args.batch_size
        num_epochs = args.num_epochs
        digits = int(np.ceil(np.log10(num_epochs)))

        num_workers = args.num_workers

        print('epoch', 'loss_train', 'loss_val')
        #criteria = nn.L1Loss()
        criteria = args.criteria

        if method == 'noise2noise':
            iterator = dn.train_noise2noise(model,
                                            dataset_train,
                                            lr=lr,
                                            optim=optim,
                                            batch_size=batch_size,
                                            criteria=criteria,
                                            num_epochs=num_epochs,
                                            dataset_val=dataset_val,
                                            use_cuda=use_cuda,
                                            num_workers=num_workers,
                                            shuffle=shuffle)
        elif method == 'masked':
            iterator = dn.train_mask_denoise(model,
                                             dataset_train,
                                             lr=lr,
                                             optim=optim,
                                             batch_size=batch_size,
                                             criteria=criteria,
                                             num_epochs=num_epochs,
                                             dataset_val=dataset_val,
                                             use_cuda=use_cuda,
                                             num_workers=num_workers,
                                             shuffle=shuffle)

        for epoch, loss_train, loss_val in iterator:
            print(epoch, loss_train, loss_val)
            sys.stdout.flush()

            # save the model
            if args.save_prefix is not None:
                path = args.save_prefix + ('_epoch{:0' + str(digits) +
                                           '}.sav').format(epoch)
                #path = args.save_prefix + '_epoch{}.sav'.format(epoch)
                model.cpu()
                model.eval()
                torch.save(model, path)
                if use_cuda:
                    model.cuda()

        models = [model]

    else:  # load the saved model(s)
        models = []
        for arg in args.model:
            if arg == 'none':
                print('# Warning: no denoising model will be used',
                      file=sys.stderr)
            else:
                print('# Loading model:', arg, file=sys.stderr)
            model = dn.load_model(arg)

            model.eval()
            if use_cuda:
                model.cuda()

            models.append(model)

    # using trained model
    # denoise the images

    normalize = args.normalize
    if args.format_ == 'png' or args.format_ == 'jpg':
        # always normalize png and jpg format
        normalize = True

    format_ = args.format_
    suffix = args.suffix

    lowpass = args.lowpass
    gaus = args.gaussian
    if gaus > 0:
        gaus = dn.GaussianDenoise(gaus)
        if use_cuda:
            gaus.cuda()
    else:
        gaus = None
    inv_gaus = args.inv_gaussian
    if inv_gaus > 0:
        inv_gaus = dn.InvGaussianFilter(inv_gaus)
        if use_cuda:
            inv_gaus.cuda()
    else:
        inv_gaus = None
    deconvolve = args.deconvolve
    deconv_patch = args.deconv_patch

    ps = args.patch_size
    padding = args.patch_padding

    count = 0

    # we are denoising a single MRC stack
    if args.stack:
        with open(args.micrographs[0], 'rb') as f:
            content = f.read()
        stack, _, _ = mrc.parse(content)
        print('# denoising stack with shape:', stack.shape, file=sys.stderr)
        total = len(stack)

        denoised = np.zeros_like(stack)
        for i in range(len(stack)):
            mic = stack[i]
            # process and denoise the micrograph
            mic = denoise_image(mic,
                                models,
                                lowpass=lowpass,
                                cutoff=cutoff,
                                gaus=gaus,
                                inv_gaus=inv_gaus,
                                deconvolve=deconvolve,
                                deconv_patch=deconv_patch,
                                patch_size=ps,
                                padding=padding,
                                normalize=normalize,
                                use_cuda=use_cuda)
            denoised[i] = mic

            count += 1
            print('# {} of {} completed.'.format(count, total),
                  file=sys.stderr,
                  end='\r')

        print('', file=sys.stderr)
        # write the denoised stack
        path = args.output
        print('# writing', path, file=sys.stderr)
        with open(path, 'wb') as f:
            mrc.write(f, denoised)

    else:
        # stream the micrographs and denoise them
        total = len(args.micrographs)

        # make the output directory if it doesn't exist
        if not os.path.exists(args.output):
            os.makedirs(args.output)

        for path in args.micrographs:
            name, _ = os.path.splitext(os.path.basename(path))
            mic = np.array(load_image(path), copy=False).astype(np.float32)

            # process and denoise the micrograph
            mic = denoise_image(mic,
                                models,
                                lowpass=lowpass,
                                cutoff=cutoff,
                                gaus=gaus,
                                inv_gaus=inv_gaus,
                                deconvolve=deconvolve,
                                deconv_patch=deconv_patch,
                                patch_size=ps,
                                padding=padding,
                                normalize=normalize,
                                use_cuda=use_cuda)

            # write the micrograph
            if not args.output:
                if suffix == '' or suffix is None:
                    suffix = '.denoised'
                # write the file to the same location as input
                no_ext, ext = os.path.splitext(path)
                outpath = no_ext + suffix + '.' + format_
            else:
                outpath = args.output + os.sep + name + suffix + '.' + format_
            save_image(mic, outpath)  #, mi=None, ma=None)

            count += 1
            print('# {} of {} completed.'.format(count, total),
                  file=sys.stderr,
                  end='\r')
        print('', file=sys.stderr)
예제 #2
0
def denoise(model, path, outdir, patch_size=128, padding=128, batch_size=1):
    with open(path, 'rb') as f:
        content = f.read()
    tomo,header,_ = mrc.parse(content)
    name = os.path.basename(path)

    mu = tomo.mean()
    std = tomo.std()

    # denoise in patches
    d = next(iter(model.parameters())).device
    denoised = np.zeros_like(tomo)

    with torch.no_grad():
        if patch_size < 1:
            x = (tomo - mu)/std
            x = torch.from_numpy(x).to(d)
            x = model(x.unsqueeze(0).unsqueeze(0)).squeeze().cpu().numpy()
            x = std*x + mu
            denoised[:] = x
        else:
            patch_data = PatchDataset(tomo, patch_size, padding)
            total = len(patch_data)
            count = 0

            batch_iterator = torch.utils.data.DataLoader(patch_data, batch_size=batch_size)
            for index,x in batch_iterator:
                x = x.to(d)
                x = (x - mu)/std
                x = x.unsqueeze(1) # batch x channel

                # denoise
                x = model(x).squeeze(1).cpu().numpy()

                # stitch into denoised volume
                for b in range(len(x)):
                    i,j,k = index[b]
                    xb = x[b]

                    patch = denoised[i:i+patch_size,j:j+patch_size,k:k+patch_size]
                    pz,py,px = patch.shape

                    xb = xb[padding:padding+pz,padding:padding+py,padding:padding+px]
                    denoised[i:i+patch_size,j:j+patch_size,k:k+patch_size] = xb

                    count += 1
                    print('# [{}/{}] {:.2%}'.format(count, total, count/total), name, file=sys.stderr, end='\r')

            print(' '*100, file=sys.stderr, end='\r')


    ## save the denoised tomogram
    outpath = outdir + os.sep + name

    # use the read header except for a few fields
    header = header._replace(mode=2) # 32-bit real
    header = header._replace(amin=denoised.min())
    header = header._replace(amax=denoised.max())
    header = header._replace(amean=denoised.mean())

    with open(outpath, 'wb') as f:
        mrc.write(f, denoised, header=header)
예제 #3
0
파일: image.py 프로젝트: zruan/topaz
def save_mrc(x, path):
    with open(path, 'wb') as f:
        x = x[np.newaxis] # need to add z-axis for mrc write
        mrc.write(f, x)
예제 #4
0
파일: denoise3d.py 프로젝트: zruan/topaz
def denoise(model,
            path,
            outdir,
            suffix,
            patch_size=128,
            padding=128,
            batch_size=1,
            volume_num=1,
            total_volumes=1):
    with open(path, 'rb') as f:
        content = f.read()
    tomo, header, extended_header = mrc.parse(content)
    tomo = tomo.astype(np.float32)
    name = os.path.basename(path)

    mu = tomo.mean()
    std = tomo.std()
    # denoise in patches
    d = next(iter(model.parameters())).device
    denoised = np.zeros_like(tomo)

    with torch.no_grad():
        if patch_size < 1:
            x = (tomo - mu) / std
            x = torch.from_numpy(x).to(d)
            x = model(x.unsqueeze(0).unsqueeze(0)).squeeze().cpu().numpy()
            x = std * x + mu
            denoised[:] = x
        else:
            patch_data = PatchDataset(tomo, patch_size, padding)
            total = len(patch_data)
            count = 0

            batch_iterator = torch.utils.data.DataLoader(patch_data,
                                                         batch_size=batch_size)
            for index, x in batch_iterator:
                x = x.to(d)
                x = (x - mu) / std
                x = x.unsqueeze(1)  # batch x channel

                # denoise
                x = model(x)
                x = x.squeeze(1).cpu().numpy()

                # restore original statistics
                x = std * x + mu

                # stitch into denoised volume
                for b in range(len(x)):
                    i, j, k = index[b]
                    xb = x[b]

                    patch = denoised[i:i + patch_size, j:j + patch_size,
                                     k:k + patch_size]
                    pz, py, px = patch.shape

                    xb = xb[padding:padding + pz, padding:padding + py,
                            padding:padding + px]
                    denoised[i:i + patch_size, j:j + patch_size,
                             k:k + patch_size] = xb

                    count += 1
                    print('# [{}/{}] {:.2%}'.format(volume_num, total_volumes,
                                                    count / total),
                          name,
                          file=sys.stderr,
                          end='\r')

            print(' ' * 100, file=sys.stderr, end='\r')

    ## save the denoised tomogram
    if outdir is None:
        # write denoised tomogram to same location as input, but add the suffix
        if suffix is None:  # use default
            suffix = '.denoised'
        no_ext, ext = os.path.splitext(path)
        outpath = no_ext + suffix + ext
    else:
        if suffix is None:
            suffix = ''
        no_ext, ext = os.path.splitext(name)
        outpath = outdir + os.sep + no_ext + suffix + ext

    # use the read header except for a few fields
    header = header._replace(mode=2)  # 32-bit real
    header = header._replace(amin=denoised.min())
    header = header._replace(amax=denoised.max())
    header = header._replace(amean=denoised.mean())

    with open(outpath, 'wb') as f:
        mrc.write(f, denoised, header=header, extended_header=extended_header)
예제 #5
0
def main(args):

    ## set the device
    use_cuda = False
    if args.device >= 0:
        use_cuda = torch.cuda.is_available()
        if use_cuda:
            torch.cuda.set_device(args.device)
    print('# using device={} with cuda={}'.format(args.device, use_cuda),
          file=sys.stderr)

    do_train = (args.dir_a is not None
                and args.dir_b is not None) or (args.hdf is not None)
    if do_train:
        if args.hdf is None:  #use dirA/dirB
            crop = args.crop
            dir_a = args.dir_a
            dir_b = args.dir_b
            random = np.random.RandomState(44444)

            dataset_train, dataset_val = make_paired_images_datasets(
                dir_a, dir_b, crop, random=random)
            shuffle = True
        else:  # make HDF datasets
            dataset_train, dataset_val = make_hdf5_datasets(args.hdf)
            shuffle = False

        # initialize the model
        #model = dn.DenoiseNet(32)
        model = dn.UDenoiseNet()
        if use_cuda:
            model = model.cuda()

        # train
        lr = args.lr
        batch_size = args.batch_size
        num_epochs = args.num_epochs

        num_workers = args.num_workers

        print('epoch', 'loss_train', 'loss_val')
        #criteria = nn.L1Loss()
        criteria = args.criteria

        for epoch, loss_train, loss_val in dn.train_noise2noise(
                model,
                dataset_train,
                lr=lr,
                batch_size=batch_size,
                criteria=criteria,
                num_epochs=num_epochs,
                dataset_val=dataset_val,
                use_cuda=use_cuda,
                num_workers=num_workers,
                shuffle=shuffle):
            print(epoch, loss_train, loss_val)
            sys.stdout.flush()

            # save the model
            if args.save_prefix is not None:
                path = args.save_prefix + '_epoch{}.sav'.format(epoch)
                model.cpu()
                model.eval()
                torch.save(model, path)
                if use_cuda:
                    model.cuda()

    else:  # load the saved model
        if args.model in ['L0', 'L1', 'L2']:
            if args.model in ['L0', 'L1']:
                print(
                    'ERROR: L0 and L1 models are not implemented in the current version',
                    file=sys.stderr)
                sys.exit(1)
            model = dn.load_model(args.model)
        else:
            model = torch.load(args.model)
        print('# using model:', args.model, file=sys.stderr)
        model.eval()
        if use_cuda:
            model.cuda()

    if args.stack:
        # we are denoising a single MRC stack
        with open(args.micrographs[0], 'rb') as f:
            content = f.read()
        stack, _, _ = mrc.parse(content)
        print('# denoising stack with shape:', stack.shape, file=sys.stderr)

        denoised = dn.denoise_stack(model, stack, use_cuda=use_cuda)

        # write the denoised stack
        path = args.output
        print('# writing', path, file=sys.stderr)
        with open(path, 'wb') as f:
            mrc.write(f, denoised)

    else:
        # using trained model
        # stream the micrographs and denoise as we go

        normalize = args.normalize
        if args.format_ == 'png' or args.format_ == 'jpg':
            # always normalize png and jpg format
            normalize = True

        format_ = args.format_

        count = 0
        total = len(args.micrographs)

        bin_ = args.bin
        ps = args.patch_size
        padding = args.patch_padding

        # now, stream the micrographs and denoise them
        for path in args.micrographs:
            name, _ = os.path.splitext(os.path.basename(path))
            mic = np.array(load_image(path), copy=False)
            if bin_ > 1:
                mic = downsample(mic, bin_)
            mu = mic.mean()
            std = mic.std()

            # denoise
            mic = (mic - mu) / std
            mic = dn.denoise(model,
                             mic,
                             patch_size=ps,
                             padding=padding,
                             use_cuda=use_cuda)

            if normalize:
                mic = (mic - mic.mean()) / mic.std()
            else:
                # add back std. dev. and mean
                mic = std * mic + mu

            # write the micrograph
            outpath = args.output + os.sep + name + '.' + format_
            save_image(mic, outpath)

            count += 1
            print('# {} of {} completed.'.format(count, total),
                  file=sys.stderr,
                  end='\r')

        print('', file=sys.stderr)