Пример #1
0
    def test_rectangular(self):
        t = torch.Tensor([[[
            [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
            [0.0, 0.0, 0.1, 0.0, 0.0, 0.0],
            [0.0, 0.1, 0.6, 0.1, 0.0, 0.0],
            [0.0, 0.0, 0.1, 0.0, 0.0, 0.0],
        ]]])

        actual = average_loss(variance_reg_losses(t, 2.0))
        self.assertEqual(28.88, actual.item())
Пример #2
0
def test_variance_rectangular():
    t = torch.Tensor([[[
        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
        [0.0, 0.0, 0.1, 0.0, 0.0, 0.0],
        [0.0, 0.1, 0.6, 0.1, 0.0, 0.0],
        [0.0, 0.0, 0.1, 0.0, 0.0, 0.0],
    ]]])

    actual = average_loss(variance_reg_losses(t, 2.0))
    assert actual.item() == 28.88
Пример #3
0
    def test_exact(self):
        t = torch.Tensor([[[
            [0.0, 0.0, 0.0, 0.0],
            [0.0, 0.0, 0.1, 0.0],
            [0.0, 0.1, 0.6, 0.1],
            [0.0, 0.0, 0.1, 0.0],
        ]]])

        actual = average_loss(variance_reg_losses(Variable(t), 2.0))
        self.assertEqual(28.88, actual.data[0])
Пример #4
0
def test_average_loss_mask():
    losses = torch.tensor([
        [0.0, 1.0, 0.0],
        [1.0, 0.0, 0.0],
    ])
    mask = torch.tensor([
        [1.0, 0.0, 1.0],
        [0.0, 1.0, 1.0],
    ])
    actual = average_loss(losses, mask)
    assert float(actual) == 0.0
def train_new_data_with_model():
    model = CoordRegression(n_locations=8)
    optimizer = optim.RMSprop(model.parameters(), lr=2.5e-4, alpha=0.9)
    model = torch.nn.DataParallel(model).cuda()

    from data_process_landmarks_hw import dataloader
    # 训练集
    dataloader = dataloader

    for epoch in range(10):
        epoch_start = time.time()
        print("Epoch: {}/{}".format(epoch + 1, 10))
        train_loss = 0.0
        train_loss_cord = []

        # forward pass
        model.train()

        # 训练
        for i_batch, data in enumerate(dataloader):
            img, landmarks = data
            img = torch.tensor(img, dtype=torch.float32)
            img = img.to(device)
            landmarks = torch.tensor(landmarks / 64.0, dtype=torch.float32)
            landmarks = landmarks.to(device)
            # print("Ground-truth:", gt_hmap.shape)
            optimizer.zero_grad()
            # forward pass
            coords, heatmaps = model(img)
            # per-location euclidean losses
            euc_losses = dsntnn.euclidean_losses(coords, landmarks)
            # print("predict coords", coords, landmarks)
            # per-location regulation losses
            reg_losses = dsntnn.js_reg_losses(heatmaps, landmarks, sigma_t=1.0)
            # combine losses into an overall loss
            loss = dsntnn.average_loss(euc_losses + reg_losses)

            # calculate gradients
            optimizer.zero_grad()
            loss.backward()

            # update model parameters with RMSprop
            optimizer.step()

            train_loss_cord.append(loss)

            if i_batch % 20 == 19:
                print(loss, euc_losses, reg_losses)
                # break
            # print(loss)
        torch.save(
            model,
            'models/' + 'landmarks' + '_model_new_data_8' + str(epoch) + '.pt')
        print(train_loss_cord)
Пример #6
0
    def test_rectangular(self):
        t = torch.Tensor([[[
            [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
            [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
            [0.0, 0.0, 0.0, 0.0, 0.0, 0.1],
            [0.0, 0.0, 0.0, 0.0, 0.1, 0.8],
        ]]])
        coords = torch.Tensor([[[1, 1]]])

        actual = average_loss(kl_reg_losses(t, coords, 2.0))

        self.assertEqual(1.2646753877545842, actual.item())
Пример #7
0
def test_kl_rectangular():
    t = torch.Tensor([[[
        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
        [0.0, 0.0, 0.0, 0.0, 0.0, 0.1],
        [0.0, 0.0, 0.0, 0.0, 0.1, 0.8],
    ]]])
    coords = torch.Tensor([[[1, 1]]])

    actual = average_loss(kl_reg_losses(t, coords, 2.0))

    assert actual.item() == 1.2646753877545842
Пример #8
0
    def test_rectangular(self):
        t = torch.Tensor([[[
            [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
            [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
            [0.0, 0.0, 0.0, 0.0, 0.0, 0.1],
            [0.0, 0.0, 0.0, 0.0, 0.1, 0.8],
        ]]])
        coords = torch.Tensor([[[1, 1]]])

        actual = average_loss(kl_reg_losses(Variable(t), Variable(coords),
                                            2.0))

        self.assertEqual(1.2646753877545842, actual.data[0])
def test_kl_rectangular():
    t = torch.tensor([[
        [
            [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
            [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
            [0.0, 0.0, 0.0, 0.0, 0.0, 0.1],
            [0.0, 0.0, 0.0, 0.0, 0.1, 0.8],
        ]
    ]])
    coords = torch.tensor([[[1.0, 1.0]]])

    actual = average_loss(kl_reg_losses(t, coords, 2.0))

    assert actual.item() == pytest.approx(1.26467538775)
Пример #10
0
    def test_batch(self):
        t = torch.Tensor([[[
            [0.0, 0.0, 0.0, 0.0],
            [0.0, 0.0, 0.1, 0.0],
            [0.0, 0.1, 0.6, 0.1],
            [0.0, 0.0, 0.1, 0.0],
        ]],
                          [[
                              [0.0, 0.0, 0.0, 0.0],
                              [0.0, 0.2, 0.0, 0.0],
                              [0.1, 0.5, 0.1, 0.0],
                              [0.0, 0.1, 0.0, 0.0],
                          ]]])

        actual = average_loss(variance_reg_losses(t, 2.0))
        self.assertEqual(28.54205, actual.item())
Пример #11
0
def test_variance_batch():
    t = torch.Tensor([[[
        [0.0, 0.0, 0.0, 0.0],
        [0.0, 0.0, 0.1, 0.0],
        [0.0, 0.1, 0.6, 0.1],
        [0.0, 0.0, 0.1, 0.0],
    ]],
                      [[
                          [0.0, 0.0, 0.0, 0.0],
                          [0.0, 0.2, 0.0, 0.0],
                          [0.1, 0.5, 0.1, 0.0],
                          [0.0, 0.1, 0.0, 0.0],
                      ]]])

    actual = average_loss(variance_reg_losses(t, 2.0))
    assert actual.item() == approx(28.54205)
Пример #12
0
 def test_3d(self):
     t = torch.Tensor([[[[
         [0.000035, 0.000002, 0.000000],
         [0.009165, 0.000570, 0.000002],
         [0.147403, 0.009165, 0.000035],
     ],
                         [
                             [0.000142, 0.000009, 0.000000],
                             [0.036755, 0.002285, 0.000009],
                             [0.591145, 0.036755, 0.000142],
                         ],
                         [
                             [0.000035, 0.000002, 0.000000],
                             [0.009165, 0.000570, 0.000002],
                             [0.147403, 0.009165, 0.000035],
                         ]]]])
     actual = average_loss(variance_reg_losses(Variable(t), 0.6))
     self.assertEqual(0.18564102213775013, actual.data[0])
Пример #13
0
def test_variance_3d():
    t = torch.Tensor([[[[
        [0.000035, 0.000002, 0.000000],
        [0.009165, 0.000570, 0.000002],
        [0.147403, 0.009165, 0.000035],
    ],
                        [
                            [0.000142, 0.000009, 0.000000],
                            [0.036755, 0.002285, 0.000009],
                            [0.591145, 0.036755, 0.000142],
                        ],
                        [
                            [0.000035, 0.000002, 0.000000],
                            [0.009165, 0.000570, 0.000002],
                            [0.147403, 0.009165, 0.000035],
                        ]]]])
    actual = average_loss(variance_reg_losses(t, 0.6))
    assert actual.item() == approx(0.18564102213775013)
Пример #14
0
def test_euclidean_mask():
    output = torch.tensor([
        [[0.0, 0.0], [1.0, 1.0], [0.0, 0.0]],
        [[1.0, 1.0], [0.0, 0.0], [0.0, 0.0]],
    ])

    target = torch.tensor([
        [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]],
        [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]],
    ])

    mask = torch.tensor([
        [1.0, 0.0, 1.0],
        [0.0, 1.0, 1.0],
    ])

    expected = 0.0
    actual = average_loss(euclidean_losses(output, target), mask)
    assert expected == actual.item()
Пример #15
0
def calc_coord_loss(coords, heatmaps, target_var, masks):
    # Per-location euclidean losses
    euc_losses = dsntnn.euclidean_losses(
        coords,
        target_var)  # shape:[B, D, L, 2] batch, depth, locations, feature
    # Per-location regularization losses

    reg_losses = []
    for i in range(heatmaps.shape[1]):
        hms = heatmaps[:, i]
        target = target_var[:, i]
        reg_loss = dsntnn.js_reg_losses(hms, target, sigma_t=1.0)
        reg_losses.append(reg_loss)
    reg_losses = torch.stack(reg_losses, 1)
    # reg_losses = dsntnn.js_reg_losses(heatmaps, target_var, sigma_t=1.0) # shape: [B, D, L, 7, 7]
    # Combine losses into an overall loss
    coord_loss = dsntnn.average_loss((euc_losses + reg_losses).squeeze(),
                                     mask=masks)
    return coord_loss
Пример #16
0
    def test_mask(self):
        output = torch.Tensor([
            [[0, 0], [1, 1], [0, 0]],
            [[1, 1], [0, 0], [0, 0]],
        ])

        target = torch.Tensor([
            [[0, 0], [0, 0], [0, 0]],
            [[0, 0], [0, 0], [0, 0]],
        ])

        mask = torch.Tensor([
            [1, 0, 1],
            [0, 1, 1],
        ])

        expected = 0.0
        actual = average_loss(euclidean_losses(output, target), mask)
        self.assertEqual(expected, actual.item())
Пример #17
0
def test_kl_mask():
    t = torch.Tensor([[[
        [0.0, 0.0, 0.0, 0.0],
        [0.0, 0.0, 0.0, 0.0],
        [0.0, 0.0, 0.0, 0.1],
        [0.0, 0.0, 0.1, 0.8],
    ],
                       [
                           [0.8, 0.1, 0.0, 0.0],
                           [0.1, 0.0, 0.0, 0.0],
                           [0.0, 0.0, 0.0, 0.0],
                           [0.0, 0.0, 0.0, 0.0],
                       ]]])
    coords = torch.Tensor([[[1, 1], [0, 0]]])
    mask = torch.Tensor([[1, 0]])

    actual = average_loss(kl_reg_losses(t, coords, 2.0), mask)

    assert actual.item() == approx(1.2228811717796824)
Пример #18
0
    def test_mask(self):
        t = torch.Tensor([[[
            [0.0, 0.0, 0.0, 0.0],
            [0.0, 0.0, 0.0, 0.0],
            [0.0, 0.0, 0.0, 0.1],
            [0.0, 0.0, 0.1, 0.8],
        ],
                           [
                               [0.8, 0.1, 0.0, 0.0],
                               [0.1, 0.0, 0.0, 0.0],
                               [0.0, 0.0, 0.0, 0.0],
                               [0.0, 0.0, 0.0, 0.0],
                           ]]])
        coords = torch.Tensor([[[1, 1], [0, 0]]])
        mask = torch.Tensor([[1, 0]])

        actual = average_loss(kl_reg_losses(t, coords, 2.0), mask)

        self.assertEqual(1.2228811717796824, actual.item())
Пример #19
0
    def test_mask(self):
        t = torch.Tensor([[[
            [0.0, 0.0, 0.0, 0.0],
            [0.0, 0.0, 0.0, 0.0],
            [0.0, 0.0, 0.0, 0.1],
            [0.0, 0.0, 0.1, 0.8],
        ],
                           [
                               [0.8, 0.1, 0.0, 0.0],
                               [0.1, 0.0, 0.0, 0.0],
                               [0.0, 0.0, 0.0, 0.0],
                               [0.0, 0.0, 0.0, 0.0],
                           ]]])
        coords = torch.Tensor([[[1, 1], [0, 0]]])
        mask = torch.Tensor([[1, 0]])

        actual = average_loss(
            kl_reg_losses(Variable(t), Variable(coords), 2.0), Variable(mask))

        self.assertEqual(1.2228811717796824, actual.data[0])
Пример #20
0
    def forward(self, heatmaps, coords, targets, device):
        out = torch.Tensor([31, 31]).to(device)
        batch_size = coords.shape[0]
        n_stages = coords.shape[1]

        if len(targets.shape) != len(coords.shape):
            targets = torch.unsqueeze(targets, 1)
            targets = targets.repeat(1, n_stages, 1, 1)

        targets = (targets.div(255) * 2 + 1) / out - 1

        losses = []
        for i in range(batch_size):
            euc_loss = dsntnn.euclidean_losses(coords[i, :, :, :],
                                               targets[i, :, :, :])
            reg_loss = dsntnn.js_reg_losses(heatmaps[i, :, :, :, :],
                                            targets[i, :, :, :],
                                            sigma_t=1.0)
            losses.append(dsntnn.average_loss(euc_loss + reg_loss))
        return sum(losses) / batch_size
Пример #21
0
    def test_mask(self):
        output = torch.Tensor([
            [[0, 0], [1, 1], [0, 0]],
            [[1, 1], [0, 0], [0, 0]],
        ])

        target = torch.Tensor([
            [[0, 0], [0, 0], [0, 0]],
            [[0, 0], [0, 0], [0, 0]],
        ])

        mask = torch.Tensor([
            [1, 0, 1],
            [0, 1, 1],
        ])

        expected = torch.Tensor([0])
        actual = average_loss(
            euclidean_losses(Variable(output), Variable(target)),
            Variable(mask))

        self.assertEqual(expected, actual.data)
Пример #22
0
            def coord_and_heatmap_loss(_coords, _heatmaps, y):
                if not isinstance(_coords, list):
                    coords = [_coords]
                    heatmaps = [_heatmaps]
                else:
                    coords = _coords
                    heatmaps = _heatmaps

                _losses = []
                for c, h in zip(
                        coords, heatmaps
                ):  # for intermediate supervision: apply loss to all outputs
                    coords_loss = sum(
                        [criterion(c, y) for criterion in coord_criterions])
                    heatmap_losses = dsntnn.js_reg_losses(
                        h, y, sigma_t=self.config["heatmap_sigma"]
                    )  # TODO different sigmas in each HG?
                    heatmap_loss = dsntnn.average_loss(heatmap_losses)
                    _losses.append(
                        coords_loss + heatmap_loss
                    )  # TODO add lambda for regularization strength??
                return sum(_losses)
Пример #23
0
def test_euclidean_forward_and_backward():
    input_tensor = torch.tensor([
        [[3.0, 4.0], [3.0, 4.0]],
        [[3.0, 4.0], [3.0, 4.0]],
    ])

    target = torch.tensor([
        [[0.0, 0.0], [0.0, 0.0]],
        [[0.0, 0.0], [0.0, 0.0]],
    ])

    in_var = input_tensor.detach().requires_grad_(True)

    expected_loss = 5.0
    actual_loss = average_loss(euclidean_losses(in_var, target))
    expected_grad = torch.tensor([
        [[0.15, 0.20], [0.15, 0.20]],
        [[0.15, 0.20], [0.15, 0.20]],
    ])
    actual_loss.backward()

    assert expected_loss == actual_loss.item()
    assert_allclose(expected_grad, in_var.grad)
Пример #24
0
    def test_forward_and_backward(self):
        input_tensor = torch.Tensor([
            [[3, 4], [3, 4]],
            [[3, 4], [3, 4]],
        ])

        target = torch.Tensor([
            [[0, 0], [0, 0]],
            [[0, 0], [0, 0]],
        ])

        in_var = input_tensor.detach().requires_grad_(True)

        expected_loss = 5.0
        actual_loss = average_loss(euclidean_losses(in_var, target))
        expected_grad = torch.Tensor([
            [[0.15, 0.20], [0.15, 0.20]],
            [[0.15, 0.20], [0.15, 0.20]],
        ])
        actual_loss.backward()

        self.assertEqual(expected_loss, actual_loss.item())
        self.assertEqual(expected_grad, in_var.grad)
Пример #25
0
    def test_forward_and_backward(self):
        input_tensor = torch.Tensor([
            [[3, 4], [3, 4]],
            [[3, 4], [3, 4]],
        ])

        target = torch.Tensor([
            [[0, 0], [0, 0]],
            [[0, 0], [0, 0]],
        ])

        in_var = Variable(input_tensor, requires_grad=True)

        expected_loss = torch.Tensor([5])
        actual_loss = average_loss(euclidean_losses(in_var, Variable(target)))
        expected_grad = torch.Tensor([
            [[0.15, 0.20], [0.15, 0.20]],
            [[0.15, 0.20], [0.15, 0.20]],
        ])
        actual_loss.backward()
        actual_grad = in_var.grad.data

        self.assertEqual(expected_loss, actual_loss.data)
        self.assertEqual(expected_grad, actual_grad)
Пример #26
0
        input_var = input_tensor.cuda()

        eye_coords_tensor = torch.Tensor([[label_all]])
        target_tensor = (eye_coords_tensor * 2 +
                         1) / torch.Tensor(image_size) - 1
        target_var = target_tensor.cuda()

        coords, heatmaps = model(input_var)

        # Per-location euclidean losses
        euc_losses = dsntnn.euclidean_losses(coords, target_var)
        # Per-location regularization losses
        reg_losses = dsntnn.js_reg_losses(heatmaps, target_var,
                                          sigma_t=1.0).cuda()
        # Combine losses into an overall loss
        loss = dsntnn.average_loss(euc_losses + reg_losses).cuda()

        # Calculate gradients
        optimizer.zero_grad()
        loss.backward()
        count += 1

        if count % 200 == 0:
            print("process: " + str(count) + "  /2000   in epoch:   " +
                  str(i) + str(target_var))
            print("loss: " + str(loss) + " coords: " +
                  str(list(coords.data[0, 0])))
            logging.info("process: " + str(count) + "  /2000   in epoch:   " +
                         str(i) + str(target_var))

        # Update model parameters with RMSprop
Пример #27
0
 def calc_loss(mean, stddev):
     hm = make_gauss(mean, [9, 9], sigma=stddev)
     args = [hm]
     if uses_mean: args.append(target_mean)
     args.append(target_stddev)
     return average_loss(loss_method(*args))
    total_acc_misses = np.zeros((len(threshes)), dtype=int)
    model.train()
    start = time.time()

    for images, labels in train_loader:  # Get Batch
        cuda_images, cuda_labels = images.cuda(main_gpu), labels.cuda(main_gpu)
        # Forward pass
        coords, heatmaps = model(cuda_images)

        total_acc_misses = total_acc_misses + accuracy_func(
            coords, cuda_labels)

        # Loss
        euc_losses = dsntnn.euclidean_losses(coords, cuda_labels)
        reg_losses = dsntnn.js_reg_losses(heatmaps, cuda_labels, sigma_t=1.0)
        loss = dsntnn.average_loss(euc_losses + reg_losses)

        total_loss += loss.item()

        # Calculate gradients
        optimizer.zero_grad()
        loss.backward()

        # Update model parameters with RMSprop
        optimizer.step()

    end = time.time()

    torch.save(model.state_dict(),
               checkpoints_path + "epoch_{}.pth".format(epoch))
Пример #29
0
    def train(self,
              epoch,
              write2tensorboard=True,
              writer_interval=20,
              viz_model=False):
        #~Main training loop
        #@param:
        #    writer_interval = write to tensorboard every x epochs

        #* Allow param optimisation & reset losses
        self.model.train()
        self.writer.reset_losses()
        #*Visualise model using graphviz
        if viz_model:
            self.viz_model(output_path='./logs/')
        # ** Training Loop **
        for idx, data in enumerate(tqdm(self.train_dataLoader)):
            #*Load data
            sag_img = data['sag_image'].to(self.device, dtype=torch.float32)
            cor_img = data['cor_image'].to(self.device, dtype=torch.float32)
            heatmap = data['heatmap'].to(self.device, dtype=torch.float32)
            keypoints, labels = data['keypoints'].to(
                self.device, dtype=torch.float32), data['class_labels'].to(
                    self.device, dtype=torch.float32)
            self.optimizer.zero_grad()  #Reset gradients

            #! seg_out = output from segmentation head (B x N_OUTPUTS x H x W)
            #! heatmap = 1D heatmap (B x N_OUTPUTS x H x 1)
            #! coords = 1D coordinates (B x N_OUTPUTS x 1)

            pred_seg, pred_heatmap, pred_coords, pred_labels = self.model(
                sag_img, cor_img)
            if self.detect:
                #* If model set for detection only (no labelling)
                pred_map, gt_map = torch.max(
                    pred_heatmap, dim=1, keepdim=True), torch.max(heatmap,
                                                                  dim=1,
                                                                  keepdim=True)
                ce_loss = self.criterion(pred_labels, labels)  # Classifier
                loss = cl.js_reg(pred_map.values,
                                 gt_map.values).mean()  # MSE w/ heatmaps
                loss += ce_loss
            else:
                #* Loss + Regularisation
                l1_loss = torch.nn.functional.mse_loss(pred_coords,
                                                       keypoints,
                                                       reduction='none')
                js_reg = cl.kl_reg(pred_heatmap, heatmap)
                loss = dsntnn.average_loss(l1_loss + js_reg, mask=labels)

            self.writer.train_loss.append(loss.item())
            #* Optimiser step
            loss.backward()
            self.optimizer.step()

            if write2tensorboard:
                # ** Write inputs to tensorboard
                if epoch % writer_interval == 0 and idx == 0:
                    self.writer.plot_inputs(
                        f'Sagittal Inputs at epoch {epoch}',
                        sag_img,
                        targets=[keypoints, labels])
                    self.writer.plot_inputs(f'Coronal Inputs at epoch {epoch}',
                                            cor_img)
                    if self.detect:
                        self.writer.plot_histogram(
                            f'Target heatmap at epoch {epoch}',
                            gt_map.values,
                            targets=[None, labels],
                            detect=True)
                    else:
                        self.writer.plot_histogram(
                            f'Target heatmap at epoch {epoch}',
                            heatmap,
                            targets=[None, labels],
                            detect=False)

        print('Train Loss:', np.mean(self.writer.train_loss))
        self.writer.add_scalar('Training Loss',
                               np.mean(self.writer.train_loss), epoch)
Пример #30
0
    def validation(self,
                   epoch,
                   write2tensorboard=True,
                   writer_interval=10,
                   write_gif=False):
        #~Validation loop
        with torch.set_grad_enabled(False):
            print('Validation...')
            for idx, data in enumerate(tqdm(self.val_dataLoader)):
                #* Load data
                sag_img = data['sag_image'].to(self.device,
                                               dtype=torch.float32)
                cor_img = data['cor_image'].to(self.device,
                                               dtype=torch.float32)
                heatmap = data['heatmap'].to(self.device, dtype=torch.float32)
                keypoints, labels = data['keypoints'].to(
                    self.device, dtype=torch.float32), data['class_labels'].to(
                        self.device, dtype=torch.float32)
                pred_seg, pred_heatmap, pred_coords, pred_labels = self.model(
                    sag_img, cor_img)
                if self.detect:
                    #* If model set for detection only (no labelling)
                    pred_map, gt_map = torch.max(pred_heatmap,
                                                 dim=1,
                                                 keepdim=True), torch.max(
                                                     heatmap,
                                                     dim=1,
                                                     keepdim=True)
                    ce_loss = self.criterion(pred_labels, labels)  # Classifier
                    loss = cl.js_reg(pred_map.values,
                                     gt_map.values).mean()  # MSE w/ heatmaps
                    loss += ce_loss
                else:
                    #* Loss + Regularisation
                    l1_loss = torch.nn.functional.mse_loss(pred_coords,
                                                           keypoints,
                                                           reduction='none')
                    js_reg = cl.kl_reg(pred_heatmap, heatmap)
                    loss = dsntnn.average_loss(l1_loss + js_reg, mask=labels)

                self.writer.val_loss.append(loss.item())
                if not self.detect:
                    self.writer.reg.append(js_reg[labels == 1].mean().item())
                    self.writer.l1.append(l1_loss[labels == 1].mean().item())
                self.writer.ce.append(ce_loss.item())
                if write_gif:
                    if idx == 0:
                        self.write2file(pred_heatmap[0],
                                        heatmap[0],
                                        epoch=epoch)
                if write2tensorboard:
                    #* Write predictions to tensorboard
                    if epoch % writer_interval == 0 and idx == 0:
                        print('Predicted Labels + GT: ', pred_labels, labels)
                        self.writer.plot_prediction(
                            f'Prediction at epoch {epoch}',
                            img=sag_img,
                            prediction=pred_coords,
                            targets=[keypoints, labels])
                        if self.detect:
                            self.writer.plot_histogram(
                                f'Predicted Heatmap at epoch {epoch}',
                                pred_map.values,
                                targets=[heatmap, labels],
                                detect=True)
                        else:
                            self.writer.plot_histogram(
                                f'Predicted Heatmap at epoch {epoch}',
                                pred_heatmap,
                                targets=[heatmap, labels],
                                detect=False)
            print('Validation Loss:', np.mean(self.writer.val_loss))
            self.scheduler.step(np.mean(self.writer.val_loss))
            self.writer.add_scalar('Validation Loss',
                                   np.mean(self.writer.val_loss), epoch)
            if not self.detect:
                self.writer.add_scalar('Regularisation',
                                       np.mean(self.writer.reg), epoch)
                self.writer.add_scalar('L1-Loss', np.mean(self.writer.l1),
                                       epoch)
            self.writer.add_scalar('CE loss', np.mean(self.writer.ce), epoch)