Exemple #1
0
    def _process(self):
        f = osp.join(self.processed_dir, 'pre_transform.pkl')
        if osp.exists(f) and jt.load(f) != __repr__(self.pre_transform):
            logging.warning(
                'The `pre_transform` argument differs from the one used in '
                'the pre-processed version of this dataset. If you really '
                'want to make use of another pre-processing technique, make '
                'sure to delete `{}` first.'.format(self.processed_dir))
        f = osp.join(self.processed_dir, 'pre_filter.pkl')
        if osp.exists(f) and jt.load(f) != __repr__(self.pre_filter):
            logging.warning(
                'The `pre_filter` argument differs from the one used in the '
                'pre-processed version of this dataset. If you really want to '
                'make use of another pre-fitering technique, make sure to '
                'delete `{}` first.'.format(self.processed_dir))

        if files_exist(self.processed_paths):  # pragma: no cover
            return

        print('Processing...')

        makedirs(self.processed_dir)
        self.process()

        path = osp.join(self.processed_dir, 'pre_transform.pkl')
        jt.save(__repr__(self.pre_transform), path)
        path = osp.join(self.processed_dir, 'pre_filter.pkl')
        jt.save(__repr__(self.pre_filter), path)

        print('Done!')
 def test_save(self):
     pp = [1,2,jt.array([1,2,3]), {"a":[1,2,3], "b":jt.array([1,2,3])}]
     jt.save(pp, "/tmp/xx.pkl")
     x = jt.load("/tmp/xx.pkl")
     assert x[:2] == [1,2]
     assert (x[2] == np.array([1,2,3])).all()
     assert x[3]['a'] == [1,2,3]
     assert (x[3]['b'] == np.array([1,2,3])).all()
Exemple #3
0
def save_checkpoint(checkpoint_path, model, _optimizers, logger, cfg,
                    **kwargs):
    state = {
        'state_dict': model.state_dict(),
        'optimizer': _optimizers.state_dict(),
        'cfg': cfg
    }
    state.update(kwargs)
    jittor.save(state, checkpoint_path)
    logger.info('models saved to %s' % checkpoint_path)
Exemple #4
0
def train(network: model.Siren, optim: jittor.optim.Optimizer, loss_fn, epochs, coords, gt, save_path):
    network.train()
    min_loss = np.inf
    for epoch in tqdm(range(epochs)):
        output = network(coords)
        loss = loss_fn(output, gt)
        optim.step(loss)
        loss = loss.item()
        if loss < min_loss:
            min_loss = loss
            jittor.save(network.state_dict(), save_path)
        if epoch % 10 == 0:
            tqdm.write(f"epoch: {epoch}, loss: {loss}, min_loss: {min_loss}")
Exemple #5
0
 def test_save(self):
     pp = [
         1, 2,
         jt.array([1, 2, 3]), {
             "a": [1, 2, 3],
             "b": jt.array([1, 2, 3])
         }
     ]
     name = jt.flags.cache_path + "/xx.pkl"
     jt.save(pp, name)
     x = jt.load(name)
     assert x[:2] == [1, 2]
     assert (x[2] == np.array([1, 2, 3])).all()
     assert x[3]['a'] == [1, 2, 3]
     assert (x[3]['b'] == np.array([1, 2, 3])).all()
Exemple #6
0
    def save(self, name, **kwargs):
        if not self.save_dir:
            return

        if not self.save_to_disk:
            return

        data = {}
        data["model"] = self.model.state_dict()
        if self.optimizer is not None:
            data["optimizer"] = self.optimizer.state_dict()
        if self.scheduler is not None:
            data["scheduler"] = self.scheduler.state_dict()
        data.update(kwargs)

        save_file = os.path.join(self.save_dir, "{}.pth".format(name))
        self.logger.info("Saving checkpoint to {}".format(save_file))
        jt.save(data, save_file)
        self.tag_last_checkpoint(save_file)
Exemple #7
0
def train(hyp, opt, tb_writer=None):
    logger.info(
        colorstr('hyperparameters: ') + ', '.join(f'{k}={v}'
                                                  for k, v in hyp.items()))
    save_dir, epochs, batch_size, weights = Path(
        opt.save_dir), opt.epochs, opt.batch_size, opt.weights

    # Directories
    wdir = save_dir / 'weights'
    wdir.mkdir(parents=True, exist_ok=True)  # make dir
    last = wdir / 'last.pkl'
    best = wdir / 'best.pkl'
    results_file = save_dir / 'results.txt'

    # Save run settings
    with open(save_dir / 'hyp.yaml', 'w') as f:
        yaml.dump(hyp, f, sort_keys=False)
    with open(save_dir / 'opt.yaml', 'w') as f:
        yaml.dump(vars(opt), f, sort_keys=False)

    # Configure
    plots = not opt.evolve  # create plots
    cuda = not opt.no_cuda
    if cuda:
        jt.flags.use_cuda = 1

    init_seeds(1)
    with open(opt.data) as f:
        data_dict = yaml.load(f, Loader=yaml.SafeLoader)  # data dict

    check_dataset(data_dict)  # check
    train_path = data_dict['train']
    test_path = data_dict['val']
    nc = 1 if opt.single_cls else int(data_dict['nc'])  # number of classes
    names = ['item'] if opt.single_cls and len(
        data_dict['names']) != 1 else data_dict['names']  # class names
    assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (
        len(names), nc, opt.data)  # check

    # Model
    model = Model(opt.cfg, ch=3, nc=nc)  # create
    pretrained = weights.endswith('.pkl')
    if pretrained:
        model.load(weights)  # load

    # Optimizer
    nbs = 64  # nominal batch size
    accumulate = max(round(nbs / batch_size),
                     1)  # accumulate loss before optimizing
    hyp['weight_decay'] *= batch_size * accumulate / nbs  # scale weight_decay
    logger.info(f"Scaled weight_decay = {hyp['weight_decay']}")

    pg0, pg1, pg2 = [], [], []  # optimizer parameter groups
    for k, v in model.named_modules():
        if hasattr(v, 'bias') and isinstance(v.bias, jt.Var):
            pg2.append(v.bias)  # biases
        if isinstance(v, nn.BatchNorm):
            pg0.append(v.weight)  # no decay
        elif hasattr(v, 'weight') and isinstance(v.weight, jt.Var):
            pg1.append(v.weight)  # apply decay

    if opt.adam:
        optimizer = optim.Adam(pg0,
                               lr=hyp['lr0'],
                               betas=(hyp['momentum'],
                                      0.999))  # adjust beta1 to momentum
    else:
        optimizer = optim.SGD(pg0,
                              lr=hyp['lr0'],
                              momentum=hyp['momentum'],
                              nesterov=True)

    optimizer.add_param_group({
        'params': pg1,
        'weight_decay': hyp['weight_decay']
    })  # add pg1 with weight_decay
    optimizer.add_param_group({'params': pg2})  # add pg2 (biases)
    logger.info('Optimizer groups: %g .bias, %g conv.weight, %g other' %
                (len(pg2), len(pg1), len(pg0)))
    del pg0, pg1, pg2

    # Scheduler https://arxiv.org/pdf/1812.01187.pdf
    # https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
    lf = one_cycle(1, hyp['lrf'], epochs)  # cosine 1->hyp['lrf']
    scheduler = optim.LambdaLR(optimizer, lr_lambda=lf)
    # plot_lr_scheduler(optimizer, scheduler, epochs)

    loggers = {}  # loggers dict

    start_epoch, best_fitness = 0, 0.0

    # Image sizes
    gs = int(model.stride.max())  # grid size (max stride)
    nl = model.model[
        -1].nl  # number of detection layers (used for scaling hyp['obj'])
    imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size
                         ]  # verify imgsz are gs-multiples

    # EMA
    ema = ModelEMA(model)

    # Trainloader
    dataloader = create_dataloader(train_path,
                                   imgsz,
                                   batch_size,
                                   gs,
                                   opt,
                                   hyp=hyp,
                                   augment=True,
                                   cache=opt.cache_images,
                                   rect=opt.rect,
                                   workers=opt.workers,
                                   image_weights=opt.image_weights,
                                   quad=opt.quad,
                                   prefix=colorstr('train: '))

    mlc = np.concatenate(dataloader.labels, 0)[:, 0].max()  # max label class
    nb = len(dataloader)  # number of batches
    assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (
        mlc, nc, opt.data, nc - 1)

    ema.updates = start_epoch * nb // accumulate  # set EMA updates
    testloader = create_dataloader(
        test_path,
        imgsz_test,
        batch_size,
        gs,
        opt,  # testloader
        hyp=hyp,
        cache=opt.cache_images and not opt.notest,
        rect=True,
        workers=opt.workers,
        pad=0.5,
        prefix=colorstr('val: '))

    labels = np.concatenate(dataloader.labels, 0)
    c = jt.array(labels[:, 0])  # classes

    # cf = torch.bincount(c.int(), minlength=nc) + 1.  # frequency
    # model._initialize_biases(cf)
    if plots:
        plot_labels(labels, save_dir, loggers)
        if tb_writer:
            tb_writer.add_histogram('classes', c.numpy(), 0)

    # Anchors
    if not opt.noautoanchor:
        check_anchors(dataloader,
                      model=model,
                      thr=hyp['anchor_t'],
                      imgsz=imgsz)

    # Model parameters
    hyp['box'] *= 3. / nl  # scale to layers
    hyp['cls'] *= nc / 80. * 3. / nl  # scale to classes and layers
    hyp['obj'] *= (imgsz / 640)**2 * 3. / nl  # scale to image size and layers
    model.nc = nc  # attach number of classes to model
    model.hyp = hyp  # attach hyperparameters to model
    model.gr = 1.0  # iou loss ratio (obj_loss = 1.0 or iou)
    model.class_weights = labels_to_class_weights(
        dataloader.labels, nc) * nc  # attach class weights
    model.names = names
    # Start training
    t0 = time.time()
    nw = max(round(hyp['warmup_epochs'] * nb),
             1000)  # number of warmup iterations, max(3 epochs, 1k iterations)
    # nw = min(nw, (epochs - start_epoch) / 2 * nb)  # limit warmup to < 1/2 of training
    maps = np.zeros(nc)  # mAP per class
    results = (0, 0, 0, 0, 0, 0, 0
               )  # P, R, [email protected], [email protected], val_loss(box, obj, cls)
    scheduler.last_epoch = start_epoch - 1  # do not move
    logger.info(f'Image sizes {imgsz} train, {imgsz_test} test\n'
                f'Using {dataloader.num_workers} dataloader workers\n'
                f'Logging results to {save_dir}\n'
                f'Starting training for {epochs} epochs...')
    for epoch in range(
            start_epoch, epochs
    ):  # epoch ------------------------------------------------------------------
        model.train()

        # Update image weights (optional)
        if opt.image_weights:
            # Generate indices
            cw = model.class_weights.numpy() * (1 -
                                                maps)**2 / nc  # class weights
            iw = labels_to_image_weights(dataloader.labels,
                                         nc=nc,
                                         class_weights=cw)  # image weights
            dataloader.indices = random.choices(
                range(dataloader.n), weights=iw,
                k=dataloader.n)  # rand weighted idx

        # Update mosaic border
        # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
        # dataset.mosaic_border = [b - imgsz, -b]  # height, width borders

        mloss = jt.zeros((4, ))  # mean losses
        pbar = enumerate(dataloader)
        logger.info(
            ('\n' + '%10s' * 7) %
            ('Epoch', 'box', 'obj', 'cls', 'total', 'targets', 'img_size'))
        pbar = tqdm(pbar, total=nb)  # progress bar
        for i, (
                imgs, targets, paths, _
        ) in pbar:  # batch -------------------------------------------------------------
            ni = i + nb * epoch  # number integrated batches (since train start)
            imgs = imgs.float() / 255.0  # uint8 to float32, 0-255 to 0.0-1.0

            # Warmup
            if ni <= nw:
                xi = [0, nw]  # x interp
                # model.gr = np.interp(ni, xi, [0.0, 1.0])  # iou loss ratio (obj_loss = 1.0 or iou)
                # accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round())

                for j, x in enumerate(optimizer.param_groups):
                    # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
                    x['lr'] = np.interp(ni, xi, [
                        hyp['warmup_bias_lr'] if j == 2 else 0.0,
                        x['initial_lr'] * lf(epoch)
                    ])
                    if 'momentum' in x:
                        x['momentum'] = np.interp(
                            ni, xi, [hyp['warmup_momentum'], hyp['momentum']])

            # Multi-scale
            if opt.multi_scale:
                sz = random.randrange(imgsz * 0.5,
                                      imgsz * 1.5 + gs) // gs * gs  # size
                sf = sz / max(imgs.shape[2:])  # scale factor
                if sf != 1:
                    ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]
                          ]  # new shape (stretched to gs-multiple)
                    imgs = nn.interpolate(imgs,
                                          size=ns,
                                          mode='bilinear',
                                          align_corners=False)
            # Forward
            pred = model(imgs)  # forward
            loss, loss_items = compute_loss(pred, targets,
                                            model)  # loss scaled by batch_size
            if opt.quad:
                loss *= 4.

            # Optimize
            optimizer.step(loss)
            if ema:
                ema.update(model)

            # Print
            mloss = (mloss * i + loss_items) / (i + 1)  # update mean losses
            s = ('%10s' + '%10.4g' * 6) % ('%g/%g' %
                                           (epoch, epochs - 1), *mloss,
                                           targets.shape[0], imgs.shape[-1])
            pbar.set_description(s)

            # Plot
            if plots and ni < 3:
                f = save_dir / f'train_batch{ni}.jpg'  # filename
                Thread(target=plot_images,
                       args=(imgs, targets, paths, f),
                       daemon=True).start()
                # if tb_writer:
                #     tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
                #     tb_writer.add_graph(model, imgs)  # add model to tensorboard

            # end batch ------------------------------------------------------------------------------------------------
        # end epoch ----------------------------------------------------------------------------------------------------

        # Scheduler
        lr = [x['lr'] for x in optimizer.param_groups]  # for tensorboard
        scheduler.step()

        # mAP
        if ema:
            ema.update_attr(model,
                            include=[
                                'yaml', 'nc', 'hyp', 'gr', 'names', 'stride',
                                'class_weights'
                            ])
        final_epoch = epoch + 1 == epochs
        if not opt.notest or final_epoch:  # Calculate mAP
            results, maps, times = test.test(data=opt.data,
                                             batch_size=batch_size,
                                             imgsz=imgsz_test,
                                             model=ema.ema,
                                             single_cls=opt.single_cls,
                                             dataloader=testloader,
                                             save_dir=save_dir,
                                             plots=plots and final_epoch)

        # Write
        with open(results_file, 'a') as f:
            f.write(s + '%10.4g' * 7 % results +
                    '\n')  # P, R, [email protected], [email protected], val_loss(box, obj, cls)
        if len(opt.name) and opt.bucket:
            os.system('gsutil cp %s gs://%s/results/results%s.txt' %
                      (results_file, opt.bucket, opt.name))

        # Log
        tags = [
            'train/box_loss',
            'train/obj_loss',
            'train/cls_loss',  # train loss
            'metrics/precision',
            'metrics/recall',
            'metrics/mAP_0.5',
            'metrics/mAP_0.5-0.95',
            'val/box_loss',
            'val/obj_loss',
            'val/cls_loss',  # val loss
            'x/lr0',
            'x/lr1',
            'x/lr2'
        ]  # params
        for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
            if tb_writer:
                if hasattr(x, "numpy"):
                    x = x.numpy()
                tb_writer.add_scalar(tag, x, epoch)  # tensorboard

        # Update best mAP
        fi = fitness(np.array(results).reshape(
            1, -1))  # weighted combination of [P, R, [email protected], [email protected]]
        if fi > best_fitness:
            best_fitness = fi

        # Save model
        save = (not opt.nosave) or (final_epoch and not opt.evolve)
        if save:
            # Save last, best and delete
            jt.save(ema.ema.state_dict(), last)
            if best_fitness == fi:
                jt.save(ema.ema.state_dict(), best)
        # end epoch ----------------------------------------------------------------------------------------------------
    # end training
    # Strip optimizers
    final = best if best.exists() else last  # final model
    if opt.bucket:
        os.system(f'gsutil cp {final} gs://{opt.bucket}/weights')  # upload

    # Plots
    if plots:
        plot_results(save_dir=save_dir)  # save as results.png

    # Test best.pkl
    logger.info('%g epochs completed in %.3f hours.\n' %
                (epoch - start_epoch + 1, (time.time() - t0) / 3600))
    best_model = Model(opt.cfg)
    best_model.load(str(final))
    best_model = best_model.fuse()
    if opt.data.endswith('coco.yaml') and nc == 80:  # if COCO
        for conf, iou, save_json in ([0.25, 0.45,
                                      False], [0.001, 0.65,
                                               True]):  # speed, mAP tests
            results, _, _ = test.test(opt.data,
                                      batch_size=total_batch_size,
                                      imgsz=imgsz_test,
                                      conf_thres=conf,
                                      iou_thres=iou,
                                      model=best_model,
                                      single_cls=opt.single_cls,
                                      dataloader=testloader,
                                      save_dir=save_dir,
                                      save_json=save_json,
                                      plots=False)

    return results
Exemple #8
0
 def process(self):
     data = read_planetoid_data(self.raw_dir, self.name)
     data = data if self.pre_transform is None else self.pre_transform(data)
     jt.save(self.collate([data]), self.processed_paths[0])
Exemple #9
0
def train():

    parser = config_parser()
    args = parser.parse_args()

    # Load data
    intrinsic = None
    if args.dataset_type == 'llff':
        images, poses, bds, render_poses, i_test = load_llff_data(
            args.datadir,
            args.factor,
            recenter=True,
            bd_factor=.75,
            spherify=args.spherify)
        hwf = poses[0, :3, -1]
        poses = poses[:, :3, :4]
        print('Loaded llff', images.shape, render_poses.shape, hwf,
              args.datadir)
        if not isinstance(i_test, list):
            i_test = [i_test]

        if args.llffhold > 0:
            print('Auto LLFF holdout,', args.llffhold)
            i_test = np.arange(images.shape[0])[::args.llffhold]

        i_val = i_test
        i_train = np.array([
            i for i in np.arange(int(images.shape[0]))
            if (i not in i_test and i not in i_val)
        ])

        print('DEFINING BOUNDS')
        if args.no_ndc:
            near = np.ndarray.min(bds) * .9
            far = np.ndarray.max(bds) * 1.

        else:
            near = 0.
            far = 1.
        print('NEAR FAR', near, far)

    elif args.dataset_type == 'blender':
        testskip = args.testskip
        faketestskip = args.faketestskip
        if jt.mpi and jt.mpi.local_rank() != 0:
            testskip = faketestskip
            faketestskip = 1
        if args.do_intrinsic:
            images, poses, intrinsic, render_poses, hwf, i_split = load_blender_data(
                args.datadir, args.half_res, args.testskip,
                args.blender_factor, True)
        else:
            images, poses, render_poses, hwf, i_split = load_blender_data(
                args.datadir, args.half_res, args.testskip,
                args.blender_factor)
        print('Loaded blender', images.shape, render_poses.shape, hwf,
              args.datadir)
        i_train, i_val, i_test = i_split
        i_test_tot = i_test
        i_test = i_test[::args.faketestskip]

        near = args.near
        far = args.far
        print(args.do_intrinsic)
        print("hwf", hwf)
        print("near", near)
        print("far", far)

        if args.white_bkgd:
            images = images[..., :3] * images[..., -1:] + (1. -
                                                           images[..., -1:])
        else:
            images = images[..., :3]

    elif args.dataset_type == 'deepvoxels':

        images, poses, render_poses, hwf, i_split = load_dv_data(
            scene=args.shape, basedir=args.datadir, testskip=args.testskip)

        print('Loaded deepvoxels', images.shape, render_poses.shape, hwf,
              args.datadir)
        i_train, i_val, i_test = i_split

        hemi_R = np.mean(np.linalg.norm(poses[:, :3, -1], axis=-1))
        near = hemi_R - 1.
        far = hemi_R + 1.

    else:
        print('Unknown dataset type', args.dataset_type, 'exiting')
        return

    # Cast intrinsics to right types
    H, W, focal = hwf
    H, W = int(H), int(W)
    hwf = [H, W, focal]

    render_poses = np.array(poses[i_test])

    # Create log dir and copy the config file
    basedir = args.basedir
    expname = args.expname
    os.makedirs(os.path.join(basedir, expname), exist_ok=True)
    f = os.path.join(basedir, expname, 'args.txt')
    with open(f, 'w') as file:
        for arg in sorted(vars(args)):
            attr = getattr(args, arg)
            file.write('{} = {}\n'.format(arg, attr))
    if args.config is not None:
        f = os.path.join(basedir, expname, 'config.txt')
        with open(f, 'w') as file:
            file.write(open(args.config, 'r').read())

    # Create nerf model
    render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf(
        args)
    global_step = start

    bds_dict = {
        'near': near,
        'far': far,
    }
    render_kwargs_train.update(bds_dict)
    render_kwargs_test.update(bds_dict)

    # Move testing data to GPU
    render_poses = jt.array(render_poses)

    # Short circuit if only rendering out from trained model
    if args.render_only:
        print('RENDER ONLY')
        with jt.no_grad():
            testsavedir = os.path.join(
                basedir, expname, 'renderonly_{}_{:06d}'.format(
                    'test' if args.render_test else 'path', start))
            os.makedirs(testsavedir, exist_ok=True)
            print('test poses shape', render_poses.shape)

            rgbs, _ = render_path(render_poses,
                                  hwf,
                                  args.chunk,
                                  render_kwargs_test,
                                  savedir=testsavedir,
                                  render_factor=args.render_factor)
            print('Done rendering', testsavedir)
            imageio.mimwrite(os.path.join(testsavedir, 'video.mp4'),
                             to8b(rgbs),
                             fps=30,
                             quality=8)

            return

    # Prepare raybatch tensor if batching random rays
    accumulation_steps = 1
    N_rand = args.N_rand // accumulation_steps
    use_batching = not args.no_batching
    if use_batching:
        # For random ray batching
        print('get rays')
        rays = np.stack(
            [get_rays_np(H, W, focal, p) for p in poses[:, :3, :4]],
            0)  # [N, ro+rd, H, W, 3]
        print('done, concats')
        rays_rgb = np.concatenate([rays, images[:, None]],
                                  1)  # [N, ro+rd+rgb, H, W, 3]
        rays_rgb = np.transpose(rays_rgb,
                                [0, 2, 3, 1, 4])  # [N, H, W, ro+rd+rgb, 3]
        rays_rgb = np.stack([rays_rgb[i] for i in i_train],
                            0)  # train images only
        rays_rgb = np.reshape(rays_rgb,
                              [-1, 3, 3])  # [(N-1)*H*W, ro+rd+rgb, 3]
        rays_rgb = rays_rgb.astype(np.float32)
        print('shuffle rays')
        np.random.shuffle(rays_rgb)

        print('done')
        i_batch = 0

    # Move training data to GPU
    images = jt.array(images.astype(np.float32))
    poses = jt.array(poses)
    if use_batching:
        rays_rgb = jt.array(rays_rgb)

    N_iters = 51000
    print('Begin')
    print('TRAIN views are', i_train)
    print('TEST views are', i_test)
    print('VAL views are', i_val)

    # Summary writers
    # writer = SummaryWriter(os.path.join(basedir, 'summaries', expname))
    if not jt.mpi or jt.mpi.local_rank() == 0:
        date = str(datetime.datetime.now())
        date = date[:date.rfind(":")].replace("-", "")\
                                        .replace(":", "")\
                                        .replace(" ", "_")
        gpu_idx = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
        log_dir = os.path.join("./logs", "summaries",
                               "log_" + date + "_gpu" + gpu_idx)
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)
        writer = SummaryWriter(log_dir=log_dir)

    start = start + 1
    for i in trange(start, N_iters):
        # jt.display_memory_info()
        time0 = time.time()

        # Sample random ray batch
        if use_batching:
            # Random over all images
            batch = rays_rgb[i_batch:i_batch + N_rand]  # [B, 2+1, 3*?]
            batch = jt.transpose(batch, (1, 0, 2))
            batch_rays, target_s = batch[:2], batch[2]

            i_batch += N_rand
            if i_batch >= rays_rgb.shape[0]:
                print("Shuffle data after an epoch!")
                rand_idx = jt.randperm(rays_rgb.shape[0])
                rays_rgb = rays_rgb[rand_idx]
                i_batch = 0

        else:
            # Random from one image
            np.random.seed(i)
            img_i = np.random.choice(i_train)
            target = images[img_i]  #.squeeze(0)
            pose = poses[img_i, :3, :4]  #.squeeze(0)
            if N_rand is not None:
                rays_o, rays_d = pinhole_get_rays(
                    H, W, focal, pose, intrinsic)  # (H, W, 3), (H, W, 3)
                if i < args.precrop_iters:
                    dH = int(H // 2 * args.precrop_frac)
                    dW = int(W // 2 * args.precrop_frac)
                    coords = jt.stack(
                        jt.meshgrid(
                            jt.linspace(H // 2 - dH, H // 2 + dH - 1, 2 * dH),
                            jt.linspace(W // 2 - dW, W // 2 + dW - 1, 2 * dW)),
                        -1)
                    if i == start:
                        print(
                            f"[Config] Center cropping of size {2*dH} x {2*dW} is enabled until iter {args.precrop_iters}"
                        )
                else:
                    coords = jt.stack(
                        jt.meshgrid(jt.linspace(0, H - 1, H),
                                    jt.linspace(0, W - 1, W)), -1)  # (H, W, 2)

                coords = jt.reshape(coords, [-1, 2])  # (H * W, 2)
                select_inds = np.random.choice(coords.shape[0],
                                               size=[N_rand],
                                               replace=False)  # (N_rand,)
                select_coords = coords[select_inds].int()  # (N_rand, 2)
                rays_o = rays_o[select_coords[:, 0],
                                select_coords[:, 1]]  # (N_rand, 3)
                rays_d = rays_d[select_coords[:, 0],
                                select_coords[:, 1]]  # (N_rand, 3)
                batch_rays = jt.stack([rays_o, rays_d], 0)
                target_s = target[select_coords[:, 0],
                                  select_coords[:, 1]]  # (N_rand, 3)

        #####  Core optimization loop  #####
        rgb, disp, acc, extras = render(H,
                                        W,
                                        focal,
                                        chunk=args.chunk,
                                        rays=batch_rays,
                                        verbose=i < 10,
                                        retraw=True,
                                        **render_kwargs_train)
        img_loss = img2mse(rgb, target_s)
        trans = extras['raw'][..., -1]
        loss = img_loss
        psnr = mse2psnr(img_loss)

        if 'rgb0' in extras:
            img_loss0 = img2mse(extras['rgb0'], target_s)
            loss = loss + img_loss0
            psnr0 = mse2psnr(img_loss0)

        optimizer.backward(loss / accumulation_steps)
        if i % accumulation_steps == 0:
            optimizer.step()

        ###   update learning rate   ###
        decay_rate = 0.1
        decay_steps = args.lrate_decay * accumulation_steps * 1000
        new_lrate = args.lrate * (decay_rate**(global_step / decay_steps))
        for param_group in optimizer.param_groups:
            param_group['lr'] = new_lrate
        ################################

        dt = time.time() - time0

        # Rest is logging
        if (i + 1) % args.i_weights == 0 and (not jt.mpi
                                              or jt.mpi.local_rank() == 0):
            print(i)
            path = os.path.join(basedir, expname, '{:06d}.tar'.format(i))
            jt.save(
                {
                    'global_step':
                    global_step,
                    'network_fn_state_dict':
                    render_kwargs_train['network_fn'].state_dict(),
                    'network_fine_state_dict':
                    render_kwargs_train['network_fine'].state_dict(),
                }, path)
            print('Saved checkpoints at', path)

        if i % args.i_video == 0 and i > 0:
            # Turn on testing mode
            with jt.no_grad():
                rgbs, disps = render_path(render_poses,
                                          hwf,
                                          args.chunk,
                                          render_kwargs_test,
                                          intrinsic=intrinsic)
            if not jt.mpi or jt.mpi.local_rank() == 0:
                print('Done, saving', rgbs.shape, disps.shape)
                moviebase = os.path.join(
                    basedir, expname, '{}_spiral_{:06d}_'.format(expname, i))
                print('movie base ', moviebase)
                imageio.mimwrite(moviebase + 'rgb.mp4',
                                 to8b(rgbs),
                                 fps=30,
                                 quality=8)
                imageio.mimwrite(moviebase + 'disp.mp4',
                                 to8b(disps / np.max(disps)),
                                 fps=30,
                                 quality=8)

        if i % args.i_print == 0:
            tqdm.write(
                f"[TRAIN] Iter: {i} Loss: {loss.item()}  PSNR: {psnr.item()}")
            if i % args.i_img == 0:
                img_i = np.random.choice(i_val)
                target = images[img_i]
                pose = poses[img_i, :3, :4]
                with jt.no_grad():
                    rgb, disp, acc, extras = render(H,
                                                    W,
                                                    focal,
                                                    chunk=args.chunk,
                                                    c2w=pose,
                                                    intrinsic=intrinsic,
                                                    **render_kwargs_test)
                psnr = mse2psnr(img2mse(rgb, target))
                rgb = rgb.numpy()
                disp = disp.numpy()
                acc = acc.numpy()

                if not jt.mpi or jt.mpi.local_rank() == 0:
                    writer.add_image('test/rgb',
                                     to8b(rgb),
                                     global_step,
                                     dataformats="HWC")
                    writer.add_image('test/target',
                                     target.numpy(),
                                     global_step,
                                     dataformats="HWC")
                    writer.add_scalar('test/psnr', psnr.item(), global_step)

            jt.clean_graph()
            jt.sync_all()
            jt.gc()

            if i % args.i_testset == 0 and i > 0:
                si_test = i_test_tot if i % args.i_tottest == 0 else i_test
                testsavedir = os.path.join(basedir, expname,
                                           'testset_{:06d}'.format(i))
                os.makedirs(testsavedir, exist_ok=True)
                print('test poses shape', poses[si_test].shape)
                with jt.no_grad():
                    rgbs, disps = render_path(jt.array(poses[si_test]),
                                              hwf,
                                              args.chunk,
                                              render_kwargs_test,
                                              savedir=testsavedir,
                                              intrinsic=intrinsic,
                                              expname=expname)
                jt.gc()
        global_step += 1
Exemple #10
0
                alpha = 0
                ckpt_step = step

            resolution = 4 * 2**step

            image_loader = SymbolDataset(args.path, transform,
                                         resolution).set_attrs(
                                             batch_size=batch_size.get(
                                                 resolution, batch_default),
                                             shuffle=True)
            train_loader = iter(image_loader)

            jt.save(
                {
                    'generator': netG.state_dict(),
                    'discriminator': netD.state_dict(),
                    'g_running': g_running.state_dict(),
                },
                f'FFHQ/checkpoint/train_step-{ckpt_step}.model',
            )

        try:
            real_image = next(train_loader)
        except (OSError, StopIteration):
            train_loader = iter(image_loader)
            real_image = next(train_loader)

        real_image.requires_grad = True
        b_size = real_image.size(0)

        real_scores = netD(real_image, step=step, alpha=alpha)
        real_predict = jt.nn.softplus(-real_scores).mean()