def conditional_save_img_comparison_with_intensity(self, mode, i, ele, pred, pred_intensity, epoch): # save 8 images for visualization if mode == 'val' or mode == 'eval': skip = 100 if i == 0: self.img_merge = vis_utils.merge_into_row_with_intensity( ele, pred, pred_intensity) # self.img_merge = vis_utils.merge_into_row(ele, pred) elif i % skip == 0 and i < 8 * skip: # row = vis_utils.merge_into_row(ele, pred) row = vis_utils.merge_into_row_with_intensity( ele, pred, pred_intensity) self.img_merge = vis_utils.add_row(self.img_merge, row) elif i == 8 * skip: filename = self._get_img_comparison_name(mode, epoch) vis_utils.save_image(self.img_merge, filename) #HWC # input C x H x W img_np_rescale = skimage.transform.rescale(np.array( self.img_merge, dtype='float64'), 0.5, order=0) img_np_rescale_CHW = np.transpose(img_np_rescale, (2, 0, 1)) self.writer.add_image('comparison', img_np_rescale_CHW, i)
def conditional_save_img_comparison(self, mode, i, ele, pred, epoch): # save 8 images for visualization if mode == 'val' or mode == 'eval': skip = 100 if i == 0: self.img_merge = vis_utils.merge_into_row(ele, pred) elif i % skip == 0 and i < 8 * skip: row = vis_utils.merge_into_row(ele, pred) self.img_merge = vis_utils.add_row(self.img_merge, row) elif i == 8 * skip: filename = self._get_img_comparison_name(mode, epoch) vis_utils.save_image(self.img_merge, filename)