Beispiel #1
0
def main(first, second, out):
    K = 2
    T = 1
    ckpt = "D:\\hyz\\TONGCE-exp\\py\\MCNET.model-10002"
    ref1 = first
    ref2 = second
    # IPython.embed()

    img1 = cv2.imread(ref1)
    img2 = cv2.imread(ref2)

    img1_yuv = cv2.cvtColor(img1, cv2.COLOR_RGB2YUV)
    img2_yuv = cv2.cvtColor(img2, cv2.COLOR_RGB2YUV)

    # Some basic setting
    height = img1_yuv.shape[0]
    width = img1_yuv.shape[1]
    c_dim = 3
    # Then begin to build the MCNet
    with tf.device("/cpu:0"):
        model = MCNET(image_size=[height, width],
                      batch_size=1,
                      K=K,
                      T=T,
                      c_dim=c_dim,
                      is_train=False,
                      checkpoint_dir=ckpt)

    with tf.Session() as sess:
        tf.global_variables_initializer().run()
        model.load_ckpt(sess, ckpt)

        # exit(0)
        seq = np.zeros([1, height, width, K + T, 3], dtype="float32")
        seq[..., 0] = transform(img1_yuv)
        seq[..., 1] = transform(img2_yuv)

        diff = np.zeros([1, height, width, K - 1, 3], dtype="float32")
        diff[:, :, :, 0, :] = inverse_transform(
            seq[..., 1]) - inverse_transform(seq[..., 0])

        pred_raw = sess.run([model.G],
                            feed_dict={
                                model.diff_in: diff,
                                model.xt: seq[..., K - 1],
                                model.target: seq
                            })[0]

        pred_raw = (
            inverse_transform(pred_raw[0].reshape([height, width, 3])) *
            255.0).astype(np.uint8)

        pred_rgb = cv2.cvtColor(pred_raw, cv2.COLOR_YUV2RGB)
        cv2.imwrite(out, pred_rgb)
Beispiel #2
0
def main(lr, batch_size, image_size, K, T, num_iter, gpu):
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu[0])

    data_path = "../data/vimeo_interp_test/target/"
    f = open(data_path + "tri_testlist.txt", "r")
    trainfiles = [l[:-1] for l in f.readlines()]
    margin = 0.3
    updateD = False
    updateG = True
    iters = 0
    prefix = ("VIMEO_MCNET" + "_image_size=" + str(image_size) + "_K=" +
              str(K) + "_T=" + str(T) + "_batch_size=" + str(batch_size) +
              "_lr=" + str(lr))

    print("\n" + prefix + "\n")
    checkpoint_dir = "../models/" + prefix + "/"
    samples_dir = "../samples/" + prefix + "/"
    summary_dir = "../logs/" + prefix + "/"

    if not exists(checkpoint_dir):
        makedirs(checkpoint_dir)
    if not exists(samples_dir):
        makedirs(samples_dir)
    if not exists(summary_dir):
        makedirs(summary_dir)

    with tf.device("/gpu:%d" % gpu[0]):
        model = MCNET(image_size=[image_size, image_size],
                      c_dim=3,
                      K=K,
                      batch_size=batch_size,
                      T=T,
                      checkpoint_dir=checkpoint_dir)

        g_optim = tf.train.AdamOptimizer(lr, beta1=0.5).minimize(
            model.L_img, var_list=model.g_vars)

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.8)
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                          log_device_placement=False,
                                          gpu_options=gpu_options)) as sess:

        tf.global_variables_initializer().run()

        if model.load(sess, checkpoint_dir):
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")

        g_sum = tf.summary.merge(
            [model.L_p_sum, model.L_gdl_sum, model.loss_sum])

        writer = tf.summary.FileWriter(summary_dir, sess.graph)

        counter = iters + 1
        start_time = time.time()
        while iters < num_iter:
            mini_batches = get_minibatches_idx(len(trainfiles),
                                               batch_size,
                                               shuffle=True)
            for _, batchidx in mini_batches:
                if len(batchidx) == batch_size:
                    seq_batch = np.zeros(
                        (batch_size, image_size, image_size, K + T, 3),
                        dtype="float32")
                    diff_batch = np.zeros(
                        (batch_size, image_size, image_size, K - 1, 3),
                        dtype="float32")
                    t0 = time.time()
                    Ts = np.repeat(np.array([T]), batch_size, axis=0)
                    Ks = np.repeat(np.array([K]), batch_size, axis=0)
                    paths = np.repeat(data_path, batch_size, axis=0)
                    tfiles = np.array(trainfiles)[batchidx]
                    shapes = np.repeat(np.array([image_size]),
                                       batch_size,
                                       axis=0)
                    output = [
                        load_vimeo_data(f, p, img_sze, k,
                                        t) for f, p, img_sze, k, t in zip(
                                            tfiles, paths, shapes, Ks, Ts)
                    ]

                for i in range(batch_size):
                    seq_batch[i] = output[i][0]
                    diff_batch[i] = output[i][1]

                _, summary_str = sess.run(
                    [g_optim, g_sum],
                    feed_dict={
                        model.diff_in: diff_batch,
                        model.xt: seq_batch[:, :, :, K - 1],
                        model.target: seq_batch
                    })
                writer.add_summary(summary_str, counter)

                errL_img = model.L_img.eval({
                    model.diff_in: diff_batch,
                    model.xt: seq_batch[:, :, :, K - 1],
                    model.target: seq_batch
                })
                counter += 1
                if counter % 50 == 0:
                    print("Iters: [%2d] time: %4.4f, L_img: %.8f" %
                          (iters, time.time() - start_time, errL_img))

                if np.mod(counter, 10) == 1:
                    samples = sess.run(
                        [model.G],
                        feed_dict={
                            model.diff_in: diff_batch,
                            model.xt: seq_batch[:, :, :, K - 1],
                            model.target: seq_batch
                        })[0]
                    # IPython.embed()
                    samples = samples[0].swapaxes(0, 2).swapaxes(1, 2)
                    # IPython.embed()

                    sbatch = seq_batch[0, :, :, :].swapaxes(0,
                                                            2).swapaxes(1, 2)

                    sbatch2 = sbatch.copy()
                    # IPython.embed()
                    sbatch2[K:, :, :] = samples
                    # IPython.embed()
                    samples = np.concatenate((sbatch2, sbatch), axis=0)
                    # IPython.embed()
                    print("Saving sample ...")
                    save_images(samples, [2, K + T],
                                samples_dir + "train_%s.png" % (iters))
                if np.mod(counter, 10000) == 2:
                    model.save(sess, checkpoint_dir, counter)

                iters += 1
Beispiel #3
0
def main(lr, batch_size, alpha, beta, image_size, K, T, num_iter, gpu):
    data_path = "../data/KTH/"
    f = open(data_path + "train_data_list_trimmed.txt", "r")
    trainfiles = f.readlines()
    margin = 0.3
    updateD = True
    updateG = True
    iters = 0
    prefix = ("KTH_MCNET" + "_image_size=" + str(image_size) + "_K=" + str(K) +
              "_T=" + str(T) + "_batch_size=" + str(batch_size) + "_alpha=" +
              str(alpha) + "_beta=" + str(beta) + "_lr=" + str(lr))

    print("\n" + prefix + "\n")
    checkpoint_dir = "../models/" + prefix + "/"
    samples_dir = "../samples/" + prefix + "/"
    summary_dir = "../logs/" + prefix + "/"

    if not exists(checkpoint_dir):
        makedirs(checkpoint_dir)
    if not exists(samples_dir):
        makedirs(samples_dir)
    if not exists(summary_dir):
        makedirs(summary_dir)

    with tf.device("/gpu:%d" % gpu[0]):
        model = MCNET(image_size=[image_size, image_size],
                      c_dim=1,
                      K=K,
                      batch_size=batch_size,
                      T=T,
                      checkpoint_dir=checkpoint_dir)
        d_optim = tf.train.AdamOptimizer(lr, beta1=0.5).minimize(
            model.d_loss, var_list=model.d_vars)
        g_optim = tf.train.AdamOptimizer(lr, beta1=0.5).minimize(
            alpha * model.L_img + beta * model.L_GAN, var_list=model.g_vars)

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=1.0)
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                          log_device_placement=False,
                                          gpu_options=gpu_options)) as sess:

        tf.global_variables_initializer().run()

        if model.load(sess, checkpoint_dir):
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")

        g_sum = tf.summary.merge(
            [model.L_p_sum, model.L_gdl_sum, model.loss_sum, model.L_GAN_sum])
        d_sum = tf.summary.merge(
            [model.d_loss_real_sum, model.d_loss_sum, model.d_loss_fake_sum])
        writer = tf.summary.FileWriter(summary_dir, sess.graph)

        counter = iters + 1
        start_time = time.time()

        with Parallel(n_jobs=batch_size) as parallel:
            while iters < num_iter:
                mini_batches = get_minibatches_idx(len(trainfiles),
                                                   batch_size,
                                                   shuffle=True)
                for _, batchidx in mini_batches:
                    if len(batchidx) == batch_size:
                        seq_batch = np.zeros(
                            (batch_size, image_size, image_size, K + T, 1),
                            dtype="float32")
                        diff_batch = np.zeros(
                            (batch_size, image_size, image_size, K - 1, 1),
                            dtype="float32")
                        t0 = time.time()
                        Ts = np.repeat(np.array([T]), batch_size, axis=0)
                        Ks = np.repeat(np.array([K]), batch_size, axis=0)
                        paths = np.repeat(data_path, batch_size, axis=0)
                        tfiles = np.array(trainfiles)[batchidx]
                        shapes = np.repeat(np.array([image_size]),
                                           batch_size,
                                           axis=0)
                        output = parallel(
                            delayed(load_kth_data)(f, p, img_sze, k, t)
                            for f, p, img_sze, k, t in zip(
                                tfiles, paths, shapes, Ks, Ts))
                        for i in xrange(batch_size):
                            seq_batch[i] = output[i][0]
                            diff_batch[i] = output[i][1]

                        if updateD:
                            _, summary_str = sess.run(
                                [d_optim, d_sum],
                                feed_dict={
                                    model.diff_in: diff_batch,
                                    model.xt: seq_batch[:, :, :, K - 1],
                                    model.target: seq_batch
                                })
                            writer.add_summary(summary_str, counter)

                        if updateG:
                            _, summary_str = sess.run(
                                [g_optim, g_sum],
                                feed_dict={
                                    model.diff_in: diff_batch,
                                    model.xt: seq_batch[:, :, :, K - 1],
                                    model.target: seq_batch
                                })
                            writer.add_summary(summary_str, counter)

                        errD_fake = model.d_loss_fake.eval({
                            model.diff_in:
                            diff_batch,
                            model.xt:
                            seq_batch[:, :, :, K - 1],
                            model.target:
                            seq_batch
                        })
                        errD_real = model.d_loss_real.eval({
                            model.diff_in:
                            diff_batch,
                            model.xt:
                            seq_batch[:, :, :, K - 1],
                            model.target:
                            seq_batch
                        })
                        errG = model.L_GAN.eval({
                            model.diff_in:
                            diff_batch,
                            model.xt:
                            seq_batch[:, :, :, K - 1],
                            model.target:
                            seq_batch
                        })

                        if errD_fake < margin or errD_real < margin:
                            updateD = False
                        if errD_fake > (1. - margin) or errD_real > (1. -
                                                                     margin):
                            updateG = False
                        if not updateD and not updateG:
                            updateD = True
                            updateG = True

                        counter += 1

                        print(
                            "Iters: [%2d] time: %4.4f, d_loss: %.8f, L_GAN: %.8f"
                            % (iters, time.time() - start_time,
                               errD_fake + errD_real, errG))

                        if np.mod(counter, 100) == 1:
                            samples = sess.run(
                                [model.G],
                                feed_dict={
                                    model.diff_in: diff_batch,
                                    model.xt: seq_batch[:, :, :, K - 1],
                                    model.target: seq_batch
                                })[0]
                            samples = samples[0].swapaxes(0, 2).swapaxes(1, 2)
                            sbatch = seq_batch[0, :, :,
                                               K:].swapaxes(0,
                                                            2).swapaxes(1, 2)
                            samples = np.concatenate((samples, sbatch), axis=0)
                            print("Saving sample ...")
                            save_images(samples[:, :, :, ::-1], [2, T],
                                        samples_dir + "train_%s.png" % (iters))
                        if np.mod(counter, 500) == 2:
                            model.save(sess, checkpoint_dir, counter)

                        iters += 1
Beispiel #4
0
def main(prefix, image_size, K, T, gpu):
  data_path = "../data/KTH/"
  f = open(data_path+"test_data_list.txt","r")
  testfiles = f.readlines()
  c_dim = 1
  iters = 0

  if prefix == "paper_models":
    checkpoint_dir = "../models/"+prefix+"/KTH/"
    best_model = "MCNET.model-98502"
  else:
    checkpoint_dir = "../models/"+prefix+"/"
    best_model = None # will pick last model

  with tf.device("/gpu:%d"%gpu[0]):
    model = MCNET(image_size=[image_size, image_size], batch_size=1, K=K,
                  T=T, c_dim=c_dim, checkpoint_dir=checkpoint_dir,
                  is_train=False)

  gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.5)
  with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False,
                                        gpu_options=gpu_options)) as sess:

    tf.global_variables_initializer().run()

    loaded, model_name = model.load(sess, checkpoint_dir, best_model)

    if loaded:
      print(" [*] Load SUCCESS")
    else:
      print(" [!] Load failed... exitting")
      return

    quant_dir = "../results/quantitative/KTH/"+prefix+"/"
    save_path = quant_dir+"results_model="+model_name+".npz"
    if not exists(quant_dir):
      makedirs(quant_dir)


    vid_names = []
    psnr_err = np.zeros((0, T))
    ssim_err = np.zeros((0, T))
    for i in xrange(len(testfiles)):
      tokens = testfiles[i].split()
      vid_path = data_path+tokens[0]+"_uncomp.avi"
      while True:
        try:
          vid = imageio.get_reader(vid_path,"ffmpeg")
          break
        except Exception:
          print("imageio failed loading frames, retrying")

      action = vid_path.split("_")[1]
      if action in ["running", "jogging"]:
        n_skip = 3
      else:
        n_skip = T

      for j in xrange(int(tokens[1]),int(tokens[2])-K-T-1,n_skip):
        print("Video "+str(i)+"/"+str(len(testfiles))+". Index "+str(j)+
              "/"+str(vid.get_length()-T-1))

        folder_pref = vid_path.split("/")[-1].split(".")[0]
        folder_name = folder_pref+"."+str(j)+"-"+str(j+T)

        vid_names.append(folder_name)
        savedir = "../results/images/KTH/"+prefix+"/"+folder_name

        seq_batch = np.zeros((1, image_size, image_size,
                              K+T, c_dim), dtype="float32")
        diff_batch = np.zeros((1, image_size, image_size,
                               K-1, 1), dtype="float32") 
        for t in xrange(K+T):

          # imageio fails randomly sometimes
          while True:
            try:
              img = cv2.resize(vid.get_data(j+t), (image_size, image_size))
              break
            except Exception:
              print("imageio failed loading frames, retrying")

          img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
          seq_batch[0,:,:,t] = transform(img[:,:,None])

        for t in xrange(1,K):
          prev = inverse_transform(seq_batch[0,:,:,t-1])
          next = inverse_transform(seq_batch[0,:,:,t])
          diff = next.astype("float32")-prev.astype("float32")
          diff_batch[0,:,:,t-1] = diff

        true_data = seq_batch[:,:,:,K:,:].copy()
        pred_data = np.zeros(true_data.shape, dtype="float32")
        xt = seq_batch[:,:,:,K-1]
        pred_data[0] = sess.run(model.G,
                                feed_dict={model.diff_in: diff_batch,
                                           model.xt: xt})

        if not os.path.exists(savedir):
          os.makedirs(savedir)

        cpsnr = np.zeros((K+T,))
        cssim = np.zeros((K+T,))
        pred_data = np.concatenate((seq_batch[:,:,:,:K], pred_data),axis=3)
        true_data = np.concatenate((seq_batch[:,:,:,:K], true_data),axis=3)
        for t in xrange(K+T):
          pred = (inverse_transform(pred_data[0,:,:,t])*255).astype("uint8")
          target = (inverse_transform(true_data[0,:,:,t])*255).astype("uint8")

          cpsnr[t] = measure.compare_psnr(pred,target)
          cssim[t] = ssim.compute_ssim(Image.fromarray(cv2.cvtColor(target,
                                                       cv2.COLOR_GRAY2BGR)),
                                       Image.fromarray(cv2.cvtColor(pred,
                                                       cv2.COLOR_GRAY2BGR)))
          pred = draw_frame(pred, t < K)
          target = draw_frame(target, t < K)

          cv2.imwrite(savedir+"/pred_"+"{0:04d}".format(t)+".png", pred)
          cv2.imwrite(savedir+"/gt_"+"{0:04d}".format(t)+".png", target)

        cmd1 = "rm "+savedir+"/pred.gif"
        cmd2 = ("ffmpeg -f image2 -framerate 7 -i "+savedir+
                "/pred_%04d.png "+savedir+"/pred.gif")
        cmd3 = "rm "+savedir+"/pred*.png"

        # Comment out "system(cmd3)" if you want to keep the output images
        # Otherwise only the gifs will be kept
        system(cmd1); system(cmd2); system(cmd3)

        cmd1 = "rm "+savedir+"/gt.gif"
        cmd2 = ("ffmpeg -f image2 -framerate 7 -i "+savedir+
                "/gt_%04d.png "+savedir+"/gt.gif")
        cmd3 = "rm "+savedir+"/gt*.png"

        # Comment out "system(cmd3)" if you want to keep the output images
        # Otherwise only the gifs will be kept
        system(cmd1); system(cmd2); system(cmd3)

        psnr_err = np.concatenate((psnr_err, cpsnr[None,K:]), axis=0)
        ssim_err = np.concatenate((ssim_err, cssim[None,K:]), axis=0)

    np.savez(save_path, psnr=psnr_err, ssim=ssim_err)
    print("Results saved to "+save_path)
  print("Done.")
Beispiel #5
0
def main(lr, prefix, K, T, gpu):
    data_path = "../data/UCF101/UCF-101/"
    f = open(data_path.rsplit("/", 2)[0] + "/testlist01.txt", "r")
    testfiles = f.readlines()
    image_size = [240, 320]
    c_dim = 3
    iters = 0

    if prefix == "paper_models":
        checkpoint_dir = "../models/" + prefix + "/S1M/"
        best_model = "MCNET.model-102502"
    else:
        checkpoint_dir = "../models/" + prefix + "/"
        best_model = None  # will pick last model

    with tf.device("/gpu:%d" % gpu[0]):
        model = MCNET(image_size=image_size,
                      batch_size=1,
                      K=K,
                      T=T,
                      c_dim=c_dim,
                      checkpoint_dir=checkpoint_dir,
                      is_train=False)

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.5)
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                          log_device_placement=False,
                                          gpu_options=gpu_options)) as sess:

        tf.global_variables_initializer().run()

        loaded, model_name = model.load(sess, checkpoint_dir, best_model)

        if loaded:
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed... exitting")
            return

        quant_dir = "../results/quantitative/UCF101/" + prefix + "/"
        save_path = quant_dir + "results_model=" + model_name + ".npz"
        if not exists(quant_dir):
            makedirs(quant_dir)

        vid_names = []
        psnr_err = np.zeros((0, T))
        ssim_err = np.zeros((0, T))
        for i in range(0, len(testfiles), 10):
            print(" Video " + str(i) + "/" + str(len(testfiles)))

            tokens = testfiles[i].split("/")[1].split()

            testfiles[i] = testfiles[i].replace("/HandStandPushups/",
                                                "/HandstandPushups/")

            vid_path = data_path + testfiles[i].split()[0]
            vid = imageio.get_reader(vid_path, "ffmpeg")
            folder_name = vid_path.split("/")[-1].split(".")[0]
            vid_names.append(folder_name)
            vid = imageio.get_reader(vid_path, "ffmpeg")
            savedir = "../results/images/UCF101/" + prefix + "/" + str(i + 1)

            seq_batch = np.zeros(
                (1, image_size[0], image_size[1], K + T, c_dim),
                dtype="float32")
            diff_batch = np.zeros((1, image_size[0], image_size[1], K - 1, 1),
                                  dtype="float32")
            for t in range(K + T):
                img = vid.get_data(t)[:, :, ::-1]
                seq_batch[0, :, :, t] = transform(img)

            for t in range(1, K):
                prev = inverse_transform(seq_batch[0, :, :, t - 1]) * 255
                prev = cv2.cvtColor(prev.astype("uint8"), cv2.COLOR_BGR2GRAY)
                next = inverse_transform(seq_batch[0, :, :, t]) * 255
                next = cv2.cvtColor(next.astype("uint8"), cv2.COLOR_BGR2GRAY)
                diff = next.astype("float32") - prev.astype("float32")
                diff_batch[0, :, :, t - 1] = diff[:, :, None] / 255.

            true_data = seq_batch[:, :, :, K:, :].copy()
            pred_data = np.zeros(true_data.shape, dtype="float32")
            xt = seq_batch[:, :, :, K - 1]
            pred_data[0] = sess.run(model.G,
                                    feed_dict={
                                        model.diff_in: diff_batch,
                                        model.xt: xt
                                    })

            if not os.path.exists(savedir):
                os.makedirs(savedir)

            cpsnr = np.zeros((K + T, ))
            cssim = np.zeros((K + T, ))
            pred_data = np.concatenate((seq_batch[:, :, :, :K], pred_data),
                                       axis=3)
            true_data = np.concatenate((seq_batch[:, :, :, :K], true_data),
                                       axis=3)
            for t in range(K + T):
                pred = (inverse_transform(pred_data[0, :, :, t]) *
                        255).astype("uint8")
                target = (inverse_transform(true_data[0, :, :, t]) *
                          255).astype("uint8")

                cpsnr[t] = measure.compare_psnr(pred, target)
                cssim[t] = ssim.compute_ssim(Image.fromarray(target),
                                             Image.fromarray(pred))

                pred = draw_frame(pred, t < K)
                target = draw_frame(target, t < K)

                cv2.imwrite(savedir + "/pred_" + "{0:04d}".format(t) + ".png",
                            pred)
                cv2.imwrite(savedir + "/gt_" + "{0:04d}".format(t) + ".png",
                            target)

            cmd1 = "rm " + savedir + "/pred.gif"
            cmd2 = ("ffmpeg -f image2 -framerate 3 -i " + savedir +
                    "/pred_%04d.png " + savedir + "/pred.gif")
            cmd3 = "rm " + savedir + "/pred*.png"

            # Comment out "system(cmd3)" if you want to keep the output images
            # Otherwise only the gifs will be kept
            system(cmd1)
            system(cmd2)
            system(cmd3)

            cmd1 = "rm " + savedir + "/gt.gif"
            cmd2 = ("ffmpeg -f image2 -framerate 3 -i " + savedir +
                    "/gt_%04d.png " + savedir + "/gt.gif")
            cmd3 = "rm " + savedir + "/gt*.png"

            # Comment out "system(cmd3)" if you want to keep the output images
            # Otherwise only the gifs will be kept
            system(cmd1)
            system(cmd2)
            system(cmd3)

            psnr_err = np.concatenate((psnr_err, cpsnr[None, K:]), axis=0)
            ssim_err = np.concatenate((ssim_err, cssim[None, K:]), axis=0)

        np.savez(save_path, psnr=psnr_err, ssim=ssim_err)
        print("Results saved to " + save_path)
    print("Done.")
Beispiel #6
0
def main(lr, batch_size, alpha, beta, image_size_h, image_size_w, K,
         T, num_iter, gpu):
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu[0])
    
    data_path = "/data1/ikusyou/vimeo_septuplet/sequences/"
    f = open(data_path+"sep_trainlist.txt","r")
    trainfiles = [l[:-1] for l in f.readlines()]
    f.close()

    # new part for validation files, @@@ please fill the path
    val_path = ""
    f = open(val_path+"", "r")
    val_files = [l[:-1] for l in f.readlines()]
    f.close()
    val_num = len(val_files)
    val_seq = np.zeros((val_num, 240, 416, K+T, 3), dtype="float32")
    val_diff = np.zeros((val_num, 240, 416, K-1, 3), dtype="float32")

    # At the very beginning, we will read the whole validation set to the memory
    for idx, f_name in enumerate(val_files):
        output = load_vimeo_data(f_name, val_path, [240, 416], K, T)
        val_seq[idx] = output[0]
        val_diff[idx] = output[1]
    # Now we finish reading the val data    


    margin = 0.3 
    updateD = False
    updateG = True
    iters = 0
    prefix  = ("VIMEO_MCNET_V1.1.1"
            + "_image_size_h="+str(image_size_h)
            + "_image_size_w="+str(image_size_w)
            + "_K="+str(K)
            + "_T="+str(T)
            + "_batch_size="+str(batch_size)
            + "_lr="+str(lr))

    print("\n"+prefix+"\n")
    checkpoint_dir = "../models/"+prefix+"/"
    samples_dir = "../samples/"+prefix+"/"
    summary_dir = "../logs/"+prefix+"/"

    if not exists(checkpoint_dir):
        makedirs(checkpoint_dir)
    if not exists(samples_dir):
        makedirs(samples_dir)
    if not exists(summary_dir):
        makedirs(summary_dir)
    
    with tf.device("/gpu:%d"%gpu[0]):
        model = MCNET(image_size=[image_size_h,image_size_w], c_dim=3,
                    K=K, batch_size=batch_size, T=T,
                    checkpoint_dir=checkpoint_dir)
        # d_optim = tf.train.AdamOptimizer(lr, beta1=0.5).minimize(
        #     model.d_loss, var_list=model.d_vars
        # )
        g_optim = tf.train.AdamOptimizer(lr, beta1=0.5).minimize(
            alpha*model.L_img, var_list=model.g_vars
        )

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.8)
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                    log_device_placement=False,
                    gpu_options=gpu_options)) as sess:

        tf.global_variables_initializer().run()

        if model.load(sess, checkpoint_dir):
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")

        g_sum = tf.summary.merge([model.L_p_sum,
                                model.L_gdl_sum, model.loss_sum])
        # d_sum = tf.summary.merge([model.d_loss_real_sum, model.d_loss_sum,
        #                         model.d_loss_fake_sum])
        writer = tf.summary.FileWriter(summary_dir, sess.graph)

        counter = iters+1
        start_time = time.time()
        # IPython.embed()
        with Parallel(n_jobs=batch_size) as parallel:
            while iters < num_iter:
                mini_batches = get_minibatches_idx(len(trainfiles), batch_size, shuffle=True)
                for _, batchidx in mini_batches:
                    if len(batchidx) == batch_size:
                        seq_batch  = np.zeros((batch_size, image_size_h, image_size_w,
                                                K+T, 3), dtype="float32")
                        diff_batch = np.zeros((batch_size, image_size_h, image_size_w,
                                                K-1, 3), dtype="float32")
                        t0 = time.time()
                        Ts = np.repeat(np.array([T]),batch_size,axis=0)
                        Ks = np.repeat(np.array([K]),batch_size,axis=0)
                        paths = np.repeat(data_path, batch_size,axis=0)
                        tfiles = np.array(trainfiles)[batchidx]
                        shapes = np.repeat(np.array([image_size_h]),batch_size,axis=0)
                        output = parallel(delayed(load_vimeo_data)(f, p,img_sze, k, t)
                                                                for f,p,img_sze,k,t in zip(tfiles,
                                                                                        paths,
                                                                                        shapes,
                                                                                        Ks, Ts))

                    for i in range(batch_size):
                        seq_batch[i] = output[i][0]
                        diff_batch[i] = output[i][1]


                    _, summary_str = sess.run([g_optim, g_sum],
                                                feed_dict={model.diff_in: diff_batch,
                                                        model.xt: seq_batch[:,:,:,K-1],
                                                        model.target: seq_batch})

                    
                    writer.add_summary(summary_str, counter)
                    
                    errL_img = model.L_img.eval({model.diff_in: diff_batch,
                                                model.xt: seq_batch[:,:,:,K-1],
                                                model.target: seq_batch})



                    counter += 1
                    if counter % 50 == 0:
                        print(
                            "Iters: [%2d] time: %4.4f, img_loss:%.8f" 
                            % (iters, time.time() - start_time, errL_img)
                        )
                        val_loss = []
                        for idx in range(0, val_num, batch_size)[:-1]: # [:-1] to avoid out if index
                            diff_batch = val_diff[idx:idx + batch_size]
                            seq_batch = val_seq[idx:idx + batch_size]
                            errL_img = model.L_img.eval({model.diff_in: diff_batch,
                                                        model.xt: seq_batch[:,:,:,K-1],
                                                        model.target: seq_batch})
                            val_loss.append(np.mean(errL_img))
                        



                    if np.mod(counter, 200) == 1:
                        samples = sess.run([model.G],
                                            feed_dict={model.diff_in: diff_batch,
                                                        model.xt: seq_batch[:,:,:,K-1],
                                                        model.target: seq_batch})[0]
                        # IPython.embed()
                        samples = samples[0].swapaxes(0,2).swapaxes(1,2)
                        # IPython.embed()

                        sbatch  = seq_batch[0,:,:,:].swapaxes(0,2).swapaxes(1,2)
                        
                        sbatch2 = sbatch.copy()
                        # IPython.embed()
                        sbatch2[K:,:,:] = samples
                        # IPython.embed()
                        samples = np.concatenate((sbatch2,sbatch), axis=0)
                        # IPython.embed()
                        print("Saving sample ...")
                        save_images(samples, [2, K+T], 
                                    samples_dir+"train_%s.png" % (iters))
                    if np.mod(counter, 10000) == 2:
                        model.save(sess, checkpoint_dir, counter)

                    iters += 1
def main(name, lr, batch_size, alpha, beta, image_size, K, T, num_iter, gpu,
         nonlinearity, samples_every, gdl, channels, dataset, residual,
         planes):
    margin = 0.3
    updateD = True
    updateG = True
    iters = 0
    namestr = name if len(name) == 0 else "_" + name
    prefix = (dataset.replace('/', '-') + namestr + "_channels=" +
              str(channels) + "_alpha=" + str(alpha) + "_planes=" +
              str(planes) + "_beta=" + str(beta) + "_lr=" + str(lr) +
              "_nonlin=" + str(nonlinearity) + "_res=" + str(residual) +
              "_gdl=" + str(gdl))

    print("\n" + prefix + "\n")
    checkpoint_dir = "../models/" + prefix + "/"
    samples_dir = "../samples/" + prefix + "/"
    summary_dir = "../logs/" + prefix + "/"

    normalizer = partial(normalize_data, dataset, channels, K)
    train_data, test_data, num_workers = load_dataset(dataset,
                                                      T + K,
                                                      image_size,
                                                      channels,
                                                      transforms=[normalizer])

    train_loader = DataLoader(train_data,
                              num_workers=num_workers,
                              batch_size=batch_size,
                              shuffle=True,
                              drop_last=True,
                              pin_memory=True)

    def get_training_batch():
        while True:
            for sequence in train_loader:
                yield sequence

    training_batch_generator = get_training_batch()

    checkpoint_dir = "../models/" + prefix + "/"
    samples_dir = "../samples/" + prefix + "/"
    summary_dir = "../logs/" + prefix + "/"

    if not exists(checkpoint_dir):
        makedirs(checkpoint_dir)
    if not exists(samples_dir):
        makedirs(samples_dir)
    if not exists(summary_dir):
        makedirs(summary_dir)

    with tf.device("/gpu:%d" % gpu[0]):
        model = MCNET(image_size=[image_size, image_size],
                      c_dim=channels,
                      K=K,
                      batch_size=batch_size,
                      T=T,
                      checkpoint_dir=checkpoint_dir,
                      nonlinearity=nonlinearity,
                      gdl_weight=gdl,
                      residual=residual,
                      planes=planes)
        d_optim = tf.train.AdamOptimizer(lr, beta1=0.5).minimize(
            model.d_loss, var_list=model.d_vars)
        # facd_optim = tf.train.AdamOptimizer(lr, beta1=0.5).minimize(
        #     model.facd_loss, var_list=model.facd_vars
        # )
        g_optim = tf.train.AdamOptimizer(lr, beta1=0.5).minimize(
            # alpha * model.L_img + beta * model.L_GAN + gamma * model.L_FAC,
            alpha * model.L_img + beta * model.L_GAN,
            var_list=model.g_vars)
        print("GDL: ", model.gdl_weight)

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.9)
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=True,
                                            intra_op_parallelism_threads=3,
                                            inter_op_parallelism_threads=3,
                                            gpu_options=gpu_options))

    tf.global_variables_initializer().run()

    if model.load(sess, checkpoint_dir):
        print(" [*] Load SUCCESS")
    else:
        print(" [!] Load failed...")

    g_sum = tf.summary.merge(
        [model.L_p_sum, model.L_gdl_sum, model.loss_sum, model.L_GAN_sum])
    d_sum = tf.summary.merge(
        [model.d_loss_real_sum, model.d_loss_sum, model.d_loss_fake_sum])
    writer = tf.summary.FileWriter(summary_dir, sess.graph)

    counter = iters + 1
    start_time = time.time()

    with Parallel(n_jobs=batch_size) as parallel:
        while iters < num_iter:
            for batch_index in range(100000):
                seq_batch, diff_batch = next(training_batch_generator)
                # show(seq_batch[0])
                # show(diff_batch[0])
                # ipdb.set_trace()
                seq_batch = seq_batch.numpy()
                diff_batch = diff_batch.numpy()

                if updateD:
                    _, summary_str = sess.run(
                        [d_optim, d_sum],
                        feed_dict={
                            model.diff_in: diff_batch,
                            model.xt: seq_batch[:, :, :, K - 1],
                            model.target: seq_batch
                        })
                    writer.add_summary(summary_str, counter)

                if updateG:
                    _, summary_str = sess.run(
                        [g_optim, g_sum],
                        feed_dict={
                            model.diff_in: diff_batch,
                            model.xt: seq_batch[:, :, :, K - 1],
                            model.target: seq_batch
                        })
                    writer.add_summary(summary_str, counter)

                errD_fake = model.d_loss_fake.eval({
                    model.diff_in:
                    diff_batch,
                    model.xt:
                    seq_batch[:, :, :, K - 1],
                    model.target:
                    seq_batch
                })
                errD_real = model.d_loss_real.eval({
                    model.diff_in:
                    diff_batch,
                    model.xt:
                    seq_batch[:, :, :, K - 1],
                    model.target:
                    seq_batch
                })
                errG = model.L_GAN.eval({
                    model.diff_in: diff_batch,
                    model.xt: seq_batch[:, :, :, K - 1],
                    model.target: seq_batch
                })
                errImage = model.L_img.eval({
                    model.diff_in: diff_batch,
                    model.xt: seq_batch[:, :, :, K - 1],
                    model.target: seq_batch
                })

                if errD_fake < margin or errD_real < margin:
                    updateD = False
                if errD_fake > (1. - margin) or errD_real > (1. - margin):
                    updateG = False
                if not updateD and not updateG:
                    updateD = True
                    updateG = True

                counter += 1

                print(
                    "Iters: [%2d] time: %4.4f, d_loss: %.8f, L_GAN: %.8f, L_img: %.8f"
                    % (iters, time.time() - start_time, errD_fake + errD_real,
                       errG, errImage))

                if (counter % samples_every) == 0:
                    samples = sess.run(
                        [model.G],
                        feed_dict={
                            model.diff_in: diff_batch,
                            model.xt: seq_batch[:, :, :, K - 1],
                            model.target: seq_batch
                        })[0]
                    # ipdb.set_trace()
                    samples_pad = np.array(samples)
                    samples_pad.fill(0)
                    generations = np.concatenate((samples_pad, samples), 3)
                    generations = generations[0].swapaxes(0, 2).swapaxes(1, 2)
                    sbatch = seq_batch[0].swapaxes(0, 2).swapaxes(1, 2)
                    generations = np.concatenate((generations, sbatch), axis=0)
                    print("Saving sample ...")
                    # ipdb.set_trace()
                    save_images(generations[:, :, :, ::-1], [2, T + K],
                                samples_dir + "train_%s.png" % (iters))
                if np.mod(counter, 500) == 2:
                    model.save(sess, checkpoint_dir, counter)

                iters += 1
    sess.close()