Exemple #1
0
def _load_daae(state_dict, device=0):
    """
    Load a pretrained pytorch DAAE2D state dict

    :param state_dict: state dict of a daae
    :param device: index of the device (if there are no GPU devices, it will be moved to the CPU)
    :return: a module that corresponds to the trained network
    """
    from networks.daae import DAAE2D

    # extract the hyperparameters of the network
    feature_maps = state_dict[
        'encoder.features.convblock1.conv1.unit.0.weight'].size(0)
    levels = int(list(state_dict.keys())[-15][len('decoder.features.upconv')])
    bottleneck_in_features = state_dict['encoder.bottleneck.0.weight'].size(1)
    bottleneck_dim = state_dict['encoder.bottleneck.0.weight'].size(0)
    x = int(
        np.sqrt(bottleneck_in_features * 2**(3 * levels - 1) / feature_maps))
    norm = 'batch' if 'norm' in list(state_dict.keys())[2] else 'instance'
    lambda_reg = 0.0
    activation = 'relu'
    dropout_enc = 0.0
    n_hidden = state_dict['domain_classifier.linear2.unit.0.weight'].size(1)
    n_domains = state_dict['domain_classifier.linear2.unit.0.weight'].size(0)

    # initialize the network
    net = DAAE2D(lambda_reg=lambda_reg,
                 input_size=[x, x],
                 bottleneck_dim=bottleneck_dim,
                 feature_maps=feature_maps,
                 levels=levels,
                 dropout_enc=dropout_enc,
                 norm=norm,
                 activation=activation,
                 fc_channels=(n_hidden, n_domains))

    # load the parameters in the model
    net.load_state_dict(state_dict)

    # map to the correct device
    module_to_device(net, device=device)

    return net
Exemple #2
0
    def train_epoch(self,
                    loader_src,
                    loader_tar_ul,
                    loader_tar_l,
                    optimizer,
                    epoch,
                    augmenter=None,
                    print_stats=1,
                    writer=None,
                    write_images=False,
                    device=0):
        """
        Trains the network for one epoch
        :param loader_src: source dataloader (labeled)
        :param loader_tar_ul: target dataloader (unlabeled)
        :param loader_tar_l: target dataloader (labeled)
        :param optimizer: optimizer for the loss function
        :param epoch: current epoch
        :param augmenter: data augmenter
        :param print_stats: frequency of printing statistics
        :param writer: summary writer
        :param write_images: frequency of writing images
        :param device: GPU device where the computations should occur
        :return: average training loss over the epoch
        """
        # perform training on GPU/CPU
        module_to_device(self, device)
        self.train()

        # keep track of the average loss during the epoch
        loss_seg_src_cum = 0.0
        loss_seg_tar_cum = 0.0
        total_loss_cum = 0.0
        cnt = 0

        # zip dataloaders
        if loader_tar_l is None:
            dl = zip(loader_src)
        else:
            dl = zip(loader_src, loader_tar_l)

        # start epoch
        time_start = datetime.datetime.now()
        for i, data in enumerate(dl):

            # transfer to suitable device
            data_src = tensor_to_device(data[0], device)
            if loader_tar_l is not None:
                data_tar_l = tensor_to_device(data[1], device)

            # augment if necessary
            if loader_tar_l is None:
                data_aug = (data_src[0], data_src[1])
                x_src, y_src = augment_samples(data_aug, augmenter=augmenter)
            else:
                data_aug = (data_src[0], data_src[1])
                x_src, y_src = augment_samples(data_aug, augmenter=augmenter)
                data_aug = (data_tar_l[0], data_tar_l[1])
                x_tar_l, y_tar_l = augment_samples(data_aug,
                                                   augmenter=augmenter)
                y_tar_l = get_labels(y_tar_l, coi=self.coi, dtype=int)

            # zero the gradient buffers
            self.zero_grad()

            # forward prop and compute loss
            loss_seg_tar = torch.Tensor([0])
            y_src_pred = self(x_src)
            loss_seg_src = self.seg_loss(y_src_pred, y_src[:, 0, ...])
            total_loss = loss_seg_src
            if loader_tar_l is not None:
                y_tar_l_pred = self(x_tar_l)
                loss_seg_tar = self.seg_loss(y_tar_l_pred, y_tar_l[:, 0, ...])
                total_loss = total_loss + loss_seg_tar

            loss_seg_src_cum += loss_seg_src.data.cpu().numpy()
            loss_seg_tar_cum += loss_seg_tar.data.cpu().numpy()
            total_loss_cum += total_loss.data.cpu().numpy()
            cnt += 1

            # backward prop
            total_loss.backward()

            # apply one step in the optimization
            optimizer.step()

            # print statistics of necessary
            if i % print_stats == 0:
                print(
                    '[%s] Epoch %5d - Iteration %5d/%5d - Loss seg src: %.6f - Loss seg tar: %.6f - Loss: %.6f'
                    % (datetime.datetime.now(), epoch, i,
                       len(loader_src.dataset) / loader_src.batch_size,
                       loss_seg_src_cum / cnt, loss_seg_tar_cum / cnt,
                       total_loss_cum / cnt))

        # keep track of time
        runtime = datetime.datetime.now() - time_start
        seconds = runtime.total_seconds()
        hours = seconds // 3600
        minutes = (seconds - hours * 3600) // 60
        seconds = seconds - hours * 3600 - minutes * 60
        print_frm(
            'Epoch %5d - Runtime for training: %d hours, %d minutes, %f seconds'
            % (epoch, hours, minutes, seconds))

        # don't forget to compute the average and print it
        loss_seg_src_avg = loss_seg_src_cum / cnt
        loss_seg_tar_avg = loss_seg_tar_cum / cnt
        total_loss_avg = total_loss_cum / cnt
        print(
            '[%s] Training Epoch %4d - Loss seg src: %.6f - Loss seg tar: %.6f - Loss: %.6f'
            % (datetime.datetime.now(), epoch, loss_seg_src_avg,
               loss_seg_tar_avg, total_loss_avg))

        # log everything
        if writer is not None:

            # always log scalars
            log_scalars([loss_seg_src_avg, loss_seg_tar_avg, total_loss_avg], [
                'train/' + s
                for s in ['loss-seg-src', 'loss-seg-tar', 'total-loss']
            ],
                        writer,
                        epoch=epoch)

            # log images if necessary
            if write_images:
                y_src_pred = F.softmax(y_src_pred, dim=1)[:, 1:2, :, :].data
                log_images_2d(
                    [x_src.data, y_src.data, y_src_pred],
                    ['train/' + s for s in ['src/x', 'src/y', 'src/y-pred']],
                    writer,
                    epoch=epoch)
                if loader_tar_l is not None:
                    y_tar_l_pred = F.softmax(y_tar_l_pred,
                                             dim=1)[:, 1:2, :, :].data
                    log_images_2d([x_tar_l.data, y_tar_l, y_tar_l_pred], [
                        'train/' + s
                        for s in ['tar/x-l', 'tar/y-l', 'tar/y-l-pred']
                    ],
                                  writer,
                                  epoch=epoch)

        return total_loss_avg
Exemple #3
0
    def test_epoch(self,
                   loader_src,
                   loader_tar_ul,
                   loader_tar_l,
                   epoch,
                   writer=None,
                   write_images=False,
                   device=0):
        """
        Trains the network for one epoch
        :param loader_src: source dataloader (labeled)
        :param loader_tar_ul: target dataloader (unlabeled)
        :param loader_tar_l: target dataloader (labeled)
        :param epoch: current epoch
        :param writer: summary writer
        :param write_images: frequency of writing images
        :param device: GPU device where the computations should occur
        :return: average training loss over the epoch
        """
        # perform training on GPU/CPU
        module_to_device(self, device)
        self.eval()

        # keep track of the average loss during the epoch
        loss_seg_src_cum = 0.0
        loss_seg_tar_cum = 0.0
        total_loss_cum = 0.0
        cnt = 0

        # zip dataloaders
        if loader_tar_l is None:
            dl = zip(loader_src)
        else:
            dl = zip(loader_src, loader_tar_l)

        # start epoch
        y_preds = []
        ys = []
        time_start = datetime.datetime.now()
        for i, data in enumerate(dl):

            # transfer to suitable device
            x_src, y_src = tensor_to_device(data[0], device)
            x_tar_l, y_tar_l = tensor_to_device(data[1], device)
            x_src = x_src.float()
            x_tar_l = x_tar_l.float()
            y_src = y_src.long()
            y_tar_l = y_tar_l.long()

            # forward prop and compute loss
            y_src_pred = self(x_src)
            y_tar_l_pred = self(x_tar_l)
            loss_seg_src = self.seg_loss(y_src_pred, y_src[:, 0, ...])
            loss_seg_tar = self.seg_loss(y_tar_l_pred, y_tar_l[:, 0, ...])
            total_loss = loss_seg_src + loss_seg_tar

            loss_seg_src_cum += loss_seg_src.data.cpu().numpy()
            loss_seg_tar_cum += loss_seg_tar.data.cpu().numpy()
            total_loss_cum += total_loss.data.cpu().numpy()
            cnt += 1

            for b in range(y_tar_l_pred.size(0)):
                y_preds.append(
                    F.softmax(y_tar_l_pred,
                              dim=1)[b, ...].view(y_tar_l_pred.size(1),
                                                  -1).data.cpu().numpy())
                ys.append(y_tar_l[b, 0, ...].flatten().cpu().numpy())

        # keep track of time
        runtime = datetime.datetime.now() - time_start
        seconds = runtime.total_seconds()
        hours = seconds // 3600
        minutes = (seconds - hours * 3600) // 60
        seconds = seconds - hours * 3600 - minutes * 60
        print_frm(
            'Epoch %5d - Runtime for testing: %d hours, %d minutes, %f seconds'
            % (epoch, hours, minutes, seconds))

        # prep for metric computation
        y_preds = np.concatenate(y_preds, axis=1)
        ys = np.concatenate(ys)
        js = np.asarray([
            jaccard((ys == i).astype(int), y_preds[i, :])
            for i in range(len(self.coi))
        ])
        ams = np.asarray([
            accuracy_metrics((ys == i).astype(int), y_preds[i, :])
            for i in range(len(self.coi))
        ])

        # don't forget to compute the average and print it
        loss_seg_src_avg = loss_seg_src_cum / cnt
        loss_seg_tar_avg = loss_seg_tar_cum / cnt
        total_loss_avg = total_loss_cum / cnt
        print(
            '[%s] Testing Epoch %4d - Loss seg src: %.6f - Loss seg tar: %.6f - Loss: %.6f'
            % (datetime.datetime.now(), epoch, loss_seg_src_avg,
               loss_seg_tar_avg, total_loss_avg))

        # log everything
        if writer is not None:

            # always log scalars
            log_scalars([
                loss_seg_src_avg, loss_seg_tar_avg, total_loss_avg,
                np.mean(js, axis=0), *(np.mean(ams, axis=0))
            ], [
                'test/' + s for s in [
                    'loss-seg-src', 'loss-seg-tar', 'total-loss', 'jaccard',
                    'accuracy', 'balanced-accuracy', 'precision', 'recall',
                    'f-score'
                ]
            ],
                        writer,
                        epoch=epoch)

            # log images if necessary
            if write_images:
                y_src_pred = F.softmax(y_src_pred, dim=1)[:, 1:2, :, :].data
                y_tar_l_pred = F.softmax(y_tar_l_pred, dim=1)[:,
                                                              1:2, :, :].data
                log_images_2d([
                    x_src.data, y_src.data, y_src_pred, x_tar_l.data, y_tar_l,
                    y_tar_l_pred
                ], [
                    'test/' + s for s in [
                        'src/x', 'src/y', 'src/y-pred', 'tar/x-l', 'tar/y-l',
                        'tar/y-l-pred'
                    ]
                ],
                              writer,
                              epoch=epoch)

        return total_loss_avg
Exemple #4
0
    def test_epoch(self,
                   loader_src,
                   loader_tar_ul,
                   loader_tar_l,
                   epoch,
                   writer=None,
                   write_images=False,
                   device=0):
        """
        Trains the network for one epoch
        :param loader_src: source dataloader (labeled)
        :param loader_tar_ul: target dataloader (unlabeled)
        :param loader_tar_l: target dataloader (labeled)
        :param epoch: current epoch
        :param writer: summary writer
        :param write_images: frequency of writing images
        :param device: GPU device where the computations should occur
        :return: average training loss over the epoch
        """
        # perform training on GPU/CPU
        module_to_device(self, device)
        self.eval()

        # keep track of the average loss during the epoch
        loss_seg_src_cum = 0.0
        loss_seg_tar_cum = 0.0
        loss_rec_src_cum = 0.0
        loss_rec_tar_cum = 0.0
        loss_dc_x_cum = 0.0
        loss_dc_y_cum = 0.0
        total_loss_cum = 0.0
        cnt = 0

        # zip dataloaders
        dl = zip(loader_src, loader_tar_ul, loader_tar_l)

        # start epoch
        y_preds = []
        ys = []
        time_start = datetime.datetime.now()
        for i, data in enumerate(dl):

            # transfer to suitable device
            x_src, y_src = tensor_to_device(data[0], device)
            x_tar_ul = tensor_to_device(data[1], device)
            x_tar_l, y_tar_l = tensor_to_device(data[2], device)
            x_src = x_src.float()
            x_tar_ul = x_tar_ul.float()
            x_tar_l = x_tar_l.float()
            y_src = y_src.long()
            y_tar_l = y_tar_l.long()

            # get domain labels for domain confusion
            dom_labels_x = tensor_to_device(
                torch.zeros((x_src.size(0) + x_tar_ul.size(0))),
                device).long()
            dom_labels_x[x_src.size(0):] = 1
            dom_labels_y = tensor_to_device(
                torch.zeros((x_src.size(0) + x_tar_ul.size(0))),
                device).long()
            dom_labels_y[x_src.size(0):] = 1

            # check train mode and compute loss
            loss_seg_src = torch.Tensor([0])
            loss_seg_tar = torch.Tensor([0])
            loss_rec_src = torch.Tensor([0])
            loss_rec_tar = torch.Tensor([0])
            loss_dc_x = torch.Tensor([0])
            loss_dc_y = torch.Tensor([0])
            if self.train_mode == RECONSTRUCTION:
                x_src_rec, x_src_rec_dom = self.forward_rec(x_src)
                x_tar_ul_rec, x_tar_ul_rec_dom = self.forward_rec(x_tar_ul)
                loss_rec_src = self.rec_loss(x_src_rec, x_src)
                loss_rec_tar = self.rec_loss(x_tar_ul, x_tar_ul_rec)
                loss_dc_x = self.dc_loss(
                    torch.cat((x_src_rec_dom, x_tar_ul_rec_dom), dim=0),
                    dom_labels_x)
                total_loss = loss_rec_src + loss_rec_tar + self.lambda_dc * loss_dc_x
            elif self.train_mode == SEGMENTATION:
                # switch between reconstructed and original inputs
                if np.random.rand() < self.p:
                    y_src_pred, y_src_pred_dom = self.forward_seg(x_src)
                else:
                    x_src_rec, _ = self.forward_rec(x_src)
                    y_src_pred, y_src_pred_dom = self.forward_seg(x_src_rec)
                    dom_labels_y[:x_src.size(0)] = 1
                if np.random.rand() < self.p:
                    y_tar_ul_pred, y_tar_ul_pred_dom = self.forward_seg(
                        x_tar_ul)
                else:
                    x_tar_ul_rec, _ = self.forward_rec(x_tar_ul)
                    y_tar_ul_pred, y_tar_ul_pred_dom = self.forward_seg(
                        x_tar_ul_rec)
                    dom_labels_y[x_src.size(0):] = 1
                loss_seg_src = self.seg_loss(y_src_pred, y_src[:, 0, ...])
                loss_dc_y = self.dc_loss(
                    torch.cat((y_src_pred_dom, y_tar_ul_pred_dom), dim=0),
                    dom_labels_y)
                total_loss = loss_seg_src + self.lambda_dc * loss_dc_y
                y_tar_l_pred, _ = self.forward_seg(x_tar_l)
                loss_seg_tar = self.seg_loss(y_tar_l_pred, y_tar_l[:, 0, ...])
                total_loss = total_loss + loss_seg_tar
            else:
                x_src_rec, x_src_rec_dom = self.forward_rec(x_src)
                if np.random.rand() < self.p:
                    y_src_pred, y_src_pred_dom = self.forward_seg(x_src)
                else:
                    y_src_pred, y_src_pred_dom = self.forward_seg(x_src_rec)
                    dom_labels_y[:x_src.size(0)] = 1
                x_tar_ul_rec, x_tar_ul_rec_dom = self.forward_rec(x_tar_ul)
                if np.random.rand() < self.p:
                    y_tar_ul_pred, y_tar_ul_pred_dom = self.forward_seg(
                        x_tar_ul)
                else:
                    y_tar_ul_pred, y_tar_ul_pred_dom = self.forward_seg(
                        x_tar_ul_rec)
                    dom_labels_y[x_src.size(0):] = 1
                loss_rec_src = self.rec_loss(x_src_rec, x_src)
                loss_rec_tar = self.rec_loss(x_tar_ul, x_tar_ul_rec)
                loss_seg_src = self.seg_loss(y_src_pred, y_src[:, 0, ...])
                loss_dc_x = self.dc_loss(
                    torch.cat((x_src_rec_dom, x_tar_ul_rec_dom), dim=0),
                    dom_labels_x)
                loss_dc_y = self.dc_loss(
                    torch.cat((y_src_pred_dom, y_tar_ul_pred_dom), dim=0),
                    dom_labels_y)
                total_loss = loss_seg_src + self.lambda_rec * (loss_rec_src + loss_rec_tar) + \
                             self.lambda_dc * (loss_dc_x + loss_dc_y)
                _, y_tar_l_pred, _, y_tar_l_pred_dom = self(x_tar_l)
                loss_seg_tar = self.seg_loss(y_tar_l_pred, y_tar_l[:, 0, ...])
                total_loss = total_loss + loss_seg_tar

            loss_seg_src_cum += loss_seg_src.data.cpu().numpy()
            loss_seg_tar_cum += loss_seg_tar.data.cpu().numpy()
            loss_rec_src_cum += loss_rec_src.data.cpu().numpy()
            loss_rec_tar_cum += loss_rec_tar.data.cpu().numpy()
            loss_dc_x_cum += loss_dc_x.data.cpu().numpy()
            loss_dc_y_cum += loss_dc_y.data.cpu().numpy()
            total_loss_cum += total_loss.data.cpu().numpy()
            cnt += 1

            if self.train_mode == SEGMENTATION or self.train_mode == JOINT:
                for b in range(y_tar_l_pred.size(0)):
                    y_preds.append(
                        F.softmax(y_tar_l_pred,
                                  dim=1)[b, ...].view(y_tar_l_pred.size(1),
                                                      -1).data.cpu().numpy())
                    ys.append(y_tar_l[b, 0, ...].flatten().cpu().numpy())

        # keep track of time
        runtime = datetime.datetime.now() - time_start
        seconds = runtime.total_seconds()
        hours = seconds // 3600
        minutes = (seconds - hours * 3600) // 60
        seconds = seconds - hours * 3600 - minutes * 60
        print_frm(
            'Epoch %5d - Runtime for testing: %d hours, %d minutes, %f seconds'
            % (epoch, hours, minutes, seconds))

        # prep for metric computation
        if self.train_mode == SEGMENTATION or self.train_mode == JOINT:
            y_preds = np.concatenate(y_preds, axis=1)
            ys = np.concatenate(ys)
            js = np.asarray([
                jaccard((ys == i).astype(int), y_preds[i, :])
                for i in range(len(self.coi))
            ])
            ams = np.asarray([
                accuracy_metrics((ys == i).astype(int), y_preds[i, :])
                for i in range(len(self.coi))
            ])

        # don't forget to compute the average and print it
        loss_seg_src_avg = loss_seg_src_cum / cnt
        loss_seg_tar_avg = loss_seg_tar_cum / cnt
        loss_rec_src_avg = loss_rec_src_cum / cnt
        loss_rec_tar_avg = loss_rec_tar_cum / cnt
        loss_dc_x_avg = loss_dc_x_cum / cnt
        loss_dc_y_avg = loss_dc_y_cum / cnt
        total_loss_avg = total_loss_cum / cnt
        print(
            '[%s] Testing Epoch %5d - Loss seg src: %.6f - Loss seg tar: %.6f - Loss rec src: %.6f - Loss rec tar: %.6f - Loss DCX: %.6f - Loss DCY: %.6f - Loss: %.6f'
            % (datetime.datetime.now(), epoch, loss_seg_src_avg,
               loss_seg_tar_avg, loss_rec_src_avg, loss_rec_tar_avg,
               loss_dc_x_avg, loss_dc_y_avg, total_loss_avg))

        # log everything
        if writer is not None:

            # always log scalars
            if self.train_mode == RECONSTRUCTION:
                log_scalars(
                    [loss_rec_src_avg, loss_rec_tar_avg, loss_dc_x_avg], [
                        'test/' + s
                        for s in ['loss-rec-src', 'loss-rec-tar', 'loss-dc-x']
                    ],
                    writer,
                    epoch=epoch)
            elif self.train_mode == SEGMENTATION:
                log_scalars([
                    loss_seg_src_avg, loss_seg_tar_avg, loss_dc_y_avg,
                    np.mean(js, axis=0), *(np.mean(ams, axis=0))
                ], [
                    'test/' + s for s in [
                        'loss-seg-src', 'loss-seg-tar', 'loss-dc-y', 'jaccard',
                        'accuracy', 'balanced-accuracy', 'precision', 'recall',
                        'f-score'
                    ]
                ],
                            writer,
                            epoch=epoch)
            else:
                log_scalars([
                    loss_seg_src_avg, loss_seg_tar_avg, loss_rec_src_avg,
                    loss_rec_tar_avg, loss_dc_x_avg, loss_dc_y_avg,
                    np.mean(js, axis=0), *(np.mean(ams, axis=0))
                ], [
                    'test/' + s for s in [
                        'loss-seg-src', 'loss-seg-tar', 'loss-rec-src',
                        'loss-rec-tar', 'loss-dc-x', 'loss-dc-y', 'jaccard',
                        'accuracy', 'balanced-accuracy', 'precision', 'recall',
                        'f-score'
                    ]
                ],
                            writer,
                            epoch=epoch)
            log_scalars([total_loss_avg],
                        ['test/' + s for s in ['total-loss']],
                        writer,
                        epoch=epoch)

            # log images if necessary
            if write_images:
                log_images_2d([x_src.data], ['test/' + s for s in ['src/x']],
                              writer,
                              epoch=epoch)
                if self.train_mode == RECONSTRUCTION:
                    log_images_2d(
                        [x_src_rec.data, x_tar_ul.data, x_tar_ul_rec.data], [
                            'test/' + s
                            for s in ['src/x-rec', 'tar/x-ul', 'tar/x-ul-rec']
                        ],
                        writer,
                        epoch=epoch)
                elif self.train_mode == SEGMENTATION:
                    y_src_pred = F.softmax(y_src_pred, dim=1)[:,
                                                              1:2, :, :].data
                    log_images_2d(
                        [y_src.data, y_src_pred],
                        ['test/' + s for s in ['src/y', 'src/y-pred']],
                        writer,
                        epoch=epoch)
                    if loader_tar_l is not None:
                        y_tar_l_pred = F.softmax(y_tar_l_pred,
                                                 dim=1)[:, 1:2, :, :].data
                        log_images_2d([x_tar_l.data, y_tar_l, y_tar_l_pred], [
                            'test/' + s
                            for s in ['tar/x-l', 'tar/y-l', 'tar/y-l-pred']
                        ],
                                      writer,
                                      epoch=epoch)
                else:
                    y_src_pred = F.softmax(y_src_pred, dim=1)[:,
                                                              1:2, :, :].data
                    log_images_2d([
                        x_src_rec.data, y_src.data, y_src_pred, x_tar_ul.data,
                        x_tar_ul_rec.data
                    ], [
                        'test/' + s for s in [
                            'src/x-rec', 'src/y', 'src/y-pred', 'tar/x-ul',
                            'tar/x-ul-rec'
                        ]
                    ],
                                  writer,
                                  epoch=epoch)
                    if loader_tar_l is not None:
                        y_tar_l_pred = F.softmax(y_tar_l_pred,
                                                 dim=1)[:, 1:2, :, :].data
                        log_images_2d([x_tar_l.data, y_tar_l, y_tar_l_pred], [
                            'test/' + s
                            for s in ['tar/x-l', 'tar/y-l', 'tar/y-l-pred']
                        ],
                                      writer,
                                      epoch=epoch)

        return total_loss_avg
Exemple #5
0
    def train_epoch(self,
                    loader_src,
                    loader_tar_ul,
                    loader_tar_l,
                    optimizer,
                    epoch,
                    augmenter=None,
                    print_stats=1,
                    writer=None,
                    write_images=False,
                    device=0):
        """
        Trains the network for one epoch
        :param loader_src: source dataloader (labeled)
        :param loader_tar_ul: target dataloader (unlabeled)
        :param loader_tar_l: target dataloader (labeled)
        :param optimizer: optimizer for the loss function
        :param epoch: current epoch
        :param augmenter: data augmenter
        :param print_stats: frequency of printing statistics
        :param writer: summary writer
        :param write_images: frequency of writing images
        :param device: GPU device where the computations should occur
        :return: average training loss over the epoch
        """
        # perform training on GPU/CPU
        module_to_device(self, device)
        self.train()

        # keep track of the average loss during the epoch
        loss_seg_src_cum = 0.0
        loss_seg_tar_cum = 0.0
        loss_rec_src_cum = 0.0
        loss_rec_tar_cum = 0.0
        loss_dc_x_cum = 0.0
        loss_dc_y_cum = 0.0
        total_loss_cum = 0.0
        cnt = 0

        # zip dataloaders
        if loader_tar_l is None:
            dl = zip(loader_src, loader_tar_ul)
        else:
            dl = zip(loader_src, loader_tar_ul, loader_tar_l)

        # start epoch
        time_start = datetime.datetime.now()
        for i, data in enumerate(dl):

            # transfer to suitable device
            data_src = tensor_to_device(data[0], device)
            x_tar_ul = tensor_to_device(data[1], device)
            if loader_tar_l is not None:
                data_tar_l = tensor_to_device(data[2], device)

            # augment if necessary
            if loader_tar_l is None:
                data_aug = (data_src[0], data_src[1])
                x_src, y_src = augment_samples(data_aug, augmenter=augmenter)
                data_aug = (x_tar_ul, x_tar_ul)
                x_tar_ul, _ = augment_samples(data_aug, augmenter=augmenter)
            else:
                data_aug = (data_src[0], data_src[1])
                x_src, y_src = augment_samples(data_aug, augmenter=augmenter)
                data_aug = (x_tar_ul, x_tar_ul)
                x_tar_ul, _ = augment_samples(data_aug, augmenter=augmenter)
                data_aug = (data_tar_l[0], data_tar_l[1])
                x_tar_l, y_tar_l = augment_samples(data_aug,
                                                   augmenter=augmenter)
                y_tar_l = get_labels(y_tar_l, coi=self.coi, dtype=int)
            y_src = get_labels(y_src, coi=self.coi, dtype=int)
            x_tar_ul = x_tar_ul.float()

            # zero the gradient buffers
            self.zero_grad()

            # get domain labels for domain confusion
            dom_labels_x = tensor_to_device(
                torch.zeros((x_src.size(0) + x_tar_ul.size(0))),
                device).long()
            dom_labels_x[x_src.size(0):] = 1
            dom_labels_y = tensor_to_device(
                torch.zeros((x_src.size(0) + x_tar_ul.size(0))),
                device).long()
            dom_labels_y[x_src.size(0):] = 1

            # check train mode and compute loss
            loss_seg_src = torch.Tensor([0])
            loss_seg_tar = torch.Tensor([0])
            loss_rec_src = torch.Tensor([0])
            loss_rec_tar = torch.Tensor([0])
            loss_dc_x = torch.Tensor([0])
            loss_dc_y = torch.Tensor([0])
            if self.train_mode == RECONSTRUCTION:
                x_src_rec, x_src_rec_dom = self.forward_rec(x_src)
                x_tar_ul_rec, x_tar_ul_rec_dom = self.forward_rec(x_tar_ul)
                loss_rec_src = self.rec_loss(x_src_rec, x_src)
                loss_rec_tar = self.rec_loss(x_tar_ul, x_tar_ul_rec)
                loss_dc_x = self.dc_loss(
                    torch.cat((x_src_rec_dom, x_tar_ul_rec_dom), dim=0),
                    dom_labels_x)
                total_loss = loss_rec_src + loss_rec_tar + self.lambda_dc * loss_dc_x
            elif self.train_mode == SEGMENTATION:
                # switch between reconstructed and original inputs
                if np.random.rand() < self.p:
                    y_src_pred, y_src_pred_dom = self.forward_seg(x_src)
                else:
                    x_src_rec, _ = self.forward_rec(x_src)
                    y_src_pred, y_src_pred_dom = self.forward_seg(x_src_rec)
                    dom_labels_y[:x_src.size(0)] = 1
                if np.random.rand() < self.p:
                    y_tar_ul_pred, y_tar_ul_pred_dom = self.forward_seg(
                        x_tar_ul)
                else:
                    x_tar_ul_rec, _ = self.forward_rec(x_tar_ul)
                    y_tar_ul_pred, y_tar_ul_pred_dom = self.forward_seg(
                        x_tar_ul_rec)
                    dom_labels_y[x_src.size(0):] = 1
                loss_seg_src = self.seg_loss(y_src_pred, y_src[:, 0, ...])
                loss_dc_y = self.dc_loss(
                    torch.cat((y_src_pred_dom, y_tar_ul_pred_dom), dim=0),
                    dom_labels_y)
                total_loss = loss_seg_src + self.lambda_dc * loss_dc_y
                if loader_tar_l is not None:
                    y_tar_l_pred, _ = self.forward_seg(x_tar_l)
                    loss_seg_tar = self.seg_loss(y_tar_l_pred, y_tar_l[:, 0,
                                                                       ...])
                    total_loss = total_loss + loss_seg_tar
            else:
                x_src_rec, x_src_rec_dom = self.forward_rec(x_src)
                if np.random.rand() < self.p:
                    y_src_pred, y_src_pred_dom = self.forward_seg(x_src)
                else:
                    y_src_pred, y_src_pred_dom = self.forward_seg(x_src_rec)
                    dom_labels_y[:x_src.size(0)] = 1
                x_tar_ul_rec, x_tar_ul_rec_dom = self.forward_rec(x_tar_ul)
                if np.random.rand() < self.p:
                    y_tar_ul_pred, y_tar_ul_pred_dom = self.forward_seg(
                        x_tar_ul)
                else:
                    y_tar_ul_pred, y_tar_ul_pred_dom = self.forward_seg(
                        x_tar_ul_rec)
                    dom_labels_y[x_src.size(0):] = 1
                loss_rec_src = self.rec_loss(x_src_rec, x_src)
                loss_rec_tar = self.rec_loss(x_tar_ul, x_tar_ul_rec)
                loss_seg_src = self.seg_loss(y_src_pred, y_src[:, 0, ...])
                loss_dc_x = self.dc_loss(
                    torch.cat((x_src_rec_dom, x_tar_ul_rec_dom), dim=0),
                    dom_labels_x)
                loss_dc_y = self.dc_loss(
                    torch.cat((y_src_pred_dom, y_tar_ul_pred_dom), dim=0),
                    dom_labels_y)
                total_loss = loss_seg_src + self.lambda_rec * (loss_rec_src + loss_rec_tar) + \
                             self.lambda_dc * (loss_dc_x + loss_dc_y)
                if loader_tar_l is not None:
                    _, y_tar_l_pred, _, y_tar_l_pred_dom = self(x_tar_l)
                    loss_seg_tar = self.seg_loss(y_tar_l_pred, y_tar_l[:, 0,
                                                                       ...])
                    total_loss = total_loss + loss_seg_tar

            loss_seg_src_cum += loss_seg_src.data.cpu().numpy()
            loss_seg_tar_cum += loss_seg_tar.data.cpu().numpy()
            loss_rec_src_cum += loss_rec_src.data.cpu().numpy()
            loss_rec_tar_cum += loss_rec_tar.data.cpu().numpy()
            loss_dc_x_cum += loss_dc_x.data.cpu().numpy()
            loss_dc_y_cum += loss_dc_y.data.cpu().numpy()
            total_loss_cum += total_loss.data.cpu().numpy()
            cnt += 1

            # backward prop
            total_loss.backward()

            # apply one step in the optimization
            optimizer.step()

            # print statistics of necessary
            if i % print_stats == 0:
                print(
                    '[%s] Epoch %5d - Iteration %5d/%5d - Loss seg src: %.6f - Loss seg tar: %.6f - Loss rec src: %.6f - Loss rec tar: %.6f - Loss DCX: %.6f - Loss DCY: %.6f - Loss: %.6f'
                    % (datetime.datetime.now(), epoch, i,
                       len(loader_src.dataset) / loader_src.batch_size,
                       loss_seg_src_cum / cnt, loss_seg_tar_cum / cnt,
                       loss_rec_src_cum / cnt, loss_rec_tar_cum / cnt,
                       loss_dc_x_cum / cnt, loss_dc_y_cum / cnt,
                       total_loss_cum / cnt))

        # keep track of time
        runtime = datetime.datetime.now() - time_start
        seconds = runtime.total_seconds()
        hours = seconds // 3600
        minutes = (seconds - hours * 3600) // 60
        seconds = seconds - hours * 3600 - minutes * 60
        print_frm(
            'Epoch %5d - Runtime for training: %d hours, %d minutes, %f seconds'
            % (epoch, hours, minutes, seconds))

        # don't forget to compute the average and print it
        loss_seg_src_avg = loss_seg_src_cum / cnt
        loss_seg_tar_avg = loss_seg_tar_cum / cnt
        loss_rec_src_avg = loss_rec_src_cum / cnt
        loss_rec_tar_avg = loss_rec_tar_cum / cnt
        loss_dc_x_avg = loss_dc_x_cum / cnt
        loss_dc_y_avg = loss_dc_y_cum / cnt
        total_loss_avg = total_loss_cum / cnt
        print(
            '[%s] Training Epoch %4d - Loss seg src: %.6f - Loss seg tar: %.6f - Loss rec src: %.6f - Loss rec tar: %.6f - Loss DCX: %.6f - Loss DCY: %.6f - Loss: %.6f'
            % (datetime.datetime.now(), epoch, loss_seg_src_avg,
               loss_seg_tar_avg, loss_rec_src_avg, loss_rec_tar_avg,
               loss_dc_x_avg, loss_dc_y_avg, total_loss_avg))

        # log everything
        if writer is not None:

            # always log scalars
            if self.train_mode == RECONSTRUCTION:
                log_scalars(
                    [loss_rec_src_avg, loss_rec_tar_avg, loss_dc_x_avg], [
                        'train/' + s
                        for s in ['loss-rec-src', 'loss-rec-tar', 'loss-dc-x']
                    ],
                    writer,
                    epoch=epoch)
            elif self.train_mode == SEGMENTATION:
                log_scalars(
                    [loss_seg_src_avg, loss_seg_tar_avg, loss_dc_y_avg], [
                        'train/' + s
                        for s in ['loss-seg-src', 'loss-seg-tar', 'loss-dc-y']
                    ],
                    writer,
                    epoch=epoch)
            else:
                log_scalars([
                    loss_seg_src_avg, loss_seg_tar_avg, loss_rec_src_avg,
                    loss_rec_tar_avg, loss_dc_x_avg, loss_dc_y_avg
                ], [
                    'train/' + s for s in [
                        'loss-seg-src', 'loss-seg-tar', 'loss-rec-src',
                        'loss-rec-tar', 'loss-dc-x', 'loss-dc-y'
                    ]
                ],
                            writer,
                            epoch=epoch)
            log_scalars([total_loss_avg],
                        ['train/' + s for s in ['total-loss']],
                        writer,
                        epoch=epoch)

            # log images if necessary
            if write_images:
                log_images_2d([x_src.data], ['train/' + s for s in ['src/x']],
                              writer,
                              epoch=epoch)
                if self.train_mode == RECONSTRUCTION:
                    log_images_2d(
                        [x_src_rec.data, x_tar_ul.data, x_tar_ul_rec.data], [
                            'train/' + s
                            for s in ['src/x-rec', 'tar/x-ul', 'tar/x-ul-rec']
                        ],
                        writer,
                        epoch=epoch)
                elif self.train_mode == SEGMENTATION:
                    y_src_pred = F.softmax(y_src_pred, dim=1)[:,
                                                              1:2, :, :].data
                    log_images_2d(
                        [y_src.data, y_src_pred],
                        ['train/' + s for s in ['src/y', 'src/y-pred']],
                        writer,
                        epoch=epoch)
                    if loader_tar_l is not None:
                        y_tar_l_pred = F.softmax(y_tar_l_pred,
                                                 dim=1)[:, 1:2, :, :].data
                        log_images_2d([x_tar_l.data, y_tar_l, y_tar_l_pred], [
                            'train/' + s
                            for s in ['tar/x-l', 'tar/y-l', 'tar/y-l-pred']
                        ],
                                      writer,
                                      epoch=epoch)
                else:
                    y_src_pred = F.softmax(y_src_pred, dim=1)[:,
                                                              1:2, :, :].data
                    log_images_2d([
                        x_src_rec.data, y_src.data, y_src_pred, x_tar_ul.data,
                        x_tar_ul_rec.data
                    ], [
                        'train/' + s for s in [
                            'src/x-rec', 'src/y', 'src/y-pred', 'tar/x-ul',
                            'tar/x-ul-rec'
                        ]
                    ],
                                  writer,
                                  epoch=epoch)
                    if loader_tar_l is not None:
                        y_tar_l_pred = F.softmax(y_tar_l_pred,
                                                 dim=1)[:, 1:2, :, :].data
                        log_images_2d([x_tar_l.data, y_tar_l, y_tar_l_pred], [
                            'train/' + s
                            for s in ['tar/x-l', 'tar/y-l', 'tar/y-l-pred']
                        ],
                                      writer,
                                      epoch=epoch)

        return total_loss_avg
Exemple #6
0
    def test_epoch(self,
                   loader_src,
                   loader_tar,
                   loss_seg_fn,
                   loss_rec_fn,
                   epoch,
                   writer=None,
                   write_images=False,
                   device=0):
        """
        Tests the network for one epoch
        :param loader_src: source dataloader (should be labeled)
        :param loader_tar: target dataloader (should be labeled)
        :param loss_seg_fn: segmentation loss function
        :param loss_rec_fn: reconstruction loss function
        :param epoch: current epoch
        :param writer: summary writer
        :param write_images: frequency of writing images
        :param device: GPU device where the computations should occur
        :return: average training loss over the epoch
        """
        # perform training on GPU/CPU
        module_to_device(self, device)
        self.eval()

        # keep track of the average loss and metrics during the epoch
        loss_seg_cum = 0.0
        loss_rec_cum = 0.0
        total_loss_cum = 0.0
        cnt = 0

        # start epoch
        y_src_preds = []
        ys_src = []
        y_tar_preds = []
        ys_tar = []
        for i, data in enumerate(zip(loader_src, loader_tar)):
            # get inputs and transfer to suitable device
            x_src, y_src = tensor_to_device(data[0], device)
            x_tar, y_tar = tensor_to_device(data[1], device)
            y_src = get_labels(y_src, coi=self.coi, dtype=int)
            y_tar = get_labels(y_tar, coi=self.coi, dtype=int)
            x_src = x_src.float()
            x_tar = x_tar.float()

            # zero the gradient buffers
            self.zero_grad()

            # forward prop
            y_src_pred = self(x_src)
            x_src_pred = self.reconstruction_outputs
            y_tar_pred = self(x_tar)
            x_tar_pred = self.reconstruction_outputs

            # compute loss
            loss_seg = loss_seg_fn(y_src_pred, y_src)
            loss_rec = 0.5 * (loss_rec_fn(x_src_pred, x_src) +
                              loss_rec_fn(x_tar_pred, x_tar))
            total_loss = loss_seg + self.lambda_rec * loss_rec
            loss_seg_cum += loss_seg.data.cpu().numpy()
            loss_rec_cum += loss_rec.data.cpu().numpy()
            total_loss_cum += total_loss.data.cpu().numpy()
            cnt += 1

            for b in range(y_src_pred.size(0)):
                y_src_preds.append(
                    F.softmax(y_src_pred, dim=1).data.cpu().numpy()[b, 1, ...])
                y_tar_preds.append(
                    F.softmax(y_tar_pred, dim=1).data.cpu().numpy()[b, 1, ...])
                ys_src.append(y_src[b, 0, ...].cpu().numpy())
                ys_tar.append(y_tar[b, 0, ...].cpu().numpy())

        # compute interesting metrics
        y_src_preds = np.asarray(y_src_preds)
        y_tar_preds = np.asarray(y_tar_preds)
        ys_src = np.asarray(ys_src)
        ys_tar = np.asarray(ys_tar)
        j_src = jaccard(ys_src, y_src_preds)
        j_tar = jaccard(ys_src, y_tar_preds)
        a_src, ba_src, p_src, r_src, f_src = accuracy_metrics(
            ys_src, y_src_preds)
        a_tar, ba_tar, p_tar, r_tar, f_tar = accuracy_metrics(
            ys_tar, y_tar_preds)

        # don't forget to compute the average and print it
        loss_seg_avg = loss_seg_cum / cnt
        loss_rec_avg = loss_rec_cum / cnt
        total_loss_avg = total_loss_cum / cnt
        print('[%s] Epoch %5d - Loss seg: %.6f - Loss rec: %.6f - Loss: %.6f' %
              (datetime.datetime.now(), epoch, loss_seg_avg, loss_rec_avg,
               total_loss_avg))

        # log everything
        if writer is not None:

            # always log scalars
            log_scalars([
                loss_seg_avg, loss_rec_avg, total_loss_avg, j_src, a_src,
                ba_src, p_src, r_src, f_src, j_tar, a_tar, ba_tar, p_tar,
                r_tar, f_tar
            ], [
                'test/' + s for s in [
                    'loss-rec', 'loss-seg', 'total-loss', 'src/jaccard',
                    'src/accuracy', 'src/balanced-accuracy', 'src/precision',
                    'src/recall', 'src/f-score', 'tar/jaccard', 'tar/accuracy',
                    'tar/balanced-accuracy', 'tar/precision', 'tar/recall',
                    'tar/f-score'
                ]
            ],
                        writer,
                        epoch=epoch)

            # log images if necessary
            if write_images:
                y_src_pred = F.softmax(y_src_pred, dim=1)[:, 1:2, ...].data
                y_tar_pred = F.softmax(y_tar_pred, dim=1)[:, 1:2, ...].data
                log_images_3d([
                    x_src, x_src_pred.data, y_src, y_src_pred, x_tar,
                    x_tar_pred.data, y_tar, y_tar_pred
                ], [
                    'test/' + s for s in [
                        'src/x', 'src/x-pred', 'src/y', 'src/y-pred', 'tar/x',
                        'tar/x-pred', 'tar/y', 'tar/y-pred'
                    ]
                ],
                              writer,
                              epoch=epoch)

        return total_loss_avg
Exemple #7
0
    def train_epoch_semi_supervised(self,
                                    loader_src,
                                    loader_tar_ul,
                                    loader_tar_l,
                                    loss_seg_fn,
                                    loss_rec_fn,
                                    optimizer,
                                    epoch,
                                    augmenter_src=None,
                                    augmenter_tar=None,
                                    print_stats=1,
                                    writer=None,
                                    write_images=False,
                                    device=0):
        """
        Trains the network for one epoch
        :param loader_src: source dataloader (labeled)
        :param loader_tar_ul: target dataloader (unlabeled)
        :param loader_tar_l: target dataloader (labeled)
        :param loss_seg_fn: segmentation loss function
        :param loss_rec_fn: reconstruction loss function
        :param optimizer: optimizer for the loss function
        :param epoch: current epoch
        :param augmenter_src: source data augmenter
        :param augmenter_tar: target data augmenter
        :param print_stats: frequency of printing statistics
        :param writer: summary writer
        :param write_images: frequency of writing images
        :param device: GPU device where the computations should occur
        :return: average training loss over the epoch
        """
        # perform training on GPU/CPU
        module_to_device(self, device)
        self.train()

        # keep track of the average loss during the epoch
        loss_seg_cum = 0.0
        loss_rec_cum = 0.0
        total_loss_cum = 0.0
        cnt = 0

        # start epoch
        for i, data in enumerate(zip(loader_src, loader_tar_ul, loader_tar_l)):

            # transfer to suitable device
            data_src = tensor_to_device(data[0], device)
            x_tar_ul = tensor_to_device(data[1], device)
            data_tar_l = tensor_to_device(data[2], device)

            # augment if necessary
            x_src, y_src = augment_samples(data_src, augmenter=augmenter_src)
            x_tar_l, y_tar_l = augment_samples(data_tar_l,
                                               augmenter=augmenter_tar)
            y_src = get_labels(y_src, coi=self.coi, dtype=int)
            y_tar_l = get_labels(y_tar_l, coi=self.coi, dtype=int)
            x_tar_ul = x_tar_ul.float()

            # zero the gradient buffers
            self.zero_grad()

            # forward prop
            y_src_pred = self(x_src)
            x_src_pred = self.reconstruction_outputs
            y_tar_ul_pred = self(x_tar_ul)
            x_tar_ul_pred = self.reconstruction_outputs
            y_tar_l_pred = self(x_tar_l)
            x_tar_l_pred = self.reconstruction_outputs

            # compute loss
            loss_seg = 0.5 * (loss_seg_fn(y_src_pred, y_src) +
                              loss_seg_fn(y_tar_l_pred, y_tar_l))
            loss_rec = 0.5 * (loss_rec_fn(x_src_pred, x_src) +
                              loss_rec_fn(x_tar_ul_pred, x_tar_ul))
            total_loss = loss_seg + self.lambda_rec * loss_rec
            loss_seg_cum += loss_seg.data.cpu().numpy()
            loss_rec_cum += loss_rec.data.cpu().numpy()
            total_loss_cum += total_loss.data.cpu().numpy()
            cnt += 1

            # backward prop
            total_loss.backward()

            # apply one step in the optimization
            optimizer.step()

            # print statistics of necessary
            if i % print_stats == 0:
                print(
                    '[%s] Epoch %5d - Iteration %5d/%5d - Loss seg: %.6f - Loss rec: %.6f - Loss: %.6f'
                    % (datetime.datetime.now(), epoch, i,
                       len(loader_src.dataset) / loader_src.batch_size,
                       loss_seg, loss_rec, total_loss))

        # don't forget to compute the average and print it
        loss_seg_avg = loss_seg_cum / cnt
        loss_rec_avg = loss_rec_cum / cnt
        total_loss_avg = total_loss_cum / cnt
        print('[%s] Epoch %5d - Loss seg: %.6f - Loss rec: %.6f - Loss: %.6f' %
              (datetime.datetime.now(), epoch, loss_seg_avg, loss_rec_avg,
               total_loss_avg))

        # log everything
        if writer is not None:

            # always log scalars
            log_scalars(
                [loss_seg_avg, loss_rec_avg, total_loss_avg],
                ['train/' + s for s in ['loss-rec', 'loss-seg', 'total-loss']],
                writer,
                epoch=epoch)

            # log images if necessary
            if write_images:
                y_src_pred = F.softmax(y_src_pred, dim=1)[:, 1:2, ...].data
                y_tar_l_pred = F.softmax(y_tar_l_pred, dim=1)[:, 1:2, ...].data
                log_images_3d([
                    x_src, x_src_pred.data, y_src, y_src_pred, x_tar_l,
                    x_tar_l_pred.data, y_tar_l, y_tar_l_pred
                ], [
                    'train/' + s for s in [
                        'src/x', 'src/x-pred', 'src/y', 'src/y-pred', 'tar/x',
                        'tar/x-pred', 'tar/y', 'tar/y-pred'
                    ]
                ],
                              writer,
                              epoch=epoch)

        return total_loss_avg
Exemple #8
0
    def test_epoch(self,
                   loader,
                   loss_rec_fn,
                   loss_kl_fn,
                   epoch,
                   writer=None,
                   write_images=False,
                   device=0):
        """
        Tests the network for one epoch
        :param loader: dataloader
        :param loss_rec_fn: reconstruction loss function
        :param loss_kl_fn: kullback leibler loss function
        :param epoch: current epoch
        :param writer: summary writer
        :param write_images: frequency of writing images
        :param device: GPU device where the computations should occur
        :return: average testing loss over the epoch
        """
        # make sure network is on the gpu and in training mode
        module_to_device(self, device)
        self.eval()

        # keep track of the average losses during the epoch
        loss_rec_cum = 0.0
        loss_kl_cum = 0.0
        loss_cum = 0.0
        cnt = 0

        # start epoch
        z = []
        li = []
        for i, data in enumerate(loader):
            # transfer to suitable device
            x = tensor_to_device(data.float(), device)

            # forward prop
            x_pred = torch.sigmoid(self(x))
            z.append(_reparametrise(self.mu, self.logvar).cpu().data.numpy())
            li.append(x.cpu().data.numpy())

            # compute loss
            loss_rec = loss_rec_fn(x_pred, x)
            loss_kl = loss_kl_fn(self.mu, self.logvar)
            loss = loss_rec + self.beta * loss_kl
            loss_rec_cum += loss_rec.data.cpu().numpy()
            loss_kl_cum += loss_kl.data.cpu().numpy()
            loss_cum += loss.data.cpu().numpy()
            cnt += 1

        # don't forget to compute the average and print it
        loss_rec_avg = loss_rec_cum / cnt
        loss_kl_avg = loss_kl_cum / cnt
        loss_avg = loss_cum / cnt
        print_frm(
            'Epoch %5d - Average test loss rec: %.6f - Average test loss KL: %.6f - Average test loss: %.6f'
            % (epoch, loss_rec_avg, loss_kl_avg, loss_avg))

        # log everything
        if writer is not None:

            # always log scalars
            log_scalars([loss_rec_avg, loss_kl_avg, loss_avg],
                        ['test/' + s for s in ['loss-rec', 'loss-kl', 'loss']],
                        writer,
                        epoch=epoch)

            # log images if necessary
            if write_images:
                log_images_2d([x, x_pred],
                              ['test/' + s for s in ['x', 'x_pred']],
                              writer,
                              epoch=epoch)

        return loss_avg
Exemple #9
0
    def train_epoch(self,
                    loader,
                    loss_rec_fn,
                    loss_kl_fn,
                    optimizer,
                    epoch,
                    augmenter=None,
                    print_stats=1,
                    writer=None,
                    write_images=False,
                    device=0):
        """
        Trains the network for one epoch
        :param loader: dataloader
        :param loss_rec_fn: reconstruction loss function
        :param loss_kl_fn: kullback leibler loss function
        :param optimizer: optimizer for the loss function
        :param epoch: current epoch
        :param augmenter: data augmenter
        :param print_stats: frequency of printing statistics
        :param writer: summary writer
        :param write_images: frequency of writing images
        :param device: GPU device where the computations should occur
        :return: average training loss over the epoch
        """
        # make sure network is on the gpu and in training mode
        module_to_device(self, device)
        self.train()

        # keep track of the average losses during the epoch
        loss_rec_cum = 0.0
        loss_kl_cum = 0.0
        loss_cum = 0.0
        cnt = 0

        # start epoch
        for i, data in enumerate(loader):

            # transfer to suitable device
            x = tensor_to_device(data.float(), device)

            # get the inputs and augment if necessary
            if augmenter is not None:
                x = augmenter(x)

            # zero the gradient buffers
            self.zero_grad()

            # forward prop
            x_pred = torch.sigmoid(self(x))

            # compute loss
            loss_rec = loss_rec_fn(x_pred, x)
            loss_kl = loss_kl_fn(self.mu, self.logvar)
            loss = loss_rec + self.beta * loss_kl
            loss_rec_cum += loss_rec.data.cpu().numpy()
            loss_kl_cum += loss_kl.data.cpu().numpy()
            loss_cum += loss.data.cpu().numpy()
            cnt += 1

            # backward prop
            loss.backward()

            # apply one step in the optimization
            optimizer.step()

            # print statistics if necessary
            if i % print_stats == 0:
                print_frm(
                    'Epoch %5d - Iteration %5d/%5d - Loss Rec: %.6f - Loss KL: %.6f - Loss: %.6f'
                    % (epoch, i, len(loader.dataset) / loader.batch_size,
                       loss_rec, loss_kl, loss))

        # don't forget to compute the average and print it
        loss_rec_avg = loss_rec_cum / cnt
        loss_kl_avg = loss_kl_cum / cnt
        loss_avg = loss_cum / cnt
        print_frm(
            'Epoch %5d - Average train loss rec: %.6f - Average train loss KL: %.6f - Average train loss: %.6f'
            % (epoch, loss_rec_avg, loss_kl_avg, loss_avg))

        # log everything
        if writer is not None:

            # always log scalars
            log_scalars(
                [loss_rec_avg, loss_kl_avg, loss_avg],
                ['train/' + s for s in ['loss-rec', 'loss-kl', 'loss']],
                writer,
                epoch=epoch)

            # log images if necessary
            if write_images:
                log_images_2d([x, x_pred],
                              ['train/' + s for s in ['x', 'x_pred']],
                              writer,
                              epoch=epoch)

        return loss_avg
Exemple #10
0
def segment_multichannel_3d(data,
                            net,
                            input_shape,
                            in_channels=1,
                            batch_size=1,
                            step_size=None,
                            train=False,
                            track_progress=False,
                            device=0,
                            orientation=0,
                            normalization='unit'):
    """
    Segment a multichannel 3D image using a specific network

    :param data: 4D array (C, Z, Y, X) representing the multichannel 3D image
    :param net: image-to-image segmentation network
    :param input_shape: size of the inputs (either 2 or 3-tuple)
    :param in_channels: amount of subsequent slices that serve as input for the network (should be odd)
    :param batch_size: batch size for processing
    :param step_size: step size of the sliding window
    :param train: evaluate the network in training mode
    :param track_progress: optionally, for tracking progress with progress bar
    :param device: GPU device where the computations should occur
    :param orientation: orientation to perform segmentation: 0-Z, 1-Y, 2-X (only for 2D based segmentation)
    :param normalization: type of data normalization (unit, z or minmax)
    :return: the segmented image
    """

    # make sure we compute everything on the correct device
    module_to_device(net, device)

    # set the network in the correct mode
    if train:
        net.train()
    else:
        net.eval()

    # orient data if necessary
    data = _orient(data, orientation)

    # pad data if necessary
    data, pad_width = _pad(data, input_shape, in_channels)

    # 2D or 3D
    is2d = len(input_shape) == 2

    # get the amount of channels
    channels = data.shape[0]
    if is2d:
        channels = in_channels

    # initialize the step size
    step_size = _init_step_size(step_size, input_shape, is2d)

    # gaussian window for smooth block merging
    g_window = _init_gaussian_window(input_shape, is2d)

    # allocate space
    seg_cum = np.zeros((net.out_channels, *data.shape[1:]))
    counts_cum = np.zeros(data.shape[1:])

    # define sliding window
    sw = _init_sliding_window(data, step_size, input_shape, in_channels, is2d,
                              track_progress, normalization)

    # start prediction
    batch_counter = 0
    batch = np.zeros((batch_size, channels, *input_shape))
    positions = np.zeros((batch_size, 3), dtype=int)
    for (z, y, x, inputs) in sw:

        # fill batch
        batch[batch_counter, ...] = inputs
        positions[batch_counter, :] = [z, y, x]

        # increment batch counter
        batch_counter += 1

        # perform segmentation when a full batch is filled
        if batch_counter == batch_size:
            # process a single batch
            _process_batch(net, batch, device, seg_cum, counts_cum, g_window,
                           positions, batch_size, input_shape, in_channels,
                           is2d)

            # reset batch counter
            batch_counter = 0

    # don't forget to process the last batch
    _process_batch(net, batch, device, seg_cum, counts_cum, g_window,
                   positions, batch_size, input_shape, in_channels, is2d)

    # crop out the symmetric extension and compute segmentation
    data, seg_cum, counts_cum = _crop(data, seg_cum, counts_cum, pad_width)
    for c in range(net.out_channels):
        seg_cum[c, ...] = np.divide(seg_cum[c, ...], counts_cum)

    # reorient data to its original orientation
    data = _orient(data, orientation)
    seg_cum = _orient(seg_cum, orientation)

    return seg_cum
Exemple #11
0
def segment_multichannel_2d(data,
                            net,
                            input_shape,
                            batch_size=1,
                            step_size=None,
                            train=False,
                            track_progress=False,
                            device=0,
                            normalization='unit'):
    """
    Segment a multichannel 2D image using a specific network

    :param data: 3D array (C, Y, X) representing the multichannel 2D image
    :param net: image-to-image segmentation network
    :param input_shape: size of the inputs (2-tuple)
    :param batch_size: batch size for processing
    :param step_size: step size of the sliding window
    :param train: evaluate the network in training mode
    :param track_progress: optionally, for tracking progress with progress bar
    :param device: GPU device where the computations should occur
    :param normalization: type of data normalization (unit, z or minmax)
    :return: the segmented image
    """

    # make sure we compute everything on the correct device
    module_to_device(net, device)

    # set the network in the correct mode
    if train:
        net.train()
    else:
        net.eval()

    # pad data if necessary
    data, pad_width = _pad(data[:, np.newaxis, ...], input_shape, 1)
    data = data[:, 0, ...]

    # get the amount of channels
    channels = data.shape[0]

    # initialize the step size
    step_size = _init_step_size(step_size, input_shape, True)

    # gaussian window for smooth block merging
    g_window = _init_gaussian_window(input_shape, True)

    # allocate space
    seg_cum = np.zeros((net.out_channels, 1, *data.shape[1:]))
    counts_cum = np.zeros((1, *data.shape[1:]))

    # define sliding window
    sw = _init_sliding_window(data[np.newaxis, ...],
                              [channels, *step_size[1:]], input_shape,
                              channels, True, track_progress, normalization)

    # start prediction
    batch_counter = 0
    batch = np.zeros((batch_size, channels, *input_shape))
    positions = np.zeros((batch_size, 3), dtype=int)
    for (z, y, x, inputs) in sw:

        # fill batch
        batch[batch_counter, ...] = inputs
        positions[batch_counter, :] = [z, y, x]

        # increment batch counter
        batch_counter += 1

        # perform segmentation when a full batch is filled
        if batch_counter == batch_size:
            # process a single batch
            _process_batch(net, batch, device, seg_cum, counts_cum, g_window,
                           positions, batch_size, input_shape, 1, True)

            # reset batch counter
            batch_counter = 0

    # don't forget to process the last batch
    _process_batch(net, batch, device, seg_cum, counts_cum, g_window,
                   positions, batch_size, input_shape, 1, True)

    # crop out the symmetric extension and compute segmentation
    data, seg_cum, counts_cum = _crop(data[:, np.newaxis, ...], seg_cum,
                                      counts_cum, pad_width)
    for c in range(net.out_channels):
        seg_cum[c, ...] = np.divide(seg_cum[c, ...], counts_cum)
    seg_cum = seg_cum[:, 0, ...]

    return seg_cum