Esempio n. 1
0
def load_mrc(path, standardize=False):
    with open(path, 'rb') as f:
        content = f.read()
    image, header, extended_header = mrc.parse(content)
    if standardize:
        image = image - header.amean
        image /= header.rms
    return Image.fromarray(image)
Esempio n. 2
0
 def get(self, *args, **kwargs):
     ext = self.pathspec.format(*args, **kwargs) + '.' + self.format
     path = os.path.join(self.rootdir, ext)
     if self.format == 'mrc':
         with open(path, 'rb') as f:
             content = f.read()
         image, header, extended_header = mrc.parse(content)
         if self.standardize:
             image = image - header.amean
             image /= header.rms
     else:
         image = Image.open(path)
         fp = image.fp
         image.load()
         fp.close()
         image = np.array(image, copy=False)
         if self.standardize:
             image = (image - image.mean())/image.std()
     return Image.fromarray(image)
Esempio n. 3
0
def main(args):

    # set the number of threads
    num_threads = args.num_threads
    from topaz.torch import set_num_threads
    set_num_threads(num_threads)

    ## set the device
    use_cuda = topaz.cuda.set_device(args.device)
    print('# using device={} with cuda={}'.format(args.device, use_cuda),
          file=sys.stderr)

    cutoff = args.pixel_cutoff  # pixel truncation limit

    do_train = (args.dir_a is not None
                and args.dir_b is not None) or (args.hdf is not None)
    if do_train:

        method = args.method
        paired = (method == 'noise2noise')
        preload = args.preload
        holdout = args.holdout  # fraction of image pairs to holdout for validation

        if args.hdf is None:  #use dirA/dirB
            crop = args.crop
            dir_as = args.dir_a
            dir_bs = args.dir_b

            dset_train = []
            dset_val = []

            for dir_a, dir_b in zip(dir_as, dir_bs):
                random = np.random.RandomState(44444)
                if paired:
                    dataset_train, dataset_val = make_paired_images_datasets(
                        dir_a,
                        dir_b,
                        crop,
                        random=random,
                        holdout=holdout,
                        preload=preload,
                        cutoff=cutoff)
                else:
                    dataset_train, dataset_val = make_images_datasets(
                        dir_a,
                        dir_b,
                        crop,
                        cutoff=cutoff,
                        random=random,
                        holdout=holdout)
                dset_train.append(dataset_train)
                dset_val.append(dataset_val)

            dataset_train = dset_train[0]
            for i in range(1, len(dset_train)):
                dataset_train.x += dset_train[i].x
                if paired:
                    dataset_train.y += dset_train[i].y

            dataset_val = dset_val[0]
            for i in range(1, len(dset_val)):
                dataset_val.x += dset_val[i].x
                if paired:
                    dataset_val.y += dset_val[i].y

            shuffle = True
        else:  # make HDF datasets
            dataset_train, dataset_val = make_hdf5_datasets(args.hdf,
                                                            paired=paired,
                                                            cutoff=cutoff,
                                                            holdout=holdout,
                                                            preload=preload)
            shuffle = preload

        # initialize the model
        arch = args.arch
        if arch == 'unet':
            model = dn.UDenoiseNet()
        elif arch == 'unet-small':
            model = dn.UDenoiseNetSmall()
        elif arch == 'unet2':
            model = dn.UDenoiseNet2()
        elif arch == 'unet3':
            model = dn.UDenoiseNet3()
        elif arch == 'fcnet':
            model = dn.DenoiseNet(32)
        elif arch == 'fcnet2':
            model = dn.DenoiseNet2(64)
        elif arch == 'affine':
            model = dn.AffineDenoise()
        else:
            raise Exception('Unknown architecture: ' + arch)

        if use_cuda:
            model = model.cuda()

        # train
        optim = args.optim
        lr = args.lr
        batch_size = args.batch_size
        num_epochs = args.num_epochs
        digits = int(np.ceil(np.log10(num_epochs)))

        num_workers = args.num_workers

        print('epoch', 'loss_train', 'loss_val')
        #criteria = nn.L1Loss()
        criteria = args.criteria

        if method == 'noise2noise':
            iterator = dn.train_noise2noise(model,
                                            dataset_train,
                                            lr=lr,
                                            optim=optim,
                                            batch_size=batch_size,
                                            criteria=criteria,
                                            num_epochs=num_epochs,
                                            dataset_val=dataset_val,
                                            use_cuda=use_cuda,
                                            num_workers=num_workers,
                                            shuffle=shuffle)
        elif method == 'masked':
            iterator = dn.train_mask_denoise(model,
                                             dataset_train,
                                             lr=lr,
                                             optim=optim,
                                             batch_size=batch_size,
                                             criteria=criteria,
                                             num_epochs=num_epochs,
                                             dataset_val=dataset_val,
                                             use_cuda=use_cuda,
                                             num_workers=num_workers,
                                             shuffle=shuffle)

        for epoch, loss_train, loss_val in iterator:
            print(epoch, loss_train, loss_val)
            sys.stdout.flush()

            # save the model
            if args.save_prefix is not None:
                path = args.save_prefix + ('_epoch{:0' + str(digits) +
                                           '}.sav').format(epoch)
                #path = args.save_prefix + '_epoch{}.sav'.format(epoch)
                model.cpu()
                model.eval()
                torch.save(model, path)
                if use_cuda:
                    model.cuda()

        models = [model]

    else:  # load the saved model(s)
        models = []
        for arg in args.model:
            if arg == 'none':
                print('# Warning: no denoising model will be used',
                      file=sys.stderr)
            else:
                print('# Loading model:', arg, file=sys.stderr)
            model = dn.load_model(arg)

            model.eval()
            if use_cuda:
                model.cuda()

            models.append(model)

    # using trained model
    # denoise the images

    normalize = args.normalize
    if args.format_ == 'png' or args.format_ == 'jpg':
        # always normalize png and jpg format
        normalize = True

    format_ = args.format_
    suffix = args.suffix

    lowpass = args.lowpass
    gaus = args.gaussian
    if gaus > 0:
        gaus = dn.GaussianDenoise(gaus)
        if use_cuda:
            gaus.cuda()
    else:
        gaus = None
    inv_gaus = args.inv_gaussian
    if inv_gaus > 0:
        inv_gaus = dn.InvGaussianFilter(inv_gaus)
        if use_cuda:
            inv_gaus.cuda()
    else:
        inv_gaus = None
    deconvolve = args.deconvolve
    deconv_patch = args.deconv_patch

    ps = args.patch_size
    padding = args.patch_padding

    count = 0

    # we are denoising a single MRC stack
    if args.stack:
        with open(args.micrographs[0], 'rb') as f:
            content = f.read()
        stack, _, _ = mrc.parse(content)
        print('# denoising stack with shape:', stack.shape, file=sys.stderr)
        total = len(stack)

        denoised = np.zeros_like(stack)
        for i in range(len(stack)):
            mic = stack[i]
            # process and denoise the micrograph
            mic = denoise_image(mic,
                                models,
                                lowpass=lowpass,
                                cutoff=cutoff,
                                gaus=gaus,
                                inv_gaus=inv_gaus,
                                deconvolve=deconvolve,
                                deconv_patch=deconv_patch,
                                patch_size=ps,
                                padding=padding,
                                normalize=normalize,
                                use_cuda=use_cuda)
            denoised[i] = mic

            count += 1
            print('# {} of {} completed.'.format(count, total),
                  file=sys.stderr,
                  end='\r')

        print('', file=sys.stderr)
        # write the denoised stack
        path = args.output
        print('# writing', path, file=sys.stderr)
        with open(path, 'wb') as f:
            mrc.write(f, denoised)

    else:
        # stream the micrographs and denoise them
        total = len(args.micrographs)

        # make the output directory if it doesn't exist
        if not os.path.exists(args.output):
            os.makedirs(args.output)

        for path in args.micrographs:
            name, _ = os.path.splitext(os.path.basename(path))
            mic = np.array(load_image(path), copy=False).astype(np.float32)

            # process and denoise the micrograph
            mic = denoise_image(mic,
                                models,
                                lowpass=lowpass,
                                cutoff=cutoff,
                                gaus=gaus,
                                inv_gaus=inv_gaus,
                                deconvolve=deconvolve,
                                deconv_patch=deconv_patch,
                                patch_size=ps,
                                padding=padding,
                                normalize=normalize,
                                use_cuda=use_cuda)

            # write the micrograph
            if not args.output:
                if suffix == '' or suffix is None:
                    suffix = '.denoised'
                # write the file to the same location as input
                no_ext, ext = os.path.splitext(path)
                outpath = no_ext + suffix + '.' + format_
            else:
                outpath = args.output + os.sep + name + suffix + '.' + format_
            save_image(mic, outpath)  #, mi=None, ma=None)

            count += 1
            print('# {} of {} completed.'.format(count, total),
                  file=sys.stderr,
                  end='\r')
        print('', file=sys.stderr)
Esempio n. 4
0
def denoise(model, path, outdir, patch_size=128, padding=128, batch_size=1):
    with open(path, 'rb') as f:
        content = f.read()
    tomo,header,_ = mrc.parse(content)
    name = os.path.basename(path)

    mu = tomo.mean()
    std = tomo.std()

    # denoise in patches
    d = next(iter(model.parameters())).device
    denoised = np.zeros_like(tomo)

    with torch.no_grad():
        if patch_size < 1:
            x = (tomo - mu)/std
            x = torch.from_numpy(x).to(d)
            x = model(x.unsqueeze(0).unsqueeze(0)).squeeze().cpu().numpy()
            x = std*x + mu
            denoised[:] = x
        else:
            patch_data = PatchDataset(tomo, patch_size, padding)
            total = len(patch_data)
            count = 0

            batch_iterator = torch.utils.data.DataLoader(patch_data, batch_size=batch_size)
            for index,x in batch_iterator:
                x = x.to(d)
                x = (x - mu)/std
                x = x.unsqueeze(1) # batch x channel

                # denoise
                x = model(x).squeeze(1).cpu().numpy()

                # stitch into denoised volume
                for b in range(len(x)):
                    i,j,k = index[b]
                    xb = x[b]

                    patch = denoised[i:i+patch_size,j:j+patch_size,k:k+patch_size]
                    pz,py,px = patch.shape

                    xb = xb[padding:padding+pz,padding:padding+py,padding:padding+px]
                    denoised[i:i+patch_size,j:j+patch_size,k:k+patch_size] = xb

                    count += 1
                    print('# [{}/{}] {:.2%}'.format(count, total, count/total), name, file=sys.stderr, end='\r')

            print(' '*100, file=sys.stderr, end='\r')


    ## save the denoised tomogram
    outpath = outdir + os.sep + name

    # use the read header except for a few fields
    header = header._replace(mode=2) # 32-bit real
    header = header._replace(amin=denoised.min())
    header = header._replace(amax=denoised.max())
    header = header._replace(amean=denoised.mean())

    with open(outpath, 'wb') as f:
        mrc.write(f, denoised, header=header)
Esempio n. 5
0
 def load_mrc(self, path):
     with open(path, 'rb') as f:
         content = f.read()
     tomo,_,_ = mrc.parse(content)
     return tomo
Esempio n. 6
0
def denoise(model,
            path,
            outdir,
            suffix,
            patch_size=128,
            padding=128,
            batch_size=1,
            volume_num=1,
            total_volumes=1):
    with open(path, 'rb') as f:
        content = f.read()
    tomo, header, extended_header = mrc.parse(content)
    tomo = tomo.astype(np.float32)
    name = os.path.basename(path)

    mu = tomo.mean()
    std = tomo.std()
    # denoise in patches
    d = next(iter(model.parameters())).device
    denoised = np.zeros_like(tomo)

    with torch.no_grad():
        if patch_size < 1:
            x = (tomo - mu) / std
            x = torch.from_numpy(x).to(d)
            x = model(x.unsqueeze(0).unsqueeze(0)).squeeze().cpu().numpy()
            x = std * x + mu
            denoised[:] = x
        else:
            patch_data = PatchDataset(tomo, patch_size, padding)
            total = len(patch_data)
            count = 0

            batch_iterator = torch.utils.data.DataLoader(patch_data,
                                                         batch_size=batch_size)
            for index, x in batch_iterator:
                x = x.to(d)
                x = (x - mu) / std
                x = x.unsqueeze(1)  # batch x channel

                # denoise
                x = model(x)
                x = x.squeeze(1).cpu().numpy()

                # restore original statistics
                x = std * x + mu

                # stitch into denoised volume
                for b in range(len(x)):
                    i, j, k = index[b]
                    xb = x[b]

                    patch = denoised[i:i + patch_size, j:j + patch_size,
                                     k:k + patch_size]
                    pz, py, px = patch.shape

                    xb = xb[padding:padding + pz, padding:padding + py,
                            padding:padding + px]
                    denoised[i:i + patch_size, j:j + patch_size,
                             k:k + patch_size] = xb

                    count += 1
                    print('# [{}/{}] {:.2%}'.format(volume_num, total_volumes,
                                                    count / total),
                          name,
                          file=sys.stderr,
                          end='\r')

            print(' ' * 100, file=sys.stderr, end='\r')

    ## save the denoised tomogram
    if outdir is None:
        # write denoised tomogram to same location as input, but add the suffix
        if suffix is None:  # use default
            suffix = '.denoised'
        no_ext, ext = os.path.splitext(path)
        outpath = no_ext + suffix + ext
    else:
        if suffix is None:
            suffix = ''
        no_ext, ext = os.path.splitext(name)
        outpath = outdir + os.sep + no_ext + suffix + ext

    # use the read header except for a few fields
    header = header._replace(mode=2)  # 32-bit real
    header = header._replace(amin=denoised.min())
    header = header._replace(amax=denoised.max())
    header = header._replace(amean=denoised.mean())

    with open(outpath, 'wb') as f:
        mrc.write(f, denoised, header=header, extended_header=extended_header)
Esempio n. 7
0
 def load_mrc(self, path):
     with open(path, 'rb') as f:
         content = f.read()
     tomo, _, _ = mrc.parse(content)
     tomo = tomo.astype(np.float32)
     return tomo
Esempio n. 8
0
def main(args):

    ## set the device
    use_cuda = False
    if args.device >= 0:
        use_cuda = torch.cuda.is_available()
        if use_cuda:
            torch.cuda.set_device(args.device)
    print('# using device={} with cuda={}'.format(args.device, use_cuda),
          file=sys.stderr)

    do_train = (args.dir_a is not None
                and args.dir_b is not None) or (args.hdf is not None)
    if do_train:
        if args.hdf is None:  #use dirA/dirB
            crop = args.crop
            dir_a = args.dir_a
            dir_b = args.dir_b
            random = np.random.RandomState(44444)

            dataset_train, dataset_val = make_paired_images_datasets(
                dir_a, dir_b, crop, random=random)
            shuffle = True
        else:  # make HDF datasets
            dataset_train, dataset_val = make_hdf5_datasets(args.hdf)
            shuffle = False

        # initialize the model
        #model = dn.DenoiseNet(32)
        model = dn.UDenoiseNet()
        if use_cuda:
            model = model.cuda()

        # train
        lr = args.lr
        batch_size = args.batch_size
        num_epochs = args.num_epochs

        num_workers = args.num_workers

        print('epoch', 'loss_train', 'loss_val')
        #criteria = nn.L1Loss()
        criteria = args.criteria

        for epoch, loss_train, loss_val in dn.train_noise2noise(
                model,
                dataset_train,
                lr=lr,
                batch_size=batch_size,
                criteria=criteria,
                num_epochs=num_epochs,
                dataset_val=dataset_val,
                use_cuda=use_cuda,
                num_workers=num_workers,
                shuffle=shuffle):
            print(epoch, loss_train, loss_val)
            sys.stdout.flush()

            # save the model
            if args.save_prefix is not None:
                path = args.save_prefix + '_epoch{}.sav'.format(epoch)
                model.cpu()
                model.eval()
                torch.save(model, path)
                if use_cuda:
                    model.cuda()

    else:  # load the saved model
        if args.model in ['L0', 'L1', 'L2']:
            if args.model in ['L0', 'L1']:
                print(
                    'ERROR: L0 and L1 models are not implemented in the current version',
                    file=sys.stderr)
                sys.exit(1)
            model = dn.load_model(args.model)
        else:
            model = torch.load(args.model)
        print('# using model:', args.model, file=sys.stderr)
        model.eval()
        if use_cuda:
            model.cuda()

    if args.stack:
        # we are denoising a single MRC stack
        with open(args.micrographs[0], 'rb') as f:
            content = f.read()
        stack, _, _ = mrc.parse(content)
        print('# denoising stack with shape:', stack.shape, file=sys.stderr)

        denoised = dn.denoise_stack(model, stack, use_cuda=use_cuda)

        # write the denoised stack
        path = args.output
        print('# writing', path, file=sys.stderr)
        with open(path, 'wb') as f:
            mrc.write(f, denoised)

    else:
        # using trained model
        # stream the micrographs and denoise as we go

        normalize = args.normalize
        if args.format_ == 'png' or args.format_ == 'jpg':
            # always normalize png and jpg format
            normalize = True

        format_ = args.format_

        count = 0
        total = len(args.micrographs)

        bin_ = args.bin
        ps = args.patch_size
        padding = args.patch_padding

        # now, stream the micrographs and denoise them
        for path in args.micrographs:
            name, _ = os.path.splitext(os.path.basename(path))
            mic = np.array(load_image(path), copy=False)
            if bin_ > 1:
                mic = downsample(mic, bin_)
            mu = mic.mean()
            std = mic.std()

            # denoise
            mic = (mic - mu) / std
            mic = dn.denoise(model,
                             mic,
                             patch_size=ps,
                             padding=padding,
                             use_cuda=use_cuda)

            if normalize:
                mic = (mic - mic.mean()) / mic.std()
            else:
                # add back std. dev. and mean
                mic = std * mic + mu

            # write the micrograph
            outpath = args.output + os.sep + name + '.' + format_
            save_image(mic, outpath)

            count += 1
            print('# {} of {} completed.'.format(count, total),
                  file=sys.stderr,
                  end='\r')

        print('', file=sys.stderr)
Esempio n. 9
0
def main(args):
    particles = pd.read_csv(args.file, sep='\t')

    print('#', 'Loaded', len(particles), 'particles', file=sys.stderr)

    # threshold the particles
    if 'score' in particles:
        particles = particles.loc[particles['score'] >= args.threshold]
        print('#', 'Thresholding at', args.threshold, file=sys.stderr)

    print('#', 'Extracting', len(particles), 'particles', file=sys.stderr)

    N = len(particles)
    size = args.size
    resize = args.resize
    if resize < 0:
        resize = size

    #
    wrote_header = False
    read_metadata = False
    metadata = []

    # write the particles iteratively
    i = 0
    with open(args.output, 'wb') as f:
        for image_name, coords in particles.groupby('image_name'):

            print('#', image_name, len(coords), 'particles', file=sys.stderr)

            # load the micrograph
            image_name = image_name + args.image_ext
            path = os.path.join(args.image_root, image_name)
            with open(path, 'rb') as fm:
                content = fm.read()
            micrograph, header, extended_header = mrc.parse(content)
            if len(micrograph.shape) < 3:
                micrograph = micrograph[
                    np.newaxis]  # add z dim if micrograph is image

            if not wrote_header:  # load a/px and angles from micrograph header and write the stack header
                mz = micrograph.shape[0]

                dtype = micrograph.dtype

                cella = (header.xlen, header.ylen, header.zlen)
                cellb = (header.alpha, header.beta, header.gamma)
                shape = (N * mz, resize, resize)

                header = mrc.make_header(shape,
                                         cella,
                                         cellb,
                                         mz=mz,
                                         dtype=dtype)

                buf = mrc.header_struct.pack(*list(header))
                f.write(buf)
                wrote_header = True

            _, n, m = micrograph.shape

            x_coord = coords['x_coord'].values
            y_coord = coords['y_coord'].values
            scores = None
            if 'score' in coords:
                scores = coords['score'].values

            # crop out the particles
            for j in range(len(coords)):
                x = x_coord[j]
                y = y_coord[j]

                if scores is not None:
                    metadata.append((image_name, x, y, scores[j]))
                else:
                    metadata.append((image_name, x, y))

                left = x - size // 2
                upper = y - size // 2
                right = left + size
                lower = upper + size

                c = micrograph[:,
                               max(0, upper):min(n, lower),
                               max(0, left):min(m, right)]

                c = (c - c.mean()) / c.std()
                stack = np.zeros((mz, size, size), dtype=dtype)

                #stack = np.zeros((mz, size, size), dtype=dtype) + c.mean().astype(dtype)
                stack[:,
                      max(0, -upper):min(size + n - lower, size),
                      max(0, -left):min(size + m - right, size)] = c

                # write particle to mrc file
                if resize != size:
                    restack = downsample(stack, 0, shape=(resize, resize))
                    #print(restack.shape, restack.mean(), restack.std())
                    restack = (restack - restack.mean()) / restack.std()
                    f.write(restack.tobytes())
                else:
                    f.write(stack.tobytes())

                i += 1
                #print('# wrote', i, 'out of', N, 'particles', end='\r', flush=True)

    ## write the particle stack mrcs
    #with open(args.output, 'wb') as f:
    #    mrc.write(f, stack, ax=ax, ay=ay, az=az, alpha=alpha, beta=beta, gamma=gamma)

    image_name = os.path.basename(args.output)
    star_path = os.path.splitext(args.output)[0] + '.star'

    ## create the star file
    columns = ['MicrographName', star.X_COLUMN_NAME, star.Y_COLUMN_NAME]
    if 'score' in particles:
        columns.append(star.SCORE_COLUMN_NAME)
    metadata = pd.DataFrame(metadata, columns=columns)
    metadata['ImageName'] = [
        str(i + 1) + '@' + image_name for i in range(len(metadata))
    ]
    if mz > 1:
        metadata['NrOfFrames'] = mz

    micrograph_metadata = None
    if args.metadata is not None:
        with open(args.metadata, 'r') as f:
            micrograph_metadata = star.parse_star(f)
        metadata = pd.merge(metadata,
                            micrograph_metadata,
                            on='MicrographName',
                            how='left')

    if resize != size:
        # rescale the detector pixel size
        pix = metadata['DetectorPixelSize'].values.astype(float)
        metadata['DetectorPixelSize'] = pix * (size / resize)

    ## write the star file
    with open(star_path, 'w') as f:
        star.write(metadata, f)