Exemplo n.º 1
0
def main(args):
    check_inputs(args)
    t1 = dt.now()

    ## set the device
    use_cuda = torch.cuda.is_available()
    device = torch.device('cuda' if use_cuda else 'cpu')
    log('Use cuda {}'.format(use_cuda))
    if not use_cuda:
        log('WARNING: No GPUs detected')

    log(args)
    cfg = config.overwrite_config(args.config, args)
    log('Loaded configuration:')
    pprint.pprint(cfg)

    D = cfg['lattice_args']['D']  # image size + 1
    zdim = cfg['model_args']['zdim']
    norm = cfg['dataset_args']['norm']

    if args.downsample:
        assert args.downsample % 2 == 0, "Boxsize must be even"
        assert args.downsample <= D - 1, "Must be smaller than original box size"

    model, lattice = HetOnlyVAE.load(cfg, args.weights, device=device)
    model.eval()

    ### Multiple z ###
    if args.z_start or args.zfile:

        ### Get z values
        if args.z_start:
            args.z_start = np.array(args.z_start)
            args.z_end = np.array(args.z_end)
            z = np.repeat(np.arange(args.n, dtype=np.float32), zdim).reshape(
                (args.n, zdim))
            z *= ((args.z_end - args.z_start) / (args.n - 1))
            z += args.z_start
        else:
            z = np.loadtxt(args.zfile).reshape(-1, zdim)

        if not os.path.exists(args.o):
            os.makedirs(args.o)

        log(f'Generating {len(z)} volumes')
        for i, zz in enumerate(z):
            log(zz)
            if args.downsample:
                extent = lattice.extent * (args.downsample / (D - 1))
                vol = model.decoder.eval_volume(
                    lattice.get_downsample_coords(args.downsample + 1),
                    args.downsample + 1, extent, norm, zz)
            else:
                vol = model.decoder.eval_volume(lattice.coords, lattice.D,
                                                lattice.extent, norm, zz)
            out_mrc = '{}/{}{:03d}.mrc'.format(args.o, args.prefix, i)
            if args.flip:
                vol = vol[::-1]
            if args.invert:
                vol *= -1
            mrc.write(out_mrc, vol.astype(np.float32), Apix=args.Apix)

    ### Single z ###
    else:
        z = np.array(args.z)
        log(z)
        if args.downsample:
            extent = lattice.extent * (args.downsample / (D - 1))
            vol = model.decoder.eval_volume(
                lattice.get_downsample_coords(args.downsample + 1),
                args.downsample + 1, extent, norm, z)
        else:
            vol = model.decoder.eval_volume(lattice.coords, lattice.D,
                                            lattice.extent, norm, z)
        if args.flip:
            vol = vol[::-1]
        if args.invert:
            vol *= -1
        mrc.write(args.o, vol.astype(np.float32), Apix=args.Apix)

    td = dt.now() - t1
    log('Finished in {}'.format(td))
Exemplo n.º 2
0
def main(args):
    t1 = dt.now()

    # make output directories
    if not os.path.exists(os.path.dirname(args.o)):
        os.makedirs(os.path.dirname(args.o))
    if not os.path.exists(os.path.dirname(args.out_z)):
        os.makedirs(os.path.dirname(args.out_z))

    # set the device
    use_cuda = torch.cuda.is_available()
    device = torch.device('cuda' if use_cuda else 'cpu')
    log('Use cuda {}'.format(use_cuda))
    if not use_cuda:
        log('WARNING: No GPUs detected')

    log(args)
    cfg = config.overwrite_config(args.config, args)
    log('Loaded configuration:')
    pprint.pprint(cfg)

    zdim = cfg['model_args']['zdim']
    beta = 1. / zdim if args.beta is None else args.beta

    # load the particles
    if args.ind is not None:
        log('Filtering image dataset with {}'.format(args.ind))
        ind = pickle.load(open(args.ind, 'rb'))
    else:
        ind = None

    # TODO: extract dataset arguments from cfg
    if args.tilt is None:
        if args.encode_mode == 'conv':
            args.use_real = True
        if args.lazy:
            data = dataset.LazyMRCData(args.particles,
                                       norm=args.norm,
                                       invert_data=args.invert_data,
                                       ind=ind,
                                       keepreal=args.use_real,
                                       window=args.window,
                                       datadir=args.datadir,
                                       window_r=args.window_r)
        else:
            data = dataset.MRCData(args.particles,
                                   norm=args.norm,
                                   invert_data=args.invert_data,
                                   ind=ind,
                                   keepreal=args.use_real,
                                   window=args.window,
                                   datadir=args.datadir,
                                   window_r=args.window_r)
        tilt = None
    else:
        assert args.encode_mode == 'tilt'
        if args.lazy: raise NotImplementedError
        data = dataset.TiltMRCData(args.particles,
                                   args.tilt,
                                   norm=args.norm,
                                   invert_data=args.invert_data,
                                   ind=ind,
                                   window=args.window,
                                   keepreal=args.use_real,
                                   datadir=args.datadir,
                                   window_r=args.window_r)
        tilt = torch.tensor(utils.xrot(args.tilt_deg).astype(np.float32))
    Nimg = data.N
    D = data.D

    if args.encode_mode == 'conv':
        assert D - 1 == 64, "Image size must be 64x64 for convolutional encoder"

    # load poses
    posetracker = PoseTracker.load(args.poses,
                                   Nimg,
                                   D,
                                   None,
                                   ind,
                                   device=device)

    # load ctf
    if args.ctf is not None:
        if args.use_real:
            raise NotImplementedError(
                "Not implemented with real-space encoder. Use phase-flipped images instead"
            )
        log('Loading ctf params from {}'.format(args.ctf))
        ctf_params = ctf.load_ctf_for_training(D - 1, args.ctf)
        if args.ind is not None: ctf_params = ctf_params[ind]
        ctf_params = torch.tensor(ctf_params, device=device)
    else:
        ctf_params = None

    # instantiate model
    model, lattice = HetOnlyVAE.load(cfg, args.weights, device=device)
    model.eval()
    z_mu_all = []
    z_logvar_all = []
    gen_loss_accum = 0
    kld_accum = 0
    loss_accum = 0
    batch_it = 0
    data_generator = DataLoader(data,
                                batch_size=args.batch_size,
                                shuffle=False)
    for minibatch in data_generator:
        ind = minibatch[-1].to(device)
        y = minibatch[0].to(device)
        yt = minibatch[1].to(device) if tilt is not None else None
        B = len(ind)
        batch_it += B

        yr = torch.from_numpy(
            data.particles_real[ind]).to(device) if args.use_real else None
        rot, tran = posetracker.get_pose(ind)
        ctf_param = ctf_params[ind] if ctf_params is not None else None

        z_mu, z_logvar, loss, gen_loss, kld = eval_batch(model,
                                                         lattice,
                                                         y,
                                                         yt,
                                                         rot,
                                                         tran,
                                                         beta,
                                                         tilt,
                                                         ctf_params=ctf_param,
                                                         yr=yr)

        z_mu_all.append(z_mu)
        z_logvar_all.append(z_logvar)

        # logging
        gen_loss_accum += gen_loss * B
        kld_accum += kld * B
        loss_accum += loss * B

        if batch_it % args.log_interval == 0:
            log('# [{}/{} images] gen loss={:.4f}, kld={:.4f}, beta={:.4f}, loss={:.4f}'
                .format(batch_it, Nimg, gen_loss, kld, beta, loss))
    log('# =====> Average gen loss = {:.6}, KLD = {:.6f}, total loss = {:.6f}'.
        format(gen_loss_accum / Nimg, kld_accum / Nimg, loss_accum / Nimg))

    z_mu_all = np.vstack(z_mu_all)
    z_logvar_all = np.vstack(z_logvar_all)

    with open(args.out_z, 'wb') as f:
        pickle.dump(z_mu_all, f)
        pickle.dump(z_logvar_all, f)
    with open(args.o, 'wb') as f:
        pickle.dump(
            {
                'loss': loss_accum / Nimg,
                'recon': gen_loss_accum / Nimg,
                'kld': kld_accum / Nimg
            }, f)

    log('Finished in {}'.format(dt.now() - t1))