Beispiel #1
0
    def infer(self, source_obj, model_dir, save_dir):
        source_provider = InjectDataProvider(source_obj)

        source_iter = source_provider.get_iter(self.batch_size)

        tf.global_variables_initializer().run()
        saver = tf.train.Saver(var_list=self.retrieve_generator_vars())
        self.restore_model(saver, model_dir)

        def save_imgs(imgs, count):
            p = os.path.join(save_dir, "inferred_%04d.png" % count)
            save_concat_images(imgs, img_path=p)
            print("generated images saved at %s" % p)

        count = 0
        batch_buffer = list()
        for source_imgs in source_iter:
            fake_imgs = self.generate_fake_samples(source_imgs)[0]
            merged_fake_images = merge(scale_back(fake_imgs),
                                       [self.batch_size, 1])
            batch_buffer.append(merged_fake_images)
            if len(batch_buffer) == 10:
                save_imgs(batch_buffer, count)
                batch_buffer = list()
            count += 1
        if batch_buffer:
            # last batch
            save_imgs(batch_buffer, count)
    def validate_model(self, images, epoch, step):

        fake_imgs, real_imgs, d_loss, g_loss, l1_loss = self.generate_fake_samples(images)
        print("Sample: d_loss: %.5f, g_loss: %.5f, l1_loss: %.5f" % (d_loss, g_loss, l1_loss))

        merged_fake_images = merge(scale_back(fake_imgs), [self.batch_size, 1])
        merged_real_images = merge(scale_back(real_imgs), [self.batch_size, 1])
        merged_pair = np.concatenate([merged_real_images, merged_fake_images], axis=1)

        model_id, _ = self.get_model_id_and_dir()

        model_sample_dir = os.path.join(self.sample_dir, model_id)
        if not os.path.exists(model_sample_dir):
            os.makedirs(model_sample_dir)

        sample_img_path = os.path.join(model_sample_dir, "sample_%02d_%04d.png" % (epoch, step))
        misc.imsave(sample_img_path, merged_pair)
Beispiel #3
0
    def test(self, source_provider, model_dir, save_dir):
        source_len = len(source_provider.data.examples)

        total_count = source_len
        source_len = min(10, source_len)

        source_iter = source_provider.get_iter(source_len)

        tf.global_variables_initializer().run()

        saver = tf.train.Saver(var_list=self.retrieve_generator_vars())
        self.restore_model(saver, model_dir)

        def save_imgs(imgs, count, threshold):
            p = os.path.join(save_dir,
                             "inferred_%04d_%.2f.png" % (count, threshold))
            save_concat_images(imgs, img_path=p)
            print("generated images saved at %s" % p)

        count = 0
        threshold = 0.1
        batch_buffer = list()
        accuracy = 0.0
        for source_imgs in source_iter:
            fake_imgs, real_imgs, d_loss, g_loss, l1_loss = self.generate_fake_samples(
                source_imgs)
            img_shape = fake_imgs.shape

            fake_imgs_reshape = np.reshape(
                np.array(fake_imgs),
                [img_shape[0], img_shape[1] * img_shape[2] * img_shape[3]])
            real_imgs_reshape = np.reshape(
                np.array(real_imgs),
                [img_shape[0], img_shape[1] * img_shape[2] * img_shape[3]])

            # threshold
            for bt in range(fake_imgs_reshape.shape[0]):
                for it in range(fake_imgs_reshape.shape[1]):
                    if fake_imgs_reshape[bt][it] >= threshold:
                        fake_imgs_reshape[bt][it] = 1.0
                    else:
                        fake_imgs_reshape[bt][it] = -1.0

            for bt in range(fake_imgs_reshape.shape[0]):
                over = 0.0
                less = 0.0
                base = 0.0
                for it in range(fake_imgs_reshape.shape[1]):
                    if real_imgs_reshape[bt][
                            it] == 1.0 and fake_imgs_reshape[bt][it] != 1.0:
                        over += 1
                    if real_imgs_reshape[bt][it] != 1.0 and fake_imgs_reshape[
                            bt][it] == -1.0:
                        less += 1
                    if real_imgs_reshape[bt][it] != 1.0:
                        base += 1
                print("over:{} - under:{} - base:{}".format(over, less, base))
                accuracy += 1 - ((over + less) / base)
                print("avg acc:{}".format(1 - ((over + less) / base)))

            fake_imgs_reshape = np.reshape(fake_imgs_reshape, fake_imgs.shape)
            real_imgs_reshape = np.reshape(real_imgs_reshape, real_imgs.shape)
            merged_fake_images = merge(scale_back(fake_imgs_reshape),
                                       [source_len, 1])
            merged_real_images = merge(scale_back(real_imgs_reshape),
                                       [source_len, 1])
            merged_pair = np.concatenate(
                [merged_real_images, merged_fake_images], axis=1)

            batch_buffer.append(merged_pair)
            count += 1
        if batch_buffer:
            # last batch
            save_imgs(batch_buffer, count, threshold)

        accuracy = accuracy / total_count
        print("Average accruacy: %.5f" % accuracy)
    def test(self, source_provider, model_dir, save_dir):
        source_len = len(source_provider.data.examples)

        source_len = min(16, source_len)

        source_iter = source_provider.get_iter(source_len)

        tf.global_variables_initializer().run()

        saver = tf.train.Saver(var_list=self.retrieve_generator_vars())
        self.restore_model(saver, model_dir)

        def save_imgs(imgs, count, threshold):
            p = os.path.join(save_dir, "inferred_%04d_%.2f.png" % (count, threshold))
            save_concat_images(imgs, img_path=p)
            print("generated images saved at %s" % p)

        def save_img(img, mse_diff, nrmse_diff, ssim_diff, psnr_diff):
            p = os.path.join(save_dir, "cgan_patch%.4f-%.4f-%.4f-%.4f.png" % (ssim_diff, mse_diff, nrmse_diff,
                                                                              psnr_diff))
            save_image(img, img_path=p)
            # print("generated ssim: %.4f images saved at %s" % (ssim_diff, p) )

        def save_batch_samples(imgs, count, threshold):
            p = os.path.join(save_dir, "cgan_test_sample id:%04d_count:%04d_%.2f.png" % (self.experiment_id, count,
                                                                                         threshold))
            try:
                save_concat_images(imgs, img_path=p)
                # print("test batch samples saved!")
            except Exception as e:
                print(e)

        def save_single_img(img, count, bt):
            p = os.path.join(save_dir, "cgan_single_%d_%d.png" % (count, bt))
            save_image(img, img_path=p)
            print("cgan single sample id: %d _ %d saved" % (count, bt))

        count = 0
        threshold = 0.1
        # batch_buffer = list()

        for source_imgs in source_iter:
            fake_imgs, real_imgs, d_loss, g_loss, l1_loss = self.generate_fake_samples(source_imgs)

            img_shape = fake_imgs.shape

            fake_imgs_reshape = np.reshape(np.array(fake_imgs),
                                           [img_shape[0], img_shape[1] * img_shape[2] * img_shape[3]])
            real_imgs_reshape = np.reshape(np.array(real_imgs),
                                           [img_shape[0], img_shape[1] * img_shape[2] * img_shape[3]])

            fake_imgs_reshape_saved = fake_imgs_reshape
            real_imgs_reshape_saved = real_imgs_reshape

            # threshold -- fixed
            for bt in range(fake_imgs_reshape.shape[0]):
                # statistics mean of generator output
                g_mean = np.mean(np.array(fake_imgs_reshape))
                g_min = np.min(np.array(fake_imgs_reshape))
                g_max = np.max(np.array(fake_imgs_reshape))
                print("g_mean : %.05f g_min: %.05f g_max:%.05f" % (g_mean, g_min, g_max))

                for it in range(fake_imgs_reshape.shape[1]):
                    if fake_imgs_reshape[bt][it] >= threshold:
                        fake_imgs_reshape[bt][it] = 1.0
                    else:
                        fake_imgs_reshape[bt][it] = -1.0

            # otsu threshold
            # radius = 15
            # selem = disk(radius)
            #
            # local_otsu = rank.otsu(fake_imgs_reshape, selem)
            # fake_imgs_reshape >= local_otsu

            # valid pixels
            for bt in range(fake_imgs_reshape.shape[0]):
                p_over = 0
                p_less = 0
                p_valid = 0
                for it in range(fake_imgs_reshape.shape[1]):
                    if fake_imgs_reshape[bt][it] == 1.0 and real_imgs_reshape[bt][it] != 1.0:
                        p_over += 1
                    if fake_imgs_reshape[bt][it] != 1.0 and real_imgs_reshape[bt][it] == 1.0:
                        p_less += 1
                    if real_imgs_reshape[bt][it] == 1.0:
                        p_valid += 1

                p_accuracy = 1.0 * (p_valid - p_over - p_less) / p_valid
                print("cgan count %d sample %d pixel accuracy: %.05f" % (count, bt, p_accuracy))

            # save ave sample images
            for bt in range(fake_imgs_reshape.shape[0]):
                fk_reshape = np.reshape(fake_imgs_reshape_saved[bt], (fake_imgs.shape[1], fake_imgs.shape[2]))
                save_single_img(fk_reshape, count, bt)

            # mse, nrmse, ssim and psnr
            for bt in range(fake_imgs_reshape.shape[0]):
                mse_diff = compare_mse(real_imgs_reshape[bt], fake_imgs_reshape[bt])
                nrmse_diff = compare_nrmse(real_imgs_reshape[bt], fake_imgs_reshape[bt], norm_type="Euclidean")
                ssim_diff = compare_ssim(real_imgs_reshape[bt], fake_imgs_reshape[bt])
                psnr_diff = compare_psnr(real_imgs_reshape[bt], fake_imgs_reshape[bt])
                print("mse diff:{} | nrmse diff:{} | ssim:{} | psnr:{}".format(mse_diff, nrmse_diff,
                                                                               ssim_diff, psnr_diff))
                # kde
                r_reshape = np.reshape(real_imgs_reshape[bt], [1, img_shape[1] * img_shape[2] * img_shape[3]])
                f_reshape = np.reshape(fake_imgs_reshape[bt], [1, img_shape[1] * img_shape[2] * img_shape[3]])
                # kde = KernelDensity(kernel="gaussian", bandwidth=1).fit(r_reshape)
                # kde_score_fake = kde.score_samples(f_reshape)
                # kde_score_real = kde.score_samples(r_reshape)
                # print("fake: %0.3f real: %0.3f" % (kde_score_fake, kde_score_real))

                # save the images with ssim > 0.8 and ssim < 0.5
                if ssim_diff > 0.8 or ssim_diff < 0.5:
                    fk_reshape = np.reshape(fake_imgs_reshape_saved[bt], (1, fake_imgs.shape[1], fake_imgs.shape[2],
                                                                    fake_imgs.shape[3]))
                    rl_reshape = np.reshape(real_imgs_reshape_saved[bt], (1, real_imgs.shape[1], real_imgs.shape[2],
                                                                    real_imgs.shape[3]))
                    fk_reshape = merge(scale_back(fk_reshape), [1, 1])
                    rl_reshape = merge(scale_back(rl_reshape), [1, 1])
                    pair = np.concatenate([rl_reshape, fk_reshape], axis=1)
                    save_img(pair, mse_diff, nrmse_diff, ssim_diff, psnr_diff)

            fake_imgs_reshape = np.reshape(fake_imgs_reshape, fake_imgs.shape)
            real_imgs_reshape = np.reshape(real_imgs_reshape, real_imgs.shape)
            merged_fake_images = merge(scale_back(fake_imgs_reshape), [source_len, 1])
            merged_real_images = merge(scale_back(real_imgs_reshape), [source_len, 1])
            merged_pair = np.concatenate([merged_real_images, merged_fake_images], axis=1)
            merged_pair_splited = np.split(merged_pair, 4)
            save_batch_samples(merged_pair_splited, count, threshold)

            # batch_buffer.append(merged_pair)
            count += 1
        # if batch_buffer:
        #     # last batch
        #     save_imgs(batch_buffer, count, threshold)
Beispiel #5
0
    def test(self, source_provider, model_dir, save_dir):
        source_len = len(source_provider.data.examples)

        source_len = min(16, source_len)

        source_iter = source_provider.get_iter(source_len)

        tf.global_variables_initializer().run()

        saver = tf.train.Saver(var_list=self.retrieve_generator_vars())
        self.restore_model(saver, model_dir)

        def save_imgs(imgs, count, threshold):
            p = os.path.join(
                save_dir, "wgan_inferred_%04d_%.2f.png" % (count, threshold))
            save_concat_images(imgs, img_path=p)
            print("wgan generated images saved at %s" % p)

        def save_img(img, mse_diff, nrmse_diff, ssim_diff, psnr_diff):
            p = os.path.join(
                save_dir, "wgan-%.4f-%.4f-%.4f-%.4f.png" %
                (ssim_diff, mse_diff, nrmse_diff, psnr_diff))
            save_image(img, img_path=p)
            print("wgan generated ssim: %.4f images saved at %s" %
                  (ssim_diff, p))

        def save_single_img(img, count, bt):
            p = os.path.join(save_dir, "wgan_single_%d_%d.png" % (count, bt))
            save_image(img, img_path=p)
            print("wgan single sample id: %d_%d saved" % (count, bt))

        count = 0
        threshold = 0.1
        # batch_buffer = list()

        for source_imgs in source_iter:
            fake_imgs, real_imgs, d_loss, g_loss, l1_loss = self.generate_fake_samples(
                source_imgs)
            img_shape = fake_imgs.shape

            fake_imgs_reshape = np.reshape(
                np.array(fake_imgs),
                [img_shape[0], img_shape[1] * img_shape[2] * img_shape[3]])
            real_imgs_reshape = np.reshape(
                np.array(real_imgs),
                [img_shape[0], img_shape[1] * img_shape[2] * img_shape[3]])
            fake_imgs_reshape_saved = fake_imgs_reshape
            real_imgs_reshape_saved = real_imgs_reshape

            # threshold -- fixed
            for bt in range(fake_imgs_reshape.shape[0]):
                for it in range(fake_imgs_reshape.shape[1]):
                    if fake_imgs_reshape[bt][it] >= threshold:
                        fake_imgs_reshape[bt][it] = 1.0
                    else:
                        fake_imgs_reshape[bt][it] = -1.0

            # valid pixels
            for bt in range(fake_imgs_reshape.shape[0]):
                p_over = 0
                p_less = 0
                p_valid = 0
                for it in range(fake_imgs_reshape.shape[1]):
                    if fake_imgs_reshape[bt][
                            it] == 1.0 and real_imgs_reshape[bt][it] != 1.0:
                        p_over += 1
                    if fake_imgs_reshape[bt][it] != 1.0 and real_imgs_reshape[
                            bt][it] == 1.0:
                        p_less += 1
                    if real_imgs_reshape[bt][it] == 1.0:
                        p_valid += 1

                p_accuracy = 1.0 * (p_valid - p_over - p_less) / p_valid
                print("wgan count: %d sample %d pixel accuracy: %.05f" %
                      (count, bt, p_accuracy))

            # save ave sample images
            for bt in range(fake_imgs_reshape.shape[0]):
                fk_reshape = np.reshape(
                    fake_imgs_reshape_saved[bt],
                    (fake_imgs.shape[1], fake_imgs.shape[2]))
                save_single_img(fk_reshape, count, bt)

            # mse, nrmse, ssim and psnr
            for bt in range(fake_imgs_reshape.shape[0]):
                mse_diff = compare_mse(real_imgs_reshape[bt],
                                       fake_imgs_reshape[bt])
                nrmse_diff = compare_nrmse(real_imgs_reshape[bt],
                                           fake_imgs_reshape[bt],
                                           norm_type="Euclidean")
                ssim_diff = compare_ssim(real_imgs_reshape[bt],
                                         fake_imgs_reshape[bt])
                psnr_diff = compare_psnr(real_imgs_reshape[bt],
                                         fake_imgs_reshape[bt])
                print("mse diff:{} | nrmse diff:{} | ssim:{} | psnr:{}".format(
                    mse_diff, nrmse_diff, ssim_diff, psnr_diff))

                # save the images with ssim > 0.8 and ssim < 0.5
                if ssim_diff > 0.8 or ssim_diff < 0.5:
                    fk_reshape = np.reshape(
                        fake_imgs_reshape_saved[bt],
                        (1, fake_imgs.shape[1], fake_imgs.shape[2],
                         fake_imgs.shape[3]))
                    rl_reshape = np.reshape(
                        real_imgs_reshape_saved[bt],
                        (1, real_imgs.shape[1], real_imgs.shape[2],
                         real_imgs.shape[3]))
                    fk_reshape = merge(scale_back(fk_reshape), [1, 1])
                    rl_reshape = merge(scale_back(rl_reshape), [1, 1])
                    pair = np.concatenate([rl_reshape, fk_reshape], axis=1)
                    save_img(pair, mse_diff, nrmse_diff, ssim_diff, psnr_diff)

            fake_imgs_reshape = np.reshape(fake_imgs_reshape, fake_imgs.shape)
            real_imgs_reshape = np.reshape(real_imgs_reshape, real_imgs.shape)
            merged_fake_images = merge(scale_back(fake_imgs_reshape),
                                       [source_len, 1])
            merged_real_images = merge(scale_back(real_imgs_reshape),
                                       [source_len, 1])
            merged_pair = np.concatenate(
                [merged_real_images, merged_fake_images], axis=1)

            # batch_buffer.append(merged_pair)
            count += 1