def train(dataset_frame1, dataset_frame2, dataset_frame3):
    """Trains a model."""
    with tf.Graph().as_default():
        # Create input and target placeholder.
        input_placeholder = tf.placeholder(tf.float32,
                                           shape=(None, 256, 256, 6))
        target_placeholder = tf.placeholder(tf.float32,
                                            shape=(None, 256, 256, 3))

        # input_resized = tf.image.resize_area(input_placeholder, [128, 128])
        # target_resized = tf.image.resize_area(target_placeholder,[128, 128])

        # Prepare model.
        model = Voxel_flow_model()
        prediction = model.inference(input_placeholder)
        # reproduction_loss, prior_loss = model.loss(prediction, target_placeholder)
        reproduction_loss = model.loss(prediction, target_placeholder)
        # total_loss = reproduction_loss + prior_loss
        total_loss = reproduction_loss

        # Perform learning rate scheduling.
        learning_rate = FLAGS.initial_learning_rate

        # Create an optimizer that performs gradient descent.
        opt = tf.train.AdamOptimizer(learning_rate)
        grads = opt.compute_gradients(total_loss)
        update_op = opt.apply_gradients(grads)

        # Create summaries
        summaries = tf.get_collection(tf.GraphKeys.SUMMARIES)
        summaries.append(tf.scalar_summary('total_loss', total_loss))
        summaries.append(
            tf.scalar_summary('reproduction_loss', reproduction_loss))
        # summaries.append(tf.scalar_summary('prior_loss', prior_loss))
        summaries.append(tf.image_summary('Input Image', input_placeholder, 3))
        summaries.append(tf.image_summary('Output Image', prediction, 3))
        summaries.append(
            tf.image_summary('Target Image', target_placeholder, 3))

        # Create a saver.
        saver = tf.train.Saver(tf.all_variables())

        # Build the summary operation from the last tower summaries.
        summary_op = tf.merge_all_summaries()

        # Build an initialization operation to run below.
        init = tf.initialize_all_variables()
        sess = tf.Session()
        sess.run(init)

        # Summary Writter
        summary_writer = tf.train.SummaryWriter(FLAGS.train_dir,
                                                graph=sess.graph)

        # Training loop using feed dict method.
        data_list_frame1 = dataset_frame1.read_data_list_file()
        random.seed(1)
        shuffle(data_list_frame1)

        data_list_frame2 = dataset_frame2.read_data_list_file()
        random.seed(1)
        shuffle(data_list_frame2)

        data_list_frame3 = dataset_frame3.read_data_list_file()
        random.seed(1)
        shuffle(data_list_frame3)

        data_size = len(data_list_frame1)
        epoch_num = int(data_size / FLAGS.batch_size)

        # num_workers = 1

        # load_fn_frame1 = partial(dataset_frame1.process_func)
        # p_queue_frame1 = PrefetchQueue(load_fn_frame1, data_list_frame1, FLAGS.batch_size, shuffle=False, num_workers=num_workers)

        # load_fn_frame2 = partial(dataset_frame2.process_func)
        # p_queue_frame2 = PrefetchQueue(load_fn_frame2, data_list_frame2, FLAGS.batch_size, shuffle=False, num_workers=num_workers)

        # load_fn_frame3 = partial(dataset_frame3.process_func)
        # p_queue_frame3 = PrefetchQueue(load_fn_frame3, data_list_frame3, FLAGS.batch_size, shuffle=False, num_workers=num_workers)

        for step in xrange(0, FLAGS.max_steps):
            batch_idx = step % epoch_num

            batch_data_list_frame1 = data_list_frame1[
                int(batch_idx * FLAGS.batch_size):int((batch_idx + 1) *
                                                      FLAGS.batch_size)]
            batch_data_list_frame2 = data_list_frame2[
                int(batch_idx * FLAGS.batch_size):int((batch_idx + 1) *
                                                      FLAGS.batch_size)]
            batch_data_list_frame3 = data_list_frame3[
                int(batch_idx * FLAGS.batch_size):int((batch_idx + 1) *
                                                      FLAGS.batch_size)]

            # Load batch data.
            batch_data_frame1 = np.array([
                dataset_frame1.process_func(line)
                for line in batch_data_list_frame1
            ])
            batch_data_frame2 = np.array([
                dataset_frame2.process_func(line)
                for line in batch_data_list_frame2
            ])
            batch_data_frame3 = np.array([
                dataset_frame3.process_func(line)
                for line in batch_data_list_frame3
            ])

            # batch_data_frame1 = p_queue_frame1.get_batch()
            # batch_data_frame2 = p_queue_frame2.get_batch()
            # batch_data_frame3 = p_queue_frame3.get_batch()

            feed_dict = {
                input_placeholder:
                np.concatenate((batch_data_frame1, batch_data_frame3), 3),
                target_placeholder:
                batch_data_frame2
            }

            # Run single step update.
            _, loss_value = sess.run([update_op, total_loss],
                                     feed_dict=feed_dict)

            if batch_idx == 0:
                # Shuffle data at each epoch.
                random.seed(1)
                shuffle(data_list_frame1)
                random.seed(1)
                shuffle(data_list_frame2)
                random.seed(1)
                shuffle(data_list_frame3)
                print('Epoch Number: %d' % int(step / epoch_num))

            # Output Summary
            if step % 10 == 0:
                # summary_str = sess.run(summary_op, feed_dict = feed_dict)
                # summary_writer.add_summary(summary_str, step)
                print("Loss at step %d: %f" % (step, loss_value))

            if step % 500 == 0:
                # Run a batch of images
                prediction_np, target_np = sess.run(
                    [prediction, target_placeholder], feed_dict=feed_dict)
                for i in range(0, prediction_np.shape[0]):
                    file_name = FLAGS.train_image_dir + str(i) + '_out.png'
                    file_name_label = FLAGS.train_image_dir + str(
                        i) + '_gt.png'
                    imwrite(file_name, prediction_np[i, :, :, :])
                    imwrite(file_name_label, target_np[i, :, :, :])

            # Save checkpoint
            if step % 5000 == 0 or (step + 1) == FLAGS.max_steps:
                checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)
def test(dataset_frame1, dataset_frame2, dataset_frame3):
    """Perform test on a trained model."""
    with tf.Graph().as_default():
        # Create input and target placeholder.
        input_placeholder = tf.placeholder(tf.float32,
                                           shape=(None, 256, 256, 6))
        target_placeholder = tf.placeholder(tf.float32,
                                            shape=(None, 256, 256, 3))

        # input_resized = tf.image.resize_area(input_placeholder, [128, 128])
        # target_resized = tf.image.resize_area(target_placeholder,[128, 128])

        # Prepare model.
        model = Voxel_flow_model(is_train=True)
        prediction = model.inference(input_placeholder)
        # reproduction_loss, prior_loss = model.loss(prediction, target_placeholder)
        reproduction_loss = model.loss(prediction, target_placeholder)
        # total_loss = reproduction_loss + prior_loss
        total_loss = reproduction_loss

        # Create a saver and load.
        saver = tf.train.Saver(tf.all_variables())
        sess = tf.Session()

        # Restore checkpoint from file.
        if FLAGS.pretrained_model_checkpoint_path:
            assert tf.gfile.Exists(FLAGS.pretrained_model_checkpoint_path)
            ckpt = tf.train.get_checkpoint_state(
                FLAGS.pretrained_model_checkpoint_path)
            restorer = tf.train.Saver()
            restorer.restore(sess, ckpt.model_checkpoint_path)
            print('%s: Pre-trained model restored from %s' %
                  (datetime.now(), ckpt.model_checkpoint_path))

        # Process on test dataset.
        data_list_frame1 = dataset_frame1.read_data_list_file()
        data_size = len(data_list_frame1)
        epoch_num = int(data_size / FLAGS.batch_size)

        data_list_frame2 = dataset_frame2.read_data_list_file()

        data_list_frame3 = dataset_frame3.read_data_list_file()

        i = 0
        PSNR = 0

        for id_img in range(0, data_size):
            # Load single data.
            line_image_frame1 = dataset_frame1.process_func(
                data_list_frame1[id_img])
            line_image_frame2 = dataset_frame2.process_func(
                data_list_frame2[id_img])
            line_image_frame3 = dataset_frame3.process_func(
                data_list_frame3[id_img])

            batch_data_frame1 = [
                dataset_frame1.process_func(ll)
                for ll in data_list_frame1[0:63]
            ]
            batch_data_frame2 = [
                dataset_frame2.process_func(ll)
                for ll in data_list_frame2[0:63]
            ]
            batch_data_frame3 = [
                dataset_frame3.process_func(ll)
                for ll in data_list_frame3[0:63]
            ]

            batch_data_frame1.append(line_image_frame1)
            batch_data_frame2.append(line_image_frame2)
            batch_data_frame3.append(line_image_frame3)

            batch_data_frame1 = np.array(batch_data_frame1)
            batch_data_frame2 = np.array(batch_data_frame2)
            batch_data_frame3 = np.array(batch_data_frame3)

            feed_dict = {
                input_placeholder:
                np.concatenate((batch_data_frame1, batch_data_frame3), 3),
                target_placeholder:
                batch_data_frame2
            }
            # Run single step update.
            prediction_np, target_np, loss_value = sess.run(
                [prediction, target_placeholder, total_loss],
                feed_dict=feed_dict)
            print("Loss for image %d: %f" % (i, loss_value))
            file_name = FLAGS.test_image_dir + str(i) + '_out.png'
            file_name_label = FLAGS.test_image_dir + str(i) + '_gt.png'
            imwrite(file_name, prediction_np[-1, :, :, :])
            imwrite(file_name_label, target_np[-1, :, :, :])
            i += 1
            PSNR += 10 * np.log10(
                255.0 * 255.0 / np.sum(np.square(prediction_np - target_np)))
        print("Overall PSNR: %f db" % (PSNR / len(data_list)))
def test(
    dataset_frame1, dataset_frame2, dataset_frame3, dataset_frame4,
    dataset_frame5, dataset_frame6, dataset_frame7, dataset_frame8,
    dataset_frame9, dataset_frame10, dataset_frame11, dataset_frame12,
    dataset_frame13, dataset_frame14, dataset_frame15, dataset_frame16,
    dataset_frame17, dataset_frame18, dataset_frame19, dataset_frame20,
    dataset_frame21, dataset_frame22, dataset_frame23, dataset_frame24,
    dataset_frame25, dataset_frame26, dataset_frame27, dataset_frame28,
    dataset_frame29, dataset_frame30, dataset_frame31, dataset_frame32,
    dataset_frame33, dataset_frame34, dataset_frame35, dataset_frame36,
    dataset_frame37, dataset_frame38, dataset_frame39, dataset_frame40,
    dataset_frame41
):  #, dataset_frame42, dataset_frame43, dataset_frame44, dataset_frame45, dataset_frame46, dataset_frame47, dataset_frame48, dataset_frame49, dataset_frame50, dataset_frame51, dataset_frame52, dataset_frame53, dataset_frame54):
    """Perform test on a trained model."""
    with tf.Graph().as_default():
        # Create input and target placeholder.
        input_placeholder = tf.placeholder(tf.float32,
                                           shape=(None, 256, 256,
                                                  FLAGS.num_in * 3))
        target_placeholder = tf.placeholder(tf.float32,
                                            shape=(None, 256, 256,
                                                   FLAGS.num_out * 3))

        # input_resized = tf.image.resize_area(input_placeholder, [128, 128])
        # target_resized = tf.image.resize_area(target_placeholder,[128, 128])

        # Prepare model.
        model = Voxel_flow_model(is_train=True)
        prediction = model.inference(input_placeholder)
        # reproduction_loss, prior_loss = model.loss(prediction, target_placeholder)
        reproduction_loss = model.loss(prediction, target_placeholder)
        # total_loss = reproduction_loss + prior_loss
        total_loss = reproduction_loss

        # Create a saver and load.
        saver = tf.train.Saver(tf.all_variables())
        sess = tf.Session()
        # Restore checkpoint from file.
        if FLAGS.pretrained_model_checkpoint_path:
            assert tf.gfile.Exists(FLAGS.pretrained_model_checkpoint_path)
            ckpt = tf.train.get_checkpoint_state(
                FLAGS.pretrained_model_checkpoint_path)
            restorer = tf.train.Saver()
            restorer.restore(sess, ckpt.model_checkpoint_path)
            print('%s: Pre-trained model restored from %s' %
                  (datetime.now(), ckpt.model_checkpoint_path))

        # Process on test dataset.
        # data_list_frame1 = dataset_frame1.read_data_list_file()
        # data_list_frame2 = dataset_frame2.read_data_list_file()
        # data_list_frame3 = dataset_frame3.read_data_list_file()
        # data_list_frame4 = dataset_frame4.read_data_list_file()
        # data_list_frame5 = dataset_frame5.read_data_list_file()
        # data_list_frame6 = dataset_frame6.read_data_list_file()
        # data_list_frame7 = dataset_frame7.read_data_list_file()
        # data_list_frame8 = dataset_frame8.read_data_list_file()
        # data_list_frame9 = dataset_frame9.read_data_list_file()
        # data_list_frame10 = dataset_frame10.read_data_list_file()
        # data_list_frame11 = dataset_frame11.read_data_list_file()
        # data_list_frame12 = dataset_frame12.read_data_list_file()
        # data_list_frame13 = dataset_frame13.read_data_list_file()
        # data_list_frame14 = dataset_frame14.read_data_list_file()
        feed_str = "{input_placeholder:np.concatenate(("
        for i in range(FLAGS.num_in):
            feed_str += "batch_data_frame{}, ".format(i + 1)
        feed_str = feed_str[:-2] + '),3)'
        m = globals()
        n = locals()
        for i in range(FLAGS.num_out + FLAGS.num_in):
            exec(
                "data_list_frame{}=dataset_frame{}.read_data_list_file()".
                format(i + 1, i + 1), m, n)
        exec("data_size = len(data_list_frame1)", m, n)
        exec("epoch_num = int(data_size / FLAGS.batch_size)", m, n)
        print("feed_dict = " + feed_str +
              ", target_placeholder: batch_data_frame41}")

        j = 0
        PSNR = 0
        # batch_data_frame1 = [dataset_frame1.process_func(ll) for ll in data_list_frame1[-8:]]
        # batch_data_frame2 = [dataset_frame2.process_func(ll) for ll in data_list_frame2[-8:]]
        # batch_data_frame3 = [dataset_frame3.process_func(ll) for ll in data_list_frame3[-8:]]
        # batch_data_frame4 = [dataset_frame4.process_func(ll) for ll in data_list_frame4[-8:]]
        # batch_data_frame5 = [dataset_frame5.process_func(ll) for ll in data_list_frame5[-8:]]
        # batch_data_frame6 = [dataset_frame6.process_func(ll) for ll in data_list_frame6[-8:]]
        # batch_data_frame7 = [dataset_frame7.process_func(ll) for ll in data_list_frame7[-8:]]
        # batch_data_frame8 = [dataset_frame8.process_func(ll) for ll in data_list_frame8[-8:]]
        # batch_data_frame9 = [dataset_frame9.process_func(ll) for ll in data_list_frame9[-8:]]
        # batch_data_frame10 = [dataset_frame10.process_func(ll) for ll in data_list_frame10[-8:]]
        # batch_data_frame11 = [dataset_frame11.process_func(ll) for ll in data_list_frame11[-8:]]
        # batch_data_frame12 = [dataset_frame12.process_func(ll) for ll in data_list_frame12[-8:]]
        # batch_data_frame13 = [dataset_frame13.process_func(ll) for ll in data_list_frame13[-8:]]
        # batch_data_frame14 = [dataset_frame14.process_func(ll) for ll in data_list_frame14[-8:]]

        # predicting using batch size 8 and input starting from 101
        for i in range(FLAGS.num_out + FLAGS.num_in):
            exec(
                "tr = np.array([dataset_frame{}.process_func(data_list_frame{}[73])])"
                .format(i + 1, i + 1), m, n)
            exec("np.expand_dims(tr,axis=0)", m, n)
            exec("s='data_list_frame{}'".format(i + 1), m, n)
            for line in n[n['s']][74:81]:
                n['line'] = line
                exec(
                    "tr=np.append(tr,np.expand_dims(dataset_frame1.process_func(line),axis=0),axis=0)",
                    m, n)
            exec('batch_data_frame{} = tr'.format(i + 1), m, n)

        # batch_data_frame1 = np.array(batch_data_frame1)
        # batch_data_frame2 = np.array(batch_data_frame2)
        # batch_data_frame3 = np.array(batch_data_frame3)
        # batch_data_frame4 = np.array(batch_data_frame4)
        # batch_data_frame5 = np.array(batch_data_frame5)
        # batch_data_frame6 = np.array(batch_data_frame6)
        # batch_data_frame7 = np.array(batch_data_frame7)
        # batch_data_frame8 = np.array(batch_data_frame8)
        # batch_data_frame9 = np.array(batch_data_frame9)
        # batch_data_frame10 = np.array(batch_data_frame10)
        # batch_data_frame11 = np.array(batch_data_frame11)
        # batch_data_frame12 = np.array(batch_data_frame12)
        # batch_data_frame13 = np.array(batch_data_frame13)
        # batch_data_frame14 = np.array(batch_data_frame14)

        for i in range(FLAGS.num_out + FLAGS.num_in):
            exec(
                "batch_data_frame{} = np.array(batch_data_frame{})".format(
                    i + 1, i + 1), m, n)

        for id_img in range(0, 10):
            # Load single data.
            # line_image_frame1 = dataset_frame1.process_func(data_list_frame1[id_img])
            # line_image_frame2 = dataset_frame2.process_func(data_list_frame2[id_img])
            # line_image_frame3 = dataset_frame3.process_func(data_list_frame3[id_img])

            # batch_data_frame1.append(line_image_frame1)
            # batch_data_frame2.append(line_image_frame2)
            # batch_data_frame3.append(line_image_frame3)

            # feed_dict = {input_placeholder: np.concatenate((batch_data_frame1, batch_data_frame2, batch_data_frame3, batch_data_frame4,
            # batch_data_frame5, batch_data_frame6, batch_data_frame7, batch_data_frame8, batch_data_frame9, batch_data_frame10), 3), target_placeholder: np.concatenate((batch_data_frame11,batch_data_frame12,batch_data_frame13,batch_data_frame14),3)}
            # Run single step update.

            exec(
                "feed_dict = " + feed_str +
                ", target_placeholder: batch_data_frame41}", m, n)

            prediction_np, target_np, loss_value = sess.run(
                [prediction, target_placeholder, total_loss],
                feed_dict=n['feed_dict'])
            # print("Loss for image %d: %f" % (i,loss_value))
            # for i in range(0,prediction_np.shape[0]):
            #       for j in range(0,4):
            #         file_name = FLAGS.test_image_dir+str(i)+'_out_{}.png'.format(j)
            #         file_name_label = FLAGS.test_image_dir+str(i)+'_gt_{}.png'.format(j)
            #         imsave(prediction_np[i,:,:,j*3:(j+1)*3], file_name)
            # imsave(target_np[i,:,:,j*3:(j+1)*3], file_name_label)
            file_name = FLAGS.test_image_dir + str(j) + '_out.png'
            # file_name_label = FLAGS.test_image_dir+str(j)+'_gt.png'
            imsave(prediction_np[-1, :, :, :], file_name)
            # imsave(target_np[-1,:,:,:], file_name_label)
            j += 1
            print(id_img)
            PSNR += 10 * np.log10(
                255.0 * 255.0 / np.sum(np.square(prediction_np - target_np)))
            # batch_data_frame1[-1]=batch_data_frame2[-1]
            # batch_data_frame2[-1]=batch_data_frame3[-1]
            # batch_data_frame3[-1]=batch_data_frame4[-1]
            # batch_data_frame4[-1]=batch_data_frame5[-1]
            # batch_data_frame5[-1]=batch_data_frame6[-1]
            # batch_data_frame6[-1]=batch_data_frame7[-1]
            # batch_data_frame7[-1]=batch_data_frame8[-1]
            # batch_data_frame8[-1]=batch_data_frame9[-1]
            # batch_data_frame9[-1]=batch_data_frame10[-1]
            # batch_data_frame10[-1]=batch_data_frame11[-1]
            # batch_data_frame11[-1]=prediction_np[-1,:,:,:]

            for i in range(FLAGS.num_out + FLAGS.num_in - 1):
                exec(
                    "batch_data_frame{}[-1]=batch_data_frame{}[-1]".format(
                        i + 1, i + 2), m, n)
            n['batch_data_frame41'][-1] = prediction_np[-1, :, :, :]

        print("Overall PSNR: %f db" % (PSNR / n['data_size']))
def train(
    dataset_frame1, dataset_frame2, dataset_frame3, dataset_frame4,
    dataset_frame5, dataset_frame6, dataset_frame7, dataset_frame8
):  # dataset_frame9, dataset_frame10, dataset_frame11, dataset_frame12, dataset_frame13, dataset_frame14, dataset_frame15, dataset_frame16, dataset_frame17, dataset_frame18, dataset_frame19, dataset_frame20,dataset_frame21):
    # , dataset_frame22, dataset_frame23, dataset_frame24, dataset_frame25, dataset_frame26, dataset_frame27, dataset_frame28, dataset_frame29, dataset_frame30, dataset_frame31, dataset_frame32, dataset_frame33, dataset_frame34, dataset_frame35, dataset_frame36, dataset_frame37, dataset_frame38, dataset_frame39, dataset_frame40,
    # dataset_frame41, dataset_frame42, dataset_frame43, dataset_frame44, dataset_frame45, dataset_frame46, dataset_frame47, dataset_frame48, dataset_frame49, dataset_frame50, dataset_frame51, dataset_frame52, dataset_frame53, dataset_frame54):
    """Trains a model."""
    with tf.Graph().as_default():
        # Create input and target placeholder.
        input_placeholder = tf.placeholder(tf.float32,
                                           shape=(None, 256, 256,
                                                  FLAGS.num_in * 3))
        target_placeholder = tf.placeholder(tf.float32,
                                            shape=(None, 256, 256,
                                                   FLAGS.num_out * 3))

        # input_resized = tf.image.resize_area(input_placeholder, [128, 128])
        # target_resized = tf.image.resize_area(target_placeholder,[128, 128])

        # Prepare model.
        model = Voxel_flow_model()
        prediction = model.inference(input_placeholder)
        # reproduction_loss, prior_loss = model.loss(prediction, target_placeholder)
        reproduction_loss = model.loss(prediction, target_placeholder)
        # total_loss = reproduction_loss + prior_loss
        total_loss = reproduction_loss

        # Perform learning rate scheduling.
        learning_rate = FLAGS.initial_learning_rate

        # Create an optimizer that performs gradient descent.
        opt = tf.train.AdamOptimizer(learning_rate)
        grads = opt.compute_gradients(total_loss)
        update_op = opt.apply_gradients(grads)

        # Create summaries
        summaries = tf.get_collection(tf.GraphKeys.SUMMARIES)
        summaries.append(tf.summary.scalar('total_loss', total_loss))
        summaries.append(
            tf.summary.scalar('reproduction_loss', reproduction_loss))
        # summaries.append(tf.summary.scalar('prior_loss', prior_loss))
        summaries.append(tf.summary.image('Input Image', input_placeholder, 3))
        summaries.append(tf.summary.image('Output Image', prediction, 3))
        summaries.append(
            tf.summary.image('Target Image', target_placeholder, 3))

        # Create a saver.
        saver = tf.train.Saver(tf.all_variables())

        # Build the summary operation from the last tower summaries.
        summary_op = tf.summary.merge_all()

        # Build an initialization operation to run below.
        init = tf.initialize_all_variables()
        sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))
        sess.run(init)

        # Summary Writter
        summary_writer = tf.summary.FileWriter(FLAGS.train_dir,
                                               graph=sess.graph)

        # Training loop using feed dict method.
        # data_list_frame1 = dataset_frame1.read_data_list_file()
        # random.seed(1)
        # shuffle(data_list_frame1)

        # data_list_frame2 = dataset_frame2.read_data_list_file()
        # random.seed(1)
        # shuffle(data_list_frame2)

        # data_list_frame3 = dataset_frame3.read_data_list_file()
        # random.seed(1)
        # shuffle(data_list_frame3)

        # data_list_frame4 = dataset_frame4.read_data_list_file()
        # random.seed(1)
        # shuffle(data_list_frame4)

        # data_list_frame5 = dataset_frame5.read_data_list_file()
        # random.seed(1)
        # shuffle(data_list_frame5)

        # data_list_frame6 = dataset_frame6.read_data_list_file()
        # random.seed(1)
        # shuffle(data_list_frame6)

        # data_list_frame7 = dataset_frame7.read_data_list_file()
        # random.seed(1)
        # shuffle(data_list_frame7)

        # data_list_frame8 = dataset_frame8.read_data_list_file()
        # random.seed(1)
        # shuffle(data_list_frame8)

        # data_list_frame9 = dataset_frame9.read_data_list_file()
        # random.seed(1)
        # shuffle(data_list_frame9)

        # data_list_frame10 = dataset_frame10.read_data_list_file()
        # random.seed(1)
        # shuffle(data_list_frame10)

        # data_list_frame11 = dataset_frame11.read_data_list_file()
        # random.seed(1)
        # shuffle(data_list_frame11)

        # data_list_frame12 = dataset_frame12.read_data_list_file()
        # random.seed(1)
        # shuffle(data_list_frame12)

        # data_list_frame13 = dataset_frame13.read_data_list_file()
        # random.seed(1)
        # shuffle(data_list_frame13)

        # data_list_frame14 = dataset_frame14.read_data_list_file()
        # random.seed(1)
        # shuffle(data_list_frame14)
        feed_str = "{input_placeholder:np.concatenate(("
        for i in range(FLAGS.num_in):
            feed_str += "batch_data_frame{}, ".format(i + 1)
        feed_str = feed_str[:-2] + '),3)'
        feed_str += ", target_placeholder:np.concatenate(("
        for i in range(FLAGS.num_in, FLAGS.num_in + FLAGS.num_out):
            feed_str += "batch_data_frame{}, ".format(i + 1)
        feed_str = feed_str[:-2] + '),3)}'
        m = globals()
        n = locals()
        for i in range(FLAGS.num_out + FLAGS.num_in):
            exec(
                "data_list_frame{}=dataset_frame{}.read_data_list_file()".
                format(i + 1, i + 1), m, n)
            exec("random.seed(1)", m, n)
            exec("shuffle(data_list_frame{})".format(i + 1), m, n)

        exec("data_size = len(data_list_frame1)", m, n)
        exec("epoch_num = int(data_size / FLAGS.batch_size)", m, n)
        results = np.array([[0, 0]])
        # num_workers = 1

        # load_fn_frame1 = partial(dataset_frame1.process_func)
        # p_queue_frame1 = PrefetchQueue(load_fn_frame1, data_list_frame1, FLAGS.batch_size, shuffle=False, num_workers=num_workers)

        # load_fn_frame2 = partial(dataset_frame2.process_func)
        # p_queue_frame2 = PrefetchQueue(load_fn_frame2, data_list_frame2, FLAGS.batch_size, shuffle=False, num_workers=num_workers)

        # load_fn_frame3 = partial(dataset_frame3.process_func)
        # p_queue_frame3 = PrefetchQueue(load_fn_frame3, data_list_frame3, FLAGS.batch_size, shuffle=False, num_workers=num_workers)

        for step in range(0, FLAGS.max_steps):
            n['step'] = step
            exec("batch_idx = step % epoch_num", m, n)

            # batch_data_list_frame1 = data_list_frame1[int(batch_idx * FLAGS.batch_size) : int((batch_idx + 1) * FLAGS.batch_size)]
            # batch_data_list_frame2 = data_list_frame2[int(batch_idx * FLAGS.batch_size) : int((batch_idx + 1) * FLAGS.batch_size)]
            # batch_data_list_frame3 = data_list_frame3[int(batch_idx * FLAGS.batch_size) : int((batch_idx + 1) * FLAGS.batch_size)]
            # batch_data_list_frame4 = data_list_frame4[int(batch_idx * FLAGS.batch_size) : int((batch_idx + 1) * FLAGS.batch_size)]
            # batch_data_list_frame5 = data_list_frame5[int(batch_idx * FLAGS.batch_size) : int((batch_idx + 1) * FLAGS.batch_size)]
            # batch_data_list_frame6 = data_list_frame6[int(batch_idx * FLAGS.batch_size) : int((batch_idx + 1) * FLAGS.batch_size)]
            # batch_data_list_frame7 = data_list_frame7[int(batch_idx * FLAGS.batch_size) : int((batch_idx + 1) * FLAGS.batch_size)]
            # batch_data_list_frame8 = data_list_frame8[int(batch_idx * FLAGS.batch_size) : int((batch_idx + 1) * FLAGS.batch_size)]
            # batch_data_list_frame9 = data_list_frame9[int(batch_idx * FLAGS.batch_size) : int((batch_idx + 1) * FLAGS.batch_size)]
            # batch_data_list_frame10 = data_list_frame10[int(batch_idx * FLAGS.batch_size) : int((batch_idx + 1) * FLAGS.batch_size)]
            # batch_data_list_frame11 = data_list_frame11[int(batch_idx * FLAGS.batch_size) : int((batch_idx + 1) * FLAGS.batch_size)]
            # batch_data_list_frame12 = data_list_frame12[int(batch_idx * FLAGS.batch_size) : int((batch_idx + 1) * FLAGS.batch_size)]
            # batch_data_list_frame13 = data_list_frame13[int(batch_idx * FLAGS.batch_size) : int((batch_idx + 1) * FLAGS.batch_size)]
            # batch_data_list_frame14 = data_list_frame14[int(batch_idx * FLAGS.batch_size) : int((batch_idx + 1) * FLAGS.batch_size)]

            for i in range(FLAGS.num_out + FLAGS.num_in):
                exec(
                    "batch_data_list_frame{} = data_list_frame{}[int(batch_idx * FLAGS.batch_size) : int((batch_idx + 1) * FLAGS.batch_size)]"
                    .format(i + 1, i + 1), m, n)

            # Load batch data.
            # batch_data_frame1 = np.array([dataset_frame1.process_func(line) for line in batch_data_list_frame1])
            # batch_data_frame2 = np.array([dataset_frame2.process_func(line) for line in batch_data_list_frame2])
            # batch_data_frame3 = np.array([dataset_frame3.process_func(line) for line in batch_data_list_frame3])
            # batch_data_frame4 = np.array([dataset_frame4.process_func(line) for line in batch_data_list_frame4])
            # batch_data_frame5 = np.array([dataset_frame5.process_func(line) for line in batch_data_list_frame5])
            # batch_data_frame6 = np.array([dataset_frame6.process_func(line) for line in batch_data_list_frame6])
            # batch_data_frame7 = np.array([dataset_frame7.process_func(line) for line in batch_data_list_frame7])
            # batch_data_frame8 = np.array([dataset_frame8.process_func(line) for line in batch_data_list_frame8])
            # batch_data_frame9 = np.array([dataset_frame9.process_func(line) for line in batch_data_list_frame9])
            # batch_data_frame10 = np.array([dataset_frame10.process_func(line) for line in batch_data_list_frame10])
            # batch_data_frame11 = np.array([dataset_frame11.process_func(line) for line in batch_data_list_frame11])
            # batch_data_frame12 = np.array([dataset_frame12.process_func(line) for line in batch_data_list_frame12])
            # batch_data_frame13 = np.array([dataset_frame13.process_func(line) for line in batch_data_list_frame13])
            # batch_data_frame14 = np.array([dataset_frame14.process_func(line) for line in batch_data_list_frame14])

            # print(np.array([n['dataset_frame1'].process_func(line) for line in n['batch_data_list_frame1']]))
            for i in range(FLAGS.num_out + FLAGS.num_in):
                exec(
                    "tr = np.array([dataset_frame{}.process_func(batch_data_list_frame{}[0])])"
                    .format(i + 1, i + 1), m, n)
                exec("np.expand_dims(tr,axis=0)")
                exec("s='batch_data_list_frame{}'".format(i + 1), m, n)
                for line in n[n['s']][1:]:
                    n['line'] = line
                    exec(
                        "tr=np.append(tr,np.expand_dims(dataset_frame1.process_func(line),axis=0),axis=0)",
                        m, n)
                exec('batch_data_frame{} = tr'.format(i + 1), m, n)
            # batch_data_frame1 = p_queue_frame1.get_batch()
            # batch_data_frame2 = p_queue_frame2.get_batch()
            # batch_data_frame3 = p_queue_frame3.get_batch()

            exec("feed_dict = " + feed_str, m, n)
            # exec("feed_dict = "+feed_str+", target_placeholder: batch_data_frame21}",m,n)

            # exec("feed_dict = {input_placeholder: np.concatenate((batch_data_frame1, batch_data_frame2, batch_data_frame3, batch_data_frame4, batch_data_frame5, batch_data_frame6, batch_data_frame7, batch_data_frame8, batch_data_frame9, batch_data_frame10, batch_data_frame11, batch_data_frame12, batch_data_frame13, batch_data_frame14, batch_data_frame15, batch_data_frame16, batch_data_frame17, batch_data_frame18, batch_data_frame19, batch_data_frame20), 3), target_placeholder: np.concatenate((batch_data_frame11,batch_data_frame12,batch_data_frame13,batch_data_frame14),3)}",m,n)
            # exec("feed_dict = {input_placeholder: np.concatenate((batch_data_frame1, batch_data_frame2, batch_data_frame3, batch_data_frame4, batch_data_frame5, batch_data_frame6, batch_data_frame7, batch_data_frame8, batch_data_frame9, batch_data_frame10, batch_data_frame11, batch_data_frame12, batch_data_frame13, batch_data_frame14, batch_data_frame15, batch_data_frame16, batch_data_frame17, batch_data_frame18, batch_data_frame19, batch_data_frame20, batch_data_frame21, batch_data_frame22, batch_data_frame23, batch_data_frame24, batch_data_frame25, batch_data_frame26, batch_data_frame27, batch_data_frame28, batch_data_frame29, batch_data_frame30, batch_data_frame31, batch_data_frame32, batch_data_frame33, batch_data_frame34, batch_data_frame35, batch_data_frame36, batch_data_frame37, batch_data_frame38, batch_data_frame39, batch_data_frame40), 3), target_placeholder: batch_data_frame41}",m,n)

            # Run single step update.
            _, loss_value = sess.run([update_op, total_loss],
                                     feed_dict=n['feed_dict'])

            if n['batch_idx'] == 0:
                # Shuffle data at each epoch.
                # random.seed(1)
                # shuffle(data_list_frame1)
                # random.seed(1)
                # shuffle(data_list_frame2)
                # random.seed(1)
                # shuffle(data_list_frame3)
                # random.seed(1)
                # shuffle(data_list_frame4)
                # random.seed(1)
                # shuffle(data_list_frame5)
                # random.seed(1)
                # shuffle(data_list_frame6)
                # random.seed(1)
                # shuffle(data_list_frame7)
                # random.seed(1)
                # shuffle(data_list_frame8)
                # random.seed(1)
                # shuffle(data_list_frame9)
                # random.seed(1)
                # shuffle(data_list_frame10)
                # random.seed(1)
                # shuffle(data_list_frame11)
                # random.seed(1)
                # shuffle(data_list_frame12)
                # random.seed(1)
                # shuffle(data_list_frame13)
                # random.seed(1)
                # shuffle(data_list_frame14)
                for i in range(FLAGS.num_in + FLAGS.num_out):
                    exec("random.seed(1)")
                    exec("shuffle(data_list_frame{})".format(i + 1), m, n)
                print('Epoch Number: %d' % int(step / n['epoch_num']))

            # Output Summary
            if step % 10 == 0:
                # summary_str = sess.run(summary_op, feed_dict = feed_dict)
                # summary_writer.add_summary(summary_str, step)
                print("Loss at step %d: %f" % (step, loss_value))
                results = np.append(results, [[step, loss_value]], axis=0)

            if step % 100 == 0:
                # Run a batch of images
                try:
                    prediction_np, target_np = sess.run(
                        [prediction, target_placeholder],
                        feed_dict=n['feed_dict'])
                    for i in range(0, prediction_np.shape[0]):
                        for j in range(0, FLAGS.num_out):
                            file_name = FLAGS.train_image_dir + str(
                                i) + '_out_{}.png'.format(j)
                            file_name_label = FLAGS.train_image_dir + str(
                                i) + '_gt_{}.png'.format(j)
                            imsave(prediction_np[i, :, :, j * 3:(j + 1) * 3],
                                   file_name)
                            imsave(target_np[i, :, :, j * 3:(j + 1) * 3],
                                   file_name_label)
                except ValueError:
                    print(prediction_np[0, :, :, 0:3].shape)
                    print(target_np[0, :, :, 0:3].shape)
                    break

            # Save checkpoint
            if step % 200 == 0 or (step + 1) == FLAGS.max_steps:
                checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)
                np.savetxt(FLAGS.train_image_dir + "results.csv",
                           results[1:, :],
                           delimiter=",",
                           header="iter,loss",
                           comments='')
示例#5
0
def train(dataset_frame1, dataset_frame2, dataset_frame3):
  """Trains a model."""
  with tf.Graph().as_default():
    # Create input.
    data_list_frame1 = dataset_frame1.read_data_list_file()
    dataset_frame1 = tf.data.Dataset.from_tensor_slices(tf.constant(data_list_frame1))
    dataset_frame1 = dataset_frame1.repeat().shuffle(buffer_size=1000000, seed=1).map(_read_image)
    dataset_frame1 = dataset_frame1.prefetch(100)

    data_list_frame2 = dataset_frame2.read_data_list_file()
    dataset_frame2 = tf.data.Dataset.from_tensor_slices(tf.constant(data_list_frame2))
    dataset_frame2 = dataset_frame2.repeat().shuffle(buffer_size=1000000, seed=1).map(_read_image)
    dataset_frame2 = dataset_frame2.prefetch(100)

    data_list_frame3 = dataset_frame3.read_data_list_file()
    dataset_frame3 = tf.data.Dataset.from_tensor_slices(tf.constant(data_list_frame3))
    dataset_frame3 = dataset_frame3.repeat().shuffle(buffer_size=1000000, seed=1).map(_read_image)
    dataset_frame3 = dataset_frame3.prefetch(100)

    batch_frame1 = dataset_frame1.batch(FLAGS.batch_size).make_initializable_iterator()
    batch_frame2 = dataset_frame2.batch(FLAGS.batch_size).make_initializable_iterator()
    batch_frame3 = dataset_frame3.batch(FLAGS.batch_size).make_initializable_iterator()
    
    # Create input and target placeholder.
    input_placeholder = tf.concat([batch_frame1.get_next(), batch_frame3.get_next()], 3)
    target_placeholder = batch_frame2.get_next()

    # input_resized = tf.image.resize_area(input_placeholder, [128, 128])
    # target_resized = tf.image.resize_area(target_placeholder,[128, 128])

    # Prepare model.
    model = Voxel_flow_model(is_train=True)
    prediction, flow_motion, flow_mask = model.inference(input_placeholder)
    # reproduction_loss, prior_loss = model.loss(prediction, target_placeholder)
    reproduction_loss = model.loss(prediction, flow_motion, 
                              flow_mask, target_placeholder,
                              FLAGS.lambda_motion, FLAGS.lambda_mask)
    # total_loss = reproduction_loss + prior_loss
    total_loss = reproduction_loss
    
    # Perform learning rate scheduling.
    learning_rate = FLAGS.initial_learning_rate

    # Create an optimizer that performs gradient descent.
    opt = tf.train.AdamOptimizer(learning_rate)
    grads = opt.compute_gradients(total_loss)
    update_op = opt.apply_gradients(grads)

    # Create summaries
    summaries = tf.get_collection(tf.GraphKeys.SUMMARIES)
    summaries.append(tf.summary.scalar('total_loss', total_loss))
    summaries.append(tf.summary.scalar('reproduction_loss', reproduction_loss))
    # summaries.append(tf.summary.scalar('prior_loss', prior_loss))
    summaries.append(tf.summary.image('Input Image (before)', input_placeholder[:, :, :, 0:3], 3))
    summaries.append(tf.summary.image('Input Image (after)', input_placeholder[:, :, :, 3:6], 3))
    summaries.append(tf.summary.image('Output Image', prediction, 3))
    summaries.append(tf.summary.image('Target Image', target_placeholder, 3))
    # summaries.append(tf.summary.image('Flow', flow, 3))

    # Create a saver.
    saver = tf.train.Saver(tf.global_variables())

    # Build the summary operation from the last tower summaries.
    summary_op = tf.summary.merge_all()

    # Restore checkpoint from file.
    if FLAGS.pretrained_model_checkpoint_path \
        and tf.train.get_checkpoint_state(FLAGS.pretrained_model_checkpoint_path):
      sess = tf.Session()
      assert tf.gfile.Exists(FLAGS.pretrained_model_checkpoint_path)
      ckpt = tf.train.get_checkpoint_state(
               FLAGS.pretrained_model_checkpoint_path)
      restorer = tf.train.Saver()
      restorer.restore(sess, ckpt.model_checkpoint_path)
      print('%s: Pre-trained model restored from %s' %
        (datetime.now(), ckpt.model_checkpoint_path))
      sess.run([batch_frame1.initializer, batch_frame2.initializer, batch_frame3.initializer])
    else:
      # Build an initialization operation to run below.
      print('No existing checkpoints.')
      init = tf.global_variables_initializer()
      sess = tf.Session()
      sess.run([init, batch_frame1.initializer, batch_frame2.initializer, batch_frame3.initializer])

    # Summary Writter
    summary_writer = tf.summary.FileWriter(
      FLAGS.train_dir,
      graph=sess.graph)

    data_size = len(data_list_frame1)
    epoch_num = int(data_size / FLAGS.batch_size)

    for step in range(0, FLAGS.max_steps):
      batch_idx = step % epoch_num
      
      # Run single step update.
      _, loss_value = sess.run([update_op, total_loss])
      
      if batch_idx == 0:
        print('Epoch Number: %d' % int(step / epoch_num))
      
      if step % 10 == 0:
        print("Loss at step %d: %f" % (step, loss_value))

      if step % 100 == 0:
        # Output Summary 
        summary_str = sess.run(summary_op)
        summary_writer.add_summary(summary_str, step)

      if step % 500 == 0:
        # Run a batch of images 
        prediction_np, target_np = sess.run([prediction, target_placeholder])
        for i in range(0,prediction_np.shape[0]):
          file_name = FLAGS.train_image_dir+str(i)+'_out.png'
          file_name_label = FLAGS.train_image_dir+str(i)+'_gt.png'
          imwrite(file_name, prediction_np[i,:,:,:])
          imwrite(file_name_label, target_np[i,:,:,:])

      # Save checkpoint 
      if step % 500 == 0 or (step +1) == FLAGS.max_steps:
        checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
        saver.save(sess, checkpoint_path, global_step=step)