Exemplo n.º 1
0
def test_net(model,
             img_dir,
             max_iter=1000000,
             check_every_n=500,
             loss_check_n=10,
             save_model_freq=1000,
             batch_size=128):
    img1 = U.get_placeholder_cached(name="img1")
    img2 = U.get_placeholder_cached(name="img2")

    # Testing
    img_test = U.get_placeholder_cached(name="img_test")
    reconst_tp = U.get_placeholder_cached(name="reconst_tp")

    vae_loss = U.mean(model.vaeloss)

    latent_z1_tp = model.latent_z1
    latent_z2_tp = model.latent_z2

    losses = [
        U.mean(model.vaeloss),
        U.mean(model.siam_loss),
        U.mean(model.kl_loss1),
        U.mean(model.kl_loss2),
        U.mean(model.reconst_error1),
        U.mean(model.reconst_error2),
    ]

    tf.summary.scalar('Total Loss', losses[0])
    tf.summary.scalar('Siam Loss', losses[1])
    tf.summary.scalar('kl1_loss', losses[2])
    tf.summary.scalar('kl2_loss', losses[3])
    tf.summary.scalar('reconst_err1', losses[4])
    tf.summary.scalar('reconst_err2', losses[5])

    decoded_img = [model.reconst1, model.reconst2]

    weight_loss = [1, 1, 1]

    compute_losses = U.function([img1, img2], vae_loss)
    lr = 0.00005
    optimizer = tf.train.AdamOptimizer(learning_rate=lr,
                                       epsilon=0.01 / batch_size)

    all_var_list = model.get_trainable_variables()

    # print all_var_list
    img1_var_list = all_var_list
    #[v for v in all_var_list if v.name.split("/")[1].startswith("proj1") or v.name.split("/")[1].startswith("unproj1")]
    optimize_expr1 = optimizer.minimize(vae_loss, var_list=img1_var_list)
    merged = tf.summary.merge_all()
    train = U.function([img1, img2], [
        losses[0], losses[1], losses[2], losses[3], losses[4], losses[5],
        latent_z1_tp, latent_z2_tp, merged
    ],
                       updates=[optimize_expr1])
    get_reconst_img = U.function(
        [img1, img2],
        [model.reconst1, model.reconst2, latent_z1_tp, latent_z2_tp])
    get_latent_var = U.function([img1, img2], [latent_z1_tp, latent_z2_tp])

    # [testing -> ]
    test = U.function([img_test], model.latent_z_test)
    test_reconst = U.function([reconst_tp], [model.reconst_test])
    # [testing <- ]

    cur_dir = get_cur_dir()
    chk_save_dir = os.path.join(cur_dir, "chk1")
    log_save_dir = os.path.join(cur_dir, "log")
    validate_img_saver_dir = os.path.join(cur_dir, "validate_images")
    test_img_saver_dir = os.path.join(cur_dir, "test_images")
    testing_img_dir = os.path.join(cur_dir, "dataset/test_img")

    train_writer = U.summary_writer(dir=log_save_dir)

    U.initialize()

    saver, chk_file_num = U.load_checkpoints(load_requested=True,
                                             checkpoint_dir=chk_save_dir)
    validate_img_saver = Img_Saver(validate_img_saver_dir)

    # [testing -> ]
    test_img_saver = Img_Saver(test_img_saver_dir)
    # [testing <- ]

    meta_saved = False

    iter_log = []
    loss1_log = []
    loss2_log = []

    loss3_log = []

    training_images_list = read_dataset(img_dir)
    n_total_train_data = len(training_images_list)

    testing_images_list = read_dataset(testing_img_dir)
    n_total_testing_data = len(testing_images_list)

    training = False
    testing = True

    # if training == True:
    # 	for num_iter in range(chk_file_num+1, max_iter):
    # 		header("******* {}th iter: *******".format(num_iter))

    # 		idx = random.sample(range(n_total_train_data), 2*batch_size)
    # 		batch_files = [training_images_list[i] for i in idx]
    # 		# print batch_files
    # 		[images1, images2] = load_image(dir_name = img_dir, img_names = batch_files)
    # 		img1, img2 = images1, images2
    # 		[l1, l2, _, _] = get_reconst_img(img1, img2)

    # 		[loss0, loss1, loss2, loss3, loss4, loss5, latent1, latent2, summary] = train(img1, img2)

    # 		warn("Total Loss: {}".format(loss0))
    # 		warn("Siam loss: {}".format(loss1))
    # 		warn("kl1_loss: {}".format(loss2))
    # 		warn("kl2_loss: {}".format(loss3))
    # 		warn("reconst_err1: {}".format(loss4))
    # 		warn("reconst_err2: {}".format(loss5))

    # 		# warn("num_iter: {} check: {}".format(num_iter, check_every_n))
    # 		# warn("Total Loss: {}".format(loss6))
    # 		if num_iter % check_every_n == 1:
    # 			header("******* {}th iter: *******".format(num_iter))
    # 			idx = random.sample(range(len(training_images_list)), 2*5)
    # 			validate_batch_files = [training_images_list[i] for i in idx]
    # 			[images1, images2] = load_image(dir_name = img_dir, img_names = validate_batch_files)
    # 			[reconst1, reconst2, _, _] = get_reconst_img(images1, images2)
    # 			# for i in range(len(latent1[0])):
    # 			# 	print "{} th: {:.2f}".format(i, np.mean(np.abs(latent1[:, i] - latent2[:, i])))
    # 			for img_idx in range(len(images1)):
    # 				sub_dir = "iter_{}".format(num_iter)

    # 				save_img = np.squeeze(images1[img_idx])
    # 				save_img = Image.fromarray(save_img)
    # 				img_file_name = "{}_ori.png".format(validate_batch_files[img_idx].split('.')[0])
    # 				validate_img_saver.save(save_img, img_file_name, sub_dir = sub_dir)

    # 				save_img = np.squeeze(reconst1[img_idx])
    # 				save_img = Image.fromarray(save_img)
    # 				img_file_name = "{}_rec.png".format(validate_batch_files[img_idx].split('.')[0])
    # 				validate_img_saver.save(save_img, img_file_name, sub_dir = sub_dir)

    # 		if num_iter % loss_check_n == 1:
    # 			train_writer.add_summary(summary, num_iter)

    # 		if num_iter > 11 and num_iter % save_model_freq == 1:
    # 			if meta_saved == True:
    # 				saver.save(U.get_session(), chk_save_dir + '/' + 'checkpoint', global_step = num_iter, write_meta_graph = False)
    # 			else:
    # 				print "Save  meta graph"
    # 				saver.save(U.get_session(), chk_save_dir + '/' + 'checkpoint', global_step = num_iter, write_meta_graph = True)
    # 				meta_saved = True

    # Testing
    print testing_images_list
    if testing == True:
        test_file_name = testing_images_list[6]
        print test_file_name
        test_img = load_single_img(dir_name=testing_img_dir,
                                   img_name=test_file_name)
        test_features = np.arange(25, 32)
        for test_feature in test_features:
            test_variation = np.arange(-10, 10, 0.1)

            z = test(test_img)
            print np.shape(z)
            print z
            for idx in range(len(test_variation)):
                z_test = np.copy(z)
                z_test[0, test_feature] = z_test[
                    0, test_feature] + test_variation[idx]
                reconst_test = test_reconst(z_test)
                test_save_img = np.squeeze(reconst_test[0])
                test_save_img = Image.fromarray(test_save_img)
                img_file_name = "test_feat_{}_var_({}).png".format(
                    test_feature, test_variation[idx])
                test_img_saver.save(test_save_img, img_file_name, sub_dir=None)
            reconst_test = test_reconst(z)
            test_save_img = np.squeeze(reconst_test[0])
            test_save_img = Image.fromarray(test_save_img)
            img_file_name = "test_feat_{}_var_original.png".format(
                test_feature)
            test_img_saver.save(test_save_img, img_file_name, sub_dir=None)
Exemplo n.º 2
0
def train_net(model, mode, img_dir, dataset, chkfile_name, logfile_name, validatefile_name, entangled_feat, max_epoch = 300, check_every_n = 500, loss_check_n = 10, save_model_freq = 5, batch_size = 512, lr = 0.001):
    img1 = U.get_placeholder_cached(name="img1")
    img2 = U.get_placeholder_cached(name="img2")

    vae_loss = U.mean(model.vaeloss)

    latent_z1_tp = model.latent_z1
    latent_z2_tp = model.latent_z2

    losses = [U.mean(model.vaeloss),
            U.mean(model.siam_loss),
            U.mean(model.kl_loss1),
            U.mean(model.kl_loss2),
            U.mean(model.reconst_error1),
            U.mean(model.reconst_error2),
            ]

    siam_normal = losses[1]/entangled_feat
    siam_max = U.mean(model.max_siam_loss)

    tf.summary.scalar('Total Loss', losses[0])
    tf.summary.scalar('Siam Loss', losses[1])
    tf.summary.scalar('kl1_loss', losses[2])
    tf.summary.scalar('kl2_loss', losses[3])
    tf.summary.scalar('reconst_err1', losses[4])
    tf.summary.scalar('reconst_err2', losses[5])
    tf.summary.scalar('Siam Normal', siam_normal)
    tf.summary.scalar('Siam Max', siam_max)



    compute_losses = U.function([img1, img2], vae_loss)
    optimizer=tf.train.AdamOptimizer(learning_rate=lr, epsilon = 0.01/batch_size)

    all_var_list = model.get_trainable_variables()


    img1_var_list = all_var_list
    optimize_expr1 = optimizer.minimize(vae_loss, var_list=img1_var_list)
    merged = tf.summary.merge_all()
    train = U.function([img1, img2],
                        [losses[0], losses[1], losses[2], losses[3], losses[4], losses[5], latent_z1_tp, latent_z2_tp, merged], updates = [optimize_expr1])
    get_reconst_img = U.function([img1, img2], [model.reconst1, model.reconst2, latent_z1_tp, latent_z2_tp])
    get_latent_var = U.function([img1, img2], [latent_z1_tp, latent_z2_tp])

    cur_dir = get_cur_dir()
    chk_save_dir = os.path.join(cur_dir, chkfile_name)
    log_save_dir = os.path.join(cur_dir, logfile_name)
    validate_img_saver_dir = os.path.join(cur_dir, validatefile_name)
    if dataset == 'chairs' or dataset == 'celeba':
        test_img_saver_dir = os.path.join(cur_dir, "test_images")
        testing_img_dir = os.path.join(cur_dir, "dataset/{}/test_img".format(dataset))

    train_writer = U.summary_writer(dir = log_save_dir)

    U.initialize()

    saver, chk_file_epoch_num = U.load_checkpoints(load_requested = True, checkpoint_dir = chk_save_dir)
    if dataset == 'chairs' or dataset == 'celeba':
        validate_img_saver = Img_Saver(Img_dir = validate_img_saver_dir)
    elif dataset == 'dsprites':
        validate_img_saver = BW_Img_Saver(Img_dir = validate_img_saver_dir) # Black and White, temporary usage
    else:
        warn("Unknown dataset Error")
        # break

    warn(img_dir)
    if dataset == 'chairs' or dataset == 'celeba':
        training_images_list = read_dataset(img_dir)
        n_total_train_data = len(training_images_list)
        testing_images_list = read_dataset(testing_img_dir)
        n_total_testing_data = len(testing_images_list)
    elif dataset == 'dsprites':
        cur_dir = osp.join(cur_dir, 'dataset')
        cur_dir = osp.join(cur_dir, 'dsprites')
        img_dir = osp.join(cur_dir, 'dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz')
        manager = DataManager(img_dir, batch_size)
    else:
        warn("Unknown dataset Error")
        # break

    meta_saved = False

    if mode == 'train':
        for epoch_idx in range(chk_file_epoch_num+1, max_epoch):
            t_epoch_start = time.time()
            num_batch = manager.get_len()

            for batch_idx in range(num_batch):
                if dataset == 'chairs' or dataset == 'celeba':
                    idx = random.sample(range(n_total_train_data), 2*batch_size)
                    batch_files = [training_images_list[i] for i in idx]
                    [images1, images2] = load_image(dir_name = img_dir, img_names = batch_files)
                elif dataset == 'dsprites':
                    [images1, images2] = manager.get_next()
                img1, img2 = images1, images2
                [l1, l2, _, _] = get_reconst_img(img1, img2)

                [loss0, loss1, loss2, loss3, loss4, loss5, latent1, latent2, summary] = train(img1, img2)

                if batch_idx % 50 == 1:
                    header("******* epoch: {}/{} batch: {}/{} *******".format(epoch_idx, max_epoch, batch_idx, num_batch))
                    warn("Total Loss: {}".format(loss0))
                    warn("Siam loss: {}".format(loss1))
                    warn("kl1_loss: {}".format(loss2))
                    warn("kl2_loss: {}".format(loss3))
                    warn("reconst_err1: {}".format(loss4))
                    warn("reconst_err2: {}".format(loss5))

                if batch_idx % check_every_n == 1:
                    if dataset == 'chairs' or dataset == 'celeba':
                        idx = random.sample(range(len(training_images_list)), 2*5)
                        validate_batch_files = [training_images_list[i] for i in idx]
                        [images1, images2] = load_image(dir_name = img_dir, img_names = validate_batch_files)
                    elif dataset == 'dsprites':
                        [images1, images2] = manager.get_next()

                    [reconst1, reconst2, _, _] = get_reconst_img(images1, images2)

                    if dataset == 'chairs':
                        for img_idx in range(len(images1)):
                            sub_dir = "iter_{}".format(batch_idx)

                            save_img = np.squeeze(images1[img_idx])
                            save_img = Image.fromarray(save_img)
                            img_file_name = "{}_ori.png".format(validate_batch_files[img_idx].split('.')[0])
                            validate_img_saver.save(save_img, img_file_name, sub_dir = sub_dir)

                            save_img = np.squeeze(reconst1[img_idx])
                            save_img = Image.fromarray(save_img)
                            img_file_name = "{}_rec.png".format(validate_batch_files[img_idx].split('.')[0])
                            validate_img_saver.save(save_img, img_file_name, sub_dir = sub_dir)
                    elif dataset == 'celeba':
                        for img_idx in range(len(images1)):
                            sub_dir = "iter_{}".format(batch_idx)

                            save_img = np.squeeze(images1[img_idx])
                            save_img = Image.fromarray(save_img, 'RGB')
                            img_file_name = "{}_ori.png".format(validate_batch_files[img_idx].split('.')[0])
                            validate_img_saver.save(save_img, img_file_name, sub_dir = sub_dir)

                            save_img = np.squeeze(reconst1[img_idx])
                            save_img = Image.fromarray(save_img, 'RGB')
                            img_file_name = "{}_rec.png".format(validate_batch_files[img_idx].split('.')[0])
                            validate_img_saver.save(save_img, img_file_name, sub_dir = sub_dir)
                    elif dataset == 'dsprites':
                        for img_idx in range(len(images1)):
                            sub_dir = "iter_{}".format(batch_idx)

                            # save_img = images1[img_idx].reshape(64, 64)
                            save_img = np.squeeze(images1[img_idx])
                            save_img = save_img.astype(np.float32)
                            img_file_name = "{}_ori.jpg".format(img_idx)
                            validate_img_saver.save(save_img, img_file_name, sub_dir = sub_dir)

                            # save_img = reconst1[img_idx].reshape(64, 64)
                            save_img = np.squeeze(reconst1[img_idx])
                            save_img = save_img.astype(np.float32)
                            img_file_name = "{}_rec.jpg".format(img_idx)
                            validate_img_saver.save(save_img, img_file_name, sub_dir = sub_dir)

                if batch_idx % loss_check_n == 1:
                    train_writer.add_summary(summary, batch_idx)

            t_epoch_end = time.time()
            t_epoch_run = t_epoch_end - t_epoch_start
            if dataset == 'dsprites':
                t_check = manager.sample_size / t_epoch_run

                warn("==========================================")
                warn("Run {} th epoch in {} sec: {} images / sec".format(epoch_idx+1, t_epoch_run, t_check))
                warn("==========================================")

            # if epoch_idx % save_model_freq == 0:
            if meta_saved == True:
                saver.save(U.get_session(), chk_save_dir + '/' + 'checkpoint', global_step = epoch_idx, write_meta_graph = False)
            else:
                print "Save  meta graph"
                saver.save(U.get_session(), chk_save_dir + '/' + 'checkpoint', global_step = epoch_idx, write_meta_graph = True)
                meta_saved = True

    # Testing
    elif mode == 'test':
        test_file_name = testing_images_list[0]
        test_img = load_single_img(dir_name = testing_img_dir, img_name = test_file_name)
        test_feature = 31
        test_variation = np.arange(-5, 5, 0.1)

        z = test(test_img)
        for idx in range(len(test_variation)):
            z_test = np.copy(z)
            z_test[0, test_feature] = z_test[0, test_feature] + test_variation[idx]
            reconst_test = test_reconst(z_test)
            test_save_img = np.squeeze(reconst_test[0])
            test_save_img = Image.fromarray(test_save_img)
            img_file_name = "test_feat_{}_var_({}).png".format(test_feature, test_variation[idx])
            test_img_saver.save(test_save_img, img_file_name, sub_dir = None)
        reconst_test = test_reconst(z)
        test_save_img = np.squeeze(reconst_test[0])
        test_save_img = Image.fromarray(test_save_img)
        img_file_name = "test_feat_{}_var_original.png".format(test_feature)
        test_img_saver.save(test_save_img, img_file_name, sub_dir = None)
Exemplo n.º 3
0
def mgpu_train_net(models, num_gpus, mode, img_dir, dataset, chkfile_name, logfile_name, validatefile_name, entangled_feat, max_epoch = 300, check_every_n = 500, loss_check_n = 10, save_model_freq = 5, batch_size = 512, lr = 0.001):
    img1 = U.get_placeholder_cached(name="img1")
    img2 = U.get_placeholder_cached(name="img2")

    feat_cls = U.get_placeholder_cached(name="feat_cls")

    # batch size must be multiples of ntowers (# of GPUs)
    ntowers = len(models)
    tf.assert_equal(tf.shape(img1)[0], tf.shape(img2)[0])
    tf.assert_equal(tf.floormod(tf.shape(img1)[0], ntowers), 0)

    img1splits = tf.split(img1, ntowers, 0)
    img2splits = tf.split(img2, ntowers, 0)

    tower_vae_loss = []
    tower_latent_z1_tp = []
    tower_latent_z2_tp = []
    tower_losses = []
    tower_siam_max = []
    tower_reconst1 = []
    tower_reconst2 = []
    tower_cls_loss = []
    for gid, model in enumerate(models):
        with tf.name_scope('gpu%d' % gid) as scope:
            with tf.device('/gpu:%d' % gid):

                vae_loss = U.mean(model.vaeloss)
                latent_z1_tp = model.latent_z1
                latent_z2_tp = model.latent_z2
                losses = [U.mean(model.vaeloss),
                          U.mean(model.siam_loss),
                          U.mean(model.kl_loss1),
                          U.mean(model.kl_loss2),
                          U.mean(model.reconst_error1),
                          U.mean(model.reconst_error2),
                          ]
                siam_max = U.mean(model.max_siam_loss)
                cls_loss = U.mean(model.cls_loss)

                tower_vae_loss.append(vae_loss)
                tower_latent_z1_tp.append(latent_z1_tp)
                tower_latent_z2_tp.append(latent_z2_tp)
                tower_losses.append(losses)
                tower_siam_max.append(siam_max)
                tower_reconst1.append(model.reconst1)
                tower_reconst2.append(model.reconst2)
                tower_cls_loss.append(cls_loss)

                tf.summary.scalar('Total Loss', losses[0])
                tf.summary.scalar('Siam Loss', losses[1])
                tf.summary.scalar('kl1_loss', losses[2])
                tf.summary.scalar('kl2_loss', losses[3])
                tf.summary.scalar('reconst_err1', losses[4])
                tf.summary.scalar('reconst_err2', losses[5])
                tf.summary.scalar('Siam Max', siam_max)

    vae_loss = U.mean(tower_vae_loss)
    siam_max = U.mean(tower_siam_max)
    latent_z1_tp = tf.concat(tower_latent_z1_tp, 0)
    latent_z2_tp = tf.concat(tower_latent_z2_tp, 0)
    model_reconst1 = tf.concat(tower_reconst1, 0)
    model_reconst2 = tf.concat(tower_reconst2, 0)
    cls_loss = U.mean(tower_cls_loss)

    losses = [[] for _ in range(len(losses))]
    for tl in tower_losses:
        for i, l in enumerate(tl):
            losses[i].append(l)

    losses = [U.mean(l) for l in losses]
    siam_normal = losses[1] / entangled_feat

    tf.summary.scalar('total/Total Loss', losses[0])
    tf.summary.scalar('total/Siam Loss', losses[1])
    tf.summary.scalar('total/kl1_loss', losses[2])
    tf.summary.scalar('total/kl2_loss', losses[3])
    tf.summary.scalar('total/reconst_err1', losses[4])
    tf.summary.scalar('total/reconst_err2', losses[5])
    tf.summary.scalar('total/Siam Normal', siam_normal)
    tf.summary.scalar('total/Siam Max', siam_max)

    compute_losses = U.function([img1, img2], vae_loss)

    all_var_list = model.get_trainable_variables()
    vae_var_list = [v for v in all_var_list if v.name.split("/")[2].startswith("vae")]
    cls_var_list = [v for v in all_var_list if v.name.split("/")[2].startswith("cls")]

    warn("{}".format(all_var_list))
    warn("==========================")
    warn("{}".format(vae_var_list))
    # warn("==========================")
    # warn("{}".format(cls_var_list))

    # with tf.device('/cpu:0'):
    optimizer = tf.train.AdamOptimizer(learning_rate=lr, epsilon = 0.01/batch_size)
    optimize_expr1 = optimizer.minimize(vae_loss, var_list=vae_var_list)

    feat_cls_optimizer = tf.train.AdagradOptimizer(learning_rate=0.01)
    optimize_expr2 = feat_cls_optimizer.minimize(cls_loss, var_list=cls_var_list)


    merged = tf.summary.merge_all()
    train = U.function([img1, img2],
                        [losses[0], losses[1], losses[2], losses[3], losses[4], losses[5], latent_z1_tp, latent_z2_tp, merged], updates = [optimize_expr1])


    get_reconst_img = U.function([img1, img2], [model_reconst1, model_reconst2, latent_z1_tp, latent_z2_tp])
    get_latent_var = U.function([img1, img2], [latent_z1_tp, latent_z2_tp])

    cur_dir = get_cur_dir()
    chk_save_dir = os.path.join(cur_dir, chkfile_name)
    log_save_dir = os.path.join(cur_dir, logfile_name)
    validate_img_saver_dir = os.path.join(cur_dir, validatefile_name)
    if dataset == 'chairs' or dataset == 'celeba':
        test_img_saver_dir = os.path.join(cur_dir, "test_images")
        testing_img_dir = os.path.join(cur_dir, "dataset/{}/test_img".format(dataset))

    train_writer = U.summary_writer(dir = log_save_dir)

    U.initialize()

    saver, chk_file_epoch_num = U.load_checkpoints(load_requested = True, checkpoint_dir = chk_save_dir)
    if dataset == 'chairs' or dataset == 'celeba':
        validate_img_saver = Img_Saver(Img_dir = validate_img_saver_dir)
    elif dataset == 'dsprites':
        validate_img_saver = BW_Img_Saver(Img_dir = validate_img_saver_dir) # Black and White, temporary usage
    else:
        warn("Unknown dataset Error")
        # break

    warn("dataset: {}".format(dataset))
    if dataset == 'chairs' or dataset == 'celeba':
        training_images_list = read_dataset(img_dir)
        n_total_train_data = len(training_images_list)
        testing_images_list = read_dataset(testing_img_dir)
        n_total_testing_data = len(testing_images_list)
    elif dataset == 'dsprites':
        cur_dir = osp.join(cur_dir, 'dataset')
        cur_dir = osp.join(cur_dir, 'dsprites')
        img_dir = osp.join(cur_dir, 'dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz')
        manager = DataManager(img_dir, batch_size)
    else:
        warn("Unknown dataset Error")
        # break

    meta_saved = False

    if mode == 'train':
        for epoch_idx in range(chk_file_epoch_num+1, max_epoch):
            t_epoch_start = time.time()
            num_batch = manager.get_len()

            for batch_idx in range(num_batch):
                if dataset == 'chairs' or dataset == 'celeba':
                    idx = random.sample(range(n_total_train_data), 2*batch_size)
                    batch_files = [training_images_list[i] for i in idx]
                    [images1, images2] = load_image(dir_name = img_dir, img_names = batch_files)
                elif dataset == 'dsprites':
                    [images1, images2] = manager.get_next()
                img1, img2 = images1, images2
                [l1, l2, _, _] = get_reconst_img(img1, img2)

                [loss0, loss1, loss2, loss3, loss4, loss5, latent1, latent2, summary] = train(img1, img2)

                if batch_idx % 50 == 1:
                    header("******* epoch: {}/{} batch: {}/{} *******".format(epoch_idx, max_epoch, batch_idx, num_batch))
                    warn("Total Loss: {}".format(loss0))
                    warn("Siam loss: {}".format(loss1))
                    warn("kl1_loss: {}".format(loss2))
                    warn("kl2_loss: {}".format(loss3))
                    warn("reconst_err1: {}".format(loss4))
                    warn("reconst_err2: {}".format(loss5))

                if batch_idx % check_every_n == 1:
                    if dataset == 'chairs' or dataset == 'celeba':
                        idx = random.sample(range(len(training_images_list)), 2*5)
                        validate_batch_files = [training_images_list[i] for i in idx]
                        [images1, images2] = load_image(dir_name = img_dir, img_names = validate_batch_files)
                    elif dataset == 'dsprites':
                        [images1, images2] = manager.get_next()

                    [reconst1, reconst2, _, _] = get_reconst_img(images1, images2)

                    if dataset == 'chairs':
                        for img_idx in range(len(images1)):
                            sub_dir = "iter_{}_{}".format(epoch_idx, batch_idx)

                            save_img = np.squeeze(images1[img_idx])
                            save_img = Image.fromarray(save_img)
                            img_file_name = "{}_ori.png".format(validate_batch_files[img_idx].split('.')[0])
                            validate_img_saver.save(save_img, img_file_name, sub_dir = sub_dir)

                            save_img = np.squeeze(reconst1[img_idx])
                            save_img = Image.fromarray(save_img)
                            img_file_name = "{}_rec.png".format(validate_batch_files[img_idx].split('.')[0])
                            validate_img_saver.save(save_img, img_file_name, sub_dir = sub_dir)
                    elif dataset == 'celeba':
                        for img_idx in range(len(images1)):
                            sub_dir = "iter_{}_{}".format(epoch_idx, batch_idx)

                            save_img = np.squeeze(images1[img_idx])
                            save_img = Image.fromarray(save_img, 'RGB')
                            img_file_name = "{}_ori.png".format(validate_batch_files[img_idx].split('.')[0])
                            validate_img_saver.save(save_img, img_file_name, sub_dir = sub_dir)

                            save_img = np.squeeze(reconst1[img_idx])
                            save_img = Image.fromarray(save_img, 'RGB')
                            img_file_name = "{}_rec.png".format(validate_batch_files[img_idx].split('.')[0])
                            validate_img_saver.save(save_img, img_file_name, sub_dir = sub_dir)
                    elif dataset == 'dsprites':
                        for img_idx in range(len(images1)):
                            sub_dir = "iter_{}_{}".format(epoch_idx, batch_idx)

                            # save_img = images1[img_idx].reshape(64, 64)
                            save_img = np.squeeze(images1[img_idx])
                            save_img = save_img.astype(np.float32)
                            img_file_name = "{}_ori.jpg".format(img_idx)
                            validate_img_saver.save(save_img, img_file_name, sub_dir = sub_dir)

                            # save_img = reconst1[img_idx].reshape(64, 64)
                            save_img = np.squeeze(reconst1[img_idx])
                            save_img = save_img.astype(np.float32)
                            img_file_name = "{}_rec.jpg".format(img_idx)
                            validate_img_saver.save(save_img, img_file_name, sub_dir = sub_dir)

                if batch_idx % loss_check_n == 1:
                    train_writer.add_summary(summary, batch_idx)

            t_epoch_end = time.time()
            t_epoch_run = t_epoch_end - t_epoch_start
            if dataset == 'dsprites':
                t_check = manager.sample_size / t_epoch_run

                warn("==========================================")
                warn("Run {} th epoch in {} sec: {} images / sec".format(epoch_idx+1, t_epoch_run, t_check))
                warn("==========================================")


            if meta_saved == True:
                saver.save(U.get_session(), chk_save_dir + '/' + 'checkpoint', global_step = epoch_idx, write_meta_graph = False)
            else:
                print "Save  meta graph"
                saver.save(U.get_session(), chk_save_dir + '/' + 'checkpoint', global_step = epoch_idx, write_meta_graph = True)
                meta_saved = True
Exemplo n.º 4
0
def mgpu_classifier_train_net(models, num_gpus, cls_batch_per_gpu, cls_L, mode, img_dir, dataset, chkfile_name, logfile_name, validatefile_name, entangled_feat, max_epoch = 300, check_every_n = 500, loss_check_n = 10, save_model_freq = 5, batch_size = 512, lr = 0.001):
    img1 = U.get_placeholder_cached(name="img1")
    img2 = U.get_placeholder_cached(name="img2")

    feat_cls = U.get_placeholder_cached(name="feat_cls")

    # batch size must be multiples of ntowers (# of GPUs)
    ntowers = len(models)
    tf.assert_equal(tf.shape(img1)[0], tf.shape(img2)[0])
    tf.assert_equal(tf.floormod(tf.shape(img1)[0], ntowers), 0)

    img1splits = tf.split(img1, ntowers, 0)
    img2splits = tf.split(img2, ntowers, 0)

    tower_vae_loss = []
    tower_latent_z1_tp = []
    tower_latent_z2_tp = []
    tower_losses = []
    tower_siam_max = []
    tower_reconst1 = []
    tower_reconst2 = []
    tower_cls_loss = []
    for gid, model in enumerate(models):
        with tf.name_scope('gpu%d' % gid) as scope:
            with tf.device('/gpu:%d' % gid):

                vae_loss = U.mean(model.vaeloss)
                latent_z1_tp = model.latent_z1
                latent_z2_tp = model.latent_z2
                losses = [U.mean(model.vaeloss),
                          U.mean(model.siam_loss),
                          U.mean(model.kl_loss1),
                          U.mean(model.kl_loss2),
                          U.mean(model.reconst_error1),
                          U.mean(model.reconst_error2),
                          ]
                siam_max = U.mean(model.max_siam_loss)
                cls_loss = U.mean(model.cls_loss)

                tower_vae_loss.append(vae_loss)
                tower_latent_z1_tp.append(latent_z1_tp)
                tower_latent_z2_tp.append(latent_z2_tp)
                tower_losses.append(losses)
                tower_siam_max.append(siam_max)
                tower_reconst1.append(model.reconst1)
                tower_reconst2.append(model.reconst2)
                tower_cls_loss.append(cls_loss)

                tf.summary.scalar('Cls Loss', cls_loss)

    vae_loss = U.mean(tower_vae_loss)
    siam_max = U.mean(tower_siam_max)
    latent_z1_tp = tf.concat(tower_latent_z1_tp, 0)
    latent_z2_tp = tf.concat(tower_latent_z2_tp, 0)
    model_reconst1 = tf.concat(tower_reconst1, 0)
    model_reconst2 = tf.concat(tower_reconst2, 0)
    cls_loss = U.mean(tower_cls_loss)

    losses = [[] for _ in range(len(losses))]
    for tl in tower_losses:
        for i, l in enumerate(tl):
            losses[i].append(l)

    losses = [U.mean(l) for l in losses]
    siam_normal = losses[1] / entangled_feat

    tf.summary.scalar('total/cls_loss', cls_loss)

    compute_losses = U.function([img1, img2], vae_loss)

    all_var_list = model.get_trainable_variables()
    vae_var_list = [v for v in all_var_list if v.name.split("/")[2].startswith("vae")]
    cls_var_list = [v for v in all_var_list if v.name.split("/")[2].startswith("cls")]
    warn("{}".format(all_var_list))
    warn("=======================")
    warn("{}".format(vae_var_list))
    warn("=======================")
    warn("{}".format(cls_var_list))

    # with tf.device('/cpu:0'):
    # optimizer = tf.train.AdamOptimizer(learning_rate=lr, epsilon = 0.01/batch_size)
    # optimize_expr1 = optimizer.minimize(vae_loss, var_list=vae_var_list)

    feat_cls_optimizer = tf.train.AdagradOptimizer(learning_rate=0.01)
    optimize_expr2 = feat_cls_optimizer.minimize(cls_loss, var_list=cls_var_list)

    merged = tf.summary.merge_all()
    # train = U.function([img1, img2],
    #                     [losses[0], losses[1], losses[2], losses[3], losses[4], losses[5], latent_z1_tp, latent_z2_tp, merged], updates = [optimize_expr1])

    classifier_train = U.function([img1, img2, feat_cls],
                        [cls_loss, latent_z1_tp, latent_z2_tp, merged], updates = [optimize_expr2])

    get_reconst_img = U.function([img1, img2], [model_reconst1, model_reconst2, latent_z1_tp, latent_z2_tp])
    get_latent_var = U.function([img1, img2], [latent_z1_tp, latent_z2_tp])

    cur_dir = get_cur_dir()
    chk_save_dir = os.path.join(cur_dir, chkfile_name)
    log_save_dir = os.path.join(cur_dir, logfile_name)
    cls_logfile_name = 'cls_{}'.format(logfile_name)
    cls_log_save_dir = os.path.join(cur_dir, cls_logfile_name)
    validate_img_saver_dir = os.path.join(cur_dir, validatefile_name)
    if dataset == 'chairs' or dataset == 'celeba':
        test_img_saver_dir = os.path.join(cur_dir, "test_images")
        testing_img_dir = os.path.join(cur_dir, "dataset/{}/test_img".format(dataset))

    cls_train_writer = U.summary_writer(dir = cls_log_save_dir)

    U.initialize()

    saver, chk_file_epoch_num = U.load_checkpoints(load_requested = True, checkpoint_dir = chk_save_dir)
    if dataset == 'chairs' or dataset == 'celeba':
        validate_img_saver = Img_Saver(Img_dir = validate_img_saver_dir)
    elif dataset == 'dsprites':
        validate_img_saver = BW_Img_Saver(Img_dir = validate_img_saver_dir) # Black and White, temporary usage
    else:
        warn("Unknown dataset Error")
        # break

    warn("dataset: {}".format(dataset))
    if dataset == 'chairs' or dataset == 'celeba':
        training_images_list = read_dataset(img_dir)
        n_total_train_data = len(training_images_list)
        testing_images_list = read_dataset(testing_img_dir)
        n_total_testing_data = len(testing_images_list)
    elif dataset == 'dsprites':
        cur_dir = osp.join(cur_dir, 'dataset')
        cur_dir = osp.join(cur_dir, 'dsprites')
        img_dir = osp.join(cur_dir, 'dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz')
        manager = DataManager(img_dir, batch_size)
    else:
        warn("Unknown dataset Error")
        # break

    meta_saved = False

    cls_train_iter = 10000
    for cls_train_i in range(cls_train_iter):
        # warn("Train:{}".format(cls_train_i))
        if dataset == 'dsprites':
            # At every epoch, train classifier and check result
            # (1) Load images
            num_img_pair = cls_L * num_gpus * cls_batch_per_gpu
            # warn("{} {} {}".format(len(manager.latents_sizes)-1, num_gpus, cls_batch_per_gpu))
            feat = np.random.randint(len(manager.latents_sizes)-1, size = num_gpus * cls_batch_per_gpu)
            [images1, images2] = manager.get_image_fixed_feat_batch(feat, num_img_pair)

            # warn("images shape:{}".format(np.shape(images1)))

            # (2) Input PH images
            [classification_loss, _, _, summary] = classifier_train(images1, images2, feat)
            if cls_train_i % 100 == 0:
                warn("cls loss {}: {}".format(cls_train_i, classification_loss))

            cls_train_writer.add_summary(summary, cls_train_i)
Exemplo n.º 5
0
def train_net(model, img_dir, max_iter = 100000, check_every_n = 20, save_model_freq = 1000, batch_size = 128):
	img1 = U.get_placeholder_cached(name="img1")
	img2 = U.get_placeholder_cached(name="img2")

	mean_loss1 = U.mean(model.match_error)
	mean_loss2 = U.mean(model.reconst_error1)
	mean_loss3 = U.mean(model.reconst_error2)

	decoded_img = [model.reconst1, model.reconst2]

	weight_loss = [1, 1, 1]

	compute_losses = U.function([img1, img2], [mean_loss1, mean_loss2, mean_loss3])
	lr = 0.00001
	optimizer=tf.train.AdamOptimizer(learning_rate=lr, epsilon = 0.01/batch_size)

	all_var_list = model.get_trainable_variables()

	img1_var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("proj1") or v.name.split("/")[1].startswith("unproj1")]
	img2_var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("proj2") or v.name.split("/")[1].startswith("unproj2")]


	img1_loss = mean_loss1 + mean_loss2
	img2_loss = mean_loss1 + mean_loss3

	optimize_expr1 = optimizer.minimize(img1_loss, var_list=img1_var_list)
	optimize_expr2 = optimizer.minimize(img2_loss, var_list=img2_var_list)

	img1_train = U.function([img1, img2], [mean_loss1, mean_loss2, mean_loss3], updates = [optimize_expr1])
	img2_train = U.function([img1, img2], [mean_loss1, mean_loss2, mean_loss3], updates = [optimize_expr2])

	get_reconst_img = U.function([img1, img2], decoded_img)

	U.initialize()

	name = "test"
	cur_dir = get_cur_dir()
	chk_save_dir = os.path.join(cur_dir, "chkfiles")
	log_save_dir = os.path.join(cur_dir, "log")
	test_img_saver_dir = os.path.join(cur_dir, "test_images")

	saver, chk_file_num = U.load_checkpoints(load_requested = True, checkpoint_dir = chk_save_dir)
	test_img_saver = Img_Saver(test_img_saver_dir)

	meta_saved = False

	iter_log = []
	loss1_log = []
	loss2_log = []
	loss3_log = []

	training_images_list = read_dataset(img_dir)

	for num_iter in range(chk_file_num+1, max_iter):
		header("******* {}th iter: Img {} side *******".format(num_iter, num_iter%2 + 1))

		idx = random.sample(range(len(training_images_list)), batch_size)
		batch_files = [training_images_list[i] for i in idx]
		[images1, images2] = load_image(dir_name = img_dir, img_names = batch_files)
		img1, img2 = images1, images2
		# args = images1, images2
		if num_iter%2 == 0:
			[loss1, loss2, loss3] = img1_train(img1, img2)
		elif num_iter%2 == 1:
			[loss1, loss2, loss3] = img2_train(img1, img2)		
		warn("match_error: {}".format(loss1))
		warn("reconst_err1: {}".format(loss2))
		warn("reconst_err2: {}".format(loss3))
		warn("num_iter: {} check: {}".format(num_iter, check_every_n))
		if num_iter % check_every_n == 1:
			idx = random.sample(range(len(training_images_list)), 10)
			test_batch_files = [training_images_list[i] for i in idx]
			[images1, images2] = load_image(dir_name = img_dir, img_names = test_batch_files)
			[reconst1, reconst2] = get_reconst_img(images1, images2)
			for img_idx in range(len(images1)):
				sub_dir = "iter_{}".format(num_iter)

				save_img = np.squeeze(images1[img_idx])
				save_img = Image.fromarray(save_img)
				img_file_name = "{}_ori_2d.jpg".format(test_batch_files[img_idx])				
				test_img_saver.save(save_img, img_file_name, sub_dir = sub_dir)

				save_img = np.squeeze(images2[img_idx])
				save_img = Image.fromarray(save_img)
				img_file_name = "{}_ori_3d.jpg".format(test_batch_files[img_idx])				
				test_img_saver.save(save_img, img_file_name, sub_dir = sub_dir)

				save_img = np.squeeze(reconst1[img_idx])
				save_img = Image.fromarray(save_img)
				img_file_name = "{}_rec_2d.jpg".format(test_batch_files[img_idx])				
				test_img_saver.save(save_img, img_file_name, sub_dir = sub_dir)

				save_img = np.squeeze(reconst2[img_idx])
				save_img = Image.fromarray(save_img)
				img_file_name = "{}_rec_3d.jpg".format(test_batch_files[img_idx])				
				test_img_saver.save(save_img, img_file_name, sub_dir = sub_dir)

		if num_iter > 11 and num_iter % save_model_freq == 1:
			if meta_saved == True:
				saver.save(U.get_session(), chk_save_dir + '/' + 'checkpoint', global_step = num_iter, write_meta_graph = False)
			else:
				print "Save  meta graph"
				saver.save(U.get_session(), chk_save_dir + '/' + 'checkpoint', global_step = num_iter, write_meta_graph = True)
				meta_saved = True