def _process(self): f = osp.join(self.processed_dir, 'pre_transform.pkl') if osp.exists(f) and jt.load(f) != __repr__(self.pre_transform): logging.warning( 'The `pre_transform` argument differs from the one used in ' 'the pre-processed version of this dataset. If you really ' 'want to make use of another pre-processing technique, make ' 'sure to delete `{}` first.'.format(self.processed_dir)) f = osp.join(self.processed_dir, 'pre_filter.pkl') if osp.exists(f) and jt.load(f) != __repr__(self.pre_filter): logging.warning( 'The `pre_filter` argument differs from the one used in the ' 'pre-processed version of this dataset. If you really want to ' 'make use of another pre-fitering technique, make sure to ' 'delete `{}` first.'.format(self.processed_dir)) if files_exist(self.processed_paths): # pragma: no cover return print('Processing...') makedirs(self.processed_dir) self.process() path = osp.join(self.processed_dir, 'pre_transform.pkl') jt.save(__repr__(self.pre_transform), path) path = osp.join(self.processed_dir, 'pre_filter.pkl') jt.save(__repr__(self.pre_filter), path) print('Done!')
def test_save(self): pp = [1,2,jt.array([1,2,3]), {"a":[1,2,3], "b":jt.array([1,2,3])}] jt.save(pp, "/tmp/xx.pkl") x = jt.load("/tmp/xx.pkl") assert x[:2] == [1,2] assert (x[2] == np.array([1,2,3])).all() assert x[3]['a'] == [1,2,3] assert (x[3]['b'] == np.array([1,2,3])).all()
def save_checkpoint(checkpoint_path, model, _optimizers, logger, cfg, **kwargs): state = { 'state_dict': model.state_dict(), 'optimizer': _optimizers.state_dict(), 'cfg': cfg } state.update(kwargs) jittor.save(state, checkpoint_path) logger.info('models saved to %s' % checkpoint_path)
def train(network: model.Siren, optim: jittor.optim.Optimizer, loss_fn, epochs, coords, gt, save_path): network.train() min_loss = np.inf for epoch in tqdm(range(epochs)): output = network(coords) loss = loss_fn(output, gt) optim.step(loss) loss = loss.item() if loss < min_loss: min_loss = loss jittor.save(network.state_dict(), save_path) if epoch % 10 == 0: tqdm.write(f"epoch: {epoch}, loss: {loss}, min_loss: {min_loss}")
def test_save(self): pp = [ 1, 2, jt.array([1, 2, 3]), { "a": [1, 2, 3], "b": jt.array([1, 2, 3]) } ] name = jt.flags.cache_path + "/xx.pkl" jt.save(pp, name) x = jt.load(name) assert x[:2] == [1, 2] assert (x[2] == np.array([1, 2, 3])).all() assert x[3]['a'] == [1, 2, 3] assert (x[3]['b'] == np.array([1, 2, 3])).all()
def save(self, name, **kwargs): if not self.save_dir: return if not self.save_to_disk: return data = {} data["model"] = self.model.state_dict() if self.optimizer is not None: data["optimizer"] = self.optimizer.state_dict() if self.scheduler is not None: data["scheduler"] = self.scheduler.state_dict() data.update(kwargs) save_file = os.path.join(self.save_dir, "{}.pth".format(name)) self.logger.info("Saving checkpoint to {}".format(save_file)) jt.save(data, save_file) self.tag_last_checkpoint(save_file)
def train(hyp, opt, tb_writer=None): logger.info( colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items())) save_dir, epochs, batch_size, weights = Path( opt.save_dir), opt.epochs, opt.batch_size, opt.weights # Directories wdir = save_dir / 'weights' wdir.mkdir(parents=True, exist_ok=True) # make dir last = wdir / 'last.pkl' best = wdir / 'best.pkl' results_file = save_dir / 'results.txt' # Save run settings with open(save_dir / 'hyp.yaml', 'w') as f: yaml.dump(hyp, f, sort_keys=False) with open(save_dir / 'opt.yaml', 'w') as f: yaml.dump(vars(opt), f, sort_keys=False) # Configure plots = not opt.evolve # create plots cuda = not opt.no_cuda if cuda: jt.flags.use_cuda = 1 init_seeds(1) with open(opt.data) as f: data_dict = yaml.load(f, Loader=yaml.SafeLoader) # data dict check_dataset(data_dict) # check train_path = data_dict['train'] test_path = data_dict['val'] nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes names = ['item'] if opt.single_cls and len( data_dict['names']) != 1 else data_dict['names'] # class names assert len(names) == nc, '%g names found for nc=%g dataset in %s' % ( len(names), nc, opt.data) # check # Model model = Model(opt.cfg, ch=3, nc=nc) # create pretrained = weights.endswith('.pkl') if pretrained: model.load(weights) # load # Optimizer nbs = 64 # nominal batch size accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay logger.info(f"Scaled weight_decay = {hyp['weight_decay']}") pg0, pg1, pg2 = [], [], [] # optimizer parameter groups for k, v in model.named_modules(): if hasattr(v, 'bias') and isinstance(v.bias, jt.Var): pg2.append(v.bias) # biases if isinstance(v, nn.BatchNorm): pg0.append(v.weight) # no decay elif hasattr(v, 'weight') and isinstance(v.weight, jt.Var): pg1.append(v.weight) # apply decay if opt.adam: optimizer = optim.Adam(pg0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum else: optimizer = optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True) optimizer.add_param_group({ 'params': pg1, 'weight_decay': hyp['weight_decay'] }) # add pg1 with weight_decay optimizer.add_param_group({'params': pg2}) # add pg2 (biases) logger.info('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0))) del pg0, pg1, pg2 # Scheduler https://arxiv.org/pdf/1812.01187.pdf # https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR lf = one_cycle(1, hyp['lrf'], epochs) # cosine 1->hyp['lrf'] scheduler = optim.LambdaLR(optimizer, lr_lambda=lf) # plot_lr_scheduler(optimizer, scheduler, epochs) loggers = {} # loggers dict start_epoch, best_fitness = 0, 0.0 # Image sizes gs = int(model.stride.max()) # grid size (max stride) nl = model.model[ -1].nl # number of detection layers (used for scaling hyp['obj']) imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size ] # verify imgsz are gs-multiples # EMA ema = ModelEMA(model) # Trainloader dataloader = create_dataloader(train_path, imgsz, batch_size, gs, opt, hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, workers=opt.workers, image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: ')) mlc = np.concatenate(dataloader.labels, 0)[:, 0].max() # max label class nb = len(dataloader) # number of batches assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % ( mlc, nc, opt.data, nc - 1) ema.updates = start_epoch * nb // accumulate # set EMA updates testloader = create_dataloader( test_path, imgsz_test, batch_size, gs, opt, # testloader hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, workers=opt.workers, pad=0.5, prefix=colorstr('val: ')) labels = np.concatenate(dataloader.labels, 0) c = jt.array(labels[:, 0]) # classes # cf = torch.bincount(c.int(), minlength=nc) + 1. # frequency # model._initialize_biases(cf) if plots: plot_labels(labels, save_dir, loggers) if tb_writer: tb_writer.add_histogram('classes', c.numpy(), 0) # Anchors if not opt.noautoanchor: check_anchors(dataloader, model=model, thr=hyp['anchor_t'], imgsz=imgsz) # Model parameters hyp['box'] *= 3. / nl # scale to layers hyp['cls'] *= nc / 80. * 3. / nl # scale to classes and layers hyp['obj'] *= (imgsz / 640)**2 * 3. / nl # scale to image size and layers model.nc = nc # attach number of classes to model model.hyp = hyp # attach hyperparameters to model model.gr = 1.0 # iou loss ratio (obj_loss = 1.0 or iou) model.class_weights = labels_to_class_weights( dataloader.labels, nc) * nc # attach class weights model.names = names # Start training t0 = time.time() nw = max(round(hyp['warmup_epochs'] * nb), 1000) # number of warmup iterations, max(3 epochs, 1k iterations) # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training maps = np.zeros(nc) # mAP per class results = (0, 0, 0, 0, 0, 0, 0 ) # P, R, [email protected], [email protected], val_loss(box, obj, cls) scheduler.last_epoch = start_epoch - 1 # do not move logger.info(f'Image sizes {imgsz} train, {imgsz_test} test\n' f'Using {dataloader.num_workers} dataloader workers\n' f'Logging results to {save_dir}\n' f'Starting training for {epochs} epochs...') for epoch in range( start_epoch, epochs ): # epoch ------------------------------------------------------------------ model.train() # Update image weights (optional) if opt.image_weights: # Generate indices cw = model.class_weights.numpy() * (1 - maps)**2 / nc # class weights iw = labels_to_image_weights(dataloader.labels, nc=nc, class_weights=cw) # image weights dataloader.indices = random.choices( range(dataloader.n), weights=iw, k=dataloader.n) # rand weighted idx # Update mosaic border # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs) # dataset.mosaic_border = [b - imgsz, -b] # height, width borders mloss = jt.zeros((4, )) # mean losses pbar = enumerate(dataloader) logger.info( ('\n' + '%10s' * 7) % ('Epoch', 'box', 'obj', 'cls', 'total', 'targets', 'img_size')) pbar = tqdm(pbar, total=nb) # progress bar for i, ( imgs, targets, paths, _ ) in pbar: # batch ------------------------------------------------------------- ni = i + nb * epoch # number integrated batches (since train start) imgs = imgs.float() / 255.0 # uint8 to float32, 0-255 to 0.0-1.0 # Warmup if ni <= nw: xi = [0, nw] # x interp # model.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou) # accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round()) for j, x in enumerate(optimizer.param_groups): # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0 x['lr'] = np.interp(ni, xi, [ hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch) ]) if 'momentum' in x: x['momentum'] = np.interp( ni, xi, [hyp['warmup_momentum'], hyp['momentum']]) # Multi-scale if opt.multi_scale: sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs # size sf = sz / max(imgs.shape[2:]) # scale factor if sf != 1: ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:] ] # new shape (stretched to gs-multiple) imgs = nn.interpolate(imgs, size=ns, mode='bilinear', align_corners=False) # Forward pred = model(imgs) # forward loss, loss_items = compute_loss(pred, targets, model) # loss scaled by batch_size if opt.quad: loss *= 4. # Optimize optimizer.step(loss) if ema: ema.update(model) # Print mloss = (mloss * i + loss_items) / (i + 1) # update mean losses s = ('%10s' + '%10.4g' * 6) % ('%g/%g' % (epoch, epochs - 1), *mloss, targets.shape[0], imgs.shape[-1]) pbar.set_description(s) # Plot if plots and ni < 3: f = save_dir / f'train_batch{ni}.jpg' # filename Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start() # if tb_writer: # tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch) # tb_writer.add_graph(model, imgs) # add model to tensorboard # end batch ------------------------------------------------------------------------------------------------ # end epoch ---------------------------------------------------------------------------------------------------- # Scheduler lr = [x['lr'] for x in optimizer.param_groups] # for tensorboard scheduler.step() # mAP if ema: ema.update_attr(model, include=[ 'yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights' ]) final_epoch = epoch + 1 == epochs if not opt.notest or final_epoch: # Calculate mAP results, maps, times = test.test(data=opt.data, batch_size=batch_size, imgsz=imgsz_test, model=ema.ema, single_cls=opt.single_cls, dataloader=testloader, save_dir=save_dir, plots=plots and final_epoch) # Write with open(results_file, 'a') as f: f.write(s + '%10.4g' * 7 % results + '\n') # P, R, [email protected], [email protected], val_loss(box, obj, cls) if len(opt.name) and opt.bucket: os.system('gsutil cp %s gs://%s/results/results%s.txt' % (results_file, opt.bucket, opt.name)) # Log tags = [ 'train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss 'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5-0.95', 'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss 'x/lr0', 'x/lr1', 'x/lr2' ] # params for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags): if tb_writer: if hasattr(x, "numpy"): x = x.numpy() tb_writer.add_scalar(tag, x, epoch) # tensorboard # Update best mAP fi = fitness(np.array(results).reshape( 1, -1)) # weighted combination of [P, R, [email protected], [email protected]] if fi > best_fitness: best_fitness = fi # Save model save = (not opt.nosave) or (final_epoch and not opt.evolve) if save: # Save last, best and delete jt.save(ema.ema.state_dict(), last) if best_fitness == fi: jt.save(ema.ema.state_dict(), best) # end epoch ---------------------------------------------------------------------------------------------------- # end training # Strip optimizers final = best if best.exists() else last # final model if opt.bucket: os.system(f'gsutil cp {final} gs://{opt.bucket}/weights') # upload # Plots if plots: plot_results(save_dir=save_dir) # save as results.png # Test best.pkl logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600)) best_model = Model(opt.cfg) best_model.load(str(final)) best_model = best_model.fuse() if opt.data.endswith('coco.yaml') and nc == 80: # if COCO for conf, iou, save_json in ([0.25, 0.45, False], [0.001, 0.65, True]): # speed, mAP tests results, _, _ = test.test(opt.data, batch_size=total_batch_size, imgsz=imgsz_test, conf_thres=conf, iou_thres=iou, model=best_model, single_cls=opt.single_cls, dataloader=testloader, save_dir=save_dir, save_json=save_json, plots=False) return results
def process(self): data = read_planetoid_data(self.raw_dir, self.name) data = data if self.pre_transform is None else self.pre_transform(data) jt.save(self.collate([data]), self.processed_paths[0])
def train(): parser = config_parser() args = parser.parse_args() # Load data intrinsic = None if args.dataset_type == 'llff': images, poses, bds, render_poses, i_test = load_llff_data( args.datadir, args.factor, recenter=True, bd_factor=.75, spherify=args.spherify) hwf = poses[0, :3, -1] poses = poses[:, :3, :4] print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir) if not isinstance(i_test, list): i_test = [i_test] if args.llffhold > 0: print('Auto LLFF holdout,', args.llffhold) i_test = np.arange(images.shape[0])[::args.llffhold] i_val = i_test i_train = np.array([ i for i in np.arange(int(images.shape[0])) if (i not in i_test and i not in i_val) ]) print('DEFINING BOUNDS') if args.no_ndc: near = np.ndarray.min(bds) * .9 far = np.ndarray.max(bds) * 1. else: near = 0. far = 1. print('NEAR FAR', near, far) elif args.dataset_type == 'blender': testskip = args.testskip faketestskip = args.faketestskip if jt.mpi and jt.mpi.local_rank() != 0: testskip = faketestskip faketestskip = 1 if args.do_intrinsic: images, poses, intrinsic, render_poses, hwf, i_split = load_blender_data( args.datadir, args.half_res, args.testskip, args.blender_factor, True) else: images, poses, render_poses, hwf, i_split = load_blender_data( args.datadir, args.half_res, args.testskip, args.blender_factor) print('Loaded blender', images.shape, render_poses.shape, hwf, args.datadir) i_train, i_val, i_test = i_split i_test_tot = i_test i_test = i_test[::args.faketestskip] near = args.near far = args.far print(args.do_intrinsic) print("hwf", hwf) print("near", near) print("far", far) if args.white_bkgd: images = images[..., :3] * images[..., -1:] + (1. - images[..., -1:]) else: images = images[..., :3] elif args.dataset_type == 'deepvoxels': images, poses, render_poses, hwf, i_split = load_dv_data( scene=args.shape, basedir=args.datadir, testskip=args.testskip) print('Loaded deepvoxels', images.shape, render_poses.shape, hwf, args.datadir) i_train, i_val, i_test = i_split hemi_R = np.mean(np.linalg.norm(poses[:, :3, -1], axis=-1)) near = hemi_R - 1. far = hemi_R + 1. else: print('Unknown dataset type', args.dataset_type, 'exiting') return # Cast intrinsics to right types H, W, focal = hwf H, W = int(H), int(W) hwf = [H, W, focal] render_poses = np.array(poses[i_test]) # Create log dir and copy the config file basedir = args.basedir expname = args.expname os.makedirs(os.path.join(basedir, expname), exist_ok=True) f = os.path.join(basedir, expname, 'args.txt') with open(f, 'w') as file: for arg in sorted(vars(args)): attr = getattr(args, arg) file.write('{} = {}\n'.format(arg, attr)) if args.config is not None: f = os.path.join(basedir, expname, 'config.txt') with open(f, 'w') as file: file.write(open(args.config, 'r').read()) # Create nerf model render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf( args) global_step = start bds_dict = { 'near': near, 'far': far, } render_kwargs_train.update(bds_dict) render_kwargs_test.update(bds_dict) # Move testing data to GPU render_poses = jt.array(render_poses) # Short circuit if only rendering out from trained model if args.render_only: print('RENDER ONLY') with jt.no_grad(): testsavedir = os.path.join( basedir, expname, 'renderonly_{}_{:06d}'.format( 'test' if args.render_test else 'path', start)) os.makedirs(testsavedir, exist_ok=True) print('test poses shape', render_poses.shape) rgbs, _ = render_path(render_poses, hwf, args.chunk, render_kwargs_test, savedir=testsavedir, render_factor=args.render_factor) print('Done rendering', testsavedir) imageio.mimwrite(os.path.join(testsavedir, 'video.mp4'), to8b(rgbs), fps=30, quality=8) return # Prepare raybatch tensor if batching random rays accumulation_steps = 1 N_rand = args.N_rand // accumulation_steps use_batching = not args.no_batching if use_batching: # For random ray batching print('get rays') rays = np.stack( [get_rays_np(H, W, focal, p) for p in poses[:, :3, :4]], 0) # [N, ro+rd, H, W, 3] print('done, concats') rays_rgb = np.concatenate([rays, images[:, None]], 1) # [N, ro+rd+rgb, H, W, 3] rays_rgb = np.transpose(rays_rgb, [0, 2, 3, 1, 4]) # [N, H, W, ro+rd+rgb, 3] rays_rgb = np.stack([rays_rgb[i] for i in i_train], 0) # train images only rays_rgb = np.reshape(rays_rgb, [-1, 3, 3]) # [(N-1)*H*W, ro+rd+rgb, 3] rays_rgb = rays_rgb.astype(np.float32) print('shuffle rays') np.random.shuffle(rays_rgb) print('done') i_batch = 0 # Move training data to GPU images = jt.array(images.astype(np.float32)) poses = jt.array(poses) if use_batching: rays_rgb = jt.array(rays_rgb) N_iters = 51000 print('Begin') print('TRAIN views are', i_train) print('TEST views are', i_test) print('VAL views are', i_val) # Summary writers # writer = SummaryWriter(os.path.join(basedir, 'summaries', expname)) if not jt.mpi or jt.mpi.local_rank() == 0: date = str(datetime.datetime.now()) date = date[:date.rfind(":")].replace("-", "")\ .replace(":", "")\ .replace(" ", "_") gpu_idx = os.environ.get("CUDA_VISIBLE_DEVICES", "0") log_dir = os.path.join("./logs", "summaries", "log_" + date + "_gpu" + gpu_idx) if not os.path.exists(log_dir): os.makedirs(log_dir) writer = SummaryWriter(log_dir=log_dir) start = start + 1 for i in trange(start, N_iters): # jt.display_memory_info() time0 = time.time() # Sample random ray batch if use_batching: # Random over all images batch = rays_rgb[i_batch:i_batch + N_rand] # [B, 2+1, 3*?] batch = jt.transpose(batch, (1, 0, 2)) batch_rays, target_s = batch[:2], batch[2] i_batch += N_rand if i_batch >= rays_rgb.shape[0]: print("Shuffle data after an epoch!") rand_idx = jt.randperm(rays_rgb.shape[0]) rays_rgb = rays_rgb[rand_idx] i_batch = 0 else: # Random from one image np.random.seed(i) img_i = np.random.choice(i_train) target = images[img_i] #.squeeze(0) pose = poses[img_i, :3, :4] #.squeeze(0) if N_rand is not None: rays_o, rays_d = pinhole_get_rays( H, W, focal, pose, intrinsic) # (H, W, 3), (H, W, 3) if i < args.precrop_iters: dH = int(H // 2 * args.precrop_frac) dW = int(W // 2 * args.precrop_frac) coords = jt.stack( jt.meshgrid( jt.linspace(H // 2 - dH, H // 2 + dH - 1, 2 * dH), jt.linspace(W // 2 - dW, W // 2 + dW - 1, 2 * dW)), -1) if i == start: print( f"[Config] Center cropping of size {2*dH} x {2*dW} is enabled until iter {args.precrop_iters}" ) else: coords = jt.stack( jt.meshgrid(jt.linspace(0, H - 1, H), jt.linspace(0, W - 1, W)), -1) # (H, W, 2) coords = jt.reshape(coords, [-1, 2]) # (H * W, 2) select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False) # (N_rand,) select_coords = coords[select_inds].int() # (N_rand, 2) rays_o = rays_o[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) rays_d = rays_d[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) batch_rays = jt.stack([rays_o, rays_d], 0) target_s = target[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) ##### Core optimization loop ##### rgb, disp, acc, extras = render(H, W, focal, chunk=args.chunk, rays=batch_rays, verbose=i < 10, retraw=True, **render_kwargs_train) img_loss = img2mse(rgb, target_s) trans = extras['raw'][..., -1] loss = img_loss psnr = mse2psnr(img_loss) if 'rgb0' in extras: img_loss0 = img2mse(extras['rgb0'], target_s) loss = loss + img_loss0 psnr0 = mse2psnr(img_loss0) optimizer.backward(loss / accumulation_steps) if i % accumulation_steps == 0: optimizer.step() ### update learning rate ### decay_rate = 0.1 decay_steps = args.lrate_decay * accumulation_steps * 1000 new_lrate = args.lrate * (decay_rate**(global_step / decay_steps)) for param_group in optimizer.param_groups: param_group['lr'] = new_lrate ################################ dt = time.time() - time0 # Rest is logging if (i + 1) % args.i_weights == 0 and (not jt.mpi or jt.mpi.local_rank() == 0): print(i) path = os.path.join(basedir, expname, '{:06d}.tar'.format(i)) jt.save( { 'global_step': global_step, 'network_fn_state_dict': render_kwargs_train['network_fn'].state_dict(), 'network_fine_state_dict': render_kwargs_train['network_fine'].state_dict(), }, path) print('Saved checkpoints at', path) if i % args.i_video == 0 and i > 0: # Turn on testing mode with jt.no_grad(): rgbs, disps = render_path(render_poses, hwf, args.chunk, render_kwargs_test, intrinsic=intrinsic) if not jt.mpi or jt.mpi.local_rank() == 0: print('Done, saving', rgbs.shape, disps.shape) moviebase = os.path.join( basedir, expname, '{}_spiral_{:06d}_'.format(expname, i)) print('movie base ', moviebase) imageio.mimwrite(moviebase + 'rgb.mp4', to8b(rgbs), fps=30, quality=8) imageio.mimwrite(moviebase + 'disp.mp4', to8b(disps / np.max(disps)), fps=30, quality=8) if i % args.i_print == 0: tqdm.write( f"[TRAIN] Iter: {i} Loss: {loss.item()} PSNR: {psnr.item()}") if i % args.i_img == 0: img_i = np.random.choice(i_val) target = images[img_i] pose = poses[img_i, :3, :4] with jt.no_grad(): rgb, disp, acc, extras = render(H, W, focal, chunk=args.chunk, c2w=pose, intrinsic=intrinsic, **render_kwargs_test) psnr = mse2psnr(img2mse(rgb, target)) rgb = rgb.numpy() disp = disp.numpy() acc = acc.numpy() if not jt.mpi or jt.mpi.local_rank() == 0: writer.add_image('test/rgb', to8b(rgb), global_step, dataformats="HWC") writer.add_image('test/target', target.numpy(), global_step, dataformats="HWC") writer.add_scalar('test/psnr', psnr.item(), global_step) jt.clean_graph() jt.sync_all() jt.gc() if i % args.i_testset == 0 and i > 0: si_test = i_test_tot if i % args.i_tottest == 0 else i_test testsavedir = os.path.join(basedir, expname, 'testset_{:06d}'.format(i)) os.makedirs(testsavedir, exist_ok=True) print('test poses shape', poses[si_test].shape) with jt.no_grad(): rgbs, disps = render_path(jt.array(poses[si_test]), hwf, args.chunk, render_kwargs_test, savedir=testsavedir, intrinsic=intrinsic, expname=expname) jt.gc() global_step += 1
alpha = 0 ckpt_step = step resolution = 4 * 2**step image_loader = SymbolDataset(args.path, transform, resolution).set_attrs( batch_size=batch_size.get( resolution, batch_default), shuffle=True) train_loader = iter(image_loader) jt.save( { 'generator': netG.state_dict(), 'discriminator': netD.state_dict(), 'g_running': g_running.state_dict(), }, f'FFHQ/checkpoint/train_step-{ckpt_step}.model', ) try: real_image = next(train_loader) except (OSError, StopIteration): train_loader = iter(image_loader) real_image = next(train_loader) real_image.requires_grad = True b_size = real_image.size(0) real_scores = netD(real_image, step=step, alpha=alpha) real_predict = jt.nn.softplus(-real_scores).mean()