def main(): args = get_args() param_file = os.path.join(os.path.dirname(args.ckpt), "params.json") with open(param_file, 'r') as fh: args.__dict__.update(json.load(fh)) print(args) # prepare dataset dataset = loader.RB2DataLoader( data_dir=args.data_folder, data_filename=args.eval_dataset, nx=args.eval_xres, nz=args.eval_zres, nt=args.eval_tres, n_samp_pts_per_crop=1, lres_interp=args.lres_interp, lres_filter=args.lres_filter, downsamp_xz=args.eval_downsamp_xz, downsamp_t=args.eval_downsamp_t, normalize_output=args.normalize_channels, return_hres=True) # extract data hres, lres, _, _ = dataset[0] # get pdelayer for the RB2 equations if args.normalize_channels: mean = dataset.channel_mean std = dataset.channel_std else: mean = std = None pde_layer = get_rb2_pde_layer(mean=mean, std=std, prandtl=args.prandtl, rayleigh=args.rayleigh) # pde_layer = get_rb2_pde_layer(mean=mean, std=std) # evaluate model for getting high res spatial temporal sequence res_dict = model_inference(args, lres, pde_layer) # save video export_video(args, res_dict, hres, lres, dataset)
def main_ddp(rank, world_size, args): offset = int(args.apex_optim_level[1])*world_size+rank setup(rank, world_size, offset=offset) args.rank = rank if args.use_apex and (not HASAPEX): if rank == 0: print(import_error) warnings.warn( "Failed to import Apex. Falling back to PyTorch DistributedDataParallel.", ImportError) args.use_apex = False DDP = ADDP if args.use_apex else TDDP # n_per_rank = torch.cuda.device_count() // world_size # device_ids = list(range(rank * n_per_rank, (rank + 1) * n_per_rank)) device_ids = [args.rank] torch.cuda.set_device(args.rank) kwargs = {'num_workers': 1, 'pin_memory': True} device = torch.device(device_ids[0]) # no need to adjust batch size. batch size = batch_size_per_gpu args.batch_size = args.batch_size_per_gpu # log and create snapshots os.makedirs(args.log_dir, exist_ok=True) filenames_to_snapshot = glob("*.py") + glob("*.sh") utils.snapshot_files(filenames_to_snapshot, args.log_dir) logger = utils.get_logger(log_dir=args.log_dir) with open(os.path.join(args.log_dir, "params.json"), 'w') as fh: json.dump(args.__dict__, fh, indent=2) if args.rank == 0: logger.info("%s", repr(args)) logger.info(f"[Rank] {rank:2d} [Cuda IDs] {device_ids}") # tensorboard writer writer = SummaryWriter(log_dir=os.path.join(args.log_dir, 'tensorboard')) # random seed for reproducability torch.manual_seed((args.seed+1) * rank) np.random.seed((args.seed+1) * rank) # create dataloaders trainset = loader.RB2DataLoader( data_dir=args.data_folder, data_filename="rb2d_ra1e6_s42.npz", nx=args.nx, nz=args.nz, nt=args.nt, n_samp_pts_per_crop=args.n_samp_pts_per_crop, downsamp_xz=args.downsamp_xz, downsamp_t=args.downsamp_t, normalize_output=args.normalize_channels, return_hres=False, lres_filter=args.lres_filter, lres_interp=args.lres_interp ) evalset = loader.RB2DataLoader( data_dir=args.data_folder, data_filename="rb2d_ra1e6_s42.npz", nx=args.nx, nz=args.nz, nt=args.nt, n_samp_pts_per_crop=args.n_samp_pts_per_crop, downsamp_xz=args.downsamp_xz, downsamp_t=args.downsamp_t, normalize_output=args.normalize_channels, return_hres=True, lres_filter=args.lres_filter, lres_interp=args.lres_interp ) nsamp_per_proc = args.pseudo_epoch_size // args.nprocs train_sampler = RandomSampler(trainset, replacement=True, num_samples=nsamp_per_proc) eval_sampler = RandomSampler(evalset, replacement=True, num_samples=args.num_log_images) train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=False, drop_last=True, sampler=train_sampler, **kwargs) eval_loader = DataLoader(evalset, batch_size=args.batch_size, shuffle=False, drop_last=False, sampler=eval_sampler, **kwargs) # setup model unet = UNet3d(in_features=4, out_features=args.lat_dims, igres=trainset.scale_lres, nf=args.unet_nf, mf=args.unet_mf) imnet = ImNet(dim=3, in_features=args.lat_dims, out_features=4, nf=args.imnet_nf) if args.resume: # configure map_location properly rank0_devices = [x - rank * len(device_ids) for x in device_ids] device_pairs = zip(rank0_devices, device_ids) map_location = {'cuda:%d' % x: 'cuda:%d' % y for x, y in device_pairs} resume_dict = torch.load(args.resume, map_location=map_location) start_ep = resume_dict["epoch"] global_step = resume_dict["global_step"] tracked_stats = resume_dict["tracked_stats"] unet.load_state_dict(resume_dict["unet_state_dict"]) imnet.load_state_dict(resume_dict["imnet_state_dict"]) unet.to(device) imnet.to(device) all_model_params = list(unet.parameters())+list(imnet.parameters()) if args.optim == "sgd": optimizer = optim.SGD(all_model_params, lr=args.lr) else: optimizer = optim.Adam(all_model_params, lr=args.lr) if args.use_apex: (unet, imnet), optimizer = amp.initialize([unet, imnet], optimizer, opt_level=args.apex_optim_level) if args.use_apex: unet = DDP(unet) imnet = DDP(imnet) else: unet = DDP(unet, device_ids=device_ids) imnet = DDP(imnet, device_ids=device_ids) if args.resume: optimizer.load_state_dict(resume_dict["optim_state_dict"]) for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(device) start_ep = 0 global_step = np.zeros(1, dtype=np.uint32) tracked_stats = np.inf model_param_count = lambda model: sum(x.numel() for x in model.parameters()) if args.rank == 0: logger.info("{}(unet) + {}(imnet) paramerters in total".format( model_param_count(unet), model_param_count(imnet))) checkpoint_path = os.path.join(args.log_dir, "checkpoint_latest.pth.tar") # get pdelayer for the RB2 equations if args.normalize_channels: mean = trainset.channel_mean std = trainset.channel_std else: mean = std = None pde_layer = get_rb2_pde_layer(mean=mean, std=std, t_crop=args.nt*0.125, z_crop=args.nz*(1./128), x_crop=args.nx*(1./128)) if args.lr_scheduler: scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') # training loop for epoch in range(start_ep + 1, args.epochs + 1): t0 = time.time() loss = train(args, unet, imnet, train_loader, epoch, global_step, device, logger, writer, optimizer, pde_layer) t1 = time.time() eval(args, unet, imnet, eval_loader, epoch, global_step, device, logger, writer, optimizer, pde_layer) t2 = time.time() if args.lr_scheduler: scheduler.step(loss) if loss < tracked_stats: tracked_stats = loss is_best = True else: is_best = False if args.rank == 0: utils.save_checkpoint({ "epoch": epoch, "unet_state_dict": unet.module.state_dict(), "imnet_state_dict": imnet.module.state_dict(), "optim_state_dict": optimizer.state_dict(), "tracked_stats": tracked_stats, "global_step": global_step, }, is_best, epoch, checkpoint_path, "_pdenet", logger) t3 = time.time() if args.rank == 0: logger.info(f"Total time per epoch: {datetime.timedelta(seconds=t3-t0)} ({t3-t0:.2f} secs)") logger.info(f"Train time per epoch: {datetime.timedelta(seconds=t1-t0)} ({t1-t0:.2f} secs)") logger.info(f"Eval time per epoch: {datetime.timedelta(seconds=t2-t1)} ({t2-t1:.2f} secs)") if epoch == 1 and args.output_timing: if not os.path.exists(args.output_timing): newfile = True else: newfile = False with open(args.output_timing, "a") as fh: if newfile: fh.write("num_gpu,opt_level,total_time_per_epoch,train_time_per_epoch,eval_time_per_epoch\n") fh.write(("{num_gpu},{opt_level},{tot_time},{train_time},{eval_time}\n" .format(num_gpu=args.nprocs, opt_level=args.apex_optim_level, tot_time=t3-t0, train_time=t1-t0, eval_time=t2-t1))) cleanup()
def main(): args = get_args() use_cuda = (not args.no_cuda) and torch.cuda.is_available() kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} device = torch.device("cuda" if use_cuda else "cpu") # adjust batch size based on the number of gpus available args.batch_size = int(torch.cuda.device_count()) * args.batch_size_per_gpu # log and create snapshots os.makedirs(args.log_dir, exist_ok=True) filenames_to_snapshot = glob("*.py") + glob("*.sh") utils.snapshot_files(filenames_to_snapshot, args.log_dir) logger = utils.get_logger(log_dir=args.log_dir) with open(os.path.join(args.log_dir, "params.json"), 'w') as fh: json.dump(args.__dict__, fh, indent=2) logger.info("%s", repr(args)) # tensorboard writer writer = SummaryWriter(log_dir=os.path.join(args.log_dir, 'tensorboard')) # random seed for reproducability torch.manual_seed(args.seed) np.random.seed(args.seed) # create dataloaders trainset = loader.RB2DataLoader( data_dir=args.data_folder, data_filename=args.train_data, nx=args.nx, nz=args.nz, nt=args.nt, n_samp_pts_per_crop=args.n_samp_pts_per_crop, downsamp_xz=args.downsamp_xz, downsamp_t=args.downsamp_t, normalize_output=args.normalize_channels, return_hres=False, lres_filter=args.lres_filter, lres_interp=args.lres_interp) evalset = loader.RB2DataLoader( data_dir=args.data_folder, data_filename=args.eval_data, nx=args.nx, nz=args.nz, nt=args.nt, n_samp_pts_per_crop=args.n_samp_pts_per_crop, downsamp_xz=args.downsamp_xz, downsamp_t=args.downsamp_t, normalize_output=args.normalize_channels, return_hres=True, lres_filter=args.lres_filter, lres_interp=args.lres_interp) train_sampler = RandomSampler(trainset, replacement=True, num_samples=args.pseudo_epoch_size) eval_sampler = RandomSampler(evalset, replacement=True, num_samples=args.num_log_images) train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=False, drop_last=True, sampler=train_sampler, **kwargs) eval_loader = DataLoader(evalset, batch_size=args.batch_size, shuffle=False, drop_last=False, sampler=eval_sampler, **kwargs) # setup model unet = UNet3d(in_features=4, out_features=args.lat_dims, igres=trainset.scale_lres, nf=args.unet_nf, mf=args.unet_mf) imnet = ImNet(dim=3, in_features=args.lat_dims, out_features=4, nf=args.imnet_nf, activation=NONLINEARITIES[args.nonlin]) all_model_params = list(unet.parameters()) + list(imnet.parameters()) if args.optim == "sgd": optimizer = optim.SGD(all_model_params, lr=args.lr) else: optimizer = optim.Adam(all_model_params, lr=args.lr) start_ep = 0 global_step = np.zeros(1, dtype=np.uint32) tracked_stats = np.inf if args.resume: resume_dict = torch.load(args.resume) start_ep = resume_dict["epoch"] global_step = resume_dict["global_step"] tracked_stats = resume_dict["tracked_stats"] unet.load_state_dict(resume_dict["unet_state_dict"]) imnet.load_state_dict(resume_dict["imnet_state_dict"]) optimizer.load_state_dict(resume_dict["optim_state_dict"]) for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(device) unet = nn.DataParallel(unet) unet.to(device) imnet = nn.DataParallel(imnet) imnet.to(device) model_param_count = lambda model: sum(x.numel() for x in model.parameters()) logger.info("{}(unet) + {}(imnet) paramerters in total".format( model_param_count(unet), model_param_count(imnet))) checkpoint_path = os.path.join(args.log_dir, "checkpoint_latest.pth.tar") # get pdelayer for the RB2 equations if args.normalize_channels: mean = trainset.channel_mean std = trainset.channel_std else: mean = std = None pde_layer = get_rb2_pde_layer(mean=mean, std=std, t_crop=args.nt * 0.125, z_crop=args.nz * (1. / 128), x_crop=args.nx * (1. / 128), prandtl=args.prandtl, rayleigh=args.rayleigh, use_continuity=args.use_continuity) if args.lr_scheduler: scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') # training loop for epoch in range(start_ep + 1, args.epochs + 1): loss = train(args, unet, imnet, train_loader, epoch, global_step, device, logger, writer, optimizer, pde_layer) eval(args, unet, imnet, eval_loader, epoch, global_step, device, logger, writer, optimizer, pde_layer) if args.lr_scheduler: scheduler.step(loss) if loss < tracked_stats: tracked_stats = loss is_best = True else: is_best = False utils.save_checkpoint( { "epoch": epoch, "unet_state_dict": unet.module.state_dict(), "imnet_state_dict": imnet.module.state_dict(), "optim_state_dict": optimizer.state_dict(), "tracked_stats": tracked_stats, "global_step": global_step, }, is_best, epoch, checkpoint_path, "_pdenet", logger)