def main(args): # gpu or cpu device = torch.device( 'cuda') if torch.cuda.is_available() else torch.device('cpu') args = utils.setup_experiment(args) utils.init_logging(args) # Loading models MODEL_PATH_LOAD = "../lidar_experiments/2d/lidar_unet2d/lidar-unet2d-Nov-08-16:29:49/checkpoints/checkpoint_best.pt" train_new_model = True # Build data loaders, a model and an optimizer if train_new_model: model = models.build_model(args).to(device) else: model = models.build_model(args) model.load_state_dict(torch.load(args.MODEL_PATH_LOAD)['model'][0]) model.to(device) print(model) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=[5, 15, 30, 50, 100, 250], gamma=0.5) logging.info( f"Built a model consisting of {sum(p.numel() for p in model.parameters()):,} parameters" ) if args.resume_training: state_dict = utils.load_checkpoint(args, model, optimizer, scheduler) global_step = state_dict['last_step'] start_epoch = int(state_dict['last_step'] / (403200 / state_dict['args'].batch_size)) + 1 else: global_step = -1 start_epoch = 0 ## Load the pts files # Loads as a list of numpy arrays scan_line_tensor = torch.load(args.data_path + 'scan_line_tensor.pts') train_idx_list = torch.load(args.data_path + 'train_idx_list.pts') valid_idx_list = torch.load(args.data_path + 'valid_idx_list.pts') sc = torch.load(args.data_path + 'sc.pts') # Dataloaders train_dataset = LidarLstmDataset(scan_line_tensor, train_idx_list, args.seq_len, args.mask_pts_per_seq) valid_dataset = LidarLstmDataset(scan_line_tensor, valid_idx_list, args.seq_len, args.mask_pts_per_seq) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, num_workers=4, shuffle=True) valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, num_workers=4, shuffle=True) # Track moving average of loss values train_meters = { name: utils.RunningAverageMeter(0.98) for name in (["train_loss"]) } valid_meters = {name: utils.AverageMeter() for name in (["valid_loss"])} writer = SummaryWriter( log_dir=args.experiment_dir) if not args.no_visual else None ################################################## # TRAINING for epoch in range(start_epoch, args.num_epochs): if args.resume_training: if epoch % 1 == 0: optimizer.param_groups[0]["lr"] /= 2 print('learning rate reduced by factor of 2') train_bar = utils.ProgressBar(train_loader, epoch) for meter in train_meters.values(): meter.reset() # epoch_loss_sum = 0 for batch_id, (clean, mask) in enumerate(train_bar): # dataloader returns [clean, mask] list model.train() global_step += 1 inputs = clean.to(device) mask_inputs = mask.to(device) # only use the mask part of the outputs raw_outputs = model(inputs, mask_inputs) outputs = ( 1 - mask_inputs[:, :3, :, :] ) * raw_outputs + mask_inputs[:, :3, :, :] * inputs[:, :3, :, :] if args.wtd_loss: loss = weighted_MSELoss(outputs, inputs[:, :3, :, :], sc) / (inputs.size(0) * (args.mask_pts_per_seq**2)) # Regularization? else: # normalized by the number of masked points loss = F.mse_loss(outputs, inputs[:,:3,:,:], reduction="sum") / \ (inputs.size(0) * (args.mask_pts_per_seq**2)) model.zero_grad() loss.backward() optimizer.step() # epoch_loss_sum += loss * inputs.size(0) train_meters["train_loss"].update(loss) train_bar.log(dict(**train_meters, lr=optimizer.param_groups[0]["lr"]), verbose=True) if writer is not None and global_step % args.log_interval == 0: writer.add_scalar("lr", optimizer.param_groups[0]["lr"], global_step) writer.add_scalar("loss/train", loss.item(), global_step) gradients = torch.cat([ p.grad.view(-1) for p in model.parameters() if p.grad is not None ], dim=0) writer.add_histogram("gradients", gradients, global_step) sys.stdout.flush() # epoch_loss = epoch_loss_sum / len(train_loader.dataset) if epoch % args.valid_interval == 0: model.eval() for meter in valid_meters.values(): meter.reset() valid_bar = utils.ProgressBar(valid_loader) val_loss = 0 for sample_id, (clean, mask) in enumerate(valid_bar): with torch.no_grad(): inputs = clean.to(device) mask_inputs = mask.to(device) # only use the mask part of the outputs raw_output = model(inputs, mask_inputs) output = ( 1 - mask_inputs[:, :3, :, :] ) * raw_output + mask_inputs[:, :3, :, :] * inputs[:, : 3, :, :] # TO DO, only run loss on masked part of output if args.wtd_loss: val_loss = weighted_MSELoss( output, inputs[:, :3, :, :], sc) / (inputs.size(0) * (args.mask_pts_per_seq**2)) else: # normalized by the number of masked points val_loss = F.mse_loss(output, inputs[:,:3,:,:], reduction="sum")/(inputs.size(0)* \ (args.mask_pts_per_seq**2)) valid_meters["valid_loss"].update(val_loss.item()) if writer is not None: writer.add_scalar("loss/valid", valid_meters['valid_loss'].avg, global_step) sys.stdout.flush() logging.info( train_bar.print( dict(**train_meters, **valid_meters, lr=optimizer.param_groups[0]["lr"]))) utils.save_checkpoint(args, global_step, model, optimizer, score=valid_meters["valid_loss"].avg, mode="min") scheduler.step() logging.info( f"Done training! Best Loss {utils.save_checkpoint.best_score:.3f} obtained after step {utils.save_checkpoint.best_step}." )
if __name__ == '__main__': regularization_fns, regularization_coeffs = create_regularization_fns(args) model = build_model_tabular(args, 1, regularization_fns).to(device) if args.spectral_norm: add_spectral_norm(model) set_cnf_options(args, model) #logger.info(model) logger.info("Number of trainable parameters: {}".format( count_parameters(model))) optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) time_meter = utils.RunningAverageMeter(0.93) loss_meter = utils.RunningAverageMeter(0.93) nfef_meter = utils.RunningAverageMeter(0.93) nfeb_meter = utils.RunningAverageMeter(0.93) tt_meter = utils.RunningAverageMeter(0.93) end = time.time() best_loss = float('inf') model.train() # Get truth and reco data xTrain, xTest, uniformityWeightsTrain, uniformityWeightsTest, yTrain, yTest, inputScaler, outputScaler = prepMatchedData( args.h5Path, args.nTrainSamp, args.nTestSamp,
def main(args): if not torch.cuda.is_available(): raise NotImplementedError("Training on CPU is not supported.") utils.setup_experiment(args) utils.init_logging(args) train_loaders, valid_loaders = data.build_dataset( args.dataset, args.data_path, batch_size=args.batch_size) model = models.build_model(args).cuda() optimizer = optim.build_optimizer(args, model.parameters()) logging.info( f"Built a model consisting of {sum(p.numel() for p in model.parameters() if p.requires_grad):,} parameters" ) meters = { name: utils.RunningAverageMeter(0.98) for name in (["loss", "context", "graph", "target"]) } acc_names = ["overall" ] + [f"task{idx}" for idx in range(len(valid_loaders))] acc_meters = {name: utils.AverageMeter() for name in acc_names} writer = SummaryWriter( log_dir=args.experiment_dir) if not args.no_visual else None global_step = -1 for epoch in range(args.num_epochs): acc_tasks = {f"task{idx}": None for idx in range(len(valid_loaders))} for task_id, train_loader in enumerate(train_loaders): for repeat in range(args.num_repeats_per_task): train_bar = utils.ProgressBar(train_loader, epoch, prefix=f"task {task_id}") for meter in meters.values(): meter.reset() for batch_id, (images, labels) in enumerate(train_bar): model.train() global_step += 1 images, labels = images.cuda(), labels.cuda() outputs = model(images, labels, task_id=task_id) if global_step == 0: continue loss = outputs["loss"] model.zero_grad() loss.backward() optimizer.step() meters["loss"].update(loss.item()) meters["context"].update(outputs["context_loss"].item()) meters["target"].update(outputs["target_loss"].item()) meters["graph"].update(outputs["graph_loss"].item()) train_bar.log(dict( **meters, lr=optimizer.get_lr(), )) if writer is not None: writer.add_scalar("loss/train", loss.item(), global_step) gradients = torch.cat([ p.grad.view(-1) for p in model.parameters() if p.grad is not None ], dim=0) writer.add_histogram("gradients", gradients, global_step) model.eval() for meter in acc_meters.values(): meter.reset() for idx, valid_loader in enumerate(valid_loaders): valid_bar = utils.ProgressBar(valid_loader, epoch, prefix=f"task {task_id}") for batch_id, (images, labels) in enumerate(valid_bar): model.eval() with torch.no_grad(): images, labels = images.cuda(), labels.cuda() outputs = model.predict(images, labels, task_id=idx) correct = outputs["preds"].eq(labels).sum().item() acc_meters[f"task{idx}"].update(100 * correct, n=len(images)) acc_meters["overall"].update(acc_meters[f"task{idx}"].avg) acc_tasks[f"task{task_id}"] = acc_meters[f"task{task_id}"].avg if writer is not None: for name, meter in acc_meters.items(): writer.add_scalar(f"accuracy/{name}", meter.avg, global_step) logging.info( train_bar.print( dict(**meters, **acc_meters, lr=optimizer.get_lr()))) utils.save_checkpoint(args, global_step, model, optimizer, score=acc_meters["overall"].avg, mode="max") bwt = sum(acc_meters[task].avg - acc for task, acc in acc_tasks.items()) / (len(valid_loaders) - 1) logging.info( f"Done training! Final accuracy {acc_meters['overall'].avg:.4f}, backward transfer {bwt:.4f}." )
def main(args): device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') utils.setup_experiment(args) utils.init_logging(args) # Build data loaders, a model and an optimizer model = models.build_model(args).to(device) print(model) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50, 60, 70, 80, 90, 100], gamma=0.5) logging.info(f"Built a model consisting of {sum(p.numel() for p in model.parameters()):,} parameters") if args.resume_training: state_dict = utils.load_checkpoint(args, model, optimizer, scheduler) global_step = state_dict['last_step'] start_epoch = int(state_dict['last_step']/(403200/state_dict['args'].batch_size))+1 else: global_step = -1 start_epoch = 0 train_loader, valid_loader, _ = data.build_dataset(args.dataset, args.data_path, batch_size=args.batch_size) # Track moving average of loss values train_meters = {name: utils.RunningAverageMeter(0.98) for name in (["train_loss", "train_psnr", "train_ssim"])} valid_meters = {name: utils.AverageMeter() for name in (["valid_psnr", "valid_ssim"])} writer = SummaryWriter(log_dir=args.experiment_dir) if not args.no_visual else None for epoch in range(start_epoch, args.num_epochs): if args.resume_training: if epoch %10 == 0: optimizer.param_groups[0]["lr"] /= 2 print('learning rate reduced by factor of 2') train_bar = utils.ProgressBar(train_loader, epoch) for meter in train_meters.values(): meter.reset() for batch_id, inputs in enumerate(train_bar): model.train() global_step += 1 inputs = inputs.to(device) noise = utils.get_noise(inputs, mode = args.noise_mode, min_noise = args.min_noise/255., max_noise = args.max_noise/255., noise_std = args.noise_std/255.) noisy_inputs = noise + inputs; outputs = model(noisy_inputs) loss = F.mse_loss(outputs, inputs, reduction="sum") / (inputs.size(0) * 2) model.zero_grad() loss.backward() optimizer.step() train_psnr = utils.psnr(outputs, inputs) train_ssim = utils.ssim(outputs, inputs) train_meters["train_loss"].update(loss.item()) train_meters["train_psnr"].update(train_psnr.item()) train_meters["train_ssim"].update(train_ssim.item()) train_bar.log(dict(**train_meters, lr=optimizer.param_groups[0]["lr"]), verbose=True) if writer is not None and global_step % args.log_interval == 0: writer.add_scalar("lr", optimizer.param_groups[0]["lr"], global_step) writer.add_scalar("loss/train", loss.item(), global_step) writer.add_scalar("psnr/train", train_psnr.item(), global_step) writer.add_scalar("ssim/train", train_ssim.item(), global_step) gradients = torch.cat([p.grad.view(-1) for p in model.parameters() if p.grad is not None], dim=0) writer.add_histogram("gradients", gradients, global_step) sys.stdout.flush() if epoch % args.valid_interval == 0: model.eval() for meter in valid_meters.values(): meter.reset() valid_bar = utils.ProgressBar(valid_loader) for sample_id, sample in enumerate(valid_bar): with torch.no_grad(): sample = sample.to(device) noise = utils.get_noise(sample, mode = 'S', noise_std = (args.min_noise + args.max_noise)/(2*255.)) noisy_inputs = noise + sample; output = model(noisy_inputs) valid_psnr = utils.psnr(output, sample) valid_meters["valid_psnr"].update(valid_psnr.item()) valid_ssim = utils.ssim(output, sample) valid_meters["valid_ssim"].update(valid_ssim.item()) if writer is not None and sample_id < 10: image = torch.cat([sample, noisy_inputs, output], dim=0) image = torchvision.utils.make_grid(image.clamp(0, 1), nrow=3, normalize=False) writer.add_image(f"valid_samples/{sample_id}", image, global_step) if writer is not None: writer.add_scalar("psnr/valid", valid_meters['valid_psnr'].avg, global_step) writer.add_scalar("ssim/valid", valid_meters['valid_ssim'].avg, global_step) sys.stdout.flush() logging.info(train_bar.print(dict(**train_meters, **valid_meters, lr=optimizer.param_groups[0]["lr"]))) utils.save_checkpoint(args, global_step, model, optimizer, score=valid_meters["valid_psnr"].avg, mode="max") scheduler.step() logging.info(f"Done training! Best PSNR {utils.save_checkpoint.best_score:.3f} obtained after step {utils.save_checkpoint.best_step}.")
else: kernel_optimizer = optim.Adam(list(kernel_net.parameters()) + [log_bandwidth], lr=args.lr, betas=(.5, .9), weight_decay=args.critic_weight_decay) if args.kernel == "neural": if args.k_dim == 1: encoder_fn = lambda x: kernel_net(x)[:, None] else: encoder_fn = kernel_net else: encoder_fn = lambda x: x time_meter = utils.RunningAverageMeter(0.98) loss_meter = utils.RunningAverageMeter(0.98) ebm_meter = utils.RunningAverageMeter(0.98) def sample_data(): if args.fixed_dataset: inds = list(range(args.batch_size)) np.random.shuffle(inds) inds = torch.from_numpy(inds) return fixed_data[inds] else: return trueICA.sample(args.batch_size) best_loss = float('inf') modelICA.train() end = time.time()
def _main(rank, world_size, args, savepath, logger): if rank == 0: logger.info(args) logger.info(f"Saving to {savepath}") tb_writer = SummaryWriter(os.path.join(savepath, "tb_logdir")) device = torch.device( f'cuda:{rank:d}' if torch.cuda.is_available() else 'cpu') if rank == 0: if device.type == 'cuda': logger.info('Found {} CUDA devices.'.format( torch.cuda.device_count())) for i in range(torch.cuda.device_count()): props = torch.cuda.get_device_properties(i) logger.info('{} \t Memory: {:.2f}GB'.format( props.name, props.total_memory / (1024**3))) else: logger.info('WARNING: Using device {}'.format(device)) t0, t1 = map(lambda x: cast(x, device), get_t0_t1(args.data)) train_set = load_data(args.data, split="train") val_set = load_data(args.data, split="val") test_set = load_data(args.data, split="test") train_epoch_iter = EpochBatchIterator( dataset=train_set, collate_fn=datasets.spatiotemporal_events_collate_fn, batch_sampler=train_set.batch_by_size(args.max_events), seed=args.seed + rank, ) val_loader = torch.utils.data.DataLoader( val_set, batch_size=args.test_bsz, shuffle=False, collate_fn=datasets.spatiotemporal_events_collate_fn, ) test_loader = torch.utils.data.DataLoader( test_set, batch_size=args.test_bsz, shuffle=False, collate_fn=datasets.spatiotemporal_events_collate_fn, ) if rank == 0: logger.info( f"{len(train_set)} training examples, {len(val_set)} val examples, {len(test_set)} test examples" ) x_dim = get_dim(args.data) if args.model == "jumpcnf" and args.tpp == "neural": model = JumpCNFSpatiotemporalModel( dim=x_dim, hidden_dims=list(map(int, args.hdims.split("-"))), tpp_hidden_dims=list(map(int, args.tpp_hdims.split("-"))), actfn=args.actfn, tpp_cond=args.tpp_cond, tpp_style=args.tpp_style, tpp_actfn=args.tpp_actfn, share_hidden=args.share_hidden, solve_reverse=args.solve_reverse, tol=args.tol, otreg_strength=args.otreg_strength, tpp_otreg_strength=args.tpp_otreg_strength, layer_type=args.layer_type, ).to(device) elif args.model == "attncnf" and args.tpp == "neural": model = SelfAttentiveCNFSpatiotemporalModel( dim=x_dim, hidden_dims=list(map(int, args.hdims.split("-"))), tpp_hidden_dims=list(map(int, args.tpp_hdims.split("-"))), actfn=args.actfn, tpp_cond=args.tpp_cond, tpp_style=args.tpp_style, tpp_actfn=args.tpp_actfn, share_hidden=args.share_hidden, solve_reverse=args.solve_reverse, l2_attn=args.l2_attn, tol=args.tol, otreg_strength=args.otreg_strength, tpp_otreg_strength=args.tpp_otreg_strength, layer_type=args.layer_type, lowvar_trace=not args.naive_hutch, ).to(device) elif args.model == "cond_gmm" and args.tpp == "neural": model = JumpGMMSpatiotemporalModel( dim=x_dim, hidden_dims=list(map(int, args.hdims.split("-"))), tpp_hidden_dims=list(map(int, args.tpp_hdims.split("-"))), actfn=args.actfn, tpp_cond=args.tpp_cond, tpp_style=args.tpp_style, tpp_actfn=args.tpp_actfn, share_hidden=args.share_hidden, tol=args.tol, tpp_otreg_strength=args.tpp_otreg_strength, ).to(device) else: # Mix and match between spatial and temporal models. if args.tpp == "poisson": tpp_model = HomogeneousPoissonPointProcess() elif args.tpp == "hawkes": tpp_model = HawkesPointProcess() elif args.tpp == "correcting": tpp_model = SelfCorrectingPointProcess() elif args.tpp == "neural": tpp_hidden_dims = list(map(int, args.tpp_hdims.split("-"))) tpp_model = NeuralPointProcess( cond_dim=x_dim, hidden_dims=tpp_hidden_dims, cond=args.tpp_cond, style=args.tpp_style, actfn=args.tpp_actfn, otreg_strength=args.tpp_otreg_strength, tol=args.tol) else: raise ValueError(f"Invalid tpp model {args.tpp}") if args.model == "gmm": model = CombinedSpatiotemporalModel(GaussianMixtureSpatialModel(), tpp_model).to(device) elif args.model == "cnf": model = CombinedSpatiotemporalModel( IndependentCNF(dim=x_dim, hidden_dims=list(map(int, args.hdims.split("-"))), layer_type=args.layer_type, actfn=args.actfn, tol=args.tol, otreg_strength=args.otreg_strength, squash_time=True), tpp_model).to(device) elif args.model == "tvcnf": model = CombinedSpatiotemporalModel( IndependentCNF(dim=x_dim, hidden_dims=list(map(int, args.hdims.split("-"))), layer_type=args.layer_type, actfn=args.actfn, tol=args.tol, otreg_strength=args.otreg_strength), tpp_model).to(device) elif args.model == "jumpcnf": model = CombinedSpatiotemporalModel( JumpCNF(dim=x_dim, hidden_dims=list(map(int, args.hdims.split("-"))), layer_type=args.layer_type, actfn=args.actfn, tol=args.tol, otreg_strength=args.otreg_strength), tpp_model).to(device) elif args.model == "attncnf": model = CombinedSpatiotemporalModel( SelfAttentiveCNF(dim=x_dim, hidden_dims=list( map(int, args.hdims.split("-"))), layer_type=args.layer_type, actfn=args.actfn, l2_attn=args.l2_attn, tol=args.tol, otreg_strength=args.otreg_strength), tpp_model).to(device) else: raise ValueError(f"Invalid model {args.model}") params = [] attn_params = [] for name, p in model.named_parameters(): if "self_attns" in name: attn_params.append(p) else: params.append(p) optimizer = torch.optim.AdamW([{ "params": params }, { "params": attn_params }], lr=args.lr, weight_decay=args.weight_decay, betas=(0.9, 0.98)) if rank == 0: ema = utils.ExponentialMovingAverage(model) model = DDP(model, device_ids=[rank], find_unused_parameters=True) if rank == 0: logger.info(model) begin_itr = 0 checkpt_path = os.path.join(savepath, "model.pth") if os.path.exists(checkpt_path): # Restart from checkpoint if run is a restart. if rank == 0: logger.info(f"Resuming checkpoint from {checkpt_path}") checkpt = torch.load(checkpt_path, "cpu") model.module.load_state_dict(checkpt["state_dict"]) optimizer.load_state_dict(checkpt["optim_state_dict"]) begin_itr = checkpt["itr"] + 1 elif args.resume: # Check the resume flag if run is new. if rank == 0: logger.info(f"Resuming model from {args.resume}") checkpt = torch.load(args.resume, "cpu") model.module.load_state_dict(checkpt["state_dict"]) optimizer.load_state_dict(checkpt["optim_state_dict"]) begin_itr = checkpt["itr"] + 1 space_loglik_meter = utils.RunningAverageMeter(0.98) time_loglik_meter = utils.RunningAverageMeter(0.98) gradnorm_meter = utils.RunningAverageMeter(0.98) model.train() start_time = time.time() iteration_counter = itertools.count(begin_itr) begin_epoch = begin_itr // len(train_epoch_iter) for epoch in range(begin_epoch, math.ceil(args.num_iterations / len(train_epoch_iter))): batch_iter = train_epoch_iter.next_epoch_itr(shuffle=True) for batch in batch_iter: itr = next(iteration_counter) optimizer.zero_grad() event_times, spatial_locations, input_mask = map( lambda x: cast(x, device), batch) N, T = input_mask.shape num_events = input_mask.sum() if num_events == 0: raise RuntimeError("Got batch with no observations.") space_loglik, time_loglik = model(event_times, spatial_locations, input_mask, t0, t1) space_loglik = space_loglik.sum() / num_events time_loglik = time_loglik.sum() / num_events loglik = time_loglik + space_loglik space_loglik_meter.update(space_loglik.item()) time_loglik_meter.update(time_loglik.item()) loss = loglik.mul(-1.0).mean() loss.backward() # Set learning rate total_itrs = math.ceil( args.num_iterations / len(train_epoch_iter)) * len(train_epoch_iter) lr = learning_rate_schedule(itr, args.warmup_itrs, args.lr, total_itrs) set_learning_rate(optimizer, lr) grad_norm = torch.nn.utils.clip_grad.clip_grad_norm_( model.parameters(), max_norm=args.gradclip).item() gradnorm_meter.update(grad_norm) optimizer.step() if rank == 0: if itr > 0.8 * args.num_iterations: ema.apply() else: ema.apply(decay=0.0) if rank == 0: tb_writer.add_scalar("train/lr", lr, itr) tb_writer.add_scalar("train/temporal_loss", time_loglik.item(), itr) tb_writer.add_scalar("train/spatial_loss", space_loglik.item(), itr) tb_writer.add_scalar("train/grad_norm", grad_norm, itr) if itr % args.logfreq == 0: elapsed_time = time.time() - start_time # Average NFE across devices. nfe = 0 for m in model.modules(): if isinstance(m, TimeVariableCNF) or isinstance( m, TimeVariableODE): nfe += m.nfe nfe = torch.tensor(nfe).to(device) dist.all_reduce(nfe, op=dist.ReduceOp.SUM) nfe = nfe // world_size # Sum memory usage across devices. mem = torch.tensor(memory_usage_psutil()).float().to(device) dist.all_reduce(mem, op=dist.ReduceOp.SUM) if rank == 0: logger.info( f"Iter {itr} | Epoch {epoch} | LR {lr:.5f} | Time {elapsed_time:.1f}" f" | Temporal {time_loglik_meter.val:.4f}({time_loglik_meter.avg:.4f})" f" | Spatial {space_loglik_meter.val:.4f}({space_loglik_meter.avg:.4f})" f" | GradNorm {gradnorm_meter.val:.2f}({gradnorm_meter.avg:.2f})" f" | NFE {nfe.item()}" f" | Mem {mem.item():.2f} MB") tb_writer.add_scalar("train/nfe", nfe, itr) tb_writer.add_scalar("train/time_per_itr", elapsed_time / args.logfreq, itr) start_time = time.time() if rank == 0 and itr % args.testfreq == 0: # ema.swap() val_space_loglik, val_time_loglik = validate( model, val_loader, t0, t1, device) test_space_loglik, test_time_loglik = validate( model, test_loader, t0, t1, device) # ema.swap() logger.info( f"[Test] Iter {itr} | Val Temporal {val_time_loglik:.4f} | Val Spatial {val_space_loglik:.4f}" f" | Test Temporal {test_time_loglik:.4f} | Test Spatial {test_space_loglik:.4f}" ) tb_writer.add_scalar("val/temporal_loss", val_time_loglik, itr) tb_writer.add_scalar("val/spatial_loss", val_space_loglik, itr) tb_writer.add_scalar("test/temporal_loss", test_time_loglik, itr) tb_writer.add_scalar("test/spatial_loss", test_space_loglik, itr) torch.save( { "itr": itr, "state_dict": model.module.state_dict(), "optim_state_dict": optimizer.state_dict(), "ema_parmas": ema.ema_params, }, checkpt_path) start_time = time.time() if rank == 0: tb_writer.close()
def main(): parser = argparse.ArgumentParser() parser.add_argument('--data', choices=[ 'swissroll', '8gaussians', 'pinwheel', 'circles', 'moons', '2spirals', 'checkerboard', 'rings' ], type=str, default='moons') parser.add_argument('--niters', type=int, default=10000) parser.add_argument('--batch_size', type=int, default=100) parser.add_argument('--test_batch_size', type=int, default=1000) parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--weight_decay', type=float, default=0) parser.add_argument('--critic_weight_decay', type=float, default=0) parser.add_argument('--save', type=str, default='/tmp/test_lsd') parser.add_argument('--mode', type=str, default="lsd", choices=['lsd', 'sm']) parser.add_argument('--viz_freq', type=int, default=100) parser.add_argument('--save_freq', type=int, default=10000) parser.add_argument('--log_freq', type=int, default=100) parser.add_argument('--base_dist', action="store_true") parser.add_argument('--c_iters', type=int, default=5) parser.add_argument('--l2', type=float, default=10.) parser.add_argument('--exact_trace', action="store_true") parser.add_argument('--n_steps', type=int, default=10) args = parser.parse_args() # logger utils.makedirs(args.save) logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__)) logger.info(args) # fit a gaussian to the training data init_size = 1000 init_batch = sample_data(args, init_size).requires_grad_() mu, std = init_batch.mean(0), init_batch.std(0) base_dist = distributions.Normal(mu, std) # neural netz critic = networks.SmallMLP(2, n_out=2) net = networks.SmallMLP(2) ebm = EBM(net, base_dist if args.base_dist else None) ebm.to(device) critic.to(device) # for sampling init_fn = lambda: base_dist.sample_n(args.test_batch_size) cov = utils.cov(init_batch) sampler = HMCSampler(ebm, .3, 5, init_fn, device=device, covariance_matrix=cov) logger.info(ebm) logger.info(critic) # optimizers optimizer = optim.Adam(ebm.parameters(), lr=args.lr, weight_decay=args.weight_decay, betas=(.0, .999)) critic_optimizer = optim.Adam(critic.parameters(), lr=args.lr, betas=(.0, .999), weight_decay=args.critic_weight_decay) time_meter = utils.RunningAverageMeter(0.98) loss_meter = utils.RunningAverageMeter(0.98) ebm.train() end = time.time() for itr in range(args.niters): optimizer.zero_grad() critic_optimizer.zero_grad() x = sample_data(args, args.batch_size) x.requires_grad_() if args.mode == "lsd": # our method # compute dlogp(x)/dx logp_u = ebm(x) sq = keep_grad(logp_u.sum(), x) fx = critic(x) # compute (dlogp(x)/dx)^T * f(x) sq_fx = (sq * fx).sum(-1) # compute/estimate Tr(df/dx) if args.exact_trace: tr_dfdx = exact_jacobian_trace(fx, x) else: tr_dfdx = approx_jacobian_trace(fx, x) stats = (sq_fx + tr_dfdx) loss = stats.mean() # estimate of S(p, q) l2_penalty = ( fx * fx).sum(1).mean() * args.l2 # penalty to enforce f \in F # adversarial! if args.c_iters > 0 and itr % (args.c_iters + 1) != 0: (-1. * loss + l2_penalty).backward() critic_optimizer.step() else: loss.backward() optimizer.step() elif args.mode == "sm": # score matching for reference fx = ebm(x) dfdx = torch.autograd.grad(fx.sum(), x, retain_graph=True, create_graph=True)[0] eps = torch.randn_like(dfdx) # use hutchinson here as well epsH = torch.autograd.grad(dfdx, x, grad_outputs=eps, create_graph=True, retain_graph=True)[0] trH = (epsH * eps).sum(1) norm_s = (dfdx * dfdx).sum(1) loss = (trH + .5 * norm_s).mean() loss.backward() optimizer.step() else: assert False loss_meter.update(loss.item()) time_meter.update(time.time() - end) if itr % args.log_freq == 0: log_message = ( 'Iter {:04d} | Time {:.4f}({:.4f}) | Loss {:.4f}({:.4f})'. format(itr, time_meter.val, time_meter.avg, loss_meter.val, loss_meter.avg)) logger.info(log_message) if itr % args.save_freq == 0 or itr == args.niters: ebm.cpu() utils.makedirs(args.save) torch.save({ 'args': args, 'state_dict': ebm.state_dict(), }, os.path.join(args.save, 'checkpt.pth')) ebm.to(device) if itr % args.viz_freq == 0: # plot dat plt.clf() npts = 100 p_samples = toy_data.inf_train_gen(args.data, batch_size=npts**2) q_samples = sampler.sample(args.n_steps) ebm.cpu() x_enc = critic(x) xes = x_enc.detach().cpu().numpy() trans = xes.min() scale = xes.max() - xes.min() xes = (xes - trans) / scale * 8 - 4 plt.figure(figsize=(4, 4)) visualize_transform( [p_samples, q_samples.detach().cpu().numpy(), xes], ["data", "model", "embed"], [ebm], ["model"], npts=npts) fig_filename = os.path.join(args.save, 'figs', '{:04d}.png'.format(itr)) utils.makedirs(os.path.dirname(fig_filename)) plt.savefig(fig_filename) plt.close() ebm.to(device) end = time.time() logger.info('Training has finished, can I get a yeet?')