def train(net,
          train_loader,
          val_loader,
          callbacks,
          params,
          reduce_epochs=False):

    # trains a model with a specific training and validation loader, and manually specified callbacks
    t_start = time.perf_counter()
    epochs = params['epochs'] // 2 if reduce_epochs else params['epochs']
    trainer = pl.Trainer(max_epochs=epochs,
                         gpus=params['gpus'],
                         accelerator=params['accelerator'],
                         default_root_dir=params['log_dir'],
                         flush_logs_every_n_steps=params['log_freq'],
                         log_every_n_steps=params['log_freq'],
                         callbacks=callbacks,
                         progress_bar_refresh_rate=params['log_refresh_rate'])
    trainer.fit(net, train_loader, val_loader)
    t_stop = time.perf_counter()
    print_frm('Elapsed training time: %d hours, %d minutes, %.2f seconds' %
              process_seconds(t_stop - t_start))

    # load the best checkpoint
    net.load_state_dict(
        torch.load(trainer.checkpoint_callback.best_model_path)['state_dict'])

    return trainer
Example #2
0
    def train_net(self, train_loader, test_loader, loss_fn, optimizer, epochs, scheduler=None, test_freq=1,
                  augmenter=None, print_stats=1, log_dir=None, write_images_freq=1, device=0):
        """
        Trains the network
        :param train_loader: data loader with training data
        :param test_loader: data loader with testing data
        :param loss_fn: loss function
        :param optimizer: optimizer for the loss function
        :param epochs: number of training epochs
        :param scheduler: optional scheduler for learning rate tuning
        :param test_freq: frequency of testing
        :param augmenter: data augmenter
        :param print_stats: frequency of logging statistics
        :param log_dir: logging directory
        :param write_images_freq: frequency of writing images
        :param device: GPU device where the computations should occur
        """
        # log everything if necessary
        if log_dir is not None:
            writer = SummaryWriter(log_dir=log_dir)
        else:
            writer = None

        j_max = 0
        for epoch in range(epochs):

            print_frm('Epoch %5d/%5d' % (epoch, epochs))

            # train the model for one epoch
            self.train_epoch(loader=train_loader, loss_fn=loss_fn, optimizer=optimizer, epoch=epoch,
                             augmenter=augmenter, print_stats=print_stats, writer=writer,
                             write_images=epoch % write_images_freq == 0, device=device)

            # adjust learning rate if necessary
            if scheduler is not None:
                scheduler.step()

                # and keep track of the learning rate
                writer.add_scalar('learning_rate', float(scheduler.get_last_lr()[0]), epoch)

            # test the model for one epoch is necessary
            if epoch % test_freq == 0:
                j = self.test_epoch(loader=test_loader, loss_fn=loss_fn, epoch=epoch, writer=writer,
                                    write_images=True, device=device)

                # and save model if higher segmentation performance was obtained
                if j > j_max:
                    j_max = j
                    torch.save(self, os.path.join(log_dir, 'best_checkpoint.pytorch'))

            # save model every epoch
            torch.save(self, os.path.join(log_dir, 'checkpoint.pytorch'))

        writer.close()
def validate(net, trainer, loader, params):
    # validates a network that was trained using a specific trainer on a dataset
    t_start = time.perf_counter()
    test_data, test_labels = loader.dataset.data[0], loader.dataset.labels[0]
    validate_base(net,
                  test_data,
                  test_labels,
                  params['input_size'],
                  in_channels=params['in_channels'],
                  classes_of_interest=params['coi'],
                  batch_size=params['test_batch_size'],
                  write_dir=os.path.join(trainer.log_dir, 'best_predictions'),
                  val_file=os.path.join(trainer.log_dir, 'metrics.npy'),
                  device=params['gpus'][0])
    t_stop = time.perf_counter()
    print_frm('Elapsed testing time: %d hours, %d minutes, %.2f seconds' %
              process_seconds(t_stop - t_start))
Example #4
0
def _select_subset(data,
                   labels,
                   n=1,
                   sz_size=(512, 512),
                   min_pos=0.01,
                   coi=(0, 1)):

    # data dimensions
    z = int(n)
    y = int(min(sz_size[0], data.shape[1]))
    x = int(min(sz_size[1], data.shape[2]))
    data_ = np.zeros((z, y, x), dtype=data.dtype)
    labels_ = np.zeros((z, y, x), dtype=labels.dtype)

    # constant
    max_iters = 100

    # select samples
    for z_ in range(z):
        found = False
        iters = 0
        while not found:
            iters += 1

            # select sample
            data_[z_:z_ + 1], labels_[z_:z_ + 1] = sample_labeled_input(
                data, labels, (1, y, x))

            # check f sample is valid
            nnz = 0
            for c in coi:
                if c > 0:
                    nnz += np.sum(labels_[z_:z_ + 1] == c)
            if nnz / (y * x) > min_pos:
                print_frm('Sample %d successfully found!' % z_)
                found = True
            if iters > max_iters:
                print_frm(
                    'Maximum number of iterations reached.. selecting random sample'
                )
                found = True

    # select the data and return
    return data_, labels_
Example #5
0
print('[%s] Arguments: ' % (datetime.datetime.now()))
print('[%s] %s' % (datetime.datetime.now(), args))
args.input_size = [int(item) for item in args.input_size.split(',')]
"""
Fix seed (for reproducibility)
"""
set_seed(args.seed)

# parameters
device = args.device  # computing device
n = args.n  # amount of samples to be extracted per domain
b = args.batch_size  # batch size for processing
input_size = args.input_size

# load the network
print_frm('Loading network')
model_file = args.net
net = _load_net(model_file, device)

# load reference patch
print_frm('Loading data')
data_file = args.data_file
df = json.load(open(data_file))
n_domains = len(df['raw'])
input_shape = (1, input_size[0], input_size[1])

# datasets
dss = []
for d in range(n_domains):
    print_frm('Loading %s' % df['raw'][d])
    dss.append(
Example #6
0
    def test_epoch(self,
                   loader_src,
                   loader_tar_ul,
                   loader_tar_l,
                   epoch,
                   writer=None,
                   write_images=False,
                   device=0):
        """
        Trains the network for one epoch
        :param loader_src: source dataloader (labeled)
        :param loader_tar_ul: target dataloader (unlabeled)
        :param loader_tar_l: target dataloader (labeled)
        :param epoch: current epoch
        :param writer: summary writer
        :param write_images: frequency of writing images
        :param device: GPU device where the computations should occur
        :return: average training loss over the epoch
        """
        # perform training on GPU/CPU
        module_to_device(self, device)
        self.eval()

        # keep track of the average loss during the epoch
        loss_seg_src_cum = 0.0
        loss_seg_tar_cum = 0.0
        total_loss_cum = 0.0
        cnt = 0

        # zip dataloaders
        if loader_tar_l is None:
            dl = zip(loader_src)
        else:
            dl = zip(loader_src, loader_tar_l)

        # start epoch
        y_preds = []
        ys = []
        time_start = datetime.datetime.now()
        for i, data in enumerate(dl):

            # transfer to suitable device
            x_src, y_src = tensor_to_device(data[0], device)
            x_tar_l, y_tar_l = tensor_to_device(data[1], device)
            x_src = x_src.float()
            x_tar_l = x_tar_l.float()
            y_src = y_src.long()
            y_tar_l = y_tar_l.long()

            # forward prop and compute loss
            y_src_pred = self(x_src)
            y_tar_l_pred = self(x_tar_l)
            loss_seg_src = self.seg_loss(y_src_pred, y_src[:, 0, ...])
            loss_seg_tar = self.seg_loss(y_tar_l_pred, y_tar_l[:, 0, ...])
            total_loss = loss_seg_src + loss_seg_tar

            loss_seg_src_cum += loss_seg_src.data.cpu().numpy()
            loss_seg_tar_cum += loss_seg_tar.data.cpu().numpy()
            total_loss_cum += total_loss.data.cpu().numpy()
            cnt += 1

            for b in range(y_tar_l_pred.size(0)):
                y_preds.append(
                    F.softmax(y_tar_l_pred,
                              dim=1)[b, ...].view(y_tar_l_pred.size(1),
                                                  -1).data.cpu().numpy())
                ys.append(y_tar_l[b, 0, ...].flatten().cpu().numpy())

        # keep track of time
        runtime = datetime.datetime.now() - time_start
        seconds = runtime.total_seconds()
        hours = seconds // 3600
        minutes = (seconds - hours * 3600) // 60
        seconds = seconds - hours * 3600 - minutes * 60
        print_frm(
            'Epoch %5d - Runtime for testing: %d hours, %d minutes, %f seconds'
            % (epoch, hours, minutes, seconds))

        # prep for metric computation
        y_preds = np.concatenate(y_preds, axis=1)
        ys = np.concatenate(ys)
        js = np.asarray([
            jaccard((ys == i).astype(int), y_preds[i, :])
            for i in range(len(self.coi))
        ])
        ams = np.asarray([
            accuracy_metrics((ys == i).astype(int), y_preds[i, :])
            for i in range(len(self.coi))
        ])

        # don't forget to compute the average and print it
        loss_seg_src_avg = loss_seg_src_cum / cnt
        loss_seg_tar_avg = loss_seg_tar_cum / cnt
        total_loss_avg = total_loss_cum / cnt
        print(
            '[%s] Testing Epoch %4d - Loss seg src: %.6f - Loss seg tar: %.6f - Loss: %.6f'
            % (datetime.datetime.now(), epoch, loss_seg_src_avg,
               loss_seg_tar_avg, total_loss_avg))

        # log everything
        if writer is not None:

            # always log scalars
            log_scalars([
                loss_seg_src_avg, loss_seg_tar_avg, total_loss_avg,
                np.mean(js, axis=0), *(np.mean(ams, axis=0))
            ], [
                'test/' + s for s in [
                    'loss-seg-src', 'loss-seg-tar', 'total-loss', 'jaccard',
                    'accuracy', 'balanced-accuracy', 'precision', 'recall',
                    'f-score'
                ]
            ],
                        writer,
                        epoch=epoch)

            # log images if necessary
            if write_images:
                y_src_pred = F.softmax(y_src_pred, dim=1)[:, 1:2, :, :].data
                y_tar_l_pred = F.softmax(y_tar_l_pred, dim=1)[:,
                                                              1:2, :, :].data
                log_images_2d([
                    x_src.data, y_src.data, y_src_pred, x_tar_l.data, y_tar_l,
                    y_tar_l_pred
                ], [
                    'test/' + s for s in [
                        'src/x', 'src/y', 'src/y-pred', 'tar/x-l', 'tar/y-l',
                        'tar/y-l-pred'
                    ]
                ],
                              writer,
                              epoch=epoch)

        return total_loss_avg
Example #7
0
    def train_epoch(self,
                    loader_src,
                    loader_tar_ul,
                    loader_tar_l,
                    optimizer,
                    epoch,
                    augmenter=None,
                    print_stats=1,
                    writer=None,
                    write_images=False,
                    device=0):
        """
        Trains the network for one epoch
        :param loader_src: source dataloader (labeled)
        :param loader_tar_ul: target dataloader (unlabeled)
        :param loader_tar_l: target dataloader (labeled)
        :param optimizer: optimizer for the loss function
        :param epoch: current epoch
        :param augmenter: data augmenter
        :param print_stats: frequency of printing statistics
        :param writer: summary writer
        :param write_images: frequency of writing images
        :param device: GPU device where the computations should occur
        :return: average training loss over the epoch
        """
        # perform training on GPU/CPU
        module_to_device(self, device)
        self.train()

        # keep track of the average loss during the epoch
        loss_seg_src_cum = 0.0
        loss_seg_tar_cum = 0.0
        loss_rec_src_cum = 0.0
        loss_rec_tar_cum = 0.0
        loss_dc_x_cum = 0.0
        loss_dc_y_cum = 0.0
        total_loss_cum = 0.0
        cnt = 0

        # zip dataloaders
        if loader_tar_l is None:
            dl = zip(loader_src, loader_tar_ul)
        else:
            dl = zip(loader_src, loader_tar_ul, loader_tar_l)

        # start epoch
        time_start = datetime.datetime.now()
        for i, data in enumerate(dl):

            # transfer to suitable device
            data_src = tensor_to_device(data[0], device)
            x_tar_ul = tensor_to_device(data[1], device)
            if loader_tar_l is not None:
                data_tar_l = tensor_to_device(data[2], device)

            # augment if necessary
            if loader_tar_l is None:
                data_aug = (data_src[0], data_src[1])
                x_src, y_src = augment_samples(data_aug, augmenter=augmenter)
                data_aug = (x_tar_ul, x_tar_ul)
                x_tar_ul, _ = augment_samples(data_aug, augmenter=augmenter)
            else:
                data_aug = (data_src[0], data_src[1])
                x_src, y_src = augment_samples(data_aug, augmenter=augmenter)
                data_aug = (x_tar_ul, x_tar_ul)
                x_tar_ul, _ = augment_samples(data_aug, augmenter=augmenter)
                data_aug = (data_tar_l[0], data_tar_l[1])
                x_tar_l, y_tar_l = augment_samples(data_aug,
                                                   augmenter=augmenter)
                y_tar_l = get_labels(y_tar_l, coi=self.coi, dtype=int)
            y_src = get_labels(y_src, coi=self.coi, dtype=int)
            x_tar_ul = x_tar_ul.float()

            # zero the gradient buffers
            self.zero_grad()

            # get domain labels for domain confusion
            dom_labels_x = tensor_to_device(
                torch.zeros((x_src.size(0) + x_tar_ul.size(0))),
                device).long()
            dom_labels_x[x_src.size(0):] = 1
            dom_labels_y = tensor_to_device(
                torch.zeros((x_src.size(0) + x_tar_ul.size(0))),
                device).long()
            dom_labels_y[x_src.size(0):] = 1

            # check train mode and compute loss
            loss_seg_src = torch.Tensor([0])
            loss_seg_tar = torch.Tensor([0])
            loss_rec_src = torch.Tensor([0])
            loss_rec_tar = torch.Tensor([0])
            loss_dc_x = torch.Tensor([0])
            loss_dc_y = torch.Tensor([0])
            if self.train_mode == RECONSTRUCTION:
                x_src_rec, x_src_rec_dom = self.forward_rec(x_src)
                x_tar_ul_rec, x_tar_ul_rec_dom = self.forward_rec(x_tar_ul)
                loss_rec_src = self.rec_loss(x_src_rec, x_src)
                loss_rec_tar = self.rec_loss(x_tar_ul, x_tar_ul_rec)
                loss_dc_x = self.dc_loss(
                    torch.cat((x_src_rec_dom, x_tar_ul_rec_dom), dim=0),
                    dom_labels_x)
                total_loss = loss_rec_src + loss_rec_tar + self.lambda_dc * loss_dc_x
            elif self.train_mode == SEGMENTATION:
                # switch between reconstructed and original inputs
                if np.random.rand() < self.p:
                    y_src_pred, y_src_pred_dom = self.forward_seg(x_src)
                else:
                    x_src_rec, _ = self.forward_rec(x_src)
                    y_src_pred, y_src_pred_dom = self.forward_seg(x_src_rec)
                    dom_labels_y[:x_src.size(0)] = 1
                if np.random.rand() < self.p:
                    y_tar_ul_pred, y_tar_ul_pred_dom = self.forward_seg(
                        x_tar_ul)
                else:
                    x_tar_ul_rec, _ = self.forward_rec(x_tar_ul)
                    y_tar_ul_pred, y_tar_ul_pred_dom = self.forward_seg(
                        x_tar_ul_rec)
                    dom_labels_y[x_src.size(0):] = 1
                loss_seg_src = self.seg_loss(y_src_pred, y_src[:, 0, ...])
                loss_dc_y = self.dc_loss(
                    torch.cat((y_src_pred_dom, y_tar_ul_pred_dom), dim=0),
                    dom_labels_y)
                total_loss = loss_seg_src + self.lambda_dc * loss_dc_y
                if loader_tar_l is not None:
                    y_tar_l_pred, _ = self.forward_seg(x_tar_l)
                    loss_seg_tar = self.seg_loss(y_tar_l_pred, y_tar_l[:, 0,
                                                                       ...])
                    total_loss = total_loss + loss_seg_tar
            else:
                x_src_rec, x_src_rec_dom = self.forward_rec(x_src)
                if np.random.rand() < self.p:
                    y_src_pred, y_src_pred_dom = self.forward_seg(x_src)
                else:
                    y_src_pred, y_src_pred_dom = self.forward_seg(x_src_rec)
                    dom_labels_y[:x_src.size(0)] = 1
                x_tar_ul_rec, x_tar_ul_rec_dom = self.forward_rec(x_tar_ul)
                if np.random.rand() < self.p:
                    y_tar_ul_pred, y_tar_ul_pred_dom = self.forward_seg(
                        x_tar_ul)
                else:
                    y_tar_ul_pred, y_tar_ul_pred_dom = self.forward_seg(
                        x_tar_ul_rec)
                    dom_labels_y[x_src.size(0):] = 1
                loss_rec_src = self.rec_loss(x_src_rec, x_src)
                loss_rec_tar = self.rec_loss(x_tar_ul, x_tar_ul_rec)
                loss_seg_src = self.seg_loss(y_src_pred, y_src[:, 0, ...])
                loss_dc_x = self.dc_loss(
                    torch.cat((x_src_rec_dom, x_tar_ul_rec_dom), dim=0),
                    dom_labels_x)
                loss_dc_y = self.dc_loss(
                    torch.cat((y_src_pred_dom, y_tar_ul_pred_dom), dim=0),
                    dom_labels_y)
                total_loss = loss_seg_src + self.lambda_rec * (loss_rec_src + loss_rec_tar) + \
                             self.lambda_dc * (loss_dc_x + loss_dc_y)
                if loader_tar_l is not None:
                    _, y_tar_l_pred, _, y_tar_l_pred_dom = self(x_tar_l)
                    loss_seg_tar = self.seg_loss(y_tar_l_pred, y_tar_l[:, 0,
                                                                       ...])
                    total_loss = total_loss + loss_seg_tar

            loss_seg_src_cum += loss_seg_src.data.cpu().numpy()
            loss_seg_tar_cum += loss_seg_tar.data.cpu().numpy()
            loss_rec_src_cum += loss_rec_src.data.cpu().numpy()
            loss_rec_tar_cum += loss_rec_tar.data.cpu().numpy()
            loss_dc_x_cum += loss_dc_x.data.cpu().numpy()
            loss_dc_y_cum += loss_dc_y.data.cpu().numpy()
            total_loss_cum += total_loss.data.cpu().numpy()
            cnt += 1

            # backward prop
            total_loss.backward()

            # apply one step in the optimization
            optimizer.step()

            # print statistics of necessary
            if i % print_stats == 0:
                print(
                    '[%s] Epoch %5d - Iteration %5d/%5d - Loss seg src: %.6f - Loss seg tar: %.6f - Loss rec src: %.6f - Loss rec tar: %.6f - Loss DCX: %.6f - Loss DCY: %.6f - Loss: %.6f'
                    % (datetime.datetime.now(), epoch, i,
                       len(loader_src.dataset) / loader_src.batch_size,
                       loss_seg_src_cum / cnt, loss_seg_tar_cum / cnt,
                       loss_rec_src_cum / cnt, loss_rec_tar_cum / cnt,
                       loss_dc_x_cum / cnt, loss_dc_y_cum / cnt,
                       total_loss_cum / cnt))

        # keep track of time
        runtime = datetime.datetime.now() - time_start
        seconds = runtime.total_seconds()
        hours = seconds // 3600
        minutes = (seconds - hours * 3600) // 60
        seconds = seconds - hours * 3600 - minutes * 60
        print_frm(
            'Epoch %5d - Runtime for training: %d hours, %d minutes, %f seconds'
            % (epoch, hours, minutes, seconds))

        # don't forget to compute the average and print it
        loss_seg_src_avg = loss_seg_src_cum / cnt
        loss_seg_tar_avg = loss_seg_tar_cum / cnt
        loss_rec_src_avg = loss_rec_src_cum / cnt
        loss_rec_tar_avg = loss_rec_tar_cum / cnt
        loss_dc_x_avg = loss_dc_x_cum / cnt
        loss_dc_y_avg = loss_dc_y_cum / cnt
        total_loss_avg = total_loss_cum / cnt
        print(
            '[%s] Training Epoch %4d - Loss seg src: %.6f - Loss seg tar: %.6f - Loss rec src: %.6f - Loss rec tar: %.6f - Loss DCX: %.6f - Loss DCY: %.6f - Loss: %.6f'
            % (datetime.datetime.now(), epoch, loss_seg_src_avg,
               loss_seg_tar_avg, loss_rec_src_avg, loss_rec_tar_avg,
               loss_dc_x_avg, loss_dc_y_avg, total_loss_avg))

        # log everything
        if writer is not None:

            # always log scalars
            if self.train_mode == RECONSTRUCTION:
                log_scalars(
                    [loss_rec_src_avg, loss_rec_tar_avg, loss_dc_x_avg], [
                        'train/' + s
                        for s in ['loss-rec-src', 'loss-rec-tar', 'loss-dc-x']
                    ],
                    writer,
                    epoch=epoch)
            elif self.train_mode == SEGMENTATION:
                log_scalars(
                    [loss_seg_src_avg, loss_seg_tar_avg, loss_dc_y_avg], [
                        'train/' + s
                        for s in ['loss-seg-src', 'loss-seg-tar', 'loss-dc-y']
                    ],
                    writer,
                    epoch=epoch)
            else:
                log_scalars([
                    loss_seg_src_avg, loss_seg_tar_avg, loss_rec_src_avg,
                    loss_rec_tar_avg, loss_dc_x_avg, loss_dc_y_avg
                ], [
                    'train/' + s for s in [
                        'loss-seg-src', 'loss-seg-tar', 'loss-rec-src',
                        'loss-rec-tar', 'loss-dc-x', 'loss-dc-y'
                    ]
                ],
                            writer,
                            epoch=epoch)
            log_scalars([total_loss_avg],
                        ['train/' + s for s in ['total-loss']],
                        writer,
                        epoch=epoch)

            # log images if necessary
            if write_images:
                log_images_2d([x_src.data], ['train/' + s for s in ['src/x']],
                              writer,
                              epoch=epoch)
                if self.train_mode == RECONSTRUCTION:
                    log_images_2d(
                        [x_src_rec.data, x_tar_ul.data, x_tar_ul_rec.data], [
                            'train/' + s
                            for s in ['src/x-rec', 'tar/x-ul', 'tar/x-ul-rec']
                        ],
                        writer,
                        epoch=epoch)
                elif self.train_mode == SEGMENTATION:
                    y_src_pred = F.softmax(y_src_pred, dim=1)[:,
                                                              1:2, :, :].data
                    log_images_2d(
                        [y_src.data, y_src_pred],
                        ['train/' + s for s in ['src/y', 'src/y-pred']],
                        writer,
                        epoch=epoch)
                    if loader_tar_l is not None:
                        y_tar_l_pred = F.softmax(y_tar_l_pred,
                                                 dim=1)[:, 1:2, :, :].data
                        log_images_2d([x_tar_l.data, y_tar_l, y_tar_l_pred], [
                            'train/' + s
                            for s in ['tar/x-l', 'tar/y-l', 'tar/y-l-pred']
                        ],
                                      writer,
                                      epoch=epoch)
                else:
                    y_src_pred = F.softmax(y_src_pred, dim=1)[:,
                                                              1:2, :, :].data
                    log_images_2d([
                        x_src_rec.data, y_src.data, y_src_pred, x_tar_ul.data,
                        x_tar_ul_rec.data
                    ], [
                        'train/' + s for s in [
                            'src/x-rec', 'src/y', 'src/y-pred', 'tar/x-ul',
                            'tar/x-ul-rec'
                        ]
                    ],
                                  writer,
                                  epoch=epoch)
                    if loader_tar_l is not None:
                        y_tar_l_pred = F.softmax(y_tar_l_pred,
                                                 dim=1)[:, 1:2, :, :].data
                        log_images_2d([x_tar_l.data, y_tar_l, y_tar_l_pred], [
                            'train/' + s
                            for s in ['tar/x-l', 'tar/y-l', 'tar/y-l-pred']
                        ],
                                      writer,
                                      epoch=epoch)

        return total_loss_avg
Example #8
0
    def test_epoch(self,
                   loader,
                   epoch,
                   writer=None,
                   write_images=False,
                   device=0):
        """
        Tests the network for one epoch
        :param loader: dataloader
        :param epoch: current epoch
        :param writer: summary writer
        :param write_images: frequency of writing images
        :param device: GPU device where the computations should occur
        :return: average testing loss over the epoch
        """
        # make sure network is on the gpu and in training mode
        module_to_device(self, device)
        self.eval()

        # keep track of the average losses during the epoch
        loss_rec_cum = 0.0
        loss_dc_cum = 0.0
        loss_cum = 0.0
        cnt = 0

        # start epoch
        for i, data in enumerate(loader):
            # transfer to suitable device
            x, dom = data
            x = tensor_to_device(x.float(), device)
            dom = tensor_to_device(dom.long(), device)

            # forward prop
            x_pred, dom_pred = self(x)
            x_pred = torch.sigmoid(x_pred)

            # compute loss
            loss_rec = self.loss_rec_fn(x_pred, x)
            loss_dc = self.loss_dc_fn(dom_pred, dom)
            loss = loss_rec + self.lambda_reg * loss_dc
            loss_rec_cum += loss_rec.data.cpu().numpy()
            loss_dc_cum += loss_dc.data.cpu().numpy()
            loss_cum += loss.data.cpu().numpy()
            cnt += 1

        # don't forget to compute the average and print it
        loss_rec_avg = loss_rec_cum / cnt
        loss_dc_avg = loss_dc_cum / cnt
        loss_avg = loss_cum / cnt
        print_frm(
            'Epoch %5d - Average test loss rec: %.6f - Average test loss DC: %.6f - Average test loss: %.6f'
            % (epoch, loss_rec_avg, loss_dc_avg, loss_avg))

        # log everything
        if writer is not None:

            # always log scalars
            log_scalars([loss_rec_avg, loss_dc_avg, loss_avg],
                        ['test/' + s for s in ['loss-rec', 'loss-dc', 'loss']],
                        writer,
                        epoch=epoch)

            # log images if necessary
            if write_images:
                log_images_2d([x, x_pred],
                              ['test/' + s for s in ['x', 'x_pred']],
                              writer,
                              epoch=epoch)

        return loss_avg
Example #9
0
    def test_epoch(self, loader, loss_fn, epoch, writer=None, write_images=False, device=0):
        """
        Tests the network for one epoch
        :param loader: dataloader
        :param loss_fn: loss function
        :param epoch: current epoch
        :param writer: summary writer
        :param write_images: frequency of writing images
        :param device: GPU device where the computations should occur
        :return: average testing loss over the epoch
        """
        # perform training on GPU/CPU
        module_to_device(self, device)
        self.eval()

        # keep track of the average loss and metrics during the epoch
        loss_cum = 0.0
        cnt = 0

        # test loss
        y_preds = []
        ys = []
        ys_ = []
        time_start = datetime.datetime.now()
        for i, data in enumerate(loader):

            # get the inputs and transfer to suitable device
            x, y = tensor_to_device(data, device)
            y_ = get_unlabeled(y)
            x = x.float()
            y = get_labels(y, coi=self.coi, dtype=int)
            y_ = get_labels(y_, coi=[0, 255], dtype=bool)

            # forward prop
            y_pred = self(x)

            # compute loss
            loss = loss_fn(y_pred, y[:, 0, ...], mask=~y_)
            loss_cum += loss.data.cpu().numpy()
            cnt += 1

            for b in range(y_pred.size(0)):
                y_preds.append(F.softmax(y_pred, dim=1)[b, ...].view(y_pred.size(1), -1).data.cpu().numpy())
                ys.append(y[b, 0, ...].flatten().cpu().numpy())
                ys_.append(y_[b, 0, ...].flatten().cpu().numpy())

        # keep track of time
        runtime = datetime.datetime.now() - time_start
        seconds = runtime.total_seconds()
        hours = seconds // 3600
        minutes = (seconds - hours * 3600) // 60
        seconds = seconds - hours * 3600 - minutes * 60
        print_frm(
            'Epoch %5d - Runtime for testing: %d hours, %d minutes, %f seconds' % (epoch, hours, minutes, seconds))

        # prep for metric computation
        y_preds = np.concatenate(y_preds, axis=1)
        ys = np.concatenate(ys)
        ys_ = np.concatenate(ys_)
        w = (1 - ys_).astype(bool)
        js = np.asarray([jaccard((ys == i).astype(int), y_preds[i, :], w=w) for i in range(len(self.coi))])
        ams = np.asarray([accuracy_metrics((ys == i).astype(int), y_preds[i, :], w=w) for i in range(len(self.coi))])

        # don't forget to compute the average and print it
        loss_avg = loss_cum / cnt
        print_frm('Epoch %5d - Average test loss: %.6f' % (epoch, loss_avg))

        # log everything
        if writer is not None:

            # always log scalars
            log_scalars([loss_avg, np.mean(js, axis=0), *(np.mean(ams, axis=0))], ['test/' + s for s in
                                                                                   ['loss-seg', 'jaccard', 'accuracy',
                                                                                    'balanced-accuracy', 'precision',
                                                                                    'recall', 'f-score']], writer,
                        epoch=epoch)

            # log images if necessary
            if write_images:
                log_images_3d([x], ['test/' + s for s in ['x']], writer, epoch=epoch)
                y_pred = F.softmax(y_pred, dim=1)
                for i, c in enumerate(self.coi):
                    if not i == 0:  # skip background class
                        y_p = y_pred[:, i:i + 1, ...].data
                        y_t = (y == i).long()
                        log_images_3d([y_t, y_p],
                                      ['test/' + s for s in ['y_class_%d)' % (c), 'y_pred_class_%d)' % (c)]], writer,
                                      epoch=epoch)

        return np.mean(js)
Example #10
0
from neuralnets.util.io import print_frm
from neuralnets.util.tools import set_seed

from util.tools import parse_params, get_dataloaders
from networks.factory import generate_model
from train.base import train, validate

from multiprocessing import freeze_support

if __name__ == '__main__':
    freeze_support()
    """
        Parse all the arguments
    """
    print_frm('Parsing arguments')
    parser = argparse.ArgumentParser()
    parser.add_argument("--config",
                        "-c",
                        help="Path to the configuration file",
                        type=str,
                        default='train_supervised.yaml')
    parser.add_argument(
        "--clean-up",
        help="Boolean flag that specifies cleaning of the checkpoints",
        action='store_true',
        default=False)
    args = parser.parse_args()
    with open(args.config) as file:
        params = parse_params(yaml.load(file, Loader=yaml.FullLoader))
    """
Example #11
0
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.transforms import Compose

from neuralnets.data.datasets import StronglyLabeledVolumeDataset
from neuralnets.networks.unet import UNet2D
from neuralnets.util.augmentation import *
from neuralnets.util.io import print_frm
from neuralnets.util.losses import get_loss_function
from neuralnets.util.tools import set_seed
from neuralnets.util.validation import validate
"""
    Parse all the arguments
"""
print_frm('Parsing arguments')
parser = argparse.ArgumentParser()

# logging parameters
parser.add_argument("--seed",
                    help="Seed for randomization",
                    type=int,
                    default=0)
parser.add_argument("--device",
                    help="GPU device for computations",
                    type=int,
                    default=0)
parser.add_argument("--log_dir",
                    help="Logging directory",
                    type=str,
                    default="unet_2d")
Example #12
0
def get_dataloaders(params,
                    domain=None,
                    domain_labels_available=1.0,
                    supervised=False):

    input_shape = (1, *(params['input_size']))
    transform = get_transforms(params['augmentation'], coi=params['coi'])
    print_frm('Applying data augmentation! Specifically %s' %
              str(params['augmentation']))

    if domain is None:

        split_src = params['src']['train_val_test_split']
        split_tar = params['tar']['train_val_test_split']
        print_frm('Train data... ')
        train = LabeledVolumeDataset(
            (params['src']['data'], params['tar']['data']),
            (params['src']['labels'], params['tar']['labels']),
            len_epoch=params['len_epoch'],
            input_shape=input_shape,
            in_channels=params['in_channels'],
            type=params['type'],
            batch_size=params['train_batch_size'],
            transform=transform,
            range_split=((0, split_src[0]), (0, split_tar[0])),
            coi=params['coi'],
            range_dir=(params['src']['split_orientation'],
                       params['tar']['split_orientation']),
            partial_labels=(1, params['tar_labels_available']),
            seed=params['seed'])
        print_frm('Validation data...')
        val = LabeledVolumeDataset(
            (params['src']['data'], params['tar']['data']),
            (params['src']['labels'], params['tar']['labels']),
            len_epoch=params['len_epoch'],
            input_shape=input_shape,
            in_channels=params['in_channels'],
            type=params['type'],
            batch_size=params['test_batch_size'],
            coi=params['coi'],
            range_split=((split_src[0], split_src[1]), (split_tar[0],
                                                        split_tar[1])),
            range_dir=(params['src']['split_orientation'],
                       params['tar']['split_orientation']),
            partial_labels=(1, params['tar_labels_available']),
            seed=params['seed'])
        print_frm('Test data...')
        test = LabeledSlidingWindowDataset(
            params['tar']['data'],
            params['tar']['labels'],
            in_channels=params['in_channels'],
            type=params['type'],
            batch_size=params['test_batch_size'],
            range_split=(split_tar[1], 1),
            range_dir=params['tar']['split_orientation'],
            coi=params['coi'])

        print_frm('Train volume shape: %s (source) - %s (target)' %
                  (str(train.data[0].shape), str(train.data[1].shape)))
        print_frm(
            'Available target labels for training: %.1f (i.e. %.2f MV)' %
            (params['tar_labels_available'] * 100, np.prod(train.data[1].shape)
             * params['tar_labels_available'] / 1000 / 1000))
        print_frm('Validation volume shape: %s (source) - %s (target)' %
                  (str(val.data[0].shape), str(val.data[1].shape)))
        print_frm('Test volume shape: %s (target)' % str(test.data[0].shape))

    else:

        split = params['train_val_test_split'] if supervised else params[
            domain]['train_val_test_split']
        data = params['data'] if supervised else params[domain]['data']
        labels = params['labels'] if supervised else params[domain]['labels']
        range_dir = params['split_orientation'] if supervised else params[
            domain]['split_orientation']
        print_frm('Train data...')
        train = LabeledVolumeDataset(data,
                                     labels,
                                     len_epoch=params['len_epoch'],
                                     input_shape=input_shape,
                                     in_channels=params['in_channels'],
                                     type=params['type'],
                                     batch_size=params['train_batch_size'],
                                     transform=transform,
                                     range_split=(0, split[0]),
                                     range_dir=range_dir,
                                     partial_labels=domain_labels_available,
                                     seed=params['seed'],
                                     coi=params['coi'])
        print_frm('Validation data...')
        val = LabeledVolumeDataset(data,
                                   labels,
                                   len_epoch=params['len_epoch'],
                                   input_shape=input_shape,
                                   in_channels=params['in_channels'],
                                   type=params['type'],
                                   batch_size=params['test_batch_size'],
                                   transform=transform,
                                   range_split=(split[0], split[1]),
                                   range_dir=range_dir,
                                   coi=params['coi'],
                                   partial_labels=domain_labels_available,
                                   seed=params['seed'])
        print_frm('Test data...')
        test = LabeledSlidingWindowDataset(
            data,
            labels,
            in_channels=params['in_channels'],
            type=params['type'],
            batch_size=params['test_batch_size'],
            transform=transform,
            range_split=(split[1], 1),
            range_dir=range_dir,
            coi=params['coi'])

        print_frm('Train volume shape: %s' % str(train.data[0].shape))
        print_frm(
            'Available %s labels for training: %d%% (i.e. %.2f MV)' %
            (domain, domain_labels_available * 100, np.prod(
                train.data[0].shape) * domain_labels_available / 1000 / 1000))
        print_frm('Validation volume shape: %s' % str(val.data[0].shape))
        print_frm('Test volume shape: %s' % str(test.data[0].shape))

    train_loader = DataLoader(train,
                              batch_size=params['train_batch_size'],
                              num_workers=params['num_workers'],
                              pin_memory=True)
    val_loader = DataLoader(val,
                            batch_size=params['test_batch_size'],
                            num_workers=params['num_workers'],
                            pin_memory=True)
    test_loader = DataLoader(test,
                             batch_size=params['test_batch_size'],
                             num_workers=params['num_workers'],
                             pin_memory=True)

    return train_loader, val_loader, test_loader
Example #13
0
def mv(source, target):
    print_frm('    Moving %s -> %s' % (source, target))
    shutil.move(source, target)
Example #14
0
def cp(source, target):
    print_frm('    Copying %s -> %s' % (source, target))
    shutil.copyfile(source, target)
Example #15
0
def rmdir(dir):
    print_frm('    Removing %s' % dir)
    shutil.rmtree(dir, ignore_errors=True)
Example #16
0
    def __init__(self,
                 data_path,
                 input_shape,
                 split_orientation='z',
                 split_location=0.50,
                 scaling=None,
                 len_epoch=1000,
                 types=['tif3d'],
                 sampling_mode='uniform',
                 in_channels=1,
                 orientations=(0, ),
                 batch_size=1,
                 dtype='uint8',
                 norm_type='unit',
                 train=True,
                 available=-1):
        self.data_path = data_path
        self.input_shape = input_shape
        self.scaling = scaling
        self.len_epoch = len_epoch
        self.sampling_mode = sampling_mode
        self.in_channels = in_channels
        self.orientations = orientations
        self.orientation = 0
        self.k = 0
        self.batch_size = batch_size
        self.norm_type = norm_type

        # load the data
        self.data = []
        self.data_sizes = []
        for k, path in enumerate(data_path):
            print_frm('Loading dataset %d/%d: %s' % (k, len(data_path), path))

            d = 0 if split_orientation[
                k] == 'z' else 1 if split_orientation[k] == 'y' else 2
            if split_orientation[k] == 'z':
                split = int(len(os.listdir(path)) * split_location[k])
                start = 0 if train else split
                stop = split if train else -1
                data = read_volume(path,
                                   type=types[k],
                                   dtype=dtype,
                                   start=start,
                                   stop=stop)
            else:
                data = read_volume(path, type=types[k], dtype=dtype)
                split = int(data.shape[d] * split_location[k])
                if split_orientation[k] == 'y':
                    data = data[:, :split, :] if train else data[:, split:, :]
                else:
                    data = data[:, :, :split] if train else data[:, :, split:]

            # rescale the dataset if necessary
            if scaling is not None:
                target_size = np.asarray(np.multiply(data.shape, scaling),
                                         dtype=int)
                data = F.interpolate(torch.Tensor(data[np.newaxis, np.newaxis,
                                                       ...]),
                                     size=tuple(target_size),
                                     mode='area')[0, 0, ...].numpy()

            self.data.append(data)
            self.data_sizes.append(data.size)

        self.data_sizes = np.array(self.data_sizes)
        self.data_sizes = self.data_sizes / np.sum(self.data_sizes)
from neuralnets.util.io import print_frm, read_pngseq
from neuralnets.util.tools import set_seed
from neuralnets.util.validation import segment_read, segment_ram

from util.tools import parse_params, process_seconds
from networks.factory import generate_model

from multiprocessing import freeze_support

if __name__ == '__main__':
    freeze_support()

    """
        Parse all the arguments
    """
    print_frm('Parsing arguments')
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", "-c", help="Path to the network configuration file", type=str,
                        default='../train_supervised.yaml')
    parser.add_argument("--model", "-m", help="Path to the network parameters", type=str, required=True)
    parser.add_argument("--dataset", "-d", help="Path to the dataset that needs to be segmented", type=str,
                        required=True)
    parser.add_argument("--block_wise", "-bw", help="Flag that specifies to compute block wise or not",
                        action='store_true', default=False)
    parser.add_argument("--output", "-o", help="Path to store the output segmentation", type=str, required=True)
    parser.add_argument("--gpu", "-g", help="GPU device for computations", type=int, default=0)
    args = parser.parse_args()
    with open(args.config) as file:
        params = parse_params(yaml.load(file, Loader=yaml.FullLoader))

    """
Example #18
0
args.input_size = [int(item) for item in args.input_size.split(',')]
"""
Fix seed (for reproducibility)
"""
set_seed(args.seed)

# parameters
domain_id = args.domain_id  # id of the domain where a reference patch should be selected
device = args.device  # computing device
n = args.n  # amount of samples to be extracted per domain
b = args.batch_size  # batch size for processing
input_size = args.input_size
k = args.k  # amount of closest samples to be extracted

# load the network
print_frm('Loading network')
model_file = args.net
net = _load_net(model_file, device)

# load reference patch
print_frm('Loading data')
data_file = args.data_file
df = json.load(open(data_file))
n_domains = len(df['raw'])
input_shape = (1, input_size[0], input_size[1])
dataset_ref = UnlabeledVolumeDataset(
    df['raw'][domain_id],
    split_orientation=df['split-orientation'][domain_id],
    split_location=df['split-location'][domain_id],
    input_shape=input_shape,
    type=df['types'][domain_id],
Example #19
0
from neuralnets.util.augmentation import *
from neuralnets.util.io import print_frm, mkdir
from neuralnets.util.tools import set_seed

from util.tools import parse_params
from networks.factory import generate_model
from train.base import train, validate

from multiprocessing import freeze_support

if __name__ == '__main__':
    freeze_support()
    """
        Parse all the arguments
    """
    print_frm('Parsing arguments')
    parser = argparse.ArgumentParser()
    parser.add_argument("--config",
                        "-c",
                        help="Path to the configuration file",
                        type=str,
                        default='clem1.yaml')
    parser.add_argument(
        "--clean-up",
        help="Boolean flag that specifies cleaning of the checkpoints",
        action='store_true',
        default=False)
    args = parser.parse_args()
    with open(args.config) as file:
        params = parse_params(yaml.load(file, Loader=yaml.FullLoader))
    """
Example #20
0
from util.tools import parse_params, get_dataloaders, rmdir, mv, cp
from networks.factory import generate_model
from train.base import train, validate

from multiprocessing import freeze_support



if __name__ == '__main__':
    freeze_support()

    """
        Parse all the arguments
    """
    print_frm('Parsing arguments')
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", "-c", help="Path to the configuration file", type=str,
                        default='train_semi_supervised.yaml')
    parser.add_argument("--clean-up", help="Boolean flag that specifies cleaning of the checkpoints",
                        action='store_true', default=False)
    args = parser.parse_args()
    with open(args.config) as file:
        params = parse_params(yaml.load(file, Loader=yaml.FullLoader))

    """
    Fix seed (for reproducibility)
    """
    set_seed(params['seed'])

    """
Example #21
0
    def train_epoch(self, loader, loss_fn, optimizer, epoch, augmenter=None, print_stats=1, writer=None,
                    write_images=False, device=0):
        """
        Trains the network for one epoch
        :param loader: dataloader
        :param loss_fn: loss function
        :param optimizer: optimizer for the loss function
        :param epoch: current epoch
        :param augmenter: data augmenter
        :param print_stats: frequency of printing statistics
        :param writer: summary writer
        :param write_images: frequency of writing images
        :param device: GPU device where the computations should occur
        :return: average training loss over the epoch
        """
        # perform training on GPU/CPU
        module_to_device(self, device)
        self.train()

        # keep track of the average loss during the epoch
        loss_cum = 0.0
        cnt = 0

        # start epoch
        time_start = datetime.datetime.now()
        for i, data in enumerate(loader):

            # transfer to suitable device and get labels
            data = tensor_to_device(data, device)
            y = get_labels(data[1], coi=self.coi, dtype=float)

            # filter out unlabeled pixels and include them in augmentation
            y_ = get_unlabeled(data[1], dtype=float)
            data.append(y)
            data.append(y_)

            # perform augmentation and transform to appropriate type
            x, _, y, y_ = augment_samples(data, augmenter=augmenter)
            rep = 0
            rep_max = 10
            while (
                    1 - y_).sum() == 0 and rep < rep_max:  # make sure labels are not lost in augmentation, otherwise augment new sample
                x, _, y, y_ = augment_samples(data, augmenter=augmenter)
                rep += 1
                if rep == rep_max:
                    x, _, y, y_ = data
            x = x.float()
            y = y.round().long()
            # clean labels if necessary (due to augmentations)
            if len(self.coi) > 2:
                y = clean_labels(y, len(self.coi))
            y_ = y_.bool()

            # zero the gradient buffers
            self.zero_grad()

            # forward prop
            y_pred = self(x)

            # compute loss
            loss = loss_fn(y_pred, y[:, 0, ...], mask=~y_)
            loss_cum += loss.data.cpu().numpy()
            cnt += 1

            # backward prop
            loss.backward()

            # apply one step in the optimization
            optimizer.step()

            # print statistics if necessary
            if i % print_stats == 0:
                print_frm('Epoch %5d - Iteration %5d/%5d - Loss: %.6f' % (
                    epoch, i, len(loader.dataset) / loader.batch_size, loss))

        # keep track of time
        runtime = datetime.datetime.now() - time_start
        seconds = runtime.total_seconds()
        hours = seconds // 3600
        minutes = (seconds - hours * 3600) // 60
        seconds = seconds - hours * 3600 - minutes * 60
        print_frm(
            'Epoch %5d - Runtime for training: %d hours, %d minutes, %f seconds' % (epoch, hours, minutes, seconds))

        # don't forget to compute the average and print it
        loss_avg = loss_cum / cnt
        print_frm('Epoch %5d - Average train loss: %.6f' % (epoch, loss_avg))

        # log everything
        if writer is not None:

            # always log scalars
            log_scalars([loss_avg], ['train/' + s for s in ['loss-seg']], writer, epoch=epoch)

            # log images if necessary
            if write_images:
                log_images_3d([x], ['train/' + s for s in ['x']], writer, epoch=epoch)
                y_pred = F.softmax(y_pred, dim=1)
                for i, c in enumerate(self.coi):
                    if not i == 0:  # skip background class
                        y_p = y_pred[:, i:i + 1, ...].data
                        y_t = (y == i).long()
                        log_images_3d([y_t, y_p],
                                      ['train/' + s for s in ['y_class_%d)' % (c), 'y_pred_class_%d)' % (c)]], writer,
                                      epoch=epoch)

        return loss_avg
Example #22
0
def validate(net,
             data,
             labels,
             input_size,
             in_channels=1,
             classes_of_interest=(0, 1),
             batch_size=1,
             write_dir=None,
             val_file=None,
             track_progress=False,
             device=0,
             orientations=(0, ),
             normalization='unit'):
    """
    Validate a network on a dataset and its labels

    :param net: image-to-image segmentation network
    :param data: 3D array (Z, Y, X) representing the 3D image
    :param labels: 3D array (Z, Y, X) representing the 3D labels
    :param input_size: size of the inputs (either 2 or 3-tuple) for processing
    :param in_channels: Amount of subsequent slices that serve as input for the network (should be odd)
    :param classes_of_interest: index of the label of interest
    :param batch_size: batch size for processing
    :param write_dir: optionally, specify a directory to write the output
    :param val_file: optionally, specify a file to write the validation results
    :param track_progress: optionally, for tracking progress with progress bar
    :param device: GPU device where the computations should occur
    :param orientations: list of orientations to perform segmentation: 0-Z, 1-Y, 2-X (only for 2D based segmentation)
    :param normalization: type of data normalization (unit, z or minmax)
    :return: validation results, i.e. accuracy, precision, recall, f-score, jaccard and dice score
    """

    print_frm('Validating the trained network...')

    # compute segmentation for each orientation and average results
    segmentation = np.zeros((net.out_channels, *data.shape))
    for orientation in orientations:
        segmentation += segment(data,
                                net,
                                input_size,
                                in_channels=in_channels,
                                batch_size=batch_size,
                                track_progress=track_progress,
                                device=device,
                                orientation=orientation,
                                normalization=normalization)
    segmentation = segmentation / len(orientations)

    # compute metrics
    w = labels != 255
    comp_hausdorff = np.sum(labels == 255) == 0
    js = np.asarray([
        jaccard(segmentation[i], (labels == c).astype('float'), w=w)
        for i, c in enumerate(classes_of_interest)
    ])
    ams = np.asarray([
        accuracy_metrics(segmentation[i], (labels == c).astype('float'), w=w)
        for i, c in enumerate(classes_of_interest)
    ])
    for i, c in enumerate(classes_of_interest):
        if comp_hausdorff:
            h = hausdorff_distance(segmentation[i], labels)[0]
        else:
            h = -1

        # report results
        print_frm('Validation performance for class %d: ' % c)
        print_frm('    - Accuracy: %f' % ams[i, 0])
        print_frm('    - Balanced accuracy: %f' % ams[i, 1])
        print_frm('    - Precision: %f' % ams[i, 2])
        print_frm('    - Recall: %f' % ams[i, 3])
        print_frm('    - F1 score: %f' % ams[i, 4])
        print_frm('    - IoU: %f' % js[i])
        print_frm('    - Hausdorff distance: %f' % h)

    # report results
    print_frm('Validation performance mean: ')
    print_frm('    - Accuracy: %f' % np.mean(ams[:, 0]))
    print_frm('    - Balanced accuracy: %f' % np.mean(ams[:, 1]))
    print_frm('    - Precision: %f' % np.mean(ams[:, 2]))
    print_frm('    - Recall: %f' % np.mean(ams[:, 3]))
    print_frm('    - F1 score: %f' % np.mean(ams[:, 4]))
    print_frm('    - mIoU: %f' % np.mean(js))

    # write stuff if necessary
    if write_dir is not None:
        print_frm('Writing out the segmentation...')
        mkdir(write_dir)
        segmentation_volume = np.zeros(segmentation.shape[1:])
        for i, c in enumerate(classes_of_interest):
            segmentation_volume[segmentation[i] > 0.5] = c
        write_volume(segmentation_volume, write_dir, type='pngseq')
    if val_file is not None:
        np.save(val_file, np.concatenate((js[:, np.newaxis], ams), axis=1))
    return js, ams
Example #23
0
    def train_epoch(self,
                    loader,
                    optimizer,
                    epoch,
                    augmenter=None,
                    print_stats=1,
                    writer=None,
                    write_images=False,
                    device=0):
        """
        Trains the network for one epoch
        :param loader: dataloader
        :param optimizer: optimizer for the loss function
        :param epoch: current epoch
        :param augmenter: data augmenter
        :param print_stats: frequency of printing statistics
        :param writer: summary writer
        :param write_images: frequency of writing images
        :param device: GPU device where the computations should occur
        :return: average training loss over the epoch
        """
        # make sure network is on the gpu and in training mode
        module_to_device(self, device)
        self.train()

        # keep track of the average losses during the epoch
        loss_rec_cum = 0.0
        loss_dc_cum = 0.0
        loss_cum = 0.0
        cnt = 0

        # start epoch
        for i, data in enumerate(loader):

            # transfer to suitable device
            x, dom = data
            x = tensor_to_device(x.float(), device)
            dom = tensor_to_device(dom.long(), device)

            # get the inputs and augment if necessary
            if augmenter is not None:
                x = augmenter(x)

            # zero the gradient buffers
            self.zero_grad()

            # forward prop
            x_pred, dom_pred = self(x)
            x_pred = torch.sigmoid(x_pred)

            # compute loss
            loss_rec = self.loss_rec_fn(x_pred, x)
            loss_dc = self.loss_dc_fn(dom_pred, dom)
            loss = loss_rec + self.lambda_reg * loss_dc
            loss_rec_cum += loss_rec.data.cpu().numpy()
            loss_dc_cum += loss_dc.data.cpu().numpy()
            loss_cum += loss.data.cpu().numpy()
            cnt += 1

            # backward prop
            loss.backward()

            # apply one step in the optimization
            optimizer.step()

            # print statistics if necessary
            if i % print_stats == 0:
                print_frm(
                    'Epoch %5d - Iteration %5d/%5d - Loss Rec: %.6f - Loss DC: %.6f - Loss: %.6f'
                    % (epoch, i, len(loader.dataset) / loader.batch_size,
                       loss_rec, loss_dc, loss))

        # don't forget to compute the average and print it
        loss_rec_avg = loss_rec_cum / cnt
        loss_dc_avg = loss_dc_cum / cnt
        loss_avg = loss_cum / cnt
        print_frm(
            'Epoch %5d - Average train loss rec: %.6f - Average train loss DC: %.6f - Average train loss: %.6f'
            % (epoch, loss_rec_avg, loss_dc_avg, loss_avg))

        # log everything
        if writer is not None:

            # always log scalars
            log_scalars(
                [loss_rec_avg, loss_dc_avg, loss_avg],
                ['train/' + s for s in ['loss-rec', 'loss-dc', 'loss']],
                writer,
                epoch=epoch)

            # log images if necessary
            if write_images:
                log_images_2d([x, x_pred],
                              ['train/' + s for s in ['x', 'x_pred']],
                              writer,
                              epoch=epoch)

        return loss_avg
Example #24
0
    def test_epoch(self,
                   loader,
                   loss_rec_fn,
                   loss_kl_fn,
                   epoch,
                   writer=None,
                   write_images=False,
                   device=0):
        """
        Tests the network for one epoch
        :param loader: dataloader
        :param loss_rec_fn: reconstruction loss function
        :param loss_kl_fn: kullback leibler loss function
        :param epoch: current epoch
        :param writer: summary writer
        :param write_images: frequency of writing images
        :param device: GPU device where the computations should occur
        :return: average testing loss over the epoch
        """
        # make sure network is on the gpu and in training mode
        module_to_device(self, device)
        self.eval()

        # keep track of the average losses during the epoch
        loss_rec_cum = 0.0
        loss_kl_cum = 0.0
        loss_cum = 0.0
        cnt = 0

        # start epoch
        z = []
        li = []
        for i, data in enumerate(loader):
            # transfer to suitable device
            x = tensor_to_device(data.float(), device)

            # forward prop
            x_pred = torch.sigmoid(self(x))
            z.append(_reparametrise(self.mu, self.logvar).cpu().data.numpy())
            li.append(x.cpu().data.numpy())

            # compute loss
            loss_rec = loss_rec_fn(x_pred, x)
            loss_kl = loss_kl_fn(self.mu, self.logvar)
            loss = loss_rec + self.beta * loss_kl
            loss_rec_cum += loss_rec.data.cpu().numpy()
            loss_kl_cum += loss_kl.data.cpu().numpy()
            loss_cum += loss.data.cpu().numpy()
            cnt += 1

        # don't forget to compute the average and print it
        loss_rec_avg = loss_rec_cum / cnt
        loss_kl_avg = loss_kl_cum / cnt
        loss_avg = loss_cum / cnt
        print_frm(
            'Epoch %5d - Average test loss rec: %.6f - Average test loss KL: %.6f - Average test loss: %.6f'
            % (epoch, loss_rec_avg, loss_kl_avg, loss_avg))

        # log everything
        if writer is not None:

            # always log scalars
            log_scalars([loss_rec_avg, loss_kl_avg, loss_avg],
                        ['test/' + s for s in ['loss-rec', 'loss-kl', 'loss']],
                        writer,
                        epoch=epoch)

            # log images if necessary
            if write_images:
                log_images_2d([x, x_pred],
                              ['test/' + s for s in ['x', 'x_pred']],
                              writer,
                              epoch=epoch)

        return loss_avg
Example #25
0
results_final = np.asarray(results_final)  # D x N x C x M
results_best = np.asarray(results_best)  # D x N x C x M

# average over classes
results_final = 100 * np.mean(results_final, axis=2)
results_best = 100 * np.mean(results_best, axis=2)

# compute average and standard deviation over experiments
results_final_mean = np.mean(results_final, axis=1)
results_final_std = np.std(results_final, axis=1)
results_best_mean = np.mean(results_best, axis=1)
results_best_std = np.std(results_best, axis=1)

# report mean performance
for i, dom in enumerate(domains):
    print_frm('Domain: %s' % dom)
    print_frm('')

    print_frm('Validation performance final: ')
    for j, metric in enumerate(metrics):
        print_frm('    - %s: %.2f (+/- %.2f)' %
                  (metric, results_final_mean[i, j], results_final_std[i, j]))
    print_frm('')

    print_frm('Validation performance best: ')
    for j, metric in enumerate(metrics):
        print_frm('    - %s: %.2f (+/- %.2f)' %
                  (metric, results_best_mean[i, j], results_best_std[i, j]))
    print_frm('')

    print_frm('=================================================')
Example #26
0
def generate_model(name, params):

    if name == 'u-net' or name == 'no-da':
        net = UNetDA2D(in_channels=params['in_channels'],
                       feature_maps=params['fm'],
                       levels=params['levels'],
                       dropout_enc=params['dropout'],
                       dropout_dec=params['dropout'],
                       norm=params['norm'],
                       activation=params['activation'],
                       coi=params['coi'],
                       loss_fn=params['loss'],
                       lr=params['lr'])
    elif name == 'mmd':
        net = UNetMMD2D(in_channels=params['in_channels'],
                        feature_maps=params['fm'],
                        levels=params['levels'],
                        dropout_enc=params['dropout'],
                        dropout_dec=params['dropout'],
                        norm=params['norm'],
                        activation=params['activation'],
                        coi=params['coi'],
                        loss_fn=params['loss'],
                        lr=params['lr'],
                        lambda_mmd=params['lambda_mmd'])
    elif name == 'dat':
        net = UNetDAT2D(in_channels=params['in_channels'],
                        feature_maps=params['fm'],
                        levels=params['levels'],
                        dropout_enc=params['dropout'],
                        dropout_dec=params['dropout'],
                        norm=params['norm'],
                        activation=params['activation'],
                        coi=params['coi'],
                        loss_fn=params['loss'],
                        lr=params['lr'],
                        lambda_dat=params['lambda_dat'],
                        input_shape=params['input_size'])
    elif name == 'ynet':
        net = YNet2D(in_channels=params['in_channels'],
                     feature_maps=params['fm'],
                     levels=params['levels'],
                     dropout_enc=params['dropout'],
                     dropout_dec=params['dropout'],
                     norm=params['norm'],
                     activation=params['activation'],
                     coi=params['coi'],
                     loss_fn=params['loss'],
                     lr=params['lr'],
                     lambda_rec=params['lambda_rec'])
    elif name == 'wnet':
        net = WNet2D(in_channels=params['in_channels'],
                     feature_maps=params['fm'],
                     levels=params['levels'],
                     dropout_enc=params['dropout'],
                     dropout_dec=params['dropout'],
                     norm=params['norm'],
                     activation=params['activation'],
                     coi=params['coi'],
                     loss_fn=params['loss'],
                     lr=params['lr'],
                     lambda_rec=params['lambda_rec'],
                     lambda_dat=params['lambda_dat'],
                     input_shape=params['input_size'])
    elif name == 'unet-ts':
        net = UNetTS2D(in_channels=params['in_channels'],
                       feature_maps=params['fm'],
                       levels=params['levels'],
                       dropout_enc=params['dropout'],
                       dropout_dec=params['dropout'],
                       norm=params['norm'],
                       activation=params['activation'],
                       coi=params['coi'],
                       loss_fn=params['loss'],
                       lr=params['lr'],
                       lambda_w=params['lambda_w'],
                       lambda_o=params['lambda_o'])
    else:
        net = UNetDA2D(in_channels=params['in_channels'],
                       feature_maps=params['fm'],
                       levels=params['levels'],
                       dropout_enc=params['dropout'],
                       dropout_dec=params['dropout'],
                       norm=params['norm'],
                       activation=params['activation'],
                       coi=params['coi'],
                       loss_fn=params['loss'],
                       lr=params['lr'])

    print_frm('Employed network: %s' % str(net.__class__.__name__))
    print_frm('    - Input channels: %d' % params['in_channels'])
    print_frm('    - Initial feature maps: %d' % params['fm'])
    print_frm('    - Levels: %d' % params['levels'])
    print_frm('    - Dropout: %.2f' % params['dropout'])
    print_frm('    - Normalization: %s' % params['norm'])
    print_frm('    - Activation: %s' % params['activation'])
    print_frm('    - Classes of interest: %s' % str(params['coi']))
    print_frm('    - Initial learning rate: %f' % params['lr'])

    return net
Example #27
0
    def test_epoch(self,
                   loader_src,
                   loader_tar_ul,
                   loader_tar_l,
                   epoch,
                   writer=None,
                   write_images=False,
                   device=0):
        """
        Trains the network for one epoch
        :param loader_src: source dataloader (labeled)
        :param loader_tar_ul: target dataloader (unlabeled)
        :param loader_tar_l: target dataloader (labeled)
        :param epoch: current epoch
        :param writer: summary writer
        :param write_images: frequency of writing images
        :param device: GPU device where the computations should occur
        :return: average training loss over the epoch
        """
        # perform training on GPU/CPU
        module_to_device(self, device)
        self.eval()

        # keep track of the average loss during the epoch
        loss_seg_src_cum = 0.0
        loss_seg_tar_cum = 0.0
        loss_rec_src_cum = 0.0
        loss_rec_tar_cum = 0.0
        loss_dc_x_cum = 0.0
        loss_dc_y_cum = 0.0
        total_loss_cum = 0.0
        cnt = 0

        # zip dataloaders
        dl = zip(loader_src, loader_tar_ul, loader_tar_l)

        # start epoch
        y_preds = []
        ys = []
        time_start = datetime.datetime.now()
        for i, data in enumerate(dl):

            # transfer to suitable device
            x_src, y_src = tensor_to_device(data[0], device)
            x_tar_ul = tensor_to_device(data[1], device)
            x_tar_l, y_tar_l = tensor_to_device(data[2], device)
            x_src = x_src.float()
            x_tar_ul = x_tar_ul.float()
            x_tar_l = x_tar_l.float()
            y_src = y_src.long()
            y_tar_l = y_tar_l.long()

            # get domain labels for domain confusion
            dom_labels_x = tensor_to_device(
                torch.zeros((x_src.size(0) + x_tar_ul.size(0))),
                device).long()
            dom_labels_x[x_src.size(0):] = 1
            dom_labels_y = tensor_to_device(
                torch.zeros((x_src.size(0) + x_tar_ul.size(0))),
                device).long()
            dom_labels_y[x_src.size(0):] = 1

            # check train mode and compute loss
            loss_seg_src = torch.Tensor([0])
            loss_seg_tar = torch.Tensor([0])
            loss_rec_src = torch.Tensor([0])
            loss_rec_tar = torch.Tensor([0])
            loss_dc_x = torch.Tensor([0])
            loss_dc_y = torch.Tensor([0])
            if self.train_mode == RECONSTRUCTION:
                x_src_rec, x_src_rec_dom = self.forward_rec(x_src)
                x_tar_ul_rec, x_tar_ul_rec_dom = self.forward_rec(x_tar_ul)
                loss_rec_src = self.rec_loss(x_src_rec, x_src)
                loss_rec_tar = self.rec_loss(x_tar_ul, x_tar_ul_rec)
                loss_dc_x = self.dc_loss(
                    torch.cat((x_src_rec_dom, x_tar_ul_rec_dom), dim=0),
                    dom_labels_x)
                total_loss = loss_rec_src + loss_rec_tar + self.lambda_dc * loss_dc_x
            elif self.train_mode == SEGMENTATION:
                # switch between reconstructed and original inputs
                if np.random.rand() < self.p:
                    y_src_pred, y_src_pred_dom = self.forward_seg(x_src)
                else:
                    x_src_rec, _ = self.forward_rec(x_src)
                    y_src_pred, y_src_pred_dom = self.forward_seg(x_src_rec)
                    dom_labels_y[:x_src.size(0)] = 1
                if np.random.rand() < self.p:
                    y_tar_ul_pred, y_tar_ul_pred_dom = self.forward_seg(
                        x_tar_ul)
                else:
                    x_tar_ul_rec, _ = self.forward_rec(x_tar_ul)
                    y_tar_ul_pred, y_tar_ul_pred_dom = self.forward_seg(
                        x_tar_ul_rec)
                    dom_labels_y[x_src.size(0):] = 1
                loss_seg_src = self.seg_loss(y_src_pred, y_src[:, 0, ...])
                loss_dc_y = self.dc_loss(
                    torch.cat((y_src_pred_dom, y_tar_ul_pred_dom), dim=0),
                    dom_labels_y)
                total_loss = loss_seg_src + self.lambda_dc * loss_dc_y
                y_tar_l_pred, _ = self.forward_seg(x_tar_l)
                loss_seg_tar = self.seg_loss(y_tar_l_pred, y_tar_l[:, 0, ...])
                total_loss = total_loss + loss_seg_tar
            else:
                x_src_rec, x_src_rec_dom = self.forward_rec(x_src)
                if np.random.rand() < self.p:
                    y_src_pred, y_src_pred_dom = self.forward_seg(x_src)
                else:
                    y_src_pred, y_src_pred_dom = self.forward_seg(x_src_rec)
                    dom_labels_y[:x_src.size(0)] = 1
                x_tar_ul_rec, x_tar_ul_rec_dom = self.forward_rec(x_tar_ul)
                if np.random.rand() < self.p:
                    y_tar_ul_pred, y_tar_ul_pred_dom = self.forward_seg(
                        x_tar_ul)
                else:
                    y_tar_ul_pred, y_tar_ul_pred_dom = self.forward_seg(
                        x_tar_ul_rec)
                    dom_labels_y[x_src.size(0):] = 1
                loss_rec_src = self.rec_loss(x_src_rec, x_src)
                loss_rec_tar = self.rec_loss(x_tar_ul, x_tar_ul_rec)
                loss_seg_src = self.seg_loss(y_src_pred, y_src[:, 0, ...])
                loss_dc_x = self.dc_loss(
                    torch.cat((x_src_rec_dom, x_tar_ul_rec_dom), dim=0),
                    dom_labels_x)
                loss_dc_y = self.dc_loss(
                    torch.cat((y_src_pred_dom, y_tar_ul_pred_dom), dim=0),
                    dom_labels_y)
                total_loss = loss_seg_src + self.lambda_rec * (loss_rec_src + loss_rec_tar) + \
                             self.lambda_dc * (loss_dc_x + loss_dc_y)
                _, y_tar_l_pred, _, y_tar_l_pred_dom = self(x_tar_l)
                loss_seg_tar = self.seg_loss(y_tar_l_pred, y_tar_l[:, 0, ...])
                total_loss = total_loss + loss_seg_tar

            loss_seg_src_cum += loss_seg_src.data.cpu().numpy()
            loss_seg_tar_cum += loss_seg_tar.data.cpu().numpy()
            loss_rec_src_cum += loss_rec_src.data.cpu().numpy()
            loss_rec_tar_cum += loss_rec_tar.data.cpu().numpy()
            loss_dc_x_cum += loss_dc_x.data.cpu().numpy()
            loss_dc_y_cum += loss_dc_y.data.cpu().numpy()
            total_loss_cum += total_loss.data.cpu().numpy()
            cnt += 1

            if self.train_mode == SEGMENTATION or self.train_mode == JOINT:
                for b in range(y_tar_l_pred.size(0)):
                    y_preds.append(
                        F.softmax(y_tar_l_pred,
                                  dim=1)[b, ...].view(y_tar_l_pred.size(1),
                                                      -1).data.cpu().numpy())
                    ys.append(y_tar_l[b, 0, ...].flatten().cpu().numpy())

        # keep track of time
        runtime = datetime.datetime.now() - time_start
        seconds = runtime.total_seconds()
        hours = seconds // 3600
        minutes = (seconds - hours * 3600) // 60
        seconds = seconds - hours * 3600 - minutes * 60
        print_frm(
            'Epoch %5d - Runtime for testing: %d hours, %d minutes, %f seconds'
            % (epoch, hours, minutes, seconds))

        # prep for metric computation
        if self.train_mode == SEGMENTATION or self.train_mode == JOINT:
            y_preds = np.concatenate(y_preds, axis=1)
            ys = np.concatenate(ys)
            js = np.asarray([
                jaccard((ys == i).astype(int), y_preds[i, :])
                for i in range(len(self.coi))
            ])
            ams = np.asarray([
                accuracy_metrics((ys == i).astype(int), y_preds[i, :])
                for i in range(len(self.coi))
            ])

        # don't forget to compute the average and print it
        loss_seg_src_avg = loss_seg_src_cum / cnt
        loss_seg_tar_avg = loss_seg_tar_cum / cnt
        loss_rec_src_avg = loss_rec_src_cum / cnt
        loss_rec_tar_avg = loss_rec_tar_cum / cnt
        loss_dc_x_avg = loss_dc_x_cum / cnt
        loss_dc_y_avg = loss_dc_y_cum / cnt
        total_loss_avg = total_loss_cum / cnt
        print(
            '[%s] Testing Epoch %5d - Loss seg src: %.6f - Loss seg tar: %.6f - Loss rec src: %.6f - Loss rec tar: %.6f - Loss DCX: %.6f - Loss DCY: %.6f - Loss: %.6f'
            % (datetime.datetime.now(), epoch, loss_seg_src_avg,
               loss_seg_tar_avg, loss_rec_src_avg, loss_rec_tar_avg,
               loss_dc_x_avg, loss_dc_y_avg, total_loss_avg))

        # log everything
        if writer is not None:

            # always log scalars
            if self.train_mode == RECONSTRUCTION:
                log_scalars(
                    [loss_rec_src_avg, loss_rec_tar_avg, loss_dc_x_avg], [
                        'test/' + s
                        for s in ['loss-rec-src', 'loss-rec-tar', 'loss-dc-x']
                    ],
                    writer,
                    epoch=epoch)
            elif self.train_mode == SEGMENTATION:
                log_scalars([
                    loss_seg_src_avg, loss_seg_tar_avg, loss_dc_y_avg,
                    np.mean(js, axis=0), *(np.mean(ams, axis=0))
                ], [
                    'test/' + s for s in [
                        'loss-seg-src', 'loss-seg-tar', 'loss-dc-y', 'jaccard',
                        'accuracy', 'balanced-accuracy', 'precision', 'recall',
                        'f-score'
                    ]
                ],
                            writer,
                            epoch=epoch)
            else:
                log_scalars([
                    loss_seg_src_avg, loss_seg_tar_avg, loss_rec_src_avg,
                    loss_rec_tar_avg, loss_dc_x_avg, loss_dc_y_avg,
                    np.mean(js, axis=0), *(np.mean(ams, axis=0))
                ], [
                    'test/' + s for s in [
                        'loss-seg-src', 'loss-seg-tar', 'loss-rec-src',
                        'loss-rec-tar', 'loss-dc-x', 'loss-dc-y', 'jaccard',
                        'accuracy', 'balanced-accuracy', 'precision', 'recall',
                        'f-score'
                    ]
                ],
                            writer,
                            epoch=epoch)
            log_scalars([total_loss_avg],
                        ['test/' + s for s in ['total-loss']],
                        writer,
                        epoch=epoch)

            # log images if necessary
            if write_images:
                log_images_2d([x_src.data], ['test/' + s for s in ['src/x']],
                              writer,
                              epoch=epoch)
                if self.train_mode == RECONSTRUCTION:
                    log_images_2d(
                        [x_src_rec.data, x_tar_ul.data, x_tar_ul_rec.data], [
                            'test/' + s
                            for s in ['src/x-rec', 'tar/x-ul', 'tar/x-ul-rec']
                        ],
                        writer,
                        epoch=epoch)
                elif self.train_mode == SEGMENTATION:
                    y_src_pred = F.softmax(y_src_pred, dim=1)[:,
                                                              1:2, :, :].data
                    log_images_2d(
                        [y_src.data, y_src_pred],
                        ['test/' + s for s in ['src/y', 'src/y-pred']],
                        writer,
                        epoch=epoch)
                    if loader_tar_l is not None:
                        y_tar_l_pred = F.softmax(y_tar_l_pred,
                                                 dim=1)[:, 1:2, :, :].data
                        log_images_2d([x_tar_l.data, y_tar_l, y_tar_l_pred], [
                            'test/' + s
                            for s in ['tar/x-l', 'tar/y-l', 'tar/y-l-pred']
                        ],
                                      writer,
                                      epoch=epoch)
                else:
                    y_src_pred = F.softmax(y_src_pred, dim=1)[:,
                                                              1:2, :, :].data
                    log_images_2d([
                        x_src_rec.data, y_src.data, y_src_pred, x_tar_ul.data,
                        x_tar_ul_rec.data
                    ], [
                        'test/' + s for s in [
                            'src/x-rec', 'src/y', 'src/y-pred', 'tar/x-ul',
                            'tar/x-ul-rec'
                        ]
                    ],
                                  writer,
                                  epoch=epoch)
                    if loader_tar_l is not None:
                        y_tar_l_pred = F.softmax(y_tar_l_pred,
                                                 dim=1)[:, 1:2, :, :].data
                        log_images_2d([x_tar_l.data, y_tar_l, y_tar_l_pred], [
                            'test/' + s
                            for s in ['tar/x-l', 'tar/y-l', 'tar/y-l-pred']
                        ],
                                      writer,
                                      epoch=epoch)

        return total_loss_avg
Example #28
0
    def __init__(self,
                 data_path,
                 label_path,
                 input_shape=None,
                 split_orientation='z',
                 split_location=0.50,
                 scaling=None,
                 len_epoch=1000,
                 type='tif3d',
                 coi=(0, 1),
                 in_channels=1,
                 orientations=(0, ),
                 batch_size=1,
                 data_dtype='uint8',
                 label_dtype='uint8',
                 norm_type='unit',
                 train=True,
                 available=-1):
        super().__init__(data_path,
                         input_shape,
                         split_orientation=split_orientation,
                         split_location=split_location,
                         scaling=scaling,
                         len_epoch=len_epoch,
                         type=type,
                         in_channels=in_channels,
                         orientations=orientations,
                         batch_size=batch_size,
                         dtype=data_dtype,
                         norm_type=norm_type,
                         train=train)

        self.label_path = label_path
        self.coi = coi
        self.available = available

        # load labels
        d = 0 if split_orientation == 'z' else 1 if split_orientation == 'y' else 2
        if split_orientation == 'z':
            split = int(len(os.listdir(label_path)) * split_location)
            start = 0 if train else split
            stop = split if train else -1
            self.labels = read_volume(label_path,
                                      type=type,
                                      dtype=label_dtype,
                                      start=start,
                                      stop=stop)
        else:
            data = read_volume(label_path, type=type, dtype=label_dtype)
            split = int(data.shape[d] * split_location)
            if split_orientation == 'y':
                self.labels = data[:, :split, :] if train else data[:,
                                                                    split:, :]
            else:
                self.labels = data[:, :, :split] if train else data[:, :,
                                                                    split:]

        # rescale the dataset if necessary
        if scaling is not None:
            target_size = np.asarray(np.multiply(self.labels.shape, scaling),
                                     dtype=int)
            self.labels = F.interpolate(torch.Tensor(self.labels[np.newaxis,
                                                                 np.newaxis,
                                                                 ...]),
                                        size=tuple(target_size),
                                        mode='area')[0, 0, ...].numpy()

        # select a crop of the data if necessary
        print_frm('Original dataset size: %d x %d x %d (total: %d)' %
                  (self.data.shape[0], self.data.shape[1], self.data.shape[2],
                   self.data.size))
        if available > 0:
            self.data, self.labels = _select_subset(self.data,
                                                    self.labels,
                                                    n=available,
                                                    coi=coi)
        t_str = 'training' if train else 'testing'
        print_frm('Used for %s: %d x %d x %d (total: %d)' %
                  (t_str, self.data.shape[0], self.data.shape[1],
                   self.data.shape[2], self.data.size))
Example #29
0
    def train_epoch(self,
                    loader_src,
                    loader_tar_ul,
                    loader_tar_l,
                    optimizer,
                    epoch,
                    augmenter=None,
                    print_stats=1,
                    writer=None,
                    write_images=False,
                    device=0):
        """
        Trains the network for one epoch
        :param loader_src: source dataloader (labeled)
        :param loader_tar_ul: target dataloader (unlabeled)
        :param loader_tar_l: target dataloader (labeled)
        :param optimizer: optimizer for the loss function
        :param epoch: current epoch
        :param augmenter: data augmenter
        :param print_stats: frequency of printing statistics
        :param writer: summary writer
        :param write_images: frequency of writing images
        :param device: GPU device where the computations should occur
        :return: average training loss over the epoch
        """
        # perform training on GPU/CPU
        module_to_device(self, device)
        self.train()

        # keep track of the average loss during the epoch
        loss_seg_src_cum = 0.0
        loss_seg_tar_cum = 0.0
        total_loss_cum = 0.0
        cnt = 0

        # zip dataloaders
        if loader_tar_l is None:
            dl = zip(loader_src)
        else:
            dl = zip(loader_src, loader_tar_l)

        # start epoch
        time_start = datetime.datetime.now()
        for i, data in enumerate(dl):

            # transfer to suitable device
            data_src = tensor_to_device(data[0], device)
            if loader_tar_l is not None:
                data_tar_l = tensor_to_device(data[1], device)

            # augment if necessary
            if loader_tar_l is None:
                data_aug = (data_src[0], data_src[1])
                x_src, y_src = augment_samples(data_aug, augmenter=augmenter)
            else:
                data_aug = (data_src[0], data_src[1])
                x_src, y_src = augment_samples(data_aug, augmenter=augmenter)
                data_aug = (data_tar_l[0], data_tar_l[1])
                x_tar_l, y_tar_l = augment_samples(data_aug,
                                                   augmenter=augmenter)
                y_tar_l = get_labels(y_tar_l, coi=self.coi, dtype=int)

            # zero the gradient buffers
            self.zero_grad()

            # forward prop and compute loss
            loss_seg_tar = torch.Tensor([0])
            y_src_pred = self(x_src)
            loss_seg_src = self.seg_loss(y_src_pred, y_src[:, 0, ...])
            total_loss = loss_seg_src
            if loader_tar_l is not None:
                y_tar_l_pred = self(x_tar_l)
                loss_seg_tar = self.seg_loss(y_tar_l_pred, y_tar_l[:, 0, ...])
                total_loss = total_loss + loss_seg_tar

            loss_seg_src_cum += loss_seg_src.data.cpu().numpy()
            loss_seg_tar_cum += loss_seg_tar.data.cpu().numpy()
            total_loss_cum += total_loss.data.cpu().numpy()
            cnt += 1

            # backward prop
            total_loss.backward()

            # apply one step in the optimization
            optimizer.step()

            # print statistics of necessary
            if i % print_stats == 0:
                print(
                    '[%s] Epoch %5d - Iteration %5d/%5d - Loss seg src: %.6f - Loss seg tar: %.6f - Loss: %.6f'
                    % (datetime.datetime.now(), epoch, i,
                       len(loader_src.dataset) / loader_src.batch_size,
                       loss_seg_src_cum / cnt, loss_seg_tar_cum / cnt,
                       total_loss_cum / cnt))

        # keep track of time
        runtime = datetime.datetime.now() - time_start
        seconds = runtime.total_seconds()
        hours = seconds // 3600
        minutes = (seconds - hours * 3600) // 60
        seconds = seconds - hours * 3600 - minutes * 60
        print_frm(
            'Epoch %5d - Runtime for training: %d hours, %d minutes, %f seconds'
            % (epoch, hours, minutes, seconds))

        # don't forget to compute the average and print it
        loss_seg_src_avg = loss_seg_src_cum / cnt
        loss_seg_tar_avg = loss_seg_tar_cum / cnt
        total_loss_avg = total_loss_cum / cnt
        print(
            '[%s] Training Epoch %4d - Loss seg src: %.6f - Loss seg tar: %.6f - Loss: %.6f'
            % (datetime.datetime.now(), epoch, loss_seg_src_avg,
               loss_seg_tar_avg, total_loss_avg))

        # log everything
        if writer is not None:

            # always log scalars
            log_scalars([loss_seg_src_avg, loss_seg_tar_avg, total_loss_avg], [
                'train/' + s
                for s in ['loss-seg-src', 'loss-seg-tar', 'total-loss']
            ],
                        writer,
                        epoch=epoch)

            # log images if necessary
            if write_images:
                y_src_pred = F.softmax(y_src_pred, dim=1)[:, 1:2, :, :].data
                log_images_2d(
                    [x_src.data, y_src.data, y_src_pred],
                    ['train/' + s for s in ['src/x', 'src/y', 'src/y-pred']],
                    writer,
                    epoch=epoch)
                if loader_tar_l is not None:
                    y_tar_l_pred = F.softmax(y_tar_l_pred,
                                             dim=1)[:, 1:2, :, :].data
                    log_images_2d([x_tar_l.data, y_tar_l, y_tar_l_pred], [
                        'train/' + s
                        for s in ['tar/x-l', 'tar/y-l', 'tar/y-l-pred']
                    ],
                                  writer,
                                  epoch=epoch)

        return total_loss_avg
Example #30
0
from neuralnets.util.io import print_frm, save
from neuralnets.util.tools import set_seed, log_hparams

from multiprocessing import freeze_support
from sklearn.model_selection import GridSearchCV
from torch.utils.data import DataLoader

from util.tools import parse_params, parse_search_grid, get_transforms
from networks.factory import generate_classifier

if __name__ == '__main__':
    freeze_support()
    """
        Parse all the arguments
    """
    print_frm('Parsing arguments')
    parser = argparse.ArgumentParser()
    parser.add_argument("--config",
                        "-c",
                        help="Path to the configuration file",
                        type=str,
                        default='cross_validate.yaml')
    args = parser.parse_args()
    with open(args.config) as file:
        params = parse_params(yaml.load(file, Loader=yaml.FullLoader))
    """
    Fix seed (for reproducibility)
    """
    set_seed(params['seed'])
    """
        Load the data