Пример #1
0
    def __init__(self, flags):
        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction = 1)
        tf_config = tf.ConfigProto(allow_soft_placement = True, gpu_options = gpu_options)
        self.sess = tf.Session(config = tf_config)
        self.flags = flags
        self.iter_time = 0
        self.num_examples_IS = 1000
        self._make_folders()
        self._init_logger()

        self.dataset = Dataset(self.sess, self.flags, self.flags.dataset, log_path=self.log_out_dir)
        self.model = WGAN_GP(self.sess, self.flags, self.dataset, log_path=self.log_out_dir)

        self.saver = tf.train.Saver()
        self.sess.run([tf.global_variables_initializer()])
Пример #2
0
    def __init__(self, flags):
        run_config = tf.ConfigProto()
        run_config.gpu_options.allow_growth = True
        self.sess = tf.Session(config=run_config)

        self.flags = flags
        self.iter_time = 0
        self._make_folders()
        self._init_logger()

        self.dataset = Dataset(self.flags.dataset,
                               self.flags,
                               image_size=(128, 256, 3))
        self.model = WGAN_GP(self.sess,
                             self.flags,
                             self.dataset.image_size,
                             self.dataset(),
                             log_path=self.log_out_dir)

        self.saver = tf.train.Saver()
        self.sess.run(tf.global_variables_initializer())
Пример #3
0
def main():
	opt = TrainOptions().parse()

	# writer
	writer = SummaryWriter(os.path.join(opt.log_dir, 'runs'))

	# dataset
	transform = transforms.Compose([
		transforms.Resize((opt.input_size, opt.input_size)),
		transforms.RandomHorizontalFlip(),
		transforms.ToTensor(),
		transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
	])
	if opt.dataset == 'cifar10':
		dataset = torchvision.datasets.CIFAR10('data', train=True, download=True, transform=transform)
	elif opt.dataset == 'cifar100':
		dataset = torchvision.datasets.CIFAR100('data', train=True, download=True, transform=transform)
	elif opt.dataset == 'stl10':
		dataset = torchvision.datasets.STL10('data', split='train', download=True, transform=transform)
	loader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers)

	# model
	wgan_gp = WGAN_GP(opt)
	wgan_gp.train(loader, opt, writer)
Пример #4
0
def main():
    # parse arguments
    args = parse_args()
    print(args)
    if args is None:
        exit()

    gan = WGAN_GP(args)

    # launch the graph in a session
    gan.train()
    print(" [*] Training finished!")

    # visualize learned generator
    gan.visualize_results(args.epoch)
    print(" [*] Testing finished!")
Пример #5
0
from wgan_gp import WGAN_GP

sys.path.append("..")
import utils

dataset, _, timesteps = utils.load_splitted_dataset()
# dataset, _, timesteps = utils.load_resized_mnist()

use_mbd = False
use_packing = False
packing_degree = 2

run_dir, img_dir, model_dir, generated_datesets_dir = utils.generate_run_dir()

config_2 = {
    'timesteps': timesteps,
    'use_mbd': use_mbd,
    'use_packing': use_packing,
    'packing_degree': packing_degree,
    'run_dir': run_dir,
    'img_dir': img_dir,
    'model_dir': model_dir,
    'generated_datesets_dir': generated_datesets_dir,
}

config = utils.merge_config_and_save(config_2)

wgan_gp = WGAN_GP(config)
losses = wgan_gp.train(dataset)
Пример #6
0
def main():
    parser = argparse.ArgumentParser(description="WGAN-GP")

    # Saving parameters
    parser.add_argument("--name",
                        "-n",
                        "-id",
                        type=str,
                        default=str(int(time.time())),
                        help="Name/ID of the current training model")
    parser.add_argument("--resume_from",
                        "-rf",
                        type=int,
                        default=0,
                        help="Number of epoch to resume from (if existing)")
    parser.add_argument(
        "--checkpoint_interval",
        "-ci",
        type=int,
        default=20,
        help=
        "Number of epoch before saving a checkpoint (0 to disable checkpoints) (default = 20)"
    )

    # Model hyper parameters
    parser.add_argument("--learning_rate_d",
                        "-lrd",
                        type=float,
                        default=2e-4,
                        help="Learning rate of the critic (default = 2e-4)")
    parser.add_argument("--learning_rate_g",
                        "-lrg",
                        type=float,
                        default=2e-4,
                        help="Learning rate of the generator (default = 2e-4)")
    parser.add_argument("--beta_1",
                        "-b1",
                        type=float,
                        default=0.5,
                        help="BETA 1 of the optimizer (default = 0.5)")
    parser.add_argument("--beta_2",
                        "-b2",
                        type=float,
                        default=0.9,
                        help="BETA 2 of the optimizer (default = 0.9)")
    parser.add_argument("--training_ratio",
                        "-tr",
                        type=int,
                        default=5,
                        help="Training ratio of the critic (default = 5)")
    parser.add_argument(
        "--gradient_penalty_weight",
        "-gpd",
        type=int,
        default=10,
        help="Gradient penalty weight applied to the critic (default = 10)")
    parser.add_argument(
        "--z_size",
        type=int,
        default=128,
        help="Size of the noise vector of the generator (default = 128)")

    # General hyper parameters
    parser.add_argument("--epoch",
                        type=int,
                        default=10000,
                        help="Number of epoch to train (default = 10000)")
    parser.add_argument("--batch_size",
                        "-bs",
                        type=int,
                        default=512,
                        help="Size of the dataset mini-batch (default = 512)")
    parser.add_argument(
        "--buffer_size",
        "-bus",
        type=int,
        default=2048,
        help="Size of the buffer of the dataset iterator (default = 2048)")
    parser.add_argument(
        "--prefetch_size",
        "-ps",
        type=int,
        default=10,
        help="Size of prefetching of the dataset iterator (default = 10)")

    # Layers hyper parameters
    parser.add_argument(
        "--bn_momentum",
        "-bm",
        type=float,
        default=0.8,
        help="Momentum of the batch normalization layer (default = 0.8)")
    parser.add_argument("--lr_alpha",
                        "-la",
                        type=float,
                        default=0.2,
                        help="Alpha of the LeakyReLU layer (default = 0.2)")
    parser.add_argument(
        "--kernel_size",
        "-ks",
        type=int,
        default=5,
        help=
        "Size of the kernel of the convolutional layer (best if odd) (default = 5)"
    )
    parser.add_argument(
        "--rn_stddev",
        "-rs",
        type=float,
        default=0.02,
        help=
        "Standard deviation of the initialization of the weights of each layers (default = 0.02)"
    )
    parser.add_argument(
        "--min_weight",
        "-mw",
        type=int,
        default=5,
        help=
        "Minimum size pow(2, mw) of the first layer of convolutional layer (doubles each times) (default = 5)"
    )

    # Dataset parameters
    parser.add_argument("--type",
                        "-t",
                        type=str,
                        default="digits",
                        choices=[
                            "custom", "digits", "fashion", "cifar10",
                            "cifar100", "celebA_128", "LAG48", "LAG128",
                            "cars64"
                        ],
                        help="Type of dataset to use (default = 'digits')")
    args = parser.parse_args()
    print(args)

    from wgan_gp import WGAN_GP
    from toolbox import extract_mnist
    from tensorflow.keras.datasets import mnist, fashion_mnist
    from tensorflow.keras.datasets import cifar10, cifar100

    if args.type == "custom":
        print("Custom type is not yet implemented !")
        return
    elif args.type in ["digits", "fashion"]:
        sample_shape = (7, 7)
        output_shape = (28, 28, 1)
        min_wh = 7
        tensor_to_img = False
        nb_layers = 3
        data_dir = "keras"
        X_train = extract_mnist((mnist, fashion_mnist)[args.type == "fashion"])
    elif args.type in ["cifar10", "cifar100"]:
        sample_shape = (7, 7)
        output_shape = (32, 32, 3)
        min_wh = 4
        tensor_to_img = False
        nb_layers = 4
        data_dir = "keras"
        X_train = (cifar10, cifar100)[args.type == "cifar100"]
        X_train = extract_mnist(X_train,
                                img_shape=output_shape)  # , label = 1)
    elif args.type == "celebA_128":
        sample_shape = (5, 5)
        output_shape = (128, 128, 3)
        min_wh = 4
        data_dir = './datasets/celebA_128'
        X_train = np.array(os.listdir(data_dir))
        tensor_to_img = True
        nb_layers = 6
    elif args.type == "LAG48":
        sample_shape = (5, 5)
        output_shape = (48, 48, 3)
        min_wh = 3
        data_dir = './datasets/_binarynumpy/normalized_LAGdataset_48.npy'
        X_train = np.load(data_dir)
        tensor_to_img = False
        nb_layers = 5
    elif args.type == "LAG128":
        sample_shape = (5, 5)
        output_shape = (128, 128, 3)
        min_wh = 4
        data_dir = './datasets/_binarynumpy/normalized_LAGdataset_128.npy'
        X_train = np.load(data_dir)
        tensor_to_img = False
        nb_layers = 6
    elif args.type == "cars64":
        sample_shape = (5, 5)
        output_shape = (64, 64, 3)
        min_wh = 4
        data_dir = './datasets/_binarynumpy/normalized_cars.npy'
        X_train = np.load(data_dir)
        tensor_to_img = False
        nb_layers = 5

    name = f"wgan-gp_{args.type}_{args.name}"
    weights = [
        pow(2, i) for i in range(args.min_weight, args.min_weight + nb_layers)
    ]

    model = WGAN_GP(name, args.learning_rate_d, args.learning_rate_g,
                    args.beta_1, args.beta_2, args.training_ratio,
                    args.gradient_penalty_weight, args.z_size,
                    args.bn_momentum, args.lr_alpha, args.kernel_size,
                    args.rn_stddev)
    model.feed_data(X_train, data_dir, tensor_to_img, args.batch_size,
                    args.buffer_size, args.prefetch_size)
    model.set_output(sample_shape, output_shape)
    model.create_model(args.min_weight, min_wh, weights, nb_layers)
    model.print_desc(args.resume_from)
    model.train(args.epoch, args.checkpoint_interval, args.resume_from)
Пример #7
0
def main():
    # initialise parameters
    batch_size = 64
    n_epochs = 50
    n_critic_steps = 2      # no of steps to train critic before training generator.
    lr = 1e-4
    z_size = 100
    img_size = 32
    img_channel_size = 3
    c_channel_size = 64
    g_channel_size = 64
    penalty_strength = 10
    samples_dir = 'experimentx'
    create_folder(samples_dir)
    create_folder(f"{samples_dir}/generator")
    create_folder(f"{samples_dir}/critic")

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda:0" if use_cuda else "cpu")

    # create GAN
    wgan = WGAN_GP(z_size, penalty_strength, img_size, img_channel_size, c_channel_size, g_channel_size, device).to(device)
    print('Critic summary: ', wgan.critic)
    print('Generator summary: ', wgan.generator)

    # create optimizers
    critic_optimizer = torch.optim.Adam(wgan.critic.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=0)
    generator_optimizer = torch.optim.Adam(wgan.generator.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=0)

    # get data loaders
    train_loader, valid_loader, test_loader = get_data_loader("data/svhn", batch_size)

    # train
    step_no = 0
    critic_losses = []
    generator_losses = []
    for epoch in range(n_epochs):
        print('Epoch ', epoch)
        # fix latents to use to visualisation.
        fixed_latents = Variable(wgan.generator.sample_latent(64))
        if use_cuda:
            fixed_latents = fixed_latents.cuda()

        # save 1 real image
        for i, (x, _) in enumerate(train_loader):
            if i == 0:
                save_image((x.cpu() + 1) / 2, f"{samples_dir}/original_sample_{i}.png")
                break

        wgan.train()
        for i, (x, _) in enumerate(train_loader):
            step_no += 1
            # print('Step: {}----'.format(step_no))
            x = x.to(device)

            # train critic
            critic_optimizer.zero_grad()
            critic_loss = wgan.critic_loss(x)
            critic_loss.backward()
            critic_optimizer.step()

            # save critic loss
            # print('Critic loss: {}'.format(critic_loss.item()))
            critic_losses.append(critic_loss.item())

            # train generator every 'n_critic_steps' times.
            if step_no % n_critic_steps == 0:
                generator_optimizer.zero_grad()
                z = wgan.sample_noise(batch_size)
                generator_loss = wgan.generator_loss(z)
                generator_loss.backward()
                generator_optimizer.step()

                # save generator loss
                # print('Generator loss: {}'.format(generator_loss.item()))
                generator_losses.append(generator_loss.item())

        # generate and save images.
        # save model
        torch.save(wgan.generator.state_dict(), f"{samples_dir}/generator/epoch_{epoch}.pt")
        torch.save(wgan.critic.state_dict(), f"{samples_dir}/critic/epoch_{epoch}.pt")
        # save image
        save_image((wgan.generator(fixed_latents).cpu() + 1) / 2, f"{samples_dir}/generated_train_{epoch}.png")
Пример #8
0
class Solver(object):
    def __init__(self, flags):
        run_config = tf.ConfigProto()
        run_config.gpu_options.allow_growth = True
        self.sess = tf.Session(config=run_config)

        self.flags = flags
        self.iter_time = 0
        self.num_examples_IS = 1000
        self._make_folders()
        self._init_logger()

        self.dataset = Dataset(self.sess,
                               self.flags,
                               self.flags.dataset,
                               log_path=self.log_out_dir)
        self.model = WGAN_GP(self.sess,
                             self.flags,
                             self.dataset,
                             log_path=self.log_out_dir)

        self.saver = tf.train.Saver()
        self.sess.run([tf.global_variables_initializer()])

        # tf_utils.show_all_variables()

    def _make_folders(self):
        if self.flags.is_train:  # train stage
            if self.flags.load_model is None:
                cur_time = datetime.now().strftime("%Y%m%d-%H%M")
                self.model_out_dir = "{}/model/{}".format(
                    self.flags.dataset, cur_time)
                if not os.path.isdir(self.model_out_dir):
                    os.makedirs(self.model_out_dir)
            else:
                cur_time = self.flags.load_model
                self.model_out_dir = "{}/model/{}".format(
                    self.flags.dataset, cur_time)

            self.sample_out_dir = "{}/sample/{}".format(
                self.flags.dataset, cur_time)
            if not os.path.isdir(self.sample_out_dir):
                os.makedirs(self.sample_out_dir)

            self.log_out_dir = "{}/logs/{}".format(self.flags.dataset,
                                                   cur_time)
            self.train_writer = tf.summary.FileWriter(
                "{}/logs/{}".format(self.flags.dataset, cur_time),
                graph_def=self.sess.graph_def)

        elif not self.flags.is_train:  # test stage
            self.model_out_dir = "{}/model/{}".format(self.flags.dataset,
                                                      self.flags.load_model)
            self.test_out_dir = "{}/test/{}".format(self.flags.dataset,
                                                    self.flags.load_model)
            self.log_out_dir = "{}/logs/{}".format(self.flags.dataset,
                                                   self.flags.load_model)

            if not os.path.isdir(self.test_out_dir):
                os.makedirs(self.test_out_dir)

    def _init_logger(self):
        formatter = logging.Formatter('%(asctime)s:%(name)s:%(message)s')
        # file handler
        file_handler = logging.FileHandler(
            os.path.join(self.log_out_dir, 'solver.log'))
        file_handler.setFormatter(formatter)
        file_handler.setLevel(logging.INFO)
        # stream handler
        stream_handler = logging.StreamHandler()
        stream_handler.setFormatter(formatter)
        # add handlers
        logger.addHandler(file_handler)
        logger.addHandler(stream_handler)

        if self.flags.is_train:
            logger.info('gpu_index: {}'.format(self.flags.gpu_index))
            logger.info('batch_size: {}'.format(self.flags.batch_size))
            logger.info('dataset: {}'.format(self.flags.dataset))

            logger.info('is_train: {}'.format(self.flags.is_train))
            logger.info('learning_rate: {}'.format(self.flags.learning_rate))
            logger.info('num_critic: {}'.format(self.flags.num_critic))
            logger.info('z_dim: {}'.format(self.flags.z_dim))
            logger.info('lambda_: {}'.format(self.flags.lambda_))
            logger.info('beta1: {}'.format(self.flags.beta1))
            logger.info('beta2: {}'.format(self.flags.beta2))

            logger.info('iters: {}'.format(self.flags.iters))
            logger.info('print_freq: {}'.format(self.flags.print_freq))
            logger.info('save_freq: {}'.format(self.flags.save_freq))
            logger.info('sample_freq: {}'.format(self.flags.sample_freq))
            logger.info('inception_freq: {}'.format(self.flags.inception_freq))
            logger.info('sample_batch: {}'.format(self.flags.sample_batch))
            logger.info('load_model: {}'.format(self.flags.load_model))

    def train(self):
        # load initialized checkpoint that provided
        if self.flags.load_model is not None:
            if self.load_model():
                logger.info(' [*] Load SUCCESS!\n')
            else:
                logger.info(' [!] Load Failed...\n')

        # for iter_time in range(self.flags.iters):
        while self.iter_time < self.flags.iters:
            # samppling images and save them
            self.sample(self.iter_time)

            # train_step
            loss, summary = self.model.train_step()
            self.model.print_info(loss, self.iter_time)
            self.train_writer.add_summary(summary, self.iter_time)
            self.train_writer.flush()

            if self.flags.dataset == 'cifar10':
                self.get_inception_score(
                    self.iter_time)  # calculate inception score

            # save model
            self.save_model(self.iter_time)
            self.iter_time += 1

        self.save_model(self.flags.iters)

    def test(self):
        if self.load_model():
            logger.info(' [*] Load SUCCESS!')
        else:
            logger.info(' [!] Load Failed...')

        num_iters = 20
        for iter_time in range(num_iters):
            print('iter_time: {}'.format(iter_time))

            imgs = self.model.test_step()
            self.model.plots(imgs, iter_time, self.test_out_dir)

    def get_inception_score(self, iter_time):
        if np.mod(iter_time, self.flags.inception_freq) == 0:
            sample_size = 100
            all_samples = []
            for _ in range(int(self.num_examples_IS / sample_size)):
                imgs = self.model.sample_imgs(sample_size=sample_size)
                all_samples.append(imgs[0])

            all_samples = np.concatenate(all_samples, axis=0)
            all_samples = ((all_samples + 1.) * 255. / 2.).astype(np.uint8)

            mean_IS, std_IS = get_inception_score(list(all_samples),
                                                  self.flags)
            # print('Inception score iter: {}, IS: {}'.format(self.iter_time, mean_IS))

            plot.plot('inception score', mean_IS)
            plot.flush(self.log_out_dir)  # write logs
            plot.tick()

    def sample(self, iter_time):
        if np.mod(iter_time, self.flags.sample_freq) == 0:
            imgs = self.model.sample_imgs(sample_size=self.flags.sample_batch)
            self.model.plots(imgs, iter_time, self.sample_out_dir)

    def save_model(self, iter_time):
        if np.mod(iter_time + 1, self.flags.save_freq) == 0:
            model_name = 'model'
            self.saver.save(self.sess,
                            os.path.join(self.model_out_dir, model_name),
                            global_step=iter_time)
            logger.info('[*] Model saved! Iter: {}'.format(iter_time))

    def load_model(self):
        logger.info(' [*] Reading checkpoint...')

        ckpt = tf.train.get_checkpoint_state(self.model_out_dir)
        if ckpt and ckpt.model_checkpoint_path:
            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
            self.saver.restore(self.sess,
                               os.path.join(self.model_out_dir, ckpt_name))

            meta_graph_path = ckpt.model_checkpoint_path + '.meta'
            self.iter_time = int(meta_graph_path.split('-')[-1].split('.')[0])

            logger.info('[*] Load iter_time: {}'.format(self.iter_time))
            return True
        else:
            return False
Пример #9
0
class Solver(object):
    def __init__(self, flags):
        run_config = tf.ConfigProto()
        run_config.gpu_options.allow_growth = True
        self.sess = tf.Session(config=run_config)

        self.flags = flags
        self.iter_time = 0
        self._make_folders()
        self._init_logger()

        self.dataset = Dataset(self.flags.dataset,
                               self.flags,
                               image_size=(128, 256, 3))
        self.model = WGAN_GP(self.sess,
                             self.flags,
                             self.dataset.image_size,
                             self.dataset(),
                             log_path=self.log_out_dir)

        self.saver = tf.train.Saver()
        self.sess.run(tf.global_variables_initializer())

    def _make_folders(self):
        if self.flags.is_train:  # train stage
            if self.flags.load_model is None:
                cur_time = datetime.now().strftime("%Y%m%d-%H%M")
                self.model_out_dir = "{}/model/{}".format(
                    self.flags.dataset, cur_time)
                if not os.path.isdir(self.model_out_dir):
                    os.makedirs(self.model_out_dir)
            else:
                cur_time = self.flags.load_model
                self.model_out_dir = "{}/model/{}".format(
                    self.flags.dataset, cur_time)

            self.sample_out_dir = "{}/sample/{}".format(
                self.flags.dataset, cur_time)
            if not os.path.isdir(self.sample_out_dir):
                os.makedirs(self.sample_out_dir)

            self.log_out_dir = "{}/logs/{}".format(self.flags.dataset,
                                                   cur_time)
            self.train_writer = tf.summary.FileWriter(
                "{}/logs/{}".format(self.flags.dataset, cur_time),
                graph_def=self.sess.graph_def)
        else:  # test stage
            self.model_out_dir = "{}/model/{}".format(self.flags.dataset,
                                                      self.flags.load_model)
            self.test_out_dir = "{}/test/{}".format(self.flags.dataset,
                                                    self.flags.load_model)
            self.log_out_dir = "{}/logs/{}".format(self.flags.dataset,
                                                   self.flags.load_model)

            if not os.path.isdir(self.test_out_dir):
                os.makedirs(self.test_out_dir)

    def _init_logger(self):
        formatter = logging.Formatter('%(asctime)s:%(name)s:%(message)s')
        # file handler
        file_handler = logging.FileHandler(
            os.path.join(self.log_out_dir, 'solver.log'))
        file_handler.setFormatter(formatter)
        file_handler.setLevel(logging.INFO)
        # stream handler
        stream_handler = logging.StreamHandler()
        stream_handler.setFormatter(formatter)
        # add handlers
        logger.addHandler(file_handler)
        logger.addHandler(stream_handler)

        if self.flags.is_train:
            logger.info('gpu_index: {}'.format(self.flags.gpu_index))
            logger.info('batch_size: {}'.format(self.flags.batch_size))
            logger.info('dataset: {}'.format(self.flags.dataset))

            logger.info('is_train: {}'.format(self.flags.is_train))
            logger.info('learning_rate: {}'.format(self.flags.learning_rate))
            logger.info('num_critic: {}'.format(self.flags.num_critic))
            logger.info('z_dim: {}'.format(self.flags.z_dim))
            logger.info('lambda_: {}'.format(self.flags.lambda_))
            logger.info('beta1: {}'.format(self.flags.beta1))
            logger.info('beta2: {}'.format(self.flags.beta2))

            logger.info('iters: {}'.format(self.flags.iters))
            logger.info('print_freq: {}'.format(self.flags.print_freq))
            logger.info('save_freq: {}'.format(self.flags.save_freq))
            logger.info('sample_freq: {}'.format(self.flags.sample_freq))
            logger.info('sample_batch: {}'.format(self.flags.sample_batch))
            logger.info('load_model: {}'.format(self.flags.load_model))

    def train(self):
        # load initialized checkpoint that provided
        if self.flags.load_model is not None:
            if self.load_model():
                logger.info(' [*] Load SUCCESS!\n')
            else:
                logger.info(' [!] Load Failed...\n')

        # threads for tfrecord
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=self.sess, coord=coord)

        try:
            while self.iter_time < self.flags.iters:
                # samppling images and save them
                self.sample(self.iter_time)

                # train_step
                loss, summary = self.model.train_step()
                self.model.print_info(loss, self.iter_time)
                self.train_writer.add_summary(summary, self.iter_time)
                self.train_writer.flush()

                # save model
                self.save_model(self.iter_time)
                self.iter_time += 1

            self.save_model(self.flags.iters)

        except KeyboardInterrupt:
            coord.request_stop()
        except Exception as e:
            coord.request_stop(e)
        finally:
            # when done, ask the thread to stop
            coord.request_stop()
            coord.join(threads)

    def test(self):
        if self.load_model():
            print(' [*] Load SUCCESS!')
        else:
            print(' [!] Load Failed...')

        num_iters = 20
        total_time = 0.
        for iter_time in range(num_iters):
            print('iter_time: {}'.format(iter_time))

            # measure inference time
            start_time = time.time()
            imgs = self.model.sample_test()  # inference
            total_time += time.time() - start_time
            self.model.plots_test(imgs, iter_time, self.test_out_dir)

        print('Avg PT: {:.2f} msec.'.format(total_time / num_iters * 1000.))

    def sample(self, iter_time):
        if np.mod(iter_time, self.flags.sample_freq) == 0:
            imgs = self.model.sample_imgs(sample_size=self.flags.sample_batch)
            self.model.plots(imgs, iter_time, self.sample_out_dir)

    def save_model(self, iter_time):
        if np.mod(iter_time + 1, self.flags.save_freq) == 0:
            model_name = 'model'
            self.saver.save(self.sess,
                            os.path.join(self.model_out_dir, model_name),
                            global_step=iter_time)
            logger.info(' [*] Model saved! Iter: {}'.format(iter_time))

    def load_model(self):
        logger.info(' [*] Reading checkpoint...')

        ckpt = tf.train.get_checkpoint_state(self.model_out_dir)
        if ckpt and ckpt.model_checkpoint_path:
            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
            self.saver.restore(self.sess,
                               os.path.join(self.model_out_dir, ckpt_name))

            meta_graph_path = ckpt.model_checkpoint_path + '.meta'
            self.iter_time = int(meta_graph_path.split('-')[-1].split('.')[0])

            logger.info(' [*] Load iter_time: {}'.format(self.iter_time))
            return True
        else:
            return False