예제 #1
0
def validate(state_dict_path, use_gpu, device):
    model = UNet(n_channels=1, n_classes=2)
    model.load_state_dict(torch.load(state_dict_path, map_location='cpu' if not use_gpu else device))
    model.to(device)
    val_transforms = transforms.Compose([
        ToTensor(), 
        NormalizeBRATS()])

    BraTS_val_ds = BRATS2018('./BRATS2018',\
        data_set='val',\
        seg_type='et',\
        scan_type='t1ce',\
        transform=val_transforms)

    data_loader = DataLoader(BraTS_val_ds, batch_size=2, shuffle=False, num_workers=0)

    running_dice_score = 0.

    for batch_ind, batch in enumerate(data_loader):
        imgs, targets = batch
        imgs = imgs.to(device)
        targets = targets.to(device)
        
        model.eval()

        with torch.no_grad():
            outputs = model(imgs)
            preds = torch.argmax(F.softmax(outputs, dim=1), dim=1)

            running_dice_score += dice_score(preds, targets) * targets.size(0)
            print('running dice score: {:.6f}'.format(running_dice_score))
    
    dice = running_dice_score / len(BraTS_val_ds)
    print('mean dice score of the validating set: {:.6f}'.format(dice))
예제 #2
0
def train():
    args = setup_run_arguments()

    # args = parse_args()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"[INFO] Initializing UNet-model using: {device}")

    net = UNet(n_channels=args.n_channels, n_classes=args.n_classes, bilinear=True)

    if args.from_pretrained:
        net.load_state_dict(torch.load(args.from_pretrained, map_location=device))

    net.to(device=device)

    training_loop.run(network=net,
                      epochs=args.epochs,
                      batch_size=args.batch_size,
                      lr=args.learning_rate,
                      device=device,
                      n_classes=args.n_classes,
                      val_percent=args.val_percent,
                      image_dir=args.image_dir,
                      mask_dir=args.mask_dir,
                      checkpoint_path=args.checkpoint_path,
                      loss=args.loss,
                      num_workers=args.num_workers
                      )
예제 #3
0
class EventGANBase(object):
    def __init__(self, options):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.generator = UNet(num_input_channels=2*options.n_image_channels,
                              num_output_channels=options.n_time_bins * 2,
                              skip_type='concat',
                              activation='relu',
                              num_encoders=4,
                              base_num_channels=32,
                              num_residual_blocks=2,
                              norm='BN',
                              use_upsample_conv=True,
                              with_activation=True,
                              sn=options.sn,
                              multi=False)
        latest_checkpoint = get_latest_checkpoint(options.checkpoint_dir)
        checkpoint = torch.load(latest_checkpoint)
        self.generator.load_state_dict(checkpoint["gen"])
        self.generator.to(self.device)
        
    def forward(self, images, is_train=False):
        if len(images.shape) == 3:
            images = images[None, ...]
        assert len(images.shape) == 4 and images.shape[1] == 2, \
            "Input images must be either 2xHxW or Bx2xHxW."
        if not is_train:
            with torch.no_grad():
                self.generator.eval()
                event_volume = self.generator(images)
            self.generator.train()
        else:
            event_volume = self.generator(images)

        return event_volume
def prediction_to_json(image_path, chkp_path, net=None) -> dict:
    """
    Convert mask prediction to json. The format matches the format in the training annotation data:

    {'filename':file_name, 'labels':
    [{'name': label_name, 'annotations': [{'id':some_unique_integer_id, 'segmentation':[x,y,x,y,x,y....]}
                                             ....] }
        ....]
        }
    """
    file_name = os.path.basename(image_path)
    annotation = {'filename': file_name, 'labels': []}

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if not net:
        net = UNet(n_channels=3, n_classes=4)

        net.to(device=device)
        net.load_state_dict(torch.load(chkp_path, map_location=device))

    img = Image.open(image_path)

    msk = predict_on_image(net=net, device=device, src_img=img)
    msk = msk.transpose((1, 2, 0))

    h, w, n_labels = msk.shape
    rgb_mask = np.ones((h, w, 3), dtype=np.uint8)
    annotation['height'] = h
    annotation['width'] = w

    for label in range(1, n_labels):
        color = hex_labels[str(label)]
        category = category_labels[str(label)]
        c_label = {'color': color, 'name': category, 'annotations': []}

        label_mask = msk[:, :, label].astype(int).astype(np.uint8)
        contours, hierarchy = cv2.findContours(label_mask, cv2.RETR_TREE,
                                               cv2.CHAIN_APPROX_SIMPLE)

        for contour in contours:
            vector_points = []
            for x, y in contour.reshape((len(contour), 2)):
                vector_points += [float(x), float(y)]

            c_label['annotations'].append({'segmentation': vector_points})

        idx = np.where(msk[:, :, label].astype(int) == 1)
        rgb_mask[idx] = colors_from_hex[str(label)]

        annotation['labels'].append(c_label)

    return annotation
예제 #5
0
    def load_finetuned_model(self, baseline_model):
        """
        Loads the augmentation net, sample reweighting net, and baseline model
        Note: sets all these models to train mode
        """
        # augment_net = Net(0, 0.0, 32, 3, 0.0, num_classes=32**2 * 3, do_res=True)
        if self.args.dataset == DATASET_MNIST:
            imsize, in_channel, num_classes = 28, 1, 10
        else:
            imsize, in_channel, num_classes = 32, 3, 10

        augment_net = UNet(
            in_channels=in_channel,
            n_classes=in_channel,
            depth=2,
            wf=3,
            padding=True,
            batch_norm=False,
            do_noise_channel=True,
            up_mode='upconv',
            use_identity_residual=True)  # TODO(PV): Initialize UNet properly
        # TODO (JON): DEPTH 1 WORKED WELL.  Changed upconv to upsample.  Use a wf of 2.

        # This ResNet outputs scalar weights to be applied element-wise to the per-example losses
        reweighting_net = Net(1, 0.0, imsize, in_channel, 0.0, num_classes=1)
        # resnet_cifar.resnet20(num_classes=1)

        if self.args.load_finetune_checkpoint:
            checkpoint = torch.load(self.args.load_finetune_checkpoint)
            # temp_baseline_model = baseline_model
            # baseline_model.load_state_dict(checkpoint['elementary_model_state_dict'])
            if 'weight_decay' in checkpoint:
                baseline_model.weight_decay = checkpoint['weight_decay']
            # baseline_model.weight_decay = temp_baseline_model.weight_decay
            # baseline_model.load_state_dict(checkpoint['elementary_model_state_dict'])
            augment_net.load_state_dict(checkpoint['augment_model_state_dict'])
            try:
                reweighting_net.load_state_dict(
                    checkpoint['reweighting_model_state_dict'])
            except KeyError:
                pass

        augment_net, reweighting_net, baseline_model = augment_net.to(
            self.device), reweighting_net.to(self.device), baseline_model.to(
                self.device)
        augment_net.train(), reweighting_net.train(), baseline_model.train()
        return augment_net, reweighting_net, baseline_model
예제 #6
0
        Learning rate:     {args.lr}
        Weight decay:      {args.weight_decay}
        Device:            GPU{args.gpu}
        Log name:          {args.save}
    ''')

    torch.cuda.set_device(args.gpu)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # choose a model
    if args.model == 'unet':
        net = UNet()
    elif args.model == 'nestedunet':
        net = NestedUNet()

    net.to(device=device)

    # choose a dataset

    if args.dataset == 'promise12':
        dir_data = '../data/promise12'
        trainset = Promise12(dir_data, mode='train')
        valset = Promise12(dir_data, mode='val')
    elif args.dataset == 'chaos':
        dir_data = '../data/chaos'
        trainset = Chaos(dir_data, mode='train')
        valset = Chaos(dir_data, mode='val')

    try:
        train_net(net=net,
                  trainset=trainset,
예제 #7
0
def train(input_data_type,
          grade,
          seg_type,
          num_classes,
          batch_size,
          epochs,
          use_gpu,
          learning_rate,
          w_decay,
          pre_trained=False):
    logger.info('Start training using {} modal.'.format(input_data_type))
    model = UNet(4, 4, residual=True, expansion=2)

    criterion = nn.CrossEntropyLoss()

    optimizer = optim.Adam(params=model.parameters(),
                           lr=learning_rate,
                           weight_decay=w_decay)

    if pre_trained:
        checkpoint = torch.load(pre_trained_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])

    if use_gpu:
        ts = time.time()
        model.to(device)

        print("Finish cuda loading, time elapsed {}".format(time.time() - ts))

    scheduler = lr_scheduler.StepLR(
        optimizer, step_size=step_size,
        gamma=gamma)  # decay LR by a factor of 0.5 every 5 epochs

    data_set, data_loader = get_dataset_dataloader(input_data_type,
                                                   seg_type,
                                                   batch_size,
                                                   grade=grade)

    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_iou = 0.0

    epoch_loss = np.zeros((2, epochs))
    epoch_acc = np.zeros((2, epochs))
    epoch_class_acc = np.zeros((2, epochs))
    epoch_mean_iou = np.zeros((2, epochs))
    evaluator = Evaluator(num_classes)

    def term_int_handler(signal_num, frame):
        np.save(os.path.join(score_dir, 'epoch_accuracy'), epoch_acc)
        np.save(os.path.join(score_dir, 'epoch_mean_iou'), epoch_mean_iou)
        np.save(os.path.join(score_dir, 'epoch_loss'), epoch_loss)

        model.load_state_dict(best_model_wts)

        logger.info('Got terminated and saved model.state_dict')
        torch.save(model.state_dict(),
                   os.path.join(score_dir, 'terminated_model.pt'))
        torch.save(
            {
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict()
            }, os.path.join(score_dir, 'terminated_model.tar'))

        quit()

    signal.signal(signal.SIGINT, term_int_handler)
    signal.signal(signal.SIGTERM, term_int_handler)

    for epoch in range(epochs):
        logger.info('Epoch {}/{}'.format(epoch + 1, epochs))
        logger.info('-' * 28)

        for phase_ind, phase in enumerate(['train', 'val']):
            if phase == 'train':
                model.train()
                logger.info(phase)
            else:
                model.eval()
                logger.info(phase)

            evaluator.reset()
            running_loss = 0.0
            running_dice = 0.0

            for batch_ind, batch in enumerate(data_loader[phase]):
                imgs, targets = batch
                imgs = imgs.to(device)
                targets = targets.to(device)

                # zero the learnable parameters gradients
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(imgs)
                    loss = criterion(outputs, targets)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                preds = torch.argmax(F.softmax(outputs, dim=1),
                                     dim=1,
                                     keepdim=True)
                running_loss += loss * imgs.size(0)
                logger.debug('Batch {} running loss: {:.4f}'.format(batch_ind,\
                    running_loss))

                # test the iou and pixelwise accuracy using evaluator
                preds = torch.squeeze(preds, dim=1)
                preds = preds.cpu().numpy()
                targets = targets.cpu().numpy()
                evaluator.add_batch(targets, preds)

            epoch_loss[phase_ind, epoch] = running_loss / len(data_set[phase])
            epoch_acc[phase_ind, epoch] = evaluator.Pixel_Accuracy()
            epoch_class_acc[phase_ind,
                            epoch] = evaluator.Pixel_Accuracy_Class()
            epoch_mean_iou[phase_ind,
                           epoch] = evaluator.Mean_Intersection_over_Union()

            logger.info('{} loss: {:.4f}, acc: {:.4f}, class acc: {:.4f}, mean iou: {:.6f}'.format(phase,\
                epoch_loss[phase_ind, epoch],\
                epoch_acc[phase_ind, epoch],\
                epoch_class_acc[phase_ind, epoch],\
                epoch_mean_iou[phase_ind, epoch]))

            if phase == 'val' and epoch_mean_iou[phase_ind, epoch] > best_iou:
                best_iou = epoch_mean_iou[phase_ind, epoch]
                best_model_wts = copy.deepcopy(model.state_dict())

            if phase == 'val' and (epoch + 1) % 10 == 0:
                logger.info('Saved model.state_dict in epoch {}'.format(epoch +
                                                                        1))
                torch.save(
                    model.state_dict(),
                    os.path.join(score_dir,
                                 'epoch{}_model.pt'.format(epoch + 1)))

        print()

    time_elapsed = time.time() - since
    logger.info('Training completed in {}m {}s'.format(int(time_elapsed / 60),\
        int(time_elapsed) % 60))

    # load best model weights
    model.load_state_dict(best_model_wts)

    # save numpy results
    np.save(os.path.join(score_dir, 'epoch_accuracy'), epoch_acc)
    np.save(os.path.join(score_dir, 'epoch_mean_iou'), epoch_mean_iou)
    np.save(os.path.join(score_dir, 'epoch_loss'), epoch_loss)

    return model, optimizer
예제 #8
0
class NNUnet(pl.LightningModule):
    def __init__(self, args):
        super(NNUnet, self).__init__()
        self.args = args
        if not hasattr(self.args, "drop_block"):  # For backward compability
            self.args.drop_block = False
        self.save_hyperparameters()
        self.build_nnunet()
        self.loss = Loss(self.args.focal)
        self.dice = Dice(self.n_class)
        self.best_sum = 0
        self.best_sum_epoch = 0
        self.best_dice = self.n_class * [0]
        self.best_epoch = self.n_class * [0]
        self.best_sum_dice = self.n_class * [0]
        self.learning_rate = args.learning_rate
        self.tta_flips = get_tta_flips(args.dim)
        self.test_idx = 0
        self.test_imgs = []
        if self.args.exec_mode in ["train", "evaluate"]:
            self.dllogger = get_dllogger(args.results)

    def forward(self, img):
        if self.args.benchmark:
            if self.args.dim == 2 and self.args.data2d_dim == 3:
                img = layout_2d(img, None)
            return self.model(img)
        return self.tta_inference(img) if self.args.tta else self.do_inference(
            img)

    def training_step(self, batch, batch_idx):
        img, lbl = self.get_train_data(batch)
        pred = self.model(img)
        loss = self.compute_loss(pred, lbl)
        mark_step(self.args.run_lazy_mode)
        return loss

    def on_before_zero_grad(self, optimizer):
        mark_step(self.args.run_lazy_mode)

    def on_after_backward(self):
        mark_step(self.args.run_lazy_mode)

    def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
                       optimizer_closure, on_tpu, using_native_amp,
                       using_lbfgs):
        optimizer.step(closure=optimizer_closure)
        mark_step(self.args.run_lazy_mode)

    def validation_step(self, batch, batch_idx):
        if self.current_epoch < self.args.skip_first_n_eval:
            return None
        img, lbl = batch["image"], batch["label"]
        if self.args.hpus:
            img, lbl = img.to(torch.device("hpu"),
                              non_blocking=False), lbl.to(torch.device("hpu"),
                                                          non_blocking=False)
        pred = self.forward(img)
        loss = self.loss(pred, lbl)
        self.dice.update(pred, lbl[:, 0])
        mark_step(self.args.run_lazy_mode)
        return {"val_loss": loss}

    def test_step(self, batch, batch_idx):
        print("Start test")
        if self.args.exec_mode == "evaluate":
            return self.validation_step(batch, batch_idx)
        img = batch["image"]
        if self.args.hpus:
            img = img.to(torch.device("hpu"), non_blocking=False)
        if self.args.channels_last:
            if img.ndim == 4 or self.args.dim == 2:
                img = img.contiguous(memory_format=torch.channels_last)
            elif img.ndim == 5 and self.args.dim == 3:
                img = img.contiguous(memory_format=torch.channels_last_3d)
            mark_step(self.args.run_lazy_mode)

        pred = self.forward(img)
        mark_step(self.args.run_lazy_mode)

        if self.args.save_preds:
            meta = batch["meta"][0].cpu().detach().numpy()
            original_shape = meta[2]
            min_d, max_d = meta[0, 0], meta[1, 0]
            min_h, max_h = meta[0, 1], meta[1, 1]
            min_w, max_w = meta[0, 2], meta[1, 2]

            final_pred = torch.zeros((1, pred.shape[1], *original_shape),
                                     device=img.device)
            final_pred[:, :, min_d:max_d, min_h:max_h, min_w:max_w] = pred
            final_pred = nn.functional.softmax(final_pred, dim=1)
            final_pred = final_pred.squeeze(0).cpu().detach().numpy()

            if not all(original_shape == final_pred.shape[1:]):
                class_ = final_pred.shape[0]
                resized_pred = np.zeros((class_, *original_shape))
                for i in range(class_):
                    resized_pred[i] = resize(final_pred[i],
                                             original_shape,
                                             order=3,
                                             mode="edge",
                                             cval=0,
                                             clip=True,
                                             anti_aliasing=False)
                final_pred = resized_pred

            self.save_mask(final_pred)

    def on_save_checkpoint(self, checkpoint):
        if not self.args.hpus:
            return
        state_dict = checkpoint['state_dict']
        optimizer_states = checkpoint['optimizer_states']
        optimizer_state_dict = optimizer_states[0]['state']

        for k, v in checkpoint["callbacks"].items():
            if isinstance(v, dict):
                for k1, v1 in v.items():
                    if isinstance(v1, torch.Tensor):
                        v[k1] = v1.to("cpu")

        adjust_tensors_for_save(state_dict,
                                optimizer_state_dict,
                                to_device="cpu",
                                to_filters_last=False,
                                lazy_mode=self.args.run_lazy_mode,
                                permute=True)

    def build_nnunet(self):
        in_channels, n_class, kernels, strides, self.patch_size = get_unet_params(
            self.args)
        self.n_class = n_class - 1
        self.model = UNet(
            in_channels=in_channels,
            n_class=n_class,
            kernels=kernels,
            strides=strides,
            dimension=self.args.dim,
            residual=self.args.residual,
            attention=self.args.attention,
            drop_block=self.args.drop_block,
            normalization_layer=self.args.norm,
            negative_slope=self.args.negative_slope,
            deep_supervision=self.args.deep_supervision,
        )
        if is_main_process():
            print(
                f"Filters: {self.model.filters},\nKernels: {kernels}\nStrides: {strides}"
            )

    def compute_loss(self, preds, label):
        if self.args.deep_supervision:
            loss = self.loss(preds[0], label)
            for i, pred in enumerate(preds[1:]):
                downsampled_label = nn.functional.interpolate(
                    label, pred.shape[2:])
                loss += 0.5**(i + 1) * self.loss(pred, downsampled_label)
            c_norm = 1 / (2 - 2**(-len(preds)))
            return c_norm * loss
        return self.loss(preds, label)

    def do_inference(self, image):
        if self.args.dim == 3:
            return self.sliding_window_inference(image)
        if self.args.data2d_dim == 2:
            return self.model(image)
        if self.args.exec_mode == "predict":
            return self.inference2d_test(image)
        return self.inference2d(image)

    def tta_inference(self, img):
        pred = self.do_inference(img)
        for flip_idx in self.tta_flips:
            pred += flip(self.do_inference(flip(img, flip_idx)), flip_idx)
        pred /= len(self.tta_flips) + 1
        return pred

    def inference2d(self, image):
        batch_modulo = image.shape[2] % self.args.val_batch_size
        if batch_modulo != 0:
            batch_pad = self.args.val_batch_size - batch_modulo
            image = nn.ConstantPad3d((0, 0, 0, 0, batch_pad, 0), 0)(image)
            mark_step(self.args.run_lazy_mode)
        image = torch.transpose(image.squeeze(0), 0, 1)
        preds_shape = (image.shape[0], self.n_class + 1, *image.shape[2:])
        if self.args.hpus:
            preds = None
            for start in range(0,
                               image.shape[0] - self.args.val_batch_size + 1,
                               self.args.val_batch_size):
                end = start + self.args.val_batch_size
                pred = self.model(image[start:end])
                preds = pred if preds == None else torch.cat(
                    (preds, pred), dim=0)
                mark_step(self.args.run_lazy_mode)
            if batch_modulo != 0:
                preds = preds[batch_pad:]
                mark_step(self.args.run_lazy_mode)
        else:
            preds = torch.zeros(preds_shape,
                                dtype=image.dtype,
                                device=image.device)
            for start in range(0,
                               image.shape[0] - self.args.val_batch_size + 1,
                               self.args.val_batch_size):
                end = start + self.args.val_batch_size
                pred = self.model(image[start:end])
                preds[start:end] = pred.data
            if batch_modulo != 0:
                preds = preds[batch_pad:]
        return torch.transpose(preds, 0, 1).unsqueeze(0)

    def inference2d_test(self, image):
        preds_shape = (image.shape[0], self.n_class + 1, *image.shape[2:])
        preds = torch.zeros(preds_shape,
                            dtype=image.dtype,
                            device=image.device)
        for depth in range(image.shape[2]):
            preds[:, :, depth] = self.sliding_window_inference(image[:, :,
                                                                     depth])
        return preds

    def sliding_window_inference(self, image):
        if self.args.hpus:
            from models.monai_sliding_window_inference import sliding_window_inference
        else:
            from monai.inferers import sliding_window_inference
        return sliding_window_inference(
            inputs=image,
            roi_size=self.patch_size,
            sw_batch_size=self.args.val_batch_size,
            predictor=self.model,
            overlap=self.args.overlap,
            mode=self.args.blend,
        )

    @staticmethod
    def metric_mean(name, outputs):
        return torch.stack([out[name] for out in outputs]).mean(dim=0)

    def validation_epoch_end(self, outputs):

        if self.current_epoch < self.args.skip_first_n_eval:
            self.log("dice_sum", 0.001 * self.current_epoch)
            self.dice.reset()
            return None
        loss = self.metric_mean("val_loss", outputs)
        dice = self.dice.compute()
        dice_sum = torch.sum(dice)
        if dice_sum >= self.best_sum:
            self.best_sum = dice_sum
            self.best_sum_dice = dice[:]
            self.best_sum_epoch = self.current_epoch
        for i, dice_i in enumerate(dice):
            if dice_i > self.best_dice[i]:
                self.best_dice[i], self.best_epoch[
                    i] = dice_i, self.current_epoch

        if is_main_process():
            metrics = {}
            metrics.update({"mean dice": round(torch.mean(dice).item(), 2)})
            metrics.update(
                {"TOP_mean": round(torch.mean(self.best_sum_dice).item(), 2)})
            if self.n_class > 1:
                metrics.update({
                    f"L{i+1}": round(m.item(), 2)
                    for i, m in enumerate(dice)
                })
                metrics.update({
                    f"TOP_L{i+1}": round(m.item(), 2)
                    for i, m in enumerate(self.best_sum_dice)
                })
            metrics.update({"val_loss": round(loss.item(), 4)})
            self.dllogger.log(step=self.current_epoch, data=metrics)
            self.dllogger.flush()

        self.log("val_loss", loss)
        self.log("dice_sum", dice_sum)

    def test_epoch_end(self, outputs):
        if self.args.exec_mode == "evaluate":
            self.eval_dice = self.dice.compute()

    def configure_optimizers(self):
        if self.args.hpus:
            self.model = self.model.to(get_device(self.args))
            permute_params(self.model, True, self.args.run_lazy_mode)
        # Avoid instantiate optimizers if not have to
        # since might not be supported
        if self.args.optimizer.lower() == 'sgd':
            optimizer = SGD(self.parameters(),
                            lr=self.learning_rate,
                            momentum=self.args.momentum)
        elif self.args.optimizer.lower() == 'adam':
            optimizer = Adam(self.parameters(),
                             lr=self.learning_rate,
                             weight_decay=self.args.weight_decay)
        elif self.args.optimizer.lower() == 'radam':
            optimizer = RAdam(self.parameters(),
                              lr=self.learning_rate,
                              weight_decay=self.args.weight_decay)
        elif self.args.optimizer.lower() == 'adamw':
            optimizer = torch.optim.AdamW(self.parameters(),
                                          lr=self.learning_rate,
                                          weight_decay=self.args.weight_decay)
        elif self.args.optimizer.lower() == 'fusedadamw':
            from habana_frameworks.torch.hpex.optimizers import FusedAdamW
            optimizer = FusedAdamW(self.parameters(),
                                   lr=self.learning_rate,
                                   eps=1e-08,
                                   weight_decay=self.args.weight_decay)
        else:
            assert False, "optimizer {} not suppoerted".format(
                self.args.optimizer.lower())

        scheduler = {
            "none":
            None,
            "multistep":
            torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                 self.args.steps,
                                                 gamma=self.args.factor),
            "cosine":
            torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                       self.args.max_epochs),
            "plateau":
            torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer,
                factor=self.args.factor,
                patience=self.args.lr_patience),
        }[self.args.scheduler.lower()]

        opt_dict = {"optimizer": optimizer, "monitor": "val_loss"}
        if scheduler is not None:
            opt_dict.update({"lr_scheduler": scheduler})
        return opt_dict

    def save_mask(self, pred):
        if self.test_idx == 0:
            data_path = get_path(self.args)
            self.test_imgs, _ = get_test_fnames(self.args, data_path)
        fname = os.path.basename(self.test_imgs[self.test_idx]).replace(
            "_x", "")
        np.save(os.path.join(self.save_dir, fname), pred, allow_pickle=False)
        self.test_idx += 1

    def get_train_data(self, batch):
        img, lbl = batch["image"], batch["label"]
        if self.args.dim == 2 and self.args.data2d_dim == 3:
            img, lbl = layout_2d(img, lbl)
        if self.args.hpus:
            img, lbl = img.to(torch.device("hpu"),
                              non_blocking=False), lbl.to(torch.device("hpu"),
                                                          non_blocking=False)
        if self.args.channels_last:
            if img.ndim == 4:
                img = img.contiguous(memory_format=torch.channels_last)
                lbl = lbl.contiguous(memory_format=torch.channels_last)
            elif img.ndim == 5:
                img = img.contiguous(memory_format=torch.channels_last_3d)
                lbl = lbl.contiguous(memory_format=torch.channels_last_3d)
            mark_step(self.args.run_lazy_mode)
        return img, lbl
예제 #9
0
    'cmap': 'jet',
    'vmin': 0,
    'vmax': eval_label.max()
}, {
    'cmap': 'jet',
    'vmin': 0,
    'vmax': eval_label.max()
})

net_is_3d = False
if torch.cuda.device_count() > 1:
    print("Using", torch.cuda.device_count(), "GPUs.")

device_ids = [i for i in range(torch.cuda.device_count())]
model = nn.DataParallel(model, device_ids=device_ids)
model = model.to(device)

if experiment == "Unet":
    model.load_state_dict(torch.load("best_weights.pth"))

elif experiment == "DeepLab":
    model.load_state_dict(torch.load(f"best_weights_{backbone}_deeplab.pth"))

model.eval()

eval_images, eval_labels, eval_label_corners = batch_generator(
    eval_image, eval_label, **windowing_params, return_corners=True)

eval_dataset = PlateletDataset(eval_images, eval_labels, train=False)

prob_maps = stitch(model, eval_images, eval_labels, eval_label.shape,
예제 #10
0
def inference():
    """Support two mode: evaluation (on valid set) or inference mode (on test-set for submission)

    """
    parser = argparse.ArgumentParser(description="Inference mode")
    parser.add_argument('-testf', "--test-filepath", type=str, default=None, required=True,
                        help="testing dataset filepath.")
    parser.add_argument("-eval", "--evaluate", action="store_true", default=False,
                        help="Evaluation mode")
    parser.add_argument("--load-weights", type=str, default=None,
                        help="Load pretrained weights, torch state_dict() (filepath, default: None)")
    parser.add_argument("--load-model", type=str, default=None,
                        help="Load pretrained model, entire model (filepath, default: None)")

    parser.add_argument("--save2dir", type=str, default=None,
                        help="save the prediction labels to the directory (default: None)")
    parser.add_argument("--debug", action="store_true", default=False)
    parser.add_argument("--batch-size", type=int, default=32,
                        help="Batch size")

    parser.add_argument("--num-cpu", type=int, default=10,
                        help="Number of CPUs to use in parallel for dataloader.")
    parser.add_argument('--cuda', type=int, default=0,
                        help='CUDA visible device (use CPU if -1, default: 0)')
    args = parser.parse_args()

    printYellow("="*10 + " Inference mode. "+"="*10)
    if args.save2dir:
        os.makedirs(args.save2dir, exist_ok=True)

    device = torch.device("cuda:{}".format(args.cuda) if torch.cuda.is_available()
                          and (args.cuda >= 0) else "cpu")

    transform_normalize = transforms.Normalize(mean=[0.5],
                                               std=[0.5])

    data_transform = transforms.Compose([
        transforms.ToTensor(),
        transform_normalize
    ])

    data_loader_params = {'batch_size': args.batch_size,
                          'shuffle': False,
                          'num_workers': args.num_cpu,
                          'drop_last': False,
                          'pin_memory': False
                          }

    test_set = LiTSDataset(args.test_filepath,
                           dtype=np.float32,
                           pixelwise_transform=data_transform,
                           inference_mode=(not args.evaluate),
                           )
    dataloader_test = torch.utils.data.DataLoader(test_set, **data_loader_params)
    # =================== Build model ===================
    if args.load_weights:
        model = UNet(in_ch=1,
                     out_ch=3,  # there are 3 classes: 0: background, 1: liver, 2: tumor
                     depth=4,
                     start_ch=64,
                     inc_rate=2,
                     kernel_size=3,
                     padding=True,
                     batch_norm=True,
                     spec_norm=False,
                     dropout=0.5,
                     up_mode='upconv',
                     include_top=True,
                     include_last_act=False,
                     )
        model.load_state_dict(torch.load(args.load_weights))
        printYellow("Successfully loaded pretrained weights.")
    elif args.load_model:
        # load entire model
        model = torch.load(args.load_model)
        printYellow("Successfully loaded pretrained model.")
    model.eval()
    model.to(device)

    # n_batch_per_epoch = len(dataloader_test)

    sigmoid_act = torch.nn.Sigmoid()
    st = time.time()

    volume_start_index = test_set.volume_start_index
    spacing = test_set.spacing
    direction = test_set.direction  # use it for the submission
    offset = test_set.offset

    msk_pred_buffer = []
    if args.evaluate:
        msk_gt_buffer = []

    for data_batch in tqdm(dataloader_test):
        # import ipdb
        # ipdb.set_trace()
        if args.evaluate:
            img, msk_gt = data_batch
            msk_gt_buffer.append(msk_gt.cpu().detach().numpy())
        else:
            img = data_batch
        img = img.to(device)
        with torch.no_grad():
            msk_pred = model(img)  # shape (N, 3, H, W)
            msk_pred = sigmoid_act(msk_pred)
        msk_pred_buffer.append(msk_pred.cpu().detach().numpy())

    msk_pred_buffer = np.vstack(msk_pred_buffer)  # shape (N, 3, H, W)
    if args.evaluate:
        msk_gt_buffer = np.vstack(msk_gt_buffer)

    results = []
    for vol_ind, vol_start_ind in enumerate(volume_start_index):
        if vol_ind == len(volume_start_index) - 1:
            volume_msk = msk_pred_buffer[vol_start_ind:]  # shape (N, 3, H, W)
            if args.evaluate:
                volume_msk_gt = msk_gt_buffer[vol_start_ind:]
        else:
            vol_end_ind = volume_start_index[vol_ind+1]
            volume_msk = msk_pred_buffer[vol_start_ind:vol_end_ind]  # shape (N, 3, H, W)
            if args.evaluate:
                volume_msk_gt = msk_gt_buffer[vol_start_ind:vol_end_ind]
        if args.evaluate:
            # liver
            liver_scores = get_scores(volume_msk[:, 1] >= 0.5, volume_msk_gt >= 1, spacing[vol_ind])
            # tumor
            lesion_scores = get_scores(volume_msk[:, 2] >= 0.5, volume_msk_gt == 2, spacing[vol_ind])
            print("Liver dice", liver_scores['dice'], "Lesion dice", lesion_scores['dice'])
            results.append([vol_ind, liver_scores, lesion_scores])
            # ===========================
        else:
            # import ipdb; ipdb.set_trace()
            if args.save2dir:
                # reverse the order, because we prioritize tumor, liver then background.
                msk_pred = (volume_msk >= 0.5)[:, ::-1, ...]  # shape (N, 3, H, W)
                msk_pred = np.argmax(msk_pred, axis=1)  # shape (N, H, W) = (z, x, y)
                msk_pred = np.transpose(msk_pred, axes=(1, 2, 0))  # shape (x, y, z)
                # remember to correct 'direction' and np.transpose before the submission !!!
                if direction[vol_ind][0] == -1:
                    # x-axis
                    msk_pred = msk_pred[::-1, ...]
                if direction[vol_ind][1] == -1:
                    # y-axis
                    msk_pred = msk_pred[:, ::-1, :]
                if direction[vol_ind][2] == -1:
                    # z-axis
                    msk_pred = msk_pred[..., ::-1]
                # save medical image header as well
                # see: http://loli.github.io/medpy/generated/medpy.io.header.Header.html
                file_header = med_header(spacing=tuple(spacing[vol_ind]),
                                         offset=tuple(offset[vol_ind]),
                                         direction=np.diag(direction[vol_ind]))
                # submission guide:
                # see: https://github.com/PatrickChrist/LITS-CHALLENGE/blob/master/submission-guide.md
                # test-segmentation-X.nii
                filepath = os.path.join(args.save2dir, f"test-segmentation-{vol_ind}.nii")
                med_save(msk_pred, filepath, hdr=file_header)
    if args.save2dir:
        # outpath = os.path.join(args.save2dir, "results.csv")
        outpath = os.path.join(args.save2dir, "results.pkl")
        with open(outpath, "wb") as file:
            final_result = {}
            final_result['liver'] = defaultdict(list)
            final_result['tumor'] = defaultdict(list)
            for vol_ind, liver_scores, lesion_scores in results:
                # [OTC] assuming vol_ind is continuous
                for key in liver_scores:
                    final_result['liver'][key].append(liver_scores[key])
                for key in lesion_scores:
                    final_result['tumor'][key].append(lesion_scores[key])
            pickle.dump(final_result, file, protocol=3)
        # ======== code from official metric ========
        # create line for csv file
        # outstr = str(vol_ind) + ','
        # for l in [liver_scores, lesion_scores]:
        #     for k, v in l.items():
        #         outstr += str(v) + ','
        #         outstr += '\n'
        # # create header for csv file if necessary
        # if not os.path.isfile(outpath):
        #     headerstr = 'Volume,'
        #     for k, v in liver_scores.items():
        #         headerstr += 'Liver_' + k + ','
        #     for k, v in liver_scores.items():
        #         headerstr += 'Lesion_' + k + ','
        #     headerstr += '\n'
        #     outstr = headerstr + outstr
        # # write to file
        # f = open(outpath, 'a+')
        # f.write(outstr)
        # f.close()
        # ===========================
    printGreen(f"Total elapsed time: {time.time()-st}")
    return results
예제 #11
0
def train(args):
    '''
    -------------------------Hyperparameters--------------------------
    '''
    EPOCHS = args.epochs
    START = 0  # could enter a checkpoint start epoch
    ITER = args.iterations  # per epoch
    LR = args.lr
    MOM = args.momentum
    # LOGInterval = args.log_interval
    BATCHSIZE = args.batch_size
    TEST_BATCHSIZE = args.test_batch_size
    NUMBER_OF_WORKERS = args.workers
    DATA_FOLDER = args.data
    TESTSET_FOLDER = args.testset
    ROOT = args.run
    WEIGHT_DIR = os.path.join(ROOT, "weights")
    CUSTOM_LOG_DIR = os.path.join(ROOT, "additionalLOGS")
    CHECKPOINT = os.path.join(WEIGHT_DIR,
                              str(args.model) + str(args.name) + ".pt")
    useTensorboard = args.tb

    # check existance of data
    if not os.path.isdir(DATA_FOLDER):
        print("data folder not existant or in wrong layout.\n\t", DATA_FOLDER)
        exit(0)
    # check existance of testset
    if TESTSET_FOLDER is not None and not os.path.isdir(TESTSET_FOLDER):
        print("testset folder not existant or in wrong layout.\n\t",
              DATA_FOLDER)
        exit(0)
    '''
    ---------------------------preparations---------------------------
    '''

    # CUDA for PyTorch
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda:0" if use_cuda else "cpu")
    print("using device: ", str(device))

    # loading the validation samples to make online evaluations
    path_to_valX = args.valX
    path_to_valY = args.valY
    valX = None
    valY = None
    if path_to_valX is not None and path_to_valY is not None \
            and os.path.exists(path_to_valX) and os.path.exists(path_to_valY) \
            and os.path.isfile(path_to_valX) and os.path.isfile(path_to_valY):
        with torch.no_grad():
            valX, valY = torch.load(path_to_valX, map_location='cpu'), \
                   torch.load(path_to_valY, map_location='cpu')
    '''
    ---------------------------loading dataset and normalizing---------------------------
    '''
    # Dataloader Parameters
    train_params = {
        'batch_size': BATCHSIZE,
        'shuffle': True,
        'num_workers': NUMBER_OF_WORKERS
    }
    test_params = {
        'batch_size': TEST_BATCHSIZE,
        'shuffle': False,
        'num_workers': NUMBER_OF_WORKERS
    }

    # create a folder for the weights and custom logs
    if not os.path.isdir(WEIGHT_DIR):
        os.makedirs(WEIGHT_DIR)
    if not os.path.isdir(CUSTOM_LOG_DIR):
        os.makedirs(CUSTOM_LOG_DIR)

    labelsNorm = None
    # NORMLABEL
    # normalizing on a trainingset wide mean and std
    mean = None
    std = None
    if args.norm:
        print('computing mean and std over trainingset')
        # computes mean and std over all ground truths in dataset to tackle the problem of numerical insignificance
        mean, std = computeMeanStdOverDataset('CONRADataset', DATA_FOLDER,
                                              train_params, device)
        print('\niodine (mean/std): {}\t{}'.format(mean[0], std[0]))
        print('water (mean/std): {}\t{}\n'.format(mean[1], std[1]))
        labelsNorm = transforms.Normalize(mean=[0, 0], std=std)
        m2, s2 = computeMeanStdOverDataset('CONRADataset',
                                           DATA_FOLDER,
                                           train_params,
                                           device,
                                           transform=labelsNorm)
        print("new mean and std are:")
        print('\nnew iodine (mean/std): {}\t{}'.format(m2[0], s2[0]))
        print('new water (mean/std): {}\t{}\n'.format(m2[1], s2[1]))

    traindata = CONRADataset(DATA_FOLDER,
                             True,
                             device=device,
                             precompute=True,
                             transform=labelsNorm)

    testdata = None
    if TESTSET_FOLDER is not None:
        testdata = CONRADataset(TESTSET_FOLDER,
                                False,
                                device=device,
                                precompute=True,
                                transform=labelsNorm)
    else:
        testdata = CONRADataset(DATA_FOLDER,
                                False,
                                device=device,
                                precompute=True,
                                transform=labelsNorm)

    trainingset = DataLoader(traindata, **train_params)
    testset = DataLoader(testdata, **test_params)
    '''
    ----------------loading model and checkpoints---------------------
    '''

    if args.model == "unet":
        m = UNet(2, 2).to(device)
        print(
            "using the U-Net architecture with {} trainable params; Good Luck!"
            .format(count_trainables(m)))
    else:
        m = simpleConvNet(2, 2).to(device)

    o = optim.SGD(m.parameters(), lr=LR, momentum=MOM)

    loss_fn = nn.MSELoss()

    test_loss = None
    train_loss = None

    if len(os.listdir(WEIGHT_DIR)) != 0:
        checkpoints = os.listdir(WEIGHT_DIR)
        checkDir = {}
        latestCheckpoint = 0
        for i, checkpoint in enumerate(checkpoints):
            stepOfCheckpoint = int(
                checkpoint.split(str(args.model) +
                                 str(args.name))[-1].split('.pt')[0])
            checkDir[stepOfCheckpoint] = checkpoint
            latestCheckpoint = max(latestCheckpoint, stepOfCheckpoint)
            print("[{}] {}".format(stepOfCheckpoint, checkpoint))
        # if on development machine, prompt for input, else just take the most recent one
        if 'faui' in os.uname()[1]:
            toUse = int(input("select checkpoint to use: "))
        else:
            toUse = latestCheckpoint
        checkpoint = torch.load(os.path.join(WEIGHT_DIR, checkDir[toUse]))
        m.load_state_dict(checkpoint['model_state_dict'])
        m.to(device)  # pushing weights to gpu
        o.load_state_dict(checkpoint['optimizer_state_dict'])
        train_loss = checkpoint['train_loss']
        test_loss = checkpoint['test_loss']
        START = checkpoint['epoch']
        print("using checkpoint {}:\n\tloss(train/test): {}/{}".format(
            toUse, train_loss, test_loss))
    else:
        print("starting from scratch")
    '''
    -----------------------------training-----------------------------
    '''
    global_step = 0
    # calculating initial loss
    if test_loss is None or train_loss is None:
        print("calculating initial loss")
        m.eval()
        print("testset...")
        test_loss = calculate_loss(set=testset,
                                   loss_fn=loss_fn,
                                   length_set=len(testdata),
                                   dev=device,
                                   model=m)
        print("trainset...")
        train_loss = calculate_loss(set=trainingset,
                                    loss_fn=loss_fn,
                                    length_set=len(traindata),
                                    dev=device,
                                    model=m)

    ## SSIM and R value
    R = []
    SSIM = []
    performanceFLE = os.path.join(CUSTOM_LOG_DIR, "performance.csv")
    with open(performanceFLE, 'w+') as f:
        f.write(
            "step, SSIMiodine, SSIMwater, Riodine, Rwater, train_loss, test_loss\n"
        )
    print("computing ssim and r coefficents to: {}".format(performanceFLE))

    # printing runtime information
    print(
        "starting training at {} for {} epochs {} iterations each\n\t{} total".
        format(START, EPOCHS, ITER, EPOCHS * ITER))

    print("\tbatchsize: {}\n\tloss: {}\n\twill save results to \"{}\"".format(
        BATCHSIZE, train_loss, CHECKPOINT))
    print(
        "\tmodel: {}\n\tlearningrate: {}\n\tmomentum: {}\n\tnorming output space: {}"
        .format(args.model, LR, MOM, args.norm))

    #start actual training loops
    for e in range(START, START + EPOCHS):
        # iterations will not be interupted with validation and metrics
        for i in range(ITER):
            global_step = (e * ITER) + i

            # training
            m.train()
            iteration_loss = 0
            for x, y in tqdm(trainingset):
                x, y = x.to(device=device,
                            dtype=torch.float), y.to(device=device,
                                                     dtype=torch.float)
                pred = m(x)
                loss = loss_fn(pred, y)
                iteration_loss += loss.item()
                o.zero_grad()
                loss.backward()
                o.step()
            print("\niteration {}: --accumulated loss {}".format(
                global_step, iteration_loss))

        # validation, saving and logging
        print("\nvalidating")
        m.eval()  # disable dropout batchnorm etc
        print("testset...")
        test_loss = calculate_loss(set=testset,
                                   loss_fn=loss_fn,
                                   length_set=len(testdata),
                                   dev=device,
                                   model=m)
        print("trainset...")
        train_loss = calculate_loss(set=trainingset,
                                    loss_fn=loss_fn,
                                    length_set=len(traindata),
                                    dev=device,
                                    model=m)

        print("calculating SSIM and R coefficients")
        currSSIM, currR = performance(set=testset,
                                      dev=device,
                                      model=m,
                                      bs=TEST_BATCHSIZE)
        print("SSIM (iod/water): {}/{}\nR (iod/water): {}/{}".format(
            currSSIM[0], currSSIM[1], currR[0], currR[1]))
        with open(performanceFLE, 'a') as f:
            newCSVline = "{}, {}, {}, {}, {}, {}, {}\n".format(
                global_step, currSSIM[0], currSSIM[1], currR[0], currR[1],
                train_loss, test_loss)
            f.write(newCSVline)
            print("wrote new line to csv:\n\t{}".format(newCSVline))
        '''
            if valX and valY were set in preparations, use them to perform analytics.
            if not, use the first sample from the testset to perform analytics
        '''
        with torch.no_grad():
            truth, pred = None, None
            IMAGE_LOG_DIR = os.path.join(CUSTOM_LOG_DIR, str(global_step))
            if not os.path.isdir(IMAGE_LOG_DIR):
                os.makedirs(IMAGE_LOG_DIR)

            if valX is not None and valY is not None:
                batched = np.zeros((BATCHSIZE, *valX.numpy().shape))
                batched[0] = valX.numpy()
                batched = torch.from_numpy(batched).to(device=device,
                                                       dtype=torch.float)
                pred = m(batched)
                pred = pred.cpu().numpy()[0]
                truth = valY.numpy()  # still on cpu

                assert pred.shape == truth.shape
            else:
                for x, y in testset:
                    # x, y in shape[2,2,480,620] [b,c,h,w]
                    x, y = x.to(device=device,
                                dtype=torch.float), y.to(device=device,
                                                         dtype=torch.float)
                    pred = m(x)
                    pred = pred.cpu().numpy()[
                        0]  # taking only the first sample of batch
                    truth = y.cpu().numpy()[
                        0]  # first projection for evaluation
            advanvedMetrics(truth, pred, mean, std, global_step, args.norm,
                            IMAGE_LOG_DIR)

        print("logging")
        CHECKPOINT = os.path.join(
            WEIGHT_DIR,
            str(args.model) + str(args.name) + str(global_step) + ".pt")
        torch.save(
            {
                'epoch': e + 1,  # end of this epoch; so resume at next.
                'model_state_dict': m.state_dict(),
                'optimizer_state_dict': o.state_dict(),
                'train_loss': train_loss,
                'test_loss': test_loss
            },
            CHECKPOINT)
        print('\tsaved weigths to: ', CHECKPOINT)
        if logger is not None and train_loss is not None:
            logger.add_scalar('test_loss', test_loss, global_step=global_step)
            logger.add_scalar('train_loss',
                              train_loss,
                              global_step=global_step)
            logger.add_image("iodine-prediction",
                             pred[0].reshape(1, 480, 620),
                             global_step=global_step)
            logger.add_image("water-prediction",
                             pred[1].reshape(1, 480, 620),
                             global_step=global_step)
            # logger.add_image("water-prediction", wat)
            print(
                "\ttensorboard updated with test/train loss and a sample image"
            )
        elif train_loss is not None:
            print("\tloss of global-step {}: {}".format(
                global_step, train_loss))
        elif not useTensorboard:
            print("\t(tb-logging disabled) test/train loss: {}/{} ".format(
                test_loss, train_loss))
        else:
            print("\tno loss accumulated yet")

    # saving final results
    print("saving upon exit")
    torch.save(
        {
            'epoch': EPOCHS,
            'model_state_dict': m.state_dict(),
            'optimizer_state_dict': o.state_dict(),
            'train_loss': train_loss,
            'test_loss': test_loss
        }, CHECKPOINT)
    print('\tsaved progress to: ', CHECKPOINT)
    if logger is not None and train_loss is not None:
        logger.add_scalar('test_loss', test_loss, global_step=global_step)
        logger.add_scalar('train_loss', train_loss, global_step=global_step)
예제 #12
0
def main(args):
    def log_string(str):
        #        logger.info(str)
        print(str)

    '''HYPER PARAMETER'''
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    '''CREATE DIR'''
    timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))
    experiment_dir = Path('./log/')
    experiment_dir.mkdir(exist_ok=True)
    experiment_dir = experiment_dir.joinpath('part_seg')
    experiment_dir.mkdir(exist_ok=True)
    if args.log_dir is None:
        experiment_dir = experiment_dir.joinpath(timestr)
    else:
        experiment_dir = experiment_dir.joinpath(args.log_dir)
    experiment_dir.mkdir(exist_ok=True)
    checkpoints_dir = experiment_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = experiment_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)
    '''LOG'''
    args = parse_args()
    logger = logging.getLogger("Model")
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    log_string('PARAMETER ...')
    log_string(args)

    root = '/media/feihu/Storage/kitti_point_cloud/semantic_kitti/'
    #    file_list = '/media/feihu/Storage/kitti_point_cloud/semantic_kitti/train2.list'
    val_list = '/media/feihu/Storage/kitti_point_cloud/semantic_kitti/val2.list'
    #    TRAIN_DATASET = KittiDataset(root = root, file_list=file_list, npoints=args.npoint, training=True, augment=True)
    #    trainDataLoader = torch.utils.data.DataLoader(TRAIN_DATASET, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=2)
    TEST_DATASET = KittiDataset(root=root,
                                file_list=val_list,
                                npoints=args.npoint,
                                training=False,
                                augment=False)
    testDataLoader = torch.utils.data.DataLoader(TEST_DATASET,
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 drop_last=True,
                                                 num_workers=2)
    #    log_string("The number of training data is: %d" % len(TRAIN_DATASET))
    log_string("The number of test data is: %d" % len(TEST_DATASET))
    #    num_classes = 16

    num_devices = args.num_gpus  #torch.cuda.device_count()
    #    assert num_devices > 1, "Cannot detect more than 1 GPU."
    #    print(num_devices)
    devices = list(range(num_devices))
    target_device = devices[0]

    #    MODEL = importlib.import_module(args.model)

    net = UNet(4, 20, nPlanes)

    #    net = MODEL.get_model(num_classes, normal_channel=args.normal)
    net = net.to(target_device)

    try:
        checkpoint = torch.load(
            str(experiment_dir) + '/checkpoints/best_model.pth')
        start_epoch = checkpoint['epoch']
        net.load_state_dict(checkpoint['model_state_dict'])
        log_string('Use pretrain model')
    except:
        log_string('No existing model, starting training from scratch...')
        quit()

    if 1:

        with torch.no_grad():
            net.eval()
            evaluator = iouEval(num_classes, ignore)

            evaluator.reset()
            #            for iteration, (points, target, ins, mask) in tqdm(enumerate(testDataLoader), total=len(testDataLoader), smoothing=0.9):
            for iteration, (points, target, ins,
                            mask) in enumerate(testDataLoader):
                evaone = iouEval(num_classes, ignore)
                evaone.reset()
                cur_batch_size, NUM_POINT, _ = points.size()

                if iteration > 128:
                    break

                inputs, targets, masks = [], [], []
                coords = []
                for i in range(num_devices):
                    start = int(i * (cur_batch_size / num_devices))
                    end = int((i + 1) * (cur_batch_size / num_devices))
                    with torch.cuda.device(devices[i]):
                        pc = points[start:end, :, :].to(devices[i])
                        #feas = points[start:end,:,3:].to(devices[i])
                        targeti = target[start:end, :].to(devices[i])
                        maski = mask[start:end, :].to(devices[i])

                        locs, feas, label, maski, offsets = input_layer(
                            pc, targeti, maski, scale.to(devices[i]),
                            spatialSize.to(devices[i]), True)
                        #                        print(locs.size(), feas.size(), label.size(), maski.size(), offsets.size())
                        org_coords = locs[1]
                        label = Variable(label, requires_grad=False)

                        inputi = ME.SparseTensor(feas.cpu(), locs[0].cpu())
                        inputs.append([inputi.to(devices[i]), org_coords])
                        targets.append(label)
                        masks.append(maski)

                replicas = parallel.replicate(net, devices)
                outputs = parallel.parallel_apply(replicas,
                                                  inputs,
                                                  devices=devices)

                seg_pred = outputs[0].cpu()
                mask = masks[0].cpu()
                target = targets[0].cpu()
                loc = locs[0].cpu()
                for i in range(1, num_devices):
                    seg_pred = torch.cat((seg_pred, outputs[i].cpu()), 0)
                    mask = torch.cat((mask, masks[i].cpu()), 0)
                    target = torch.cat((target, targets[i].cpu()), 0)

                seg_pred = seg_pred[target > 0, :]
                target = target[target > 0]
                _, seg_pred = seg_pred.data.max(1)  #[1]

                target = target.data.numpy()

                evaluator.addBatch(seg_pred, target)

                evaone.addBatch(seg_pred, target)
                cur_accuracy = evaone.getacc()
                cur_jaccard, class_jaccard = evaone.getIoU()
                print('%.4f %.4f' % (cur_accuracy, cur_jaccard))

            m_accuracy = evaluator.getacc()
            m_jaccard, class_jaccard = evaluator.getIoU()

            log_string('Validation set:\n'
                       'Acc avg {m_accuracy:.3f}\n'
                       'IoU avg {m_jaccard:.3f}'.format(m_accuracy=m_accuracy,
                                                        m_jaccard=m_jaccard))
            # print also classwise
            for i, jacc in enumerate(class_jaccard):
                if i not in ignore:
                    log_string(
                        'IoU class {i:} [{class_str:}] = {jacc:.3f}'.format(
                            i=i,
                            class_str=class_strings[class_inv_remap[i]],
                            jacc=jacc))
def train():
    startTime = time.time()
    args = parameters.parse_arguments()
    logging.basicConfig(filename=args.logfile, level=logging.INFO)
    logging.critical("\n\n" + args.log_header)
    logging.info(args)
    device = ("cuda" if torch.cuda.is_available() else "cpu")
    logging.info(f"TIME: {time.time() - startTime}s Using device {device}")

    logging.info(f"TIME: {time.time()-startTime}s Loading dataset")
    try:
        with open(os.path.join(args.datadir, "data.pkl"), "rb") as f:
            data = pickle.load(f)
    except:
        data = DataLoader(args.datadir,
                          int(args.batchsize),
                          shuffle=int(args.shuffle))
        with open(os.path.join(args.datadir, "data.pkl"), "wb") as f:
            pickle.dump(data, f)
    data.batchSize = int(args.batchsize)
    logging.info(f"TIME: {time.time()-startTime}s Dataset Loaded")

    random.seed(args.seed)
    indices = list(range(len(data)))
    random.shuffle(
        indices
    )  # 0:floor((1-validationFrac)*len(data)) will be training data, rest will be validation data
    trainEndIndex = math.floor((1 - args.validation_frac) * (len(data)))

    model = UNet(in_channels=1,
                 num_classes=2,
                 start_filts=int(args.conv_filters),
                 up_mode=args.mode,
                 depth=int(args.depth),
                 batchnorm=args.batchnorm)
    model.reset_params()
    model = model.to(device)
    optimizer = None
    if args.optimizer == 'adam':
        optimizer = optim.Adam(model.parameters(), lr=args.lrstart)
        logging.info(f"TIME: {time.time()-startTime}s Optimizer: adam")
    elif args.optimizer == 'sgd':
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lrstart,
                              momentum=args.momentum)
        logging.info(f"TIME: {time.time()-startTime}s Optimizer: SGD")
    elif args.optimizer == 'rmsprop':
        optimizer = optim.RMSprop(model.parameters(), lr=args.lrstart)
        logging.info(f"TIME: {time.time()-startTime}s Optimizer: RMSProp")
    else:
        logging.error(
            f"TIME: {time.time()-startTime}s Incorrect optimizer given")

    scheduler = []
    if args.lrscheduler == "steplr":
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.decay)
        logging.info(f"TIME: {time.time()-startTime}s LRScheduler: StepLR")
    elif args.lrscheduler == "exponentiallr":
        scheduler = optim.lr_scheduler.ExponentialLR(optimizer,
                                                     gamma=args.decay)
        logging.info(
            f"TIME: {time.time()-startTime}s LRScheduler: exponentialLR")
    else:
        scheduler = optim.lr_scheduler.StepLR(optimizer,
                                              step_size=int(args.epochs))
        logging.info(
            f"TIME: {time.time()-startTime}s LRScheduler: lr shouldn't change with epochs"
        )

    criteria = CombinedLoss(args.lambda_loss, args.loss_type)
    diceCoeff = DiceLoss()
    TL = []
    VL = []
    if not os.path.exists(os.path.join(os.getcwd(), "loss_files")):
        os.makedirs(os.path.join(os.getcwd(), "loss_files"))
    lossFile = open(os.path.join("loss_files", args.log_header + ".csv"), "w+")
    lossFile.write("Epoch,TrainLoss,ValidationLoss,Dice Coefficient\n")

    for epoch in tqdm(range(1, int(args.epochs) + 1), desc="Training model"):
        trainLoss = 0
        valLoss = 0
        trainingSample = 0
        testSample = 0
        netCoeff = 0
        for i in range(len(data)):
            images, masks = data[i]
            images = torch.tensor(images.astype(np.float32))
            masks = torch.tensor(masks.astype(np.float32))
            images = images.to(device)
            masks = masks.to(device)
            images = torch.transpose(images, 1, 3)
            masks = torch.transpose(masks, 1, 3)
            if i in indices[:trainEndIndex]:
                trainingSample += images.shape[0]
                networkPred = model(images)
                if args.regularization == 'l1':
                    reg = L1_regularization(model, args.reg_lamda1)
                    loss = criteria(masks, networkPred) + reg
                elif args.regularization == 'l1l2':
                    reg = L1L2_regularization(model, args.reg_lamda1,
                                              args.reg_lamda2)
                    loss = criteria(masks, networkPred) + reg
                else:
                    loss = criteria(masks, networkPred)
                loss.backward()
                trainLoss += loss.item()
                optimizer.step()
                model.zero_grad()
            else:
                with torch.no_grad():
                    testSample += images.shape[0]
                    prediction = model(images)
                    if (epoch % args.save_epochs
                            == 0) or (epoch == 1) or (epoch == args.epochs):
                        imgPath = os.path.join("validation_sample",
                                               args.log_header,
                                               f"epoch {epoch}")
                        if not os.path.exists(imgPath):
                            os.makedirs(imgPath)
                        hrt = images[0, 0, :, :].to("cpu")
                        plt.imshow(np.array(hrt), cmap='gray')
                        plt.title("Heart Image")
                        plt.savefig(os.path.join(imgPath, "heart.png"))
                        plt.clf()
                        # ax = figure.add_subplot(232, title="Mask 1 Predicted")
                        msk1 = prediction[0, 0, :, :].to("cpu")
                        plt.imshow(np.array(msk1), cmap='gray')
                        plt.title("Predicted Mask 1")
                        plt.savefig(os.path.join(imgPath, "pred-mask1.png"))
                        plt.clf()
                        # ax = figure.add_subplot(231, title="Mask 2 Predicted")
                        msk2 = prediction[0, 1, :, :].to("cpu")
                        plt.imshow(np.array(msk2), cmap='gray')
                        plt.title("Predicted Mask 2")
                        plt.savefig(os.path.join(imgPath, "pred-mask2.png"))
                        plt.clf()

                        msk = np.zeros((192, 192, 3))
                        msk[:, :, 0] = np.array(msk1)
                        msk[:, :, 1] = np.array(msk2)
                        plt.imshow(np.array(hrt), cmap='gray')
                        plt.imshow(msk, cmap='jet', alpha=0.4)
                        plt.title("predicted-RV")
                        plt.savefig(os.path.join(imgPath, "pred-RV.png"))
                        plt.clf()
                        # ax = figure.add_subplot(231, title="Mask 2 Real")
                        msk1 = masks[0, 0, :, :].to("cpu")
                        plt.imshow(np.array(msk1), cmap='gray')
                        plt.title("Actual Mask 1")
                        plt.savefig(os.path.join(imgPath, "actual-mask1.png"))
                        plt.clf()
                        # ax = figure.add_subplot(231, title="Mask 2 Real")
                        msk2 = masks[0, 1, :, :].to("cpu")
                        plt.imshow(np.array(msk2), cmap='gray')
                        plt.title("Actual Mask 2")
                        plt.savefig(os.path.join(imgPath, "actual-mask2.png"))
                        plt.clf()
                        # plt.savefig(os.path.join("validation_sample", f"{args.log_header}-epoch {epoch}.png"))
                        msk = np.zeros((192, 192, 3))
                        msk[:, :, 0] = np.array(msk1)
                        msk[:, :, 1] = np.array(msk2)
                        plt.imshow(np.array(hrt), cmap='gray')
                        plt.imshow(msk, cmap='jet', alpha=0.4)
                        plt.title("actual-RV")
                        plt.savefig(os.path.join(imgPath, "actual-RV.png"))
                        plt.clf()

                    if args.regularization == 'l1':
                        reg = L1_regularization(model, args.reg_lamda1)
                        loss = criteria(masks, prediction) + reg
                    elif args.regularization == 'l1l2':
                        reg = L1L2_regularization(model, args.reg_lamda1,
                                                  args.reg_lamda2)
                        loss = criteria(masks, prediction) + reg
                    else:
                        loss = criteria(masks, prediction)
                    valLoss += loss.item()
                    coeff = diceCoeff(masks, prediction)
                    netCoeff += torch.sum(1 - coeff).item()
        if (epoch % int(args.save_epochs) == 0) or (epoch == int(args.epochs)):
            if not os.path.exists(args.model_save_dir):
                os.makedirs(args.model_save_dir)
            # save model
            torch.save(
                {
                    "epoch": epoch,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                },
                os.path.join(args.model_save_dir,
                             f"model-epoch({epoch}).hdf5"))
            logging.info(
                f"TIME: {time.time()-startTime}s Model state saved for epoch: {epoch}"
            )
        logging.info(
            f"TIME: {time.time()-startTime}s TRAINING: Epoch: {epoch}, lr: {scheduler.get_last_lr()}, loss: {trainLoss/(2*trainingSample)}"
        )
        logging.info(
            f"TIME: {time.time()-startTime}s VALIDATION: Epoch: {epoch}, lr: {scheduler.get_last_lr()}, loss: {valLoss/(2*testSample)}"
        )
        TL.append(trainLoss / (2 * trainingSample))
        VL.append(valLoss / (2 * testSample))
        lossFile.write(
            f"{epoch},{trainLoss/(2*trainingSample)},{valLoss/(2*testSample)},{netCoeff/(2*testSample)}\n"
        )
        scheduler.step(
        )  # https://www.deeplearningwizard.com/deep_learning/boosting_models_pytorch/lr_scheduling/
    plt.plot(list(range(1, int(args.epochs) + 1)), TL, label="Training loss")
    plt.plot(list(range(1, int(args.epochs) + 1)), VL, label="Validation loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend(loc="best")
    if not os.path.exists(os.path.join(os.getcwd(), "plots")):
        os.makedirs(os.path.join(os.getcwd(), "plots"))
    plt.savefig(os.path.join("plots", args.log_header + ".png"))
                                         batch_size=1,
                                         shuffle=False)

partition = 'train'
unet_train = HistologyData(ROOT_DIR, partition, True)
unet_loader = torch.utils.data.DataLoader(
    unet_train,
    batch_size=1,
    shuffle=True,
)

# Create model
model = ShapeUNet((15, 512, 512))
unet = UNet((3, 512, 512))
model.to(device)
unet.to(device)

mask_values = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14)
# here not RGB but BGR because of OPENCV.
real_colors = ((0, 0, 0), (255, 0, 0), (0, 255, 0), (0, 0, 255), (85, 0, 0),
               (0, 170, 0), (255, 0, 127), (0, 255, 255), (0, 85, 0),
               (255, 0, 255), (255, 85, 0), (255, 165, 0), (255, 255, 0),
               (128, 130, 128), (128, 190, 190))
lr = 1e-4
optimizer = Adam(model.parameters(), lr=lr)
NUM_OF_EPOCHS = 40

lr1 = 1e-4
unet_optim = Adam(unet.parameters(), lr=lr1)

train_network_on_top_of_other(model, train_loader, val_loader, optimizer, unet,
예제 #15
0
파일: ei.py 프로젝트: edongdongchen/EI
    def train_ei_adv(self,
                     dataloader,
                     physics,
                     transform,
                     epochs,
                     lr,
                     alpha,
                     ckp_interval,
                     schedule,
                     residual=True,
                     pretrained=None,
                     task='',
                     loss_type='l2',
                     cat=True,
                     report_psnr=False,
                     lr_cos=False):
        save_path = './ckp/{}_ei_adv_{}'.format(get_timestamp(), task)

        os.makedirs(save_path, exist_ok=True)

        generator = UNet(in_channels=self.in_channels,
                         out_channels=self.out_channels,
                         compact=4,
                         residual=residual,
                         circular_padding=True,
                         cat=cat)

        if pretrained:
            checkpoint = torch.load(pretrained)
            generator.load_state_dict(checkpoint['state_dict'])

        discriminator = Discriminator(
            (self.in_channels, self.img_width, self.img_height))

        generator = generator.to(self.device)
        discriminator = discriminator.to(self.device)

        if loss_type == 'l2':
            criterion_mc = torch.nn.MSELoss().to(self.device)
            criterion_ei = torch.nn.MSELoss().to(self.device)
        if loss_type == 'l1':
            criterion_mc = torch.nn.L1Loss().to(self.device)
            criterion_ei = torch.nn.L1Loss().to(self.device)

        criterion_gan = torch.nn.MSELoss().to(self.device)

        optimizer_G = Adam(generator.parameters(),
                           lr=lr['G'],
                           weight_decay=lr['WD'])
        optimizer_D = Adam(discriminator.parameters(),
                           lr=lr['D'],
                           weight_decay=0)

        if report_psnr:
            log = LOG(save_path,
                      filename='training_loss',
                      field_name=[
                          'epoch', 'loss_mc', 'loss_ei', 'loss_g', 'loss_G',
                          'loss_D', 'psnr', 'mse'
                      ])
        else:
            log = LOG(save_path,
                      filename='training_loss',
                      field_name=[
                          'epoch', 'loss_mc', 'loss_ei', 'loss_g', 'loss_G',
                          'loss_D'
                      ])

        for epoch in range(epochs):
            adjust_learning_rate(optimizer_G, epoch, lr['G'], lr_cos, epochs,
                                 schedule)
            adjust_learning_rate(optimizer_D, epoch, lr['D'], lr_cos, epochs,
                                 schedule)

            loss = closure_ei_adv(generator, discriminator, dataloader,
                                  physics, transform, optimizer_G, optimizer_D,
                                  criterion_mc, criterion_ei, criterion_gan,
                                  alpha, self.dtype, self.device, report_psnr)

            log.record(epoch + 1, *loss)

            if report_psnr:
                print(
                    '{}\tEpoch[{}/{}]\tfc={:.4e}\tti={:.4e}\tg={:.4e}\tG={:.4e}\tD={:.4e}\tpsnr={:.4f}\tmse={:.4e}'
                    .format(get_timestamp(), epoch, epochs, *loss))
            else:
                print(
                    '{}\tEpoch[{}/{}]\tfc={:.4e}\tti={:.4e}\tg={:.4e}\tG={:.4e}\tD={:.4e}'
                    .format(get_timestamp(), epoch, epochs, *loss))

            if epoch % ckp_interval == 0 or epoch + 1 == epochs:
                state = {
                    'epoch': epoch,
                    'state_dict_G': generator.state_dict(),
                    'state_dict_D': discriminator.state_dict(),
                    'optimizer_G': optimizer_G.state_dict(),
                    'optimizer_D': optimizer_D.state_dict()
                }
                torch.save(
                    state,
                    os.path.join(save_path, 'ckp_{}.pth.tar'.format(epoch)))
        log.close()
def main():
    parser = argparse.ArgumentParser(description="Train the model")
    parser.add_argument('-trainf', "--train-filepath", type=str, default=None, required=True,
                        help="training dataset filepath.")
    parser.add_argument('-validf', "--val-filepath", type=str, default=None,
                        help="validation dataset filepath.")
    parser.add_argument("--shuffle", action="store_true", default=False,
                        help="Shuffle the dataset")
    parser.add_argument("--load-weights", type=str, default=None,
                        help="load pretrained weights")
    parser.add_argument("--load-model", type=str, default=None,
                        help="load pretrained model, entire model (filepath, default: None)")

    parser.add_argument("--debug", action="store_true", default=False)
    parser.add_argument('--epochs', type=int, default=30,
                        help='number of epochs to train (default: 30)')
    parser.add_argument("--batch-size", type=int, default=32,
                        help="Batch size")

    parser.add_argument('--img-shape', type=str, default="(1,512,512)",
                        help='Image shape (default "(1,512,512)"')

    parser.add_argument("--num-cpu", type=int, default=10,
                        help="Number of CPUs to use in parallel for dataloader.")
    parser.add_argument('--cuda', type=int, default=0,
                        help='CUDA visible device (use CPU if -1, default: 0)')
    parser.add_argument('--cuda-non-deterministic', action='store_true', default=False,
                        help="sets flags for non-determinism when using CUDA (potentially fast)")

    parser.add_argument('-lr', type=float, default=0.0005,
                        help='Learning rate')
    parser.add_argument('--seed', type=int, default=0,
                        help='Seed (numpy and cuda if GPU is used.).')

    parser.add_argument('--log-dir', type=str, default=None,
                        help='Save the results/model weights/logs under the directory.')

    args = parser.parse_args()

    # TODO: support image reshape
    img_shape = tuple(map(int, args.img_shape.strip()[1:-1].split(",")))

    if args.log_dir:
        os.makedirs(args.log_dir, exist_ok=True)
        best_model_path = os.path.join(args.log_dir, "model_weights.pth")
    else:
        best_model_path = None

    if args.seed is not None:
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        if args.cuda >= 0:
            if args.cuda_non_deterministic:
                printBlue("Warning: using CUDA non-deterministc. Could be faster but results might not be reproducible.")
            else:
                printBlue("Using CUDA deterministc. Use --cuda-non-deterministic might accelerate the training a bit.")
            # Make CuDNN Determinist
            torch.backends.cudnn.deterministic = not args.cuda_non_deterministic

            # torch.cuda.manual_seed(args.seed)
            torch.cuda.manual_seed_all(args.seed)

    # TODO [OPT] enable multi-GPUs ?
    # https://pytorch.org/tutorials/beginner/former_torchies/parallelism_tutorial.html
    device = torch.device("cuda:{}".format(args.cuda) if torch.cuda.is_available()
                          and (args.cuda >= 0) else "cpu")

    # ================= Build dataloader =================
    # DataLoader
    # transform_normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5],
    #                                            std=[0.5, 0.5, 0.5])
    transform_normalize = transforms.Normalize(mean=[0.5],
                                               std=[0.5])

    # Warning: DO NOT use geometry transform (do it in the dataloader instead)
    data_transform = transforms.Compose([
        # transforms.ToPILImage(mode='F'), # mode='F' for one-channel image
        # transforms.Resize((256, 256)) # NO
        # transforms.RandomResizedCrop(256), # NO
        # transforms.RandomHorizontalFlip(p=0.5), # NO
        # WARNING, ISSUE: transforms.ColorJitter doesn't work with ToPILImage(mode='F').
        # Need custom data augmentation functions: TODO: DONE.
        # transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),

        # Use OpenCVRotation, OpenCVXXX, ... (our implementation)
        # OpenCVRotation((-10, 10)), # angles (in degree)
        transforms.ToTensor(),  # already done in the dataloader
        transform_normalize
    ])

    geo_transform = GeoCompose([
        OpenCVRotation(angles=(-10, 10),
                       scales=(0.9, 1.1),
                       centers=(-0.05, 0.05)),

        # TODO add more data augmentation here
    ])

    def worker_init_fn(worker_id):
        # WARNING spawn start method is used,
        # worker_init_fn cannot be an unpicklable object, e.g., a lambda function.
        # A work-around for issue #5059: https://github.com/pytorch/pytorch/issues/5059
        np.random.seed()

    data_loader_train = {'batch_size': args.batch_size,
                         'shuffle': args.shuffle,
                         'num_workers': args.num_cpu,
                         #   'sampler': balanced_sampler,
                         'drop_last': True,  # for GAN-like
                         'pin_memory': False,
                         'worker_init_fn': worker_init_fn,
                         }

    data_loader_valid = {'batch_size': args.batch_size,
                         'shuffle': False,
                         'num_workers': args.num_cpu,
                         'drop_last': False,
                         'pin_memory': False,
                         }

    train_set = LiTSDataset(args.train_filepath,
                            dtype=np.float32,
                            geometry_transform=geo_transform,  # TODO enable data augmentation
                            pixelwise_transform=data_transform,
                            )
    valid_set = LiTSDataset(args.val_filepath,
                            dtype=np.float32,
                            pixelwise_transform=data_transform,
                            )

    dataloader_train = torch.utils.data.DataLoader(train_set, **data_loader_train)
    dataloader_valid = torch.utils.data.DataLoader(valid_set, **data_loader_valid)
    # =================== Build model ===================
    # TODO: control the model by bash command

    if args.load_weights:
        model = UNet(in_ch=1,
                     out_ch=3,  # there are 3 classes: 0: background, 1: liver, 2: tumor
                     depth=4,
                     start_ch=32, # 64
                     inc_rate=2,
                     kernel_size=5, # 3 
                     padding=True,
                     batch_norm=True,
                     spec_norm=False,
                     dropout=0.5,
                     up_mode='upconv',
                     include_top=True,
                     include_last_act=False,
                     )
        printYellow(f"Loading pretrained weights from: {args.load_weights}...")
        model.load_state_dict(torch.load(args.load_weights))
        printYellow("+ Done.")
    elif args.load_model:
        # load entire model
        model = torch.load(args.load_model)
        printYellow("Successfully loaded pretrained model.")

    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.95))  # TODO
    best_valid_loss = float('inf')
    # TODO TODO: add learning decay
    
    for epoch in range(args.epochs):
        for valid_mode, dataloader in enumerate([dataloader_train, dataloader_valid]):
            n_batch_per_epoch = len(dataloader)
            if args.debug:
                n_batch_per_epoch = 1

            # infinite dataloader allows several update per iteration (for special models e.g. GAN)
            dataloader = infinite_dataloader(dataloader)
            if valid_mode:
                printYellow("Switch to validation mode.")
                model.eval()
                prev_grad_mode = torch.is_grad_enabled()
                torch.set_grad_enabled(False)
            else:
                model.train()

            st = time.time()
            cum_loss = 0
            for iter_ind in range(n_batch_per_epoch):
                supplement_logs = ""
                # reset cumulated losses at the begining of each batch
                # loss_manager.reset_losses() # TODO: use torch.utils.tensorboard !!
                optimizer.zero_grad()

                img, msk = next(dataloader)
                img, msk = img.to(device), msk.to(device)

                # TODO this is ugly: convert dtype and convert the shape from (N, 1, 512, 512) to (N, 512, 512)
                msk = msk.to(torch.long).squeeze(1)

                msk_pred = model(img)  # shape (N, 3, 512, 512)

                # label_weights is determined according the liver_ratio & tumor_ratio
                # loss = CrossEntropyLoss(msk_pred, msk, label_weights=[1., 10., 100.], device=device)
                loss = DiceLoss(msk_pred, msk, label_weights=[1., 20., 50.], device=device)
                # loss = DiceLoss(msk_pred, msk, label_weights=[1., 20., 500.], device=device)

                if valid_mode:
                    pass
                else:
                    loss.backward()
                    optimizer.step()

                loss = loss.item()  # release
                cum_loss += loss
                if valid_mode:
                    print("\r--------(valid) {:.2%} Loss: {:.3f} (time: {:.1f}s) |supp: {}".format(
                        (iter_ind+1)/n_batch_per_epoch, cum_loss/(iter_ind+1), time.time()-st, supplement_logs), end="")
                else:
                    print("\rEpoch: {:3}/{} {:.2%} Loss: {:.3f} (time: {:.1f}s) |supp: {}".format(
                        (epoch+1), args.epochs, (iter_ind+1)/n_batch_per_epoch, cum_loss/(iter_ind+1), time.time()-st, supplement_logs), end="")
            print()
            if valid_mode:
                torch.set_grad_enabled(prev_grad_mode)

        valid_mean_loss = cum_loss/(iter_ind+1)  # validation (mean) loss of the current epoch

        if best_model_path and (valid_mean_loss < best_valid_loss):
            printGreen("Valid loss decreases from {:.5f} to {:.5f}, saving best model.".format(
                best_valid_loss, valid_mean_loss))
            best_valid_loss = valid_mean_loss
            # Only need to save the weights
            # torch.save(model.state_dict(), best_model_path)
            # save the entire model
            torch.save(model, best_model_path)

    return best_valid_loss
예제 #17
0
    if not (os.path.exists(CHECKPOINT) and os.path.isfile(CHECKPOINT)):
        print("weights in wrong format or non-existant: \n\t{}".format(
            CHECKPOINT))
        exit()

    # loading the model
    m = None
    if args.model == "unet":
        m = UNet(2, 2).to(device)
    else:
        m = simpleConvNet(2, 2).to(device)

    print("loading model weights from \"{}\"".format(CHECKPOINT))
    checkpoint = torch.load(CHECKPOINT)
    m.load_state_dict(checkpoint['model_state_dict'])
    m.to(device)  # pushing weights to gpu
    train_loss = checkpoint['train_loss']
    test_loss = checkpoint['test_loss']
    START = checkpoint['epoch']

    scans = [
        os.path.join(root_dir, i)
        for i in os.listdir(os.path.abspath(root_dir)) if
        os.path.isdir(os.path.join(os.path.abspath(root_dir), i)) and "_" in i
    ]
    if len(scans) == 0:
        print(
            "no scan data found (folder name must be in format mmddhhmmss_x with x beeing the serialnumber"
        )
        exit()
    else:
예제 #18
0
                            pin_memory=True,
                            shuffle=False,
                            drop_last=False)

# learning-rate
LEARNING_RATE = 1e-3

# Число эпох
N_EPOCHS = 10

# tensorboard
writer = SummaryWriter(log_dir='./{}'.format(MODEL_NAME), comment=MODEL_NAME)

# Задаем модель
model = UNet(3, NUM_PTS)
model.to(device)
with torch.no_grad():
    # writer.add_graph(model, next(iter(val_dataloader))['image'].to(device))
    summary(model, next(iter(train_dataloader))['image'].shape[1:])

# Задаем параметры оптимизации
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, amsgrad=True)
# criterion = F.mse_loss
criterion = AdaptiveWingLoss()

# Временные параметры для выбора наилучшего результата
best_val_loss, best_model_state_dict = np.inf, {}

CURRENT_EPOCH = 0

for epoch in range(CURRENT_EPOCH, N_EPOCHS):