コード例 #1
0
    def on_epoch_end(self, data: Data):
        for _ in range(self.trials):
            img_path = self.dataset.one_shot_trial(self.N)
            input_img = (
                np.array([
                    np.expand_dims(cv2.imread(i, cv2.IMREAD_GRAYSCALE), -1) /
                    255. for i in img_path[0]
                ],
                         dtype=np.float32),
                np.array([
                    np.expand_dims(cv2.imread(i, cv2.IMREAD_GRAYSCALE), -1) /
                    255. for i in img_path[1]
                ],
                         dtype=np.float32))
            prediction_score = feed_forward(self.model,
                                            input_img,
                                            training=False).numpy()

            if np.argmax(
                    prediction_score) == 0 and prediction_score.std() > 0.01:
                self.correct += 1

            self.total += 1

        data.write_with_log(self.outputs[0], self.correct / self.total)
コード例 #2
0
    def on_epoch_end(self, data: Data):
        device = next(self.model.parameters()).device
        for _ in range(self.trials):
            img_path = self.dataset.one_shot_trial(self.N)
            input_img = (
                np.array([
                    np.expand_dims(cv2.imread(i, cv2.IMREAD_GRAYSCALE),
                                   -1).reshape((1, 105, 105)) / 255.
                    for i in img_path[0]
                ],
                         dtype=np.float32),
                np.array([
                    np.expand_dims(cv2.imread(i, cv2.IMREAD_GRAYSCALE),
                                   -1).reshape((1, 105, 105)) / 255.
                    for i in img_path[1]
                ],
                         dtype=np.float32))

            input_img = (to_tensor(input_img[0], "torch").to(device),
                         to_tensor(input_img[1], "torch").to(device))
            model = self.model.module if torch.cuda.device_count(
            ) > 1 else self.model
            prediction_score = feed_forward(
                model, input_img, training=False).cpu().detach().numpy()

            if np.argmax(
                    prediction_score) == 0 and prediction_score.std() > 0.01:
                self.correct += 1

            self.total += 1

        data.write_with_log(self.outputs[0], self.correct / self.total)
コード例 #3
0
 def on_epoch_end(self, state):
     if self.system.epoch_idx in self.epoch_model_map:
         model = self.epoch_model_map[self.system.epoch_idx]
         model = model.module if torch.cuda.device_count() > 1 else model
         for i in range(self.num_sample):
             random_vectors = torch.normal(mean=0.0, std=1.0,
                                           size=(1, self.latent_dim)).to(next(model.parameters()).device)
             pred = feed_forward(model, random_vectors, training=False).to("cpu")
             disp_img = np.transpose(pred.data.numpy(), (0, 2, 3, 1))  # BCHW -> BHWC
             disp_img = np.squeeze(disp_img)
             disp_img -= disp_img.min()
             disp_img /= (disp_img.max() + self.eps)
             disp_img = np.uint8(disp_img * 255)
             cv2.imwrite(
                 os.path.join(self.save_dir, 'image_at_{:08d}_{}.png').format(self.system.epoch_idx, i), disp_img)
         print("on epoch {}, saving image to {}".format(self.system.epoch_idx, self.save_dir))
コード例 #4
0
 def on_epoch_end(self, state):
     if self.system.epoch_idx in self.epoch_model_map:
         model = self.epoch_model_map[self.system.epoch_idx]
         for i in range(self.num_sample):
             random_vectors = tf.random.normal([1, self.latent_dim])
             pred = feed_forward(model, random_vectors, training=False)
             disp_img = pred.numpy()
             disp_img = np.squeeze(disp_img)
             disp_img -= disp_img.min()
             disp_img /= (disp_img.max() + self.eps)
             disp_img = np.uint8(disp_img * 255)
             cv2.imwrite(
                 os.path.join(self.save_dir,
                              'image_at_{:08d}_{}.png').format(
                                  self.system.epoch_idx, i), disp_img)
         print("on epoch {}, saving image to {}".format(
             self.system.epoch_idx, self.save_dir))