コード例 #1
0
    def __init__(self):
        test_noisy_image = utils.imread(utils.get_image_path(False, 64, 4003))
        test_noisy_image = utils.scale_image(test_noisy_image,
                                             2.0)  # Image size 128x128
        test_noisy_image /= 255.0
        test_noisy_image = test_noisy_image.reshape(128, 128, 1)
        self.noisy_img1 = test_noisy_image

        test_noisy_image = utils.imread(utils.get_image_path(False, 64, 19983))
        test_noisy_image = utils.scale_image(test_noisy_image,
                                             2.0)  # Image size 128x128
        test_noisy_image /= 255.0
        test_noisy_image = test_noisy_image.reshape(128, 128, 1)
        self.noisy_img2 = test_noisy_image
コード例 #2
0
    def step(self, x, y, log=False):
        y_prob = self.model(x)

        loss = self.bce_loss(y_prob, y)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        self.global_step += 1

        if log:
            step = self.global_step
            tfwriter = self.tfwriter
            tfwriter.add_scalar("losses/loss", loss, step)

            # create grid of images
            x_grid = torchvision.utils.make_grid(x[:, [2, 1, 0]])
            y_grid = torchvision.utils.make_grid(y)

            seg_grid = torchvision.utils.make_grid(y_prob)

            # write to tensorboard
            tfwriter.add_image('inputs', scale_image(x_grid), step)
            tfwriter.add_image('labels', y_grid, step)

            tfwriter.add_image('segmentation', seg_grid, step)
            tfwriter.add_histogram('segmentation', y, step)
        return loss
コード例 #3
0
def run(args):
    global use_cuda

    print('Loading Generator')
    model = Generator()
    model.load_state_dict(torch.load(args.weights))

    # Generate latent vector
    x = torch.randn(1, 512, 1, 1)

    if use_cuda:
        model = model.cuda()
        x = x.cuda()

    x = Variable(x)

    print('Executing forward pass')
    images = model(x)

    if use_cuda:
        images = images.cpu()

    images_np = images.data.numpy().transpose(0, 2, 3, 1)
    image_np = scale_image(images_np[0, ...])

    print('Output')
    plt.figure()
    plt.imshow(image_np)
    plt.show(block=True)
コード例 #4
0
def run(args, x, frame):

    global use_cuda

    #print('Loading Generator') # modified for readability
    model = Generator()
    model.load_state_dict(torch.load(args.weights))

    # Generate latent vector
    #    x = torch.randn(1, 512, 1, 1) # now generated in main and passed as an argument

    if use_cuda:
        model = model.cuda()
        x = x.cuda()
    x = Variable(x, volatile=True)

    #print('Executing forward pass')
    images = model(x)

    if use_cuda:
        images = images.cpu()


#
    images_np = images.data.numpy().transpose(0, 2, 3, 1)
    image_np = scale_image(images_np[0, ...])

    #print('Output')
    #plt.figure()
    #plt.imshow(image_np)
    cv2image_np = image_np[..., ::-1]
    #cv2.imshow("test", cv2image_np)
    cv2.imwrite(
        "/home/ubuntu/prog_gans_pytorch_inference/test" + str(frame) + ".png",
        cv2image_np)
コード例 #5
0
    def load_map_info(self):
        image = Image.open(self.path + "/" + self.name + ".png")
        image = utils.scale_image(self.scale, image)
        with open(self.path + "/" + self.name + "_info.txt") as file:
            map_info = file.read()

        self.image_png = image
        self.image_pygame = utils.image_to_pygame(image)
        self.map_info = map_info
コード例 #6
0
    def dis_update(self, x_dict, log=False):
        if self.distribute:
            gen = self.gen.module
            dis = self.dis.module
        else:
            gen = self.gen
            dis = self.dis

        if gen.skip_dim:
            x_skip = {name: x[:, gen.skip_dim] for name, x in x_dict.items()}
        else:
            x_skip = {name: None for name in x_dict.keys()}

        z_dict = {name: gen.encode(x, name) for name, x in x_dict.items()}
        keys = list(x_dict.keys())

        # for each discriminator, show it examples of true examples + all other false examples
        dis_losses = []
        for name, x in x_dict.items():
            cross_keys = copy.copy(list(gen.decoders.keys()))
            cross_keys.remove(name)
            x_cross = {
                kk: gen.decode(z_dict[kk][0] + z_dict[kk][1],
                               name,
                               skip_x=x_skip[kk])
                for kk in z_dict.keys()
            }
            dis_losses += [
                dis.models[name].calc_dis_loss(xc.detach(), x_dict[name])
                for kk, xc in x_cross.items()
            ]

        loss_dis = torch.mean(torch.stack(dis_losses))
        loss = loss_dis * self.params['gan_w']

        self.optimizer_dis.zero_grad()
        loss.backward()
        self.optimizer_dis.step()

        if log:
            tfwriter = self.tfwriter
            tfwriter.add_scalar('losses/dis/total', loss_dis, self.global_step)

            x_12 = gen.decode(z_dict[keys[0]][0],
                              keys[1],
                              skip_x=x_skip[keys[0]])
            for c in range(x_12.shape[1]):
                tfwriter.add_image(f"images/cross_reconstruction/channel_{c}",
                                   utils.scale_image(x_12[0, c:c + 1]),
                                   self.global_step)

        return loss
コード例 #7
0
    def step(self, x, y, log=False):
        y_reg = self.model(x)
        loss = self.mse_loss(y, y_reg)

        #print('loss', loss)
        #print('null inputs', torch.mean((x != x).type(torch.FloatTensor)))
        if loss != loss:
            print(f"Loss Regression: {loss.item()}")
            print('x', x[0,0,:,:].cpu().detach().numpy())
            print('y', y[0,0,:,:].cpu().detach().numpy())
            print('regression', y_reg[0,0,:,:].cpu().detach().numpy())
            print('probs', y_prob[0,0,:,:].cpu().detach().numpy())
            print('mask', mask[0,0,:,:].cpu().detach().numpy())
            print('sq_err', sq_err[0,0,:,:].cpu().detach().numpy())

        if loss != loss:
            sys.exit()
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        self.global_step += 1

        #print('next output', self.model(x)[0])

        if log:
            step = self.global_step
            tfwriter = self.tfwriter
            tfwriter.add_scalar("losses/loss", loss, step)
            # create grid of images
            x_grid = torchvision.utils.make_grid(x[:,[2,1,0]], nrow=2)
            y_grid = torchvision.utils.make_grid(y[:,[0,]], nrow=2)
            y_reg_grid = torchvision.utils.make_grid(y_reg[:,[0,]], nrow=2)

            # write to tensorboard
            tfwriter.add_image('inputs', scale_image(x_grid), step)
            tfwriter.add_image('label', scale_image(y_grid), step)
            tfwriter.add_image('regression', scale_image(y_reg_grid), step)

        return loss
コード例 #8
0
    def _update_scaled_image(self):
        if self.cvImage is None:
            self.qtImage = None
            return

        scaled_image = scale_image(self.cvImage, self.scale)
        if len(scaled_image.shape) == 2:
            scaled_image = cv2.cvtColor(scaled_image, cv2.COLOR_GRAY2BGR)

        height, width, byteValue = scaled_image.shape
        byteValue = byteValue * width
        cv2.cvtColor(scaled_image, cv2.COLOR_BGR2RGB, scaled_image)

        self.qtImage = QtGui.QImage(scaled_image, width, height, byteValue,
                                    QtGui.QImage.Format_RGB888)
        self.resize(width, height)
コード例 #9
0
def run(model, path):
    global use_cuda
    
    for i in xrange(0, 50):
        # Generate latent vector
        x = torch.randn(1, 512, 1, 1)
        if use_cuda:
            model = model.cuda()
            x = x.cuda()
        x = Variable(x, volatile=True)
        image = model(x)
        if use_cuda:
            image = image.cpu()
        
        image_np = image.data.numpy().transpose(0, 2, 3, 1)
        image_np = scale_image(image_np[0, ...])
        image = Image.fromarray(image_np)
        fname = os.path.join(path, '_gen{}.jpg'.format(i))
        image.save(fname)
        print("{}th image".format(i))
コード例 #10
0
    def test(self):
        self.load_best_model()
        saved_dir = os.path.join(os.getcwd(), "xray_images")
        saved_dir = os.path.join(saved_dir, "test_images_128x128")
        if not os.path.exists(saved_dir):
            os.mkdir(saved_dir)
        for i in range(1, 4000):
            if os.path.exists(utils.get_image_path(True, 64, i)):
                test_noisy_128 = utils.imread(utils.get_image_path(
                    True, 64, i))
                test_noisy_128 = utils.scale_image(test_noisy_128,
                                                   2.0)  # Image size 128x128
                test_noisy_128 /= 255.0
                test_noisy_128 = test_noisy_128.reshape(128, 128, 1)

                prediction = self.model.predict(np.array([test_noisy_128]))[0]
                prediction = prediction * 255
                prediction = prediction.astype('uint8').reshape((128, 128))
                predicted_img = Image.fromarray(prediction)
                clean_image_path = utils.get_image_path(True, 128, i)
                predicted_img.save(clean_image_path)
コード例 #11
0
    def stage_PNet(self, model, img):
        h, w, _ = img.shape
        img_size = (w, h)

        boxes_tot = np.empty((0, 5))
        reg_offsets = np.empty((0, 4))

        scales = self.get_image_pyramid_scales(self.min_face_size, img_size)

        print(scales)

        for scale in scales:
            resized = utils.scale_image(img, scale)
            normalized = utils.normalize_image(resized)
            net_input = np.expand_dims(normalized, 0)

            cls_map, reg_map, _ = model.predict(net_input)
            cls_map = cls_map.squeeze()[:, :, 1]  # here
            reg_map = reg_map.squeeze()

            boxes, indices = self.generate_bboxes_with_scores(cls_map,
                                                              scale,
                                                              threshold=0.7)
            reg_deltas = reg_map[indices]

            indices = self.non_maximum_suppression(boxes, 0.5, 'union')
            boxes_tot = np.append(boxes_tot, boxes[indices], axis=0)
            reg_offsets = np.append(reg_offsets, reg_deltas[indices], axis=0)

        indices = self.non_maximum_suppression(boxes_tot, 0.7, 'union')
        boxes_tot = boxes_tot[indices]
        reg_offsets = reg_offsets[indices]

        # refine bounding boxes
        refined_boxes = self.refine_bboxes(boxes_tot, reg_offsets)
        return refined_boxes
コード例 #12
0
STYLE_WEIGHT = 1e10
TV_WEIGHT = 0.0001

transform = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.ToTensor(),
])

style_transform = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.ToTensor(),
])

vgg = Vgg16(requires_grad=False).cuda()  # vgg16 model

I = utils.scale_image(filename=STYLE_IMG_PATH, size=128, scale=512)

I = np.array(I)
plt.imshow(I)
plt.show()

style_img = utils.load_image(filename=STYLE_IMG_PATH, size=IMAGE_SIZE)
# style_img = utils.image_compose(IMG=style_img, IMAGE_ROW=4, IMAGE_COLUMN=4, IMAGE_SIZE=128)
content_img = utils.load_image(filename=CONTENT_IMG_PATH, size=IMAGE_SIZE)

style_img = style_transform(style_img)
content_img = transform(content_img)

style_img = style_img.repeat(BATCH_SIZE, 1, 1, 1).cuda()  # make fake batch
content_img = content_img.repeat(BATCH_SIZE, 1, 1, 1).cuda()
コード例 #13
0
    def gen_update(self, x_dict, log=False):
        keys = list(x_dict.keys())

        if self.distribute:
            gen = self.gen.module
            dis = self.dis.module
        else:
            gen = self.gen
            dis = self.dis

        z_dict = {name: gen.encode(x, name) for name, x in x_dict.items()}

        if gen.skip_dim:
            x_skip = {name: x[:, gen.skip_dim] for name, x in x_dict.items()}
        else:
            x_skip = {name: None for name in x_dict.keys()}

        x_recon = {
            name: gen.decode(z + noise, name, skip_x=x_skip[name])
            for name, (z, noise) in z_dict.items()
        }

        key0 = keys[0]

        _, _, h_hat, w_hat = x_recon[key0].shape

        # decode all cross pairs
        loss_gan = 0.
        data_losses = dict()
        loss_shared_recon = 0.
        loss_cycle_z = 0.
        loss_cycle_recon = 0.
        loss_cycle_z_recon = 0.
        cycles = 1 + self.global_step // self.cycle_step

        gan_losses = []
        cycle_recon_losses = []
        shared_losses = []
        cycle_zkl_losses = []
        cycle_zrecon_losses = []

        for k in keys:
            x_k = x_dict[k]
            z_k, n_k = z_dict[k]

            cross_keys = list(gen.decoders.keys())  #copy.copy(keys)
            cross_keys.remove(k)

            cross_recon = {
                kk: gen.decode(z_k + n_k, kk, skip_x=x_skip[k])
                for kk in cross_keys
            }

            # GAN Loss
            for kk in cross_keys:
                gan_losses.append(dis.models[kk].calc_gen_loss(
                    cross_recon[kk]))

            # Reconstruction loss
            data_losses[k] = self.mse_loss(x_k, x_recon[k])

            # Shared reconstruction loss
            if hasattr(self, 'shared') and (k in self.shared):
                shared_losses = [
                    self.mse_loss(x_k[:, self.shared[k][kk]],
                                  cross_recon[kk][:, self.shared[kk][k]])
                    for kk in cross_keys
                ]
            xk_1 = x_k
            zk_1 = z_k

            for c in range(cycles):
                xk_cycle = {
                    kk: gen.decode(zk_1 + n_k, kk, skip_x=x_skip[k])
                    for kk in keys
                }  # decode each domain
                z_cycle = {
                    kk: gen.encode(_x, kk)
                    for kk, _x in xk_cycle.items()
                }  # encode back to z
                x2_cycle = {
                    kk: gen.decode(z_cycle[kk][0] + n_k, k, skip_x=x_skip[k])
                    for kk in keys
                }  # decode back to cross domains

                cycle_recon_losses += [
                    self.mse_loss(x_k, x2_cycle[kk]) for kk in keys
                ]
                cycle_zrecon_losses = [
                    self.mse_loss(z_k, z_cycle[kk][0]) for kk in keys
                ]
                if log and (len(cross_keys) > 0):
                    kshow = cross_keys[0]
                    #self.tfwriter.add_image(f"images/cross_{c}_recon_{kshow}", utils.scale_image(x2_cycle[kshow][0,:1]), self.global_step)
                    #self.tfwriter.add_histogram(f"channel0/cross_reconstruct_{kshow}", x2_cycle[kshow][0,0], self.global_step)

                cycle_zkl_losses = [
                    self._compute_kl(z2) for name, (z2, _) in z_cycle.items()
                ]

                if c > 0:
                    x1_k = x2_cycle[k]
                    zk_1, _ = gen.encode(x1_k, k)
                    #z1_cycle = {name: self.gen.encode(x, name) for name, x in x2_cycle.items()}

        loss_recon = torch.mean(
            torch.stack([x for _, x in data_losses.items()]))
        loss_gan = torch.mean(torch.stack(gan_losses))
        loss_cycle_recon = cycles * torch.mean(torch.stack(cycle_recon_losses))
        if len(shared_losses) > 0:
            loss_shared_recon += torch.mean(torch.stack(shared_losses))
        loss_cycle_z = cycles * torch.mean(torch.stack(cycle_zkl_losses))
        loss_cycle_z_recon = cycles * torch.mean(
            torch.stack(cycle_zrecon_losses))

        loss = self.params['gan_w'] * loss_gan + \
               self.params['recon_x_w'] * loss_recon + \
               self.params['recon_x_cyc_w'] * loss_cycle_recon + \
               self.params['recon_z_cyc_w'] * loss_cycle_z_recon + \
               self.params['recon_kl_cyc_w'] * loss_cycle_z + \
               self.params['recon_shared_w'] * loss_shared_recon

        self.optimizer_gen.zero_grad()
        loss.backward()
        self.optimizer_gen.step()
        self.scheduler.step()

        step = self.global_step

        if log:
            tfwriter = self.tfwriter
            #print(f"Gen: Step {self.global_step} -- Loss={loss.item()}")
            tfwriter.add_scalar("losses/gen/total", loss, step)
            tfwriter.add_scalar("losses/gen/recon", loss_recon, step)
            tfwriter.add_scalar("losses/gen/gan", loss_gan, step)
            tfwriter.add_scalar("losses/gen/cycle_recon", loss_cycle_recon,
                                step)
            tfwriter.add_scalar("losses/gen/shared_recon", loss_shared_recon,
                                step)
            #tfwriter.add_scalar("losses/gen/cycle_kl", loss_cc_z, step)
            tfwriter.add_scalar("losses/gen/cycle_z_recon", loss_cycle_z_recon,
                                step)

            for name in x_dict:
                x = x_dict[name]
                self.tfwriter.add_image(f"images/{name}/input",
                                        utils.scale_image(x[0, :1]), step)
                self.tfwriter.add_image(
                    f"images/{name}/reconstruction",
                    utils.scale_image(x_recon[name][0, :1]), step)
                self.tfwriter.add_scalar(f"losses/data/{name}",
                                         data_losses[name], step)
                for j in range(x.shape[1]):
                    self.tfwriter.add_histogram(
                        f"channel{j}/Observed_Domain{name}", x[:, j], step)
                    self.tfwriter.add_histogram(
                        f"channel{j}/reconstructed_Domain{name}",
                        x_recon[name][:, j], step)
        return loss
コード例 #14
0
 def load_sprites(self):
     self.sprite_png = Image.open(self.path + "/" + self.name + ".png")
     self.sprite_png = utils.scale_image(res["scale"], self.sprite_png)
     self.sprite_pygame = utils.image_to_pygame(self.sprite_png)
コード例 #15
0
    def __call__(self):
        seed = 9
        self.batchsize = batch_size = 5
        start_images = 0
        number_of_images = 1400
        train_ratio = 0.7

        self.reset_graph(seed)

        self.build_model(seed, batch_size)

        # Set up a saver to load and save the model and a tensorboard graph
        self.saver = tf.train.Saver(max_to_keep=5)
        writer = tf.summary.FileWriter("graphs", tf.get_default_graph())

        # Check if we should load a model, and if we succeded doing so
        if (True):
            could_load, checkpoint_counter = self.load()
        else:
            could_load = False

        if could_load:
            #
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")

        #self.export_for_web()

        # Load the data (both train and test)
        data_neg, data_pos = get_data(start_images, number_of_images,
                                      train_ratio, seed)

        n_epochs = 51

        for epoch in range(n_epochs):
            print(f'epoch{epoch}')
            for i in range(
                    int(
                        np.floor(
                            float(number_of_images * train_ratio) /
                            batch_size))):
                # Get the current batch
                batch_pos = data_pos['train_data'][(
                    i * batch_size):((i * batch_size) + batch_size)]
                batch_neg = data_neg['train_data'][(
                    i * batch_size):((i * batch_size) + batch_size)]
                labels_pos = data_pos['train_labels'][(
                    i * batch_size):((i * batch_size) + batch_size)]
                labels_neg = data_neg['train_labels'][(
                    i * batch_size):((i * batch_size) + batch_size)]

                # Update the weights
                self.sess.run(self.train_d,
                              feed_dict={
                                  self.X_neg: batch_neg,
                                  self.X_pos: batch_pos
                              })
                self.sess.run(self.train_g0,
                              feed_dict={
                                  self.X_neg: batch_neg,
                                  self.X_pos: batch_pos
                              })
                self.sess.run(self.train_g1,
                              feed_dict={
                                  self.X_neg: batch_neg,
                                  self.X_pos: batch_pos
                              })

                if i % 10 == 0:
                    [l_pix0, l_pix1, l_per0, l_per1, l_gan0, l_gan1, l_dual0, l_dual1, l_D, l_g0, l_g1] = \
                    self.sess.run(
                     [self.l_pix0, self.l_pix1, self.l_per0, self.l_per1, self.l_gan0, self.l_gan1, self.l_dual0, self.l_dual1, self.l_cls, self.l_g0, self.l_g1],
                     feed_dict = {
                      self.X_neg: batch_neg,
                      self.X_pos: batch_pos
                     }
                    )
                    print(
                        "epoch: %d batch: %d l_pix0: %g l_pix1: %g l_per0: %g l_per1: %g l_gan0: %g l_gan1: %g l_dual0: %g l_dual1: %g l_g0: %g l_g1: %g l_D: %g"
                        % (epoch, i, l_pix0, l_pix1, l_per0, l_per1, l_gan0,
                           l_gan1, l_dual0, l_dual1, l_g0, l_g1, l_D))

            if (epoch % 10 == 0):
                self.save(epoch)
                all_images = self.sess.run(self.all_images,
                                           feed_dict={
                                               self.X_neg: batch_neg,
                                               self.X_pos: batch_pos,
                                           })

                for key, img_collection in all_images.items():
                    for j, img in enumerate(img_collection):
                        im = Image.fromarray(scale_image(img).astype('uint8'))
                        im.save(f'images/{key}-{str(j)}-epoch{str(epoch)}.png')
                        im.close()
コード例 #16
0
ファイル: test.py プロジェクト: syjang/mtcnn_study
import model
import mtcnn
import utils
import numpy as np
import cv2

if __name__ == "__main__":
    net = model.make_pnet()

    img = cv2.imread(
        'dataset/wider_images/0--Parade/0_Parade_marchingband_1_6.jpg')
    scale = 1
    resized = utils.scale_image(img, scale)
    normalized = utils.normalize_image(resized)
    net_input = np.expand_dims(normalized, 0)

    print(net_input.shape)

    outputs = net.predict(net_input)

    pass
コード例 #17
0
    def step(self, x, y, mask, log=False):
        y[mask == 1] = 0.
        eps = 1e-7
        y_hat, logvar, y_prob, reg_losses = self.model(x, train=True)

        y_cond = torch.masked_select(y, mask == 0)
        y_hat_cond = torch.masked_select(y_hat, mask == 0)
        y_logvar_cond = torch.masked_select(logvar, mask == 0)
        y_precision_cond = torch.exp(-y_logvar_cond) + eps

        cond_logprob = y_precision_cond * (y_cond -
                                           y_hat_cond)**2 + y_logvar_cond
        cond_logprob *= -1
        cond_logprob = torch.mean(cond_logprob)

        logprob_classifier = mask * torch.log(y_prob + eps) + (
            1 - mask) * torch.log(1 - y_prob + eps)
        logprob_classifier = torch.mean(logprob_classifier)
        logprob = logprob_classifier + cond_logprob

        neglogloss = -logprob

        loss = neglogloss + 1e-2 * reg_losses

        #print('loss', loss)
        #print('null inputs', torch.mean((x != x).type(torch.FloatTensor)))
        if loss != loss:
            print(
                f"Loss binary: {-logprob_classifier.item()}, Loss Regression: {-cond_logprob.item()}"
            )
            print('x', x[0, 0, :, :].cpu().detach().numpy())
            print('y', y[0, 0, :, :].cpu().detach().numpy())
            print('regression', y_reg[0, 0, :, :].cpu().detach().numpy())
            print('probs', y_prob[0, 0, :, :].cpu().detach().numpy())
            print('mask', mask[0, 0, :, :].cpu().detach().numpy())
            print('sq_err', sq_err[0, 0, :, :].cpu().detach().numpy())

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        self.global_step += 1

        #print('next output', self.model(x)[0])

        if log:
            step = self.global_step
            tfwriter = self.tfwriter
            tfwriter.add_scalar("losses/binary", -logprob_classifier, step)
            tfwriter.add_scalar("losses/regression", -cond_logprob, step)
            tfwriter.add_scalar("losses/regularizer", -reg_losses, step)
            tfwriter.add_scalar("losses/loss", loss, step)

            y_hat *= 1 - mask
            # create grid of images
            x_grid = torchvision.utils.make_grid(x[:8, [2, 1, 0]])
            y_grid = torchvision.utils.make_grid(y[:8, [2, 1, 0]])
            mask_grid = torchvision.utils.make_grid(mask[:8, :1])

            seg_grid = torchvision.utils.make_grid(y_prob[:8])
            y_reg_grid = torchvision.utils.make_grid(y_hat[:8, [2, 1, 0]])

            # write to tensorboard
            tfwriter.add_image('inputs', scale_image(x_grid), step)
            tfwriter.add_image('label', scale_image(y_grid), step)
            tfwriter.add_image('regression', scale_image(y_reg_grid), step)

            tfwriter.add_image('mask', mask_grid, step)
            tfwriter.add_image('segmentation', seg_grid, step)
            tfwriter.add_histogram('segmentation', y, step)
            tfwriter.add_histogram('cond_observed', y_cond, step)
            tfwriter.add_histogram('cond_regression', y_hat_cond, step)

        return loss