def train(dataset_frame1, dataset_frame2, dataset_frame3):
    """Trains a model."""
    with tf.Graph().as_default():
        # Create input and target placeholder.
        input_placeholder = tf.placeholder(tf.float32,
                                           shape=(None, 256, 256, 6))
        target_placeholder = tf.placeholder(tf.float32,
                                            shape=(None, 256, 256, 3))

        # input_resized = tf.image.resize_area(input_placeholder, [128, 128])
        # target_resized = tf.image.resize_area(target_placeholder,[128, 128])

        # Prepare model.
        model = Voxel_flow_model()
        prediction = model.inference(input_placeholder)
        # reproduction_loss, prior_loss = model.loss(prediction, target_placeholder)
        reproduction_loss = model.loss(prediction, target_placeholder)
        # total_loss = reproduction_loss + prior_loss
        total_loss = reproduction_loss

        # Perform learning rate scheduling.
        learning_rate = FLAGS.initial_learning_rate

        # Create an optimizer that performs gradient descent.
        opt = tf.train.AdamOptimizer(learning_rate)
        grads = opt.compute_gradients(total_loss)
        update_op = opt.apply_gradients(grads)

        # Create summaries
        summaries = tf.get_collection(tf.GraphKeys.SUMMARIES)
        summaries.append(tf.scalar_summary('total_loss', total_loss))
        summaries.append(
            tf.scalar_summary('reproduction_loss', reproduction_loss))
        # summaries.append(tf.scalar_summary('prior_loss', prior_loss))
        summaries.append(tf.image_summary('Input Image', input_placeholder, 3))
        summaries.append(tf.image_summary('Output Image', prediction, 3))
        summaries.append(
            tf.image_summary('Target Image', target_placeholder, 3))

        # Create a saver.
        saver = tf.train.Saver(tf.all_variables())

        # Build the summary operation from the last tower summaries.
        summary_op = tf.merge_all_summaries()

        # Build an initialization operation to run below.
        init = tf.initialize_all_variables()
        sess = tf.Session()
        sess.run(init)

        # Summary Writter
        summary_writer = tf.train.SummaryWriter(FLAGS.train_dir,
                                                graph=sess.graph)

        # Training loop using feed dict method.
        data_list_frame1 = dataset_frame1.read_data_list_file()
        random.seed(1)
        shuffle(data_list_frame1)

        data_list_frame2 = dataset_frame2.read_data_list_file()
        random.seed(1)
        shuffle(data_list_frame2)

        data_list_frame3 = dataset_frame3.read_data_list_file()
        random.seed(1)
        shuffle(data_list_frame3)

        data_size = len(data_list_frame1)
        epoch_num = int(data_size / FLAGS.batch_size)

        # num_workers = 1

        # load_fn_frame1 = partial(dataset_frame1.process_func)
        # p_queue_frame1 = PrefetchQueue(load_fn_frame1, data_list_frame1, FLAGS.batch_size, shuffle=False, num_workers=num_workers)

        # load_fn_frame2 = partial(dataset_frame2.process_func)
        # p_queue_frame2 = PrefetchQueue(load_fn_frame2, data_list_frame2, FLAGS.batch_size, shuffle=False, num_workers=num_workers)

        # load_fn_frame3 = partial(dataset_frame3.process_func)
        # p_queue_frame3 = PrefetchQueue(load_fn_frame3, data_list_frame3, FLAGS.batch_size, shuffle=False, num_workers=num_workers)

        for step in xrange(0, FLAGS.max_steps):
            batch_idx = step % epoch_num

            batch_data_list_frame1 = data_list_frame1[
                int(batch_idx * FLAGS.batch_size):int((batch_idx + 1) *
                                                      FLAGS.batch_size)]
            batch_data_list_frame2 = data_list_frame2[
                int(batch_idx * FLAGS.batch_size):int((batch_idx + 1) *
                                                      FLAGS.batch_size)]
            batch_data_list_frame3 = data_list_frame3[
                int(batch_idx * FLAGS.batch_size):int((batch_idx + 1) *
                                                      FLAGS.batch_size)]

            # Load batch data.
            batch_data_frame1 = np.array([
                dataset_frame1.process_func(line)
                for line in batch_data_list_frame1
            ])
            batch_data_frame2 = np.array([
                dataset_frame2.process_func(line)
                for line in batch_data_list_frame2
            ])
            batch_data_frame3 = np.array([
                dataset_frame3.process_func(line)
                for line in batch_data_list_frame3
            ])

            # batch_data_frame1 = p_queue_frame1.get_batch()
            # batch_data_frame2 = p_queue_frame2.get_batch()
            # batch_data_frame3 = p_queue_frame3.get_batch()

            feed_dict = {
                input_placeholder:
                np.concatenate((batch_data_frame1, batch_data_frame3), 3),
                target_placeholder:
                batch_data_frame2
            }

            # Run single step update.
            _, loss_value = sess.run([update_op, total_loss],
                                     feed_dict=feed_dict)

            if batch_idx == 0:
                # Shuffle data at each epoch.
                random.seed(1)
                shuffle(data_list_frame1)
                random.seed(1)
                shuffle(data_list_frame2)
                random.seed(1)
                shuffle(data_list_frame3)
                print('Epoch Number: %d' % int(step / epoch_num))

            # Output Summary
            if step % 10 == 0:
                # summary_str = sess.run(summary_op, feed_dict = feed_dict)
                # summary_writer.add_summary(summary_str, step)
                print("Loss at step %d: %f" % (step, loss_value))

            if step % 500 == 0:
                # Run a batch of images
                prediction_np, target_np = sess.run(
                    [prediction, target_placeholder], feed_dict=feed_dict)
                for i in range(0, prediction_np.shape[0]):
                    file_name = FLAGS.train_image_dir + str(i) + '_out.png'
                    file_name_label = FLAGS.train_image_dir + str(
                        i) + '_gt.png'
                    imwrite(file_name, prediction_np[i, :, :, :])
                    imwrite(file_name_label, target_np[i, :, :, :])

            # Save checkpoint
            if step % 5000 == 0 or (step + 1) == FLAGS.max_steps:
                checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)
def test(dataset_frame1, dataset_frame2, dataset_frame3):
    """Perform test on a trained model."""
    with tf.Graph().as_default():
        # Create input and target placeholder.
        input_placeholder = tf.placeholder(tf.float32,
                                           shape=(None, 256, 256, 6))
        target_placeholder = tf.placeholder(tf.float32,
                                            shape=(None, 256, 256, 3))

        # input_resized = tf.image.resize_area(input_placeholder, [128, 128])
        # target_resized = tf.image.resize_area(target_placeholder,[128, 128])

        # Prepare model.
        model = Voxel_flow_model(is_train=True)
        prediction = model.inference(input_placeholder)
        # reproduction_loss, prior_loss = model.loss(prediction, target_placeholder)
        reproduction_loss = model.loss(prediction, target_placeholder)
        # total_loss = reproduction_loss + prior_loss
        total_loss = reproduction_loss

        # Create a saver and load.
        saver = tf.train.Saver(tf.all_variables())
        sess = tf.Session()

        # Restore checkpoint from file.
        if FLAGS.pretrained_model_checkpoint_path:
            assert tf.gfile.Exists(FLAGS.pretrained_model_checkpoint_path)
            ckpt = tf.train.get_checkpoint_state(
                FLAGS.pretrained_model_checkpoint_path)
            restorer = tf.train.Saver()
            restorer.restore(sess, ckpt.model_checkpoint_path)
            print('%s: Pre-trained model restored from %s' %
                  (datetime.now(), ckpt.model_checkpoint_path))

        # Process on test dataset.
        data_list_frame1 = dataset_frame1.read_data_list_file()
        data_size = len(data_list_frame1)
        epoch_num = int(data_size / FLAGS.batch_size)

        data_list_frame2 = dataset_frame2.read_data_list_file()

        data_list_frame3 = dataset_frame3.read_data_list_file()

        i = 0
        PSNR = 0

        for id_img in range(0, data_size):
            # Load single data.
            line_image_frame1 = dataset_frame1.process_func(
                data_list_frame1[id_img])
            line_image_frame2 = dataset_frame2.process_func(
                data_list_frame2[id_img])
            line_image_frame3 = dataset_frame3.process_func(
                data_list_frame3[id_img])

            batch_data_frame1 = [
                dataset_frame1.process_func(ll)
                for ll in data_list_frame1[0:63]
            ]
            batch_data_frame2 = [
                dataset_frame2.process_func(ll)
                for ll in data_list_frame2[0:63]
            ]
            batch_data_frame3 = [
                dataset_frame3.process_func(ll)
                for ll in data_list_frame3[0:63]
            ]

            batch_data_frame1.append(line_image_frame1)
            batch_data_frame2.append(line_image_frame2)
            batch_data_frame3.append(line_image_frame3)

            batch_data_frame1 = np.array(batch_data_frame1)
            batch_data_frame2 = np.array(batch_data_frame2)
            batch_data_frame3 = np.array(batch_data_frame3)

            feed_dict = {
                input_placeholder:
                np.concatenate((batch_data_frame1, batch_data_frame3), 3),
                target_placeholder:
                batch_data_frame2
            }
            # Run single step update.
            prediction_np, target_np, loss_value = sess.run(
                [prediction, target_placeholder, total_loss],
                feed_dict=feed_dict)
            print("Loss for image %d: %f" % (i, loss_value))
            file_name = FLAGS.test_image_dir + str(i) + '_out.png'
            file_name_label = FLAGS.test_image_dir + str(i) + '_gt.png'
            imwrite(file_name, prediction_np[-1, :, :, :])
            imwrite(file_name_label, target_np[-1, :, :, :])
            i += 1
            PSNR += 10 * np.log10(
                255.0 * 255.0 / np.sum(np.square(prediction_np - target_np)))
        print("Overall PSNR: %f db" % (PSNR / len(data_list)))
Esempio n. 3
0
def test(dataset_frame1, dataset_frame2, dataset_frame3):
    def rgb2gray(rgb):
        return np.dot(rgb[..., :3], [0.299, 0.587, 0.114])

    """Perform test on a trained model."""
    with tf.Graph().as_default():
        # Create input and target placeholder.
        input_placeholder = tf.placeholder(tf.float32,
                                           shape=(None, 256, 256, 6))
        target_placeholder = tf.placeholder(tf.float32,
                                            shape=(None, 256, 256, 3))

        edge_vgg_1 = Vgg16(input_placeholder[:, :, :, :3], reuse=None)
        edge_vgg_3 = Vgg16(input_placeholder[:, :, :, 3:6], reuse=True)

        edge_1 = tf.nn.sigmoid(edge_vgg_1.fuse)
        edge_3 = tf.nn.sigmoid(edge_vgg_3.fuse)

        edge_1 = tf.reshape(edge_1, [
            -1,
            input_placeholder.get_shape().as_list()[1],
            input_placeholder.get_shape().as_list()[2], 1
        ])
        edge_3 = tf.reshape(edge_3, [
            -1,
            input_placeholder.get_shape().as_list()[1],
            input_placeholder.get_shape().as_list()[2], 1
        ])

        with tf.variable_scope("Cycle_DVF"):
            # Prepare model.
            model = Voxel_flow_model(is_train=False)
            prediction = model.inference(
                tf.concat([input_placeholder, edge_1, edge_3], 3))

        # Create a saver and load.
        sess = tf.Session()

        # Restore checkpoint from file.
        if FLAGS.pretrained_model_checkpoint_path:
            restorer = tf.train.Saver()
            restorer.restore(sess, FLAGS.pretrained_model_checkpoint_path)
            print('%s: Pre-trained model restored from %s' %
                  (datetime.now(), FLAGS.pretrained_model_checkpoint_path))

        # Process on test dataset.
        data_list_frame1 = dataset_frame1.read_data_list_file()
        data_size = len(data_list_frame1)

        data_list_frame2 = dataset_frame2.read_data_list_file()

        data_list_frame3 = dataset_frame3.read_data_list_file()

        i = 0
        PSNR = 0
        SSIM = 0

        for id_img in range(0, data_size):
            UCF_index = data_list_frame1[id_img][:-12]
            # Load single data.

            batch_data_frame1 = [
                dataset_frame1.process_func(
                    os.path.join('ucf101_interp_ours', ll)[:-5] + '00.png')
                for ll in data_list_frame1[id_img:id_img + 1]
            ]
            batch_data_frame2 = [
                dataset_frame2.process_func(
                    os.path.join('ucf101_interp_ours', ll)[:-5] + '01_gt.png')
                for ll in data_list_frame2[id_img:id_img + 1]
            ]
            batch_data_frame3 = [
                dataset_frame3.process_func(
                    os.path.join('ucf101_interp_ours', ll)[:-5] + '02.png')
                for ll in data_list_frame3[id_img:id_img + 1]
            ]
            batch_data_mask = [
                dataset_frame3.process_func(
                    os.path.join('motion_masks_ucf101_interp', ll)[:-11] +
                    'motion_mask.png')
                for ll in data_list_frame3[id_img:id_img + 1]
            ]

            batch_data_frame1 = np.array(batch_data_frame1)
            batch_data_frame2 = np.array(batch_data_frame2)
            batch_data_frame3 = np.array(batch_data_frame3)
            batch_data_mask = (np.array(batch_data_mask) + 1.0) / 2.0

            feed_dict = {
                input_placeholder:
                np.concatenate((batch_data_frame1, batch_data_frame3), 3),
                target_placeholder:
                batch_data_frame2
            }
            # Run single step update.
            prediction_np, target_np, warped_img1, warped_img2 = sess.run(
                [
                    prediction, target_placeholder, model.warped_img1,
                    model.warped_img2
                ],
                feed_dict=feed_dict)

            imwrite(
                'ucf101_interp_ours/' + str(UCF_index) +
                '/frame_01_CyclicGen.png', prediction_np[0][-1, :, :, :])

            print(np.sum(batch_data_mask))
            if np.sum(batch_data_mask) > 0:
                img_pred_mask = np.expand_dims(batch_data_mask[0], -1) * (
                    prediction_np[0][-1] + 1.0) / 2.0
                img_target_mask = np.expand_dims(
                    batch_data_mask[0], -1) * (target_np[-1] + 1.0) / 2.0
                mse = np.sum((img_pred_mask - img_target_mask)**
                             2) / (3. * np.sum(batch_data_mask))
                psnr_cur = 20.0 * np.log10(1.0) - 10.0 * np.log10(mse)

                img_pred_gray = rgb2gray((prediction_np[0][-1] + 1.0) / 2.0)
                img_target_gray = rgb2gray((target_np[-1] + 1.0) / 2.0)
                ssim_cur = ssim(img_pred_gray, img_target_gray, data_range=1.0)

                PSNR += psnr_cur
                SSIM += ssim_cur

                i += 1
        print("Overall PSNR: %f db" % (PSNR / i))
        print("Overall SSIM: %f db" % (SSIM / i))
Esempio n. 4
0
        print('Copy variables from % s' % ckpt_path)

    #--test--#
    b_list = glob('./Datasets/' + dataset + '/bounding_box_train-Market/*.jpg')
    a_list = glob('./Datasets/' + dataset + '/bounding_box_train-Duke/*.jpg')

    b_save_dir = './test_predictions/' + dataset + '_spgan' + '/bounding_box_train_market2duke/'
    a_save_dir = './test_predictions/' + dataset + '_spgan' + '/bounding_box_train_duke2market/'
    utils.mkdir([a_save_dir, b_save_dir])

    for i in range(len(a_list)):
        a_real_ipt = im.imresize(im.imread(a_list[i]), [crop_size, crop_size])
        a_real_ipt.shape = 1, crop_size, crop_size, 3
        a2b_opt = sess.run(a2b, feed_dict={a_real: a_real_ipt})
        a_img_opt = a2b_opt

        img_name = os.path.basename(a_list[i])
        img_name = 'market_' + img_name  # market_style
        im.imwrite(im.immerge(a_img_opt, 1, 1), a_save_dir + img_name)
        print('Save %s' % (a_save_dir + img_name))

    for i in range(len(b_list)):
        b_real_ipt = im.imresize(im.imread(b_list[i]), [crop_size, crop_size])
        b_real_ipt.shape = 1, crop_size, crop_size, 3
        b2a_opt = sess.run(b2a, feed_dict={b_real: b_real_ipt})
        b_img_opt = b2a_opt
        img_name = os.path.basename(b_list[i])
        img_name = 'duke_' + img_name  #duke_style
        im.imwrite(im.immerge(b_img_opt, 1, 1), b_save_dir + img_name)
        print('Save %s' % (b_save_dir + img_name))
Esempio n. 5
0
def train(dataset_frame1, dataset_frame2, dataset_frame3):
  """Trains a model."""
  with tf.Graph().as_default():
    # Create input.
    data_list_frame1 = dataset_frame1.read_data_list_file()
    dataset_frame1 = tf.data.Dataset.from_tensor_slices(tf.constant(data_list_frame1))
    dataset_frame1 = dataset_frame1.repeat().shuffle(buffer_size=1000000, seed=1).map(_read_image)
    dataset_frame1 = dataset_frame1.prefetch(100)

    data_list_frame2 = dataset_frame2.read_data_list_file()
    dataset_frame2 = tf.data.Dataset.from_tensor_slices(tf.constant(data_list_frame2))
    dataset_frame2 = dataset_frame2.repeat().shuffle(buffer_size=1000000, seed=1).map(_read_image)
    dataset_frame2 = dataset_frame2.prefetch(100)

    data_list_frame3 = dataset_frame3.read_data_list_file()
    dataset_frame3 = tf.data.Dataset.from_tensor_slices(tf.constant(data_list_frame3))
    dataset_frame3 = dataset_frame3.repeat().shuffle(buffer_size=1000000, seed=1).map(_read_image)
    dataset_frame3 = dataset_frame3.prefetch(100)

    batch_frame1 = dataset_frame1.batch(FLAGS.batch_size).make_initializable_iterator()
    batch_frame2 = dataset_frame2.batch(FLAGS.batch_size).make_initializable_iterator()
    batch_frame3 = dataset_frame3.batch(FLAGS.batch_size).make_initializable_iterator()
    
    # Create input and target placeholder.
    input_placeholder = tf.concat([batch_frame1.get_next(), batch_frame3.get_next()], 3)
    target_placeholder = batch_frame2.get_next()

    # input_resized = tf.image.resize_area(input_placeholder, [128, 128])
    # target_resized = tf.image.resize_area(target_placeholder,[128, 128])

    # Prepare model.
    model = Voxel_flow_model(is_train=True)
    prediction, flow_motion, flow_mask = model.inference(input_placeholder)
    # reproduction_loss, prior_loss = model.loss(prediction, target_placeholder)
    reproduction_loss = model.loss(prediction, flow_motion, 
                              flow_mask, target_placeholder,
                              FLAGS.lambda_motion, FLAGS.lambda_mask)
    # total_loss = reproduction_loss + prior_loss
    total_loss = reproduction_loss
    
    # Perform learning rate scheduling.
    learning_rate = FLAGS.initial_learning_rate

    # Create an optimizer that performs gradient descent.
    opt = tf.train.AdamOptimizer(learning_rate)
    grads = opt.compute_gradients(total_loss)
    update_op = opt.apply_gradients(grads)

    # Create summaries
    summaries = tf.get_collection(tf.GraphKeys.SUMMARIES)
    summaries.append(tf.summary.scalar('total_loss', total_loss))
    summaries.append(tf.summary.scalar('reproduction_loss', reproduction_loss))
    # summaries.append(tf.summary.scalar('prior_loss', prior_loss))
    summaries.append(tf.summary.image('Input Image (before)', input_placeholder[:, :, :, 0:3], 3))
    summaries.append(tf.summary.image('Input Image (after)', input_placeholder[:, :, :, 3:6], 3))
    summaries.append(tf.summary.image('Output Image', prediction, 3))
    summaries.append(tf.summary.image('Target Image', target_placeholder, 3))
    # summaries.append(tf.summary.image('Flow', flow, 3))

    # Create a saver.
    saver = tf.train.Saver(tf.global_variables())

    # Build the summary operation from the last tower summaries.
    summary_op = tf.summary.merge_all()

    # Restore checkpoint from file.
    if FLAGS.pretrained_model_checkpoint_path \
        and tf.train.get_checkpoint_state(FLAGS.pretrained_model_checkpoint_path):
      sess = tf.Session()
      assert tf.gfile.Exists(FLAGS.pretrained_model_checkpoint_path)
      ckpt = tf.train.get_checkpoint_state(
               FLAGS.pretrained_model_checkpoint_path)
      restorer = tf.train.Saver()
      restorer.restore(sess, ckpt.model_checkpoint_path)
      print('%s: Pre-trained model restored from %s' %
        (datetime.now(), ckpt.model_checkpoint_path))
      sess.run([batch_frame1.initializer, batch_frame2.initializer, batch_frame3.initializer])
    else:
      # Build an initialization operation to run below.
      print('No existing checkpoints.')
      init = tf.global_variables_initializer()
      sess = tf.Session()
      sess.run([init, batch_frame1.initializer, batch_frame2.initializer, batch_frame3.initializer])

    # Summary Writter
    summary_writer = tf.summary.FileWriter(
      FLAGS.train_dir,
      graph=sess.graph)

    data_size = len(data_list_frame1)
    epoch_num = int(data_size / FLAGS.batch_size)

    for step in range(0, FLAGS.max_steps):
      batch_idx = step % epoch_num
      
      # Run single step update.
      _, loss_value = sess.run([update_op, total_loss])
      
      if batch_idx == 0:
        print('Epoch Number: %d' % int(step / epoch_num))
      
      if step % 10 == 0:
        print("Loss at step %d: %f" % (step, loss_value))

      if step % 100 == 0:
        # Output Summary 
        summary_str = sess.run(summary_op)
        summary_writer.add_summary(summary_str, step)

      if step % 500 == 0:
        # Run a batch of images 
        prediction_np, target_np = sess.run([prediction, target_placeholder])
        for i in range(0,prediction_np.shape[0]):
          file_name = FLAGS.train_image_dir+str(i)+'_out.png'
          file_name_label = FLAGS.train_image_dir+str(i)+'_gt.png'
          imwrite(file_name, prediction_np[i,:,:,:])
          imwrite(file_name_label, target_np[i,:,:,:])

      # Save checkpoint 
      if step % 500 == 0 or (step +1) == FLAGS.max_steps:
        checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
        saver.save(sess, checkpoint_path, global_step=step)
Esempio n. 6
0
def train(dataset_objects):

    with tf.Graph().as_default():
        # read and shuffle images
        data_lists = [
            dataset_obj.read_data_list_file()
            for dataset_obj in dataset_objects
        ]
        dataset_frames = [
            tf.data.Dataset.from_tensor_slices(tf.constant(data_list))
            for data_list in data_lists
        ]
        dataset_frames = [frame.repeat().shuffle(buffer_size=int(1e5), seed=1)\
                          .map(_read_image)
                          for frame in dataset_frames]
        dataset_frames = [frame.prefetch(100) for frame in dataset_frames]

        # 9 sets of frames in total, 1 for each frame in a 9-frame sequence
        batch_frames = [frame.batch(FLAGS.batch_size)\
                        .make_initializable_iterator()
                        for frame in dataset_frames]

        # grab the first and last frames for input
        input_placeholder = tf.concat(
            [batch_frames[0].get_next(), batch_frames[8].get_next()], axis=3)
        # the middle 7 frames for ground truth
        target_placeholder = tf.concat(
            [frame.get_next() for frame in batch_frames[1:8]], axis=3)

        sess = tf.Session()
        # the first network
        computer = SloMo_model(for_interpolation=False)
        # the second
        interpolater = SloMo_model(for_interpolation=True)
        # vgg for perceptual loss
        vgg_mod = vgg16(sess=sess)

        # flow computations between the first and last frames
        flow_01, flow_10 = computer.inference(input_placeholder)
        image_0, image_1 = input_placeholder[:, :, :, :3], \
                                input_placeholder[:, :, :, 3:]

        total_loss = 0
        pred_imgs_t = []
        # for each intermediate frame
        for idx, t in enumerate(np.arange(1.0 / 8, .999999, 1.0 / 8)):
            # intermediate flow approximation at t; paper calls this F hat
            flow_t0_hat = t * (-(1 - t) * flow_01 + t * flow_10)
            flow_t1_hat = (1 - t) * ((1 - t) * flow_01 - t * flow_10)

            # warp to approximate image_0, image_1 in inputs
            approx_img_0 = computer.warp(flow_10, image_1)
            approx_img_1 = computer.warp(flow_01, image_0)

            # interpolate intermediate frame
            interp_input = tf.concat([
                input_placeholder, approx_img_0, approx_img_1, flow_t0_hat,
                flow_t1_hat
            ],
                                     axis=3)
            flow_t0, flow_t1, vis_mask_0, vis_mask_1 = \
                      interpolater.inference(interp_input)
            # compute intermediate frame via equation (1)
            z = (1 - t) * tf.abs(vis_mask_0) + t * tf.abs(vis_mask_1)
            pred_img_t = (1 / z) * (
                (1 - t) * tf.abs(vis_mask_0) *
                computer.warp(-flow_t0, image_0) +
                t * tf.abs(vis_mask_1) * computer.warp(-flow_t1, image_1))
            pred_imgs_t += [pred_img_t]
            # reconstruction loss @ equation (7)
            target = target_placeholder[:, :, :, idx * 3:(idx + 1) * 3]
            loss_recons = l1_loss(pred_img_t, target)
            total_loss += FLAGS.lambda_reconstruction * loss_recons

            # perceptual loss @ equation (8)
            phi_true = vgg_mod.inference(target)
            phi_pred = vgg_mod.inference(pred_img_t)
            loss_percept = l2_loss(phi_true, phi_pred)
            total_loss += FLAGS.lambda_perceptual * loss_percept

            # Lagrangian penalty to enforce constraint @ equation (5)
            loss_constraint = l1_loss(tf.abs(vis_mask_0),
                                      1 - tf.abs(vis_mask_1))
            total_loss += FLAGS.lambda_penalty * loss_constraint

        # warping and smoothness losses @ equations (9) and (10)
        loss_warping = l1_loss(image_0, approx_img_0) \
                            + l1_loss(image_1, approx_img_1)
        loss_smooth = l1_regularizer(flow_01) + l1_regularizer(flow_10)
        # all losses
        total_loss += FLAGS.lambda_warping * loss_warping \
                        + FLAGS.lambda_smoothness * loss_smooth

        learning_rate = FLAGS.initial_learning_rate

        # backprop operation; collect gradient norms for tensorboard
        opt = tf.train.AdamOptimizer(learning_rate)
        update_op = slim.learning.create_train_op(total_loss,
                                                  opt,
                                                  summarize_gradients=True)

        # collect losses for tensorboard
        summaries = tf.get_collection(tf.GraphKeys.SUMMARIES)
        summaries.append(tf.summary.scalar('total_loss', total_loss))
        summaries.append(tf.summary.scalar('loss_recons', loss_recons))
        summaries.append(tf.summary.scalar('loss_percept', loss_percept))
        summaries.append(tf.summary.scalar('loss_constraint', loss_constraint))
        summaries.append(tf.summary.scalar('loss_warping', loss_warping))
        summaries.append(tf.summary.scalar('loss_smooth', loss_smooth))
        summary_op = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(FLAGS.train_dir,
                                               graph=sess.graph)

        # save stuffs
        saver = tf.train.Saver(tf.global_variables())

        # restore model if it exists
        if FLAGS.pretrained_model_checkpoint_path \
            and tf.train.get_checkpoint_state(
                FLAGS.pretrained_model_checkpoint_path):
            ckpt = tf.train.get_checkpoint_state(
                FLAGS.pretrained_model_checkpoint_path)
            restorer = tf.train.Saver()
            restorer.restore(sess, ckpt.model_checkpoint_path)
            print('%s: Pre-trained model restored from %s' %
                  (datetime.now(), ckpt.model_checkpoint_path))
            sess.run([batch_frame.initializer for batch_frame in batch_frames])
        else:
            print('No existing checkpoints.')
            init = tf.global_variables_initializer()
            sess.run([init] +
                     [batch_frame.initializer for batch_frame in batch_frames])

        data_size = len(data_lists[0])
        epoch_num = int(data_size / FLAGS.batch_size)

        for step in range(0, FLAGS.max_steps):
            batch_idx = step % epoch_num
            loss_value, __ = sess.run([total_loss, update_op])

            if batch_idx == 0:
                print('Epoch Number: %d' % int(step / epoch_num))

            if step % 10 == 0:
                print("Loss at step %d: %f" % (step, loss_value))

            if step % 200 == 0:
                # save summary
                summary_str = sess.run(summary_op)
                summary_writer.add_summary(summary_str, step)

            if step % 500 in {0, 1, 2}:
                # get some intermediate results
                prediction = tf.concat(pred_imgs_t, axis=3)
                prediction_np, input_np, target_np, comp_img1, comp_img2 \
                                = sess.run([prediction,
                                  input_placeholder,
                                  target_placeholder,
                                  approx_img_0,
                                  approx_img_1])
                input_1, input_2 = input_np[:, :, :, :3], input_np[:, :, :, 3:]

                for examp_idx in range(prediction_np.shape[0]):
                    file_name_comp1 = FLAGS.train_image_dir + 'comp_img0_out' + '_step' + str(step) \
                                          + '.png'
                    file_name_comp2 = FLAGS.train_image_dir + 'comp_img1_out' + '_step' + str(step) \
                                          + '.png'
                    imwrite(file_name_comp1, comp_img1[0, :, :, :])
                    imwrite(file_name_comp2, comp_img2[0, :, :, :])
                    imwrite(file_name_comp1.replace('out', 'gt'),
                            input_1[0, :, :, :])
                    imwrite(file_name_comp2.replace('out', 'gt'),
                            input_2[0, :, :, :])

                    for inputs in ['out', 'gt']:
                        file_name_input1 = FLAGS.train_image_dir + inputs + '_step' + str(step) \
                                              + '_frame0' + '.png'
                        file_name_input2 = FLAGS.train_image_dir + inputs + '_step' + str(step) \
                                              + '_frame8' + '.png'
                        imwrite(file_name_input1, input_1[0, :, :, :])
                        imwrite(file_name_input2, input_2[0, :, :, :])

                    for frame_idx in range(int(prediction_np.shape[3] / 3)):
                        file_name = FLAGS.train_image_dir + 'out' + '_step' + str(step) \
                                        + '_frame' + str(frame_idx + 1) + '.png'
                        file_name_label = FLAGS.train_image_dir + 'gt' + '_step' + str(step) \
                                        + '_frame' + str(frame_idx + 1) + '.png'

                        imwrite(
                            file_name,
                            prediction_np[examp_idx, :, :,
                                          frame_idx * 3:(frame_idx + 1) * 3])
                        imwrite(
                            file_name_label,
                            target_np[examp_idx, :, :,
                                      frame_idx * 3:(frame_idx + 1) * 3])

            # save model weights
            if step % 200 == 0 or (step + 1) == FLAGS.max_steps:
                checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)