示例#1
0
def feature_inversion(img, model, layer=None, n_steps=512, cossim_pow=0.0):
  with tf.Graph().as_default(), tf.Session() as sess:
    img = imgToModelSize(img, model)
    
    objective = objectives.Objective.sum([
        1.0 * dot_compare(layer, cossim_pow=cossim_pow),
        objectives.blur_input_each_step(),
    ])

    t_input = tf.placeholder(tf.float32, img.shape)
    param_f = param.image(img.shape[0], decorrelate=True, fft=True, alpha=False)
    param_f = tf.stack([param_f[0], t_input])

    transforms = [
      transform.pad(8, mode='constant', constant_value=.5),
      transform.jitter(8),
      transform.random_scale([0.9, 0.95, 1.05, 1.1] + [1]*4),
      transform.random_rotate(list(range(-5, 5)) + [0]*5),
      transform.jitter(2),
    ]

    T = render.make_vis_T(model, objective, param_f, transforms=transforms)
    loss, vis_op, t_image = T("loss"), T("vis_op"), T("input")

    tf.global_variables_initializer().run()
    for i in range(n_steps): _ = sess.run([vis_op], {t_input: img})

    result = t_image.eval(feed_dict={t_input: img})
    show(result[0])
示例#2
0
    def render_activation_grid_very_naive(self,
                                          img,
                                          layer="mixed4d",
                                          W=42,
                                          n_steps=256):

        # Get the activations
        with tf.Graph().as_default(), tf.Session() as sess:
            t_input = tf.placeholder("float32", [None, None, None, 3])
            T = render.import_model(self.model, t_input, t_input)
            acts = T(layer).eval({t_input: img[None]})[0]
        acts_flat = acts.reshape([-1] + [acts.shape[2]])

        # Render an image for each activation vector
        def param_f():
            return param.image(W, batch=acts_flat.shape[0])

        obj = objectives.Objective.sum([
            objectives.direction(layer, v, batch=n)
            for n, v in enumerate(acts_flat)
        ])
        thresholds = (n_steps // 2, n_steps)
        vis_imgs = render.render_vis(self.model,
                                     obj,
                                     param_f,
                                     thresholds=thresholds)[-1]

        # Combine the images and display the resulting grid
        vis_imgs_ = vis_imgs.reshape(list(acts.shape[:2]) + [W, W, 3])
        vis_imgs_cropped = vis_imgs_[:, :, 2:-2, 2:-2, :]
        show(np.hstack(np.hstack(vis_imgs_cropped)))
        return vis_imgs_cropped
示例#3
0
def keras_render_vis(input_model,
                     objective_f,
                     param_f=None,
                     optimizer=None,
                     transforms=None,
                     thresholds=[512],
                     print_objectives=None,
                     verbose=True,
                     use_fixed_seed=False,
                     raw=False):
    if use_fixed_seed:
        tf.set_random_seed(0)

    t_image, loss, train = keras_make_vis_T(input_model, objective_f, param_f,
                                            optimizer, transforms)
    if thresholds == int:
        thresholds = list(thresholds)

    cache_m = None
    cache_v = None
    lr = 0.05
    beta1 = 0.9
    beta2 = 0.999
    iters = 0
    images = []

    try:
        for i in range(max(thresholds) + 1):
            loss, grads = train([t_image, 0])
            step, cache_m, cache_v, iters = adam(grads, cache_m, cache_v,
                                                 iters, lr, beta1, beta2)
            t_image += step
            if i in thresholds:
                vis = t_image
                images.append(vis)
                if verbose:
                    print(i, loss)
                    show(np.hstack(vis))
        try:
            del loss, train  #clear graphs
            del _model
        except:
            pass

        return np.array(images)

    except KeyboardInterrupt:
        print("Interrupted optimization at step: ", i)
        print("will return the last iteration image only")
        vis = t_image
        show(np.hstack(vis))

        del loss, train  #clear graphs
        return normalize_array(np.hstack(vis))
示例#4
0
def render_vis(model,
               objective_f,
               file_name,
               filter_idx,
               param_f=None,
               optimizer=None,
               transforms=None,
               thresholds=(512, ),
               verbose=True,
               relu_gradient_override=True,
               use_fixed_seed=False):
    with tf.Graph().as_default() as graph, tf.Session() as sess:

        if use_fixed_seed:  # does not mean results are reproducible, see Args doc
            tf.set_random_seed(0)

        T = render.make_vis_T(model, objective_f, param_f, optimizer,
                              transforms, relu_gradient_override)
        loss, vis_op, t_image = T("loss"), T("vis_op"), T("input")
        tf.global_variables_initializer().run()

        loss_p = 0
        try:
            for i in range(max(thresholds) + 1):
                loss_, _ = sess.run([loss, vis_op])
                if i in thresholds:
                    vis = t_image.eval()
                    loss_p = loss_
                    plt.title("Filter {}, {:.5f}".format(filter_idx, loss_))
                    plt.imshow(np.hstack(vis))
                    plt.axis('off')
                    plt.savefig(file_name)
                    plt.clf()
        except KeyboardInterrupt:
            vis = t_image.eval()
            show(np.hstack(vis))

        return loss_p
 def train(self, thresholds=range(0, 5000, 30)):
     self.images = []
     print(self.sess.run(self.actions))
     vis = self.sess.run(self.final_canvas_state)
     show(np.hstack(vis[:2]))
     try:
         for i in range(max(thresholds) + 1):
             content_loss_, _ = self.sess.run([self.loss, self.vis_op])
             if i in thresholds:
                 vis = self.sess.run(self.resized_final)
                 print(
                     i,
                     content_loss_,
                 )
                 show(np.hstack(vis[:2]))
     except KeyboardInterrupt:
         vis = self.sess.run(self.final_canvas_state)
         show(np.hstack(vis[:2]))
示例#6
0
def render_vis(model,
               objective_f,
               param_f=None,
               optimizer=None,
               transforms=None,
               thresholds=(512, ),
               print_objectives=None,
               verbose=True,
               relu_gradient_override=True,
               use_fixed_seed=False):
    """Flexible optimization-base feature vis.

  There's a lot of ways one might wish to customize otpimization-based
  feature visualization. It's hard to create an abstraction that stands up
  to all the things one might wish to try.

  This function probably can't do *everything* you want, but it's much more
  flexible than a naive attempt. The basic abstraction is to split the problem
  into several parts. Consider the rguments:

  Args:
    model: The model to be visualized, from Alex' modelzoo.
    objective_f: The objective our visualization maximizes.
      See the objectives module for more details.
    param_f: Paramaterization of the image we're optimizing.
      See the paramaterization module for more details.
      Defaults to a naively paramaterized [1, 128, 128, 3] image.
    optimizer: Optimizer to optimize with. Either tf.train.Optimizer instance,
      or a function from (graph, sess) to such an instance.
      Defaults to Adam with lr .05.
    transforms: A list of stochastic transformations that get composed,
      which our visualization should robustly activate the network against.
      See the transform module for more details.
      Defaults to [transform.jitter(8)].
    thresholds: A list of numbers of optimization steps, at which we should
      save (and display if verbose=True) the visualization.
    print_objectives: A list of objectives separate from those being optimized,
      whose values get logged during the optimization.
    verbose: Should we display the visualization when we hit a threshold?
      This should only be used in IPython.
    relu_gradient_override: Whether to use the gradient override scheme
      described in lucid/misc/redirected_relu_grad.py. On by default!
    use_fixed_seed: Seed the RNG with a fixed value so results are reproducible.
      Off by default. As of tf 1.8 this does not work as intended, see:
      https://github.com/tensorflow/tensorflow/issues/9171
  Returns:
    2D array of optimization results containing of evaluations of supplied
    param_f snapshotted at specified thresholds. Usually that will mean one or
    multiple channel visualizations stacked on top of each other.
  """

    with tf.Graph().as_default() as graph, tf.Session() as sess:

        if use_fixed_seed:  # does not mean results are reproducible, see Args doc
            tf.set_random_seed(0)

        T = make_vis_T(model, objective_f, param_f, optimizer, transforms,
                       relu_gradient_override)
        print_objective_func = make_print_objective_func(print_objectives, T)
        loss, vis_op, t_image = T("loss"), T("vis_op"), T("input")
        tf.global_variables_initializer().run()

        images = []
        try:
            for i in range(max(thresholds) + 1):
                loss_, _ = sess.run([loss, vis_op])
                if i in thresholds:
                    vis = t_image.eval()
                    images.append(vis)
                    if verbose:
                        print(i, loss_)
                        print_objective_func(sess)
                        show(np.hstack(vis))
        except KeyboardInterrupt:
            log.warn("Interrupted optimization at step {:d}.".format(i + 1))
            vis = t_image.eval()
            show(np.hstack(vis))

        return images
示例#7
0
    def render_activation_grid_less_naive(self,
                                          img,
                                          layer="mixed4d",
                                          W=42,
                                          n_groups=6,
                                          subsample_factor=1,
                                          n_steps=256):
        # Get the activations
        with tf.Graph().as_default(), tf.Session() as sess:
            t_input = tf.placeholder("float32", [None, None, None, 3])
            T = render.import_model(self.model, t_input, t_input)
            acts = T(layer).eval({t_input: img[None]})[0]
        acts_flat = acts.reshape([-1] + [acts.shape[2]])
        N = acts_flat.shape[0]

        # The trick to avoiding "decoherence" is to recognize images that are
        # for similar activation vectors and
        if n_groups > 0:
            reducer = ChannelReducer(n_groups, "NMF")
            groups = reducer.fit_transform(acts_flat)
            groups /= groups.max(0)
        else:
            groups = np.zeros([])

        print(groups.shape)

        # The key trick to increasing memory efficiency is random sampling.
        # Even though we're visualizing lots of images, we only run a small
        # subset through the network at once. In order to do this, we'll need
        # to hold tensors in a tensorflow graph around the visualization process.
        with tf.Graph().as_default() as graph, tf.Session() as sess:
            # Using the groups, create a paramaterization of images that
            # partly shares paramters between the images for similar activation
            # vectors. Each one still has a full set of unique parameters, and could
            # optimize to any image. We're just making it easier to find solutions
            # where things are the same.
            group_imgs_raw = param.fft_image([n_groups, W, W, 3])
            unique_imgs_raw = param.fft_image([N, W, W, 3])
            opt_imgs = param.to_valid_rgb(tf.stack([
                0.7 * unique_imgs_raw[i] +
                0.5 * sum(groups[i, j] * group_imgs_raw[j]
                          for j in range(n_groups)) for i in range(N)
            ]),
                                          decorrelate=True)

            # Construct a random batch to optimize this step
            batch_size = 64
            rand_inds = tf.random_uniform([batch_size], 0, N, dtype=tf.int32)
            pres_imgs = tf.gather(opt_imgs, rand_inds)
            pres_acts = tf.gather(acts_flat, rand_inds)
            obj = objectives.Objective.sum([
                objectives.direction(layer, pres_acts[n], batch=n)
                for n in range(batch_size)
            ])

            # Actually do the optimization...
            T = render.make_vis_T(self.model, obj, param_f=pres_imgs)
            tf.global_variables_initializer().run()

            for i in range(n_steps):
                T("vis_op").run()
                if (i + 1) % (n_steps // 2) == 0:
                    show(pres_imgs.eval()[::4])

            vis_imgs = opt_imgs.eval()

        # Combine the images and display the resulting grid
        print("")
        vis_imgs_ = vis_imgs.reshape(list(acts.shape[:2]) + [W, W, 3])
        vis_imgs_cropped = vis_imgs_[:, :, 2:-2, 2:-2, :]
        show(np.hstack(np.hstack(vis_imgs_cropped)))
        return vis_imgs_cropped
# The `load` function takes a link or local filepath. Input images will be forced to squares.

# In[ ]:

# Load from a URL
if len(sys.argv) > 1:
    local_path = sys.argv[1]
else:
    local_path = "./F561f22668fee4.jpg"

CONTENT_IMAGE = load(local_path)[..., :3]  # Remove transparency channel

# Or load from a local path
#CONTENT_IMAGE = load("local_path.jpg")[..., :3]  # Remove transparency channel

show(CONTENT_IMAGE)

# ## Run!

# In[ ]:

# print(558)
lol = LucidGraph(CONTENT_IMAGE,
                 32,
                 8,
                 NUMBER_STROKES,
                 painter_type=PAINTER_MODE,
                 gpu_mode=False,
                 connected=CONNECTED_STROKES,
                 alternate=False,
                 bw=BW,
    transform.random_scale([SCALE**(n / 10.) for n in range(-10, 11)]),
    transform.random_rotate(range(-ROTATE, ROTATE + 1))
]

imgs = render.render_vis(model,
                         "mixed4b_pre_relu:452",
                         transforms=transforms,
                         param_f=lambda: param.image(64),
                         thresholds=[2048],
                         verbose=False,
                         relu_gradient_override=True,
                         use_fixed_seed=True)
plt.imshow(imgs[0][0])

# Note that we're doubling the image scale to make artifacts more obvious
show([nd.zoom(img[0], [2, 2, 1], order=0) for img in imgs])

model = models.InceptionV1_slim()
model.load_graphdef()

out = render.render_vis(model, 'InceptionV1/InceptionV1/Mixed_4c/concat:452')
plt.imshow(out[0][0])

model = models.VGG19_caffe()
model.load_graphdef()
nodes_tab = [n.name for n in tf.get_default_graph().as_graph_def().node]
out = render.render_vis(model, 'conv3_1/conv3_1:1')
plt.figure()
plt.imshow(out[0][0])

LEARNING_RATE = 0.0005
示例#10
0
model.load_graphdef()

model.show_graph()

"""## Visualize Neuron

See the [lucid tutorial](https://colab.research.google.com/github/tensorflow/lucid/blob/master/notebooks/tutorial.ipynb) to learn more.

We pick `InceptionV4/InceptionV4/Mixed_6b/concat` from above, and chose to focus on unit 0.
"""

model = models.VGG16_caffe()
model.load_graphdef()

_ = render.render_vis(model, "conv1_1/conv1_1:0")

"""## Caricature

See the [inversion and caricature notebook](https://colab.research.google.com/github/tensorflow/lucid/blob/master/notebooks/misc/feature_inversion_caricatures.ipynb) to learn more.
"""

from lucid.recipes.caricature import feature_inversion

img = load("https://storage.googleapis.com/lucid-static/building-blocks/examples/dog_cat.png")

model = models.VGG16_caffe()
model.load_graphdef()

result = feature_inversion(img, model, "conv1_1/conv1_1", n_steps=512, cossim_pow=0.0)
show(result)
示例#11
0
def make_lucid_imagenet_dataset(
        model_tag,
        n_random=9,
        min_size=3,
        max_prop=0.8,
        display=True,
        infodir='/project/clusterability_in_neural_networks/results/',
        savedir='/project/clusterability_in_neural_networks/datasets/'):

    assert model_tag in VIS_NETS

    with open(infodir + model_tag + '_clustering_info.pkl', 'rb') as f:
        clustering_info = pickle.load(f)

    layer_names = clustering_info['layers']
    labels_in_layers = [
        np.array(lyr_labels) for lyr_labels in clustering_info['labels']
    ]
    layer_sizes = [len(labels) for labels in labels_in_layers]
    n_clusters = max([max(labels) for labels in labels_in_layers]) + 1

    if model_tag == 'vgg16':
        lucid_net = models.VGG16_caffe()
    elif model_tag == 'vgg19':
        lucid_net = models.VGG19_caffe()
    else:
        lucid_net = models.ResnetV1_50_slim()
    lucid_net.load_graphdef()
    layer_map = NETWORK_LAYER_MAP[model_tag]

    max_images = [
    ]  # to be filled with images that maximize cluster activations
    # min_images = []  # to be filled with images that minimize cluster activations
    random_max_images = [
    ]  # to be filled with images that maximize random units activations
    # random_min_images = []  # to be filled with images that minimize random units activations
    max_losses = []  # to be filled with losses
    # min_losses = []  # to be filled with losses
    random_max_losses = []  # to be filled with losses
    # random_min_losses = []  # to be filled with losses
    sm_sizes = []  # list of submodule sizes
    sm_layer_sizes = []
    sm_layers = []  # list of layer names
    sm_clusters = []  # list of clusters

    for layer_name, labels, layer_size in zip(layer_names, labels_in_layers,
                                              layer_sizes):

        if layer_name not in layer_map.keys():
            continue

        lucid_name = layer_map[layer_name]
        max_size = max_prop * layer_size

        for clust_i in range(n_clusters):

            sm_binary = labels == clust_i
            sm_size = sum(sm_binary)
            if sm_size <= min_size or sm_size >= max_size:  # skip if too big or small
                continue

            sm_sizes.append(sm_size)
            sm_layer_sizes.append(layer_size)
            sm_layers.append(layer_name)
            sm_clusters.append(clust_i)

            print(f'{model_tag}, layer names: {layer_name}, {lucid_name}')
            print(f'submodule_size: {sm_size}, layer_size: {layer_size}')

            sm_idxs = [i for i in range(layer_size) if sm_binary[i]]
            max_obj = sum(
                [objectives.channel(lucid_name, unit) for unit in sm_idxs])
            # min_obj = -1 * sum([objectives.channel(lucid_name, unit) for unit in sm_idxs])

            max_im, max_loss = render_vis_with_loss(lucid_net,
                                                    max_obj,
                                                    size=IMAGE_SIZE_IMAGENET,
                                                    thresholds=(256, ))
            max_images.append(max_im)
            max_losses.append(max_loss)
            # min_im, min_loss = render_vis_with_loss(lucid_net, min_obj)
            # min_images.append(min_im)
            # min_losses.append(min_loss)
            if display:
                print(f'loss: {round(max_loss, 3)}')
                show(max_im)

            rdm_losses = []
            rdm_ims = []
            for _ in range(n_random):  # random max results
                rdm_idxs = np.random.choice(np.array(range(layer_size)),
                                            size=sm_size,
                                            replace=False)
                random_max_obj = sum([
                    objectives.channel(lucid_name, unit) for unit in rdm_idxs
                ])
                random_max_im, random_max_loss = render_vis_with_loss(
                    lucid_net,
                    random_max_obj,
                    size=IMAGE_SIZE_IMAGENET,
                    thresholds=(256, ))
                random_max_images.append(random_max_im)
                random_max_losses.append(random_max_loss)
                rdm_losses.append(round(random_max_loss, 3))
                rdm_ims.append(np.squeeze(random_max_im))
            if display:
                print(f'random losses: {rdm_losses}')
                show(np.hstack(rdm_ims))

            # for _ in range(n_random):  # random min results
            #     rdm_idxs = np.random.choice(np.array(range(layer_size)), size=sm_size, replace=False)
            #     random_min_obj = -1 * sum([objectives.channel(lucid_name, unit) for unit in rdm_idxs])
            #     random_min_im, random_min_loss = render_vis_with_loss(lucid_net, random_min_obj)
            #     random_min_images.append(random_min_im)
            #     random_min_losses.append(random_min_loss)

    max_images = np.squeeze(np.array(max_images))
    # min_images = np.squeeze(np.array(min_images))
    random_max_images = np.squeeze(np.array(random_max_images))
    # random_min_images = np.squeeze(np.array(random_min_images))
    max_losses = np.array(max_losses)
    # min_losses = np.array(min_losses)
    random_max_losses = np.array(random_max_losses)
    # random_min_losses = np.array(random_min_losses)

    results = {
        'max_images': max_images,  # 'min_images': min_images,
        'random_max_images':
        random_max_images,  # 'random_min_images': random_min_images,
        'max_losses': max_losses,  # 'min_losses': min_losses,
        'random_max_losses':
        random_max_losses,  # 'random_min_losses': random_min_losses,
        'sm_sizes': sm_sizes,
        'sm_layer_sizes': sm_layer_sizes,
        'sm_layers': sm_layers,
        'sm_clusters': sm_clusters
    }

    with open(savedir + model_tag + '_max_data.pkl', 'wb') as f:
        pickle.dump(results, f)
示例#12
0
def make_lucid_dataset(
        model_tag,
        lucid_net,
        all_labels,
        is_unpruned,
        transforms=[],
        n_random=9,
        min_size=5,
        max_prop=0.8,
        display=True,
        savedir='/project/clusterability_in_neural_networks/datasets/',
        savetag=''):

    if 'cnn' in model_tag.lower():
        cnn_params = CNN_VGG_MODEL_PARAMS if 'vgg' in str(
            model_tag).lower() else CNN_MODEL_PARAMS
        layer_sizes = [cl['filters'] for cl in cnn_params['conv']]
        layer_names = ['conv2d/Relu'] + [
            f'conv2d_{i}/Relu' for i in range(1, len(layer_sizes))
        ]
    else:  # it's an mlp
        layer_sizes = [256, 256, 256, 256]
        layer_names = ['dense/Relu'] + [
            f'dense_{i}/Relu' for i in range(1, len(layer_sizes))
        ]
    if not is_unpruned:
        layer_names = ['prune_low_magnitude_' + ln for ln in layer_names]

    labels_in_layers = [
        np.array(lyr_labels)
        for lyr_labels in list(splitter(all_labels, layer_sizes))
    ]

    max_images = [
    ]  # to be filled with images that maximize cluster activations
    random_max_images = [
    ]  # to be filled with images that maximize random units activations
    max_losses = []  # to be filled with losses
    random_max_losses = []  # to be filled with losses
    sm_sizes = []  # list of submodule sizes
    sm_layer_sizes = []
    sm_layers = []  # list of layer names
    sm_clusters = []  # list of clusters

    imsize = IMAGE_SIZE_CIFAR10 if 'vgg' in model_tag.lower() else IMAGE_SIZE

    for layer_name, labels, layer_size in zip(layer_names, labels_in_layers,
                                              layer_sizes):

        max_size = max_prop * layer_size

        for clust_i in range(max(all_labels) + 1):

            sm_binary = labels == clust_i
            sm_size = sum(sm_binary)
            if sm_size <= min_size or sm_size >= max_size:  # skip if too big or small
                continue

            sm_sizes.append(sm_size)
            sm_layer_sizes.append(layer_size)
            sm_layers.append(layer_name)
            sm_clusters.append(clust_i)

            # print(f'{model_tag}, layer: {layer_name}')
            # print(f'submodule_size: {sm_size}, layer_size: {layer_size}')

            sm_idxs = [i for i in range(layer_size) if sm_binary[i]]
            max_obj = sum(
                [objectives.channel(layer_name, unit) for unit in sm_idxs])

            max_im, max_loss = render_vis_with_loss(lucid_net,
                                                    max_obj,
                                                    size=imsize,
                                                    transforms=transforms)
            max_images.append(max_im)
            max_losses.append(max_loss)
            if display:
                print(f'loss: {round(max_loss, 3)}')
                show(max_im)

            rdm_losses = []
            rdm_ims = []
            for _ in range(n_random):  # random max results
                rdm_idxs = np.random.choice(np.array(range(layer_size)),
                                            size=sm_size,
                                            replace=False)
                random_max_obj = sum([
                    objectives.channel(layer_name, unit) for unit in rdm_idxs
                ])
                random_max_im, random_max_loss = render_vis_with_loss(
                    lucid_net,
                    random_max_obj,
                    size=imsize,
                    transforms=transforms)
                random_max_images.append(random_max_im)
                random_max_losses.append(random_max_loss)
                rdm_ims.append(np.squeeze(random_max_im))
                rdm_losses.append(round(random_max_loss, 3))
            if display:
                print(f'random losses: {rdm_losses}')
                show(np.hstack(rdm_ims))

    max_images = np.squeeze(np.array(max_images))
    random_max_images = np.squeeze(np.array(random_max_images))
    max_losses = np.array(max_losses)
    random_max_losses = np.array(random_max_losses)

    results = {
        'max_images': max_images,
        'random_max_images': random_max_images,
        'max_losses': max_losses,
        'random_max_losses': random_max_losses,
        'sm_sizes': sm_sizes,
        'sm_layer_sizes': sm_layer_sizes,
        'sm_layers': sm_layers,
        'sm_clusters': sm_clusters
    }

    if is_unpruned:
        suff = '_unpruned_max_data'
    else:
        suff = '_pruned_max_data'

    with open(savedir + model_tag + suff + savetag + '.pkl', 'wb') as f:
        pickle.dump(results, f)
def optimize_input(obj,
                   model,
                   param_f,
                   transforms,
                   lr=0.05,
                   step_n=512,
                   num_output_channels=4,
                   do_render=False,
                   out_name="out"):

    sess = create_session()

    # Set up optimization problem
    size = 84
    t_size = tf.placeholder_with_default(size, [])
    T = render.make_vis_T(
        model,
        obj,
        param_f=param_f,
        transforms=transforms,
        optimizer=tf.train.AdamOptimizer(lr),
    )

    tf.global_variables_initializer().run()

    if do_render:
        video_fn = out_name + '.mp4'
        writer = FFMPEG_VideoWriter(video_fn, (size, size * 4), 60.0)

    # Optimization loop
    try:
        for i in range(step_n):
            _, loss, img = sess.run([T("vis_op"), T("loss"), T("input")])

            if do_render:
                #if outputting only one channel...
                if num_output_channels == 1:
                    img = img[..., -1:]  #print(img.shape)
                    img = np.tile(img, 3)
                else:
                    #img=img[...,-3:]
                    img = img.transpose([0, 3, 1, 2])
                    img = img.reshape([84 * 4, 84, 1])
                    img = np.tile(img, 3)
                writer.write_frame(_normalize_array(img))
                if i > 0 and i % 50 == 0:
                    clear_output()
                    print("%d / %d  score: %f" % (i, step_n, loss))
                    show(img)

    except KeyboardInterrupt:
        pass
    finally:
        if do_render:
            print("closing...")
            writer.close()

    # Save trained variables
    if do_render:
        train_vars = sess.graph.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES)
        params = np.array(sess.run(train_vars), object)
        save(params, out_name + '.npy')

        # Save final image
        final_img = T("input").eval({t_size: 600})[..., -1:]  #change size
        save(final_img, out_name + '.jpg', quality=90)

    out = T("input").eval({t_size: 84})
    sess.close()
    return out