Beispiel #1
0
def main(lr, batch_size, image_size, K, T, num_iter, gpu):
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu[0])

    data_path = "../data/vimeo_interp_test/target/"
    f = open(data_path + "tri_testlist.txt", "r")
    trainfiles = [l[:-1] for l in f.readlines()]
    margin = 0.3
    updateD = False
    updateG = True
    iters = 0
    prefix = ("VIMEO_MCNET" + "_image_size=" + str(image_size) + "_K=" +
              str(K) + "_T=" + str(T) + "_batch_size=" + str(batch_size) +
              "_lr=" + str(lr))

    print("\n" + prefix + "\n")
    checkpoint_dir = "../models/" + prefix + "/"
    samples_dir = "../samples/" + prefix + "/"
    summary_dir = "../logs/" + prefix + "/"

    if not exists(checkpoint_dir):
        makedirs(checkpoint_dir)
    if not exists(samples_dir):
        makedirs(samples_dir)
    if not exists(summary_dir):
        makedirs(summary_dir)

    with tf.device("/gpu:%d" % gpu[0]):
        model = MCNET(image_size=[image_size, image_size],
                      c_dim=3,
                      K=K,
                      batch_size=batch_size,
                      T=T,
                      checkpoint_dir=checkpoint_dir)

        g_optim = tf.train.AdamOptimizer(lr, beta1=0.5).minimize(
            model.L_img, var_list=model.g_vars)

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.8)
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                          log_device_placement=False,
                                          gpu_options=gpu_options)) as sess:

        tf.global_variables_initializer().run()

        if model.load(sess, checkpoint_dir):
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")

        g_sum = tf.summary.merge(
            [model.L_p_sum, model.L_gdl_sum, model.loss_sum])

        writer = tf.summary.FileWriter(summary_dir, sess.graph)

        counter = iters + 1
        start_time = time.time()
        while iters < num_iter:
            mini_batches = get_minibatches_idx(len(trainfiles),
                                               batch_size,
                                               shuffle=True)
            for _, batchidx in mini_batches:
                if len(batchidx) == batch_size:
                    seq_batch = np.zeros(
                        (batch_size, image_size, image_size, K + T, 3),
                        dtype="float32")
                    diff_batch = np.zeros(
                        (batch_size, image_size, image_size, K - 1, 3),
                        dtype="float32")
                    t0 = time.time()
                    Ts = np.repeat(np.array([T]), batch_size, axis=0)
                    Ks = np.repeat(np.array([K]), batch_size, axis=0)
                    paths = np.repeat(data_path, batch_size, axis=0)
                    tfiles = np.array(trainfiles)[batchidx]
                    shapes = np.repeat(np.array([image_size]),
                                       batch_size,
                                       axis=0)
                    output = [
                        load_vimeo_data(f, p, img_sze, k,
                                        t) for f, p, img_sze, k, t in zip(
                                            tfiles, paths, shapes, Ks, Ts)
                    ]

                for i in range(batch_size):
                    seq_batch[i] = output[i][0]
                    diff_batch[i] = output[i][1]

                _, summary_str = sess.run(
                    [g_optim, g_sum],
                    feed_dict={
                        model.diff_in: diff_batch,
                        model.xt: seq_batch[:, :, :, K - 1],
                        model.target: seq_batch
                    })
                writer.add_summary(summary_str, counter)

                errL_img = model.L_img.eval({
                    model.diff_in: diff_batch,
                    model.xt: seq_batch[:, :, :, K - 1],
                    model.target: seq_batch
                })
                counter += 1
                if counter % 50 == 0:
                    print("Iters: [%2d] time: %4.4f, L_img: %.8f" %
                          (iters, time.time() - start_time, errL_img))

                if np.mod(counter, 10) == 1:
                    samples = sess.run(
                        [model.G],
                        feed_dict={
                            model.diff_in: diff_batch,
                            model.xt: seq_batch[:, :, :, K - 1],
                            model.target: seq_batch
                        })[0]
                    # IPython.embed()
                    samples = samples[0].swapaxes(0, 2).swapaxes(1, 2)
                    # IPython.embed()

                    sbatch = seq_batch[0, :, :, :].swapaxes(0,
                                                            2).swapaxes(1, 2)

                    sbatch2 = sbatch.copy()
                    # IPython.embed()
                    sbatch2[K:, :, :] = samples
                    # IPython.embed()
                    samples = np.concatenate((sbatch2, sbatch), axis=0)
                    # IPython.embed()
                    print("Saving sample ...")
                    save_images(samples, [2, K + T],
                                samples_dir + "train_%s.png" % (iters))
                if np.mod(counter, 10000) == 2:
                    model.save(sess, checkpoint_dir, counter)

                iters += 1
Beispiel #2
0
def main(lr, batch_size, alpha, beta, image_size_h, image_size_w, K,
         T, num_iter, gpu):
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu[0])
    
    data_path = "/data1/ikusyou/vimeo_septuplet/sequences/"
    f = open(data_path+"sep_trainlist.txt","r")
    trainfiles = [l[:-1] for l in f.readlines()]
    f.close()

    # new part for validation files, @@@ please fill the path
    val_path = ""
    f = open(val_path+"", "r")
    val_files = [l[:-1] for l in f.readlines()]
    f.close()
    val_num = len(val_files)
    val_seq = np.zeros((val_num, 240, 416, K+T, 3), dtype="float32")
    val_diff = np.zeros((val_num, 240, 416, K-1, 3), dtype="float32")

    # At the very beginning, we will read the whole validation set to the memory
    for idx, f_name in enumerate(val_files):
        output = load_vimeo_data(f_name, val_path, [240, 416], K, T)
        val_seq[idx] = output[0]
        val_diff[idx] = output[1]
    # Now we finish reading the val data    


    margin = 0.3 
    updateD = False
    updateG = True
    iters = 0
    prefix  = ("VIMEO_MCNET_V1.1.1"
            + "_image_size_h="+str(image_size_h)
            + "_image_size_w="+str(image_size_w)
            + "_K="+str(K)
            + "_T="+str(T)
            + "_batch_size="+str(batch_size)
            + "_lr="+str(lr))

    print("\n"+prefix+"\n")
    checkpoint_dir = "../models/"+prefix+"/"
    samples_dir = "../samples/"+prefix+"/"
    summary_dir = "../logs/"+prefix+"/"

    if not exists(checkpoint_dir):
        makedirs(checkpoint_dir)
    if not exists(samples_dir):
        makedirs(samples_dir)
    if not exists(summary_dir):
        makedirs(summary_dir)
    
    with tf.device("/gpu:%d"%gpu[0]):
        model = MCNET(image_size=[image_size_h,image_size_w], c_dim=3,
                    K=K, batch_size=batch_size, T=T,
                    checkpoint_dir=checkpoint_dir)
        # d_optim = tf.train.AdamOptimizer(lr, beta1=0.5).minimize(
        #     model.d_loss, var_list=model.d_vars
        # )
        g_optim = tf.train.AdamOptimizer(lr, beta1=0.5).minimize(
            alpha*model.L_img, var_list=model.g_vars
        )

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.8)
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                    log_device_placement=False,
                    gpu_options=gpu_options)) as sess:

        tf.global_variables_initializer().run()

        if model.load(sess, checkpoint_dir):
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")

        g_sum = tf.summary.merge([model.L_p_sum,
                                model.L_gdl_sum, model.loss_sum])
        # d_sum = tf.summary.merge([model.d_loss_real_sum, model.d_loss_sum,
        #                         model.d_loss_fake_sum])
        writer = tf.summary.FileWriter(summary_dir, sess.graph)

        counter = iters+1
        start_time = time.time()
        # IPython.embed()
        with Parallel(n_jobs=batch_size) as parallel:
            while iters < num_iter:
                mini_batches = get_minibatches_idx(len(trainfiles), batch_size, shuffle=True)
                for _, batchidx in mini_batches:
                    if len(batchidx) == batch_size:
                        seq_batch  = np.zeros((batch_size, image_size_h, image_size_w,
                                                K+T, 3), dtype="float32")
                        diff_batch = np.zeros((batch_size, image_size_h, image_size_w,
                                                K-1, 3), dtype="float32")
                        t0 = time.time()
                        Ts = np.repeat(np.array([T]),batch_size,axis=0)
                        Ks = np.repeat(np.array([K]),batch_size,axis=0)
                        paths = np.repeat(data_path, batch_size,axis=0)
                        tfiles = np.array(trainfiles)[batchidx]
                        shapes = np.repeat(np.array([image_size_h]),batch_size,axis=0)
                        output = parallel(delayed(load_vimeo_data)(f, p,img_sze, k, t)
                                                                for f,p,img_sze,k,t in zip(tfiles,
                                                                                        paths,
                                                                                        shapes,
                                                                                        Ks, Ts))

                    for i in range(batch_size):
                        seq_batch[i] = output[i][0]
                        diff_batch[i] = output[i][1]


                    _, summary_str = sess.run([g_optim, g_sum],
                                                feed_dict={model.diff_in: diff_batch,
                                                        model.xt: seq_batch[:,:,:,K-1],
                                                        model.target: seq_batch})

                    
                    writer.add_summary(summary_str, counter)
                    
                    errL_img = model.L_img.eval({model.diff_in: diff_batch,
                                                model.xt: seq_batch[:,:,:,K-1],
                                                model.target: seq_batch})



                    counter += 1
                    if counter % 50 == 0:
                        print(
                            "Iters: [%2d] time: %4.4f, img_loss:%.8f" 
                            % (iters, time.time() - start_time, errL_img)
                        )
                        val_loss = []
                        for idx in range(0, val_num, batch_size)[:-1]: # [:-1] to avoid out if index
                            diff_batch = val_diff[idx:idx + batch_size]
                            seq_batch = val_seq[idx:idx + batch_size]
                            errL_img = model.L_img.eval({model.diff_in: diff_batch,
                                                        model.xt: seq_batch[:,:,:,K-1],
                                                        model.target: seq_batch})
                            val_loss.append(np.mean(errL_img))
                        



                    if np.mod(counter, 200) == 1:
                        samples = sess.run([model.G],
                                            feed_dict={model.diff_in: diff_batch,
                                                        model.xt: seq_batch[:,:,:,K-1],
                                                        model.target: seq_batch})[0]
                        # IPython.embed()
                        samples = samples[0].swapaxes(0,2).swapaxes(1,2)
                        # IPython.embed()

                        sbatch  = seq_batch[0,:,:,:].swapaxes(0,2).swapaxes(1,2)
                        
                        sbatch2 = sbatch.copy()
                        # IPython.embed()
                        sbatch2[K:,:,:] = samples
                        # IPython.embed()
                        samples = np.concatenate((sbatch2,sbatch), axis=0)
                        # IPython.embed()
                        print("Saving sample ...")
                        save_images(samples, [2, K+T], 
                                    samples_dir+"train_%s.png" % (iters))
                    if np.mod(counter, 10000) == 2:
                        model.save(sess, checkpoint_dir, counter)

                    iters += 1
Beispiel #3
0
def main(lr, batch_size, alpha, beta, image_size, K, T, num_iter, gpu):
    data_path = "../data/KTH/"
    f = open(data_path + "train_data_list_trimmed.txt", "r")
    trainfiles = f.readlines()
    margin = 0.3
    updateD = True
    updateG = True
    iters = 0
    prefix = ("KTH_MCNET" + "_image_size=" + str(image_size) + "_K=" + str(K) +
              "_T=" + str(T) + "_batch_size=" + str(batch_size) + "_alpha=" +
              str(alpha) + "_beta=" + str(beta) + "_lr=" + str(lr))

    print("\n" + prefix + "\n")
    checkpoint_dir = "../models/" + prefix + "/"
    samples_dir = "../samples/" + prefix + "/"
    summary_dir = "../logs/" + prefix + "/"

    if not exists(checkpoint_dir):
        makedirs(checkpoint_dir)
    if not exists(samples_dir):
        makedirs(samples_dir)
    if not exists(summary_dir):
        makedirs(summary_dir)

    with tf.device("/gpu:%d" % gpu[0]):
        model = MCNET(image_size=[image_size, image_size],
                      c_dim=1,
                      K=K,
                      batch_size=batch_size,
                      T=T,
                      checkpoint_dir=checkpoint_dir)
        d_optim = tf.train.AdamOptimizer(lr, beta1=0.5).minimize(
            model.d_loss, var_list=model.d_vars)
        g_optim = tf.train.AdamOptimizer(lr, beta1=0.5).minimize(
            alpha * model.L_img + beta * model.L_GAN, var_list=model.g_vars)

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=1.0)
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                          log_device_placement=False,
                                          gpu_options=gpu_options)) as sess:

        tf.global_variables_initializer().run()

        if model.load(sess, checkpoint_dir):
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")

        g_sum = tf.summary.merge(
            [model.L_p_sum, model.L_gdl_sum, model.loss_sum, model.L_GAN_sum])
        d_sum = tf.summary.merge(
            [model.d_loss_real_sum, model.d_loss_sum, model.d_loss_fake_sum])
        writer = tf.summary.FileWriter(summary_dir, sess.graph)

        counter = iters + 1
        start_time = time.time()

        with Parallel(n_jobs=batch_size) as parallel:
            while iters < num_iter:
                mini_batches = get_minibatches_idx(len(trainfiles),
                                                   batch_size,
                                                   shuffle=True)
                for _, batchidx in mini_batches:
                    if len(batchidx) == batch_size:
                        seq_batch = np.zeros(
                            (batch_size, image_size, image_size, K + T, 1),
                            dtype="float32")
                        diff_batch = np.zeros(
                            (batch_size, image_size, image_size, K - 1, 1),
                            dtype="float32")
                        t0 = time.time()
                        Ts = np.repeat(np.array([T]), batch_size, axis=0)
                        Ks = np.repeat(np.array([K]), batch_size, axis=0)
                        paths = np.repeat(data_path, batch_size, axis=0)
                        tfiles = np.array(trainfiles)[batchidx]
                        shapes = np.repeat(np.array([image_size]),
                                           batch_size,
                                           axis=0)
                        output = parallel(
                            delayed(load_kth_data)(f, p, img_sze, k, t)
                            for f, p, img_sze, k, t in zip(
                                tfiles, paths, shapes, Ks, Ts))
                        for i in xrange(batch_size):
                            seq_batch[i] = output[i][0]
                            diff_batch[i] = output[i][1]

                        if updateD:
                            _, summary_str = sess.run(
                                [d_optim, d_sum],
                                feed_dict={
                                    model.diff_in: diff_batch,
                                    model.xt: seq_batch[:, :, :, K - 1],
                                    model.target: seq_batch
                                })
                            writer.add_summary(summary_str, counter)

                        if updateG:
                            _, summary_str = sess.run(
                                [g_optim, g_sum],
                                feed_dict={
                                    model.diff_in: diff_batch,
                                    model.xt: seq_batch[:, :, :, K - 1],
                                    model.target: seq_batch
                                })
                            writer.add_summary(summary_str, counter)

                        errD_fake = model.d_loss_fake.eval({
                            model.diff_in:
                            diff_batch,
                            model.xt:
                            seq_batch[:, :, :, K - 1],
                            model.target:
                            seq_batch
                        })
                        errD_real = model.d_loss_real.eval({
                            model.diff_in:
                            diff_batch,
                            model.xt:
                            seq_batch[:, :, :, K - 1],
                            model.target:
                            seq_batch
                        })
                        errG = model.L_GAN.eval({
                            model.diff_in:
                            diff_batch,
                            model.xt:
                            seq_batch[:, :, :, K - 1],
                            model.target:
                            seq_batch
                        })

                        if errD_fake < margin or errD_real < margin:
                            updateD = False
                        if errD_fake > (1. - margin) or errD_real > (1. -
                                                                     margin):
                            updateG = False
                        if not updateD and not updateG:
                            updateD = True
                            updateG = True

                        counter += 1

                        print(
                            "Iters: [%2d] time: %4.4f, d_loss: %.8f, L_GAN: %.8f"
                            % (iters, time.time() - start_time,
                               errD_fake + errD_real, errG))

                        if np.mod(counter, 100) == 1:
                            samples = sess.run(
                                [model.G],
                                feed_dict={
                                    model.diff_in: diff_batch,
                                    model.xt: seq_batch[:, :, :, K - 1],
                                    model.target: seq_batch
                                })[0]
                            samples = samples[0].swapaxes(0, 2).swapaxes(1, 2)
                            sbatch = seq_batch[0, :, :,
                                               K:].swapaxes(0,
                                                            2).swapaxes(1, 2)
                            samples = np.concatenate((samples, sbatch), axis=0)
                            print("Saving sample ...")
                            save_images(samples[:, :, :, ::-1], [2, T],
                                        samples_dir + "train_%s.png" % (iters))
                        if np.mod(counter, 500) == 2:
                            model.save(sess, checkpoint_dir, counter)

                        iters += 1
def main(name, lr, batch_size, alpha, beta, image_size, K, T, num_iter, gpu,
         nonlinearity, samples_every, gdl, channels, dataset, residual,
         planes):
    margin = 0.3
    updateD = True
    updateG = True
    iters = 0
    namestr = name if len(name) == 0 else "_" + name
    prefix = (dataset.replace('/', '-') + namestr + "_channels=" +
              str(channels) + "_alpha=" + str(alpha) + "_planes=" +
              str(planes) + "_beta=" + str(beta) + "_lr=" + str(lr) +
              "_nonlin=" + str(nonlinearity) + "_res=" + str(residual) +
              "_gdl=" + str(gdl))

    print("\n" + prefix + "\n")
    checkpoint_dir = "../models/" + prefix + "/"
    samples_dir = "../samples/" + prefix + "/"
    summary_dir = "../logs/" + prefix + "/"

    normalizer = partial(normalize_data, dataset, channels, K)
    train_data, test_data, num_workers = load_dataset(dataset,
                                                      T + K,
                                                      image_size,
                                                      channels,
                                                      transforms=[normalizer])

    train_loader = DataLoader(train_data,
                              num_workers=num_workers,
                              batch_size=batch_size,
                              shuffle=True,
                              drop_last=True,
                              pin_memory=True)

    def get_training_batch():
        while True:
            for sequence in train_loader:
                yield sequence

    training_batch_generator = get_training_batch()

    checkpoint_dir = "../models/" + prefix + "/"
    samples_dir = "../samples/" + prefix + "/"
    summary_dir = "../logs/" + prefix + "/"

    if not exists(checkpoint_dir):
        makedirs(checkpoint_dir)
    if not exists(samples_dir):
        makedirs(samples_dir)
    if not exists(summary_dir):
        makedirs(summary_dir)

    with tf.device("/gpu:%d" % gpu[0]):
        model = MCNET(image_size=[image_size, image_size],
                      c_dim=channels,
                      K=K,
                      batch_size=batch_size,
                      T=T,
                      checkpoint_dir=checkpoint_dir,
                      nonlinearity=nonlinearity,
                      gdl_weight=gdl,
                      residual=residual,
                      planes=planes)
        d_optim = tf.train.AdamOptimizer(lr, beta1=0.5).minimize(
            model.d_loss, var_list=model.d_vars)
        # facd_optim = tf.train.AdamOptimizer(lr, beta1=0.5).minimize(
        #     model.facd_loss, var_list=model.facd_vars
        # )
        g_optim = tf.train.AdamOptimizer(lr, beta1=0.5).minimize(
            # alpha * model.L_img + beta * model.L_GAN + gamma * model.L_FAC,
            alpha * model.L_img + beta * model.L_GAN,
            var_list=model.g_vars)
        print("GDL: ", model.gdl_weight)

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.9)
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=True,
                                            intra_op_parallelism_threads=3,
                                            inter_op_parallelism_threads=3,
                                            gpu_options=gpu_options))

    tf.global_variables_initializer().run()

    if model.load(sess, checkpoint_dir):
        print(" [*] Load SUCCESS")
    else:
        print(" [!] Load failed...")

    g_sum = tf.summary.merge(
        [model.L_p_sum, model.L_gdl_sum, model.loss_sum, model.L_GAN_sum])
    d_sum = tf.summary.merge(
        [model.d_loss_real_sum, model.d_loss_sum, model.d_loss_fake_sum])
    writer = tf.summary.FileWriter(summary_dir, sess.graph)

    counter = iters + 1
    start_time = time.time()

    with Parallel(n_jobs=batch_size) as parallel:
        while iters < num_iter:
            for batch_index in range(100000):
                seq_batch, diff_batch = next(training_batch_generator)
                # show(seq_batch[0])
                # show(diff_batch[0])
                # ipdb.set_trace()
                seq_batch = seq_batch.numpy()
                diff_batch = diff_batch.numpy()

                if updateD:
                    _, summary_str = sess.run(
                        [d_optim, d_sum],
                        feed_dict={
                            model.diff_in: diff_batch,
                            model.xt: seq_batch[:, :, :, K - 1],
                            model.target: seq_batch
                        })
                    writer.add_summary(summary_str, counter)

                if updateG:
                    _, summary_str = sess.run(
                        [g_optim, g_sum],
                        feed_dict={
                            model.diff_in: diff_batch,
                            model.xt: seq_batch[:, :, :, K - 1],
                            model.target: seq_batch
                        })
                    writer.add_summary(summary_str, counter)

                errD_fake = model.d_loss_fake.eval({
                    model.diff_in:
                    diff_batch,
                    model.xt:
                    seq_batch[:, :, :, K - 1],
                    model.target:
                    seq_batch
                })
                errD_real = model.d_loss_real.eval({
                    model.diff_in:
                    diff_batch,
                    model.xt:
                    seq_batch[:, :, :, K - 1],
                    model.target:
                    seq_batch
                })
                errG = model.L_GAN.eval({
                    model.diff_in: diff_batch,
                    model.xt: seq_batch[:, :, :, K - 1],
                    model.target: seq_batch
                })
                errImage = model.L_img.eval({
                    model.diff_in: diff_batch,
                    model.xt: seq_batch[:, :, :, K - 1],
                    model.target: seq_batch
                })

                if errD_fake < margin or errD_real < margin:
                    updateD = False
                if errD_fake > (1. - margin) or errD_real > (1. - margin):
                    updateG = False
                if not updateD and not updateG:
                    updateD = True
                    updateG = True

                counter += 1

                print(
                    "Iters: [%2d] time: %4.4f, d_loss: %.8f, L_GAN: %.8f, L_img: %.8f"
                    % (iters, time.time() - start_time, errD_fake + errD_real,
                       errG, errImage))

                if (counter % samples_every) == 0:
                    samples = sess.run(
                        [model.G],
                        feed_dict={
                            model.diff_in: diff_batch,
                            model.xt: seq_batch[:, :, :, K - 1],
                            model.target: seq_batch
                        })[0]
                    # ipdb.set_trace()
                    samples_pad = np.array(samples)
                    samples_pad.fill(0)
                    generations = np.concatenate((samples_pad, samples), 3)
                    generations = generations[0].swapaxes(0, 2).swapaxes(1, 2)
                    sbatch = seq_batch[0].swapaxes(0, 2).swapaxes(1, 2)
                    generations = np.concatenate((generations, sbatch), axis=0)
                    print("Saving sample ...")
                    # ipdb.set_trace()
                    save_images(generations[:, :, :, ::-1], [2, T + K],
                                samples_dir + "train_%s.png" % (iters))
                if np.mod(counter, 500) == 2:
                    model.save(sess, checkpoint_dir, counter)

                iters += 1
    sess.close()