예제 #1
0
    def evaluate(self, list_of_images, **kwargs):
        """ evaluate test iamges and save multires grid
        """
        self.set_eval()

        with torch.no_grad():
            evaluated_images_A = []
            evaluated_images_B = []
            for image in tqdm(list_of_images):
                self.set_inputs(image)
                self.forward()
                evaluated_images_A.append(
                    tensor_to_image(self.fake_A)[..., :3])
                evaluated_images_B.append(
                    tensor_to_image(self.fake_B)[..., :3])
            path_A = self.opt.checkpoints_dir + f"/{self.opt.name}/eval_images/grid_fakeA_{self.step}.png"
            path_B = self.opt.checkpoints_dir + f"/{self.opt.name}/eval_images/grid_fakeB_{self.step}.png"
            image_A = util.draw_multires_figure(np.array(evaluated_images_A),
                                                n_columns=3)
            image_B = util.draw_multires_figure(np.array(evaluated_images_B),
                                                n_columns=3)
            wandb.log({'eval_fakeA': [wandb.Image(image_A)]}, step=self.step)
            wandb.log({'eval_fakeB': [wandb.Image(image_B)]}, step=self.step)
            save_image(path_A, np.array(image_A))
            save_image(path_B, np.array(image_B))

        self.set_train()
예제 #2
0
 def save_generated(self, data):
     """ Generate images from batch and save them
     """
     self.set_inputs(data["A"])
     self.forward()
     path = self.opt.checkpoints_dir + f"/{self.opt.name}/generated/{data['A_name']}_fake.png"
     save_image(path, tensor_to_image(self.fake)[..., :3])
예제 #3
0
 def get_visuals(self):
     """ Returns a dict containing the current visuals
     """
     visuals = dict()
     for visual in self.visual_names:
         visuals[visual] = tensor_to_image(getattr(self, visual)) # .cpu().detach().permute(0,2,3,1).numpy()
     return visuals
예제 #4
0
 def save_image(self, save_dir, img_gen):
     save_path_img = os.path.join(save_dir, 'gen_result.png')
     save_image(make_grid(tensor_to_image(img_gen, self.image_resize), nrow=8), save_path_img)
예제 #5
0
        tf.reduce_mean((content_outputs[name] - content_targets[name])**2)
        for name in content_outputs.keys()
    ])
    content_loss *= CONTENT_WEIGHT / NUM_CONTENT_LAYERS
    loss = style_loss + content_loss
    return loss


def train_step(image):
    """:param image: tensor containing target image"""
    with tf.GradientTape() as tape:
        tape.watch(image)
        outputs = extractor(image)
        loss = style_content_loss(outputs)
        loss += TOTAL_VARIATION_WEIGHT * tf.image.total_variation(image)

    grad = tape.gradient(loss, image)
    opt = tf.optimizers.Adam(learning_rate=0.02, beta_1=0.99, epsilon=1e-1)
    opt.apply_gradients([(grad, image)])
    image.assign(util.clip_0_1(image))
    del tape


# training loop
for _ in tqdm(range(epochs * STEPS_PER_EPOCH)):
    train_step(image)

# save file to disk
FILE_NAME = "stylized-image.png"
util.tensor_to_image(image).save(FILE_NAME)