Example #1
0
                print("Epoch: (%3d) (%5d/%5d)" % (ep, it_in_epoch, it_per_epoch))

            # sample
            if (it + 1) % 2000 == 0:
                save_dir = './output/%s/sample_training' % experiment_name
                pylib.mkdir(save_dir)

                img_rec_opt_sample, img_intp_opt_sample = sess.run([img_rec_sample, img_intp_sample],
                                                                   feed_dict={img: img_ipt_sample})
                img_rec_opt_sample, img_intp_opt_sample = img_rec_opt_sample.squeeze(), img_intp_opt_sample.squeeze()
                # ipt_rec = np.concatenate((img_ipt_sample, img_rec_opt_sample), axis=2).squeeze()
                img_opt_sample = sess.run(img_sample).squeeze()

                # im.imwrite(im.immerge(ipt_rec, padding=img_shape[0] // 8),
                #            '%s/Epoch_(%d)_(%dof%d)_img_rec.png' % (save_dir, ep, it_in_epoch, it_per_epoch))
                im.imwrite(im.immerge(img_intp_opt_sample, n_col=1, padding=0),
                           '%s/Epoch_(%d)_(%dof%d)_img_intp.png' % (save_dir, ep, it_in_epoch, it_per_epoch))
                im.imwrite(im.immerge(img_opt_sample),
                           '%s/Epoch_(%d)_(%dof%d)_img_sample.png' % (save_dir, ep, it_in_epoch, it_per_epoch))

                if fid_stats_path:
                    try:
                        mu_gen, sigma_gen = fid.calculate_activation_statistics(im.im2uint(
                            np.concatenate([sess.run(fid_sample).squeeze() for _ in range(5)], 0)), sess,
                            batch_size=100)
                        fid_value = fid.calculate_frechet_distance(mu_gen, sigma_gen, mu_real, sigma_real)
                    except:
                        fid_value = -1.
                    fid_summary = tf.Summary()
                    fid_summary.value.add(tag='FID', simple_value=fid_value)
                    summary_writer.add_summary(fid_summary, it)
                    print("FID: %s" % fid_value)
            # # summary
            tl.summary(G_loss_dict,
                       step=G_optimizer.iterations,
                       name='G_losses')
            tl.summary(D_loss_dict,
                       step=G_optimizer.iterations,
                       name='D_losses')
            tl.summary({'learning rate': G_lr_scheduler.current_learning_rate},
                       step=G_optimizer.iterations,
                       name='learning rate')

            # sample
            optim_iterations = G_optimizer.iterations.numpy()
            if optim_iterations % args.checkpoint_iterations == 0:
                A, B = next(test_iter)
                A2B, B2A, A2B2A, B2A2B = sample(A, B)
                img = im.immerge(np.concatenate([A, A2B, A2B2A, B, B2A, B2A2B],
                                                axis=0),
                                 n_rows=2)
                im.imwrite(
                    img, py.join(sample_dir,
                                 'iter-%09d.jpg' % optim_iterations))

                # save checkpoint
                file_prefix = os.path.join(
                    checkpoint_dir, 'ep{}-step{}'.format(ep, optim_iterations))
                checkpoint.save(file_prefix)

        # update epoch counter
        ep_cnt.assign_add(1)
Example #3
0
    for x_real in tqdm.tqdm(data_loader, desc='Inner Epoch Loop'):
        x_real = x_real.to(device)

        D_loss_dict = train_D(x_real)
        it_d += 1

        if it_d % args.n_d == 0:
            G_loss_dict = train_G(x_real)
            it_g += 1

        # sample
        if it_g % 100 == 0:
            x_fake = sample(z)
            x_fake = np.transpose(x_fake.data.cpu().numpy(), (0, 2, 3, 1))
            img = im.immerge(x_fake, n_rows=10).squeeze()
            im.imwrite(img, py.join(sample_dir, 'iter-%09d.jpg' % it_g))

    # save checkpoint
    torchlib.save_checkpoint(
        {
            'ep': ep,
            'it_d': it_d,
            'it_g': it_g,
            'D': D.state_dict(),
            'G': G.state_dict(),
            'D_optimizer': D_optimizer.state_dict(),
            'G_optimizer': G_optimizer.state_dict()
        },
        py.join(ckpt_dir, 'Epoch_(%d).ckpt' % ep),
        max_keep=args.epochs)
Example #4
0
                              i - 1] = _b_sample_ipt[..., i - 1] * test_int
        # print(_b_sample_ipt)
        start_time = time.time()
        x_sample_opt_list.append(
            sess.run(x_sample,
                     feed_dict={
                         xa_sample: xa_sample_ipt,
                         _b_sample: _b_sample_ipt,
                         raw_b_sample: raw_a_sample_ipt
                     }))
        duration = time.time() - start_time
        # print('duration of process No.{} attribution({}) of image {}.png is: {}'.format(i,
        #                                                                                 'no-change' if i == 0 else atts[i - 1],
        #                                                                                 idx + 182638 if img is None else img[idx],
        #                                                                                 duration))
    sample = np.concatenate(x_sample_opt_list, 2)

    if test_slide: save_folder = 'sample_testing_slide'
    elif multi_atts: save_folder = 'sample_testing_multi'
    else: save_folder = 'sample_testing'
    save_dir = './output/%s/%s' % (experiment_name, save_folder)
    pylib.mkdir(save_dir)
    # im.imshow(sample.squeeze(0))
    im.imwrite(sample.squeeze(0), img_name.split('.')[0] + '1.png')

    print('%s done!' % img_name)
except:
    traceback.print_exc()
finally:
    sess.close()
Example #5
0
    def train(self):

        it_cnt, update_cnt = tl.counter()

        # saver
        saver = tf.train.Saver(max_to_keep=10)

        # summary writer
        summary_writer = tf.summary.FileWriter(self.config["projectSummary"],
                                               self.sess.graph)

        # initialization
        ckpt_dir = self.config["projectCheckpoints"]
        epoch = self.config["totalEpoch"]
        n_d = self.config["dStep"]
        atts = self.config["selectedAttrs"]
        thres_int = self.config["thresInt"]
        test_int = self.config["sampleThresInt"]
        n_sample = self.config["sampleNum"]
        img_size = self.config["imsize"]
        sample_freq = self.config["sampleEpoch"]
        save_freq = self.config["modelSaveEpoch"]
        lr_base = self.config["gLr"]
        lrDecayEpoch = self.config["lrDecayEpoch"]
        try:
            assert clear == False
            tl.load_checkpoint(ckpt_dir, self.sess)
        except:
            print('NOTE: Initializing all parameters...')
            self.sess.run(tf.global_variables_initializer())

        # train
        try:
            # data for sampling
            xa_sample_ipt, a_sample_ipt = self.val_loader.get_next()
            b_sample_ipt_list = [a_sample_ipt
                                 ]  # the first is for reconstruction
            for i in range(len(atts)):
                tmp = np.array(a_sample_ipt, copy=True)
                tmp[:, i] = 1 - tmp[:, i]  # inverse attribute
                tmp = Celeba.check_attribute_conflict(tmp, atts[i], atts)
                b_sample_ipt_list.append(tmp)

            it_per_epoch = len(self.data_loader) // (self.config["batchSize"] *
                                                     (n_d + 1))
            max_it = epoch * it_per_epoch

            for it in range(self.sess.run(it_cnt), max_it):
                with pylib.Timer(is_output=False) as t:
                    self.sess.run(update_cnt)

                    # which epoch
                    epoch = it // it_per_epoch
                    it_in_epoch = it % it_per_epoch + 1
                    # learning rate
                    lr_ipt = lr_base / (10**(epoch // lrDecayEpoch))

                    # train D
                    for i in range(n_d):
                        d_summary_opt, _ = self.sess.run(
                            [self.d_summary, self.d_step],
                            feed_dict={self.lr: lr_ipt})
                    summary_writer.add_summary(d_summary_opt, it)

                    # train G
                    g_summary_opt, _ = self.sess.run(
                        [self.g_summary, self.g_step],
                        feed_dict={self.lr: lr_ipt})
                    summary_writer.add_summary(g_summary_opt, it)

                    # display
                    if (it + 1) % 100 == 0:
                        print("Epoch: (%3d) (%5d/%5d) Time: %s!" %
                              (epoch, it_in_epoch, it_per_epoch, t))

                    # save
                    if (it + 1) % (save_freq
                                   if save_freq else it_per_epoch) == 0:
                        save_path = saver.save(
                            self.sess, '%s/Epoch_(%d)_(%dof%d).ckpt' %
                            (ckpt_dir, epoch, it_in_epoch, it_per_epoch))
                        print('Model is saved at %s!' % save_path)

                    # sample
                    if (it + 1) % (sample_freq
                                   if sample_freq else it_per_epoch) == 0:

                        x_sample_opt_list = [
                            xa_sample_ipt,
                            np.full((n_sample, img_size, img_size // 10, 3),
                                    -1.0)
                        ]
                        raw_b_sample_ipt = (b_sample_ipt_list[0].copy() * 2 -
                                            1) * thres_int

                        for i, b_sample_ipt in enumerate(b_sample_ipt_list):
                            _b_sample_ipt = (b_sample_ipt * 2 - 1) * thres_int
                            if i > 0:  # i == 0 is for reconstruction
                                _b_sample_ipt[..., i - 1] = _b_sample_ipt[
                                    ..., i - 1] * test_int / thres_int
                            x_sample_opt_list.append(
                                self.sess.run(self.x_sample,
                                              feed_dict={
                                                  self.xa_sample:
                                                  xa_sample_ipt,
                                                  self._b_sample:
                                                  _b_sample_ipt,
                                                  self.raw_b_sample:
                                                  raw_b_sample_ipt
                                              }))
                            last_images = x_sample_opt_list[-1]
                            if i > 0:  # add a mark (+/-) in the upper-left corner to identify add/remove an attribute
                                for nnn in range(last_images.shape[0]):
                                    last_images[nnn, 2:5, 0:7, :] = 1.
                                    if _b_sample_ipt[nnn, i - 1] > 0:
                                        last_images[nnn, 0:7, 2:5, :] = 1.
                                        last_images[nnn, 1:6, 3:4, :] = -1.
                                    last_images[nnn, 3:4, 1:6, :] = -1.
                        sample = np.concatenate(x_sample_opt_list, 2)

                        im.imwrite(im.immerge(sample, n_sample, 1), '%s/Epoch_(%d)_(%dof%d).jpg' % \
                                    (self.config["projectSamples"], epoch, it_in_epoch, it_per_epoch))
        except:
            traceback.print_exc()
        finally:
            save_path = saver.save(
                self.sess, '%s/Epoch_(%d)_(%dof%d).ckpt' %
                (ckpt_dir, epoch, it_in_epoch, it_per_epoch))
            print('Model is saved at %s!' % save_path)
            self.sess.close()
Example #6
0
    def train(self):

        ckpt_dir = self.config["projectCheckpoints"]
        epoch = self.config["totalEpoch"]
        n_d = self.config["dStep"]
        atts = self.config["selectedAttrs"]
        thres_int = self.config["thresInt"]
        test_int = self.config["sampleThresInt"]
        n_sample = self.config["sampleNum"]
        img_size = self.config["imsize"]
        sample_freq = self.config["sampleEpoch"]
        save_freq = self.config["modelSaveEpoch"]
        lr_base = self.config["gLr"]
        lrDecayEpoch = self.config["lrDecayEpoch"]
        n_att = len(self.config["selectedAttrs"])

        if self.config["threads"] >= 0:
            cpu_config = tf.ConfigProto(
                intra_op_parallelism_threads=self.config["threads"] // 2,
                inter_op_parallelism_threads=self.config["threads"] // 2,
                device_count={'CPU': self.config["threads"]})
            cpu_config.gpu_options.allow_growth = True
            sess = tf.Session(config=cpu_config)
        else:
            sess = tl.session()

        data_loader = Celeba(self.config["dataset_path"],
                             self.config["selectedAttrs"],
                             self.config["imsize"],
                             self.config["batchSize"],
                             part='train',
                             sess=sess,
                             crop=(self.config["imCropSize"] > 0))

        val_loader = Celeba(self.config["dataset_path"],
                            self.config["selectedAttrs"],
                            self.config["imsize"],
                            self.config["sampleNum"],
                            part='val',
                            shuffle=False,
                            sess=sess,
                            crop=(self.config["imCropSize"] > 0))

        package = __import__("components." + self.config["modelScriptName"],
                             fromlist=True)
        GencClass = getattr(package, 'Genc')
        GdecClass = getattr(package, 'Gdec')
        DClass = getattr(package, 'D')
        GP = getattr(package, "gradient_penalty")

        package = __import__("components.STU." + self.config["stuScriptName"],
                             fromlist=True)
        GstuClass = getattr(package, 'Gstu')

        Genc = partial(GencClass,
                       dim=self.config["GConvDim"],
                       n_layers=self.config["GLayerNum"],
                       multi_inputs=1)

        Gdec = partial(GdecClass,
                       dim=self.config["GConvDim"],
                       n_layers=self.config["GLayerNum"],
                       shortcut_layers=self.config["skipNum"],
                       inject_layers=self.config["injectLayers"],
                       one_more_conv=self.config["oneMoreConv"])

        Gstu = partial(GstuClass,
                       dim=self.config["stuDim"],
                       n_layers=self.config["skipNum"],
                       inject_layers=self.config["skipNum"],
                       kernel_size=self.config["stuKS"],
                       norm=None,
                       pass_state='stu')

        D = partial(DClass,
                    n_att=n_att,
                    dim=self.config["DConvDim"],
                    fc_dim=self.config["DFcDim"],
                    n_layers=self.config["DLayerNum"])

        # inputs

        xa = data_loader.batch_op[0]
        a = data_loader.batch_op[1]
        b = tf.random_shuffle(a)
        _a = (tf.to_float(a) * 2 - 1) * self.config["thresInt"]
        _b = (tf.to_float(b) * 2 - 1) * self.config["thresInt"]

        xa_sample = tf.placeholder(
            tf.float32,
            shape=[None, self.config["imsize"], self.config["imsize"], 3])
        _b_sample = tf.placeholder(tf.float32, shape=[None, n_att])
        raw_b_sample = tf.placeholder(tf.float32, shape=[None, n_att])
        lr = tf.placeholder(tf.float32, shape=[])

        # generate
        z = Genc(xa)
        zb = Gstu(z, _b - _a)
        xb_ = Gdec(zb, _b - _a)
        with tf.control_dependencies([xb_]):
            za = Gstu(z, _a - _a)
            xa_ = Gdec(za, _a - _a)

        # discriminate
        xa_logit_gan, xa_logit_att = D(xa)
        xb__logit_gan, xb__logit_att = D(xb_)

        wd = tf.reduce_mean(xa_logit_gan) - tf.reduce_mean(xb__logit_gan)
        d_loss_gan = -wd
        gp = GP(D, xa, xb_)
        xa_loss_att = tf.losses.sigmoid_cross_entropy(a, xa_logit_att)
        d_loss = d_loss_gan + gp * 10.0 + xa_loss_att

        xb__loss_gan = -tf.reduce_mean(xb__logit_gan)
        xb__loss_att = tf.losses.sigmoid_cross_entropy(b, xb__logit_att)
        xa__loss_rec = tf.losses.absolute_difference(xa, xa_)
        g_loss = xb__loss_gan + xb__loss_att * 10.0 + xa__loss_rec * self.config[
            "recWeight"]

        d_var = tl.trainable_variables('D')
        d_step = tf.train.AdamOptimizer(
            lr, beta1=self.config["beta1"]).minimize(d_loss, var_list=d_var)
        g_var = tl.trainable_variables('G')
        g_step = tf.train.AdamOptimizer(
            lr, beta1=self.config["beta1"]).minimize(g_loss, var_list=g_var)

        d_summary = tl.summary(
            {
                d_loss_gan: 'd_loss_gan',
                gp: 'gp',
                xa_loss_att: 'xa_loss_att',
            },
            scope='D')

        lr_summary = tl.summary({lr: 'lr'}, scope='Learning_Rate')

        g_summary = tl.summary(
            {
                xb__loss_gan: 'xb__loss_gan',
                xb__loss_att: 'xb__loss_att',
                xa__loss_rec: 'xa__loss_rec',
            },
            scope='G')

        d_summary = tf.summary.merge([d_summary, lr_summary])

        # sample
        test_label = _b_sample - raw_b_sample
        x_sample = Gdec(Gstu(Genc(xa_sample, is_training=False),
                             test_label,
                             is_training=False),
                        test_label,
                        is_training=False)

        it_cnt, update_cnt = tl.counter()

        # saver
        saver = tf.train.Saver(max_to_keep=self.config["max2Keep"])

        # summary writer
        summary_writer = tf.summary.FileWriter(self.config["projectSummary"],
                                               sess.graph)

        # initialization
        if self.config["mode"] == "finetune":
            print("Continute train the model")
            tl.load_checkpoint(ckpt_dir, sess)
            print("Load previous model successfully!")
        else:
            print('Initializing all parameters...')
            sess.run(tf.global_variables_initializer())

        # train
        try:
            # data for sampling
            xa_sample_ipt, a_sample_ipt = val_loader.get_next()
            b_sample_ipt_list = [a_sample_ipt
                                 ]  # the first is for reconstruction
            for i in range(len(atts)):
                tmp = np.array(a_sample_ipt, copy=True)
                tmp[:, i] = 1 - tmp[:, i]  # inverse attribute
                tmp = Celeba.check_attribute_conflict(tmp, atts[i], atts)
                b_sample_ipt_list.append(tmp)

            it_per_epoch = len(data_loader) // (self.config["batchSize"] *
                                                (n_d + 1))
            max_it = epoch * it_per_epoch

            print("Start to train the graph!")
            for it in range(sess.run(it_cnt), max_it):
                with pylib.Timer(is_output=False) as t:
                    sess.run(update_cnt)

                    # which epoch
                    epoch = it // it_per_epoch
                    it_in_epoch = it % it_per_epoch + 1
                    # learning rate
                    lr_ipt = lr_base / (10**(epoch // lrDecayEpoch))

                    # train D
                    for i in range(n_d):
                        d_summary_opt, _ = sess.run([d_summary, d_step],
                                                    feed_dict={lr: lr_ipt})
                    summary_writer.add_summary(d_summary_opt, it)

                    # train G
                    g_summary_opt, _ = sess.run([g_summary, g_step],
                                                feed_dict={lr: lr_ipt})
                    summary_writer.add_summary(g_summary_opt, it)

                    # display
                    if (it + 1) % 100 == 0:
                        print("Epoch: (%3d) (%5d/%5d) Time: %s!" %
                              (epoch, it_in_epoch, it_per_epoch, t))

                    # save
                    if (it + 1) % (save_freq
                                   if save_freq else it_per_epoch) == 0:
                        save_path = saver.save(
                            sess, '%s/Epoch_(%d).ckpt' % (ckpt_dir, epoch))
                        print('Model is saved at %s!' % save_path)

                    # sample
                    if (it + 1) % (sample_freq
                                   if sample_freq else it_per_epoch) == 0:

                        x_sample_opt_list = [
                            xa_sample_ipt,
                            np.full((n_sample, img_size, img_size // 10, 3),
                                    -1.0)
                        ]
                        raw_b_sample_ipt = (b_sample_ipt_list[0].copy() * 2 -
                                            1) * thres_int

                        for i, b_sample_ipt in enumerate(b_sample_ipt_list):
                            _b_sample_ipt = (b_sample_ipt * 2 - 1) * thres_int
                            if i > 0:  # i == 0 is for reconstruction
                                _b_sample_ipt[..., i - 1] = _b_sample_ipt[
                                    ..., i - 1] * test_int / thres_int
                            x_sample_opt_list.append(
                                sess.run(x_sample,
                                         feed_dict={
                                             xa_sample: xa_sample_ipt,
                                             _b_sample: _b_sample_ipt,
                                             raw_b_sample: raw_b_sample_ipt
                                         }))
                            last_images = x_sample_opt_list[-1]
                            if i > 0:  # add a mark (+/-) in the upper-left corner to identify add/remove an attribute
                                for nnn in range(last_images.shape[0]):
                                    last_images[nnn, 2:5, 0:7, :] = 1.
                                    if _b_sample_ipt[nnn, i - 1] > 0:
                                        last_images[nnn, 0:7, 2:5, :] = 1.
                                        last_images[nnn, 1:6, 3:4, :] = -1.
                                    last_images[nnn, 3:4, 1:6, :] = -1.
                        sample = np.concatenate(x_sample_opt_list, 2)

                        im.imwrite(im.immerge(sample, n_sample, 1), '%s/Epoch_(%d)_(%dof%d).jpg' % \
                                    (self.config["projectSamples"], epoch, it_in_epoch, it_per_epoch))
        except:
            traceback.print_exc()
        finally:
            save_path = saver.save(
                sess, '%s/Epoch_(%d)_(%dof%d).ckpt' %
                (ckpt_dir, epoch, it_in_epoch, it_per_epoch))
            print('Model is saved at %s!' % save_path)
            sess.close()
Example #7
0
def runModel(image_url, file_name, test_att, n_slide, image_labels,
             model_type):
    # ==============================================================================
    # =                                    param                                   =
    # ==============================================================================

    parser = argparse.ArgumentParser()
    parser.add_argument('--experiment_name',
                        dest='experiment_name',
                        default="384_shortcut1_inject1_none_hd",
                        help='experiment_name')
    parser.add_argument('--test_att', dest='test_att', help='test_att')
    parser.add_argument('--test_int_min',
                        dest='test_int_min',
                        type=float,
                        default=-1.0,
                        help='test_int_min')
    parser.add_argument('--test_int_max',
                        dest='test_int_max',
                        type=float,
                        default=1.0,
                        help='test_int_max')
    args_ = parser.parse_args()

    if model_type == 0:
        experiment_name = args_.experiment_name
    else:
        experiment_name = "128_custom"

    print("EXPERIMENT NAME WORKING:" + experiment_name)

    with open('./output/%s/setting.txt' % experiment_name) as f:
        args = json.load(f)

    # model
    atts = args['atts']
    n_att = len(atts)
    img_size = args['img_size']
    shortcut_layers = args['shortcut_layers']
    inject_layers = args['inject_layers']
    enc_dim = args['enc_dim']
    dec_dim = args['dec_dim']
    dis_dim = args['dis_dim']
    dis_fc_dim = args['dis_fc_dim']
    enc_layers = args['enc_layers']
    dec_layers = args['dec_layers']
    dis_layers = args['dis_layers']
    # testing
    thres_int = args['thres_int']
    test_int_min = args_.test_int_min
    test_int_max = args_.test_int_max
    # others
    use_cropped_img = args['use_cropped_img']
    n_slide = int(n_slide)

    assert test_att is not None, 'test_att should be chosen in %s' % (
        str(atts))

    # ==============================================================================
    # =                                   graphs                                   =
    # ==============================================================================

    # data
    sess = tl.session()

    # get image
    print(image_url)
    if experiment_name == "128_custom":
        os.system(
            "wget -P /home/tug44606/AttGAN-Tensorflow-master/data/img_align_celeba "
            + image_url)
    else:
        os.system(
            "wget -P /home/tug44606/AttGAN-Tensorflow-master/data/img_crop_celeba "
            + image_url)

    print("Working")

    # pass image with labels to dataset
    te_data = data.Celeba('./data',
                          atts,
                          img_size,
                          1,
                          part='val',
                          sess=sess,
                          crop=not use_cropped_img,
                          image_labels=image_labels,
                          file_name=file_name)

    sample = None

    # models
    Genc = partial(models.Genc, dim=enc_dim, n_layers=enc_layers)
    Gdec = partial(models.Gdec,
                   dim=dec_dim,
                   n_layers=dec_layers,
                   shortcut_layers=shortcut_layers,
                   inject_layers=inject_layers)

    # inputs
    xa_sample = tf.placeholder(tf.float32, shape=[None, img_size, img_size, 3])
    _b_sample = tf.placeholder(tf.float32, shape=[None, n_att])

    # sample
    x_sample = Gdec(Genc(xa_sample, is_training=False),
                    _b_sample,
                    is_training=False)

    # ==============================================================================
    # =                                    test                                    =
    # ==============================================================================

    # initialization
    ckpt_dir = './output/%s/checkpoints' % experiment_name
    print("CHECKPOINT DIR: " + ckpt_dir)
    try:
        tl.load_checkpoint(ckpt_dir, sess)
    except:
        raise Exception(' [*] No checkpoint!')

    save_location = ""
    # sample
    try:
        for idx, batch in enumerate(te_data):
            xa_sample_ipt = batch[0]
            b_sample_ipt = batch[1]

            x_sample_opt_list = []

            for i in range(n_slide - 1, n_slide):
                test_int = (test_int_max -
                            test_int_min) / (n_slide - 1) * i + test_int_min
                _b_sample_ipt = (b_sample_ipt * 2 - 1) * thres_int
                _b_sample_ipt[..., atts.index(test_att)] = test_int
                x_sample_opt_list.append(
                    sess.run(x_sample,
                             feed_dict={
                                 xa_sample: xa_sample_ipt,
                                 _b_sample: _b_sample_ipt
                             }))

            sample = np.concatenate(x_sample_opt_list, 2)
            save_location = '/output/%s/sample_testing_slide_%s/' % (
                experiment_name, test_att)
            save_dir = './output/%s/sample_testing_slide_%s' % (
                experiment_name, test_att)
            pylib.mkdir(save_dir)
            im.imwrite(sample.squeeze(0), '%s/%s' % (save_dir, file_name))

            print('%d.png done!' % (idx + 0))

            if (idx + 1 == te_data._img_num):
                break
    except:
        traceback.print_exc()
    finally:
        sess.close()

    if experiment_name == "128_custom":
        os.system("rm ./data/img_align_celeba/" + file_name)
    else:
        os.system("rm ./data/img_crop_celeba/" + file_name)

    return "http://129.32.22.10:7001" + save_location + file_name
        raw_b_sample_ipt = (b_sample_ipt * 2 - 1) * thres_int  # -0.5,0.5

        for a, w in zip(test_atts, test_ints):
            i = atts.index(a)
            raw_b_sample_ipt[:, i] = raw_b_sample_ipt[:, i] * w

        print(raw_a_sample_ipt)
        print(raw_b_sample_ipt)
        print(raw_b_sample_ipt - raw_a_sample_ipt)

        start_ = time.time()
        x_att_res = sess.run(x_sample,
                             feed_dict={
                                 xa_sample: xa_sample_ipt,
                                 _b_sample: raw_b_sample_ipt,
                                 raw_b_sample: raw_a_sample_ipt
                             })
        end_ = time.time()

        print("cost: {}".format(end_ - start_))
        img_name = os.path.basename(te_data.img_paths[idx])
        im.imwrite(x_att_res.squeeze(0),
                   '%s/%s.png' % (testing_result_dir, img_name))

        print('{}.png done!'.format(img_name))

except:
    traceback.print_exc()
finally:
    sess.close()
Example #9
0
def sample_A2B(A):
    A2B = G_A2B(A, training=False)
    A2B2A = G_B2A(A2B, training=False)
    return A2B, A2B2A


@tf.function
def sample_B2A(B):
    B2A = G_B2A(B, training=False)
    B2A2B = G_A2B(B2A, training=False)
    return B2A, B2A2B


save_dir = py.join(args.experiment_dir, 'generated_images', 'A2B')
py.mkdir(save_dir)
i = 0
for A in A_dataset_test:
    A2B, A2B2A = sample_A2B(A)
    for A_i, A2B_i, A2B2A_i in zip(A, A2B, A2B2A):
        im.imwrite(A2B_i.numpy(), py.join(save_dir, py.name_ext(A_img_paths_test[i])))
        i += 1

save_dir = py.join(args.experiment_dir, 'generated_images', 'B2A')
py.mkdir(save_dir)
i = 0
for B in B_dataset_test:
    B2A, B2A2B = sample_B2A(B)
    for B_i, B2A_i, B2A2B_i in zip(B, B2A, B2A2B):
        im.imwrite(B2A_i.numpy(), py.join(save_dir, py.name_ext(B_img_paths_test[i])))
        i += 1
Example #10
0
        for i, b_sample_ipt in enumerate(b_sample_ipt_list):
            x_sample_opt_list.append(
                sess.run(x_sample,
                         feed_dict={
                             xa_sample: xa_sample_ipt,
                             _b_sample: b_sample_ipt,
                         }))
            mask_sample_opt_list.append(
                sess.run(mask_sample,
                         feed_dict={
                             xa_sample: xa_sample_ipt,
                             _b_sample: b_sample_ipt
                         }))
        sample = np.concatenate(x_sample_opt_list, 2)
        masks = np.concatenate(mask_sample_opt_list, 2)

        save_folder = 'sample_testing'
        save_dir = './output/%s/%s' % (experiment_name, save_folder)
        pylib.mkdir(save_dir)
        im.imwrite(
            sample.squeeze(0), '%s/%06d%s.png' %
            (save_dir, idx + 182638 if img is None else img[idx], '_%s' % ''))
        im.imwrite(
            masks.squeeze(0), '%s/Mask%06d%s.png' %
            (save_dir, idx + 182638 if img is None else img[idx], '_%s' % ''))

        print('%06d.png done!' % (idx + 182638 if img is None else img[idx]))
except:
    traceback.print_exc()
finally:
    sess.close()
Example #11
0
                    vutils.save_image(x_real.detach(),
                                      py.join(
                                          out_path,
                                          'img-%d-%d.jpg' % (num_samples, i)),
                                      normalize=True)
                    num_samples -= 1
                    print('saving ', num_samples)
else:
    idx = get_indices(dataset, args.num)
    data_loader = DataLoader(dataset,
                             batch_size=args.batch_size,
                             sampler=sampler.SubsetRandomSampler(idx),
                             num_workers=0,
                             drop_last=True,
                             pin_memory=False)
    num_samples = args.num_samples_per_class
    out_path = py.join(args.out_dir, args.output_dir)
    os.makedirs(out_path, exist_ok=True)

    while (num_samples > 0):
        for (x_real, labels) in iter(data_loader):
            if num_samples > 0:
                x_real = np.transpose(x_real.data.cpu().numpy(), (0, 2, 3, 1))
                img = im.immerge(x_real, n_rows=1).squeeze()
                im.imwrite(
                    img,
                    py.join(out_path,
                            'img-%d-%d.jpg' % (num_samples, args.num)))
                num_samples -= 1
                print('saving ', num_samples)
Example #12
0
    def test(self):

        img_size = self.config["imsize"]
        n_att = len(self.config["selectedAttrs"])
        atts = self.config["selectedAttrs"]
        thres_int = self.config["thresInt"]
        save_dir = self.config["projectSamples"]
        test_int = self.config["sampleThresInt"]
        # data
        sess = tl.session()

        SpecifiedImages = None
        if self.config["useSpecifiedImage"]:
            SpecifiedImages = self.config["specifiedTestImages"]
        te_data = Celeba(self.config["dataset_path"],
                         atts,
                         img_size,
                         1,
                         part='test',
                         sess=sess,
                         crop=(self.config["imCropSize"] > 0),
                         im_no=SpecifiedImages)

        # models
        package = __import__(self.config["com_base"] +
                             self.config["modelScriptName"],
                             fromlist=True)
        GencClass = getattr(package, 'Genc')
        GdecClass = getattr(package, 'Gdec')
        package = __import__(self.config["com_base"] +
                             self.config["stuScriptName"],
                             fromlist=True)
        GstuClass = getattr(package, 'Gstu')

        Genc = partial(GencClass,
                       dim=self.config["GConvDim"],
                       n_layers=self.config["GLayerNum"],
                       multi_inputs=1)

        Gdec = partial(GdecClass,
                       dim=self.config["GConvDim"],
                       n_layers=self.config["GLayerNum"],
                       shortcut_layers=self.config["skipNum"],
                       inject_layers=self.config["injectLayers"],
                       one_more_conv=self.config["oneMoreConv"])

        Gstu = partial(GstuClass,
                       dim=self.config["stuDim"],
                       n_layers=self.config["skipNum"],
                       inject_layers=self.config["stuInjectLayers"],
                       kernel_size=self.config["stuKS"],
                       norm=None,
                       pass_state="stu")

        # inputs
        xa_sample = tf.placeholder(tf.float32,
                                   shape=[None, img_size, img_size, 3])
        _b_sample = tf.placeholder(tf.float32, shape=[None, n_att])
        raw_b_sample = tf.placeholder(tf.float32, shape=[None, n_att])

        # sample
        test_label = _b_sample - raw_b_sample
        x_sample = Gdec(Gstu(Genc(xa_sample, is_training=False),
                             test_label,
                             is_training=False),
                        test_label,
                        is_training=False)
        print("Graph build success!")

        # ==============================================================================
        # =                                    test                                    =
        # ==============================================================================

        # Load pretrained model
        ckpt_dir = os.path.join(self.config["projectCheckpoints"],
                                self.config["ckpt_prefix"])
        restorer = tf.train.Saver()
        restorer.restore(sess, ckpt_dir)
        print("Load pretrained model successfully!")
        # tl.load_checkpoint(ckpt_dir, sess)

        # test
        print("Start to test!")
        test_slide = False
        multi_atts = False

        try:
            # multi_atts = test_atts is not None
            for idx, batch in enumerate(te_data):
                xa_sample_ipt = batch[0]
                a_sample_ipt = batch[1]
                b_sample_ipt_list = [
                    a_sample_ipt.copy()
                    for _ in range(n_slide if test_slide else 1)
                ]
                # if test_slide: # test_slide
                #     for i in range(n_slide):
                #         test_int = (test_int_max - test_int_min) / (n_slide - 1) * i + test_int_min
                #         b_sample_ipt_list[i] = (b_sample_ipt_list[i]*2-1) * thres_int
                #         b_sample_ipt_list[i][..., atts.index(test_att)] = test_int
                # elif multi_atts: # test_multiple_attributes
                #     for a in test_atts:
                #         i = atts.index(a)
                #         b_sample_ipt_list[-1][:, i] = 1 - b_sample_ipt_list[-1][:, i]
                #         b_sample_ipt_list[-1] = Celeba.check_attribute_conflict(b_sample_ipt_list[-1], atts[i], atts)
                # else: # test_single_attributes
                for i in range(len(atts)):
                    tmp = np.array(a_sample_ipt, copy=True)
                    tmp[:, i] = 1 - tmp[:, i]  # inverse attribute
                    tmp = Celeba.check_attribute_conflict(tmp, atts[i], atts)
                    b_sample_ipt_list.append(tmp)

                x_sample_opt_list = [
                    xa_sample_ipt,
                    np.full((1, img_size, img_size // 10, 3), -1.0)
                ]
                raw_a_sample_ipt = a_sample_ipt.copy()
                raw_a_sample_ipt = (raw_a_sample_ipt * 2 - 1) * thres_int
                for i, b_sample_ipt in enumerate(b_sample_ipt_list):
                    _b_sample_ipt = (b_sample_ipt * 2 - 1) * thres_int
                    if not test_slide:
                        # if multi_atts: # i must be 0
                        #     for t_att, t_int in zip(test_atts, test_ints):
                        #         _b_sample_ipt[..., atts.index(t_att)] = _b_sample_ipt[..., atts.index(t_att)] * t_int
                        if i > 0:  # i == 0 is for reconstruction
                            _b_sample_ipt[...,
                                          i - 1] = _b_sample_ipt[..., i -
                                                                 1] * test_int
                    x_sample_opt_list.append(
                        sess.run(x_sample,
                                 feed_dict={
                                     xa_sample: xa_sample_ipt,
                                     _b_sample: _b_sample_ipt,
                                     raw_b_sample: raw_a_sample_ipt
                                 }))
                sample = np.concatenate(x_sample_opt_list, 2)

                # if test_slide:     save_folder = 'sample_testing_slide'
                # # elif multi_atts:   save_folder = 'sample_testing_multi'
                # else:              save_folder = 'sample_testing'
                # save_dir = './output/%s/%s' % (experiment_name, save_folder)

                # pylib.mkdir(save_dir)
                im.imwrite(
                    sample.squeeze(0), '%s/%06d%s.png' %
                    (save_dir, idx + 182638 if SpecifiedImages is None else
                     SpecifiedImages[idx], '_%s' %
                     (str("jibaride")) if multi_atts else ''))

                print('%06d.png done!' %
                      (idx + 182638
                       if SpecifiedImages is None else SpecifiedImages[idx]))
        except:
            traceback.print_exc()
        finally:
            sess.close()
Example #13
0
        ep_cnt.assign_add(1)

        # train for an epoch
        for x_real in tqdm.tqdm(dataset, desc='Inner Epoch Loop', total=len_dataset):
            D_loss_dict = train_D(x_real)
            tl.summary(D_loss_dict, step=D_optimizer.iterations, name='D_losses')

            if D_optimizer.iterations.numpy() % args.n_d == 0:
                G_loss_dict = train_G()
                tl.summary(G_loss_dict, step=G_optimizer.iterations, name='G_losses')

            # sample
            if G_optimizer.iterations.numpy() % 100 == 0:
                x_fake = sample(z)
                img = im.immerge(x_fake, n_rows=10).squeeze()
                im.imwrite(img, py.join(sample_dir, 'iter-%09d.jpg' % G_optimizer.iterations.numpy()))

        # save checkpoint
        checkpoint.save(ep)
        #calculate fid after 500 epoch
        #be nice to me graphic card
        if ep > 400:
            #generate fake pics
            samples_dir = "./evaluation/generated"
            py.mkdir(samples_dir)
            for i in range(0,1000):
                z_fid = tf.random.normal(shape=(1, 1, 1, args.z_dim))
                x_fake_fid = G(z_fid, training=False)
                img_fid = im.immerge_(x_fake_fid).squeeze()
                im.imwrite( img_fid, py.join(samples_dir, 'fake%03d.jpg' %i))
            #generate npy & compare & savefid
Example #14
0
    B2A = G_B2A(B, training=False)
    B2A2B = G_A2B(B2A, training=False)
    return B2A, B2A2B


# run
save_dir = py.join(args.experiment_dir, 'samples_testing', 'A2B')
py.mkdir(save_dir)
i = 0
for A in A_dataset_test:
    A2B, A2B2A = sample_A2B(A)
    for A_i, A2B_i, A2B2A_i in zip(A, A2B, A2B2A):
        img = np.concatenate(
            [A_i.numpy(), A2B_i.numpy(),
             A2B2A_i.numpy()], axis=1)
        im.imwrite(img, py.join(save_dir, py.name_ext(A_img_paths_test[i])))
        i += 1

save_dir = py.join(args.experiment_dir, 'samples_testing', 'B2A')
py.mkdir(save_dir)
i = 0
for B in B_dataset_test:
    B2A, B2A2B = sample_B2A(B)
    for B_i, B2A_i, B2A2B_i in zip(B, B2A, B2A2B):
        img = np.concatenate(
            [B_i.numpy(), B2A_i.numpy(),
             B2A2B_i.numpy()], axis=1)
        im.imwrite(img, py.join(save_dir, py.name_ext(B_img_paths_test[i])))
        i += 1

# 生成的B的数据集
Example #15
0
        sum_loss_D, it_D = 0, 0
        sum_loss_G, it_G = 0, 0
        # train for an epoch
        for x_real in dataset:
            D_loss_dict = train_D(x_real)
            sum_loss_D += float(D_loss_dict['D_loss'])
            it_D += 1

            if D_optimizer.iterations.numpy() % args.n_d == 0:
                G_loss_dict = train_G()
                sum_loss_G = float(G_loss_dict['g_loss'])
                it_G += 1

        with open(path.join(summary_dir, 'g_loss.txt'), 'a+') as file:
            file.write(str(sum_loss_G / it_G) + '\n')

        with open(path.join(summary_dir, 'd_loss.txt'), 'a+') as file:
            file.write(str(sum_loss_D / it_D) + '\n')

        x_fake = sample(z)
        img = im.immerge(x_fake, n_rows=10).squeeze()
        im.imwrite(img, py.join(sample_dir, '1', 'iter-%4d.jpg' % ep))

        x_fake = sample(z2)
        img = im.immerge(x_fake, n_rows=10).squeeze()
        im.imwrite(img, py.join(sample_dir, '2', 'iter-%4d.jpg' % ep))

        # save checkpoint
        checkpoint.save(ep)
Example #16
0
            if G_optimizer.iterations.numpy() % 100 == 0:
                ground_truth = get_Mask(x_real)
                x1_real = get_PET(x_real)
                x2_real = get_CT(x_real)
                x1_fake, x2_fake = G(ground_truth, training=True)

                x1_fake = x1_fake[:,:,:,0]
                x2_fake = x2_fake[:,:,:,0]



                img1_real = im.immerge(x1_real.numpy(),n_rows=10)
                img2_real = im.immerge(x2_real.numpy(),n_rows=10)
                img3_real = im.immerge(ground_truth.numpy(),n_rows=10)
                img4_real = tf.concat([img1_real[:,:,0],img2_real[:,:,0],img3_real[:,:,0]],-1)
                im.imwrite(img4_real.numpy(), py.join(sample_dir, 'img4R-iter-%09d.jpg' % G_optimizer.iterations.numpy()))

                #print('\n Shape of the generated images:')
                #print('x1_fake.shape = ', x1_fake.shape)
                #print('x2_fake.shape = ', x2_fake.shape)

                img1 = im.immerge(x1_fake.numpy(), n_rows=10)
                img2 = im.immerge(x2_fake.numpy(), n_rows=10)
                img3 = im.immerge(ground_truth.numpy(), n_rows=10)
                img4 = tf.concat([img1,img2,img3[:,:,0]],-1)
                im.imwrite(img4.numpy(), py.join(sample_dir, 'img4-iter-%09d.jpg' % G_optimizer.iterations.numpy()))

                # Added by K.C: update the mean loss functions every 100 iterations, and plot them out
                D1_loss_summary.append(D_loss_dict.get('d1_loss','').numpy())
                D1_GP_summary.append(D_loss_dict.get('gp1', '').numpy())
                D2_loss_summary.append(D_loss_dict.get('d2_loss','').numpy())
Example #17
0
                # batch data
                z_ipt = np.random.normal(size=[batch_size, z_dim])

                g_summary_opt, _ = sess.run([g_summary, g_step],
                                            feed_dict={z: z_ipt})
                summary_writer.add_summary(g_summary_opt, it)

            # display
            if it % 1 == 0:
                print("Epoch: (%3d) (%5d/%5d)" % (ep, i + 1, it_per_epoch))

            # sample
            if it % 1000 == 0:
                f_sample_opt = sess.run(f_sample,
                                        feed_dict={
                                            z_sample: z_ipt_sample
                                        }).squeeze()

                save_dir = './output/%s/sample_training' % experiment_name
                pylib.mkdir(save_dir)
                im.imwrite(
                    im.immerge(f_sample_opt), '%s/Epoch_(%d)_(%dof%d).jpg' %
                    (save_dir, ep, i + 1, it_per_epoch))

        save_path = saver.save(sess, '%s/Epoch_%d.ckpt' % (ckpt_dir, ep))
        print('Model is saved in file: %s' % save_path)
except:
    traceback.print_exc()
finally:
    sess.close()
Example #18
0
def train():
    # ===================================== Args =====================================
    args = parse_args()
    output_dir = os.path.join('output', args.dataset)
    os.makedirs(output_dir, exist_ok=True)
    settings_path = os.path.join(output_dir, 'settings.json')
    pylib.args_to_json(settings_path, args)

    # ===================================== Data =====================================
    A_img_paths = pylib.glob(
        os.path.join(args.datasets_dir, args.dataset, 'trainA'), '*.png')
    B_img_paths = pylib.glob(
        os.path.join(args.datasets_dir, args.dataset, 'trainB'), '*.png')
    print(f'len(A_img_paths) = {len(A_img_paths)}')
    print(f'len(B_img_paths) = {len(B_img_paths)}')
    load_size = [args.load_size_height, args.load_size_width]
    crop_size = [args.crop_size_height, args.crop_size_width]
    A_B_dataset, len_dataset = data.make_zip_dataset(A_img_paths,
                                                     B_img_paths,
                                                     args.batch_size,
                                                     load_size,
                                                     crop_size,
                                                     training=True,
                                                     repeat=False)

    A2B_pool = data.ItemPool(args.pool_size)
    B2A_pool = data.ItemPool(args.pool_size)

    A_img_paths_test = pylib.glob(
        os.path.join(args.datasets_dir, args.dataset, 'testA'), '*.png')
    B_img_paths_test = pylib.glob(
        os.path.join(args.datasets_dir, args.dataset, 'testB'), '*.png')
    A_B_dataset_test, _ = data.make_zip_dataset(A_img_paths_test,
                                                B_img_paths_test,
                                                args.batch_size,
                                                load_size,
                                                crop_size,
                                                training=False,
                                                repeat=True)

    # ===================================== Models =====================================
    model_input_shape = crop_size + [
        3
    ]  # [args.crop_size_height, args.crop_size_width, 3]

    G_A2B = module.ResnetGenerator(input_shape=model_input_shape, n_blocks=6)
    G_B2A = module.ResnetGenerator(input_shape=model_input_shape, n_blocks=6)

    D_A = module.ConvDiscriminator(input_shape=model_input_shape)
    D_B = module.ConvDiscriminator(input_shape=model_input_shape)

    d_loss_fn, g_loss_fn = tf2gan.get_adversarial_losses_fn(
        args.adversarial_loss_mode)
    cycle_loss_fn = tf.losses.MeanAbsoluteError()
    identity_loss_fn = tf.losses.MeanAbsoluteError()

    G_lr_scheduler = module.LinearDecay(args.lr, args.epochs * len_dataset,
                                        args.epoch_decay * len_dataset)
    D_lr_scheduler = module.LinearDecay(args.lr, args.epochs * len_dataset,
                                        args.epoch_decay * len_dataset)
    G_optimizer = tf.keras.optimizers.Adam(learning_rate=G_lr_scheduler,
                                           beta_1=args.beta_1)
    D_optimizer = tf.keras.optimizers.Adam(learning_rate=D_lr_scheduler,
                                           beta_1=args.beta_1)

    # ===================================== Training steps =====================================
    @tf.function
    def train_generators(A, B):
        with tf.GradientTape() as t:
            A2B = G_A2B(A, training=True)
            B2A = G_B2A(B, training=True)
            A2B2A = G_B2A(A2B, training=True)
            B2A2B = G_A2B(B2A, training=True)
            A2A = G_B2A(A, training=True)
            B2B = G_A2B(B, training=True)

            A2B_d_logits = D_B(A2B, training=True)
            B2A_d_logits = D_A(B2A, training=True)

            A2B_g_loss = g_loss_fn(A2B_d_logits)
            B2A_g_loss = g_loss_fn(B2A_d_logits)
            A2B2A_cycle_loss = cycle_loss_fn(A, A2B2A)
            B2A2B_cycle_loss = cycle_loss_fn(B, B2A2B)
            A2A_id_loss = identity_loss_fn(A, A2A)
            B2B_id_loss = identity_loss_fn(B, B2B)

            G_loss = (A2B_g_loss + B2A_g_loss) + (
                A2B2A_cycle_loss +
                B2A2B_cycle_loss) * args.cycle_loss_weight + (
                    A2A_id_loss + B2B_id_loss) * args.identity_loss_weight

        G_grad = t.gradient(
            G_loss, G_A2B.trainable_variables + G_B2A.trainable_variables)
        G_optimizer.apply_gradients(
            zip(G_grad, G_A2B.trainable_variables + G_B2A.trainable_variables))

        return A2B, B2A, {
            'A2B_g_loss': A2B_g_loss,
            'B2A_g_loss': B2A_g_loss,
            'A2B2A_cycle_loss': A2B2A_cycle_loss,
            'B2A2B_cycle_loss': B2A2B_cycle_loss,
            'A2A_id_loss': A2A_id_loss,
            'B2B_id_loss': B2B_id_loss
        }

    @tf.function
    def train_discriminators(A, B, A2B, B2A):
        with tf.GradientTape() as t:
            A_d_logits = D_A(A, training=True)
            B2A_d_logits = D_A(B2A, training=True)
            B_d_logits = D_B(B, training=True)
            A2B_d_logits = D_B(A2B, training=True)

            A_d_loss, B2A_d_loss = d_loss_fn(A_d_logits, B2A_d_logits)
            B_d_loss, A2B_d_loss = d_loss_fn(B_d_logits, A2B_d_logits)
            D_A_gp = tf2gan.gradient_penalty(functools.partial(D_A,
                                                               training=True),
                                             A,
                                             B2A,
                                             mode=args.gradient_penalty_mode)
            D_B_gp = tf2gan.gradient_penalty(functools.partial(D_B,
                                                               training=True),
                                             B,
                                             A2B,
                                             mode=args.gradient_penalty_mode)

            D_loss = (A_d_loss + B2A_d_loss) + (B_d_loss + A2B_d_loss) + (
                D_A_gp + D_B_gp) * args.gradient_penalty_weight

        D_grad = t.gradient(D_loss,
                            D_A.trainable_variables + D_B.trainable_variables)
        D_optimizer.apply_gradients(
            zip(D_grad, D_A.trainable_variables + D_B.trainable_variables))

        return {
            'A_d_loss': A_d_loss + B2A_d_loss,
            'B_d_loss': B_d_loss + A2B_d_loss,
            'D_A_gp': D_A_gp,
            'D_B_gp': D_B_gp
        }

    def train_step(A, B):
        A2B, B2A, G_loss_dict = train_generators(A, B)

        # cannot autograph `A2B_pool`
        A2B = A2B_pool(
            A2B)  # or A2B = A2B_pool(A2B.numpy()), but it is much slower
        B2A = B2A_pool(B2A)  # because of the communication between CPU and GPU

        D_loss_dict = train_discriminators(A, B, A2B, B2A)

        return G_loss_dict, D_loss_dict

    @tf.function
    def sample(A, B):
        A2B = G_A2B(A, training=False)
        B2A = G_B2A(B, training=False)
        A2B2A = G_B2A(A2B, training=False)
        B2A2B = G_A2B(B2A, training=False)
        return A2B, B2A, A2B2A, B2A2B

    # ===================================== Runner code =====================================
    # epoch counter
    ep_cnt = tf.Variable(initial_value=0, trainable=False, dtype=tf.int64)

    # checkpoint
    checkpoint = tf2lib.Checkpoint(dict(G_A2B=G_A2B,
                                        G_B2A=G_B2A,
                                        D_A=D_A,
                                        D_B=D_B,
                                        G_optimizer=G_optimizer,
                                        D_optimizer=D_optimizer,
                                        ep_cnt=ep_cnt),
                                   os.path.join(output_dir, 'checkpoints'),
                                   max_to_keep=5)
    try:  # restore checkpoint including the epoch counter
        checkpoint.restore().assert_existing_objects_matched()
    except Exception as e:
        print(e)

    # summary
    train_summary_writer = tf.summary.create_file_writer(
        os.path.join(output_dir, 'summaries', 'train'))

    # sample
    test_iter = iter(A_B_dataset_test)
    sample_dir = os.path.join(output_dir, 'samples_training')
    os.makedirs(sample_dir, exist_ok=True)

    # main loop
    with train_summary_writer.as_default():
        for ep in tqdm.trange(args.epochs, desc='Epoch Loop'):
            if ep < ep_cnt:
                continue

            # update epoch counter
            ep_cnt.assign_add(1)

            # train for an epoch
            for A, B in tqdm.tqdm(A_B_dataset,
                                  desc='Inner Epoch Loop',
                                  total=len_dataset):
                G_loss_dict, D_loss_dict = train_step(A, B)

                # # summary
                tf2lib.summary(G_loss_dict,
                               step=G_optimizer.iterations,
                               name='G_losses')
                tf2lib.summary(D_loss_dict,
                               step=G_optimizer.iterations,
                               name='D_losses')
                tf2lib.summary(
                    {'learning rate': G_lr_scheduler.current_learning_rate},
                    step=G_optimizer.iterations,
                    name='learning rate')

                # sample
                if G_optimizer.iterations.numpy() % 100 == 0:
                    A, B = next(test_iter)
                    A2B, B2A, A2B2A, B2A2B = sample(A, B)
                    img = imlib.immerge(np.concatenate(
                        [A, A2B, A2B2A, B, B2A, B2A2B], axis=0),
                                        n_rows=6)
                    imlib.imwrite(
                        img,
                        os.path.join(
                            sample_dir,
                            'iter-%09d.jpg' % G_optimizer.iterations.numpy()))

            # save checkpoint
            checkpoint.save(ep)
Example #19
0
use_gpu = torch.cuda.is_available()
device = torch.device("cuda" if use_gpu else "cpu")
torch.manual_seed(0)

if args.dataset in ['cifar10', 'fashion_mnist', 'mnist', 'imagenet']:  # 32x32
    output_channels = 3
    n_G_upsamplings = n_D_downsamplings = 3

for experiment in experiment_names:
    output_dir = py.join('output_new', 'output', experiment)

    G = module.ConvGenerator(args.z_dim,
                             output_channels,
                             n_upsamplings=n_G_upsamplings).to(device)

    # load checkpoint if exists
    ckpt_dir = py.join(output_dir, 'checkpoints', args.checkpoint_name)
    out_dir = py.join(output_dir, args.output_dir)
    py.mkdir(ckpt_dir)
    py.mkdir(out_dir)
    ckpt = torchlib.load_checkpoint(ckpt_dir)
    G.load_state_dict(ckpt['G'])

    for i in range(args.num_samples):
        z = torch.randn(args.batch_size, args.z_dim, 1, 1).to(device)
        x_fake = G(z).detach()
        x_fake = np.transpose(x_fake.data.cpu().numpy(), (0, 2, 3, 1))
        img = im.immerge(x_fake, n_rows=1).squeeze()
        im.imwrite(img, py.join(out_dir, 'img-%d.jpg' % i))
        print(py.join(out_dir, 'img-%d.jpg' % i))
Example #20
0
    raise Exception(' [*] No checkpoint!')

# sample
try:
    for idx, batch in enumerate(te_data):
        xa_sample_ipt = batch[0]
        a_sample_ipt = batch[1]
        b_sample_ipt = np.array(a_sample_ipt, copy=True)
        for a in test_atts:
            i = atts.index(a)
            b_sample_ipt[:, i] = 1 - b_sample_ipt[:, i]   # inverse attribute
            b_sample_ipt = data.Celeba.check_attribute_conflict(b_sample_ipt, atts[i], atts)

        x_sample_opt_list = [xa_sample_ipt, np.full((1, img_size, img_size // 10, 3), -1.0)]
        _b_sample_ipt = (b_sample_ipt * 2 - 1) * thres_int
        for a, i in zip(test_atts, test_ints):
            _b_sample_ipt[..., atts.index(a)] = _b_sample_ipt[..., atts.index(a)] * i / thres_int
        x_sample_opt_list.append(sess.run(x_sample, feed_dict={xa_sample: xa_sample_ipt, _b_sample: _b_sample_ipt}))
        sample = np.concatenate(x_sample_opt_list, 2)

        save_dir = './output/%s/sample_testing_multi_%s' % (experiment_name, str(test_atts))
        pylib.mkdir(save_dir)
        im.imwrite(sample.squeeze(0), '%s/%d.png' % (save_dir, idx + 182638))

        print('%d.png done!' % (idx + 182638))

except Exception:
    traceback.print_exc()
finally:
    sess.close()
Example #21
0
                for z_ipt_sample in z_ipt_samples:
                    f_sample_opt = sess.run(f_sample,
                                            feed_dict={
                                                z_sample: z_ipt_sample,
                                                c_sample: c_ipt_sample
                                            }).squeeze()

                    k_prod = 1
                    for k in ks:
                        k_prod *= k
                        f_sample_opts_k = list(f_sample_opt)
                        for idx in range(len(f_sample_opts_k)):
                            if idx % (len(f_sample_opts_k) / k_prod) != 0:
                                f_sample_opts_k[idx] = np.zeros_like(
                                    f_sample_opts_k[idx])
                        merge.append(np.concatenate(f_sample_opts_k, axis=1))
                merge = np.concatenate(merge, axis=0)

                save_dir = './output/%s/sample_training' % experiment_name
                pylib.mkdir(save_dir)
                im.imwrite(
                    merge, '%s/Epoch_(%d)_(%dof%d).jpg' %
                    (save_dir, ep, i + 1, it_per_epoch))

        save_path = saver.save(sess, '%s/Epoch_%d.ckpt' % (ckpt_dir, ep))
        print('Model is saved in file: %s' % save_path)
except:
    traceback.print_exc()
finally:
    sess.close()
Example #22
0
with train_summary_writer.as_default():
    for ep in tqdm.trange(args.epochs, desc='Epoch Loop'):
        if ep < ep_cnt:
            continue

        # update epoch counter
        ep_cnt.assign_add(1)

        # train for an epoch
        for A, B in tqdm.tqdm(A_B_dataset, desc='Inner Epoch Loop', total=len_dataset):
            G_loss_dict, D_loss_dict = train_step(A, B)
            ## Logging to neptune
            #for k in  G_loss_dict : neptune.log_metric(k,G_loss_dict[k])
            #for k in  D_loss_dict : neptune.log_metric(k,D_loss_dict[k])
            # # summary
            tl.summary(G_loss_dict, step=G_optimizer.iterations, name='G_losses')
            tl.summary(D_loss_dict, step=G_optimizer.iterations, name='D_losses')
            tl.summary({'learning rate': G_lr_scheduler.current_learning_rate}, step=G_optimizer.iterations, name='learning rate')

            # sample
            if G_optimizer.iterations.numpy() % 100 == 0:
                A, B = next(test_iter)
                A2B, B2A, A2B2A, B2A2B = sample(A, B)
                img = im.immerge(np.concatenate([A, A2B, A2B2A, B, B2A, B2A2B], axis=0), n_rows=2)
                
                im.imwrite(img, py.join(sample_dir, 'iter-%09d.jpg' % G_optimizer.iterations.numpy()))
                #neptune.log_image( 'iter-%09d'%G_optimizer.iterations.numpy(), py.join(sample_dir, 'iter-%09d.jpg' % G_optimizer.iterations.numpy()))

        # save checkpoint
        checkpoint.save(ep)
Example #23
0
                sess.run(x_sample,
                         feed_dict={
                             xa_sample: xa_sample_ipt,
                             _b_sample: _b_sample_ipt,
                             raw_b_sample: raw_a_sample_ipt
                         }))
            duration = time.time() - start_time
            # print('duration of process No.{} attribution({}) of image {}.png is: {}'.format(i,
            #                                                                                 'no-change' if i == 0 else atts[i - 1],
            #                                                                                 idx + 182638 if img is None else img[idx],
            #                                                                                 duration))
        sample = np.concatenate(x_sample_opt_list, 2)

        if test_slide: save_folder = 'sample_testing_slide'
        elif multi_atts: save_folder = 'sample_testing_multi'
        else: save_folder = 'sample_testing'
        save_dir = './output_train_diff_att/%s/%s' % (experiment_name,
                                                      save_folder)
        pylib.mkdir(save_dir)
        # im.imshow(sample.squeeze(0))
        im.imwrite(
            sample.squeeze(0), '%s/%06d%s.png' %
            (save_dir, idx + 182638 if img is None else img[idx], '_%s' %
             (str(test_atts)) if multi_atts else ''))

        print('%06d.png done!' % (idx + 182638 if img is None else img[idx]))
except:
    traceback.print_exc()
finally:
    sess.close()
Example #24
0
                            ..., i - 1] * test_int / thres_int
                    x_sample_opt_list.append(
                        sess.run(x_sample,
                                 feed_dict={
                                     xa_sample: xa_sample_ipt,
                                     _b_sample: _b_sample_ipt,
                                     raw_b_sample: raw_b_sample_ipt
                                 }))
                    last_images = x_sample_opt_list[-1]
                    if i > 0:  # add a mark (+/-) in the upper-left corner to identify add/remove an attribute
                        for nnn in range(last_images.shape[0]):
                            last_images[nnn, 2:5, 0:7, :] = 1.
                            if _b_sample_ipt[nnn, i - 1] > 0:
                                last_images[nnn, 0:7, 2:5, :] = 1.
                                last_images[nnn, 1:6, 3:4, :] = -1.
                            last_images[nnn, 3:4, 1:6, :] = -1.
                sample = np.concatenate(x_sample_opt_list, 2)

                save_dir = './output/%s/sample_training' % experiment_name
                pylib.mkdir(save_dir)
                im.imwrite(im.immerge(sample, n_sample, 1), '%s/Epoch_(%d)_(%dof%d).jpg' % \
                                                            (save_dir, epoch, it_in_epoch, it_per_epoch))
except:
    traceback.print_exc()
finally:
    save_path = saver.save(
        sess, '%s/Epoch_(%d)_(%dof%d).ckpt' %
        (ckpt_dir, epoch, it_in_epoch, it_per_epoch))
    print('Model is saved at %s!' % save_path)
    sess.close()
Example #25
0
 def run(epoch, iter):
     x_f_opt = sess.run(x_f)
     sample = im.immerge(x_f_opt, n_rows=int(args.n_samples**0.5))
     im.imwrite(sample, '%s/Epoch-%d_Iter-%d.jpg' % (save_dir, epoch, iter))