コード例 #1
0
def main():
    args = get_args()
    param_file = os.path.join(os.path.dirname(args.ckpt), "params.json")
    with open(param_file, 'r') as fh:
        args.__dict__.update(json.load(fh))

    print(args)
    # prepare dataset
    dataset = loader.RB2DataLoader(
        data_dir=args.data_folder, data_filename=args.eval_dataset,
        nx=args.eval_xres, nz=args.eval_zres, nt=args.eval_tres, n_samp_pts_per_crop=1,
        lres_interp=args.lres_interp, lres_filter=args.lres_filter, downsamp_xz=args.eval_downsamp_xz, downsamp_t=args.eval_downsamp_t,
        normalize_output=args.normalize_channels, return_hres=True)

    # extract data
    hres, lres, _, _ = dataset[0]

    # get pdelayer for the RB2 equations
    if args.normalize_channels:
        mean = dataset.channel_mean
        std = dataset.channel_std
    else:
        mean = std = None
    pde_layer = get_rb2_pde_layer(mean=mean, std=std, prandtl=args.prandtl, rayleigh=args.rayleigh)
    # pde_layer = get_rb2_pde_layer(mean=mean, std=std)

    # evaluate model for getting high res spatial temporal sequence
    res_dict = model_inference(args, lres, pde_layer)

    # save video
    export_video(args, res_dict, hres, lres, dataset)
コード例 #2
0
def main_ddp(rank, world_size, args):
    offset = int(args.apex_optim_level[1])*world_size+rank
    setup(rank, world_size, offset=offset)

    args.rank = rank
    if args.use_apex and (not HASAPEX):
        if rank == 0:
            print(import_error)
            warnings.warn(
                "Failed to import Apex. Falling back to PyTorch DistributedDataParallel.",
                ImportError)
        args.use_apex = False
    DDP = ADDP if args.use_apex else TDDP
    
#     n_per_rank = torch.cuda.device_count() // world_size
#     device_ids = list(range(rank * n_per_rank, (rank + 1) * n_per_rank))
    device_ids = [args.rank]
    torch.cuda.set_device(args.rank)

    kwargs = {'num_workers': 1, 'pin_memory': True}
    device = torch.device(device_ids[0])
    # no need to adjust batch size. batch size = batch_size_per_gpu
    args.batch_size = args.batch_size_per_gpu

    # log and create snapshots
    os.makedirs(args.log_dir, exist_ok=True)
    filenames_to_snapshot = glob("*.py") + glob("*.sh")
    utils.snapshot_files(filenames_to_snapshot, args.log_dir)
    logger = utils.get_logger(log_dir=args.log_dir)
    with open(os.path.join(args.log_dir, "params.json"), 'w') as fh:
        json.dump(args.__dict__, fh, indent=2)
    if args.rank == 0: logger.info("%s", repr(args))
        
    logger.info(f"[Rank] {rank:2d} [Cuda IDs] {device_ids}")

    # tensorboard writer
    writer = SummaryWriter(log_dir=os.path.join(args.log_dir, 'tensorboard'))

    # random seed for reproducability
    torch.manual_seed((args.seed+1) * rank)
    np.random.seed((args.seed+1) * rank)

    # create dataloaders
    trainset = loader.RB2DataLoader(
        data_dir=args.data_folder, data_filename="rb2d_ra1e6_s42.npz",
        nx=args.nx, nz=args.nz, nt=args.nt, n_samp_pts_per_crop=args.n_samp_pts_per_crop,
        downsamp_xz=args.downsamp_xz, downsamp_t=args.downsamp_t,
        normalize_output=args.normalize_channels, return_hres=False,
        lres_filter=args.lres_filter, lres_interp=args.lres_interp
    )
    evalset = loader.RB2DataLoader(
        data_dir=args.data_folder, data_filename="rb2d_ra1e6_s42.npz",
        nx=args.nx, nz=args.nz, nt=args.nt, n_samp_pts_per_crop=args.n_samp_pts_per_crop,
        downsamp_xz=args.downsamp_xz, downsamp_t=args.downsamp_t,
        normalize_output=args.normalize_channels, return_hres=True,
        lres_filter=args.lres_filter, lres_interp=args.lres_interp
    )

    nsamp_per_proc = args.pseudo_epoch_size // args.nprocs
    train_sampler = RandomSampler(trainset, replacement=True, num_samples=nsamp_per_proc)
    eval_sampler = RandomSampler(evalset, replacement=True, num_samples=args.num_log_images)

    train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=False, drop_last=True,
                              sampler=train_sampler, **kwargs)
    eval_loader = DataLoader(evalset, batch_size=args.batch_size, shuffle=False, drop_last=False,
                             sampler=eval_sampler, **kwargs)

    # setup model
    unet = UNet3d(in_features=4, out_features=args.lat_dims, igres=trainset.scale_lres,
                  nf=args.unet_nf, mf=args.unet_mf)
    imnet = ImNet(dim=3, in_features=args.lat_dims, out_features=4, nf=args.imnet_nf)
    
    if args.resume:        
        # configure map_location properly
        rank0_devices = [x - rank * len(device_ids) for x in device_ids]
        device_pairs = zip(rank0_devices, device_ids)
        map_location = {'cuda:%d' % x: 'cuda:%d' % y for x, y in device_pairs}

        resume_dict = torch.load(args.resume, map_location=map_location)
        start_ep = resume_dict["epoch"]
        global_step = resume_dict["global_step"]
        tracked_stats = resume_dict["tracked_stats"]
        unet.load_state_dict(resume_dict["unet_state_dict"])
        imnet.load_state_dict(resume_dict["imnet_state_dict"])
    
    unet.to(device)
    imnet.to(device)
    
    all_model_params = list(unet.parameters())+list(imnet.parameters())

    if args.optim == "sgd":
        optimizer = optim.SGD(all_model_params, lr=args.lr)
    else:
        optimizer = optim.Adam(all_model_params, lr=args.lr)
    
    if args.use_apex:
        (unet, imnet), optimizer = amp.initialize([unet, imnet], optimizer, opt_level=args.apex_optim_level)
    
    if args.use_apex:
        unet = DDP(unet)
        imnet = DDP(imnet)
    else:
        unet = DDP(unet, device_ids=device_ids)
        imnet = DDP(imnet, device_ids=device_ids)
        
    if args.resume:
        optimizer.load_state_dict(resume_dict["optim_state_dict"])
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device)

    start_ep = 0
    global_step = np.zeros(1, dtype=np.uint32)
    tracked_stats = np.inf


    model_param_count = lambda model: sum(x.numel() for x in model.parameters())
    if args.rank == 0: 
        logger.info("{}(unet) + {}(imnet) paramerters in total".format(
            model_param_count(unet), model_param_count(imnet)))

    checkpoint_path = os.path.join(args.log_dir, "checkpoint_latest.pth.tar")

    # get pdelayer for the RB2 equations
    if args.normalize_channels:
        mean = trainset.channel_mean
        std = trainset.channel_std
    else:
        mean = std = None
    pde_layer = get_rb2_pde_layer(mean=mean, std=std,
        t_crop=args.nt*0.125, z_crop=args.nz*(1./128), x_crop=args.nx*(1./128))

    if args.lr_scheduler:
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')

    # training loop
    for epoch in range(start_ep + 1, args.epochs + 1):
        t0 = time.time()
        loss = train(args, unet, imnet, train_loader, epoch, global_step, device, logger, writer,
                     optimizer, pde_layer)
        t1 = time.time()
        eval(args, unet, imnet, eval_loader, epoch, global_step, device, logger, writer, 
             optimizer, pde_layer)
        t2 = time.time()
        if args.lr_scheduler:
            scheduler.step(loss)
        if loss < tracked_stats:
            tracked_stats = loss
            is_best = True
        else:
            is_best = False
        if args.rank == 0:
            utils.save_checkpoint({
                "epoch": epoch,
                "unet_state_dict": unet.module.state_dict(),
                "imnet_state_dict": imnet.module.state_dict(),
                "optim_state_dict": optimizer.state_dict(),
                "tracked_stats": tracked_stats,
                "global_step": global_step,
            }, is_best, epoch, checkpoint_path, "_pdenet", logger)
        t3 = time.time()
        if args.rank == 0:
            logger.info(f"Total time per epoch: {datetime.timedelta(seconds=t3-t0)} ({t3-t0:.2f} secs)")
            logger.info(f"Train time per epoch: {datetime.timedelta(seconds=t1-t0)} ({t1-t0:.2f} secs)")
            logger.info(f"Eval  time per epoch: {datetime.timedelta(seconds=t2-t1)} ({t2-t1:.2f} secs)")
            if epoch == 1 and args.output_timing:
                if not os.path.exists(args.output_timing):
                    newfile = True
                else:
                    newfile = False
                with open(args.output_timing, "a") as fh:
                    if newfile:
                        fh.write("num_gpu,opt_level,total_time_per_epoch,train_time_per_epoch,eval_time_per_epoch\n")
                    fh.write(("{num_gpu},{opt_level},{tot_time},{train_time},{eval_time}\n"
                              .format(num_gpu=args.nprocs, opt_level=args.apex_optim_level, 
                                      tot_time=t3-t0, train_time=t1-t0, eval_time=t2-t1)))
        
    cleanup()
コード例 #3
0
ファイル: train.py プロジェクト: bradfordlynch/space_time_pde
def main():
    args = get_args()

    use_cuda = (not args.no_cuda) and torch.cuda.is_available()
    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    device = torch.device("cuda" if use_cuda else "cpu")
    # adjust batch size based on the number of gpus available
    args.batch_size = int(torch.cuda.device_count()) * args.batch_size_per_gpu

    # log and create snapshots
    os.makedirs(args.log_dir, exist_ok=True)
    filenames_to_snapshot = glob("*.py") + glob("*.sh")
    utils.snapshot_files(filenames_to_snapshot, args.log_dir)
    logger = utils.get_logger(log_dir=args.log_dir)
    with open(os.path.join(args.log_dir, "params.json"), 'w') as fh:
        json.dump(args.__dict__, fh, indent=2)
    logger.info("%s", repr(args))

    # tensorboard writer
    writer = SummaryWriter(log_dir=os.path.join(args.log_dir, 'tensorboard'))

    # random seed for reproducability
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    # create dataloaders
    trainset = loader.RB2DataLoader(
        data_dir=args.data_folder,
        data_filename=args.train_data,
        nx=args.nx,
        nz=args.nz,
        nt=args.nt,
        n_samp_pts_per_crop=args.n_samp_pts_per_crop,
        downsamp_xz=args.downsamp_xz,
        downsamp_t=args.downsamp_t,
        normalize_output=args.normalize_channels,
        return_hres=False,
        lres_filter=args.lres_filter,
        lres_interp=args.lres_interp)
    evalset = loader.RB2DataLoader(
        data_dir=args.data_folder,
        data_filename=args.eval_data,
        nx=args.nx,
        nz=args.nz,
        nt=args.nt,
        n_samp_pts_per_crop=args.n_samp_pts_per_crop,
        downsamp_xz=args.downsamp_xz,
        downsamp_t=args.downsamp_t,
        normalize_output=args.normalize_channels,
        return_hres=True,
        lres_filter=args.lres_filter,
        lres_interp=args.lres_interp)

    train_sampler = RandomSampler(trainset,
                                  replacement=True,
                                  num_samples=args.pseudo_epoch_size)
    eval_sampler = RandomSampler(evalset,
                                 replacement=True,
                                 num_samples=args.num_log_images)

    train_loader = DataLoader(trainset,
                              batch_size=args.batch_size,
                              shuffle=False,
                              drop_last=True,
                              sampler=train_sampler,
                              **kwargs)
    eval_loader = DataLoader(evalset,
                             batch_size=args.batch_size,
                             shuffle=False,
                             drop_last=False,
                             sampler=eval_sampler,
                             **kwargs)

    # setup model
    unet = UNet3d(in_features=4,
                  out_features=args.lat_dims,
                  igres=trainset.scale_lres,
                  nf=args.unet_nf,
                  mf=args.unet_mf)
    imnet = ImNet(dim=3,
                  in_features=args.lat_dims,
                  out_features=4,
                  nf=args.imnet_nf,
                  activation=NONLINEARITIES[args.nonlin])
    all_model_params = list(unet.parameters()) + list(imnet.parameters())

    if args.optim == "sgd":
        optimizer = optim.SGD(all_model_params, lr=args.lr)
    else:
        optimizer = optim.Adam(all_model_params, lr=args.lr)

    start_ep = 0
    global_step = np.zeros(1, dtype=np.uint32)
    tracked_stats = np.inf

    if args.resume:
        resume_dict = torch.load(args.resume)
        start_ep = resume_dict["epoch"]
        global_step = resume_dict["global_step"]
        tracked_stats = resume_dict["tracked_stats"]
        unet.load_state_dict(resume_dict["unet_state_dict"])
        imnet.load_state_dict(resume_dict["imnet_state_dict"])
        optimizer.load_state_dict(resume_dict["optim_state_dict"])
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device)

    unet = nn.DataParallel(unet)
    unet.to(device)
    imnet = nn.DataParallel(imnet)
    imnet.to(device)

    model_param_count = lambda model: sum(x.numel()
                                          for x in model.parameters())
    logger.info("{}(unet) + {}(imnet) paramerters in total".format(
        model_param_count(unet), model_param_count(imnet)))

    checkpoint_path = os.path.join(args.log_dir, "checkpoint_latest.pth.tar")

    # get pdelayer for the RB2 equations
    if args.normalize_channels:
        mean = trainset.channel_mean
        std = trainset.channel_std
    else:
        mean = std = None
    pde_layer = get_rb2_pde_layer(mean=mean,
                                  std=std,
                                  t_crop=args.nt * 0.125,
                                  z_crop=args.nz * (1. / 128),
                                  x_crop=args.nx * (1. / 128),
                                  prandtl=args.prandtl,
                                  rayleigh=args.rayleigh,
                                  use_continuity=args.use_continuity)

    if args.lr_scheduler:
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')

    # training loop
    for epoch in range(start_ep + 1, args.epochs + 1):
        loss = train(args, unet, imnet, train_loader, epoch, global_step,
                     device, logger, writer, optimizer, pde_layer)
        eval(args, unet, imnet, eval_loader, epoch, global_step, device,
             logger, writer, optimizer, pde_layer)
        if args.lr_scheduler:
            scheduler.step(loss)
        if loss < tracked_stats:
            tracked_stats = loss
            is_best = True
        else:
            is_best = False

        utils.save_checkpoint(
            {
                "epoch": epoch,
                "unet_state_dict": unet.module.state_dict(),
                "imnet_state_dict": imnet.module.state_dict(),
                "optim_state_dict": optimizer.state_dict(),
                "tracked_stats": tracked_stats,
                "global_step": global_step,
            }, is_best, epoch, checkpoint_path, "_pdenet", logger)