def main(): resize_size = (210, 160) cur_dir = get_cur_dir() img_dir = os.path.join(cur_dir, "training_images") img_rgb_dir = os.path.join(img_dir, "rgb") img_lidar_dir = os.path.join(img_dir, "lidar") # rgb_files = [f for f in os.listdir(img_rgb_dir) if os.path.isfile(os.path.join(img_rgb_dir, f))] lidar_files = [f for f in os.listdir(img_lidar_dir) if os.path.isfile(os.path.join(img_lidar_dir, f))] # img_save_dir = os.path.join(cur_dir, "training_images") # img_rgb_save_dir = os.path.join(img_save_dir, "rgb") # img_lidar_save_dir = os.path.join(img_save_dir, "lidar") # rgb_files.sort(key=getint) # file_name = rgb_files[0] # print "Processing {}".format(file_name) # img_rgb_file_path = os.path.join(img_rgb_dir, file_name) # im_rgb = Image.open(img_rgb_file_path) # im_rgb.show() # im_rgb_arr = np.array(im_rgb) # # test = Image.fromarray(im_rgb_arr, 'RGB') # # test.show() # print im_rgb_arr.shape # if im_rgb_arr.size == 1242*375*3: # im_rgb_arr = np.reshape(im_rgb_arr, [375, 1242, 3]).astype(np.float32) # # test = Image.fromarray(im_rgb_arr, 'RGB') # # test.show() # im_rgb_arr = im_rgb_arr[:, :, 0] * 0.299 + im_rgb_arr[:, :, 1] * 0.587 + im_rgb_arr[:, :, 2] * 0.114 # test = Image.fromarray(im_rgb_arr) # test.show() # print im_rgb_arr.shape # for file_name in rgb_files: # print "Processing {}".format(file_name) # img_rgb_file_path = os.path.join(img_rgb_dir, file_name) # im_rgb = Image.open(img_rgb_file_path) # im_rgb_resized = im_rgb.resize(resize_size, Image.ANTIALIAS) # img_rgb_save_file_path = os.path.join(img_rgb_save_dir, file_name) # im_rgb_resized.save(img_rgb_save_file_path) # lidar_files.sort(key=getint) # file_name = lidar_files[0] # print "Processing {}".format(file_name) # img_lidar_file_path = os.path.join(img_lidar_dir, file_name) # im_lidar = Image.open(img_lidar_file_path) file_name = lidar_files[0] print "Processing {}".format(file_name) img_lidar_file_path = os.path.join(img_lidar_dir, file_name) im_lidar = Image.open(img_lidar_file_path) print im_lidar.size im_lidar.show() im_lidar_arr = np.array(im_lidar) test = Image.fromarray(im_lidar_arr, 'RGB') test.show() im_lidar_arr = im_lidar_arr[:, :, 0] * 0.299 + im_lidar_arr[:, :, 1] * 0.587 + im_lidar_arr[:, :, 2] * 0.114 test = Image.fromarray(im_lidar_arr) test.show()
def load_checkpoints(load_requested = True, checkpoint_dir = get_cur_dir()): saver = tf.train.Saver(max_to_keep = None) checkpoint = tf.train.get_checkpoint_state(checkpoint_dir) chkpoint_num = 0 if checkpoint and checkpoint.model_checkpoint_path and load_requested == True: saver.restore(get_session(), checkpoint.model_checkpoint_path) chk_file = checkpoint.model_checkpoint_path.split('/') chk_file = chk_file[-1] chk_file = chk_file.split('-') chkpoint_num = int(chk_file[-1]) warn("loaded checkpoint: {0}".format(checkpoint.model_checkpoint_path)) else: warn("Could not find old checkpoint") if not os.path.exists(checkpoint_dir): mkdir_p(checkpoint_dir) return saver, chkpoint_num
def main(): # parser = argparse.ArgumentParser() # parser.add_argument('mode', type=string) # args = parser.parse_args() # print args.mode sess = U.single_threaded_session() sess.__enter__() set_global_seeds(0) dir_name = "training_images" cur_dir = get_cur_dir() img_dir = osp.join(cur_dir, dir_name) header("Load model") mynet = mymodel(name="mynet", img_shape = [210, 160, 1], latent_dim = 2048) header("Load model") train_net(model = mynet, img_dir = img_dir)
import glob import argparse import os import time import sys import tensorflow as tf from itertools import count from misc_util import get_cur_dir, warn, mkdir_p import cv2 from utils.utils import label_to_gt_box2d, bbox_iou, random_distort_image, draw_bbox2d_on_image import numpy as np from config import cfg # from data_aug import image_augmentation cur_dir = get_cur_dir() dataset_dir = os.path.join(cur_dir, 'data/object') warn("dataset_dir: {}".format(dataset_dir)) dataset = 'training' split_file = 'trainset.txt' test_img_save_dir = 'test_img' test_img_save_dir = os.path.join(cur_dir, test_img_save_dir) mkdir_p(test_img_save_dir) def image_augmentation(f_rgb, f_label, width, height, jitter, hue, saturation, exposure): rgb_imgs = [] ious = [] org_imgs = [] label = np.array([line for line in open(f_label, 'r').readlines()])
def test_net(model, img_dir, max_iter=1000000, check_every_n=500, loss_check_n=10, save_model_freq=1000, batch_size=128): img1 = U.get_placeholder_cached(name="img1") img2 = U.get_placeholder_cached(name="img2") # Testing img_test = U.get_placeholder_cached(name="img_test") reconst_tp = U.get_placeholder_cached(name="reconst_tp") vae_loss = U.mean(model.vaeloss) latent_z1_tp = model.latent_z1 latent_z2_tp = model.latent_z2 losses = [ U.mean(model.vaeloss), U.mean(model.siam_loss), U.mean(model.kl_loss1), U.mean(model.kl_loss2), U.mean(model.reconst_error1), U.mean(model.reconst_error2), ] tf.summary.scalar('Total Loss', losses[0]) tf.summary.scalar('Siam Loss', losses[1]) tf.summary.scalar('kl1_loss', losses[2]) tf.summary.scalar('kl2_loss', losses[3]) tf.summary.scalar('reconst_err1', losses[4]) tf.summary.scalar('reconst_err2', losses[5]) decoded_img = [model.reconst1, model.reconst2] weight_loss = [1, 1, 1] compute_losses = U.function([img1, img2], vae_loss) lr = 0.00005 optimizer = tf.train.AdamOptimizer(learning_rate=lr, epsilon=0.01 / batch_size) all_var_list = model.get_trainable_variables() # print all_var_list img1_var_list = all_var_list #[v for v in all_var_list if v.name.split("/")[1].startswith("proj1") or v.name.split("/")[1].startswith("unproj1")] optimize_expr1 = optimizer.minimize(vae_loss, var_list=img1_var_list) merged = tf.summary.merge_all() train = U.function([img1, img2], [ losses[0], losses[1], losses[2], losses[3], losses[4], losses[5], latent_z1_tp, latent_z2_tp, merged ], updates=[optimize_expr1]) get_reconst_img = U.function( [img1, img2], [model.reconst1, model.reconst2, latent_z1_tp, latent_z2_tp]) get_latent_var = U.function([img1, img2], [latent_z1_tp, latent_z2_tp]) # [testing -> ] test = U.function([img_test], model.latent_z_test) test_reconst = U.function([reconst_tp], [model.reconst_test]) # [testing <- ] cur_dir = get_cur_dir() chk_save_dir = os.path.join(cur_dir, "chk1") log_save_dir = os.path.join(cur_dir, "log") validate_img_saver_dir = os.path.join(cur_dir, "validate_images") test_img_saver_dir = os.path.join(cur_dir, "test_images") testing_img_dir = os.path.join(cur_dir, "dataset/test_img") train_writer = U.summary_writer(dir=log_save_dir) U.initialize() saver, chk_file_num = U.load_checkpoints(load_requested=True, checkpoint_dir=chk_save_dir) validate_img_saver = Img_Saver(validate_img_saver_dir) # [testing -> ] test_img_saver = Img_Saver(test_img_saver_dir) # [testing <- ] meta_saved = False iter_log = [] loss1_log = [] loss2_log = [] loss3_log = [] training_images_list = read_dataset(img_dir) n_total_train_data = len(training_images_list) testing_images_list = read_dataset(testing_img_dir) n_total_testing_data = len(testing_images_list) training = False testing = True # if training == True: # for num_iter in range(chk_file_num+1, max_iter): # header("******* {}th iter: *******".format(num_iter)) # idx = random.sample(range(n_total_train_data), 2*batch_size) # batch_files = [training_images_list[i] for i in idx] # # print batch_files # [images1, images2] = load_image(dir_name = img_dir, img_names = batch_files) # img1, img2 = images1, images2 # [l1, l2, _, _] = get_reconst_img(img1, img2) # [loss0, loss1, loss2, loss3, loss4, loss5, latent1, latent2, summary] = train(img1, img2) # warn("Total Loss: {}".format(loss0)) # warn("Siam loss: {}".format(loss1)) # warn("kl1_loss: {}".format(loss2)) # warn("kl2_loss: {}".format(loss3)) # warn("reconst_err1: {}".format(loss4)) # warn("reconst_err2: {}".format(loss5)) # # warn("num_iter: {} check: {}".format(num_iter, check_every_n)) # # warn("Total Loss: {}".format(loss6)) # if num_iter % check_every_n == 1: # header("******* {}th iter: *******".format(num_iter)) # idx = random.sample(range(len(training_images_list)), 2*5) # validate_batch_files = [training_images_list[i] for i in idx] # [images1, images2] = load_image(dir_name = img_dir, img_names = validate_batch_files) # [reconst1, reconst2, _, _] = get_reconst_img(images1, images2) # # for i in range(len(latent1[0])): # # print "{} th: {:.2f}".format(i, np.mean(np.abs(latent1[:, i] - latent2[:, i]))) # for img_idx in range(len(images1)): # sub_dir = "iter_{}".format(num_iter) # save_img = np.squeeze(images1[img_idx]) # save_img = Image.fromarray(save_img) # img_file_name = "{}_ori.png".format(validate_batch_files[img_idx].split('.')[0]) # validate_img_saver.save(save_img, img_file_name, sub_dir = sub_dir) # save_img = np.squeeze(reconst1[img_idx]) # save_img = Image.fromarray(save_img) # img_file_name = "{}_rec.png".format(validate_batch_files[img_idx].split('.')[0]) # validate_img_saver.save(save_img, img_file_name, sub_dir = sub_dir) # if num_iter % loss_check_n == 1: # train_writer.add_summary(summary, num_iter) # if num_iter > 11 and num_iter % save_model_freq == 1: # if meta_saved == True: # saver.save(U.get_session(), chk_save_dir + '/' + 'checkpoint', global_step = num_iter, write_meta_graph = False) # else: # print "Save meta graph" # saver.save(U.get_session(), chk_save_dir + '/' + 'checkpoint', global_step = num_iter, write_meta_graph = True) # meta_saved = True # Testing print testing_images_list if testing == True: test_file_name = testing_images_list[6] print test_file_name test_img = load_single_img(dir_name=testing_img_dir, img_name=test_file_name) test_features = np.arange(25, 32) for test_feature in test_features: test_variation = np.arange(-10, 10, 0.1) z = test(test_img) print np.shape(z) print z for idx in range(len(test_variation)): z_test = np.copy(z) z_test[0, test_feature] = z_test[ 0, test_feature] + test_variation[idx] reconst_test = test_reconst(z_test) test_save_img = np.squeeze(reconst_test[0]) test_save_img = Image.fromarray(test_save_img) img_file_name = "test_feat_{}_var_({}).png".format( test_feature, test_variation[idx]) test_img_saver.save(test_save_img, img_file_name, sub_dir=None) reconst_test = test_reconst(z) test_save_img = np.squeeze(reconst_test[0]) test_save_img = Image.fromarray(test_save_img) img_file_name = "test_feat_{}_var_original.png".format( test_feature) test_img_saver.save(test_save_img, img_file_name, sub_dir=None)
def main(): # Base: https://openreview.net/pdf?id=Sy2fzU9gl # (1) parse arguments parser = argparse.ArgumentParser() parser.add_argument('--dataset') # chairs, celeba, dsprites parser.add_argument('--mode') # train, test parser.add_argument('--disentangled_feat', type=int) parser.add_argument('--num_gpus', type=int, default=1) args = parser.parse_args() dataset = args.dataset mode = args.mode disentangled_feat = args.disentangled_feat chkfile_name = "chk_{}_{}".format(dataset, disentangled_feat) logfile_name = "log_{}_{}".format(dataset, disentangled_feat) validatefile_name = "val_{}_{}".format(dataset, disentangled_feat) # (2) Dataset if dataset == 'chairs': dir_name = "/dataset/chairs/training_img" elif dataset == 'celeba': dir_name = 'temporarily not available' elif dataset == 'dsprites': dir_name = '/dataset/dsprites' # This is dummy, for dsprites dataset, we are using data_manager else: header("Unknown dataset name") cur_dir = get_cur_dir() cur_dir = osp.join(cur_dir, 'dataset') cur_dir = osp.join(cur_dir, 'chairs') img_dir = osp.join(cur_dir, 'training_img') # This is for chairs # (3) Set experiment configuration, and disentangled_feat, according to beta-VAE( https://openreview.net/pdf?id=Sy2fzU9gl ) if dataset == 'chairs': latent_dim = 32 loss_weight = {'siam': 50000.0, 'kl': 30000.0} batch_size = 32 max_epoch = 300 lr = 0.0001 elif dataset == 'celeba': latent_dim = 32 loss_weight = {'siam': 1000.0, 'kl': 30000.0} batch_size = 512 max_epoch = 300 lr = 0.0001 elif dataset == 'dsprites': latent_dim = 10 loss_weight = {'siam': 1.0, 'kl': 1.0} batch_size = 1024 max_epoch = 300 lr = 0.001 feat_size = 5 # shape, rotation, size, x, y => Don't know why there are only 4 features in paper p6. Need to check more about it. cls_batch_per_gpu = 15 cls_L = 10 entangled_feat = latent_dim - disentangled_feat # (4) Open Tensorflow session, Need to find optimal configuration because we don't need to use single thread session # Important!!! : If we don't use single threaded session, then we need to change this!!! # sess = U.single_threaded_session() sess = U.mgpu_session() sess.__enter__() set_global_seeds(0) num_gpus = args.num_gpus # Model Setting # (5) Import model, merged into models.py # only celeba has RGB channel, other has black and white. if dataset == 'chairs': import models mynet = models.mymodel(name="mynet", img_shape=[64, 64, 1], latent_dim=latent_dim, disentangled_feat=disentangled_feat, mode=mode, loss_weight=loss_weight) elif dataset == 'celeba': import models mynet = models.mymodel(name="mynet", img_shape=[64, 64, 3], latent_dim=latent_dim, disentangled_feat=disentangled_feat, mode=mode, loss_weight=loss_weight) elif dataset == 'dsprites': import models img_shape = [None, 64, 64, 1] img1 = U.get_placeholder(name="img1", dtype=tf.float32, shape=img_shape) img2 = U.get_placeholder(name="img2", dtype=tf.float32, shape=img_shape) feat_cls = U.get_placeholder(name="feat_cls", dtype=tf.int32, shape=None) tf.assert_equal(tf.shape(img1)[0], tf.shape(img2)[0]) tf.assert_equal(tf.floormod(tf.shape(img1)[0], num_gpus), 0) tf.assert_equal(tf.floormod(tf.shape(feat_cls)[0], num_gpus), 0) img1splits = tf.split(img1, num_gpus, 0) img2splits = tf.split(img2, num_gpus, 0) feat_cls_splits = tf.split(feat_cls, num_gpus, 0) mynets = [] with tf.variable_scope(tf.get_variable_scope()): for gid in range(num_gpus): with tf.name_scope('gpu%d' % gid) as scope: with tf.device('/gpu:%d' % gid): mynet = models.mymodel( name="mynet", img1=img1splits[gid], img2=img2splits[gid], img_shape=img_shape[1:], latent_dim=latent_dim, disentangled_feat=disentangled_feat, mode=mode, loss_weight=loss_weight, feat_cls=feat_cls_splits[gid], feat_size=feat_size, cls_L=cls_L, cls_batch_per_gpu=cls_batch_per_gpu) mynets.append(mynet) # Reuse variables for the next tower. tf.get_variable_scope().reuse_variables() else: header("Unknown model name") # (6) Train or test the model # Testing by adding noise on latent feature is not merged yet. Will be finished soon. if mode == 'train': mgpu_train_net(models=mynets, num_gpus=num_gpus, mode=mode, img_dir=img_dir, dataset=dataset, chkfile_name=chkfile_name, logfile_name=logfile_name, validatefile_name=validatefile_name, entangled_feat=entangled_feat, max_epoch=max_epoch, batch_size=batch_size, lr=lr) # train_net(model=mynets[0], mode = mode, img_dir = img_dir, dataset = dataset, chkfile_name = chkfile_name, logfile_name = logfile_name, validatefile_name = validatefile_name, entangled_feat = entangled_feat, max_epoch = max_epoch, batch_size = batch_size, lr = lr) elif mode == 'classifier_train': warn("Classifier Train") mgpu_classifier_train_net(models=mynets, num_gpus=num_gpus, cls_batch_per_gpu=cls_batch_per_gpu, cls_L=cls_L, mode=mode, img_dir=img_dir, dataset=dataset, chkfile_name=chkfile_name, logfile_name=logfile_name, validatefile_name=validatefile_name, entangled_feat=entangled_feat, max_epoch=max_epoch, batch_size=batch_size, lr=lr) elif mode == 'test': header("Need to be merged") else: header("Unknown mode name")
def train_net(model, mode, img_dir, dataset, chkfile_name, logfile_name, validatefile_name, entangled_feat, max_epoch = 300, check_every_n = 500, loss_check_n = 10, save_model_freq = 5, batch_size = 512, lr = 0.001): img1 = U.get_placeholder_cached(name="img1") img2 = U.get_placeholder_cached(name="img2") vae_loss = U.mean(model.vaeloss) latent_z1_tp = model.latent_z1 latent_z2_tp = model.latent_z2 losses = [U.mean(model.vaeloss), U.mean(model.siam_loss), U.mean(model.kl_loss1), U.mean(model.kl_loss2), U.mean(model.reconst_error1), U.mean(model.reconst_error2), ] siam_normal = losses[1]/entangled_feat siam_max = U.mean(model.max_siam_loss) tf.summary.scalar('Total Loss', losses[0]) tf.summary.scalar('Siam Loss', losses[1]) tf.summary.scalar('kl1_loss', losses[2]) tf.summary.scalar('kl2_loss', losses[3]) tf.summary.scalar('reconst_err1', losses[4]) tf.summary.scalar('reconst_err2', losses[5]) tf.summary.scalar('Siam Normal', siam_normal) tf.summary.scalar('Siam Max', siam_max) compute_losses = U.function([img1, img2], vae_loss) optimizer=tf.train.AdamOptimizer(learning_rate=lr, epsilon = 0.01/batch_size) all_var_list = model.get_trainable_variables() img1_var_list = all_var_list optimize_expr1 = optimizer.minimize(vae_loss, var_list=img1_var_list) merged = tf.summary.merge_all() train = U.function([img1, img2], [losses[0], losses[1], losses[2], losses[3], losses[4], losses[5], latent_z1_tp, latent_z2_tp, merged], updates = [optimize_expr1]) get_reconst_img = U.function([img1, img2], [model.reconst1, model.reconst2, latent_z1_tp, latent_z2_tp]) get_latent_var = U.function([img1, img2], [latent_z1_tp, latent_z2_tp]) cur_dir = get_cur_dir() chk_save_dir = os.path.join(cur_dir, chkfile_name) log_save_dir = os.path.join(cur_dir, logfile_name) validate_img_saver_dir = os.path.join(cur_dir, validatefile_name) if dataset == 'chairs' or dataset == 'celeba': test_img_saver_dir = os.path.join(cur_dir, "test_images") testing_img_dir = os.path.join(cur_dir, "dataset/{}/test_img".format(dataset)) train_writer = U.summary_writer(dir = log_save_dir) U.initialize() saver, chk_file_epoch_num = U.load_checkpoints(load_requested = True, checkpoint_dir = chk_save_dir) if dataset == 'chairs' or dataset == 'celeba': validate_img_saver = Img_Saver(Img_dir = validate_img_saver_dir) elif dataset == 'dsprites': validate_img_saver = BW_Img_Saver(Img_dir = validate_img_saver_dir) # Black and White, temporary usage else: warn("Unknown dataset Error") # break warn(img_dir) if dataset == 'chairs' or dataset == 'celeba': training_images_list = read_dataset(img_dir) n_total_train_data = len(training_images_list) testing_images_list = read_dataset(testing_img_dir) n_total_testing_data = len(testing_images_list) elif dataset == 'dsprites': cur_dir = osp.join(cur_dir, 'dataset') cur_dir = osp.join(cur_dir, 'dsprites') img_dir = osp.join(cur_dir, 'dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz') manager = DataManager(img_dir, batch_size) else: warn("Unknown dataset Error") # break meta_saved = False if mode == 'train': for epoch_idx in range(chk_file_epoch_num+1, max_epoch): t_epoch_start = time.time() num_batch = manager.get_len() for batch_idx in range(num_batch): if dataset == 'chairs' or dataset == 'celeba': idx = random.sample(range(n_total_train_data), 2*batch_size) batch_files = [training_images_list[i] for i in idx] [images1, images2] = load_image(dir_name = img_dir, img_names = batch_files) elif dataset == 'dsprites': [images1, images2] = manager.get_next() img1, img2 = images1, images2 [l1, l2, _, _] = get_reconst_img(img1, img2) [loss0, loss1, loss2, loss3, loss4, loss5, latent1, latent2, summary] = train(img1, img2) if batch_idx % 50 == 1: header("******* epoch: {}/{} batch: {}/{} *******".format(epoch_idx, max_epoch, batch_idx, num_batch)) warn("Total Loss: {}".format(loss0)) warn("Siam loss: {}".format(loss1)) warn("kl1_loss: {}".format(loss2)) warn("kl2_loss: {}".format(loss3)) warn("reconst_err1: {}".format(loss4)) warn("reconst_err2: {}".format(loss5)) if batch_idx % check_every_n == 1: if dataset == 'chairs' or dataset == 'celeba': idx = random.sample(range(len(training_images_list)), 2*5) validate_batch_files = [training_images_list[i] for i in idx] [images1, images2] = load_image(dir_name = img_dir, img_names = validate_batch_files) elif dataset == 'dsprites': [images1, images2] = manager.get_next() [reconst1, reconst2, _, _] = get_reconst_img(images1, images2) if dataset == 'chairs': for img_idx in range(len(images1)): sub_dir = "iter_{}".format(batch_idx) save_img = np.squeeze(images1[img_idx]) save_img = Image.fromarray(save_img) img_file_name = "{}_ori.png".format(validate_batch_files[img_idx].split('.')[0]) validate_img_saver.save(save_img, img_file_name, sub_dir = sub_dir) save_img = np.squeeze(reconst1[img_idx]) save_img = Image.fromarray(save_img) img_file_name = "{}_rec.png".format(validate_batch_files[img_idx].split('.')[0]) validate_img_saver.save(save_img, img_file_name, sub_dir = sub_dir) elif dataset == 'celeba': for img_idx in range(len(images1)): sub_dir = "iter_{}".format(batch_idx) save_img = np.squeeze(images1[img_idx]) save_img = Image.fromarray(save_img, 'RGB') img_file_name = "{}_ori.png".format(validate_batch_files[img_idx].split('.')[0]) validate_img_saver.save(save_img, img_file_name, sub_dir = sub_dir) save_img = np.squeeze(reconst1[img_idx]) save_img = Image.fromarray(save_img, 'RGB') img_file_name = "{}_rec.png".format(validate_batch_files[img_idx].split('.')[0]) validate_img_saver.save(save_img, img_file_name, sub_dir = sub_dir) elif dataset == 'dsprites': for img_idx in range(len(images1)): sub_dir = "iter_{}".format(batch_idx) # save_img = images1[img_idx].reshape(64, 64) save_img = np.squeeze(images1[img_idx]) save_img = save_img.astype(np.float32) img_file_name = "{}_ori.jpg".format(img_idx) validate_img_saver.save(save_img, img_file_name, sub_dir = sub_dir) # save_img = reconst1[img_idx].reshape(64, 64) save_img = np.squeeze(reconst1[img_idx]) save_img = save_img.astype(np.float32) img_file_name = "{}_rec.jpg".format(img_idx) validate_img_saver.save(save_img, img_file_name, sub_dir = sub_dir) if batch_idx % loss_check_n == 1: train_writer.add_summary(summary, batch_idx) t_epoch_end = time.time() t_epoch_run = t_epoch_end - t_epoch_start if dataset == 'dsprites': t_check = manager.sample_size / t_epoch_run warn("==========================================") warn("Run {} th epoch in {} sec: {} images / sec".format(epoch_idx+1, t_epoch_run, t_check)) warn("==========================================") # if epoch_idx % save_model_freq == 0: if meta_saved == True: saver.save(U.get_session(), chk_save_dir + '/' + 'checkpoint', global_step = epoch_idx, write_meta_graph = False) else: print "Save meta graph" saver.save(U.get_session(), chk_save_dir + '/' + 'checkpoint', global_step = epoch_idx, write_meta_graph = True) meta_saved = True # Testing elif mode == 'test': test_file_name = testing_images_list[0] test_img = load_single_img(dir_name = testing_img_dir, img_name = test_file_name) test_feature = 31 test_variation = np.arange(-5, 5, 0.1) z = test(test_img) for idx in range(len(test_variation)): z_test = np.copy(z) z_test[0, test_feature] = z_test[0, test_feature] + test_variation[idx] reconst_test = test_reconst(z_test) test_save_img = np.squeeze(reconst_test[0]) test_save_img = Image.fromarray(test_save_img) img_file_name = "test_feat_{}_var_({}).png".format(test_feature, test_variation[idx]) test_img_saver.save(test_save_img, img_file_name, sub_dir = None) reconst_test = test_reconst(z) test_save_img = np.squeeze(reconst_test[0]) test_save_img = Image.fromarray(test_save_img) img_file_name = "test_feat_{}_var_original.png".format(test_feature) test_img_saver.save(test_save_img, img_file_name, sub_dir = None)
def mgpu_classifier_train_net(models, num_gpus, cls_batch_per_gpu, cls_L, mode, img_dir, dataset, chkfile_name, logfile_name, validatefile_name, entangled_feat, max_epoch = 300, check_every_n = 500, loss_check_n = 10, save_model_freq = 5, batch_size = 512, lr = 0.001): img1 = U.get_placeholder_cached(name="img1") img2 = U.get_placeholder_cached(name="img2") feat_cls = U.get_placeholder_cached(name="feat_cls") # batch size must be multiples of ntowers (# of GPUs) ntowers = len(models) tf.assert_equal(tf.shape(img1)[0], tf.shape(img2)[0]) tf.assert_equal(tf.floormod(tf.shape(img1)[0], ntowers), 0) img1splits = tf.split(img1, ntowers, 0) img2splits = tf.split(img2, ntowers, 0) tower_vae_loss = [] tower_latent_z1_tp = [] tower_latent_z2_tp = [] tower_losses = [] tower_siam_max = [] tower_reconst1 = [] tower_reconst2 = [] tower_cls_loss = [] for gid, model in enumerate(models): with tf.name_scope('gpu%d' % gid) as scope: with tf.device('/gpu:%d' % gid): vae_loss = U.mean(model.vaeloss) latent_z1_tp = model.latent_z1 latent_z2_tp = model.latent_z2 losses = [U.mean(model.vaeloss), U.mean(model.siam_loss), U.mean(model.kl_loss1), U.mean(model.kl_loss2), U.mean(model.reconst_error1), U.mean(model.reconst_error2), ] siam_max = U.mean(model.max_siam_loss) cls_loss = U.mean(model.cls_loss) tower_vae_loss.append(vae_loss) tower_latent_z1_tp.append(latent_z1_tp) tower_latent_z2_tp.append(latent_z2_tp) tower_losses.append(losses) tower_siam_max.append(siam_max) tower_reconst1.append(model.reconst1) tower_reconst2.append(model.reconst2) tower_cls_loss.append(cls_loss) tf.summary.scalar('Cls Loss', cls_loss) vae_loss = U.mean(tower_vae_loss) siam_max = U.mean(tower_siam_max) latent_z1_tp = tf.concat(tower_latent_z1_tp, 0) latent_z2_tp = tf.concat(tower_latent_z2_tp, 0) model_reconst1 = tf.concat(tower_reconst1, 0) model_reconst2 = tf.concat(tower_reconst2, 0) cls_loss = U.mean(tower_cls_loss) losses = [[] for _ in range(len(losses))] for tl in tower_losses: for i, l in enumerate(tl): losses[i].append(l) losses = [U.mean(l) for l in losses] siam_normal = losses[1] / entangled_feat tf.summary.scalar('total/cls_loss', cls_loss) compute_losses = U.function([img1, img2], vae_loss) all_var_list = model.get_trainable_variables() vae_var_list = [v for v in all_var_list if v.name.split("/")[2].startswith("vae")] cls_var_list = [v for v in all_var_list if v.name.split("/")[2].startswith("cls")] warn("{}".format(all_var_list)) warn("=======================") warn("{}".format(vae_var_list)) warn("=======================") warn("{}".format(cls_var_list)) # with tf.device('/cpu:0'): # optimizer = tf.train.AdamOptimizer(learning_rate=lr, epsilon = 0.01/batch_size) # optimize_expr1 = optimizer.minimize(vae_loss, var_list=vae_var_list) feat_cls_optimizer = tf.train.AdagradOptimizer(learning_rate=0.01) optimize_expr2 = feat_cls_optimizer.minimize(cls_loss, var_list=cls_var_list) merged = tf.summary.merge_all() # train = U.function([img1, img2], # [losses[0], losses[1], losses[2], losses[3], losses[4], losses[5], latent_z1_tp, latent_z2_tp, merged], updates = [optimize_expr1]) classifier_train = U.function([img1, img2, feat_cls], [cls_loss, latent_z1_tp, latent_z2_tp, merged], updates = [optimize_expr2]) get_reconst_img = U.function([img1, img2], [model_reconst1, model_reconst2, latent_z1_tp, latent_z2_tp]) get_latent_var = U.function([img1, img2], [latent_z1_tp, latent_z2_tp]) cur_dir = get_cur_dir() chk_save_dir = os.path.join(cur_dir, chkfile_name) log_save_dir = os.path.join(cur_dir, logfile_name) cls_logfile_name = 'cls_{}'.format(logfile_name) cls_log_save_dir = os.path.join(cur_dir, cls_logfile_name) validate_img_saver_dir = os.path.join(cur_dir, validatefile_name) if dataset == 'chairs' or dataset == 'celeba': test_img_saver_dir = os.path.join(cur_dir, "test_images") testing_img_dir = os.path.join(cur_dir, "dataset/{}/test_img".format(dataset)) cls_train_writer = U.summary_writer(dir = cls_log_save_dir) U.initialize() saver, chk_file_epoch_num = U.load_checkpoints(load_requested = True, checkpoint_dir = chk_save_dir) if dataset == 'chairs' or dataset == 'celeba': validate_img_saver = Img_Saver(Img_dir = validate_img_saver_dir) elif dataset == 'dsprites': validate_img_saver = BW_Img_Saver(Img_dir = validate_img_saver_dir) # Black and White, temporary usage else: warn("Unknown dataset Error") # break warn("dataset: {}".format(dataset)) if dataset == 'chairs' or dataset == 'celeba': training_images_list = read_dataset(img_dir) n_total_train_data = len(training_images_list) testing_images_list = read_dataset(testing_img_dir) n_total_testing_data = len(testing_images_list) elif dataset == 'dsprites': cur_dir = osp.join(cur_dir, 'dataset') cur_dir = osp.join(cur_dir, 'dsprites') img_dir = osp.join(cur_dir, 'dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz') manager = DataManager(img_dir, batch_size) else: warn("Unknown dataset Error") # break meta_saved = False cls_train_iter = 10000 for cls_train_i in range(cls_train_iter): # warn("Train:{}".format(cls_train_i)) if dataset == 'dsprites': # At every epoch, train classifier and check result # (1) Load images num_img_pair = cls_L * num_gpus * cls_batch_per_gpu # warn("{} {} {}".format(len(manager.latents_sizes)-1, num_gpus, cls_batch_per_gpu)) feat = np.random.randint(len(manager.latents_sizes)-1, size = num_gpus * cls_batch_per_gpu) [images1, images2] = manager.get_image_fixed_feat_batch(feat, num_img_pair) # warn("images shape:{}".format(np.shape(images1))) # (2) Input PH images [classification_loss, _, _, summary] = classifier_train(images1, images2, feat) if cls_train_i % 100 == 0: warn("cls loss {}: {}".format(cls_train_i, classification_loss)) cls_train_writer.add_summary(summary, cls_train_i)
def mgpu_train_net(models, num_gpus, mode, img_dir, dataset, chkfile_name, logfile_name, validatefile_name, entangled_feat, max_epoch = 300, check_every_n = 500, loss_check_n = 10, save_model_freq = 5, batch_size = 512, lr = 0.001): img1 = U.get_placeholder_cached(name="img1") img2 = U.get_placeholder_cached(name="img2") feat_cls = U.get_placeholder_cached(name="feat_cls") # batch size must be multiples of ntowers (# of GPUs) ntowers = len(models) tf.assert_equal(tf.shape(img1)[0], tf.shape(img2)[0]) tf.assert_equal(tf.floormod(tf.shape(img1)[0], ntowers), 0) img1splits = tf.split(img1, ntowers, 0) img2splits = tf.split(img2, ntowers, 0) tower_vae_loss = [] tower_latent_z1_tp = [] tower_latent_z2_tp = [] tower_losses = [] tower_siam_max = [] tower_reconst1 = [] tower_reconst2 = [] tower_cls_loss = [] for gid, model in enumerate(models): with tf.name_scope('gpu%d' % gid) as scope: with tf.device('/gpu:%d' % gid): vae_loss = U.mean(model.vaeloss) latent_z1_tp = model.latent_z1 latent_z2_tp = model.latent_z2 losses = [U.mean(model.vaeloss), U.mean(model.siam_loss), U.mean(model.kl_loss1), U.mean(model.kl_loss2), U.mean(model.reconst_error1), U.mean(model.reconst_error2), ] siam_max = U.mean(model.max_siam_loss) cls_loss = U.mean(model.cls_loss) tower_vae_loss.append(vae_loss) tower_latent_z1_tp.append(latent_z1_tp) tower_latent_z2_tp.append(latent_z2_tp) tower_losses.append(losses) tower_siam_max.append(siam_max) tower_reconst1.append(model.reconst1) tower_reconst2.append(model.reconst2) tower_cls_loss.append(cls_loss) tf.summary.scalar('Total Loss', losses[0]) tf.summary.scalar('Siam Loss', losses[1]) tf.summary.scalar('kl1_loss', losses[2]) tf.summary.scalar('kl2_loss', losses[3]) tf.summary.scalar('reconst_err1', losses[4]) tf.summary.scalar('reconst_err2', losses[5]) tf.summary.scalar('Siam Max', siam_max) vae_loss = U.mean(tower_vae_loss) siam_max = U.mean(tower_siam_max) latent_z1_tp = tf.concat(tower_latent_z1_tp, 0) latent_z2_tp = tf.concat(tower_latent_z2_tp, 0) model_reconst1 = tf.concat(tower_reconst1, 0) model_reconst2 = tf.concat(tower_reconst2, 0) cls_loss = U.mean(tower_cls_loss) losses = [[] for _ in range(len(losses))] for tl in tower_losses: for i, l in enumerate(tl): losses[i].append(l) losses = [U.mean(l) for l in losses] siam_normal = losses[1] / entangled_feat tf.summary.scalar('total/Total Loss', losses[0]) tf.summary.scalar('total/Siam Loss', losses[1]) tf.summary.scalar('total/kl1_loss', losses[2]) tf.summary.scalar('total/kl2_loss', losses[3]) tf.summary.scalar('total/reconst_err1', losses[4]) tf.summary.scalar('total/reconst_err2', losses[5]) tf.summary.scalar('total/Siam Normal', siam_normal) tf.summary.scalar('total/Siam Max', siam_max) compute_losses = U.function([img1, img2], vae_loss) all_var_list = model.get_trainable_variables() vae_var_list = [v for v in all_var_list if v.name.split("/")[2].startswith("vae")] cls_var_list = [v for v in all_var_list if v.name.split("/")[2].startswith("cls")] warn("{}".format(all_var_list)) warn("==========================") warn("{}".format(vae_var_list)) # warn("==========================") # warn("{}".format(cls_var_list)) # with tf.device('/cpu:0'): optimizer = tf.train.AdamOptimizer(learning_rate=lr, epsilon = 0.01/batch_size) optimize_expr1 = optimizer.minimize(vae_loss, var_list=vae_var_list) feat_cls_optimizer = tf.train.AdagradOptimizer(learning_rate=0.01) optimize_expr2 = feat_cls_optimizer.minimize(cls_loss, var_list=cls_var_list) merged = tf.summary.merge_all() train = U.function([img1, img2], [losses[0], losses[1], losses[2], losses[3], losses[4], losses[5], latent_z1_tp, latent_z2_tp, merged], updates = [optimize_expr1]) get_reconst_img = U.function([img1, img2], [model_reconst1, model_reconst2, latent_z1_tp, latent_z2_tp]) get_latent_var = U.function([img1, img2], [latent_z1_tp, latent_z2_tp]) cur_dir = get_cur_dir() chk_save_dir = os.path.join(cur_dir, chkfile_name) log_save_dir = os.path.join(cur_dir, logfile_name) validate_img_saver_dir = os.path.join(cur_dir, validatefile_name) if dataset == 'chairs' or dataset == 'celeba': test_img_saver_dir = os.path.join(cur_dir, "test_images") testing_img_dir = os.path.join(cur_dir, "dataset/{}/test_img".format(dataset)) train_writer = U.summary_writer(dir = log_save_dir) U.initialize() saver, chk_file_epoch_num = U.load_checkpoints(load_requested = True, checkpoint_dir = chk_save_dir) if dataset == 'chairs' or dataset == 'celeba': validate_img_saver = Img_Saver(Img_dir = validate_img_saver_dir) elif dataset == 'dsprites': validate_img_saver = BW_Img_Saver(Img_dir = validate_img_saver_dir) # Black and White, temporary usage else: warn("Unknown dataset Error") # break warn("dataset: {}".format(dataset)) if dataset == 'chairs' or dataset == 'celeba': training_images_list = read_dataset(img_dir) n_total_train_data = len(training_images_list) testing_images_list = read_dataset(testing_img_dir) n_total_testing_data = len(testing_images_list) elif dataset == 'dsprites': cur_dir = osp.join(cur_dir, 'dataset') cur_dir = osp.join(cur_dir, 'dsprites') img_dir = osp.join(cur_dir, 'dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz') manager = DataManager(img_dir, batch_size) else: warn("Unknown dataset Error") # break meta_saved = False if mode == 'train': for epoch_idx in range(chk_file_epoch_num+1, max_epoch): t_epoch_start = time.time() num_batch = manager.get_len() for batch_idx in range(num_batch): if dataset == 'chairs' or dataset == 'celeba': idx = random.sample(range(n_total_train_data), 2*batch_size) batch_files = [training_images_list[i] for i in idx] [images1, images2] = load_image(dir_name = img_dir, img_names = batch_files) elif dataset == 'dsprites': [images1, images2] = manager.get_next() img1, img2 = images1, images2 [l1, l2, _, _] = get_reconst_img(img1, img2) [loss0, loss1, loss2, loss3, loss4, loss5, latent1, latent2, summary] = train(img1, img2) if batch_idx % 50 == 1: header("******* epoch: {}/{} batch: {}/{} *******".format(epoch_idx, max_epoch, batch_idx, num_batch)) warn("Total Loss: {}".format(loss0)) warn("Siam loss: {}".format(loss1)) warn("kl1_loss: {}".format(loss2)) warn("kl2_loss: {}".format(loss3)) warn("reconst_err1: {}".format(loss4)) warn("reconst_err2: {}".format(loss5)) if batch_idx % check_every_n == 1: if dataset == 'chairs' or dataset == 'celeba': idx = random.sample(range(len(training_images_list)), 2*5) validate_batch_files = [training_images_list[i] for i in idx] [images1, images2] = load_image(dir_name = img_dir, img_names = validate_batch_files) elif dataset == 'dsprites': [images1, images2] = manager.get_next() [reconst1, reconst2, _, _] = get_reconst_img(images1, images2) if dataset == 'chairs': for img_idx in range(len(images1)): sub_dir = "iter_{}_{}".format(epoch_idx, batch_idx) save_img = np.squeeze(images1[img_idx]) save_img = Image.fromarray(save_img) img_file_name = "{}_ori.png".format(validate_batch_files[img_idx].split('.')[0]) validate_img_saver.save(save_img, img_file_name, sub_dir = sub_dir) save_img = np.squeeze(reconst1[img_idx]) save_img = Image.fromarray(save_img) img_file_name = "{}_rec.png".format(validate_batch_files[img_idx].split('.')[0]) validate_img_saver.save(save_img, img_file_name, sub_dir = sub_dir) elif dataset == 'celeba': for img_idx in range(len(images1)): sub_dir = "iter_{}_{}".format(epoch_idx, batch_idx) save_img = np.squeeze(images1[img_idx]) save_img = Image.fromarray(save_img, 'RGB') img_file_name = "{}_ori.png".format(validate_batch_files[img_idx].split('.')[0]) validate_img_saver.save(save_img, img_file_name, sub_dir = sub_dir) save_img = np.squeeze(reconst1[img_idx]) save_img = Image.fromarray(save_img, 'RGB') img_file_name = "{}_rec.png".format(validate_batch_files[img_idx].split('.')[0]) validate_img_saver.save(save_img, img_file_name, sub_dir = sub_dir) elif dataset == 'dsprites': for img_idx in range(len(images1)): sub_dir = "iter_{}_{}".format(epoch_idx, batch_idx) # save_img = images1[img_idx].reshape(64, 64) save_img = np.squeeze(images1[img_idx]) save_img = save_img.astype(np.float32) img_file_name = "{}_ori.jpg".format(img_idx) validate_img_saver.save(save_img, img_file_name, sub_dir = sub_dir) # save_img = reconst1[img_idx].reshape(64, 64) save_img = np.squeeze(reconst1[img_idx]) save_img = save_img.astype(np.float32) img_file_name = "{}_rec.jpg".format(img_idx) validate_img_saver.save(save_img, img_file_name, sub_dir = sub_dir) if batch_idx % loss_check_n == 1: train_writer.add_summary(summary, batch_idx) t_epoch_end = time.time() t_epoch_run = t_epoch_end - t_epoch_start if dataset == 'dsprites': t_check = manager.sample_size / t_epoch_run warn("==========================================") warn("Run {} th epoch in {} sec: {} images / sec".format(epoch_idx+1, t_epoch_run, t_check)) warn("==========================================") if meta_saved == True: saver.save(U.get_session(), chk_save_dir + '/' + 'checkpoint', global_step = epoch_idx, write_meta_graph = False) else: print "Save meta graph" saver.save(U.get_session(), chk_save_dir + '/' + 'checkpoint', global_step = epoch_idx, write_meta_graph = True) meta_saved = True
def train_net(model, manager, chkfile_name, logfile_name, validatefile_name, entangled_feat, max_iter = 6000001, check_every_n = 1000, loss_check_n = 10, save_model_freq = 5000, batch_size = 32): img1 = U.get_placeholder_cached(name="img1") img2 = U.get_placeholder_cached(name="img2") # Testing # img_test = U.get_placeholder_cached(name="img_test") # reconst_tp = U.get_placeholder_cached(name="reconst_tp") vae_loss = U.mean(model.vaeloss) latent_z1_tp = model.latent_z1 latent_z2_tp = model.latent_z2 losses = [U.mean(model.vaeloss), U.mean(model.siam_loss), U.mean(model.kl_loss1), U.mean(model.kl_loss2), U.mean(model.reconst_error1), U.mean(model.reconst_error2), ] siam_normal = losses[1]/entangled_feat siam_max = U.mean(model.max_siam_loss) tf.summary.scalar('Total Loss', losses[0]) tf.summary.scalar('Siam Loss', losses[1]) tf.summary.scalar('kl1_loss', losses[2]) tf.summary.scalar('kl2_loss', losses[3]) tf.summary.scalar('reconst_err1', losses[4]) tf.summary.scalar('reconst_err2', losses[5]) tf.summary.scalar('Siam Normal', siam_normal) tf.summary.scalar('Siam Max', siam_max) # decoded_img = [model.reconst1, model.reconst2] compute_losses = U.function([img1, img2], vae_loss) lr = 0.005 optimizer=tf.train.AdagradOptimizer(learning_rate=lr) all_var_list = model.get_trainable_variables() # print all_var_list img1_var_list = all_var_list #[v for v in all_var_list if v.name.split("/")[1].startswith("proj1") or v.name.split("/")[1].startswith("unproj1")] optimize_expr1 = optimizer.minimize(vae_loss, var_list=img1_var_list) merged = tf.summary.merge_all() train = U.function([img1, img2], [losses[0], losses[1], losses[2], losses[3], losses[4], losses[5], latent_z1_tp, latent_z2_tp, merged], updates = [optimize_expr1]) get_reconst_img = U.function([img1, img2], [model.reconst1_mean, model.reconst2_mean, latent_z1_tp, latent_z2_tp]) get_latent_var = U.function([img1, img2], [latent_z1_tp, latent_z2_tp]) # testing # test = U.function([img_test], model.latent_z_test) # test_reconst = U.function([reconst_tp], [model.reconst_test]) cur_dir = get_cur_dir() chk_save_dir = os.path.join(cur_dir, chkfile_name) log_save_dir = os.path.join(cur_dir, logfile_name) validate_img_saver_dir = os.path.join(cur_dir, validatefile_name) # test_img_saver_dir = os.path.join(cur_dir, "test_images") # testing_img_dir = os.path.join(cur_dir, "dataset/test_img") train_writer = U.summary_writer(dir = log_save_dir) U.initialize() saver, chk_file_num = U.load_checkpoints(load_requested = True, checkpoint_dir = chk_save_dir) validate_img_saver = BW_Img_Saver(validate_img_saver_dir) # testing # test_img_saver = Img_Saver(test_img_saver_dir) meta_saved = False iter_log = [] loss1_log = [] loss2_log = [] loss3_log = [] training_images_list = manager.imgs # read_dataset(img_dir) n_total_train_data = len(training_images_list) # testing_images_list = read_dataset(testing_img_dir) # n_total_testing_data = len(testing_images_list) training = True testing = False if training == True: for num_iter in range(chk_file_num+1, max_iter): header("******* {}th iter: *******".format(num_iter)) idx = random.sample(range(n_total_train_data), 2*batch_size) batch_files = idx # print batch_files [images1, images2] = manager.get_images(indices = idx) img1, img2 = images1, images2 [l1, l2, _, _] = get_reconst_img(img1, img2) [loss0, loss1, loss2, loss3, loss4, loss5, latent1, latent2, summary] = train(img1, img2) warn("Total Loss: {}".format(loss0)) warn("Siam loss: {}".format(loss1)) warn("kl1_loss: {}".format(loss2)) warn("kl2_loss: {}".format(loss3)) warn("reconst_err1: {}".format(loss4)) warn("reconst_err2: {}".format(loss5)) # warn("num_iter: {} check: {}".format(num_iter, check_every_n)) # warn("Total Loss: {}".format(loss6)) if num_iter % check_every_n == 1: header("******* {}th iter: *******".format(num_iter)) idx = random.sample(range(len(training_images_list)), 2*5) [images1, images2] = manager.get_images(indices = idx) [reconst1, reconst2, _, _] = get_reconst_img(images1, images2) # for i in range(len(latent1[0])): # print "{} th: {:.2f}".format(i, np.mean(np.abs(latent1[:, i] - latent2[:, i]))) for img_idx in range(len(images1)): sub_dir = "iter_{}".format(num_iter) save_img = images1[img_idx].reshape(64, 64) save_img = save_img.astype(np.float32) img_file_name = "{}_ori.jpg".format(img_idx) validate_img_saver.save(save_img, img_file_name, sub_dir = sub_dir) save_img = reconst1[img_idx].reshape(64, 64) save_img = save_img.astype(np.float32) img_file_name = "{}_rec.jpg".format(img_idx) validate_img_saver.save(save_img, img_file_name, sub_dir = sub_dir) if num_iter % loss_check_n == 1: train_writer.add_summary(summary, num_iter) if num_iter > 11 and num_iter % save_model_freq == 1: if meta_saved == True: saver.save(U.get_session(), chk_save_dir + '/' + 'checkpoint', global_step = num_iter, write_meta_graph = False) else: print "Save meta graph" saver.save(U.get_session(), chk_save_dir + '/' + 'checkpoint', global_step = num_iter, write_meta_graph = True) meta_saved = True
def train_net(model, img_dir, max_iter = 100000, check_every_n = 20, save_model_freq = 1000, batch_size = 128): img1 = U.get_placeholder_cached(name="img1") img2 = U.get_placeholder_cached(name="img2") mean_loss1 = U.mean(model.match_error) mean_loss2 = U.mean(model.reconst_error1) mean_loss3 = U.mean(model.reconst_error2) decoded_img = [model.reconst1, model.reconst2] weight_loss = [1, 1, 1] compute_losses = U.function([img1, img2], [mean_loss1, mean_loss2, mean_loss3]) lr = 0.00001 optimizer=tf.train.AdamOptimizer(learning_rate=lr, epsilon = 0.01/batch_size) all_var_list = model.get_trainable_variables() img1_var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("proj1") or v.name.split("/")[1].startswith("unproj1")] img2_var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("proj2") or v.name.split("/")[1].startswith("unproj2")] img1_loss = mean_loss1 + mean_loss2 img2_loss = mean_loss1 + mean_loss3 optimize_expr1 = optimizer.minimize(img1_loss, var_list=img1_var_list) optimize_expr2 = optimizer.minimize(img2_loss, var_list=img2_var_list) img1_train = U.function([img1, img2], [mean_loss1, mean_loss2, mean_loss3], updates = [optimize_expr1]) img2_train = U.function([img1, img2], [mean_loss1, mean_loss2, mean_loss3], updates = [optimize_expr2]) get_reconst_img = U.function([img1, img2], decoded_img) U.initialize() name = "test" cur_dir = get_cur_dir() chk_save_dir = os.path.join(cur_dir, "chkfiles") log_save_dir = os.path.join(cur_dir, "log") test_img_saver_dir = os.path.join(cur_dir, "test_images") saver, chk_file_num = U.load_checkpoints(load_requested = True, checkpoint_dir = chk_save_dir) test_img_saver = Img_Saver(test_img_saver_dir) meta_saved = False iter_log = [] loss1_log = [] loss2_log = [] loss3_log = [] training_images_list = read_dataset(img_dir) for num_iter in range(chk_file_num+1, max_iter): header("******* {}th iter: Img {} side *******".format(num_iter, num_iter%2 + 1)) idx = random.sample(range(len(training_images_list)), batch_size) batch_files = [training_images_list[i] for i in idx] [images1, images2] = load_image(dir_name = img_dir, img_names = batch_files) img1, img2 = images1, images2 # args = images1, images2 if num_iter%2 == 0: [loss1, loss2, loss3] = img1_train(img1, img2) elif num_iter%2 == 1: [loss1, loss2, loss3] = img2_train(img1, img2) warn("match_error: {}".format(loss1)) warn("reconst_err1: {}".format(loss2)) warn("reconst_err2: {}".format(loss3)) warn("num_iter: {} check: {}".format(num_iter, check_every_n)) if num_iter % check_every_n == 1: idx = random.sample(range(len(training_images_list)), 10) test_batch_files = [training_images_list[i] for i in idx] [images1, images2] = load_image(dir_name = img_dir, img_names = test_batch_files) [reconst1, reconst2] = get_reconst_img(images1, images2) for img_idx in range(len(images1)): sub_dir = "iter_{}".format(num_iter) save_img = np.squeeze(images1[img_idx]) save_img = Image.fromarray(save_img) img_file_name = "{}_ori_2d.jpg".format(test_batch_files[img_idx]) test_img_saver.save(save_img, img_file_name, sub_dir = sub_dir) save_img = np.squeeze(images2[img_idx]) save_img = Image.fromarray(save_img) img_file_name = "{}_ori_3d.jpg".format(test_batch_files[img_idx]) test_img_saver.save(save_img, img_file_name, sub_dir = sub_dir) save_img = np.squeeze(reconst1[img_idx]) save_img = Image.fromarray(save_img) img_file_name = "{}_rec_2d.jpg".format(test_batch_files[img_idx]) test_img_saver.save(save_img, img_file_name, sub_dir = sub_dir) save_img = np.squeeze(reconst2[img_idx]) save_img = Image.fromarray(save_img) img_file_name = "{}_rec_3d.jpg".format(test_batch_files[img_idx]) test_img_saver.save(save_img, img_file_name, sub_dir = sub_dir) if num_iter > 11 and num_iter % save_model_freq == 1: if meta_saved == True: saver.save(U.get_session(), chk_save_dir + '/' + 'checkpoint', global_step = num_iter, write_meta_graph = False) else: print "Save meta graph" saver.save(U.get_session(), chk_save_dir + '/' + 'checkpoint', global_step = num_iter, write_meta_graph = True) meta_saved = True