Beispiel #1
0
def validation(net, testing_data, x_criterion):
    net.eval()
    val_dice_loss = 0.0
    accuracy = 0.0

    with torch.no_grad():
        for i, data in enumerate(batch_loader(testing_data, 4)):
            image = [sample['image'] for sample in data]
            label = [sample['label'] for sample in data]

            image = torch.from_numpy(np.array(image))
            label = torch.from_numpy(np.array(label))

            if GPU:
                image = image.cuda()
                label = label.cuda()

            Y = net(image)
            Y_softmax = F.softmax(Y, dim=1)

            val_dice_loss += losses.dice_loss(Y_softmax[:, 1, :, :, :],
                                              label[:, 1, :, :, :]).item()

            predictions = Y_softmax.argmax(dim=1, keepdim=True).view_as(
                label[:, 1, :, :, :])
            accuracy += predictions.eq(
                label[:, 1, :, :, :].long()).sum().item() / label.sum()
            #print (label.shape, label.sum())

    val_dice_loss /= (i + 1)
    accuracy /= (i + 1)
    return val_dice_loss, accuracy
Beispiel #2
0
def train(args, epoch, model, train_loader, optimizer, writer, lr_scheduler):
    model.train()
    nProcessed = 0
    batch_size = args.ngpu * args.batch_size
    nTrain = len(train_loader.dataset)
    loss_list = []

    for batch_idx, sample in enumerate(train_loader):
        # read data
        image, target = sample['image'], sample['label']
        image, target = Variable(image.cuda()), Variable(target.cuda(),
                                                         requires_grad=False)

        # forward
        if args.use_tm:
            outputs, tm = model(image)
            tm = torch.sigmoid(tm)
        else:
            outputs = model(image)
        outputs_soft = F.softmax(outputs, dim=1)

        loss_seg_dice = dice_loss(outputs_soft[:, 1, :, :, :], target == 1)
        if args.use_tm:
            loss_threshold = threshold_loss(outputs_soft[:, 1, :, :, :],
                                            tm[:, 0, ...], target == 1)
            loss = loss_seg_dice + 3 * loss_threshold
        else:
            loss = loss_seg_dice

        # backward
        lr = lr_scheduler(optimizer, batch_idx, epoch, 0)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # visualize
        loss_list.append(loss.item())
        writer.add_scalar('lr', lr, epoch)

        # visualization
        nProcessed += len(image)
        partialEpoch = epoch + batch_idx / len(train_loader)
        print('Train Epoch: {:.2f} [{}/{} ({:.0f}%)]\tLoss: {:.8f}'.format(
            partialEpoch, nProcessed, nTrain,
            100. * batch_idx / len(train_loader), loss.item()))

    writer.add_scalar('train_loss', float(np.mean(loss_list)), epoch)
        time1 = time.time()
        for i_batch, sampled_batch in enumerate(trainloader):
            time2 = time.time()
            # print('fetch data cost {}'.format(time2-time1))
            volume_batch, label_batch = sampled_batch['image'], sampled_batch['label']
            volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda()
            out_dis = net(volume_batch)
            # writer.add_graph(net, volume_batch)

            # compute L1 loss between SDF Prediction and GT_SDF
            with torch.no_grad():
                gt_dis = compute_sdf(label_batch.cpu().numpy(), out_dis.shape)
                # print('np.max(gt_dis), np.min(gt_dis): ', np.max(gt_dis), np.min(gt_dis))
                gt_dis = torch.from_numpy(gt_dis).float().cuda()
                gt_dis_prob = torch.sigmoid(-1500*gt_dis)
                gt_dis_dice = dice_loss(gt_dis_prob[:, 0, :, :, :], label_batch == 1)
                # gt_dis_dice loss should be <= 0.05 (Dice Score>0.95), which means the pre-computed SDF is right.
                print('check gt_dis; dice score = ', 1 - gt_dis_dice.cpu().numpy())

            
            # compute product and L1 loss between SDF Prediction and GT_SDF
            loss_sdf_aaai = AAAI_sdf_loss(out_dis, gt_dis)
            # SDF Prediction -> heaviside function [0,1] -> Dice loss
            outputs_soft = torch.sigmoid(-1500*out_dis)
            loss_seg_dice = dice_loss(outputs_soft[:, 0, :, :, :], label_batch == 1)

            loss = loss_sdf_aaai + 10 * loss_sdf_aaai   # lambda=10 in this paper

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    # https://arxiv.org/pdf/1901.05555.pdf - balancing by the inverse root of class frequency. perhaps. Or by the effective number of samples

    with torch.no_grad():
        assert loss.ndim == 3
        perclass_items = y.sum(dim=[0, 2])
        mask = perclass_items < 1e-3
        balanced_perclass_items = loss.shape[0] * loss.shape[2] // (y.shape[1] - mask.sum())
        weights = balanced_perclass_items / perclass_items
        weights[mask] = 0
    return loss * weights.reshape(1, 2, 1)


kldivloss = torch.nn.KLDivLoss(reduction='none')
# losspeaks = lambda y_pred, y: (focal_loss(y_pred, y).mean(dim=2) + dice_loss(y_         pred, y)).mean(dim=1)
losspeaks = lambda y_pred, y: dice_loss(y_pred, y) / 2 + (kldivloss(y_pred, y) * 50).mean(dim=[1, 2])
# losspeaks = lambda y_pred, y: focal_loss(y_pred, y)
loss_enrichment = torch.nn.MSELoss(reduction='none')
trainer = Engine(partial(
    trainval.doiteration, loss_enrichment=loss_enrichment, loss_peaks=losspeaks, model=model, device=DEVICE,
    optimizer=optimizer
))
metrics.attach(trainer, "train")
ProgressBar(ncols=100).attach(trainer, metric_names='all')

trainval.attach_validation(
    trainer,
    partial(trainval.doiteration, loss_enrichment=loss_enrichment, loss_peaks=losspeaks, model=model, device=DEVICE),
    model, valloaders
)
trainval.attach_logger(trainer)
Beispiel #5
0
    max_epoch = max_iterations // len(trainloader) + 1
    lr_ = base_lr
    net.train()
    for epoch_num in tqdm(range(max_epoch), ncols=70):
        time1 = time.time()
        for i_batch, sampled_batch in enumerate(trainloader):
            time2 = time.time()
            # print('fetch data cost {}'.format(time2-time1))
            volume_batch, label_batch = sampled_batch['image'], sampled_batch[
                'label']
            volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda()
            outputs = net(volume_batch)

            loss_seg = F.cross_entropy(outputs, label_batch)
            outputs_soft = F.softmax(outputs, dim=1)
            loss_seg_dice = dice_loss(outputs_soft[:, 1, :, :, :],
                                      label_batch == 1)
            loss = 0.5 * (loss_seg + loss_seg_dice)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            iter_num = iter_num + 1
            writer.add_scalar('lr', lr_, iter_num)
            writer.add_scalar('loss/loss_seg', loss_seg, iter_num)
            writer.add_scalar('loss/loss_seg_dice', loss_seg_dice, iter_num)
            writer.add_scalar('loss/loss', loss, iter_num)
            logging.info('iteration %d : loss : %f' % (iter_num, loss.item()))
            if iter_num % 50 == 0:
                image = volume_batch[0, 0:1, :, :,
                                     20:61:10].permute(3, 0, 1,
Beispiel #6
0
            time2 = time.time()
            # print('fetch data cost {}'.format(time2-time1))
            volume_batch, label_batch = sampled_batch['image'], sampled_batch[
                'label']
            volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda()
            outputs = net(volume_batch)

            loss_seg = F.cross_entropy(
                outputs,
                label_batch,
                weight=torch.tensor(cls_weights, dtype=torch.float32).cuda())
            outputs_soft = F.softmax(outputs, dim=1)
            loss_seg_dice = 0
            print('\n')
            for i in range(num_classes):
                loss_mid = losses.dice_loss(outputs_soft[:, i, :, :, :],
                                            label_batch == i)
                loss_seg_dice += loss_mid
                print('dice score (1-dice_loss): {:.3f}'.format(1 - loss_mid))
            print('dicetotal:{:.3f}'.format(loss_seg_dice))
            loss = 0.5 * (loss_seg + loss_seg_dice)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            iter_num = iter_num + 1
            writer.add_scalar('lr', lr_, iter_num)
            writer.add_scalar('loss/loss_seg', loss_seg, iter_num)
            writer.add_scalar('loss/loss_seg_dice', loss_seg_dice, iter_num)
            writer.add_scalar('loss/loss', loss, iter_num)
            logging.info(
Beispiel #7
0
            # print('fetch data cost {}'.format(time2-time1))
            volume_batch, label_batch = sampled_batch['image'], sampled_batch['label']
            volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda()

            outputs_tanh, outputs = model(volume_batch)
            outputs_soft = torch.sigmoid(outputs)

            # calculate the loss
            with torch.no_grad():
                gt_dis = compute_sdf(label_batch[:].cpu(
                ).numpy(), outputs[:labeled_bs, 0, ...].shape)
                gt_dis = torch.from_numpy(gt_dis).float().cuda()
            loss_sdf = mse_loss(outputs_tanh[:labeled_bs, 0, ...], gt_dis)
            loss_seg = ce_loss(
                outputs[:labeled_bs, 0, ...], label_batch[:labeled_bs].float())
            loss_seg_dice = losses.dice_loss(
                outputs_soft[:labeled_bs, 0, :, :, :], label_batch[:labeled_bs] == 1)
            dis_to_mask = torch.sigmoid(-1500*outputs_tanh)

            consistency_loss = F.mse_loss(dis_to_mask, outputs_soft)
            supervised_loss = loss_seg_dice + args.beta * loss_sdf

            loss = supervised_loss + 0.1 * consistency_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            dc = metrics.dice(torch.argmax(
                outputs_soft[:labeled_bs], dim=1), label_batch[:labeled_bs])

            iter_num = iter_num + 1
Beispiel #8
0
def unet_3d_model_fn(imgs, gts, mode, params):
    # Append dimension for filters  (80, 366, 366) -> (80, 366, 366, 1)
    imgs = tf.expand_dims(imgs, 4)

    if params['leaky_relu']:
        activation = tf.nn.leaky_relu
    else:
        activation = tf.nn.relu

    levels = list()
    pool = imgs
    # Go down
    for layer_depth in range(params['depth']):
        if layer_depth < params['depth'] - 1:
            conv, pool = conv_conv_pool(pool, name=layer_depth, batch_norm=params['batch_norm'],
                                        n_filters=[params["n_base_filters"]*(2**layer_depth),
                                                   params["n_base_filters"]*(2**layer_depth)*2],
                                        activation=activation)
            levels.append([conv, pool])
        else:
            current_layer = conv_conv_pool(pool, name=layer_depth, pool=False, batch_norm=params['batch_norm'],
                                           n_filters=[params["n_base_filters"] * (2 ** layer_depth),
                                                      params["n_base_filters"] * (2 ** layer_depth) * 2],
                                           activation=activation)
            levels.append([current_layer])
    # Go up
    for i, layer_depth in enumerate(range(params['depth']-2, -1, -1)):
        concat = upconv_concat(current_layer, levels[layer_depth][0], n_filter=current_layer.shape[-1], name=params['depth']+i)
        current_layer = conv_conv_pool(concat, name=params['depth']+i, pool=False, batch_norm=params['batch_norm'],
                                       n_filters=[levels[layer_depth][0].shape[-1],
                                                  levels[layer_depth][0].shape[-1]],
                                       activation=activation)

    logits = tf.layers.conv3d(inputs=current_layer, filters=1, kernel_size=(1, 1, 1), padding='same', name='final')
    predictions = tf.nn.sigmoid(logits)
    # l2_loss = tf.losses.get_regularization_loss()
    # loss += l2_loss

    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {"predictions": predictions,
                       "probabilities": logits,
                       "imgs": imgs}
        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

    gts = tf.expand_dims(gts, 4)
    # gts = tf.cast(gts, dtype=tf.float32)

    loss = dice_loss(gts, predictions)
    # loss = -tf.reduce_sum(dice(gts, predictions))
    # loss = tf.Print(loss, [loss])

    if mode == tf.estimator.ModeKeys.EVAL:
        specs = dict(mode=mode,
                     predictions={"preds": predictions,
                                  "probabilities": logits},
                     loss=loss,
                     eval_metric_ops={
                         "dice_eval": streaming_dice(labels=gts, predictions=predictions),
                         "loss_eval": tf.metrics.mean(loss)})

    else: # TRAIN
        global_step = tf.train.get_or_create_global_step()
        opt = tf.train.AdamOptimizer(learning_rate=params['learning_rate'])
        train_op = opt.minimize(loss=loss, global_step=global_step)
        specs = dict(
            mode=mode,
            loss=loss,
            train_op=train_op,
        )

    return tf.estimator.EstimatorSpec(**specs)
Beispiel #9
0
def main():
    ###################
    # init parameters #
    ###################
    args = get_args()
    # training path
    train_data_path = args.root_path
    # writer
    idx = args.save.rfind('/')
    log_dir = args.writer_dir + args.save[idx:]
    writer = SummaryWriter(log_dir)

    batch_size = args.batch_size * args.ngpu 
    max_iterations = args.max_iterations
    base_lr = args.base_lr

    #patch_size = (112, 112, 112)
    #patch_size = (160, 160, 160)
    patch_size = (64, 128, 128)
    num_classes = 2


    # random
    if args.deterministic:
        cudnn.benchmark = False
        cudnn.deterministic = True
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)

    ## make logger file
    if os.path.exists(args.save):
        shutil.rmtree(args.save)
    os.makedirs(args.save, exist_ok=True)
    snapshot_path = args.save
    logging.basicConfig(filename=snapshot_path+"/log.txt", level=logging.INFO,
                        format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')
    logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
    logging.info(str(args))

    net = VNet(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=True, use_tm=args.use_tm)
    net = net.cuda()

    #db_train = LAHeart(base_dir=train_data_path,
    #                   split='train',
    #                   transform = transforms.Compose([
    #                      RandomRotFlip(),
    #                      RandomCrop(patch_size),
    #                      ToTensor(),
    #                      ]))

    db_train = ABUS(base_dir=args.root_path,
                       split='train',
                       use_dismap=args.use_dismap,
                       transform = transforms.Compose([RandomRotFlip(use_dismap=args.use_dismap), RandomCrop(patch_size, use_dismap=args.use_dismap), ToTensor(use_dismap=args.use_dismap)]))
    def worker_init_fn(worker_id):
        random.seed(args.seed+worker_id)
    trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True,  num_workers=4, pin_memory=True, worker_init_fn=worker_init_fn)

    net.train()
    optimizer = optim.SGD(net.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001)
    #gdl = GeneralizedDiceLoss()

    logging.info("{} itertations per epoch".format(len(trainloader)))

    iter_num = 0
    alpha = 1.0
    max_epoch = max_iterations//len(trainloader)+1
    lr_ = base_lr
    net.train()
    for epoch_num in tqdm(range(max_epoch), ncols=70):
        time1 = time.time()
        for i_batch, sampled_batch in enumerate(trainloader):
            time2 = time.time()
            # print('fetch data cost {}'.format(time2-time1))
            volume_batch, label_batch, dis_map_batch = sampled_batch['image'], sampled_batch['label'], sampled_batch['dis_map']
            volume_batch, label_batch, dis_map_batch = volume_batch.cuda(), label_batch.cuda(), dis_map_batch.cuda()
            #print('volume_batch.shape: ', volume_batch.shape)
            if args.use_tm:
                outputs, tm = net(volume_batch)
                tm = torch.sigmoid(tm)
            else:
                outputs = net(volume_batch)
            #print('volume_batch.shape: ', volume_batch.shape)
            #print('outputs.shape, ', outputs.shape)

            loss_seg = F.cross_entropy(outputs, label_batch)
            outputs_soft = F.softmax(outputs, dim=1)
            #print(outputs_soft.shape)
            #print(label_batch.shape)
            #loss_seg_dice = gdl(outputs_soft, label_batch)
            loss_seg_dice = dice_loss(outputs_soft[:, 1, :, :, :], label_batch == 1)
            #with torch.no_grad():
            #    # defalut using compute_sdf; however, compute_sdf1_1 is also worth to try;
            #    gt_sdf_npy = compute_sdf(label_batch.cpu().numpy(), outputs_soft.shape)
            #    gt_sdf = torch.from_numpy(gt_sdf_npy).float().cuda(outputs_soft.device.index)
            #    print('gt_sdf.shape: ', gt_sdf.shape)
            #loss_boundary = boundary_loss(outputs_soft, gt_sdf)

            #print('dis_map.shape: ', dis_map_batch.shape)
            loss_boundary = boundary_loss(outputs_soft, dis_map_batch)

            if args.use_tm:
                loss_threshold = threshold_loss(outputs_soft[:, 1, :, :, :], tm[:, 0, ...], label_batch == 1)
                loss_th = (0.1 * loss_seg + 0.9 * loss_seg_dice) + 3 * loss_threshold
                loss = alpha*(loss_th) + (1 - alpha) * loss_boundary
            else:
                loss = alpha * loss_seg_dice + (1-alpha) * loss_boundary

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            out = outputs_soft.max(1)[1]
            dice = GeneralizedDiceLoss.dice_coeficient(out, label_batch)

            iter_num = iter_num + 1
            writer.add_scalar('train/lr', lr_, iter_num)
            writer.add_scalar('train/loss_seg', loss_seg, iter_num)
            writer.add_scalar('train/loss_seg_dice', loss_seg_dice, iter_num)
            writer.add_scalar('train/alpha', alpha, iter_num)
            writer.add_scalar('train/loss', loss, iter_num)
            writer.add_scalar('train/dice', dice, iter_num)
            if args.use_tm:
                writer.add_scalar('train/loss_threshold', loss_threshold, iter_num)
            if args.use_dismap:
                writer.add_scalar('train/loss_dis', loss_boundary, iter_num)

            logging.info('iteration %d : loss : %f' % (iter_num, loss.item()))
            logging.info('iteration %d : alpha : %f' % (iter_num, alpha))

            if iter_num % 50 == 0:
                image = volume_batch[0, 0:1, :, 30:71:10, :].permute(2,0,1,3)
                image = (image + 0.5) * 0.5
                grid_image = make_grid(image, 5)
                writer.add_image('train/Image', grid_image, iter_num)

                #outputs_soft = F.softmax(outputs, 1) #batchsize x num_classes x w x h x d
                image = outputs_soft[0, 1:2, :, 30:71:10, :].permute(2,0,1,3)
                grid_image = make_grid(image, 5, normalize=False)
                grid_image = grid_image.cpu().detach().numpy().transpose((1,2,0))

                gt = label_batch[0, :, 30:71:10, :].unsqueeze(0).permute(2,0,1,3)
                grid_gt = make_grid(gt, 5, normalize=False)
                grid_gt = grid_gt.cpu().detach().numpy().transpose((1,2,0))

                image_tm = dis_map_batch[0, :, :, 30:71:10, :].permute(2,0,1,3)
                #image_tm = tm[0, :, :, 30:71:10, :].permute(2,0,1,3)
                grid_tm = make_grid(image_tm, 5, normalize=False)
                grid_tm = grid_tm.cpu().detach().numpy().transpose((1,2,0))


                fig = plt.figure()
                ax = fig.add_subplot(311)
                ax.imshow(grid_gt[:, :, 0], 'gray')
                ax = fig.add_subplot(312)
                cs = ax.imshow(grid_image[:, :, 0], 'hot', vmin=0., vmax=1.)
                fig.colorbar(cs, ax=ax, shrink=0.9)
                ax = fig.add_subplot(313)
                cs = ax.imshow(grid_tm[:, :, 0], 'hot', vmin=0, vmax=1.)
                fig.colorbar(cs, ax=ax, shrink=0.9)
                writer.add_figure('train/prediction_results', fig, iter_num)
                fig.clear()

            ## change lr
            if iter_num % 5000 == 0:
                lr_ = base_lr * 0.1 ** (iter_num // 5000)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr_
            if iter_num % 1000 == 0:
                save_mode_path = os.path.join(snapshot_path, 'iter_' + str(iter_num) + '.pth')
                torch.save(net.state_dict(), save_mode_path)
                logging.info("save model to {}".format(save_mode_path))

            if iter_num > max_iterations:
                break
            time1 = time.time()
        alpha -= 0.005
        if alpha <= 0.01:
            alpha = 0.01
        if iter_num > max_iterations:
            break
    save_mode_path = os.path.join(snapshot_path, 'iter_'+str(max_iterations+1)+'.pth')
    torch.save(net.state_dict(), save_mode_path)
    logging.info("save model to {}".format(save_mode_path))
    writer.close()
Beispiel #10
0
def main():
    ###################
    # init parameters #
    ###################
    args = get_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    # training path
    train_data_path = args.root_path

    # writer
    idx = args.save.rfind('/')
    log_dir = args.writer_dir + args.save[idx:]
    writer = SummaryWriter(log_dir)

    batch_size = args.batch_size * args.ngpu 
    max_iterations = args.max_iterations
    base_lr = args.base_lr

    patch_size = (64, 128, 128)
    num_classes = 2

    # random
    if args.deterministic:
        cudnn.benchmark = False
        cudnn.deterministic = True
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)

    ## make logger file
    if os.path.exists(args.save):
        shutil.rmtree(args.save)
    os.makedirs(args.save, exist_ok=True)
    snapshot_path = args.save
    logging.basicConfig(filename=snapshot_path+"/log.txt", level=logging.INFO,
                        format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')
    logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
    logging.info(str(args))

    # network
    if args.arch == 'vnet':
        net = VNet(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=True, use_tm=args.use_tm)
    elif args.arch == 'd2unet':
        net = D2UNet()
    else:
        raise(NotImplementedError('model {} not implement'.format(args.arch))) 
    net = net.cuda()

    # dataset 
    def worker_init_fn(worker_id):
        random.seed(args.seed+worker_id)
    db_train = ABUS(base_dir=args.root_path,
                       split='train',
                       fold=args.fold,
                       transform = transforms.Compose([RandomRotFlip(), RandomCrop(patch_size), ToTensor()]))
    db_val = ABUS(base_dir=args.root_path,
                       split='val',
                       fold=args.fold,
                       transform = transforms.Compose([CenterCrop(patch_size), ToTensor()]))
    trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True,  num_workers=4, pin_memory=True, worker_init_fn=worker_init_fn)

    # optimizer
    optimizer = optim.SGD(net.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001)
    gdl = GeneralizedDiceLoss()

    logging.info("{} itertations per epoch".format(len(trainloader)))

    # training
    iter_num = 0
    max_epoch = max_iterations//len(trainloader)+1
    lr_ = base_lr
    net.train()
    for epoch_num in tqdm(range(max_epoch), ncols=70):
        for i_batch, sampled_batch in enumerate(trainloader):
            volume_batch, label_batch = sampled_batch['image'], sampled_batch['label']
            volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda()
            if args.use_tm: 
                outputs, tm = net(volume_batch)
                tm = torch.sigmoid(tm)
            else:
                outputs = net(volume_batch)

            loss_seg = F.cross_entropy(outputs, label_batch)
            outputs_soft = F.softmax(outputs, dim=1)
            loss_seg_dice = dice_loss(outputs_soft[:, 1, :, :, :], label_batch == 1)
            if args.use_tm: 
                loss_threshold = threshold_loss(outputs_soft[:, 1, :, :, :], tm[:, 0, ...], label_batch == 1)
                loss = loss_seg_dice + 3 * loss_threshold
            else:
                loss = loss_seg_dice

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()


            # visualization on tensorboard
            out = outputs_soft.max(1)[1]
            dice = GeneralizedDiceLoss.dice_coeficient(out, label_batch)

            iter_num = iter_num + 1
            writer.add_scalar('train/lr', lr_, iter_num)
            writer.add_scalar('train/loss_seg', loss_seg, iter_num)
            writer.add_scalar('train/loss_seg_dice', loss_seg_dice, iter_num)
            writer.add_scalar('train/loss', loss, iter_num)
            writer.add_scalar('train/dice', dice, iter_num)
            if args.use_tm:
                writer.add_scalar('train/loss_threshold', loss_threshold, iter_num)

            logging.info('iteration %d : loss : %f' % (iter_num, loss.item()))

            if iter_num % 50 == 0:
                nrow = 5
                image = volume_batch[0, 0:1, :, 30:71:10, :].permute(2,0,1,3)
                image = (image + 0.5) * 0.5
                grid_image = make_grid(image, nrow=nrow)
                writer.add_image('train/Image', grid_image, iter_num)

                #outputs_soft = F.softmax(outputs, 1) #batchsize x num_classes x w x h x d
                image = outputs_soft[0, 1:2, :, 30:71:10, :].permute(2,0,1,3)
                grid_image = make_grid(image, nrow=nrow, normalize=False)
                grid_image = grid_image.cpu().detach().numpy().transpose((1,2,0))

                gt = label_batch[0, :, 30:71:10, :].unsqueeze(0).permute(2,0,1,3)
                grid_gt = make_grid(gt, nrow=nrow, normalize=False)
                grid_gt = grid_gt.cpu().detach().numpy().transpose((1,2,0))

                if args.use_tm:
                    image_tm = tm[0, :, :, 30:71:10, :].permute(2,0,1,3)
                else:
                    image_tm = gt
                grid_tm = make_grid(image_tm, nrow=nrow, normalize=False)
                grid_tm = grid_tm.cpu().detach().numpy().transpose((1,2,0))

                fig = plt.figure()
                ax = fig.add_subplot(311)
                ax.imshow(grid_gt[:, :, 0], 'gray')
                ax = fig.add_subplot(312)
                cs = ax.imshow(grid_image[:, :, 0], 'hot', vmin=0., vmax=1.)
                fig.colorbar(cs, ax=ax, shrink=0.9)
                ax = fig.add_subplot(313)
                cs = ax.imshow(grid_tm[:, :, 0], 'hot', vmin=0, vmax=1.)
                fig.colorbar(cs, ax=ax, shrink=0.9)
                writer.add_figure('train/prediction_results', fig, iter_num)
                fig.clear()

            ## change lr
            if iter_num % 2500 == 0:
                lr_ = base_lr * 0.1 ** (iter_num // 2500)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr_
            if iter_num % 1000 == 0 and iter_num > 5000:
                save_mode_path = os.path.join(snapshot_path, 'iter_' + str(iter_num) + '.pth')
                torch.save(net.state_dict(), save_mode_path)
                logging.info("save model to {}".format(save_mode_path))

            if iter_num > max_iterations:
                break
            time1 = time.time()
        if iter_num > max_iterations:
            break
    save_mode_path = os.path.join(snapshot_path, 'iter_'+str(max_iterations+1)+'.pth')
    torch.save(net.state_dict(), save_mode_path)
    logging.info("save model to {}".format(save_mode_path))
    writer.close()
Beispiel #11
0
            volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda()

            outputs_dis, dis_r = model_dis(volume_batch)
            seg_r, outputs_seg = model_seg(volume_batch)

            softmask_seg = torch.sigmoid(outputs_seg)

            # calculate the loss
            with torch.no_grad():
                gt_dis = compute_sdf(label_batch.cpu().numpy(), outputs_dis.shape)
                gt_dis = torch.from_numpy(gt_dis).float().cuda()
            loss_dis = torch.norm(outputs_dis-gt_dis, 1)/torch.numel(outputs_dis)
            
            loss_ce = ce_loss(
                outputs_seg[:labeled_bs, 0, ...], label_batch[:labeled_bs].float())
            loss_dice = losses.dice_loss(
                softmask_seg[:labeled_bs, 0, :, :, :], label_batch[:labeled_bs] == 1)
            loss_seg = 0.5*(loss_ce + loss_dice)
            
            dis_to_mask = torch.sigmoid(-1500*outputs_dis)
            
            loss_dis_dice = losses.dice_loss(
                dis_to_mask[:labeled_bs, 0, :, :, :], label_batch[:labeled_bs] == 1)
            
            
            consistency_loss = torch.mean((dis_to_mask - softmask_seg) ** 2)             
            consistency_weight = get_current_consistency_weight(iter_num//150)
            
            # model loss
            model_dis_loss = loss_dis_dice + consistency_weight * consistency_loss
            model_seg_loss = loss_seg + consistency_weight * consistency_loss
def main():
    ###################
    # init parameters #
    ###################
    args = get_args()
    # training path
    train_data_path = args.root_path
    # writer
    idx = args.save.rfind('/')
    log_dir = args.writer_dir + args.save[idx:]
    writer = SummaryWriter(log_dir)

    batch_size = args.batch_size * args.ngpu 
    max_iterations = args.max_iterations
    base_lr = args.base_lr

    patch_size = (112, 112, 80)
    num_classes = 2

    # random
    if args.deterministic:
        cudnn.benchmark = False
        cudnn.deterministic = True
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)

    ## make logger file
    if os.path.exists(args.save):
        shutil.rmtree(args.save)
    os.makedirs(args.save, exist_ok=True)
    snapshot_path = args.save
    logging.basicConfig(filename=snapshot_path+"/log.txt", level=logging.INFO,
                        format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')
    logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
    logging.info(str(args))

    net = VNet(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=True)
    net = net.cuda()

    db_train = LAHeart(base_dir=train_data_path,
                       split='train',
                       num=16,
                       transform = transforms.Compose([
                          RandomRotFlip(),
                          RandomCrop(patch_size),
                          ToTensor(),
                          ]))
    db_test = LAHeart(base_dir=train_data_path,
                       split='test',
                       transform = transforms.Compose([
                           CenterCrop(patch_size),
                           ToTensor()
                       ]))
    def worker_init_fn(worker_id):
        random.seed(args.seed+worker_id)
    trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True,  num_workers=4, pin_memory=True, worker_init_fn=worker_init_fn)

    net.train()
    optimizer = optim.SGD(net.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001)

    logging.info("{} itertations per epoch".format(len(trainloader)))

    iter_num = 0
    max_epoch = max_iterations//len(trainloader)+1
    lr_ = base_lr
    net.train()
    for epoch_num in tqdm(range(max_epoch), ncols=70):
        time1 = time.time()
        for i_batch, sampled_batch in enumerate(trainloader):
            time2 = time.time()
            # print('fetch data cost {}'.format(time2-time1))
            volume_batch, label_batch = sampled_batch['image'], sampled_batch['label']
            volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda()
            outputs = net(volume_batch)

            loss_seg = F.cross_entropy(outputs, label_batch)
            outputs_soft = F.softmax(outputs, dim=1)
            loss_seg_dice = dice_loss(outputs_soft[:, 1, :, :, :], label_batch == 1)
            loss = 0.5*(loss_seg+loss_seg_dice)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            iter_num = iter_num + 1
            writer.add_scalar('lr', lr_, iter_num)
            writer.add_scalar('loss/loss_seg', loss_seg, iter_num)
            writer.add_scalar('loss/loss_seg_dice', loss_seg_dice, iter_num)
            writer.add_scalar('loss/loss', loss, iter_num)
            logging.info('iteration %d : loss : %f' % (iter_num, loss.item()))
            if iter_num % 50 == 0:
                image = volume_batch[0, 0:1, :, :, 20:61:10].permute(3,0,1,2).repeat(1,3,1,1)
                grid_image = make_grid(image, 5, normalize=True)
                writer.add_image('train/Image', grid_image, iter_num)

                outputs_soft = F.softmax(outputs, 1)
                image = outputs_soft[0, 1:2, :, :, 20:61:10].permute(3, 0, 1, 2).repeat(1, 3, 1, 1)
                grid_image = make_grid(image, 5, normalize=False)
                writer.add_image('train/Predicted_label', grid_image, iter_num)

                image = label_batch[0, :, :, 20:61:10].unsqueeze(0).permute(3, 0, 1, 2).repeat(1, 3, 1, 1)
                grid_image = make_grid(image, 5, normalize=False)
                writer.add_image('train/Groundtruth_label', grid_image, iter_num)

            ## change lr
            if iter_num % 2500 == 0:
                lr_ = base_lr * 0.1 ** (iter_num // 2500)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr_
            if iter_num % 1000 == 0:
                save_mode_path = os.path.join(snapshot_path, 'iter_' + str(iter_num) + '.pth')
                torch.save(net.state_dict(), save_mode_path)
                logging.info("save model to {}".format(save_mode_path))

            if iter_num > max_iterations:
                break
            time1 = time.time()
        if iter_num > max_iterations:
            break
    save_mode_path = os.path.join(snapshot_path, 'iter_'+str(max_iterations+1)+'.pth')
    torch.save(net.state_dict(), save_mode_path)
    logging.info("save model to {}".format(save_mode_path))
    writer.close()
Beispiel #13
0
            lab_image = lab['image'].to(args.device)
            lab_label = lab['label'].to(args.device)
            unl_image = unl['image'].to(args.device)

            noise = torch.clamp(torch.randn_like(unl_image) * 0.1, -0.2, 0.2)
            ema_inputs = unl_image + noise

            outputs = model(lab_image)  ## outputs logists [N, C, W, H]
            unl_pseu_label = model(unl_image)
            with torch.no_grad():
                ema_output = ema_model(ema_inputs)

            loss_seg = F.cross_entropy(outputs,
                                       torch.squeeze(lab_label, dim=1).long())
            outputs_soft = F.softmax(outputs, dim=1)
            loss_seg_dice = losses.dice_loss(outputs, lab_label)
            supervised_loss = 0.5 * (loss_seg + loss_seg_dice)

            consistency_weight = get_current_consistency_weight(iter_num //
                                                                150)
            consistency_dist = consistency_criterion(unl_pseu_label,
                                                     ema_output)

            consistency_loss = consistency_weight * torch.sum(consistency_dist)
            loss = supervised_loss + 0.001 * consistency_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            update_ema_variables(model, ema_model, args.ema_decay,
                                 iter_num)  ## update teacher model
def main():
    ###################
    # init parameters #
    ###################
    args = get_args()
    # training path
    train_data_path = args.root_path
    # writer
    idx = args.save.rfind('/')
    log_dir = args.writer_dir + args.save[idx:]
    writer = SummaryWriter(log_dir)

    batch_size = args.batch_size * args.ngpu
    max_iterations = args.max_iterations
    base_lr = args.base_lr

    patch_size = (112, 112, 80)
    num_classes = 2

    # random
    if args.deterministic:
        cudnn.benchmark = False
        cudnn.deterministic = True
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)

    ## make logger file
    if os.path.exists(args.save):
        shutil.rmtree(args.save)
    os.makedirs(args.save, exist_ok=True)
    snapshot_path = args.save
    logging.basicConfig(filename=snapshot_path + "/log.txt",
                        level=logging.INFO,
                        format='[%(asctime)s.%(msecs)03d] %(message)s',
                        datefmt='%H:%M:%S')
    logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
    logging.info(str(args))

    #training set
    db_train = LAHeart(base_dir=train_data_path,
                       split='train',
                       num=16,
                       transform=transforms.Compose([
                           RandomRotFlip(),
                           RandomCrop(patch_size),
                           ToTensor()
                       ]))

    def worker_init_fn(worker_id):
        random.seed(args.seed + worker_id)

    trainloader = DataLoader(db_train,
                             batch_size=batch_size,
                             shuffle=True,
                             num_workers=4,
                             pin_memory=True,
                             worker_init_fn=worker_init_fn)

    net = VNet(n_channels=1,
               n_classes=num_classes,
               normalization='batchnorm',
               has_dropout=True)
    net = net.cuda()
    net.train()
    optimizer = optim.SGD(net.parameters(),
                          lr=base_lr,
                          momentum=0.9,
                          weight_decay=0.0001)

    logging.info("{} itertations per epoch".format(len(trainloader)))

    iter_num = 0
    alpha = 1.0
    max_epoch = max_iterations // len(trainloader) + 1
    lr_ = base_lr
    net.train()
    for epoch_num in tqdm(range(max_epoch), ncols=70):
        time1 = time.time()
        for i_batch, sampled_batch in enumerate(trainloader):
            time2 = time.time()
            # print('fetch data cost {}'.format(time2-time1))
            # volume_batch.shape=(b,1,x,y,z) label_patch.shape=(b,x,y,z)
            volume_batch, label_batch = sampled_batch['image'], sampled_batch[
                'label']
            volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda()
            outputs = net(volume_batch)

            loss_seg = F.cross_entropy(outputs, label_batch)
            outputs_soft = F.softmax(outputs, dim=1)
            loss_seg_dice = dice_loss(outputs_soft[:, 1, :, :, :],
                                      label_batch == 1)
            # compute gt_signed distance function and boundary loss
            with torch.no_grad():
                # defalut using compute_sdf; however, compute_sdf1_1 is also worth to try;
                gt_sdf_npy = compute_sdf(label_batch.cpu().numpy(),
                                         outputs_soft.shape)
                gt_sdf = torch.from_numpy(gt_sdf_npy).float().cuda(
                    outputs_soft.device.index)
                # show signed distance map for debug
                # import matplotlib.pyplot as plt
                # plt.figure()
                # plt.subplot(121), plt.imshow(gt_sdf_npy[0,1,:,:,40]), plt.colorbar()
                # plt.subplot(122), plt.imshow(np.uint8(label_batch.cpu().numpy()[0,:,:,40]>0)), plt.colorbar()
                # plt.show()
            loss_boundary = boundary_loss(outputs_soft, gt_sdf)
            loss = alpha * (loss_seg + loss_seg_dice) + (1 -
                                                         alpha) * loss_boundary

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            iter_num = iter_num + 1
            writer.add_scalar('lr', lr_, iter_num)
            writer.add_scalar('loss/loss_seg', loss_seg, iter_num)
            writer.add_scalar('loss/loss_seg_dice', loss_seg_dice, iter_num)
            writer.add_scalar('loss/loss_boundary', loss_boundary, iter_num)
            writer.add_scalar('loss/loss', loss, iter_num)
            writer.add_scalar('loss/alpha', alpha, iter_num)
            logging.info('iteration %d : alpha : %f' % (iter_num, alpha))
            logging.info('iteration %d : loss_seg_dice : %f' %
                         (iter_num, loss_seg_dice.item()))
            logging.info('iteration %d : loss_boundary : %f' %
                         (iter_num, loss_boundary.item()))
            logging.info('iteration %d : loss : %f' % (iter_num, loss.item()))
            if iter_num % 2 == 0:
                image = volume_batch[0, 0:1, :, :,
                                     20:61:10].permute(3, 0, 1,
                                                       2).repeat(1, 3, 1, 1)
                grid_image = make_grid(image, 5, normalize=True)
                writer.add_image('train/Image', grid_image, iter_num)

                image = outputs_soft[0, 1:2, :, :,
                                     20:61:10].permute(3, 0, 1,
                                                       2).repeat(1, 3, 1, 1)
                grid_image = make_grid(image, 5, normalize=False)
                writer.add_image('train/Predicted_label', grid_image, iter_num)

                image = label_batch[0, :, :, 20:61:10].unsqueeze(0).permute(
                    3, 0, 1, 2).repeat(1, 3, 1, 1)
                grid_image = make_grid(image, 5, normalize=False)
                writer.add_image('train/Groundtruth_label', grid_image,
                                 iter_num)

                image = gt_sdf[0, 1:2, :, :,
                               20:61:10].permute(3, 0, 1,
                                                 2).repeat(1, 3, 1, 1)
                grid_image = make_grid(image, 5, normalize=True)
                writer.add_image('train/gt_sdf', grid_image, iter_num)

            ## change lr
            if iter_num % 2500 == 0:
                lr_ = base_lr * 0.1**(iter_num // 2500)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr_
            if iter_num % 1000 == 0:
                save_mode_path = os.path.join(snapshot_path,
                                              'iter_' + str(iter_num) + '.pth')
                torch.save(net.state_dict(), save_mode_path)
                logging.info("save model to {}".format(save_mode_path))

            if iter_num > max_iterations:
                break
            time1 = time.time()
        alpha -= 0.01
        if alpha <= 0.01:
            alpha = 0.01
        if iter_num > max_iterations:
            break
    save_mode_path = os.path.join(snapshot_path,
                                  'iter_' + str(max_iterations + 1) + '.pth')
    torch.save(net.state_dict(), save_mode_path)
    logging.info("save model to {}".format(save_mode_path))
    writer.close()
Beispiel #15
0
def train_epoch(net, eval_net, labelled_data, unlabelled_data, batch_size,
                supervised_only, optimizer, x_criterion, u_criterion, K, T,
                alpha, mixup_mode, Lambda, aug_factor, decay):
    if DEBUG:
        epoch = next(c)
    net.train()
    epoch_loss = 0.0
    for i, (l_batch, u_batch) in enumerate(
            zip(batch_loader(labelled_data, batch_size),
                batch_loader(unlabelled_data, batch_size))):
        #l_batch, u_batch are lists (len = batch_size) of dicts
        l_image = [sample['image'] for sample in l_batch]
        l_label = [sample['label'] for sample in l_batch]
        u_image = [sample['image'] for sample in u_batch]

        X = list(zip(
            l_image,
            l_label))  #list of (image, onehot_label) of length = batch_size
        U = u_image  #list of image of length = batch_size

        if not supervised_only:
            copy_params(net, eval_net, decay)

            X_prime, U_prime = mix_match(X,
                                         U,
                                         eval_net=eval_net,
                                         K=K,
                                         T=T,
                                         alpha=alpha,
                                         mixup_mode=mixup_mode,
                                         aug_factor=aug_factor)

            net.train()

            X_data = torch.from_numpy(np.array([x[0] for x in X_prime]))
            X_label = torch.from_numpy(np.array([x[1] for x in X_prime]))
            U_data = torch.from_numpy(np.array([x[0] for x in U_prime]))
            U_label = torch.from_numpy(np.array([x[1] for x in U_prime]))

            if DEBUG:
                #save_as_image(X_data.numpy(), f"../debug_output/x_data")
                save_as_image(U_data.numpy(),
                              f"../debug_output/u_data_{epoch}")
                #save_as_image(X_label[:, [1], :, :, :].numpy(), f"../debug_output/x_label")
                save_as_image(U_label[:, [1], :, :, :].numpy(),
                              f"../debug_output/u_label_{epoch}")

            if GPU:
                X_data = X_data.cuda()
                X_label = X_label.cuda()
                U_data = U_data.cuda()
                U_label = U_label.cuda()

            X = torch.cat((X_data, U_data), 0)
            Y = net(X)
            Y_x = Y[:len(X_data)]
            Y_u = Y[len(X_data):]

            Y_x_softmax = F.softmax(Y_x, dim=1)
            Y_u_softmax = F.softmax(Y_u, dim=1)

            if DEBUG:
                #save_as_image(Y_x_softmax[:, [1], :, :, :].detach().cpu().numpy(), "../debug_output/x_pred")
                save_as_image(
                    Y_u_softmax[:, [1], :, :, :].detach().cpu().numpy(),
                    f"../debug_output/u_pred_{epoch}")

            loss_x_seg = x_criterion(Y_x_softmax, X_label)
            loss_x_dice = losses.dice_loss(Y_x_softmax[:, 1, :, :, :],
                                           X_label[:, 1, :, :, :])
            loss_x = 0.5 * (loss_x_seg + loss_x_dice)

            loss_u = u_criterion(Y_u_softmax, U_label)

            loss = loss_x + Lambda * loss_u

            if DEBUG:
                print(loss_x.item(), loss_u.item(), loss.item())

        else:
            #supervised_only
            X_data = torch.from_numpy(np.array(l_image))
            X_label = torch.from_numpy(np.array(l_label))
            if DEBUG:
                save_as_image(X_data.numpy(), "../debug_output/s_data")
                save_as_image(X_label[:, [1], :, :, :].numpy(),
                              "../debug_output/s_label")
            if GPU:
                X_data = X_data.cuda()
                X_label = X_label.cuda()

            Y_x = net(X_data)
            Y_x_softmax = F.softmax(Y_x, dim=1)

            if DEBUG:
                save_as_image(
                    Y_x_softmax[:, [1], :, :, :].detach().cpu().numpy(),
                    "../debug_output/s_pred")

            loss_x_seg = x_criterion(Y_x_softmax, X_label)
            loss_x_dice = losses.dice_loss(Y_x_softmax[:, 1, :, :, :],
                                           X_label[:, 1, :, :, :])
            loss = 0.5 * (loss_x_seg + loss_x_dice)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

        if DEBUG:
            break
    return epoch_loss / (i + 1)