Exemplo n.º 1
0
    def test(self):
        if not os.path.exists('./test_figure'):
            os.mkdir('./test_figure')

        image_path_pairs = zip(self.args.input_images[:-1],
                               self.args.input_images[1:])
        for img1_path, img2_path in tqdm(image_path_pairs, desc='Processing'):
            images = list(map(imageio.imread, (img1_path, img2_path)))
            images = list(map(factor_crop, images))
            images = np.array(images) / 255.
            images_expand = np.expand_dims(images, axis=0)

            flows = self.sess.run(self.flows,
                                  feed_dict={self.images: images_expand})

            flow_set = []
            for l, flow in enumerate(flows):
                upscale = 20 / 2**(self.model.num_levels - l)
                flow_set.append(flow[0] * upscale)

            dname, fname = re.split('[/.]', img1_path)[-3:-1]
            if not os.path.exists(f'./test_figure/{dname}'):
                os.mkdir(f'./test_figure/{dname}')
            vis_flow_pyramid(flow_set,
                             images=images,
                             filename=f'./test_figure/{dname}/{fname}.png')
        print('Figure saved')
Exemplo n.º 2
0
 def test(self):
     flow_pyramid = self.sess.run(self.flow_pyramid)
     flow_pyramid = [fpy[0] for fpy in flow_pyramid]
     if not os.path.exists('./test_figure'):
         os.mkdir('./test_figure')
     fname = '_'.join(re.split('[/.]', self.args.input_images[0])[-3:-1])
     vis_flow_pyramid(flow_pyramid,
                      images=self.images,
                      filename=f'./test_figure/test_{fname}.pdf')
     print('Figure saved')
Exemplo n.º 3
0
    def train(self):
        train_start = time.time()
        for e in range(self.args.num_epochs):
            for i, (images, flows_gt) in enumerate(self.train_loader):
                images = images.numpy() / 255.0
                flows_gt = flows_gt.numpy()

                time_s = time.time()
                _, _, loss, epe = \
                  self.sess.run([self.optimizer, self.global_step_update,
                                 self.loss, self.epe],
                                feed_dict = {self.images: images, self.flows_gt: flows_gt})

                if i % 20 == 0:
                    batch_time = time.time() - time_s
                    kwargs = {
                        'loss': loss,
                        'epe': epe,
                        'batch time': batch_time
                    }
                    show_progress(e + 1, i + 1, self.num_batches, **kwargs)

            loss_vals, epe_vals = [], []
            for images_val, flows_gt_val in self.val_loader:
                images_val = images_val.numpy() / 255.0
                flows_gt_val = flows_gt_val.numpy()

                flows, loss_val, epe_val \
                    = self.sess.run([self.flows, self.loss, self.epe],
                                    feed_dict = {self.images: images_val,
                                                 self.flows_gt: flows_gt_val})
                loss_vals.append(loss_val)
                epe_vals.append(epe_val)

            g_step = self.sess.run(self.global_step)
            print(f'\r{e+1} epoch validation, loss: {np.mean(loss_vals)}, epe: {np.mean(epe_vals)}'\
                  +f', global step: {g_step}, elapsed time: {time.time()-train_start} sec.')

            # visualize estimated optical flow
            if self.args.visualize:
                if not os.path.exists('./figure'):
                    os.mkdir('./figure')
                # Estimated flow values are downscaled, rescale them compatible to the ground truth
                flow_set = []
                for l, flow in enumerate(flows):
                    upscale = 20 / 2**(self.args.num_levels - l)
                    flow_set.append(flow[0] * upscale)
                flow_gt = flows_gt_val[0]
                images_v = images_val[0]
                vis_flow_pyramid(flow_set, flow_gt, images_v,
                                 f'./figure/flow_{str(e+1).zfill(4)}.pdf')

            if not os.path.exists('./model'):
                os.mkdir('./model')
            self.saver.save(self.sess, f'./model/model_{e+1}.ckpt')
Exemplo n.º 4
0
    def train(self):
        train_start = time.time()
        for e in range(self.args.n_epoch):
            for i, (images, flows_gt) in enumerate(self.train_loader):
                images = images.numpy()/255.0
                flows_gt = flows_gt.numpy()
                
                time_s = time.time()
                _, _, loss_reg, epe_final = \
                  self.sess.run([self.optimizer, self.global_step_update,
                                 self.loss_reg, self.epe_final],
                                feed_dict = {self.images: images, self.flows_gt: flows_gt})

                if i%20 == 0:
                    batch_time = time.time() - time_s
                    kwargs = {'loss':loss_reg, 'epe':epe_final, 'batch time':batch_time}
                    show_progress(e+1, i+1, self.num_batches, **kwargs)

            loss_evals, epe_evals = [], []
            for images_eval, flows_gt_eval in self.eval_loader:
                images_eval = images_eval.numpy()/255.0
                flows_gt_eval = flows_gt_eval.numpy()

                flows_pyramid, loss_eval, epe_eval \
                    = self.sess.run([self.flows_pyramid, self.loss_reg, self.epe_final],
                                    feed_dict = {self.images: images_eval,
                                                 self.flows_gt: flows_gt_eval})
                loss_evals.append(loss_eval)
                epe_evals.append(epe_eval)
                
            g_step = self.sess.run(self.global_step)
            print(f'\r{e+1} epoch evaluation, loss: {np.mean(loss_evals)}, epe: {np.mean(epe_evals)}'\
                  +f', global step: {g_step}, elapsed time: {time.time()-train_start} sec.')
            
            # visualize estimated optical flow
            if self.args.visualize:
                if not os.path.exists('./figure'):
                    os.mkdir('./figure')
                flow_pyramid = [f_py[0] for f_py in flows_pyramid]
                flow_gt = flows_gt_eval[0]
                images_e = images_eval[0]
                vis_flow_pyramid(flow_pyramid, flow_gt, images_e,
                                 f'./figure/flow_{str(e+1).zfill(4)}.pdf')

            if not os.path.exists('./model'):
                os.mkdir('./model')
            self.saver.save(self.sess, f'./model/model_{e+1}.ckpt')
Exemplo n.º 5
0
    def test(self):
        if self.args.time:
            time_s = time.time()
            for _ in tqdm(range(1000)):
                flow_pyramid = self.sess.run(self.flow_pyramid)
            time_iter = (time.time() - time_s) / 1000
            print(
                f'Inference time: {time_iter} sec (averaged over 1000 iterations)'
            )
        else:
            flow_pyramid = self.sess.run(self.flow_pyramid)

        flow_pyramid = [fpy[0] for fpy in flow_pyramid]
        if not os.path.exists('./test_figure'):
            os.mkdir('./test_figure')
        fname = '_'.join(re.split('[/.]', self.args.input_images[0])[-3:-1])
        vis_flow_pyramid(flow_pyramid,
                         images=self.images,
                         filename=f'./test_figure/test_{fname}.pdf')
        print('Figure saved')
Exemplo n.º 6
0
    def test(self):
        if self.args.time:
            time_s = time.time()
            for _ in tqdm(range(1000)):
                flows = self.sess.run(self.flows)
            time_iter = (time.time() - time_s) / 1000
            print(
                f'Inference time: {time_iter} sec (averaged over 1000 iterations)'
            )
        else:
            flows = self.sess.run(self.flows)

        flow_set = []
        for l, flow in enumerate(flows):
            upscale = 20 / 2**(self.model.num_levels - l)
            flow_set.append(flow[0] * upscale)
        if not os.path.exists('./test_figure'):
            os.mkdir('./test_figure')
        fname = '_'.join(re.split('[/.]', self.args.input_images[0])[-3:-1])
        vis_flow_pyramid(flow_set,
                         images=self.images,
                         filename=f'./test_figure/test_{fname}.pdf')
        print('Figure saved')
Exemplo n.º 7
0
    def train(self):
        for e in tqdm(range(self.args.num_epochs)):
            # Training
            for images, flows_gt in self.tloader:
                images = images.numpy() / 255.0
                flows_gt = flows_gt.numpy()

                _, g_step = self.sess.run([self.optimizer, self.global_step],
                                          feed_dict={
                                              self.images: images,
                                              self.flows_gt: flows_gt
                                          })

                if g_step % 1000 == 0:
                    summary = self.sess.run(self.merged,
                                            feed_dict={
                                                self.images: images,
                                                self.flows_gt: flows_gt
                                            })
                    self.twriter.add_summary(summary, g_step)

            # Validation
            for images_val, flows_gt_val in self.vloader:
                images_val = images_val.numpy() / 255.0
                flows_gt_val = flows_gt_val.numpy()

                summary = self.sess.run(self.merged,
                                        feed_dict={
                                            self.images: images_val,
                                            self.flows_gt: flows_gt_val
                                        })
                self.vwriter.add_summary(summary, g_step)
            # Collect convolution weights and biases
            # summary_plus = self.sess.run(self.merged_plus)
            # self.vwriter.add_summary(summary_plus, g_step)

            # visualize estimated optical flow
            if self.args.visualize:
                if not os.path.exists('./figure'):
                    os.mkdir('./figure')
                # Estimated flow values are downscaled, rescale them compatible to the ground truth
                flow_set = []
                flows_val = self.sess.run(self.flows,
                                          feed_dict={
                                              self.images: images_val,
                                              self.flows_gt: flows_gt_val
                                          })
                for l, flow in enumerate(flows_val):
                    upscale = 20 / 2**(self.args.num_levels - l)
                    flow_set.append(flow[0] * upscale)
                flow_gt = flows_gt_val[0]
                images_v = images_val[0]
                vis_flow_pyramid(flow_set, flow_gt, images_v,
                                 f'./figure/flow_{str(e+1).zfill(4)}.pdf')

            if not os.path.exists('./model'):
                os.mkdir('./model')
            self.saver.save(self.sess, f'./model/model_{e+1}.ckpt')

        self.twriter.close()
        self.vwriter.close()
        self.exp_saver.append(['./figure', './model'])
        self.exp_saver.save()