Ejemplo n.º 1
0
    def compute_inception_score(self, epoch, idx):
        # Generates images and their inception score
        try:
            # Generate images and save them
            sample_op = self.model.generator(self.model.z, train=False)
            generated_images = self.sess.run(
                sample_op,
                feed_dict={
                    self.model.z: self.model.sample_z,
                },
            )
            save_images(
                generated_images,
                image_manifold_size(generated_images.shape[0]),
                './{}/train_{:02d}_{:04d}.png'.format(self.model.sample_dir,
                                                      epoch, idx))

            # Compute inception score
            generated_images_list = [(image + 1) * 255 / 2
                                     for image in generated_images]
            score = get_inception_score(generated_images_list,
                                        self.sess,
                                        splits=5)

            return score

        except Exception as e:
            print("Sampling error:", e)
            return np.nan
Ejemplo n.º 2
0
    def train(self):
        self.sess.run([tf.global_variables_initializer(),
                       tf.local_variables_initializer()])  # <-----线程相关不要忘了它
        with self.sess.as_default():

            # counter_ = int(self.load())
            if not os.path.exists("./logs/model"):
                os.makedirs("./logs/model")
            ckpt = tf.train.get_checkpoint_state("./logs/model")
            if ckpt is not None:
                print("[*] Success to read {}".format(ckpt.model_checkpoint_path))
                self.saver.restore(self.sess, ckpt.model_checkpoint_path)
            else:
                print("[*] Failed to find a checkpoint")

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=self.sess, coord=coord)

            # for pro_step in range(30):
            #     batch_z = np.random.uniform(-1, 1, [TFR_process.BATCH_SIZE, self.z_dim]).astype(np.float32)
            #     _, summary_str = self.sess.run([self.d_optim, self.d_sum], feed_dict={self.z: batch_z})
            #     error_d_fake = self.d_loss_fake.eval({self.z: batch_z})
            #     error_d_real = self.d_loss_real.eval()
            #     print("d_loss: {0:.8f}".format(error_d_fake + error_d_real))

            for epoch in range(EPOCH):
                for step in range(STEPS):
                    counter = epoch*STEPS+step
                    batch_z = np.random.uniform(-1, 1, [TFR_process.BATCH_SIZE, self.z_dim]).astype(np.float32)

                    _, summary_str = self.sess.run([self.d_optim, self.d_sum], feed_dict={self.z: batch_z})
                    self.writer.add_summary(summary_str, counter)

                    _, summary_str = self.sess.run([self.g_optim, self.g_sum], feed_dict={self.z: batch_z})
                    self.writer.add_summary(summary_str, counter)
                    _, summary_str = self.sess.run([self.g_optim, self.g_sum], feed_dict={self.z: batch_z})
                    self.writer.add_summary(summary_str, counter)

                    error_d_fake = self.d_loss_fake.eval({self.z: batch_z})
                    error_d_real = self.d_loss_real.eval()
                    error_g = self.g_loss.eval({self.z: batch_z})
                    print("epoch: {0}, step: {1}".format(epoch, step))
                    print("d_loss: {0:.8f}, g_loss: {1:.8f}".format(error_d_fake + error_d_real, error_g))

                    if np.mod(step, 50) == 0:
                        # self.save(epoch, step, counter + counter_)
                        self.saver.save(self.sess, "./logs/model/DCGAN.model", global_step=epoch*STEPS+step)

                        sample_z = np.random.uniform(-1, 1, size=(self.batch_size, self.z_dim))
                        samples = self.sess.run(self.s, feed_dict={self.z: sample_z})
                        utils.save_images(samples, utils.image_manifold_size(samples.shape[0]),
                                          './train_{:02d}_{:04d}.png'.format(epoch, step))
                        # for sample in samples:
                        #     plt.imshow(np.clip(sample, 0, 255).astype(np.uint8))

            coord.request_stop()
            coord.join(threads)
def dcgan_train(z_dim=100, batch_size=BATCH_SIZE):

    end_points = dcgan(z_dim=z_dim, batch_size=batch_size)

    with tf.Session(config=config) as sess:
        saver = tf.train.Saver()
        writer = tf.summary.FileWriter("./logs", sess.graph)
        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())  # <-----线程相关不要忘了它

        sess.run(init_op)

        if not os.path.exists("./logs/model"):
            os.makedirs("./logs/model")
        ckpt = tf.train.get_checkpoint_state("./logs/model")
        if ckpt is not None:
            print("[*] Success to read {}".format(ckpt.model_checkpoint_path))
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            print("[*] Failed to find a checkpoint")

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        for epoch in range(EPOCH):
            for step in range(STEPS):
                counter = epoch * STEPS + step
                batch_z = np.random.uniform(-1, 1, [TFR_process.BATCH_SIZE, z_dim]).astype(np.float32)

                _, summary_str = sess.run([end_points['d_optim'], end_points['d_sum']],
                                          feed_dict={end_points['initial_z']: batch_z})
                writer.add_summary(summary_str, counter)

                _, summary_str = sess.run([end_points['g_optim'], end_points['g_sum']],
                                          feed_dict={end_points['initial_z']: batch_z})
                writer.add_summary(summary_str, counter)
                _, summary_str = sess.run([end_points['g_optim'], end_points['g_sum']],
                                          feed_dict={end_points['initial_z']: batch_z})
                writer.add_summary(summary_str, counter)

                error_d_fake = end_points['d_loss_fake'].eval({end_points['initial_z']: batch_z})
                error_d_real = end_points['d_loss_real'].eval()
                error_g = end_points['g_loss'].eval({end_points['initial_z']: batch_z})
                print("epoch: {0}, step: {1}".format(epoch, step))
                print("d_loss: {0:.8f}, g_loss: {1:.8f}".format(error_d_fake + error_d_real, error_g))

                if np.mod(step, 50) == 0:
                    # save(epoch, step, counter + counter_)
                    saver.save(sess, "./logs/model/DCGAN.model", global_step=epoch * STEPS + step)

                    sample_z = np.random.uniform(-1, 1, size=(batch_size, z_dim))
                    samples = sess.run(end_points['sample_output'], feed_dict={end_points['initial_z']: sample_z})
                    utils.save_images(samples, utils.image_manifold_size(samples.shape[0]),
                                      './train_{:02d}_{:04d}.png'.format(epoch, step))

        coord.request_stop()
        coord.join(threads)
Ejemplo n.º 4
0
def gen_with_wgan():
    if not os.path.exists("./logs/model"):
        tf.logging.info("[*] Failed to find direct './logs/model'")
        return -1

    sess = tf.Session(config=config)

    tmpz = np.array(XIAOBIANTAI_FACE) / 10.0
    tmpz = tmpz.reshape(1, 100)
    sample_z = tmpz

    wgan = WGAN_GP(sess, batch_size=1)
    samples = wgan.train(TRAIN_FLAG=False, batch_z=sample_z)

    utils.save_images(samples, utils.image_manifold_size(samples.shape[0]),
                      './wgan_face.png')
Ejemplo n.º 5
0
def reload_dcgan():
    if not os.path.exists("./logs/model"):
        tf.logging.info("[*] Failed to find direct './logs/model'")
        return -1
    end_points = dcgan()
    saver = tf.train.Saver()
    with tf.Session(config=config) as sess:
        ckpt = tf.train.get_checkpoint_state("./logs/model")
        if ckpt is not None:
            tf.logging.info("[*] Success to read {}".format(
                ckpt.model_checkpoint_path))
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            tf.logging.info("[*] Failed to find a checkpoint")
        sample_z = np.random.uniform(-1, 1, size=(64, 100))
        samples = sess.run(end_points['sample_output'],
                           feed_dict={end_points['initial_z']: sample_z})
        utils.save_images(samples, utils.image_manifold_size(samples.shape[0]),
                          './reload.png')
Ejemplo n.º 6
0
    def train(self, TRAIN_FLAG=True, batch_z=None):
        self.sess.run([
            tf.global_variables_initializer(),
            tf.local_variables_initializer()
        ])  # <-----线程相关不要忘了它

        with self.sess.as_default():

            # 加载预训练模型
            if not os.path.exists("./logs/model"):
                os.makedirs("./logs/model")
            ckpt = tf.train.get_checkpoint_state("./logs/model")
            if ckpt is not None:
                print("[*] Success to read {}".format(
                    ckpt.model_checkpoint_path))
                self.saver.restore(self.sess, ckpt.model_checkpoint_path)
            else:
                print("[*] Failed to find a checkpoint")

            if TRAIN_FLAG == False:
                return self.sess.run(self.s, feed_dict={self.z: batch_z})

            # 线程相关对象初始化
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=self.sess, coord=coord)

            # 进入循环
            for epoch in range(EPOCH):
                for step in range(STEPS):
                    counter = epoch * STEPS + step
                    # 准备数据
                    batch_z = np.random.uniform(
                        -1, 1, [TFR_process.BATCH_SIZE, self.z_dim]).astype(
                            np.float32)

                    # 训练部分
                    # 训练D
                    for _ in range(2):
                        summary_str = self.sess.run(
                            self.clip_updates, feed_dict={self.z: batch_z})
                        _, summary_str = self.sess.run(
                            [self.d_optim, self.d_sum],
                            feed_dict={self.z: batch_z})
                        self.writer.add_summary(summary_str, counter)
                    # 训练G
                    _, summary_str = self.sess.run([self.g_optim, self.g_sum],
                                                   feed_dict={self.z: batch_z})
                    self.writer.add_summary(summary_str, counter)

                    # 获取loss值进行展示
                    error_d = self.d_loss.eval({self.z: batch_z})
                    error_g = self.g_loss.eval({self.z: batch_z})
                    print("epoch: {0}, step: {1}".format(epoch, step))
                    print("d_loss: {0:.8f}, g_loss: {1:.8f}".format(
                        error_d, error_g))

                    # 模型保存与中间结果展示
                    if np.mod(step, 50) == 0:
                        # self.save(epoch, step, counter + counter_)
                        self.saver.save(self.sess,
                                        "./logs/model/WGAN.model",
                                        global_step=epoch * STEPS + step)

                        sample_z = np.random.uniform(-1,
                                                     1,
                                                     size=(self.batch_size,
                                                           self.z_dim))
                        samples = self.sess.run(self.s,
                                                feed_dict={self.z: sample_z})

                        # 将训练过程中生成的头像存放在指定位置
                        train_res_path = 'train_result'
                        if not os.path.exists("./{}".format(train_res_path)):
                            os.mkdir(train_res_path)
                        utils.save_images(
                            samples,
                            utils.image_manifold_size(samples.shape[0]),
                            './train_result/train_{:02d}_{:04d}.png'.format(
                                epoch, step))
            # 线程控制对象关闭
            coord.request_stop()
            coord.join(threads)
Ejemplo n.º 7
0
    def train(self, config):
        d_optim = tf.train.AdamOptimizer(
            config.learning_rate, beta1=config.beta1
        ).minimize(self.d_loss, var_list=self.d_vars)
        g_optim = tf.train.AdamOptimizer(
            config.learning_rate, beta1=config.beta1
        ).minimize(self.g_loss, var_list=self.g_vars)
        try:
            tf.global_variables_initializer().run()
        except:
            tf.initialize_all_variables().run()

        self.g_sum = merge_summary(
            [self.z_sum, self.d__sum, self.G_sum, self.d_loss_fake_sum,
             self.g_loss_sum])
        self.d_sum = merge_summary(
            [self.z_sum, self.d_sum, self.d_loss_real_sum, self.d_loss_sum])
        self.writer = SummaryWriter("./logs", self.sess.graph)

        sample_z = np.random.uniform(-1, 1, size=(self.sample_num, self.z_dim))

        if config.dataset == 'mnist':
            sample_inputs = self.data_X[0: self.sample_num]
            sample_labels = self.data_y[0: self.sample_num]
        else:
            sample_files = self.data[0: self.sample_num]
            sample = [get_image(sample_file,
                                input_height=self.input_height,
                                input_width=self.input_width,
                                resize_height=self.output_height,
                                resize_width=self.output_width,
                                crop=self.crop,
                                grayscale=self.grayscale)
                      for sample_file in sample_files]
            if (self.grayscale):
                sample_inputs = np.array(sample).astype(
                    np.float32)[:, :, :, None]
            else:
                sample_inputs = np.array(sample).astype(np.float32)

        counter = 1
        start_time = time.time()
        could_load, checkpoint_counter = self.load(self.checkpoint_dir)
        if could_load:
            counter = checkpoint_counter
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")

        for epoch in range(config.epoch):
            if config.dataset == 'mnist':
                batch_idxs = min(len(self.data_X),
                                 config.train_size) // config.batch_size
            else:
                self.data = glob(os.path.join(
                    "./data", config.dataset, self.input_fname_pattern))
                batch_idxs = min(
                    len(self.data), config.train_size) // config.batch_size

            for idx in range(0, batch_idxs):
                if config.dataset == 'mnist':
                    batch_images = self.data_X[
                        idx * config.batch_size: (idx + 1) * config.batch_size]
                    batch_labels = self.data_y[
                        idx * config.batch_size: (idx + 1) * config.batch_size]
                else:
                    batch_files = self.data[
                        idx * config.batch_size: (idx + 1) * config.batch_size]
                    batch = [get_image(batch_file,
                                       input_height=self.input_height,
                                       input_width=self.input_width,
                                       resize_height=self.output_height,
                                       resize_width=self.output_width,
                                       crop=self.crop,
                                       grayscale=self.grayscale)
                             for batch_file in batch_files]
                    if self.grayscale:
                        batch_images = np.array(batch).astype(
                            np.float32)[:, :, :, None]
                    else:
                        batch_images = np.array(batch).astype(np.float32)

                batch_z = np.random.uniform(
                    -1, 1, [config.batch_size, self.z_dim]
                    ).astype(np.float32)

                if config.dataset == 'mnist':
                    # Update D network
                    _, summary_str = self.sess.run(
                        [d_optim, self.d_sum],
                        feed_dict={self.inputs: batch_images,
                                   self.z: batch_z,
                                   self.y: batch_labels}
                    )
                    self.writer.add_summary(summary_str, counter)

                    # Update G network
                    _, summary_str = self.sess.run(
                        [g_optim, self.g_sum],
                        feed_dict={self.z: batch_z,
                                   self.y: batch_labels}
                    )
                    self.writer.add_summary(summary_str, counter)

                    # Run g_optim twice to make sure that d_loss does not
                    # go to zero (different from paper)
                    _, summary_str = self.sess.run(
                        [g_optim, self.g_sum],
                        feed_dict={self.z: batch_z,
                                   self.y: batch_labels}
                    )
                    self.writer.add_summary(summary_str, counter)

                    errD_fake = self.d_loss_fake.eval(
                        {self.z: batch_z,
                         self.y: batch_labels}
                    )
                    errD_real = self.d_loss_real.eval(
                        {self.inputs: batch_images,
                        self.y: batch_labels}
                    )
                    errG = self.g_loss.eval(
                        {self.z: batch_z,
                        self.y: batch_labels}
                    )
                else:
                    # Update D network
                    _, summary_str = self.sess.run(
                        [d_optim, self.d_sum],
                        feed_dict={self.inputs: batch_images,
                                   self.z: batch_z}
                    )
                    self.writer.add_summary(summary_str, counter)

                    # Update G network
                    _, summary_str = self.sess.run(
                        [g_optim, self.g_sum],
                        feed_dict={self.z: batch_z}
                    )
                    self.writer.add_summary(summary_str, counter)

                    # Run g_optim twice to make sure that d_loss does not
                    # go to zero (different from paper)
                    _, summary_str = self.sess.run(
                        [g_optim, self.g_sum],
                        feed_dict={self.z: batch_z}
                    )
                    self.writer.add_summary(summary_str, counter)

                    errD_fake = self.d_loss_fake.eval(
                        {self.z: batch_z}
                    )
                    errD_real = self.d_loss_real.eval(
                        {self.inputs: batch_images}
                    )
                    errG = self.g_loss.eval(
                        {self.z: batch_z}
                    )

                counter += 1
                print("Epoch: [%2d/%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, "
                      "g_loss: %.8f" % (epoch, config.epoch, idx, batch_idxs,
                                        time.time() - start_time,
                                        errD_fake + errD_real, errG)
                )

                if counter % 100 == 1:
                    if config.dataset == 'mnist':
                        samples, d_loss, g_loss = self.sess.run(
                            [self.sampler, self.d_loss, self.g_loss],
                            feed_dict={self.z: sample_z,
                                       self.inputs: sample_inputs,
                                       self.y: sample_labels}
                        )
                        save_images(
                            samples, image_manifold_size(samples.shape[0]),
                            './{}/train_{:02d}_{:04d}.png'
                            ''.format(config.sample_dir, epoch, idx)
                        )
                        print("[Sample] d_loss: %.8f, g_loss: %.8f" %
                              (d_loss, g_loss))
                    else:
                        try:
                            samples, d_loss, g_loss = self.sess.run(
                                [self.sampler, self.d_loss, self.g_loss],
                                feed_dict={self.z: sample_z,
                                           self.inputs: sample_inputs},
                            )
                            save_images(
                                samples, image_manifold_size(samples.shape[0]),
                                './{}/train_{:02d}_{:04d}.png'
                                ''.format(config.sample_dir, epoch, idx)
                            )
                            print("[Sample] d_loss: %.8f, g_loss: %.8f" %
                                  (d_loss, g_loss))
                        except:
                            print("one pic error!...")

                if counter % 500 == 2:
                    self.save(config.checkpoint_dir, counter)
Ejemplo n.º 8
0
class DCGAN(object):

    def __init__(self,
                 sess,
                 learning_rate=0.0002, beat1=0.5,
                 z_dim=100,
                 c_dim=3, batch_size=BATCH_SIZE,
                 gf_dim=64, gfc_dim=1024,
                 df_dim=64, dfc_dim=1024,
                 input_height=48, input_width=48):

        self.sess = sess
        self.z_dim = z_dim  # 噪声向量长度
        self.c_dim = c_dim  # 图片channel数目
        self.gf_dim = gf_dim  # G生成通道基准
        self.gfc_dim = gfc_dim  # ac_gan中最初还原的向量长度
        self.df_dim = df_dim  # D生成通道基准
        self.dfc_dim = dfc_dim  # ac_gan中最后一层全连接的输入维度的向量长度
        self.batch_size = batch_size  # 训练批次图数目
        self.input_height = input_height  # 图片高度
        self.input_width = input_width  # 图片宽度
        
        self.inputs, _ = TFR_process.batch_from_tfr()
        self.z = tf.placeholder(tf.float32, [None, self.z_dim], name='z')
        
        # real数据通过判别器
        d, d_logits = self.discriminator(self.inputs, reuse=False)
        # fake数据生成
        g = self.generator(self.z)
        # fake数据通过判别器,注意来源不同的数据流流经同一结构,要reuse
        d_, d_logits_ = self.discriminator(g, reuse=True)
        # 用生成器生成示例的节点,其数据来源于上面的g相同,故图不需要reuse
        self.s = self.generator(self.z, train=False)

        # 损失函数生成
        # D的real损失:使真实图片进入D后输出为1,只训练D的参数
        self.d_loss_real = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits, labels=tf.ones_like(d)))
        # D的fake损失:噪声经由G后进入D,使D的输出为0,只训练D的参数
        self.d_loss_fake = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_, labels=tf.zeros_like(d_)))

        # G的损失:噪声经由G后进入D,使D的输出为1,只训练G的参数
        self.g_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_, labels=tf.ones_like(d_)))
        # D的损失:D的real损失 + D的fake损失,只训练D的参数
        self.d_loss = tf.add(self.d_loss_real, self.d_loss_fake)

        t_vars = tf.trainable_variables()
        g_vars = [var for var in t_vars if var.name.startswith('generator')]
        d_vars = [var for var in t_vars if var.name.startswith('discriminator')]

        self.d_optim = tf.train.AdamOptimizer(learning_rate, beta1=beat1) \
            .minimize(self.d_loss, var_list=d_vars)
        self.g_optim = tf.train.AdamOptimizer(learning_rate, beta1=beat1) \
            .minimize(self.g_loss, var_list=g_vars)

        z_sum = tf.summary.histogram("z", self.z)
        d_sum = tf.summary.histogram("d", d)
        d__sum = tf.summary.histogram("d_", d_)
        g_sum = tf.summary.image("G", g)
        d_loss_real_sum = tf.summary.scalar("d_loss_real", self.d_loss_real)
        d_loss_fake_sum = tf.summary.scalar("d_loss_fake", self.d_loss_fake)
        g_loss_sum = tf.summary.scalar("g_loss", self.g_loss)
        d_loss_sum = tf.summary.scalar("d_loss", self.d_loss)
        self.g_sum = tf.summary.merge([z_sum, d__sum, g_sum, d_loss_fake_sum, g_loss_sum])
        self.d_sum = tf.summary.merge([z_sum, d_sum, d_loss_real_sum, d_loss_sum])

        self.saver = tf.train.Saver()
        self.writer = tf.summary.FileWriter("./logs", self.sess.graph)

    def train(self):
        self.sess.run([tf.global_variables_initializer(),
                       tf.local_variables_initializer()])  # <-----线程相关不要忘了它
        
       with self.sess.as_default():
        
            # 加载预训练模型
            if not os.path.exists("./logs/model"):
                os.makedirs("./logs/model")
            ckpt = tf.train.get_checkpoint_state("./logs/model")
            if ckpt is not None:
                print("[*] Success to read {}".format(ckpt.model_checkpoint_path))
                self.saver.restore(self.sess, ckpt.model_checkpoint_path)
            else:
                print("[*] Failed to find a checkpoint")

            # 线程相关对象初始化
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=self.sess, coord=coord)

            # 进入循环
            for epoch in range(EPOCH):
                for step in range(STEPS):
                    counter = epoch*STEPS+step
                    # 准备数据
                    batch_z = np.random.uniform(-1, 1, [TFR_process.BATCH_SIZE, self.z_dim]).astype(np.float32)
                    
                    # 训练部分,每训练一次判别器需要训练两次生成器 
                    _, summary_str = self.sess.run([self.d_optim, self.d_sum], feed_dict={self.z: batch_z})
                    self.writer.add_summary(summary_str, counter) 
                    _, summary_str = self.sess.run([self.g_optim, self.g_sum], feed_dict={self.z: batch_z})
                    self.writer.add_summary(summary_str, counter)
                    _, summary_str = self.sess.run([self.g_optim, self.g_sum], feed_dict={self.z: batch_z})
                    self.writer.add_summary(summary_str, counter)
                    
                    # 获取loss值进行展示
                    error_d_fake = self.d_loss_fake.eval({self.z: batch_z})
                    error_d_real = self.d_loss_real.eval()
                    error_g = self.g_loss.eval({self.z: batch_z})
                    print("epoch: {0}, step: {1}".format(epoch, step))
                    print("d_loss: {0:.8f}, g_loss: {1:.8f}".format(error_d_fake + error_d_real, error_g))
                    
                    # 模型保存与中间结果展示
                    if np.mod(step, 50) == 0:
                        # self.save(epoch, step, counter + counter_)
                        self.saver.save(self.sess, "./logs/model/DCGAN.model", global_step=epoch*STEPS+step)

                        sample_z = np.random.uniform(-1, 1, size=(self.batch_size, self.z_dim))
                        samples = self.sess.run(self.s, feed_dict={self.z: sample_z})
                        utils.save_images(samples, utils.image_manifold_size(samples.shape[0]),
                                          './train_{:02d}_{:04d}.png'.format(epoch, step))
            # 线程控制对象关闭
            coord.request_stop()
            coord.join(threads)

    def discriminator(self, image, reuse=False):
        with tf.variable_scope("discriminator", reuse=reuse) as scope:

            h0 = lrelu(conv2d(image, self.df_dim, scope='d_h0_conv'))
            h1 = lrelu(batch_normal(conv2d(h0, self.df_dim * 2, scope='d_h1_conv'), scope='d_bn1'))
            h2 = lrelu(batch_normal(conv2d(h1, self.df_dim * 4, scope='d_h2_conv'), scope='d_bn2'))
            h3 = lrelu(batch_normal(conv2d(h2, self.df_dim * 8, scope='d_h3_conv'), scope='d_bn3'))
            h4 = linear(tf.reshape(h3, [self.batch_size, -1]), 1, scope='d_h4_lin')

            return tf.nn.sigmoid(h4), h4

    def generator(self, z, train=True):
        """生成器"""
        with tf.variable_scope("generator") as scope:
            if not train:
                scope.reuse_variables()

            s_h, s_w = self.input_height, self.input_width
            s_h2, s_w2 = conv_out_size_same(s_h, 2), conv_out_size_same(s_w, 2)
            s_h4, s_w4 = conv_out_size_same(s_h2, 2), conv_out_size_same(s_w2, 2)
            s_h8, s_w8 = conv_out_size_same(s_h4, 2), conv_out_size_same(s_w4, 2)
            s_h16, s_w16 = conv_out_size_same(s_h8, 2), conv_out_size_same(s_w8, 2)

            z_ = linear(
                z, self.gf_dim * 8 * s_h16 * s_w16, scope='g_h0_lin')

            h0 = tf.reshape(z_, [-1, s_h16, s_w16, self.gf_dim * 8])
            h0 = tf.nn.relu(batch_normal(h0, train=train, scope='g_bn0'))

            h1 = deconv2d(h0, [self.batch_size, s_h8, s_w8, self.gf_dim * 4], scope='g_h1')
            h1 = tf.nn.relu(batch_normal(h1, train=train, scope='g_bn1'))

            h2 = deconv2d(h1, [self.batch_size, s_h4, s_w4, self.gf_dim * 2], scope='g_h2')
            h2 = tf.nn.relu(batch_normal(h2, train=train, scope='g_bn2'))

            h3 = deconv2d(h2, [self.batch_size, s_h2, s_w2, self.gf_dim * 1], scope='g_h3')
            h3 = tf.nn.relu(batch_normal(h3, train=train, scope='g_bn3'))

            h4 = deconv2d(h3, [self.batch_size, s_h, s_w, self.c_dim], scope='g_h4')

            return tf.nn.tanh(h4)
Ejemplo n.º 9
0
    def train(self, config):
        """Train the model!
        """
        d_clip = None

        ##############################
        # Define the optimizers
        if self.model_type == self.GAN:
            d_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1) \
                .minimize(self.d_loss, var_list=self.d_vars)
            g_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1) \
                .minimize(self.g_loss, var_list=self.g_vars)

        elif self.model_type == self.WGAN:
            # Wasserstein GAN
            d_optim = tf.train.RMSPropOptimizer(config.learning_rate) \
                .minimize(self.d_loss, var_list=self.d_vars)
            g_optim = tf.train.RMSPropOptimizer(config.learning_rate) \
                .minimize(self.g_loss, var_list=self.g_vars)

            # After every gradient update on the discriminator model, clamp its weights to a
            # small fixed range, [-d_clip_limit, d_clip_limit].
            d_clip = tf.group(*[
                v.assign(
                    tf.clip_by_value(v, -self.d_clip_limit, self.d_clip_limit))
                for v in self.d_vars
            ])

        elif self.model_type == self.WGAN_GP:
            d_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1, beta2=config.beta2) \
                .minimize(self.d_loss, var_list=self.d_vars)
            g_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1, beta2=config.beta2) \
                .minimize(self.g_loss, var_list=self.g_vars)

        tf.global_variables_initializer().run()

        # Merge summary
        g_sum_list = [
            self.z_sum, self.d__sum, self.G_sum, self.g_loss_sum,
            self.d_loss_fake_sum
        ]
        d_sum_list = [
            self.z_sum, self.d_sum, self.inputs_sum, self.d_loss_sum,
            self.d_loss_real_sum
        ]

        if self.model_type in (self.WGAN, self.WGAN_GP
                               ) and self.l1_regularizer_scale is not None:
            g_sum_list += [self.reg_summ]
            d_sum_list += [self.reg_summ]

        if self.model_type == self.WGAN_GP:
            d_sum_list += [self.gp_loss_sum, self.grad_norm_sum]

        self.g_sum = tf.summary.merge(g_sum_list)
        self.d_sum = tf.summary.merge(d_sum_list)

        self.writer = tf.summary.FileWriter(
            os.path.join("./logs", self.model_dir), self.sess.graph)

        # Set up the sample images
        sample_feed_dict = self.get_sample_data(config)

        # Create a sample image every `sample_every_step` steps.
        sample_every_step = int(config.max_iter // 20)

        start_time = time.time()
        could_load, checkpoint_counter = self.load()

        counter = 1  # Count how many batches we have processed.
        d_counter = 0  # Count number of batches used for training D
        g_counter = 0  # Count number of batches used for training G

        if could_load:
            counter = checkpoint_counter
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")

        ##############################
        # Start training!

        inf_data_gen = self.inf_get_next_batch(config)

        for iter_count in xrange(config.max_iter):
            if self.model_type == self.GAN:
                _d_iters = 1
            else:
                # For WGAN or WGAN_GP model, we are allowed to train the D network to be very good at
                # the beginning as a warm start. Because theoretically Wasserstain distance does not
                # suffer the vanishing gradient dilemma that vanila GAN is facing.
                _d_iters = 100 if iter_count < 25 or np.mod(
                    iter_count, 500) == 0 else self.d_iter

            # Update D network
            counter += _d_iters
            d_counter += _d_iters
            for _ in range(_d_iters):
                epoch, step, d_train_feed_dict, g_train_feed_dict = inf_data_gen.next(
                )
                self.sess.run(d_optim, feed_dict=d_train_feed_dict)
                if d_clip is not None:
                    self.sess.run(d_clip)

            summary_str = self.sess.run(self.d_sum,
                                        feed_dict=d_train_feed_dict)
            self.writer.add_summary(summary_str, iter_count)

            # Update G network
            g_counter += 1
            _, summary_str = self.sess.run([g_optim, self.g_sum],
                                           feed_dict=g_train_feed_dict)
            self.writer.add_summary(summary_str, iter_count)

            d_err = self.d_loss.eval(d_train_feed_dict)
            g_err = self.g_loss.eval(g_train_feed_dict)

            if np.mod(iter_count, 100) == 0:
                print(
                    "Iter: %d Epoch: %d [%d/%d] time: %4.4f, d_loss: %.8f, g_loss: %.8f"
                    % (iter_count, epoch, d_counter, g_counter,
                       time.time() - start_time, d_err, g_err))

            if np.mod(iter_count, sample_every_step) == 1:
                samples, d_loss, g_loss = self.sess.run(
                    [self.sampler, self.d_loss, self.g_loss],
                    feed_dict=sample_feed_dict)

                image_path = os.path.join(
                    self.sample_dir,
                    "train_{:02d}_{:04d}.png".format(epoch, step))
                save_images(samples, image_manifold_size(samples.shape[0]),
                            image_path)
                print("[Sample] d_loss: %.8f, g_loss: %.8f" % (d_loss, g_loss))

                # Save the model.
                self.save(counter)
Ejemplo n.º 10
0
    def train(self, config):
        d_optim = tf.train.AdamOptimizer(config.learning_rate,
                                         beta1=config.beta1).minimize(
                                             self.d_loss, var_list=self.d_vars)
        g_optim = tf.train.AdamOptimizer(config.learning_rate,
                                         beta1=config.beta1).minimize(
                                             self.g_loss, var_list=self.g_vars)
        tf.global_variables_initializer().run()

        self.g_sum = summarys['merge']([
            self.z_sum, self.d__sum, self.G_sum, self.d_loss_fake_sum,
            self.g_loss_sum
        ])
        self.d_sum = summarys['merge'](
            [self.z_sum, self.d_sum, self.d_loss_real_sum, self.d_loss_sum])
        self.writer = summarys['writer']('.logs', self.sess.graph)

        counter = 1
        start_time = time.time()
        could_load, checkpoint_counter = self.load(self.checkpoint_dir)
        if could_load:
            counter = checkpoint_counter
            print "[*] load success"
        else:
            print "[!] load field !!"

        for epoch in tqdm(range(config.epoch)):
            batch_idxs = min(len(self.data_x),
                             config.train_size) // config.batch_size
            for idx in range(0, batch_idxs):
                batch_images = self.data_x[idx * config.batch_size:(idx + 1) *
                                           config.batch_size]
                batch_labels = self.data_y[idx * config.batch_size:(idx + 1) *
                                           config.batch_size]
                batch_z = generate_z(self.batch_size, self.z_dim * 4)

                # update D network
                _, summary_str = self.sess.run(
                    [d_optim, self.d_sum],
                    feed_dict={
                        self.x_set[0]:
                        batch_images[:, 0:int(self.input_height /
                                              np.sqrt(self.num_patches)),
                                     0:int(self.input_width /
                                           np.sqrt(self.num_patches)), :],
                        self.x_set[1]:
                        batch_images[:, 0:int(self.input_height /
                                              np.sqrt(self.num_patches)),
                                     0:int(self.input_width /
                                           np.sqrt(self.num_patches)), :],
                        self.x_set[2]:
                        batch_images[:, 0:int(self.input_height /
                                              np.sqrt(self.num_patches)),
                                     0:int(self.input_width /
                                           np.sqrt(self.num_patches)), :],
                        self.x_set[3]:
                        batch_images[:, 0:int(self.input_height /
                                              np.sqrt(self.num_patches)),
                                     0:int(self.input_width /
                                           np.sqrt(self.num_patches)), :],
                        self.z_set[0]:
                        batch_z[:, 0:self.z_dim],
                        self.z_set[1]:
                        batch_z[:, self.z_dim:self.z_dim * 2],
                        self.z_set[2]:
                        batch_z[:, self.z_dim * 2:self.z_dim * 3],
                        self.z_set[3]:
                        batch_z[:, self.z_dim * 3:self.z_dim * 4],
                        self.y_set[0]:
                        np.concatenate([
                            batch_labels,
                            np.array([[0, 0]] * self.batch_size).astype(
                                np.float32)
                        ],
                                       axis=1),
                        self.y_set[1]:
                        np.concatenate([
                            batch_labels,
                            np.array([[0, 1]] * self.batch_size).astype(
                                np.float32)
                        ],
                                       axis=1),
                        self.y_set[2]:
                        np.concatenate([
                            batch_labels,
                            np.array([[1, 0]] * self.batch_size).astype(
                                np.float32)
                        ],
                                       axis=1),
                        self.y_set[3]:
                        np.concatenate([
                            batch_labels,
                            np.array([[1, 1]] * self.batch_size).astype(
                                np.float32)
                        ],
                                       axis=1)
                    })
                self.writer.add_summary(summary_str, counter)

                for i in range(config.g_epoch):
                    # update G network g_epoch times
                    _, summary_str = self.sess.run(
                        [g_optim, self.g_sum],
                        feed_dict={
                            self.z_set[0]:
                            batch_z[:, 0:self.z_dim],
                            self.z_set[1]:
                            batch_z[:, self.z_dim:self.z_dim * 2],
                            self.z_set[2]:
                            batch_z[:, self.z_dim * 2:self.z_dim * 3],
                            self.z_set[3]:
                            batch_z[:, self.z_dim * 3:self.z_dim * 4],
                            self.y_set[0]:
                            np.concatenate([
                                batch_labels,
                                np.array([[0, 0]] * self.batch_size).astype(
                                    np.float32)
                            ],
                                           axis=1),
                            self.y_set[1]:
                            np.concatenate([
                                batch_labels,
                                np.array([[0, 1]] * self.batch_size).astype(
                                    np.float32)
                            ],
                                           axis=1),
                            self.y_set[2]:
                            np.concatenate([
                                batch_labels,
                                np.array([[1, 0]] * self.batch_size).astype(
                                    np.float32)
                            ],
                                           axis=1),
                            self.y_set[3]:
                            np.concatenate([
                                batch_labels,
                                np.array([[1, 1]] * self.batch_size).astype(
                                    np.float32)
                            ],
                                           axis=1)
                        })
                    self.writer.add_summary(summary_str, counter)

                errD_fake = self.d_loss_fake.eval({
                    self.z_set[0]:
                    batch_z[:, 0:self.z_dim],
                    self.z_set[1]:
                    batch_z[:, self.z_dim:self.z_dim * 2],
                    self.z_set[2]:
                    batch_z[:, self.z_dim * 2:self.z_dim * 3],
                    self.z_set[3]:
                    batch_z[:, self.z_dim * 3:self.z_dim * 4],
                    self.y_set[0]:
                    np.concatenate([
                        batch_labels,
                        np.array([[0, 0]] * self.batch_size).astype(np.float32)
                    ],
                                   axis=1),
                    self.y_set[1]:
                    np.concatenate([
                        batch_labels,
                        np.array([[0, 1]] * self.batch_size).astype(np.float32)
                    ],
                                   axis=1),
                    self.y_set[2]:
                    np.concatenate([
                        batch_labels,
                        np.array([[1, 0]] * self.batch_size).astype(np.float32)
                    ],
                                   axis=1),
                    self.y_set[3]:
                    np.concatenate([
                        batch_labels,
                        np.array([[1, 1]] * self.batch_size).astype(np.float32)
                    ],
                                   axis=1)
                })
                errD_real = self.d_loss_real.eval({
                    self.x_set[0]:
                    batch_images[:, 0:int(self.input_height /
                                          np.sqrt(self.num_patches)),
                                 0:int(self.input_width /
                                       np.sqrt(self.num_patches)), :],
                    self.x_set[1]:
                    batch_images[:, 0:int(self.input_height /
                                          np.sqrt(self.num_patches)),
                                 0:int(self.input_width /
                                       np.sqrt(self.num_patches)), :],
                    self.x_set[2]:
                    batch_images[:, 0:int(self.input_height /
                                          np.sqrt(self.num_patches)),
                                 0:int(self.input_width /
                                       np.sqrt(self.num_patches)), :],
                    self.x_set[3]:
                    batch_images[:, 0:int(self.input_height /
                                          np.sqrt(self.num_patches)),
                                 0:int(self.input_width /
                                       np.sqrt(self.num_patches)), :],
                    self.y_set[0]:
                    np.concatenate([
                        batch_labels,
                        np.array([[0, 0]] * self.batch_size).astype(np.float32)
                    ],
                                   axis=1),
                    self.y_set[1]:
                    np.concatenate([
                        batch_labels,
                        np.array([[0, 1]] * self.batch_size).astype(np.float32)
                    ],
                                   axis=1),
                    self.y_set[2]:
                    np.concatenate([
                        batch_labels,
                        np.array([[1, 0]] * self.batch_size).astype(np.float32)
                    ],
                                   axis=1),
                    self.y_set[3]:
                    np.concatenate([
                        batch_labels,
                        np.array([[1, 1]] * self.batch_size).astype(np.float32)
                    ],
                                   axis=1)
                })
                errG = self.g_loss.eval({
                    self.z_set[0]:
                    batch_z[:, 0:self.z_dim],
                    self.z_set[1]:
                    batch_z[:, self.z_dim:self.z_dim * 2],
                    self.z_set[2]:
                    batch_z[:, self.z_dim * 2:self.z_dim * 3],
                    self.z_set[3]:
                    batch_z[:, self.z_dim * 3:self.z_dim * 4],
                    self.y_set[0]:
                    np.concatenate([
                        batch_labels,
                        np.array([[0, 0]] * self.batch_size).astype(np.float32)
                    ],
                                   axis=1),
                    self.y_set[1]:
                    np.concatenate([
                        batch_labels,
                        np.array([[0, 1]] * self.batch_size).astype(np.float32)
                    ],
                                   axis=1),
                    self.y_set[2]:
                    np.concatenate([
                        batch_labels,
                        np.array([[1, 0]] * self.batch_size).astype(np.float32)
                    ],
                                   axis=1),
                    self.y_set[3]:
                    np.concatenate([
                        batch_labels,
                        np.array([[1, 1]] * self.batch_size).astype(np.float32)
                    ],
                                   axis=1)
                })

                counter += 1
                if np.mod(counter, 100) == 1:
                    samples, d_loss, g_loss = self.sess.run(
                        [self.sampler, self.d_loss, self.g_loss],
                        feed_dict={
                            self.x_set[0]:
                            batch_images[:, 0:int(self.input_height /
                                                  np.sqrt(self.num_patches)),
                                         0:int(self.input_width /
                                               np.sqrt(self.num_patches)), :],
                            self.x_set[1]:
                            batch_images[:, 0:int(self.input_height /
                                                  np.sqrt(self.num_patches)),
                                         0:int(self.input_width /
                                               np.sqrt(self.num_patches)), :],
                            self.x_set[2]:
                            batch_images[:, 0:int(self.input_height /
                                                  np.sqrt(self.num_patches)),
                                         0:int(self.input_width /
                                               np.sqrt(self.num_patches)), :],
                            self.x_set[3]:
                            batch_images[:, 0:int(self.input_height /
                                                  np.sqrt(self.num_patches)),
                                         0:int(self.input_width /
                                               np.sqrt(self.num_patches)), :],
                            self.z_set[0]:
                            batch_z[:, 0:self.z_dim],
                            self.z_set[1]:
                            batch_z[:, self.z_dim:self.z_dim * 2],
                            self.z_set[2]:
                            batch_z[:, self.z_dim * 2:self.z_dim * 3],
                            self.z_set[3]:
                            batch_z[:, self.z_dim * 3:self.z_dim * 4],
                            self.y_set[0]:
                            np.concatenate([
                                batch_labels,
                                np.array([[0, 0]] * self.batch_size).astype(
                                    np.float32)
                            ],
                                           axis=1),
                            self.y_set[1]:
                            np.concatenate([
                                batch_labels,
                                np.array([[0, 1]] * self.batch_size).astype(
                                    np.float32)
                            ],
                                           axis=1),
                            self.y_set[2]:
                            np.concatenate([
                                batch_labels,
                                np.array([[1, 0]] * self.batch_size).astype(
                                    np.float32)
                            ],
                                           axis=1),
                            self.y_set[3]:
                            np.concatenate([
                                batch_labels,
                                np.array([[1, 1]] * self.batch_size).astype(
                                    np.float32)
                            ],
                                           axis=1)
                        })
                    save_images(
                        samples, image_manifold_size(samples.shape[0]),
                        './{}/train_{:02d}_{:04d}.png'.format(
                            config.sample_dir, epoch, idx))

                if np.mod(counter, 500) == 2:
                    self.save(config.checkpoint_dir, counter)

                print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
                      % (epoch, idx, batch_idxs,
                         time.time() - start_time, errD_fake + errD_real, errG))
Ejemplo n.º 11
0
def train(params):
    """Training loop."""
    # with tf.Graph().as_default(), tf.device('/cpu:0'):
    with tf.Graph().as_default(), tf.device('/device:GPU:0'):
        global_step = tf.Variable(0, trainable=False)
        # OPTIMIZER
        num_training_samples = count_text_lines(args.filenames_file)
        steps_per_epoch = np.ceil(num_training_samples / params.batch_size).astype(np.int32)
        num_total_steps = params.num_epochs * steps_per_epoch
        start_learning_rate = args.learning_rate
        boundaries = [np.int32((3/5) * num_total_steps), np.int32((4/5) * num_total_steps)]
        values = [args.learning_rate, args.learning_rate / 2, args.learning_rate / 4]
        learning_rate = tf.train.piecewise_constant(global_step, boundaries, values)
        print("Total number of samples: {}".format(num_training_samples))
        print("Total number of steps: {}".format(num_total_steps))

        z = tf.placeholder(tf.float32, [params.batch_size, params.z_dim], name='z_noise')
        dataloader = MonodepthDataloader(args.data_path, args.filenames_file, params, args.dataset, args.mode)
        left = dataloader.left_image_batch
        right = dataloader.right_image_batch

        fake_generated_left_image = []
        reuse_variables = tf.AUTO_REUSE
        # split for each gpu
        model_generator = MonodepthGenerateModel(params, args.mode, z, reuse_variables, 0)
        left_splits = tf.split(left,  args.num_gpus, 0)[0]
        left_splits_fake = tf.split(model_generator.get_model(), args.num_gpus, 0)[0]
        right_splits = tf.split(right, args.num_gpus, 0)[0]

        with tf.variable_scope('discriminator', reuse=reuse_variables):
                differences = tf.subtract(left_splits_fake, left_splits)
                alpha_shape = [params.batch_size] + [1] * (differences.shape.ndims - 1)
                alpha = tf.random_uniform(shape=alpha_shape, minval=0., maxval=1.)
                interpolates = left_splits + (alpha * differences)
                left_splits_wasserstein_model = MonodepthModel(params, args.mode, interpolates,                                                               right_splits, reuse_variables, left_splits_fake, 1)
                gradients = tf.gradients(left_splits_wasserstein_model.logistic,
                                     [interpolates], stop_gradients=interpolates, colocate_gradients_with_ops=True)[0]
                slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
                gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2)
                _gradient_penalty = 10 * gradient_penalty

        with tf.variable_scope('discriminator', reuse=reuse_variables):
            model_real = MonodepthModel(params, args.mode, left_splits,
                                                  right_splits, reuse_variables,None, 0)
        loss_discriminator_real = model_real.discriminator_total_loss
        real_feature_set = model_real.get_feature_set()

        with tf.variable_scope('discriminator', reuse=reuse_variables):
            model_fake = MonodepthModel(params, args.mode, left_splits_fake, right_splits,
                                    reuse_variables,None, 10)

        fake_feature_set = model_fake.get_feature_set()
        loss_discriminator = loss_discriminator_real
        generator_loss = tf.nn.l2_loss((real_feature_set - fake_feature_set))
        # total_loss_generator = tf.reduce_mean(generator_loss) + wasserstein_generator_loss(model_fake.logistic_linear)
        # total_loss_discriminator = wasserstein_discriminator_loss(model_real.logistic_linear, model_fake.logistic_linear)+ \
        #                             loss_discriminator + _gradient_penalty
        total_loss_generator = tf.reduce_mean(generator_loss)-tf.reduce_mean(model_fake.logistic)
        total_loss_discriminator = tf.reduce_mean(model_fake.logistic) - tf.reduce_mean(model_real.logistic) \
                                   + loss_discriminator + _gradient_penalty

        # with tf.device('/device:GPU:0'):
        opt_discriminator_step = tf.train.AdamOptimizer(learning_rate)
        opt_generator_step = tf.train.AdamOptimizer(learning_rate)

        extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(extra_update_ops):
            g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="generator/*")
            d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="discriminator/*")
            d_optim = opt_discriminator_step.minimize(total_loss_discriminator, var_list=d_vars)
            g_optim = opt_generator_step.minimize(total_loss_generator, var_list=g_vars)

        tf.summary.scalar('learning_rate', learning_rate, ['discriminator_0'])
        tf.summary.scalar('total_loss', total_loss_discriminator, ['discriminator_0'])
        summary_op = tf.summary.merge_all('discriminator_0')

        # SESSION
        config = tf.ConfigProto(allow_soft_placement=True)
        config.gpu_options.allow_growth = True
        config.log_device_placement = False
        sess = tf.Session(config=config)

        # SAVER
        summary_writer = tf.summary.FileWriter(args.log_directory + '/' + args.model_name, sess.graph)
        train_saver = tf.train.Saver()

        # COUNT PARAMS
        total_num_parameters = 0
        for variable in tf.trainable_variables():
            total_num_parameters += np.array(variable.get_shape().as_list()).prod()
        print("Number of trainable parameters: {}".format(total_num_parameters))

        # INIT
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        coordinator = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coordinator)

        # LOAD CHECKPOINT IF SET
        if args.checkpoint_path != '':
            train_saver.restore(sess, args.checkpoint_path.split(".")[0])

            if args.retrain:
                sess.run(global_step.assign(0))

        # GO!
        start_step = global_step.eval(session=sess)
        start_time = time.time()
        sample_dataset = np.random.uniform(low=-1, high=1, size=(params.batch_size, params.z_dim)).astype(np.float32)

        for step in range(start_step, num_total_steps):
            before_op_time = time.time()

            batch_z = np.random.uniform(low=-1, high=1, size=(params.batch_size, params.z_dim)).astype(np.float32)
            _, loss_value_discriminator = sess.run([d_optim, total_loss_discriminator], feed_dict={z: batch_z})
            # _, loss_value_discriminator, images_original = sess.run([d_optim, total_loss_discriminator, dataloader.left_image_batch], feed_dict={z: batch_z})
            # print("size-------> {}".format(images_original))
            for _ in range(2):
                _, loss_value_generator, generated_images = sess.run([g_optim, total_loss_generator, model_generator.samplter_network],
                                               feed_dict={z: batch_z})
            duration = time.time() - before_op_time
            if step and step % 100 == 0:
                _, loss_value_generator, generated_images = sess.run(
                    [g_optim, total_loss_generator, model_generator.samplter_network],
                    feed_dict={z: sample_dataset})
                save_images(generated_images, image_manifold_size(generated_images.shape[0]),
                             '{}/train_{:02d}_{:04d}.png'.format(params.sample_dir, step, 1))
                examples_per_sec = params.batch_size / duration
                time_sofar = (time.time() - start_time) / 3600
                training_time_left = (num_total_steps / step - 1.0) * time_sofar
                print_string = 'batch {:>6} | examples/s: {:4.2f} | loss_discriminator: {:.5f} | time elapsed: {:.2f}h ' \
                               '| time left: {:.2f}h'
                print(print_string.format(step, examples_per_sec, loss_value_discriminator, time_sofar,
                                          training_time_left))
                print_string = 'batch {:>6} | examples/s: {:4.2f} | loss_generator: {:.5f} | time elapsed: {:.2f}h | ' \
                               'time left: {:.2f}h'
                print(print_string.format(step, examples_per_sec, loss_value_generator, time_sofar,
                                          training_time_left))
                summary_str = sess.run(summary_op, feed_dict={z: batch_z})
                summary_writer.add_summary(summary_str, global_step=step)
            if step and step % 10000 == 0:
                train_saver.save(sess, args.log_directory + '/' + args.model_name + '/model', global_step=step)

        train_saver.save(sess, args.log_directory + '/' + args.model_name + '/model', global_step=num_total_steps)
Ejemplo n.º 12
0
    def train(self, config):
        d_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1) \
            .minimize(self.d_loss, var_list=self.d_vars)
        '''
        # var_list, 指定只对需要的参数进行梯度计算
        '''
        g_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1) \
            .minimize(self.g_loss, var_list=self.g_vars)
        tf.global_variables_initializer().run()

        self.g_sum = merge_summary([self.z_sum, self.d__sum,
                                    self.G_sum, self.d_loss_fake_sum, self.g_loss_sum])
        self.d_sum = merge_summary(
            [self.z_sum, self.d_sum, self.d_loss_real_sum, self.d_loss_sum])
        self.writer = SummaryWriter("logdir", self.sess.graph)

        sample_z = np.random.uniform(-1, 1, size=(self.sample_num, self.z_dim))
        sample_inputs = self.data_X[0:self.sample_num]
        sample_labels = self.data_y[0:self.sample_num]

        counter = 1
        start_time = time.time()
        could_load, checkpoint_counter = self.load(self.checkpoint_dir)
        if could_load:
            counter = checkpoint_counter
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")

        for epoch in xrange(config.epoch):
            batch_idxs = min(len(self.data_X), config.train_size) // config.batch_size
            # print(batch_idxs)
            for idx in xrange(0, batch_idxs):
                batch_images = self.data_X[idx * config.batch_size:(idx + 1) * config.batch_size]
                batch_labels = self.data_y[idx * config.batch_size:(idx + 1) * config.batch_size]

                batch_z = np.random.uniform(-1, 1, [config.batch_size, self.z_dim]) \
                    .astype(np.float32)

                _, summary_str = self.sess.run([d_optim, self.d_sum],
                                               feed_dict={
                                                   self.inputs: batch_images,
                                                   self.z: batch_z,
                                                   self.y: batch_labels,
                                               })
                self.writer.add_summary(summary_str, counter)

                # Update G network
                _, summary_str = self.sess.run([g_optim, self.g_sum],
                                               feed_dict={
                                                   self.z: batch_z,
                                                   self.y: batch_labels,
                                               })
                self.writer.add_summary(summary_str, counter)

                # Run g_optim twice to make sure that d_loss does not go to zero (different from paper)
                _, summary_str = self.sess.run([g_optim, self.g_sum],
                                               feed_dict={self.z: batch_z, self.y: batch_labels})
                self.writer.add_summary(summary_str, counter)

                errD_fake = self.d_loss_fake.eval({
                    self.z: batch_z,
                    self.y: batch_labels
                })
                errD_real = self.d_loss_real.eval({
                    self.inputs: batch_images,
                    self.y: batch_labels
                })
                errG = self.g_loss.eval({
                    self.z: batch_z,
                    self.y: batch_labels
                })

                counter += 1
                print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
                      % (epoch, idx, batch_idxs,
                         time.time() - start_time, errD_fake + errD_real, errG))

                if np.mod(counter, 100) == 1:
                    samples, d_loss, g_loss = self.sess.run(
                        [self.sampler, self.d_loss, self.g_loss],
                        feed_dict={
                            self.z: sample_z,
                            self.inputs: sample_inputs,
                            self.y: sample_labels,
                        }
                    )
                    save_images(samples, image_manifold_size(samples.shape[0]),
                                './{}/train_{:02d}_{:04d}.png'.format(config.sample_dir, epoch, idx))
                    print("[Sample] d_loss: %.8f, g_loss: %.8f" % (d_loss, g_loss))

                if np.mod(counter, 500) == 2:
                    self.save(config.checkpoint_dir, counter)
Ejemplo n.º 13
0
    def train(self, config):
        d_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1).\
            minimize(self.d_loss, var_list=self.d_vars)
        g_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1).\
            minimize(self.g_loss, var_list=self.g_vars)
        tf.global_variables_initializer().run()

        self.g_sum = summarys['merge']([self.z_sum, self.d__sum, self.G_sum, self.d_loss_fake_sum, self.g_loss_sum])
        self.d_sum = summarys['merge']([self.z_sum, self.d_sum, self.d_loss_real_sum, self.d_loss_sum])
        self.writer = summarys['writer']('.logs', self.sess.graph)

        sample_z = np.random.uniform(-1, 1, size=(self.sample_num, self.z_dim))
        sample_inputs = self.data_x[0: self.sample_num]
        sample_labels = self.data_y[0: self.sample_num]

        counter = 1
        start_time = time.time()
        could_load, checkpoint_counter = self.load(self.checkpoint_dir)
        if could_load:
            counter = checkpoint_counter
            print "[*] load success"
        else:
            print "[!] load field !!"

        for epoch in tqdm(range(config.epoch)):
            batch_idxs = min(len(self.data_x), config.train_size) // config.batch_size
            for idx in range(0, batch_idxs):
                batch_images = self.data_x[idx*config.batch_size:(idx+1)*config.batch_size]
                batch_labels = self.data_y[idx*config.batch_size:(idx+1)*config.batch_size]
                batch_z = np.random.uniform(-1, 1, [config.batch_size, self.z_dim]).astype(np.float32)

                # update D network
                _, summary_str = self.sess.run([d_optim, self.d_sum],
                                               feed_dict={self.inputs: batch_images,
                                                          self.y: batch_labels,
                                                          self.z: batch_z})
                self.writer.add_summary(summary_str, counter)

                for i in range(2):
                    # update G network twice
                    _, summary_str = self.sess.run([g_optim, self.g_sum],
                                                   feed_dict={self.z: batch_z,
                                                              self.y: batch_labels})
                    self.writer.add_summary(summary_str, counter)

                errD_fake = self.d_loss_fake.eval({self.z: batch_z, self.y: batch_labels})
                errD_real = self.d_loss_real.eval({self.inputs: batch_images, self.y: batch_labels})
                errG = self.g_loss.eval({self.z: batch_z, self.y: batch_labels})

                counter += 1
                if np.mod(counter, 100) == 1:
                    samples, d_loss, g_loss = self.sess.run([self.sampler, self.d_loss, self.g_loss],
                                                            feed_dict={self.z: batch_z,
                                                                       self.inputs: sample_inputs,
                                                                       self.y: sample_labels})
                    save_images(samples, image_manifold_size(samples.shape[0]),
                                './{}/train_{:02d}_{:04d}.png'.format(config.sample_dir, epoch, idx))
                    print("[Sample] d_loss: %.8f, g_loss: %.8f" % (d_loss, g_loss))

                if np.mod(counter, 500) == 2:
                    self.save(config.checkpoint_dir, counter)


            print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
                  % (epoch, idx, batch_idxs,
                     time.time() - start_time, errD_fake + errD_real, errG))