def main(args): t1 = dt.now() if args.outdir is not None and not os.path.exists(args.outdir): os.makedirs(args.outdir) LOG = f'{args.outdir}/run.log' def flog(msg): # HACK: switch to logging module return utils.flog(msg, LOG) if args.load == 'latest': args = get_latest(args, flog) flog(' '.join(sys.argv)) flog(args) # set the random seed np.random.seed(args.seed) torch.manual_seed(args.seed) # set the device use_cuda = torch.cuda.is_available() device = torch.device('cuda' if use_cuda else 'cpu') flog('Use cuda {}'.format(use_cuda)) if not use_cuda: log('WARNING: No GPUs detected') # set beta schedule if args.beta is None: args.beta = 1. / args.zdim try: args.beta = float(args.beta) except ValueError: assert args.beta_control, "Need to set beta control weight for schedule {}".format( args.beta) beta_schedule = get_beta_schedule(args.beta) # load index filter if args.ind is not None: flog('Filtering image dataset with {}'.format(args.ind)) ind = pickle.load(open(args.ind, 'rb')) else: ind = None # load dataset flog(f'Loading dataset from {args.particles}') if args.tilt is None: tilt = None args.use_real = args.encode_mode == 'conv' if args.lazy: data = dataset.LazyMRCData(args.particles, norm=args.norm, invert_data=args.invert_data, ind=ind, keepreal=args.use_real, window=args.window, datadir=args.datadir, window_r=args.window_r, flog=flog) elif args.preprocessed: flog( f'Using preprocessed inputs. Ignoring any --window/--invert-data options' ) data = dataset.PreprocessedMRCData(args.particles, norm=args.norm, ind=ind, flog=flog) else: data = dataset.MRCData(args.particles, norm=args.norm, invert_data=args.invert_data, ind=ind, keepreal=args.use_real, window=args.window, datadir=args.datadir, max_threads=args.max_threads, window_r=args.window_r, flog=flog) # Tilt series data -- lots of unsupported features else: assert args.encode_mode == 'tilt' if args.lazy: raise NotImplementedError if args.preprocessed: raise NotImplementedError data = dataset.TiltMRCData(args.particles, args.tilt, norm=args.norm, invert_data=args.invert_data, ind=ind, window=args.window, keepreal=args.use_real, datadir=args.datadir, window_r=args.window_r, flog=flog) tilt = torch.tensor(utils.xrot(args.tilt_deg).astype(np.float32), device=device) Nimg = data.N D = data.D if args.encode_mode == 'conv': assert D - 1 == 64, "Image size must be 64x64 for convolutional encoder" # load poses if args.do_pose_sgd: assert args.domain == 'hartley', "Need to use --domain hartley if doing pose SGD" do_pose_sgd = args.do_pose_sgd posetracker = PoseTracker.load(args.poses, Nimg, D, 's2s2' if do_pose_sgd else None, ind, device=device) pose_optimizer = torch.optim.SparseAdam( list(posetracker.parameters()), lr=args.pose_lr) if do_pose_sgd else None # load ctf if args.ctf is not None: if args.use_real: raise NotImplementedError( "Not implemented with real-space encoder. Use phase-flipped images instead" ) flog('Loading ctf params from {}'.format(args.ctf)) ctf_params = ctf.load_ctf_for_training(D - 1, args.ctf) if args.ind is not None: ctf_params = ctf_params[ind] assert ctf_params.shape == (Nimg, 8) ctf_params = torch.tensor(ctf_params, device=device) else: ctf_params = None # instantiate model lattice = Lattice(D, extent=0.5, device=device) if args.enc_mask is None: args.enc_mask = D // 2 if args.enc_mask > 0: assert args.enc_mask <= D // 2 enc_mask = lattice.get_circular_mask(args.enc_mask) in_dim = enc_mask.sum() elif args.enc_mask == -1: enc_mask = None in_dim = lattice.D**2 if not args.use_real else (lattice.D - 1)**2 else: raise RuntimeError( "Invalid argument for encoder mask radius {}".format( args.enc_mask)) activation = {"relu": nn.ReLU, "leaky_relu": nn.LeakyReLU}[args.activation] model = HetOnlyVAE(lattice, args.qlayers, args.qdim, args.players, args.pdim, in_dim, args.zdim, encode_mode=args.encode_mode, enc_mask=enc_mask, enc_type=args.pe_type, enc_dim=args.pe_dim, domain=args.domain, activation=activation, feat_sigma=args.feat_sigma) model.to(device) flog(model) flog('{} parameters in model'.format( sum(p.numel() for p in model.parameters() if p.requires_grad))) flog('{} parameters in encoder'.format( sum(p.numel() for p in model.encoder.parameters() if p.requires_grad))) flog('{} parameters in deoder'.format( sum(p.numel() for p in model.decoder.parameters() if p.requires_grad))) # save configuration out_config = '{}/config.pkl'.format(args.outdir) save_config(args, data, lattice, model, out_config) optim = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd) # Mixed precision training scaler = None if args.amp: assert args.batch_size % 8 == 0, "Batch size must be divisible by 8 for AMP training" assert ( D - 1) % 8 == 0, "Image size must be divisible by 8 for AMP training" assert args.pdim % 8 == 0, "Decoder hidden layer dimension must be divisible by 8 for AMP training" assert args.qdim % 8 == 0, "Encoder hidden layer dimension must be divisible by 8 for AMP training" # Also check zdim, enc_mask dim? Add them as warnings for now. if args.zdim % 8 != 0: log('Warning: z dimension is not a multiple of 8 -- AMP training speedup is not optimized' ) if in_dim % 8 != 0: log('Warning: Masked input image dimension is not a mutiple of 8 -- AMP training speedup is not optimized' ) try: # Mixed precision with apex.amp model, optim = amp.initialize(model, optim, opt_level='O1') except: # Mixed precision with pytorch (v1.6+) scaler = torch.cuda.amp.GradScaler() # restart from checkpoint if args.load: flog('Loading checkpoint from {}'.format(args.load)) checkpoint = torch.load(args.load) model.load_state_dict(checkpoint['model_state_dict']) optim.load_state_dict(checkpoint['optimizer_state_dict']) start_epoch = checkpoint['epoch'] + 1 model.train() else: start_epoch = 0 # parallelize if args.multigpu and torch.cuda.device_count() > 1: log(f'Using {torch.cuda.device_count()} GPUs!') args.batch_size *= torch.cuda.device_count() log(f'Increasing batch size to {args.batch_size}') model = nn.DataParallel(model) elif args.multigpu: log(f'WARNING: --multigpu selected, but {torch.cuda.device_count()} GPUs detected' ) # training loop data_generator = DataLoader(data, batch_size=args.batch_size, shuffle=True) num_epochs = args.num_epochs for epoch in range(start_epoch, num_epochs): t2 = dt.now() gen_loss_accum = 0 loss_accum = 0 kld_accum = 0 eq_loss_accum = 0 batch_it = 0 for minibatch in data_generator: ind = minibatch[-1].to(device) y = minibatch[0].to(device) yt = minibatch[1].to(device) if tilt is not None else None B = len(ind) batch_it += B global_it = Nimg * epoch + batch_it beta = beta_schedule(global_it) yr = torch.from_numpy(data.particles_real[ind.numpy()]).to( device) if args.use_real else None if do_pose_sgd: pose_optimizer.zero_grad() rot, tran = posetracker.get_pose(ind) ctf_param = ctf_params[ind] if ctf_params is not None else None loss, gen_loss, kld = train_batch(model, lattice, y, yt, rot, tran, optim, beta, args.beta_control, tilt, ctf_params=ctf_param, yr=yr, use_amp=args.amp, scaler=scaler) if do_pose_sgd and epoch >= args.pretrain: pose_optimizer.step() # logging gen_loss_accum += gen_loss * B kld_accum += kld * B loss_accum += loss * B if batch_it % args.log_interval == 0: log('# [Train Epoch: {}/{}] [{}/{} images] gen loss={:.6f}, kld={:.6f}, beta={:.6f}, loss={:.6f}' .format(epoch + 1, num_epochs, batch_it, Nimg, gen_loss, kld, beta, loss)) flog( '# =====> Epoch: {} Average gen loss = {:.6}, KLD = {:.6f}, total loss = {:.6f}; Finished in {}' .format(epoch + 1, gen_loss_accum / Nimg, kld_accum / Nimg, loss_accum / Nimg, dt.now() - t2)) if args.checkpoint and epoch % args.checkpoint == 0: out_weights = '{}/weights.{}.pkl'.format(args.outdir, epoch) out_z = '{}/z.{}.pkl'.format(args.outdir, epoch) model.eval() with torch.no_grad(): z_mu, z_logvar = eval_z(model, lattice, data, args.batch_size, device, posetracker.trans, tilt is not None, ctf_params, args.use_real) save_checkpoint(model, optim, epoch, z_mu, z_logvar, out_weights, out_z) if args.do_pose_sgd and epoch >= args.pretrain: out_pose = '{}/pose.{}.pkl'.format(args.outdir, epoch) posetracker.save(out_pose) # save model weights, latent encoding, and evaluate the model on 3D lattice out_weights = '{}/weights.pkl'.format(args.outdir) out_z = '{}/z.pkl'.format(args.outdir) model.eval() with torch.no_grad(): z_mu, z_logvar = eval_z(model, lattice, data, args.batch_size, device, posetracker.trans, tilt is not None, ctf_params, args.use_real) save_checkpoint(model, optim, epoch, z_mu, z_logvar, out_weights, out_z) if args.do_pose_sgd and epoch >= args.pretrain: out_pose = '{}/pose.pkl'.format(args.outdir) posetracker.save(out_pose) td = dt.now() - t1 flog('Finished in {} ({} per epoch)'.format( td, td / (num_epochs - start_epoch)))
def main(args): t1 = dt.now() if args.outdir is not None and not os.path.exists(args.outdir): os.makedirs(args.outdir) LOG = f'{args.outdir}/run.log' def flog(msg): # HACK: switch to logging module return utils.flog(msg, LOG) if args.load == 'latest': args = get_latest(args, flog) flog(' '.join(sys.argv)) flog(args) # set the random seed np.random.seed(args.seed) torch.manual_seed(args.seed) ## set the device use_cuda = torch.cuda.is_available() device = torch.device('cuda' if use_cuda else 'cpu') flog('Use cuda {}'.format(use_cuda)) if use_cuda: torch.set_default_tensor_type(torch.cuda.FloatTensor) else: flog('WARNING: No GPUs detected') # load the particles if args.ind is not None: flog('Filtering image dataset with {}'.format(args.ind)) ind = pickle.load(open(args.ind, 'rb')) else: ind = None if args.lazy: data = dataset.LazyMRCData(args.particles, norm=args.norm, invert_data=args.invert_data, ind=ind, window=args.window, datadir=args.datadir, relion31=args.relion31) else: data = dataset.MRCData(args.particles, norm=args.norm, invert_data=args.invert_data, ind=ind, window=args.window, datadir=args.datadir, relion31=args.relion31) D = data.D Nimg = data.N # instantiate model if args.pe_type != 'none': assert args.l_extent == 0.5 lattice = Lattice(D, extent=args.l_extent) activation = {"relu": nn.ReLU, "leaky_relu": nn.LeakyReLU}[args.activation] model = models.get_decoder(3, D, args.layers, args.dim, args.domain, args.pe_type, enc_dim=args.pe_dim, activation=activation) flog(model) flog('{} parameters in model'.format( sum(p.numel() for p in model.parameters() if p.requires_grad))) # optimizer optim = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd) # load weights if args.load: flog('Loading model weights from {}'.format(args.load)) checkpoint = torch.load(args.load) model.load_state_dict(checkpoint['model_state_dict']) optim.load_state_dict(checkpoint['optimizer_state_dict']) start_epoch = checkpoint['epoch'] + 1 assert start_epoch < args.num_epochs else: start_epoch = 0 # load poses if args.do_pose_sgd: assert args.domain == 'hartley', "Need to use --domain hartley if doing pose SGD" posetracker = PoseTracker.load(args.poses, Nimg, D, args.emb_type, ind) pose_optimizer = torch.optim.SparseAdam(posetracker.parameters(), lr=args.pose_lr) else: posetracker = PoseTracker.load(args.poses, Nimg, D, None, ind) # load CTF if args.ctf is not None: flog('Loading ctf params from {}'.format(args.ctf)) ctf_params = ctf.load_ctf_for_training(D - 1, args.ctf) if args.ind is not None: ctf_params = ctf_params[ind] ctf_params = torch.tensor(ctf_params) else: ctf_params = None Apix = ctf_params[0, 0] if ctf_params is not None else 1 # save configuration out_config = f'{args.outdir}/config.pkl' save_config(args, data, lattice, model, out_config) # Mixed precision training with AMP if args.amp: assert args.batch_size % 8 == 0 assert (D - 1) % 8 == 0 assert args.dim % 8 == 0 # Also check zdim, enc_mask dim? model, optim = amp.initialize(model, optim, opt_level='O1') # parallelize if args.multigpu and torch.cuda.device_count() > 1: flog(f'Using {torch.cuda.device_count()} GPUs!') args.batch_size *= torch.cuda.device_count() flog(f'Increasing batch size to {args.batch_size}') model = nn.DataParallel(model) elif args.multigpu: flog( f'WARNING: --multigpu selected, but {torch.cuda.device_count()} GPUs detected' ) # train data_generator = DataLoader(data, batch_size=args.batch_size, shuffle=True) for epoch in range(start_epoch, args.num_epochs): t2 = dt.now() loss_accum = 0 batch_it = 0 for batch, ind in data_generator: batch_it += len(ind) y = batch.to(device) ind = ind.to(device) if args.do_pose_sgd: pose_optimizer.zero_grad() r, t = posetracker.get_pose(ind) c = ctf_params[ind] if ctf_params is not None else None loss_item = train(model, lattice, optim, batch.to(device), r, t, c, use_amp=args.amp) if args.do_pose_sgd and epoch >= args.pretrain: pose_optimizer.step() loss_accum += loss_item * len(ind) if batch_it % args.log_interval == 0: flog( '# [Train Epoch: {}/{}] [{}/{} images] loss={:.6f}'.format( epoch + 1, args.num_epochs, batch_it, Nimg, loss_item)) flog('# =====> Epoch: {} Average loss = {:.6}; Finished in {}'.format( epoch + 1, loss_accum / Nimg, dt.now() - t2)) if args.checkpoint and epoch % args.checkpoint == 0: out_mrc = '{}/reconstruct.{}.mrc'.format(args.outdir, epoch) out_weights = '{}/weights.{}.pkl'.format(args.outdir, epoch) save_checkpoint(model, lattice, optim, epoch, data.norm, Apix, out_mrc, out_weights) if args.do_pose_sgd and epoch >= args.pretrain: out_pose = '{}/pose.{}.pkl'.format(args.outdir, epoch) posetracker.save(out_pose) ## save model weights and evaluate the model on 3D lattice out_mrc = '{}/reconstruct.mrc'.format(args.outdir) out_weights = '{}/weights.pkl'.format(args.outdir) save_checkpoint(model, lattice, optim, epoch, data.norm, Apix, out_mrc, out_weights) if args.do_pose_sgd and epoch >= args.pretrain: out_pose = '{}/pose.pkl'.format(args.outdir) posetracker.save(out_pose) td = dt.now() - t1 flog('Finsihed in {} ({} per epoch)'.format( td, td / (args.num_epochs - start_epoch)))
def main(args): t1 = dt.now() # set the device use_cuda = torch.cuda.is_available() device = torch.device('cuda' if use_cuda else 'cpu') log('Use cuda {}'.format(use_cuda)) if use_cuda: torch.set_default_tensor_type(torch.cuda.FloatTensor) if args.config is not None: args = config.load_config(args.config, args) log(args) beta = args.beta # load the particles if args.ind is not None: log('Filtering image dataset with {}'.format(args.ind)) ind = pickle.load(open(args.ind, 'rb')) else: ind = None if args.tilt is None: if args.encode_mode == 'conv': args.use_real = True if args.lazy: data = dataset.LazyMRCData(args.particles, norm=args.norm, invert_data=args.invert_data, ind=ind, keepreal=args.use_real, window=args.window, datadir=args.datadir) else: data = dataset.MRCData(args.particles, norm=args.norm, invert_data=args.invert_data, ind=ind, keepreal=args.use_real, window=args.window, datadir=args.datadir) tilt = None else: assert args.encode_mode == 'tilt' if args.lazy: raise NotImplementedError data = dataset.TiltMRCData(args.particles, args.tilt, norm=args.norm, invert_data=args.invert_data, ind=ind, window=args.window, keepreal=args.use_real, datadir=args.datadir) tilt = torch.tensor(utils.xrot(args.tilt_deg).astype(np.float32)) Nimg = data.N D = data.D if args.encode_mode == 'conv': assert D - 1 == 64, "Image size must be 64x64 for convolutional encoder" # load poses posetracker = PoseTracker.load(args.poses, Nimg, D, None, ind) # load ctf if args.ctf is not None: if args.use_real: raise NotImplementedError( "Not implemented with real-space encoder. Use phase-flipped images instead" ) log('Loading ctf params from {}'.format(args.ctf)) ctf_params = ctf.load_ctf_for_training(D - 1, args.ctf) if args.ind is not None: ctf_params = ctf_params[ind] ctf_params = torch.tensor(ctf_params) else: ctf_params = None # instantiate model lattice = Lattice(D, extent=0.5) if args.enc_mask is None: args.enc_mask = D // 2 if args.enc_mask > 0: assert args.enc_mask <= D // 2 enc_mask = lattice.get_circular_mask(args.enc_mask) in_dim = enc_mask.sum() elif args.enc_mask == -1: enc_mask = None in_dim = lattice.D**2 if not args.use_real else (lattice.D - 1)**2 else: raise RuntimeError( "Invalid argument for encoder mask radius {}".format( args.enc_mask)) model = HetOnlyVAE(lattice, args.qlayers, args.qdim, args.players, args.pdim, in_dim, args.zdim, encode_mode=args.encode_mode, enc_mask=enc_mask, enc_type=args.pe_type, enc_dim=args.pe_dim, domain=args.domain) log('Loading weights from {}'.format(args.weights)) checkpoint = torch.load(args.weights) model.load_state_dict(checkpoint['model_state_dict']) model.eval() z_mu_all = [] z_logvar_all = [] gen_loss_accum = 0 kld_accum = 0 loss_accum = 0 batch_it = 0 data_generator = DataLoader(data, batch_size=args.batch_size, shuffle=False) for minibatch in data_generator: ind = minibatch[-1].to(device) y = minibatch[0].to(device) yt = minibatch[1].to(device) if tilt is not None else None B = len(ind) batch_it += B yr = torch.from_numpy( data.particles_real[ind]).to(device) if args.use_real else None rot, tran = posetracker.get_pose(ind) ctf_param = ctf_params[ind] if ctf_params is not None else None z_mu, z_logvar, loss, gen_loss, kld = eval_batch(model, lattice, y, yt, rot, tran, beta, tilt, ctf_params=ctf_param, yr=yr) z_mu_all.append(z_mu) z_logvar_all.append(z_logvar) # logging gen_loss_accum += gen_loss * B kld_accum += kld * B loss_accum += loss * B if batch_it % args.log_interval == 0: log('# [{}/{} images] gen loss={:.4f}, kld={:.4f}, beta={:.4f}, loss={:.4f}' .format(batch_it, Nimg, gen_loss, kld, beta, loss)) log('# =====> Average gen loss = {:.6}, KLD = {:.6f}, total loss = {:.6f}'. format(gen_loss_accum / Nimg, kld_accum / Nimg, loss_accum / Nimg)) z_mu_all = np.vstack(z_mu_all) z_logvar_all = np.vstack(z_logvar_all) with open(args.o, 'wb') as f: pickle.dump(z_mu_all, f) pickle.dump(z_logvar_all, f) pickle.dump([loss_accum, gen_loss_accum, kld_accum], f) log('Finsihed in {}'.format(dt.now() - t1))
def main(args): t1 = dt.now() # make output directories if not os.path.exists(os.path.dirname(args.o)): os.makedirs(os.path.dirname(args.o)) if not os.path.exists(os.path.dirname(args.out_z)): os.makedirs(os.path.dirname(args.out_z)) # set the device use_cuda = torch.cuda.is_available() device = torch.device('cuda' if use_cuda else 'cpu') log('Use cuda {}'.format(use_cuda)) if not use_cuda: log('WARNING: No GPUs detected') log(args) cfg = config.overwrite_config(args.config, args) log('Loaded configuration:') pprint.pprint(cfg) zdim = cfg['model_args']['zdim'] beta = 1. / zdim if args.beta is None else args.beta # load the particles if args.ind is not None: log('Filtering image dataset with {}'.format(args.ind)) ind = pickle.load(open(args.ind, 'rb')) else: ind = None # TODO: extract dataset arguments from cfg if args.tilt is None: if args.encode_mode == 'conv': args.use_real = True if args.lazy: data = dataset.LazyMRCData(args.particles, norm=args.norm, invert_data=args.invert_data, ind=ind, keepreal=args.use_real, window=args.window, datadir=args.datadir, window_r=args.window_r) else: data = dataset.MRCData(args.particles, norm=args.norm, invert_data=args.invert_data, ind=ind, keepreal=args.use_real, window=args.window, datadir=args.datadir, window_r=args.window_r) tilt = None else: assert args.encode_mode == 'tilt' if args.lazy: raise NotImplementedError data = dataset.TiltMRCData(args.particles, args.tilt, norm=args.norm, invert_data=args.invert_data, ind=ind, window=args.window, keepreal=args.use_real, datadir=args.datadir, window_r=args.window_r) tilt = torch.tensor(utils.xrot(args.tilt_deg).astype(np.float32)) Nimg = data.N D = data.D if args.encode_mode == 'conv': assert D - 1 == 64, "Image size must be 64x64 for convolutional encoder" # load poses posetracker = PoseTracker.load(args.poses, Nimg, D, None, ind, device=device) # load ctf if args.ctf is not None: if args.use_real: raise NotImplementedError( "Not implemented with real-space encoder. Use phase-flipped images instead" ) log('Loading ctf params from {}'.format(args.ctf)) ctf_params = ctf.load_ctf_for_training(D - 1, args.ctf) if args.ind is not None: ctf_params = ctf_params[ind] ctf_params = torch.tensor(ctf_params, device=device) else: ctf_params = None # instantiate model model, lattice = HetOnlyVAE.load(cfg, args.weights, device=device) model.eval() z_mu_all = [] z_logvar_all = [] gen_loss_accum = 0 kld_accum = 0 loss_accum = 0 batch_it = 0 data_generator = DataLoader(data, batch_size=args.batch_size, shuffle=False) for minibatch in data_generator: ind = minibatch[-1].to(device) y = minibatch[0].to(device) yt = minibatch[1].to(device) if tilt is not None else None B = len(ind) batch_it += B yr = torch.from_numpy( data.particles_real[ind]).to(device) if args.use_real else None rot, tran = posetracker.get_pose(ind) ctf_param = ctf_params[ind] if ctf_params is not None else None z_mu, z_logvar, loss, gen_loss, kld = eval_batch(model, lattice, y, yt, rot, tran, beta, tilt, ctf_params=ctf_param, yr=yr) z_mu_all.append(z_mu) z_logvar_all.append(z_logvar) # logging gen_loss_accum += gen_loss * B kld_accum += kld * B loss_accum += loss * B if batch_it % args.log_interval == 0: log('# [{}/{} images] gen loss={:.4f}, kld={:.4f}, beta={:.4f}, loss={:.4f}' .format(batch_it, Nimg, gen_loss, kld, beta, loss)) log('# =====> Average gen loss = {:.6}, KLD = {:.6f}, total loss = {:.6f}'. format(gen_loss_accum / Nimg, kld_accum / Nimg, loss_accum / Nimg)) z_mu_all = np.vstack(z_mu_all) z_logvar_all = np.vstack(z_logvar_all) with open(args.out_z, 'wb') as f: pickle.dump(z_mu_all, f) pickle.dump(z_logvar_all, f) with open(args.o, 'wb') as f: pickle.dump( { 'loss': loss_accum / Nimg, 'recon': gen_loss_accum / Nimg, 'kld': kld_accum / Nimg }, f) log('Finished in {}'.format(dt.now() - t1))
def main(args): log(args) t1 = dt.now() if not os.path.exists(args.outdir): os.makedirs(args.outdir) # set the random seed np.random.seed(args.seed) torch.manual_seed(args.seed) ## set the device use_cuda = torch.cuda.is_available() device = torch.device('cuda' if use_cuda else 'cpu') log('Use cuda {}'.format(use_cuda)) if use_cuda: torch.set_default_tensor_type(torch.cuda.FloatTensor) # load the particles if args.ind is not None: log('Filtering image dataset with {}'.format(args.ind)) ind = pickle.load(open(args.ind, 'rb')) else: ind = None if args.lazy: data = dataset.LazyMRCData(args.particles, norm=args.norm, invert_data=args.invert_data, ind=ind, window=args.window, datadir=args.datadir) else: data = dataset.MRCData(args.particles, norm=args.norm, invert_data=args.invert_data, ind=ind, window=args.window, datadir=args.datadir) D = data.D Nimg = data.N # instantiate model if args.pe_type != 'none': assert args.l_extent == 0.5 lattice = Lattice(D, extent=args.l_extent) model = models.get_decoder(3, D, args.layers, args.dim, args.domain, args.pe_type, nn.ReLU) log(model) log('{} parameters in model'.format( sum(p.numel() for p in model.parameters() if p.requires_grad))) # optimizer optim = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd) # load weights if args.load: log('Loading model weights from {}'.format(args.load)) checkpoint = torch.load(args.load) model.load_state_dict(checkpoint['model_state_dict']) optim.load_state_dict(checkpoint['optimizer_state_dict']) start_epoch = checkpoint['epoch'] + 1 assert start_epoch < args.num_epochs else: start_epoch = 0 # load poses if args.do_pose_sgd: posetracker = PoseTracker.load(args.poses, Nimg, D, args.emb_type, ind) pose_optimizer = torch.optim.SparseAdam(posetracker.parameters(), lr=args.pose_lr) else: posetracker = PoseTracker.load(args.poses, Nimg, D, None, ind) # load CTF if args.ctf is not None: log('Loading ctf params from {}'.format(args.ctf)) ctf_params = ctf.load_ctf_for_training(D - 1, args.ctf) if args.ind is not None: ctf_params = ctf_params[ind] ctf_params = torch.tensor(ctf_params) else: ctf_params = None Apix = ctf_params[0, 0] if ctf_params is not None else 1 # train data_generator = DataLoader(data, batch_size=args.batch_size, shuffle=True) for epoch in range(start_epoch, args.num_epochs): t2 = dt.now() loss_accum = 0 batch_it = 0 for batch, ind in data_generator: batch_it += len(ind) y = batch.to(device) ind = ind.to(device) if args.do_pose_sgd: pose_optimizer.zero_grad() r, t = posetracker.get_pose(ind) c = ctf_params[ind] if ctf_params is not None else None loss_item = train(model, lattice, optim, batch.to(device), r, t, c) if args.do_pose_sgd and epoch >= args.pretrain: pose_optimizer.step() loss_accum += loss_item * len(ind) if batch_it % args.log_interval == 0: log('# [Train Epoch: {}/{}] [{}/{} images] loss={:.6f}'.format( epoch + 1, args.num_epochs, batch_it, Nimg, loss_item)) log('# =====> Epoch: {} Average loss = {:.6}; Finished in {}'.format( epoch + 1, loss_accum / Nimg, dt.now() - t2)) if args.checkpoint and epoch % args.checkpoint == 0: out_mrc = '{}/reconstruct.{}.mrc'.format(args.outdir, epoch) out_weights = '{}/weights.{}.pkl'.format(args.outdir, epoch) save_checkpoint(model, lattice, optim, epoch, data.norm, Apix, out_mrc, out_weights) if args.do_pose_sgd and epoch >= args.pretrain: out_pose = '{}/pose.{}.pkl'.format(args.outdir, epoch) posetracker.save(out_pose) ## save model weights and evaluate the model on 3D lattice out_mrc = '{}/reconstruct.mrc'.format(args.outdir) out_weights = '{}/weights.pkl'.format(args.outdir) save_checkpoint(model, lattice, optim, epoch, data.norm, Apix, out_mrc, out_weights) if args.do_pose_sgd and epoch >= args.pretrain: out_pose = '{}/pose.pkl'.format(args.outdir) posetracker.save(out_pose) td = dt.now() - t1 log('Finsihed in {} ({} per epoch)'.format( td, td / (args.num_epochs - start_epoch)))
def main(args): assert args.o.endswith('.mrc') t1 = time.time() log(args) if not os.path.exists(os.path.dirname(args.o)): os.makedirs(os.path.dirname(args.o)) ## set the device use_cuda = torch.cuda.is_available() device = torch.device('cuda' if use_cuda else 'cpu') log('Use cuda {}'.format(use_cuda)) if use_cuda: torch.set_default_tensor_type(torch.cuda.FloatTensor) else: log('WARNING: No GPUs detected') # load the particles if args.tilt is None: data = dataset.LazyMRCData(args.particles, norm=(0, 1), invert_data=args.invert_data, datadir=args.datadir) tilt = None else: data = dataset.TiltMRCData(args.particles, args.tilt, norm=(0, 1), invert_data=args.invert_data, datadir=args.datadir) tilt = torch.tensor(utils.xrot(args.tilt_deg).astype(np.float32)) D = data.D Nimg = data.N lattice = Lattice(D, extent=D // 2) posetracker = PoseTracker.load(args.poses, Nimg, D, None, None) if args.ctf is not None: log('Loading ctf params from {}'.format(args.ctf)) ctf_params = ctf.load_ctf_for_training(D - 1, args.ctf) ctf_params = torch.tensor(ctf_params) else: ctf_params = None Apix = ctf_params[0, 0] if ctf_params is not None else 1 V = torch.zeros((D, D, D)) counts = torch.zeros((D, D, D)) mask = lattice.get_circular_mask(D // 2) if args.ind: iterator = pickle.load(open(args.ind, 'rb')) elif args.first: args.first = min(args.first, Nimg) iterator = range(args.first) else: iterator = range(Nimg) for ii in iterator: if ii % 100 == 0: log('image {}'.format(ii)) r, t = posetracker.get_pose(ii) ff = data.get(ii) if tilt is not None: ff, ff_tilt = ff # EW ff = torch.tensor(ff) ff = ff.view(-1)[mask] if ctf_params is not None: freqs = lattice.freqs2d / ctf_params[ii, 0] c = ctf.compute_ctf(freqs, *ctf_params[ii, 1:]).view(-1)[mask] ff *= c.sign() if t is not None: ff = lattice.translate_ht(ff.view(1, -1), t.view(1, 1, 2), mask).view(-1) ff_coord = lattice.coords[mask] @ r add_slice(V, counts, ff_coord, ff, D) # tilt series if args.tilt is not None: ff_tilt = torch.tensor(ff_tilt) ff_tilt = ff_tilt.view(-1)[mask] if ctf_params is not None: ff_tilt *= c.sign() if t is not None: ff_tilt = lattice.translate_ht(ff_tilt.view(1, -1), t.view(1, 1, 2), mask).view(-1) ff_coord = lattice.coords[mask] @ tilt @ r add_slice(V, counts, ff_coord, ff_tilt, D) td = time.time() - t1 log('Backprojected {} images in {}s ({}s per image)'.format( len(iterator), td, td / Nimg)) counts[counts == 0] = 1 V /= counts V = fft.ihtn_center(V[0:-1, 0:-1, 0:-1].cpu().numpy()) mrc.write(args.o, V.astype('float32'), Apix=Apix)
def main(args): t1 = dt.now() if args.outdir is not None and not os.path.exists(args.outdir): os.makedirs(args.outdir) LOG = f'{args.outdir}/run.log' def flog(msg): # HACK: switch to logging module return utils.flog(msg, LOG) if args.load == 'latest': args = get_latest(args) flog(' '.join(sys.argv)) flog(args) # set the random seed np.random.seed(args.seed) torch.manual_seed(args.seed) # set the device use_cuda = torch.cuda.is_available() device = torch.device('cuda' if use_cuda else 'cpu') flog('Use cuda {}'.format(use_cuda)) if use_cuda: torch.set_default_tensor_type(torch.cuda.FloatTensor) # set beta schedule try: args.beta = float(args.beta) except ValueError: assert args.beta_control, "Need to set beta control weight for schedule {}".format( args.beta) beta_schedule = get_beta_schedule(args.beta) # load the particles if args.ind is not None: flog('Filtering image dataset with {}'.format(args.ind)) ind = pickle.load(open(args.ind, 'rb')) else: ind = None if args.tilt is None: if args.encode_mode == 'conv': args.use_real = True if args.lazy: data = dataset.LazyMRCData(args.particles, norm=args.norm, invert_data=args.invert_data, ind=ind, keepreal=args.use_real, window=args.window, datadir=args.datadir, relion31=args.relion31) else: data = dataset.MRCData(args.particles, norm=args.norm, invert_data=args.invert_data, ind=ind, keepreal=args.use_real, window=args.window, datadir=args.datadir, relion31=args.relion31) tilt = None else: assert args.encode_mode == 'tilt' if args.lazy: raise NotImplementedError if args.relion31: raise NotImplementedError data = dataset.TiltMRCData(args.particles, args.tilt, norm=args.norm, invert_data=args.invert_data, ind=ind, window=args.window, keepreal=args.use_real, datadir=args.datadir) tilt = torch.tensor(utils.xrot(args.tilt_deg).astype(np.float32)) Nimg = data.N D = data.D if args.encode_mode == 'conv': assert D - 1 == 64, "Image size must be 64x64 for convolutional encoder" # load poses if args.do_pose_sgd: assert args.domain == 'hartley', "Need to use --domain hartley if doing pose SGD" do_pose_sgd = args.do_pose_sgd posetracker = PoseTracker.load(args.poses, Nimg, D, 's2s2' if do_pose_sgd else None, ind) pose_optimizer = torch.optim.SparseAdam( posetracker.parameters(), lr=args.pose_lr) if do_pose_sgd else None # load ctf if args.ctf is not None: if args.use_real: raise NotImplementedError( "Not implemented with real-space encoder. Use phase-flipped images instead" ) flog('Loading ctf params from {}'.format(args.ctf)) ctf_params = ctf.load_ctf_for_training(D - 1, args.ctf) if args.ind is not None: ctf_params = ctf_params[ind] ctf_params = torch.tensor(ctf_params) else: ctf_params = None # instantiate model lattice = Lattice(D, extent=0.5) if args.enc_mask is None: args.enc_mask = D // 2 if args.enc_mask > 0: assert args.enc_mask <= D // 2 enc_mask = lattice.get_circular_mask(args.enc_mask) in_dim = enc_mask.sum() elif args.enc_mask == -1: enc_mask = None in_dim = lattice.D**2 if not args.use_real else (lattice.D - 1)**2 else: raise RuntimeError( "Invalid argument for encoder mask radius {}".format( args.enc_mask)) model = HetOnlyVAE(lattice, args.qlayers, args.qdim, args.players, args.pdim, in_dim, args.zdim, encode_mode=args.encode_mode, enc_mask=enc_mask, enc_type=args.pe_type, enc_dim=args.pe_dim, domain=args.domain) flog(model) flog('{} parameters in model'.format( sum(p.numel() for p in model.parameters() if p.requires_grad))) # save configuration out_config = '{}/config.pkl'.format(args.outdir) save_config(args, data, lattice, model, out_config) optim = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd) # restart from checkpoint if args.load: flog('Loading checkpoint from {}'.format(args.load)) checkpoint = torch.load(args.load) model.load_state_dict(checkpoint['model_state_dict']) optim.load_state_dict(checkpoint['optimizer_state_dict']) start_epoch = checkpoint['epoch'] + 1 model.train() else: start_epoch = 0 # training loop data_generator = DataLoader(data, batch_size=args.batch_size, shuffle=True) num_epochs = args.num_epochs for epoch in range(start_epoch, num_epochs): t2 = dt.now() gen_loss_accum = 0 loss_accum = 0 kld_accum = 0 eq_loss_accum = 0 batch_it = 0 for minibatch in data_generator: ind = minibatch[-1].to(device) y = minibatch[0].to(device) yt = minibatch[1].to(device) if tilt is not None else None B = len(ind) batch_it += B global_it = Nimg * epoch + batch_it beta = beta_schedule(global_it) yr = torch.from_numpy(data.particles_real[ind.numpy()]).to( device) if args.use_real else None if do_pose_sgd: pose_optimizer.zero_grad() rot, tran = posetracker.get_pose(ind) ctf_param = ctf_params[ind] if ctf_params is not None else None loss, gen_loss, kld = train_batch(model, lattice, y, yt, rot, tran, optim, beta, args.beta_control, tilt, ctf_params=ctf_param, yr=yr) if do_pose_sgd and epoch >= args.pretrain: pose_optimizer.step() # logging gen_loss_accum += gen_loss * B kld_accum += kld * B loss_accum += loss * B if batch_it % args.log_interval == 0: log('# [Train Epoch: {}/{}] [{}/{} images] gen loss={:.6f}, kld={:.6f}, beta={:.6f}, loss={:.6f}' .format(epoch + 1, num_epochs, batch_it, Nimg, gen_loss, kld, beta, loss)) flog( '# =====> Epoch: {} Average gen loss = {:.6}, KLD = {:.6f}, total loss = {:.6f}; Finished in {}' .format(epoch + 1, gen_loss_accum / Nimg, kld_accum / Nimg, loss_accum / Nimg, dt.now() - t2)) if args.checkpoint and epoch % args.checkpoint == 0: out_weights = '{}/weights.{}.pkl'.format(args.outdir, epoch) out_z = '{}/z.{}.pkl'.format(args.outdir, epoch) model.eval() with torch.no_grad(): z_mu, z_logvar = eval_z(model, lattice, data, args.batch_size, device, posetracker.trans, tilt is not None, ctf_params, args.use_real) save_checkpoint(model, optim, epoch, z_mu, z_logvar, out_weights, out_z) if args.do_pose_sgd and epoch >= args.pretrain: out_pose = '{}/pose.{}.pkl'.format(args.outdir, epoch) posetracker.save(out_pose) # save model weights, latent encoding, and evaluate the model on 3D lattice out_weights = '{}/weights.pkl'.format(args.outdir) out_z = '{}/z.pkl'.format(args.outdir) model.eval() with torch.no_grad(): z_mu, z_logvar = eval_z(model, lattice, data, args.batch_size, device, posetracker.trans, tilt is not None, ctf_params, args.use_real) save_checkpoint(model, optim, epoch, z_mu, z_logvar, out_weights, out_z) if args.do_pose_sgd and epoch >= args.pretrain: out_pose = '{}/pose.pkl'.format(args.outdir) posetracker.save(out_pose) td = dt.now() - t1 flog('Finsihed in {} ({} per epoch)'.format( td, td / (num_epochs - start_epoch)))