Beispiel #1
0
def main(args):

    # tensorboard
    logger_tb = logger.Logger(log_dir=args.experiment_name)

    # get dataset
    if args.dataset == "nuclei":
        train_dataset = NucleiDataset(args.train_data, 'train', args.transform,
                                      args.target_channels)
    elif args.dataset == "hpa":
        train_dataset = HPADataset(args.train_data, 'train', args.transform,
                                   args.max_mean, args.target_channels)
    elif args.dataset == "hpa_single":
        train_dataset = HPASingleDataset(args.train_data, 'train',
                                         args.transform)
    else:
        train_dataset = NeuroDataset(args.train_data, 'train', args.transform)

    # create dataloader
    train_params = {
        'batch_size': args.batch_size,
        'shuffle': False,
        'num_workers': args.num_workers
    }
    train_dataloader = DataLoader(train_dataset, **train_params)

    # device
    device = torch.device(args.device)

    # model
    if args.model == "fusion":
        model = FusionNet(args, train_dataset.dim)
    elif args.model == "dilation":
        model = DilationCNN(train_dataset.dim)
    elif args.model == "unet":
        model = UNet(args.num_kernel, args.kernel_size, train_dataset.dim,
                     train_dataset.target_dim)

    if args.device == "cuda":
        # parse gpu_ids for data paralle
        if ',' in args.gpu_ids:
            gpu_ids = [int(ids) for ids in args.gpu_ids.split(',')]
        else:
            gpu_ids = int(args.gpu_ids)

        # parallelize computation
        if type(gpu_ids) is not int:
            model = nn.DataParallel(model, gpu_ids)
    model.to(device)

    # optimizer
    parameters = model.parameters()
    if args.optimizer == "adam":
        optimizer = torch.optim.Adam(parameters, args.lr)
    else:
        optimizer = torch.optim.SGD(parameters, args.lr)

    # loss
    loss_function = dice_loss

    count = 0
    # train model
    for epoch in range(args.epoch):
        model.train()

        with tqdm.tqdm(total=len(train_dataloader.dataset),
                       unit=f"epoch {epoch} itr") as progress_bar:
            total_loss = []
            total_iou = []
            total_precision = []
            for i, (x_train, y_train) in enumerate(train_dataloader):

                with torch.set_grad_enabled(True):

                    # send data and label to device
                    x = torch.Tensor(x_train.float()).to(device)
                    y = torch.Tensor(y_train.float()).to(device)

                    # predict segmentation
                    pred = model.forward(x)

                    # calculate loss
                    loss = loss_function(pred, y)
                    total_loss.append(loss.item())

                    # calculate IoU precision
                    predictions = pred.clone().squeeze().detach().cpu().numpy()
                    gt = y.clone().squeeze().detach().cpu().numpy()
                    ious = [
                        metrics.get_ious(p, g, 0.5)
                        for p, g in zip(predictions, gt)
                    ]
                    total_iou.append(np.mean(ious))

                    # back prop
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                # log loss and iou
                avg_loss = np.mean(total_loss)
                avg_iou = np.mean(total_iou)

                logger_tb.update_value('train loss', avg_loss, count)
                logger_tb.update_value('train iou', avg_iou, count)

                # display segmentation on tensorboard
                if i == 0:
                    original = x_train[0].squeeze()
                    truth = y_train[0].squeeze()
                    seg = pred[0].cpu().squeeze().detach().numpy()

                    # TODO display segmentations based on number of ouput
                    logger_tb.update_image("truth", truth, count)
                    logger_tb.update_image("segmentation", seg, count)
                    logger_tb.update_image("original", original, count)

                    count += 1
                    progress_bar.update(len(x))

    # save model
    ckpt_dict = {
        'model_name': model.__class__.__name__,
        'model_args': model.args_dict(),
        'model_state': model.to('cpu').state_dict()
    }
    experiment_name = f"{model.__class__.__name__}_{args.dataset}_{train_dataset.target_dim}c"
    if args.dataset == "HPA":
        experiment_name += f"_{args.max_mean}"
    experiment_name += f"_{args.num_kernel}"
    ckpt_path = os.path.join(args.save_dir, f"{experiment_name}.pth")
    torch.save(ckpt_dict, ckpt_path)
Beispiel #2
0
def main(args):

    # tensorboard
    logger_tb = logger.Logger(log_dir=args.experiment_name)

    #augmenter = get_augmenter(args)

    # train dataloader and val dataset
    train_dataset = NucleiDataset(args.train_data, 'train', transform=True)
    val_dataset = NucleiDataset(args.val_data, 'val', transform=True)

    train_params = {
        'batch_size': args.batch_size,
        'shuffle': False,
        'num_workers': args.num_workers
    }

    train_dataloader = DataLoader(train_dataset, **train_params)

    # device
    device = torch.device(args.device)

    # model
    if args.model == "fusion":
        model = FusionNet(args, train_dataset.dim)
    elif args.model == "dilation":
        model = DilationCNN(train_dataset.dim)
    elif args.model == "unet":
        model = UNet(args.num_kernel, args.kernel_size, train_dataset.dim)

    if args.device == "cuda":
        # parse gpu_ids for data paralle
        if ',' in args.gpu_ids:
            gpu_ids = [int(ids) for ids in args.gpu_ids.split(',')]
        else:
            gpu_ids = int(args.gpu_ids)

        # parallelize computation
        if type(gpu_ids) is not int:
            model = nn.DataParallel(model, gpu_ids)
    model.to(device)

    # optimizer
    parameters = model.parameters()
    if args.optimizer == "adam":
        optimizer = torch.optim.Adam(parameters, args.lr)
    else:
        optimizer = torch.optim.SGD(parameters, args.lr)

    # loss
    loss_function = dice_loss

    # train model
    for epoch in range(args.epoch):
        model.train()

        with tqdm.tqdm(total=len(train_dataloader.dataset),
                       unit=f"epoch {epoch} itr") as progress_bar:
            total_loss = []
            total_iou = []
            total_precision = []
            for i, (x_train, y_train) in enumerate(train_dataloader):

                with torch.set_grad_enabled(True):

                    # send data and label to device
                    x = torch.Tensor(x_train.float()).to(device)
                    y = torch.Tensor(y_train.float()).to(device)

                    # predict segmentation
                    pred = model.forward(x)

                    # calculate loss
                    loss = loss_function(pred, y)
                    total_loss.append(loss.item())

                    # calculate IoU precision
                    predictions = pred.clone().squeeze().detach().cpu().numpy()
                    gt = y.clone().squeeze().detach().cpu().numpy()
                    ious = [
                        metrics.get_ious(p, g, 0.5)
                        for p, g in zip(predictions, gt)
                    ]
                    total_iou.append(np.mean(ious))

                    # back prop
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                # log loss and iou
                avg_loss = np.mean(total_loss)
                avg_iou = np.mean(total_iou)

                logger_tb.update_value('train loss', avg_loss, epoch)
                logger_tb.update_value('train iou', avg_iou, epoch)

                progress_bar.update(len(x))

        # validation
        model.eval()
        for idx in range(len(val_dataset)):
            x_val, y_val, mask_val = val_dataset.__getitem__(idx)

            total_precision = []
            total_iou = []
            total_loss = []
            with torch.no_grad():

                # send data and label to device
                x_val = np.expand_dims(x_val, axis=0)
                x = torch.Tensor(torch.tensor(x_val).float()).to(device)
                y = torch.Tensor(torch.tensor(y_val).float()).to(device)

                # predict segmentation
                pred = model.forward(x)

                # calculate loss
                loss = loss_function(pred, y)
                total_loss.append(loss.item())

                # calculate IoU
                prediction = pred.clone().squeeze().detach().cpu().numpy()
                gt = y.clone().squeeze().detach().cpu().numpy()
                iou = metrics.get_ious(prediction, gt, 0.5)
                total_iou.append(iou)

                # calculate precision
                precision = metrics.compute_precision(prediction, mask_val,
                                                      0.5)
                total_precision.append(precision)

                # display segmentation on tensorboard
                if idx == 1:
                    original = x_val
                    truth = np.expand_dims(y_val, axis=0)
                    seg = pred.cpu().squeeze().detach().numpy()
                    seg = np.expand_dims(seg, axis=0)

                    logger_tb.update_image("original", original, 0)
                    logger_tb.update_image("ground truth", truth, 0)
                    logger_tb.update_image("segmentation", seg, epoch)

        # log metrics
        logger_tb.update_value('val loss', np.mean(total_loss), epoch)
        logger_tb.update_value('val iou', np.mean(total_iou), epoch)
        logger_tb.update_value('val precision', np.mean(total_precision),
                               epoch)

    # save model
    ckpt_dict = {
        'model_name': model.__class__.__name__,
        'model_args': model.args_dict(),
        'model_state': model.to('cpu').state_dict()
    }
    ckpt_path = os.path.join(args.save_dir, f"{model.__class__.__name__}.pth")
    torch.save(ckpt_dict, ckpt_path)