示例#1
0
def test_L1(inceptionv1):
    objective = objectives.L1()  # on input by default
    assert_gradient_ascent(objective, inceptionv1)
示例#2
0
    def run(self,
            layer,
            class_,
            channel=None,
            style_template=None,
            transforms=False,
            opt_steps=500,
            gram_coeff=1e-14):
        """
    layer         : layer_name to visualize
    class_        : class to consider
    style_template: template for comparision of generated activation maximization map
    transforms    : transforms required
    opt_steps     : number of optimization steps
    """

        self.layer = layer
        self.channel = channel if channel is not None else 0

        with tf.Graph().as_default() as graph, tf.Session() as sess:

            if style_template is not None:

                try:
                    gram_template = tf.constant(
                        np.load(style_template),  #[1:-1,:,:],
                        dtype=tf.float32)
                except:
                    image = cv2.imread(style_template)
                    print(image.shape)
                    gram_template = tf.constant(
                        np.pad(cv2.imread(style_template),
                               ((1, 1), (0, 0))),  #[1:-1,:,:],
                        dtype=tf.float32)
            else:
                gram_template = None

            obj = self._channel(self.layer + "/convolution",
                                self.channel,
                                gram=gram_template,
                                gram_coeff=gram_coeff)
            obj += -self.L1 * objectives.L1(constant=.5)
            obj += -self.TV * objectives.total_variation()
            #obj += self.blur * objectives.blur_input_each_step()

            if transforms == True:
                transforms = [
                    transform.pad(self.jitter),
                    transform.jitter(self.jitter),
                    #transform.random_scale([self.scale ** (n/10.) for n in range(-10, 11)]),
                    #transform.random_rotate(range(-self.rotate, self.rotate + 1))
                ]
            else:
                transforms = []

            T = render.make_vis_T(
                self.model,
                obj,
                param_f=lambda: self.image(240,
                                           channels=self.n_channels,
                                           fft=self.decorrelate,
                                           decorrelate=self.decorrelate),
                optimizer=None,
                transforms=transforms,
                relu_gradient_override=False)
            tf.initialize_all_variables().run()

            images_array = []

            for i in range(opt_steps):
                T("vis_op").run()
                images_array.append(
                    T("input").eval()[:, :, :, -1].reshape((240, 240)))

            plt.figure(figsize=(10, 10))
            # for i in range(1, self.n_channels+1):
            #   plt.imshow(np.load(style_template)[:, :, i-1], cmap='gray',
            #              interpolation='bilinear', vmin=0., vmax=1.)
            #   plt.savefig('gram_template_{}.png'.format(i), bbox_inches='tight')

            texture_images = []

            for i in range(1, self.n_channels + 1):
                # plt.subplot(1, self.n_channels, i)
                image = T("input").eval()[:, :, :, i - 1].reshape((240, 240))
                print("channel: ", i, image.min(), image.max())
                # plt.imshow(image, cmap='gray',
                #            interpolation='bilinear', vmin=0., vmax=1.)
                # plt.xticks([])
                # plt.yticks([])
                texture_images.append(image)
                # show(np.hstack(T("input").eval()))

                os.makedirs(os.path.join(self.savepath, class_), exist_ok=True)
                # print(self.savepath, class_, self.layer+'_' + str(self.channel) +'.png')
                # plt.savefig(os.path.join(self.savepath, class_, self.layer+'_' + str(self.channel) + '_' + str(i) +'_noreg.png'), bbox_inches='tight')
            # plt.show()
            # print(np.array(texture_images).shape)

        return np.array(texture_images), images_array