def testScalarSummaryNameScope(self):
    """Test record_summaries_every_n_global_steps and all_summaries()."""
    with ops.Graph().as_default(), self.cached_session() as sess:
      global_step = training_util.get_or_create_global_step()
      global_step.initializer.run()
      with ops.device('/cpu:0'):
        step_increment = state_ops.assign_add(global_step, 1)
      sess.run(step_increment)  # Increment global step from 0 to 1

      logdir = tempfile.mkdtemp()
      with summary_ops.create_file_writer(logdir, max_queue=0,
                                          name='t2').as_default():
        with summary_ops.record_summaries_every_n_global_steps(2):
          summary_ops.initialize()
          with ops.name_scope('scope'):
            summary_op = summary_ops.scalar('my_scalar', 2.0)

          # Neither of these should produce a summary because
          # global_step is 1 and "1 % 2 != 0"
          sess.run(summary_ops.all_summary_ops())
          sess.run(summary_op)
          events = summary_test_util.events_from_logdir(logdir)
          self.assertEqual(len(events), 1)

          # Increment global step from 1 to 2 and check that the summary
          # is now written
          sess.run(step_increment)
          sess.run(summary_ops.all_summary_ops())
          events = summary_test_util.events_from_logdir(logdir)
          self.assertEqual(len(events), 2)
          self.assertEqual(events[1].summary.value[0].tag, 'scope/my_scalar')
    def testScalarSummaryNameScope(self):
        """Test record_summaries_every_n_global_steps and all_summaries()."""
        with ops.Graph().as_default(), self.cached_session() as sess:
            global_step = training_util.get_or_create_global_step()
            global_step.initializer.run()
            with ops.device('/cpu:0'):
                step_increment = state_ops.assign_add(global_step, 1)
            sess.run(step_increment)  # Increment global step from 0 to 1

            logdir = tempfile.mkdtemp()
            with summary_ops.create_file_writer(logdir, max_queue=0,
                                                name='t2').as_default():
                with summary_ops.record_summaries_every_n_global_steps(2):
                    summary_ops.initialize()
                    with ops.name_scope('scope'):
                        summary_op = summary_ops.scalar('my_scalar', 2.0)

                    # Neither of these should produce a summary because
                    # global_step is 1 and "1 % 2 != 0"
                    sess.run(summary_ops.all_summary_ops())
                    sess.run(summary_op)
                    events = summary_test_util.events_from_logdir(logdir)
                    self.assertEqual(len(events), 1)

                    # Increment global step from 1 to 2 and check that the summary
                    # is now written
                    sess.run(step_increment)
                    sess.run(summary_ops.all_summary_ops())
                    events = summary_test_util.events_from_logdir(logdir)
                    self.assertEqual(len(events), 2)
                    self.assertEqual(events[1].summary.value[0].tag,
                                     'scope/my_scalar')
Exemple #3
0
        def _host_call_fn(**kwargs):
            """Training host call.

      Creates summaries for training metrics.

      Args:
        **kwargs: Dict of {str: Tensor} , with `Tensor` of shape `[batch]`. Must
          contain key "global_step" with value of current global_step Tensor.

      Returns:
        List of summary ops to run on the CPU host.
      """

            from tensorflow.python.ops import summary_ops_v2  # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top

            gs = tf.cast(kwargs.pop("global_step")[0], dtype=tf.int64)
            for i, summary in enumerate(current_iteration.summaries):
                with summary_ops_v2.create_file_writer(
                        summary.logdir).as_default():
                    with summary_ops_v2.record_summaries_every_n_global_steps(
                            n=self.config.save_summary_steps, global_step=gs):
                        for j, summary_fn in enumerate(summary_fns[i]):
                            tensor = kwargs["summary_{}_{}".format(i, j)]
                            summary_fn(tensor, step=gs)
                summary.clear_summary_tuples()
            return tf.compat.v1.summary.all_v2_summary_ops()
Exemple #4
0
    def construct(self, args):
        with self.session.graph.as_default():
            # Inputs
            self.images = tf.placeholder(tf.float32,
                                         [None, MNIST.H, MNIST.W, MNIST.C],
                                         name="images")
            self.labels = tf.placeholder(tf.int64, [None], name="labels")

            # Computation
            hidden = tf.keras.layers.Flatten()(self.images)
            # TODO: Add `args.layers` number of hidden layers with size `args.hidden_layer`,
            # using activation from `args.activation`, allowing "none", "relu", "tanh", "sigmoid".
            # Store the results back to `hidden` variable.
            output_layer = tf.keras.layers.Dense(MNIST.LABELS)(hidden)
            self.predictions = tf.argmax(output_layer, axis=1)

            # Training
            loss = tf.keras.losses.sparse_categorical_crossentropy(
                self.labels, output_layer, from_logits=True)
            global_step = tf.train.create_global_step()
            self.training = tf.train.AdamOptimizer().minimize(
                loss, global_step=global_step, name="training")

            # Summaries
            accuracy = tf.math.reduce_mean(
                tf.cast(tf.equal(self.labels, self.predictions), tf.float32))
            confusion_matrix = tf.reshape(
                tf.confusion_matrix(self.labels,
                                    self.predictions,
                                    weights=tf.not_equal(
                                        self.labels, self.predictions),
                                    dtype=tf.float32),
                [1, MNIST.LABELS, MNIST.LABELS, 1])

            summary_writer = tf_summary.create_file_writer(args.logdir,
                                                           flush_millis=10 *
                                                           1000)
            self.summaries = {}
            with summary_writer.as_default(
            ), tf_summary.record_summaries_every_n_global_steps(100):
                self.summaries["train"] = [
                    tf_summary.scalar("train/loss", loss),
                    tf_summary.scalar("train/accuracy", accuracy)
                ]
            with summary_writer.as_default(
            ), tf_summary.always_record_summaries():
                for dataset in ["dev", "test"]:
                    self.summaries[dataset] = [
                        tf_summary.scalar(dataset + "/accuracy", accuracy),
                        tf_summary.image(dataset + "/confusion_matrix",
                                         confusion_matrix)
                    ]
                    with tf.control_dependencies(self.summaries[dataset]):
                        self.summaries[dataset].append(summary_writer.flush())

            # Initialize variables
            self.session.run(tf.global_variables_initializer())
            with summary_writer.as_default():
                tf_summary.initialize(session=self.session,
                                      graph=self.session.graph)
    def testRecordEveryNGlobalSteps(self):
        step = training_util.get_or_create_global_step()
        logdir = tempfile.mkdtemp()

        def run_step():
            summary_ops.scalar('scalar', i, step=step)
            step.assign_add(1)

        with summary_ops.create_file_writer(logdir).as_default(
        ), summary_ops.record_summaries_every_n_global_steps(2, step):
            for i in range(10):
                run_step()
            # And another 10 steps as a graph function.
            run_step_fn = function.defun(run_step)
            for i in range(10):
                run_step_fn()

        events = summary_test_util.events_from_logdir(logdir)
        self.assertEqual(len(events), 11)
Exemple #6
0
def train():

    parser = config_parser()
    args = parser.parse_args()

    if args.random_seed is not None:
        print('Fixing random seed', args.random_seed)
        np.random.seed(args.random_seed)
        tf.compat.v1.set_random_seed(args.random_seed)

    # Load data

    if args.dataset_type == 'llff':
        images, poses, bds, render_poses, i_test = load_llff_data(
            args.datadir,
            args.factor,
            recenter=True,
            bd_factor=.75,
            spherify=args.spherify)
        hwf = poses[0, :3, -1]
        poses = poses[:, :3, :4]
        print('Loaded llff', images.shape, render_poses.shape, hwf,
              args.datadir)
        if not isinstance(i_test, list):
            i_test = [i_test]

        if args.llffhold > 0:
            print('Auto LLFF holdout,', args.llffhold)
            i_test = np.arange(images.shape[0])[::args.llffhold]

        i_val = i_test
        i_train = np.array([
            i for i in np.arange(int(images.shape[0]))
            if (i not in i_test and i not in i_val)
        ])

        print('DEFINING BOUNDS')
        if args.no_ndc:
            near = tf.reduce_min(bds) * .9
            far = tf.reduce_max(bds) * 1.
        else:
            near = 0.
            far = 1.
        print('NEAR FAR', near, far)

    elif args.dataset_type == 'blender':
        images, poses, render_poses, hwf, i_split = load_blender_data(
            args.datadir, args.half_res, args.testskip)
        print('Loaded blender', images.shape, render_poses.shape, hwf,
              args.datadir)
        i_train, i_val, i_test = i_split

        near = 2.
        far = 6.

        if args.white_bkgd:
            images = images[..., :3] * images[..., -1:] + (1. -
                                                           images[..., -1:])
        else:
            images = images[..., :3]

    elif args.dataset_type == 'deepvoxels':

        images, poses, render_poses, hwf, i_split = load_dv_data(
            scene=args.shape, basedir=args.datadir, testskip=args.testskip)

        print('Loaded deepvoxels', images.shape, render_poses.shape, hwf,
              args.datadir)
        i_train, i_val, i_test = i_split

        hemi_R = np.mean(np.linalg.norm(poses[:, :3, -1], axis=-1))
        near = hemi_R - 1.
        far = hemi_R + 1.

    else:
        print('Unknown dataset type', args.dataset_type, 'exiting')
        return

    # Cast intrinsics to right types
    H, W, focal = hwf
    H, W = int(H), int(W)
    hwf = [H, W, focal]

    if args.render_test:
        render_poses = np.array(poses[i_test])

    # Create log dir and copy the config file
    basedir = args.basedir
    expname = args.expname
    os.makedirs(os.path.join(basedir, expname), exist_ok=True)
    f = os.path.join(basedir, expname, 'args.txt')
    with open(f, 'w') as file:
        for arg in sorted(vars(args)):
            attr = getattr(args, arg)
            file.write('{} = {}\n'.format(arg, attr))
    if args.config is not None:
        f = os.path.join(basedir, expname, 'config.txt')
        with open(f, 'w') as file:
            file.write(open(args.config, 'r').read())

    # Create nerf model
    render_kwargs_train, render_kwargs_test, start, grad_vars, models = create_nerf(
        args)

    bds_dict = {
        'near': tf.cast(near, tf.float32),
        'far': tf.cast(far, tf.float32),
    }
    render_kwargs_train.update(bds_dict)
    render_kwargs_test.update(bds_dict)

    # Short circuit if only rendering out from trained model
    if args.render_only:
        print('RENDER ONLY')
        if args.render_test:
            # render_test switches to test poses
            images = images[i_test]
        else:
            # Default is smoother render_poses path
            images = None

        testsavedir = os.path.join(
            basedir, expname, 'renderonly_{}_{:06d}'.format(
                'test' if args.render_test else 'path', start))
        os.makedirs(testsavedir, exist_ok=True)
        print('test poses shape', render_poses.shape)

        rgbs, _ = render_path(render_poses,
                              hwf,
                              args.chunk,
                              render_kwargs_test,
                              gt_imgs=images,
                              savedir=testsavedir,
                              render_factor=args.render_factor)
        print('Done rendering', testsavedir)
        imageio.mimwrite(os.path.join(testsavedir, 'video.mp4'),
                         to8b(rgbs),
                         fps=30,
                         quality=8)

        return

    # Create optimizer
    lrate = args.lrate
    if args.lrate_decay > 0:
        lrate = tf.keras.optimizers.schedules.ExponentialDecay(
            lrate, decay_steps=args.lrate_decay * 1000, decay_rate=0.1)
    optimizer = tf.keras.optimizers.Adam(lrate)
    models['optimizer'] = optimizer

    global_step = tf.compat.v1.train.get_or_create_global_step()
    global_step.assign(start)

    # Prepare raybatch tensor if batching random rays
    N_rand = args.N_rand
    use_batching = not args.no_batching
    if use_batching:
        # For random ray batching.
        #
        # Constructs an array 'rays_rgb' of shape [N*H*W, 3, 3] where axis=1 is
        # interpreted as,
        #   axis=0: ray origin in world space
        #   axis=1: ray direction in world space
        #   axis=2: observed RGB color of pixel
        print('get rays')
        # get_rays_np() returns rays_origin=[H, W, 3], rays_direction=[H, W, 3]
        # for each pixel in the image. This stack() adds a new dimension.
        rays = [get_rays_np(H, W, focal, p) for p in poses[:, :3, :4]]
        rays = np.stack(rays, axis=0)  # [N, ro+rd, H, W, 3]
        print('done, concats')
        # [N, ro+rd+rgb, H, W, 3]
        rays_rgb = np.concatenate([rays, images[:, None, ...]], 1)
        # [N, H, W, ro+rd+rgb, 3]
        rays_rgb = np.transpose(rays_rgb, [0, 2, 3, 1, 4])
        rays_rgb = np.stack([rays_rgb[i] for i in i_train],
                            axis=0)  # train images only
        # [(N-1)*H*W, ro+rd+rgb, 3]
        rays_rgb = np.reshape(rays_rgb, [-1, 3, 3])
        rays_rgb = rays_rgb.astype(np.float32)
        print('shuffle rays')
        np.random.shuffle(rays_rgb)
        print('done')
        i_batch = 0

    N_iters = 1000000
    print('Begin')
    print('TRAIN views are', i_train)
    print('TEST views are', i_test)
    print('VAL views are', i_val)

    # Summary writers
    writer = tf.summary.create_file_writer(
        os.path.join(basedir, 'summaries', expname))
    writer.set_as_default()

    for i in range(start, N_iters):
        time0 = time.time()

        # Sample random ray batch

        if use_batching:
            # Random over all images
            batch = rays_rgb[i_batch:i_batch + N_rand]  # [B, 2+1, 3*?]
            batch = tf.transpose(batch, [1, 0, 2])

            # batch_rays[i, n, xyz] = ray origin or direction, example_id, 3D position
            # target_s[n, rgb] = example_id, observed color.
            batch_rays, target_s = batch[:2], batch[2]

            i_batch += N_rand
            if i_batch >= rays_rgb.shape[0]:
                np.random.shuffle(rays_rgb)
                i_batch = 0

        else:
            # Random from one image
            img_i = np.random.choice(i_train)
            target = images[img_i]
            pose = poses[img_i, :3, :4]

            if N_rand is not None:
                rays_o, rays_d = get_rays(H, W, focal, pose)
                if i < args.precrop_iters:
                    dH = int(H // 2 * args.precrop_frac)
                    dW = int(W // 2 * args.precrop_frac)
                    coords = tf.stack(
                        tf.meshgrid(tf.range(H // 2 - dH, H // 2 + dH),
                                    tf.range(W // 2 - dW, W // 2 + dW),
                                    indexing='ij'), -1)
                    if i < 10:
                        print('precrop', dH, dW, coords[0, 0], coords[-1, -1])
                else:
                    coords = tf.stack(
                        tf.meshgrid(tf.range(H), tf.range(W), indexing='ij'),
                        -1)
                coords = tf.reshape(coords, [-1, 2])
                select_inds = np.random.choice(coords.shape[0],
                                               size=[N_rand],
                                               replace=False)
                select_inds = tf.gather_nd(coords, select_inds[:, tf.newaxis])
                rays_o = tf.gather_nd(rays_o, select_inds)
                rays_d = tf.gather_nd(rays_d, select_inds)
                batch_rays = tf.stack([rays_o, rays_d], 0)
                target_s = tf.gather_nd(target, select_inds)

        #####  Core optimization loop  #####

        with tf.GradientTape() as tape:

            # Make predictions for color, disparity, accumulated opacity.
            rgb, disp, acc, extras = render(H,
                                            W,
                                            focal,
                                            chunk=args.chunk,
                                            rays=batch_rays,
                                            verbose=i < 10,
                                            retraw=True,
                                            **render_kwargs_train)

            # Compute MSE loss between predicted and true RGB.
            img_loss = img2mse(rgb, target_s)
            trans = extras['raw'][..., -1]
            loss = img_loss
            psnr = mse2psnr(img_loss)

            # Add MSE loss for coarse-grained model
            if 'rgb0' in extras:
                img_loss0 = img2mse(extras['rgb0'], target_s)
                loss += img_loss0
                psnr0 = mse2psnr(img_loss0)

        gradients = tape.gradient(loss, grad_vars)
        optimizer.apply_gradients(zip(gradients, grad_vars))

        dt = time.time() - time0

        #####           end            #####

        # Rest is logging

        def save_weights(net, prefix, i):
            path = os.path.join(basedir, expname,
                                '{}_{:06d}.npy'.format(prefix, i))
            np.save(path, net.get_weights())
            print('saved weights at', path)

        if i % args.i_weights == 0:
            for k in models:
                save_weights(models[k], k, i)

        if i % args.i_video == 0 and i > 0:

            rgbs, disps = render_path(render_poses, hwf, args.chunk,
                                      render_kwargs_test)
            print('Done, saving', rgbs.shape, disps.shape)
            moviebase = os.path.join(basedir, expname,
                                     '{}_spiral_{:06d}_'.format(expname, i))
            imageio.mimwrite(moviebase + 'rgb.mp4',
                             to8b(rgbs),
                             fps=30,
                             quality=8)
            imageio.mimwrite(moviebase + 'disp.mp4',
                             to8b(disps / np.max(disps)),
                             fps=30,
                             quality=8)

            if args.use_viewdirs:
                render_kwargs_test['c2w_staticcam'] = render_poses[0][:3, :4]
                rgbs_still, _ = render_path(render_poses, hwf, args.chunk,
                                            render_kwargs_test)
                render_kwargs_test['c2w_staticcam'] = None
                imageio.mimwrite(moviebase + 'rgb_still.mp4',
                                 to8b(rgbs_still),
                                 fps=30,
                                 quality=8)

        if i % args.i_testset == 0 and i > 0:
            testsavedir = os.path.join(basedir, expname,
                                       'testset_{:06d}'.format(i))
            os.makedirs(testsavedir, exist_ok=True)
            print('test poses shape', poses[i_test].shape)
            render_path(poses[i_test],
                        hwf,
                        args.chunk,
                        render_kwargs_test,
                        gt_imgs=images[i_test],
                        savedir=testsavedir)
            print('Saved test set')

        if i % args.i_print == 0 or i < 10:

            print(expname, i, psnr.numpy(), loss.numpy(), global_step.numpy())
            print('iter time {:.05f}'.format(dt))
            with record_summaries_every_n_global_steps(args.i_print):
                tf.summary.scalar('loss', loss)
                tf.summary.scalar('psnr', psnr)
                tf.summary.histogram('tran', trans)
                if args.N_importance > 0:
                    tf.summary.scalar('psnr0', psnr0)

            if i % args.i_img == 0:

                # Log a rendered validation view to Tensorboard
                img_i = np.random.choice(i_val)
                target = images[img_i]
                pose = poses[img_i, :3, :4]

                rgb, disp, acc, extras = render(H,
                                                W,
                                                focal,
                                                chunk=args.chunk,
                                                c2w=pose,
                                                **render_kwargs_test)

                psnr = mse2psnr(img2mse(rgb, target))

                # Save out the validation image for Tensorboard-free monitoring
                testimgdir = os.path.join(basedir, expname, 'tboard_val_imgs')
                if i == 0:
                    os.makedirs(testimgdir, exist_ok=True)
                imageio.imwrite(
                    os.path.join(testimgdir, '{:06d}.png'.format(i)),
                    to8b(rgb))

                with record_summaries_every_n_global_steps(args.i_img):

                    tf.summary.image('rgb', to8b(rgb)[tf.newaxis])
                    tf.summary.image('disp', disp[tf.newaxis, ..., tf.newaxis])
                    tf.summary.image('acc', acc[tf.newaxis, ..., tf.newaxis])

                    tf.summary.scalar('psnr_holdout', psnr)
                    tf.summary.image('rgb_holdout', target[tf.newaxis])

                if args.N_importance > 0:

                    with record_summaries_every_n_global_steps(args.i_img):
                        tf.summary.image('rgb0',
                                         to8b(extras['rgb0'])[tf.newaxis])
                        tf.summary.image(
                            'disp0', extras['disp0'][tf.newaxis, ...,
                                                     tf.newaxis])
                        tf.summary.image(
                            'z_std', extras['z_std'][tf.newaxis, ...,
                                                     tf.newaxis])

        global_step.assign_add(1)