Esempio n. 1
0
def main(args):
    check_inputs(args)
    t1 = dt.now()

    ## set the device
    use_cuda = torch.cuda.is_available()
    log('Use cuda {}'.format(use_cuda))
    if use_cuda:
        torch.set_default_tensor_type(torch.cuda.FloatTensor)

    if args.config is not None:
        args = config.load_config(args.config, args)
    log(args)

    if args.downsample:
        assert args.downsample % 2 == 0, "Boxsize must be even"
        assert args.downsample < args.D, "Must be smaller than original box size"
    D = args.D + 1
    lattice = Lattice(D, extent=args.l_extent)
    if args.enc_mask:
        args.enc_mask = lattice.get_circular_mask(args.enc_mask)
        in_dim = args.enc_mask.sum()
    else:
        in_dim = lattice.D**2
    model = HetOnlyVAE(lattice,
                       args.qlayers,
                       args.qdim,
                       args.players,
                       args.pdim,
                       in_dim,
                       args.zdim,
                       encode_mode=args.encode_mode,
                       enc_mask=args.enc_mask,
                       enc_type=args.pe_type,
                       enc_dim=args.pe_dim,
                       domain=args.domain)

    log('Loading weights from {}'.format(args.weights))
    checkpoint = torch.load(args.weights)
    model.load_state_dict(checkpoint['model_state_dict'])

    model.eval()

    ### Multiple z ###
    if args.z_start or args.zfile:

        ### Get z values
        if args.z_start:
            args.z_start = np.array(args.z_start)
            args.z_end = np.array(args.z_end)
            z = np.repeat(np.arange(args.n, dtype=np.float32),
                          args.zdim).reshape((args.n, args.zdim))
            z *= ((args.z_end - args.z_start) / (args.n - 1))
            z += args.z_start
        else:
            z = np.loadtxt(args.zfile).reshape(-1, args.zdim)

        if not os.path.exists(args.o):
            os.makedirs(args.o)

        log(f'Generating {len(z)} volumes')
        for i, zz in enumerate(z):
            log(zz)
            if args.downsample:
                extent = lattice.extent * (args.downsample / args.D)
                vol = model.decoder.eval_volume(
                    lattice.get_downsample_coords(args.downsample + 1),
                    args.downsample + 1, extent, args.norm, zz)
            else:
                vol = model.decoder.eval_volume(lattice.coords, lattice.D,
                                                lattice.extent, args.norm, zz)
            out_mrc = '{}/{}{:03d}.mrc'.format(args.o, args.prefix, i)
            if args.flip:
                vol = vol[::-1]
            mrc.write(out_mrc, vol.astype(np.float32), Apix=args.Apix)

    ### Single z ###
    else:
        z = np.array(args.z)
        log(z)
        if args.downsample:
            extent = lattice.extent * (args.downsample / args.D)
            vol = model.decoder.eval_volume(
                lattice.get_downsample_coords(args.downsample + 1),
                args.downsample + 1, extent, args.norm, z)
        else:
            vol = model.decoder.eval_volume(lattice.coords, lattice.D,
                                            lattice.extent, args.norm, z)
        if args.flip:
            vol = vol[::-1]
        mrc.write(args.o, vol.astype(np.float32), Apix=args.Apix)

    td = dt.now() - t1
    log('Finsihed in {}'.format(td))
Esempio n. 2
0
def main(args):
    t1 = dt.now()
    if args.outdir is not None and not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    LOG = f'{args.outdir}/run.log'

    def flog(msg):  # HACK: switch to logging module
        return utils.flog(msg, LOG)

    if args.load == 'latest':
        args = get_latest(args, flog)
    flog(' '.join(sys.argv))
    flog(args)

    # set the random seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # set the device
    use_cuda = torch.cuda.is_available()
    device = torch.device('cuda' if use_cuda else 'cpu')
    flog('Use cuda {}'.format(use_cuda))
    if not use_cuda:
        log('WARNING: No GPUs detected')

    # set beta schedule
    if args.beta is None:
        args.beta = 1. / args.zdim
    try:
        args.beta = float(args.beta)
    except ValueError:
        assert args.beta_control, "Need to set beta control weight for schedule {}".format(
            args.beta)
    beta_schedule = get_beta_schedule(args.beta)

    # load index filter
    if args.ind is not None:
        flog('Filtering image dataset with {}'.format(args.ind))
        ind = pickle.load(open(args.ind, 'rb'))
    else:
        ind = None

    # load dataset
    flog(f'Loading dataset from {args.particles}')
    if args.tilt is None:
        tilt = None
        args.use_real = args.encode_mode == 'conv'

        if args.lazy:
            data = dataset.LazyMRCData(args.particles,
                                       norm=args.norm,
                                       invert_data=args.invert_data,
                                       ind=ind,
                                       keepreal=args.use_real,
                                       window=args.window,
                                       datadir=args.datadir,
                                       window_r=args.window_r,
                                       flog=flog)
        elif args.preprocessed:
            flog(
                f'Using preprocessed inputs. Ignoring any --window/--invert-data options'
            )
            data = dataset.PreprocessedMRCData(args.particles,
                                               norm=args.norm,
                                               ind=ind,
                                               flog=flog)
        else:
            data = dataset.MRCData(args.particles,
                                   norm=args.norm,
                                   invert_data=args.invert_data,
                                   ind=ind,
                                   keepreal=args.use_real,
                                   window=args.window,
                                   datadir=args.datadir,
                                   max_threads=args.max_threads,
                                   window_r=args.window_r,
                                   flog=flog)

    # Tilt series data -- lots of unsupported features
    else:
        assert args.encode_mode == 'tilt'
        if args.lazy: raise NotImplementedError
        if args.preprocessed: raise NotImplementedError
        data = dataset.TiltMRCData(args.particles,
                                   args.tilt,
                                   norm=args.norm,
                                   invert_data=args.invert_data,
                                   ind=ind,
                                   window=args.window,
                                   keepreal=args.use_real,
                                   datadir=args.datadir,
                                   window_r=args.window_r,
                                   flog=flog)
        tilt = torch.tensor(utils.xrot(args.tilt_deg).astype(np.float32),
                            device=device)
    Nimg = data.N
    D = data.D

    if args.encode_mode == 'conv':
        assert D - 1 == 64, "Image size must be 64x64 for convolutional encoder"

    # load poses
    if args.do_pose_sgd:
        assert args.domain == 'hartley', "Need to use --domain hartley if doing pose SGD"
    do_pose_sgd = args.do_pose_sgd
    posetracker = PoseTracker.load(args.poses,
                                   Nimg,
                                   D,
                                   's2s2' if do_pose_sgd else None,
                                   ind,
                                   device=device)
    pose_optimizer = torch.optim.SparseAdam(
        list(posetracker.parameters()),
        lr=args.pose_lr) if do_pose_sgd else None

    # load ctf
    if args.ctf is not None:
        if args.use_real:
            raise NotImplementedError(
                "Not implemented with real-space encoder. Use phase-flipped images instead"
            )
        flog('Loading ctf params from {}'.format(args.ctf))
        ctf_params = ctf.load_ctf_for_training(D - 1, args.ctf)
        if args.ind is not None: ctf_params = ctf_params[ind]
        assert ctf_params.shape == (Nimg, 8)
        ctf_params = torch.tensor(ctf_params, device=device)
    else:
        ctf_params = None

    # instantiate model
    lattice = Lattice(D, extent=0.5, device=device)
    if args.enc_mask is None:
        args.enc_mask = D // 2
    if args.enc_mask > 0:
        assert args.enc_mask <= D // 2
        enc_mask = lattice.get_circular_mask(args.enc_mask)
        in_dim = enc_mask.sum()
    elif args.enc_mask == -1:
        enc_mask = None
        in_dim = lattice.D**2 if not args.use_real else (lattice.D - 1)**2
    else:
        raise RuntimeError(
            "Invalid argument for encoder mask radius {}".format(
                args.enc_mask))
    activation = {"relu": nn.ReLU, "leaky_relu": nn.LeakyReLU}[args.activation]
    model = HetOnlyVAE(lattice,
                       args.qlayers,
                       args.qdim,
                       args.players,
                       args.pdim,
                       in_dim,
                       args.zdim,
                       encode_mode=args.encode_mode,
                       enc_mask=enc_mask,
                       enc_type=args.pe_type,
                       enc_dim=args.pe_dim,
                       domain=args.domain,
                       activation=activation,
                       feat_sigma=args.feat_sigma)
    model.to(device)
    flog(model)
    flog('{} parameters in model'.format(
        sum(p.numel() for p in model.parameters() if p.requires_grad)))
    flog('{} parameters in encoder'.format(
        sum(p.numel() for p in model.encoder.parameters() if p.requires_grad)))
    flog('{} parameters in deoder'.format(
        sum(p.numel() for p in model.decoder.parameters() if p.requires_grad)))

    # save configuration
    out_config = '{}/config.pkl'.format(args.outdir)
    save_config(args, data, lattice, model, out_config)

    optim = torch.optim.Adam(model.parameters(),
                             lr=args.lr,
                             weight_decay=args.wd)

    # Mixed precision training
    scaler = None
    if args.amp:
        assert args.batch_size % 8 == 0, "Batch size must be divisible by 8 for AMP training"
        assert (
            D -
            1) % 8 == 0, "Image size must be divisible by 8 for AMP training"
        assert args.pdim % 8 == 0, "Decoder hidden layer dimension must be divisible by 8 for AMP training"
        assert args.qdim % 8 == 0, "Encoder hidden layer dimension must be divisible by 8 for AMP training"
        # Also check zdim, enc_mask dim? Add them as warnings for now.
        if args.zdim % 8 != 0:
            log('Warning: z dimension is not a multiple of 8 -- AMP training speedup is not optimized'
                )
        if in_dim % 8 != 0:
            log('Warning: Masked input image dimension is not a mutiple of 8 -- AMP training speedup is not optimized'
                )
        try:  # Mixed precision with apex.amp
            model, optim = amp.initialize(model, optim, opt_level='O1')
        except:  # Mixed precision with pytorch (v1.6+)
            scaler = torch.cuda.amp.GradScaler()

    # restart from checkpoint
    if args.load:
        flog('Loading checkpoint from {}'.format(args.load))
        checkpoint = torch.load(args.load)
        model.load_state_dict(checkpoint['model_state_dict'])
        optim.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        model.train()
    else:
        start_epoch = 0

    # parallelize
    if args.multigpu and torch.cuda.device_count() > 1:
        log(f'Using {torch.cuda.device_count()} GPUs!')
        args.batch_size *= torch.cuda.device_count()
        log(f'Increasing batch size to {args.batch_size}')
        model = nn.DataParallel(model)
    elif args.multigpu:
        log(f'WARNING: --multigpu selected, but {torch.cuda.device_count()} GPUs detected'
            )

    # training loop
    data_generator = DataLoader(data, batch_size=args.batch_size, shuffle=True)
    num_epochs = args.num_epochs
    for epoch in range(start_epoch, num_epochs):
        t2 = dt.now()
        gen_loss_accum = 0
        loss_accum = 0
        kld_accum = 0
        eq_loss_accum = 0
        batch_it = 0
        for minibatch in data_generator:
            ind = minibatch[-1].to(device)
            y = minibatch[0].to(device)
            yt = minibatch[1].to(device) if tilt is not None else None
            B = len(ind)
            batch_it += B
            global_it = Nimg * epoch + batch_it

            beta = beta_schedule(global_it)

            yr = torch.from_numpy(data.particles_real[ind.numpy()]).to(
                device) if args.use_real else None
            if do_pose_sgd:
                pose_optimizer.zero_grad()
            rot, tran = posetracker.get_pose(ind)
            ctf_param = ctf_params[ind] if ctf_params is not None else None
            loss, gen_loss, kld = train_batch(model,
                                              lattice,
                                              y,
                                              yt,
                                              rot,
                                              tran,
                                              optim,
                                              beta,
                                              args.beta_control,
                                              tilt,
                                              ctf_params=ctf_param,
                                              yr=yr,
                                              use_amp=args.amp,
                                              scaler=scaler)
            if do_pose_sgd and epoch >= args.pretrain:
                pose_optimizer.step()

            # logging
            gen_loss_accum += gen_loss * B
            kld_accum += kld * B
            loss_accum += loss * B

            if batch_it % args.log_interval == 0:
                log('# [Train Epoch: {}/{}] [{}/{} images] gen loss={:.6f}, kld={:.6f}, beta={:.6f}, loss={:.6f}'
                    .format(epoch + 1, num_epochs, batch_it, Nimg, gen_loss,
                            kld, beta, loss))
        flog(
            '# =====> Epoch: {} Average gen loss = {:.6}, KLD = {:.6f}, total loss = {:.6f}; Finished in {}'
            .format(epoch + 1, gen_loss_accum / Nimg, kld_accum / Nimg,
                    loss_accum / Nimg,
                    dt.now() - t2))

        if args.checkpoint and epoch % args.checkpoint == 0:
            out_weights = '{}/weights.{}.pkl'.format(args.outdir, epoch)
            out_z = '{}/z.{}.pkl'.format(args.outdir, epoch)
            model.eval()
            with torch.no_grad():
                z_mu, z_logvar = eval_z(model, lattice, data, args.batch_size,
                                        device, posetracker.trans, tilt
                                        is not None, ctf_params, args.use_real)
                save_checkpoint(model, optim, epoch, z_mu, z_logvar,
                                out_weights, out_z)
            if args.do_pose_sgd and epoch >= args.pretrain:
                out_pose = '{}/pose.{}.pkl'.format(args.outdir, epoch)
                posetracker.save(out_pose)

    # save model weights, latent encoding, and evaluate the model on 3D lattice
    out_weights = '{}/weights.pkl'.format(args.outdir)
    out_z = '{}/z.pkl'.format(args.outdir)
    model.eval()
    with torch.no_grad():
        z_mu, z_logvar = eval_z(model, lattice, data, args.batch_size, device,
                                posetracker.trans, tilt is not None,
                                ctf_params, args.use_real)
        save_checkpoint(model, optim, epoch, z_mu, z_logvar, out_weights,
                        out_z)

    if args.do_pose_sgd and epoch >= args.pretrain:
        out_pose = '{}/pose.pkl'.format(args.outdir)
        posetracker.save(out_pose)
    td = dt.now() - t1
    flog('Finished in {} ({} per epoch)'.format(
        td, td / (num_epochs - start_epoch)))
Esempio n. 3
0
def main(args):
    assert args.o.endswith('.mrc')

    t1 = time.time()
    log(args)
    if not os.path.exists(os.path.dirname(args.o)):
        os.makedirs(os.path.dirname(args.o))

    ## set the device
    use_cuda = torch.cuda.is_available()
    device = torch.device('cuda' if use_cuda else 'cpu')
    log('Use cuda {}'.format(use_cuda))
    if use_cuda:
        torch.set_default_tensor_type(torch.cuda.FloatTensor)
    else:
        log('WARNING: No GPUs detected')

    # load the particles
    if args.tilt is None:
        data = dataset.LazyMRCData(args.particles,
                                   norm=(0, 1),
                                   invert_data=args.invert_data,
                                   datadir=args.datadir)
        tilt = None
    else:
        data = dataset.TiltMRCData(args.particles,
                                   args.tilt,
                                   norm=(0, 1),
                                   invert_data=args.invert_data,
                                   datadir=args.datadir)
        tilt = torch.tensor(utils.xrot(args.tilt_deg).astype(np.float32))
    D = data.D
    Nimg = data.N

    lattice = Lattice(D, extent=D // 2)

    posetracker = PoseTracker.load(args.poses, Nimg, D, None, None)

    if args.ctf is not None:
        log('Loading ctf params from {}'.format(args.ctf))
        ctf_params = ctf.load_ctf_for_training(D - 1, args.ctf)
        ctf_params = torch.tensor(ctf_params)
    else:
        ctf_params = None
    Apix = ctf_params[0, 0] if ctf_params is not None else 1

    V = torch.zeros((D, D, D))
    counts = torch.zeros((D, D, D))

    mask = lattice.get_circular_mask(D // 2)

    if args.ind:
        iterator = pickle.load(open(args.ind, 'rb'))
    elif args.first:
        args.first = min(args.first, Nimg)
        iterator = range(args.first)
    else:
        iterator = range(Nimg)

    for ii in iterator:
        if ii % 100 == 0: log('image {}'.format(ii))
        r, t = posetracker.get_pose(ii)
        ff = data.get(ii)
        if tilt is not None:
            ff, ff_tilt = ff  # EW
        ff = torch.tensor(ff)
        ff = ff.view(-1)[mask]
        if ctf_params is not None:
            freqs = lattice.freqs2d / ctf_params[ii, 0]
            c = ctf.compute_ctf(freqs, *ctf_params[ii, 1:]).view(-1)[mask]
            ff *= c.sign()
        if t is not None:
            ff = lattice.translate_ht(ff.view(1, -1), t.view(1, 1, 2),
                                      mask).view(-1)
        ff_coord = lattice.coords[mask] @ r
        add_slice(V, counts, ff_coord, ff, D)

        # tilt series
        if args.tilt is not None:
            ff_tilt = torch.tensor(ff_tilt)
            ff_tilt = ff_tilt.view(-1)[mask]
            if ctf_params is not None:
                ff_tilt *= c.sign()
            if t is not None:
                ff_tilt = lattice.translate_ht(ff_tilt.view(1, -1),
                                               t.view(1, 1, 2), mask).view(-1)
            ff_coord = lattice.coords[mask] @ tilt @ r
            add_slice(V, counts, ff_coord, ff_tilt, D)

    td = time.time() - t1
    log('Backprojected {} images in {}s ({}s per image)'.format(
        len(iterator), td, td / Nimg))
    counts[counts == 0] = 1
    V /= counts
    V = fft.ihtn_center(V[0:-1, 0:-1, 0:-1].cpu().numpy())
    mrc.write(args.o, V.astype('float32'), Apix=Apix)
Esempio n. 4
0
def main(args):
    t1 = dt.now()

    # set the device
    use_cuda = torch.cuda.is_available()
    device = torch.device('cuda' if use_cuda else 'cpu')
    log('Use cuda {}'.format(use_cuda))
    if use_cuda:
        torch.set_default_tensor_type(torch.cuda.FloatTensor)

    if args.config is not None:
        args = config.load_config(args.config, args)
    log(args)
    beta = args.beta

    # load the particles
    if args.ind is not None:
        log('Filtering image dataset with {}'.format(args.ind))
        ind = pickle.load(open(args.ind, 'rb'))
    else:
        ind = None

    if args.tilt is None:
        if args.encode_mode == 'conv':
            args.use_real = True
        if args.lazy:
            data = dataset.LazyMRCData(args.particles,
                                       norm=args.norm,
                                       invert_data=args.invert_data,
                                       ind=ind,
                                       keepreal=args.use_real,
                                       window=args.window,
                                       datadir=args.datadir)
        else:
            data = dataset.MRCData(args.particles,
                                   norm=args.norm,
                                   invert_data=args.invert_data,
                                   ind=ind,
                                   keepreal=args.use_real,
                                   window=args.window,
                                   datadir=args.datadir)
        tilt = None
    else:
        assert args.encode_mode == 'tilt'
        if args.lazy: raise NotImplementedError
        data = dataset.TiltMRCData(args.particles,
                                   args.tilt,
                                   norm=args.norm,
                                   invert_data=args.invert_data,
                                   ind=ind,
                                   window=args.window,
                                   keepreal=args.use_real,
                                   datadir=args.datadir)
        tilt = torch.tensor(utils.xrot(args.tilt_deg).astype(np.float32))
    Nimg = data.N
    D = data.D

    if args.encode_mode == 'conv':
        assert D - 1 == 64, "Image size must be 64x64 for convolutional encoder"

    # load poses
    posetracker = PoseTracker.load(args.poses, Nimg, D, None, ind)

    # load ctf
    if args.ctf is not None:
        if args.use_real:
            raise NotImplementedError(
                "Not implemented with real-space encoder. Use phase-flipped images instead"
            )
        log('Loading ctf params from {}'.format(args.ctf))
        ctf_params = ctf.load_ctf_for_training(D - 1, args.ctf)
        if args.ind is not None: ctf_params = ctf_params[ind]
        ctf_params = torch.tensor(ctf_params)
    else:
        ctf_params = None

    # instantiate model
    lattice = Lattice(D, extent=0.5)
    if args.enc_mask is None:
        args.enc_mask = D // 2
    if args.enc_mask > 0:
        assert args.enc_mask <= D // 2
        enc_mask = lattice.get_circular_mask(args.enc_mask)
        in_dim = enc_mask.sum()
    elif args.enc_mask == -1:
        enc_mask = None
        in_dim = lattice.D**2 if not args.use_real else (lattice.D - 1)**2
    else:
        raise RuntimeError(
            "Invalid argument for encoder mask radius {}".format(
                args.enc_mask))
    model = HetOnlyVAE(lattice,
                       args.qlayers,
                       args.qdim,
                       args.players,
                       args.pdim,
                       in_dim,
                       args.zdim,
                       encode_mode=args.encode_mode,
                       enc_mask=enc_mask,
                       enc_type=args.pe_type,
                       enc_dim=args.pe_dim,
                       domain=args.domain)

    log('Loading weights from {}'.format(args.weights))
    checkpoint = torch.load(args.weights)
    model.load_state_dict(checkpoint['model_state_dict'])

    model.eval()
    z_mu_all = []
    z_logvar_all = []
    gen_loss_accum = 0
    kld_accum = 0
    loss_accum = 0
    batch_it = 0
    data_generator = DataLoader(data,
                                batch_size=args.batch_size,
                                shuffle=False)
    for minibatch in data_generator:
        ind = minibatch[-1].to(device)
        y = minibatch[0].to(device)
        yt = minibatch[1].to(device) if tilt is not None else None
        B = len(ind)
        batch_it += B

        yr = torch.from_numpy(
            data.particles_real[ind]).to(device) if args.use_real else None
        rot, tran = posetracker.get_pose(ind)
        ctf_param = ctf_params[ind] if ctf_params is not None else None

        z_mu, z_logvar, loss, gen_loss, kld = eval_batch(model,
                                                         lattice,
                                                         y,
                                                         yt,
                                                         rot,
                                                         tran,
                                                         beta,
                                                         tilt,
                                                         ctf_params=ctf_param,
                                                         yr=yr)

        z_mu_all.append(z_mu)
        z_logvar_all.append(z_logvar)

        # logging
        gen_loss_accum += gen_loss * B
        kld_accum += kld * B
        loss_accum += loss * B

        if batch_it % args.log_interval == 0:
            log('# [{}/{} images] gen loss={:.4f}, kld={:.4f}, beta={:.4f}, loss={:.4f}'
                .format(batch_it, Nimg, gen_loss, kld, beta, loss))
    log('# =====> Average gen loss = {:.6}, KLD = {:.6f}, total loss = {:.6f}'.
        format(gen_loss_accum / Nimg, kld_accum / Nimg, loss_accum / Nimg))

    z_mu_all = np.vstack(z_mu_all)
    z_logvar_all = np.vstack(z_logvar_all)

    with open(args.o, 'wb') as f:
        pickle.dump(z_mu_all, f)
        pickle.dump(z_logvar_all, f)
        pickle.dump([loss_accum, gen_loss_accum, kld_accum], f)

    log('Finsihed in {}'.format(dt.now() - t1))
Esempio n. 5
0
def main(args):
    t1 = dt.now()
    if args.outdir is not None and not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    LOG = f'{args.outdir}/run.log'

    def flog(msg):  # HACK: switch to logging module
        return utils.flog(msg, LOG)

    if args.load == 'latest':
        args = get_latest(args)
    flog(' '.join(sys.argv))
    flog(args)

    # set the random seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # set the device
    use_cuda = torch.cuda.is_available()
    device = torch.device('cuda' if use_cuda else 'cpu')
    flog('Use cuda {}'.format(use_cuda))
    if use_cuda:
        torch.set_default_tensor_type(torch.cuda.FloatTensor)

    # set beta schedule
    try:
        args.beta = float(args.beta)
    except ValueError:
        assert args.beta_control, "Need to set beta control weight for schedule {}".format(
            args.beta)
    beta_schedule = get_beta_schedule(args.beta)

    # load the particles
    if args.ind is not None:
        flog('Filtering image dataset with {}'.format(args.ind))
        ind = pickle.load(open(args.ind, 'rb'))
    else:
        ind = None
    if args.tilt is None:
        if args.encode_mode == 'conv':
            args.use_real = True
        if args.lazy:
            data = dataset.LazyMRCData(args.particles,
                                       norm=args.norm,
                                       invert_data=args.invert_data,
                                       ind=ind,
                                       keepreal=args.use_real,
                                       window=args.window,
                                       datadir=args.datadir,
                                       relion31=args.relion31)
        else:
            data = dataset.MRCData(args.particles,
                                   norm=args.norm,
                                   invert_data=args.invert_data,
                                   ind=ind,
                                   keepreal=args.use_real,
                                   window=args.window,
                                   datadir=args.datadir,
                                   relion31=args.relion31)
        tilt = None
    else:
        assert args.encode_mode == 'tilt'
        if args.lazy: raise NotImplementedError
        if args.relion31: raise NotImplementedError
        data = dataset.TiltMRCData(args.particles,
                                   args.tilt,
                                   norm=args.norm,
                                   invert_data=args.invert_data,
                                   ind=ind,
                                   window=args.window,
                                   keepreal=args.use_real,
                                   datadir=args.datadir)
        tilt = torch.tensor(utils.xrot(args.tilt_deg).astype(np.float32))
    Nimg = data.N
    D = data.D

    if args.encode_mode == 'conv':
        assert D - 1 == 64, "Image size must be 64x64 for convolutional encoder"

    # load poses
    if args.do_pose_sgd:
        assert args.domain == 'hartley', "Need to use --domain hartley if doing pose SGD"
    do_pose_sgd = args.do_pose_sgd
    posetracker = PoseTracker.load(args.poses, Nimg, D,
                                   's2s2' if do_pose_sgd else None, ind)
    pose_optimizer = torch.optim.SparseAdam(
        posetracker.parameters(), lr=args.pose_lr) if do_pose_sgd else None

    # load ctf
    if args.ctf is not None:
        if args.use_real:
            raise NotImplementedError(
                "Not implemented with real-space encoder. Use phase-flipped images instead"
            )
        flog('Loading ctf params from {}'.format(args.ctf))
        ctf_params = ctf.load_ctf_for_training(D - 1, args.ctf)
        if args.ind is not None: ctf_params = ctf_params[ind]
        ctf_params = torch.tensor(ctf_params)
    else:
        ctf_params = None

    # instantiate model
    lattice = Lattice(D, extent=0.5)
    if args.enc_mask is None:
        args.enc_mask = D // 2
    if args.enc_mask > 0:
        assert args.enc_mask <= D // 2
        enc_mask = lattice.get_circular_mask(args.enc_mask)
        in_dim = enc_mask.sum()
    elif args.enc_mask == -1:
        enc_mask = None
        in_dim = lattice.D**2 if not args.use_real else (lattice.D - 1)**2
    else:
        raise RuntimeError(
            "Invalid argument for encoder mask radius {}".format(
                args.enc_mask))
    model = HetOnlyVAE(lattice,
                       args.qlayers,
                       args.qdim,
                       args.players,
                       args.pdim,
                       in_dim,
                       args.zdim,
                       encode_mode=args.encode_mode,
                       enc_mask=enc_mask,
                       enc_type=args.pe_type,
                       enc_dim=args.pe_dim,
                       domain=args.domain)
    flog(model)
    flog('{} parameters in model'.format(
        sum(p.numel() for p in model.parameters() if p.requires_grad)))

    # save configuration
    out_config = '{}/config.pkl'.format(args.outdir)
    save_config(args, data, lattice, model, out_config)

    optim = torch.optim.Adam(model.parameters(),
                             lr=args.lr,
                             weight_decay=args.wd)

    # restart from checkpoint
    if args.load:
        flog('Loading checkpoint from {}'.format(args.load))
        checkpoint = torch.load(args.load)
        model.load_state_dict(checkpoint['model_state_dict'])
        optim.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        model.train()
    else:
        start_epoch = 0

    # training loop
    data_generator = DataLoader(data, batch_size=args.batch_size, shuffle=True)
    num_epochs = args.num_epochs
    for epoch in range(start_epoch, num_epochs):
        t2 = dt.now()
        gen_loss_accum = 0
        loss_accum = 0
        kld_accum = 0
        eq_loss_accum = 0
        batch_it = 0
        for minibatch in data_generator:
            ind = minibatch[-1].to(device)
            y = minibatch[0].to(device)
            yt = minibatch[1].to(device) if tilt is not None else None
            B = len(ind)
            batch_it += B
            global_it = Nimg * epoch + batch_it

            beta = beta_schedule(global_it)

            yr = torch.from_numpy(data.particles_real[ind.numpy()]).to(
                device) if args.use_real else None
            if do_pose_sgd:
                pose_optimizer.zero_grad()
            rot, tran = posetracker.get_pose(ind)
            ctf_param = ctf_params[ind] if ctf_params is not None else None
            loss, gen_loss, kld = train_batch(model,
                                              lattice,
                                              y,
                                              yt,
                                              rot,
                                              tran,
                                              optim,
                                              beta,
                                              args.beta_control,
                                              tilt,
                                              ctf_params=ctf_param,
                                              yr=yr)
            if do_pose_sgd and epoch >= args.pretrain:
                pose_optimizer.step()

            # logging
            gen_loss_accum += gen_loss * B
            kld_accum += kld * B
            loss_accum += loss * B

            if batch_it % args.log_interval == 0:
                log('# [Train Epoch: {}/{}] [{}/{} images] gen loss={:.6f}, kld={:.6f}, beta={:.6f}, loss={:.6f}'
                    .format(epoch + 1, num_epochs, batch_it, Nimg, gen_loss,
                            kld, beta, loss))
        flog(
            '# =====> Epoch: {} Average gen loss = {:.6}, KLD = {:.6f}, total loss = {:.6f}; Finished in {}'
            .format(epoch + 1, gen_loss_accum / Nimg, kld_accum / Nimg,
                    loss_accum / Nimg,
                    dt.now() - t2))

        if args.checkpoint and epoch % args.checkpoint == 0:
            out_weights = '{}/weights.{}.pkl'.format(args.outdir, epoch)
            out_z = '{}/z.{}.pkl'.format(args.outdir, epoch)
            model.eval()
            with torch.no_grad():
                z_mu, z_logvar = eval_z(model, lattice, data, args.batch_size,
                                        device, posetracker.trans, tilt
                                        is not None, ctf_params, args.use_real)
                save_checkpoint(model, optim, epoch, z_mu, z_logvar,
                                out_weights, out_z)
            if args.do_pose_sgd and epoch >= args.pretrain:
                out_pose = '{}/pose.{}.pkl'.format(args.outdir, epoch)
                posetracker.save(out_pose)

    # save model weights, latent encoding, and evaluate the model on 3D lattice
    out_weights = '{}/weights.pkl'.format(args.outdir)
    out_z = '{}/z.pkl'.format(args.outdir)
    model.eval()
    with torch.no_grad():
        z_mu, z_logvar = eval_z(model, lattice, data, args.batch_size, device,
                                posetracker.trans, tilt is not None,
                                ctf_params, args.use_real)
        save_checkpoint(model, optim, epoch, z_mu, z_logvar, out_weights,
                        out_z)

    if args.do_pose_sgd and epoch >= args.pretrain:
        out_pose = '{}/pose.pkl'.format(args.outdir)
        posetracker.save(out_pose)
    td = dt.now() - t1
    flog('Finsihed in {} ({} per epoch)'.format(
        td, td / (num_epochs - start_epoch)))