Пример #1
0
def train(args, snapshot_path):
    base_lr = args.base_lr
    num_classes = 4
    batch_size = args.batch_size
    max_iterations = args.max_iterations

    def create_model(ema=False):
        # Network definition
        net = UNet('efficientnet-b3',
                   encoder_weights='imagenet',
                   in_channels=1,
                   classes=num_classes)
        model = net.cuda()
        if ema:
            for param in model.parameters():
                param.detach_()
        return model

    model = create_model()
    ema_model = create_model(ema=True)

    def worker_init_fn(worker_id):
        random.seed(args.seed + worker_id)

    db_train = ACDC(base_dir=args.root_path,
                    split="train",
                    num=None,
                    transform=transforms.Compose(
                        [RandomGenerator(args.patch_size)]))
    db_val = ACDC(base_dir=args.root_path, split="val")

    labeled_idxs = list(range(0, args.labeled_num))
    unlabeled_idxs = list(range(args.labeled_num, 1312))
    batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs,
                                          batch_size,
                                          batch_size - args.labeled_bs)

    trainloader = DataLoader(db_train,
                             batch_sampler=batch_sampler,
                             num_workers=4,
                             pin_memory=True,
                             worker_init_fn=worker_init_fn)

    model.train()
    ema_model.train()

    valloader = DataLoader(db_val, batch_size=1, shuffle=False, num_workers=1)

    optimizer = optim.SGD(model.parameters(),
                          lr=base_lr,
                          momentum=0.9,
                          weight_decay=0.0001)
    ce_loss = CrossEntropyLoss()
    dice_loss = losses.DiceLoss(num_classes)

    writer = SummaryWriter(snapshot_path + '/log')
    logging.info("{} iterations per epoch".format(len(trainloader)))

    iter_num = 0
    max_epoch = max_iterations // len(trainloader) + 1
    best_performance = 0.0
    iterator = tqdm(range(max_epoch), ncols=70)
    for epoch_num in iterator:
        for i_batch, sampled_batch in enumerate(trainloader):

            volume_batch, label_batch = sampled_batch['image'], sampled_batch[
                'label']
            volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda()
            unlabeled_volume_batch = volume_batch[args.labeled_bs:]

            noise = torch.clamp(
                torch.randn_like(unlabeled_volume_batch) * 0.1, -0.2, 0.2)
            ema_inputs = unlabeled_volume_batch + noise

            outputs = model(volume_batch)
            outputs_soft = torch.softmax(outputs, dim=1)
            with torch.no_grad():
                ema_output = ema_model(ema_inputs)
            T = 8
            _, _, w, h = unlabeled_volume_batch.shape
            volume_batch_r = unlabeled_volume_batch.repeat(2, 1, 1, 1)
            stride = volume_batch_r.shape[0] // 2
            preds = torch.zeros([stride * T, 4, w, h]).cuda()
            for i in range(T // 2):
                ema_inputs = volume_batch_r + \
                    torch.clamp(torch.randn_like(
                        volume_batch_r) * 0.1, -0.2, 0.2)
                with torch.no_grad():
                    preds[2 * stride * i:2 * stride *
                          (i + 1)] = ema_model(ema_inputs)
            preds = F.softmax(preds, dim=1)
            preds = preds.reshape(T, stride, 4, w, h)
            preds = torch.mean(preds, dim=0)
            uncertainty = -1.0 * \
                torch.sum(preds*torch.log(preds + 1e-6), dim=1, keepdim=True)

            loss_ce = ce_loss(outputs[:args.labeled_bs],
                              label_batch[:args.labeled_bs][:].long())
            loss_dice = dice_loss(outputs_soft[:args.labeled_bs],
                                  label_batch[:args.labeled_bs].unsqueeze(1))
            supervised_loss = 0.5 * (loss_dice + loss_ce)
            consistency_weight = get_current_consistency_weight(iter_num //
                                                                150)
            consistency_dist = losses.softmax_mse_loss(
                outputs[args.labeled_bs:],
                ema_output)  # (batch, 2, 112,112,80)
            threshold = (0.75 + 0.25 * ramps.sigmoid_rampup(
                iter_num, max_iterations)) * np.log(2)
            mask = (uncertainty < threshold).float()
            consistency_loss = torch.sum(
                mask * consistency_dist) / (2 * torch.sum(mask) + 1e-16)

            loss = supervised_loss + consistency_weight * consistency_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            update_ema_variables(model, ema_model, args.ema_decay, iter_num)

            lr_ = base_lr * (1.0 - iter_num / max_iterations)**0.9
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr_

            iter_num = iter_num + 1
            writer.add_scalar('info/lr', lr_, iter_num)
            writer.add_scalar('info/total_loss', loss, iter_num)
            writer.add_scalar('info/loss_ce', loss_ce, iter_num)
            writer.add_scalar('info/loss_dice', loss_dice, iter_num)
            writer.add_scalar('info/consistency_loss', consistency_loss,
                              iter_num)
            writer.add_scalar('info/consistency_weight', consistency_weight,
                              iter_num)
            logging.info(
                'iteration %d : loss : %f, loss_ce: %f, loss_dice: %f' %
                (iter_num, loss.item(), loss_ce.item(), loss_dice.item()))

            if iter_num % 20 == 0:
                image = volume_batch[1, 0:1, :, :]
                writer.add_image('train/Image', image, iter_num)
                outputs = torch.argmax(torch.softmax(outputs, dim=1),
                                       dim=1,
                                       keepdim=True)
                writer.add_image('train/Prediction', outputs[1, ...] * 50,
                                 iter_num)
                labs = label_batch[1, ...].unsqueeze(0) * 50
                writer.add_image('train/GroundTruth', labs, iter_num)

            if iter_num > 0 and iter_num % 200 == 0:
                model.eval()
                first_total, second_total, third_total = 0.0, 0.0, 0.0
                for i_batch, sampled_batch in enumerate(valloader):
                    first, second, third = test_single_volume(
                        sampled_batch["image"], sampled_batch["label"], model)
                    first_total += np.asarray(first)
                    second_total += np.asarray(second)
                    third_total += np.asarray(third)
                first_total, second_total, third_total = first_total / \
                    len(db_val), second_total / \
                    len(db_val), third_total/len(db_val)
                writer.add_scalar('info/val_one_dice', first_total[0],
                                  iter_num)
                writer.add_scalar('info/val_one_hd95', first_total[1],
                                  iter_num)

                writer.add_scalar('info/val_two_dice', second_total[0],
                                  iter_num)
                writer.add_scalar('info/val_two_hd95', second_total[1],
                                  iter_num)

                writer.add_scalar('info/val_three_dice', third_total[0],
                                  iter_num)
                writer.add_scalar('info/val_three_hd95', third_total[1],
                                  iter_num)

                performance = (first_total[0] + second_total[0] +
                               third_total[0]) / 3
                mean_hd95 = (first_total[1] + second_total[1] +
                             third_total[1]) / 3
                writer.add_scalar('info/val_mean_dice', performance, iter_num)
                writer.add_scalar('info/val_mean_hd95', mean_hd95, iter_num)

                if performance > best_performance:
                    best_performance = performance
                    save_mode_path = os.path.join(
                        snapshot_path, 'iter_{}_dice_{}.pth'.format(
                            iter_num, round(best_performance, 4)))
                    save_best = os.path.join(
                        snapshot_path, '{}_best_model.pth'.format(args.model))
                    torch.save(model.state_dict(), save_mode_path)
                    torch.save(model.state_dict(), save_best)

                logging.info('iteration %d : mean_dice : %f mean_hd95 : %f' %
                             (iter_num, performance, mean_hd95))
                model.train()

            if iter_num % 3000 == 0:
                save_mode_path = os.path.join(snapshot_path,
                                              'iter_' + str(iter_num) + '.pth')
                torch.save(model.state_dict(), save_mode_path)
                logging.info("save model to {}".format(save_mode_path))

            if iter_num >= max_iterations:
                break
        if iter_num >= max_iterations:
            iterator.close()
            break
    writer.close()
    return "Training Finished!"
Пример #2
0
def train(args, snapshot_path):
    base_lr = args.base_lr
    num_classes = 4
    batch_size = args.batch_size
    max_iterations = args.max_iterations

    def worker_init_fn(worker_id):
        random.seed(args.seed + worker_id)

    model = UNet('efficientnet-b3',
                 encoder_weights='imagenet',
                 in_channels=1,
                 classes=num_classes).cuda()

    DAN = FCDiscriminator(num_classes=num_classes)
    DAN = DAN.cuda()

    db_train = ACDC(base_dir=args.root_path,
                    split="train",
                    num=None,
                    transform=transforms.Compose(
                        [RandomGenerator(args.patch_size)]))

    labeled_idxs = list(range(0, args.labeled_num))
    unlabeled_idxs = list(range(args.labeled_num, 1312))
    batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs,
                                          batch_size,
                                          batch_size - args.labeled_bs)

    trainloader = DataLoader(db_train,
                             batch_sampler=batch_sampler,
                             num_workers=16,
                             pin_memory=True,
                             worker_init_fn=worker_init_fn)

    db_val = ACDC(base_dir=args.root_path, split="val")
    valloader = DataLoader(db_val, batch_size=1, shuffle=False, num_workers=1)

    model.train()

    optimizer = optim.SGD(model.parameters(),
                          lr=base_lr,
                          momentum=0.9,
                          weight_decay=0.0001)
    DAN_optimizer = optim.Adam(DAN.parameters(),
                               lr=args.DAN_lr,
                               betas=(0.9, 0.99))
    ce_loss = CrossEntropyLoss()
    dice_loss = losses.DiceLoss(num_classes)

    writer = SummaryWriter(snapshot_path + '/log')
    logging.info("{} iterations per epoch".format(len(trainloader)))

    iter_num = 0
    max_epoch = max_iterations // len(trainloader) + 1
    best_performance = 0.0
    iterator = tqdm(range(max_epoch), ncols=70)
    for epoch_num in iterator:
        for i_batch, sampled_batch in enumerate(trainloader):

            volume_batch, label_batch = sampled_batch['image'], sampled_batch[
                'label']
            volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda()

            DAN_target = torch.tensor([0] * args.batch_size).cuda()
            DAN_target[:args.labeled_bs] = 1
            model.train()
            DAN.eval()

            outputs = model(volume_batch)
            outputs_soft = torch.softmax(outputs, dim=1)

            loss_ce = ce_loss(outputs[:args.labeled_bs],
                              label_batch[:][:args.labeled_bs].long())
            loss_dice = dice_loss(outputs_soft[:args.labeled_bs],
                                  label_batch[:args.labeled_bs].unsqueeze(1))
            supervised_loss = 0.5 * (loss_dice + loss_ce)

            consistency_weight = get_current_consistency_weight(iter_num //
                                                                150)
            DAN_outputs = DAN(outputs_soft[args.labeled_bs:],
                              volume_batch[args.labeled_bs:])

            consistency_loss = F.cross_entropy(
                DAN_outputs, (DAN_target[:args.labeled_bs]).long())
            loss = supervised_loss + consistency_weight * consistency_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            model.eval()
            DAN.train()
            with torch.no_grad():
                outputs = model(volume_batch)
                outputs_soft = torch.softmax(outputs, dim=1)

            DAN_outputs = DAN(outputs_soft, volume_batch)
            DAN_loss = F.cross_entropy(DAN_outputs, DAN_target.long())
            DAN_optimizer.zero_grad()
            DAN_loss.backward()
            DAN_optimizer.step()

            lr_ = base_lr * (1.0 - iter_num / max_iterations)**0.9
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr_

            iter_num = iter_num + 1
            writer.add_scalar('info/lr', lr_, iter_num)
            writer.add_scalar('info/total_loss', loss, iter_num)
            writer.add_scalar('info/loss_ce', loss_ce, iter_num)
            writer.add_scalar('info/loss_dice', loss_dice, iter_num)
            writer.add_scalar('info/consistency_loss', consistency_loss,
                              iter_num)
            writer.add_scalar('info/consistency_weight', consistency_weight,
                              iter_num)

            logging.info(
                'iteration %d : loss : %f, loss_ce: %f, loss_dice: %f' %
                (iter_num, loss.item(), loss_ce.item(), loss_dice.item()))

            if iter_num % 20 == 0:
                image = volume_batch[1, 0:1, :, :]
                writer.add_image('train/Image', image, iter_num)
                outputs = torch.argmax(torch.softmax(outputs, dim=1),
                                       dim=1,
                                       keepdim=True)
                writer.add_image('train/Prediction', outputs[1, ...] * 50,
                                 iter_num)
                labs = label_batch[1, ...].unsqueeze(0) * 50
                writer.add_image('train/GroundTruth', labs, iter_num)

            if iter_num > 0 and iter_num % 200 == 0:
                model.eval()
                first_total, second_total, third_total = 0.0, 0.0, 0.0
                for i_batch, sampled_batch in enumerate(valloader):
                    first, second, third = test_single_volume(
                        sampled_batch["image"], sampled_batch["label"], model)
                    first_total += np.asarray(first)
                    second_total += np.asarray(second)
                    third_total += np.asarray(third)
                first_total, second_total, third_total = first_total / \
                    len(db_val), second_total / \
                    len(db_val), third_total/len(db_val)
                writer.add_scalar('info/val_one_dice', first_total[0],
                                  iter_num)
                writer.add_scalar('info/val_one_hd95', first_total[1],
                                  iter_num)

                writer.add_scalar('info/val_two_dice', second_total[0],
                                  iter_num)
                writer.add_scalar('info/val_two_hd95', second_total[1],
                                  iter_num)

                writer.add_scalar('info/val_three_dice', third_total[0],
                                  iter_num)
                writer.add_scalar('info/val_three_hd95', third_total[1],
                                  iter_num)

                performance = (first_total[0] + second_total[0] +
                               third_total[0]) / 3
                mean_hd95 = (first_total[1] + second_total[1] +
                             third_total[1]) / 3
                writer.add_scalar('info/val_mean_dice', performance, iter_num)
                writer.add_scalar('info/val_mean_hd95', mean_hd95, iter_num)

                if performance > best_performance:
                    best_performance = performance
                    save_mode_path = os.path.join(
                        snapshot_path, 'iter_{}_dice_{}.pth'.format(
                            iter_num, round(best_performance, 4)))
                    save_best = os.path.join(
                        snapshot_path, '{}_best_model.pth'.format(args.model))
                    torch.save(model.state_dict(), save_mode_path)
                    torch.save(model.state_dict(), save_best)

                logging.info('iteration %d : mean_dice : %f mean_hd95 : %f' %
                             (iter_num, performance, mean_hd95))
                model.train()

            if iter_num % 3000 == 0:
                save_mode_path = os.path.join(snapshot_path,
                                              'iter_' + str(iter_num) + '.pth')
                torch.save(model.state_dict(), save_mode_path)
                logging.info("save model to {}".format(save_mode_path))

            if iter_num >= max_iterations:
                break
        if iter_num >= max_iterations:
            iterator.close()
            break
    writer.close()
    return "Training Finished!"
def train(args, snapshot_path):
    base_lr = args.base_lr
    num_classes = 4
    batch_size = args.batch_size
    max_iterations = args.max_iterations

    def create_model(ema=False):
        # Network definition
        net = UNet('efficientnet-b3',
                   encoder_weights='imagenet',
                   in_channels=1,
                   classes=num_classes)
        model = net.cuda()
        if ema:
            for param in model.parameters():
                param.detach_()
        return model

    model = create_model()
    ema_model = create_model(ema=True)

    def worker_init_fn(worker_id):
        random.seed(args.seed + worker_id)

    db_train = ACDC(base_dir=args.root_path,
                    split="train",
                    num=None,
                    transform=transforms.Compose(
                        [RandomGenerator(args.patch_size)]))
    db_val = ACDC(base_dir=args.root_path, split="val")

    labeled_idxs = list(range(0, args.labeled_num))
    unlabeled_idxs = list(range(args.labeled_num, 1312))
    batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs,
                                          batch_size,
                                          batch_size - args.labeled_bs)

    trainloader = DataLoader(db_train,
                             batch_sampler=batch_sampler,
                             num_workers=4,
                             pin_memory=True,
                             worker_init_fn=worker_init_fn)

    model.train()
    ema_model.train()

    valloader = DataLoader(db_val, batch_size=1, shuffle=False, num_workers=1)

    optimizer = optim.SGD(model.parameters(),
                          lr=base_lr,
                          momentum=0.9,
                          weight_decay=0.0001)
    ce_loss = CrossEntropyLoss()
    dice_loss = losses.DiceLoss(num_classes)

    writer = SummaryWriter(snapshot_path + '/log')
    logging.info("{} iterations per epoch".format(len(trainloader)))

    iter_num = 0
    max_epoch = max_iterations // len(trainloader) + 1
    best_performance = 0.0
    iterator = tqdm(range(max_epoch), ncols=70)
    for epoch_num in iterator:
        for i_batch, sampled_batch in enumerate(trainloader):

            volume_batch, label_batch = sampled_batch['image'], sampled_batch[
                'label']
            volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda()
            unlabeled_volume_batch = volume_batch[args.labeled_bs:]
            labeled_volume_batch = volume_batch[:args.labeled_bs]

            # ICT mix factors
            ict_mix_factors = np.random.beta(args.ict_alpha,
                                             args.ict_alpha,
                                             size=(args.labeled_bs // 2, 1, 1,
                                                   1))
            ict_mix_factors = torch.tensor(ict_mix_factors,
                                           dtype=torch.float).cuda()
            unlabeled_volume_batch_0 = unlabeled_volume_batch[
                0:args.labeled_bs // 2, ...]
            unlabeled_volume_batch_1 = unlabeled_volume_batch[
                args.labeled_bs // 2:, ...]

            # Mix images
            batch_ux_mixed = unlabeled_volume_batch_0 * \
                (1.0 - ict_mix_factors) + \
                unlabeled_volume_batch_1 * ict_mix_factors
            input_volume_batch = torch.cat(
                [labeled_volume_batch, batch_ux_mixed], dim=0)
            outputs = model(input_volume_batch)
            outputs_soft = torch.softmax(outputs, dim=1)
            with torch.no_grad():
                ema_output_ux0 = torch.softmax(
                    ema_model(unlabeled_volume_batch_0), dim=1)
                ema_output_ux1 = torch.softmax(
                    ema_model(unlabeled_volume_batch_1), dim=1)
                batch_pred_mixed = ema_output_ux0 * \
                    (1.0 - ict_mix_factors) + ema_output_ux1 * ict_mix_factors

            loss_ce = ce_loss(outputs[:args.labeled_bs],
                              label_batch[:args.labeled_bs][:].long())
            loss_dice = dice_loss(outputs_soft[:args.labeled_bs],
                                  label_batch[:args.labeled_bs].unsqueeze(1))
            supervised_loss = 0.5 * (loss_dice + loss_ce)
            consistency_weight = get_current_consistency_weight(iter_num //
                                                                150)
            consistency_loss = torch.mean(
                (outputs_soft[args.labeled_bs:] - batch_pred_mixed)**2)
            loss = supervised_loss + consistency_weight * consistency_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            update_ema_variables(model, ema_model, args.ema_decay, iter_num)
            lr_ = base_lr * (1.0 - iter_num / max_iterations)**0.9
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr_

            iter_num = iter_num + 1
            writer.add_scalar('info/lr', lr_, iter_num)
            writer.add_scalar('info/total_loss', loss, iter_num)
            writer.add_scalar('info/loss_ce', loss_ce, iter_num)
            writer.add_scalar('info/loss_dice', loss_dice, iter_num)
            writer.add_scalar('info/consistency_loss', consistency_loss,
                              iter_num)
            writer.add_scalar('info/consistency_weight', consistency_weight,
                              iter_num)

            logging.info(
                'iteration %d : loss : %f, loss_ce: %f, loss_dice: %f' %
                (iter_num, loss.item(), loss_ce.item(), loss_dice.item()))

            if iter_num % 20 == 0:
                image = volume_batch[1, 0:1, :, :]
                writer.add_image('train/Image', image, iter_num)
                outputs = torch.argmax(torch.softmax(outputs, dim=1),
                                       dim=1,
                                       keepdim=True)
                writer.add_image('train/Prediction', outputs[1, ...] * 50,
                                 iter_num)
                image = batch_ux_mixed[1, 0:1, :, :]
                writer.add_image('train/Mixed_Unlabeled', image, iter_num)
                labs = label_batch[1, ...].unsqueeze(0) * 50
                writer.add_image('train/GroundTruth', labs, iter_num)

            if iter_num > 0 and iter_num % 200 == 0:
                model.eval()
                first_total, second_total, third_total = 0.0, 0.0, 0.0
                for i_batch, sampled_batch in enumerate(valloader):
                    first, second, third = test_single_volume(
                        sampled_batch["image"], sampled_batch["label"], model)
                    first_total += np.asarray(first)
                    second_total += np.asarray(second)
                    third_total += np.asarray(third)
                first_total, second_total, third_total = first_total / \
                    len(db_val), second_total / \
                    len(db_val), third_total/len(db_val)
                writer.add_scalar('info/val_one_dice', first_total[0],
                                  iter_num)
                writer.add_scalar('info/val_one_hd95', first_total[1],
                                  iter_num)

                writer.add_scalar('info/val_two_dice', second_total[0],
                                  iter_num)
                writer.add_scalar('info/val_two_hd95', second_total[1],
                                  iter_num)

                writer.add_scalar('info/val_three_dice', third_total[0],
                                  iter_num)
                writer.add_scalar('info/val_three_hd95', third_total[1],
                                  iter_num)

                performance = (first_total[0] + second_total[0] +
                               third_total[0]) / 3
                mean_hd95 = (first_total[1] + second_total[1] +
                             third_total[1]) / 3
                writer.add_scalar('info/val_mean_dice', performance, iter_num)
                writer.add_scalar('info/val_mean_hd95', mean_hd95, iter_num)

                if performance > best_performance:
                    best_performance = performance
                    save_mode_path = os.path.join(
                        snapshot_path, 'iter_{}_dice_{}.pth'.format(
                            iter_num, round(best_performance, 4)))
                    save_best = os.path.join(
                        snapshot_path, '{}_best_model.pth'.format(args.model))
                    torch.save(model.state_dict(), save_mode_path)
                    torch.save(model.state_dict(), save_best)

                logging.info('iteration %d : mean_dice : %f mean_hd95 : %f' %
                             (iter_num, performance, mean_hd95))
                model.train()

            if iter_num % 3000 == 0:
                save_mode_path = os.path.join(snapshot_path,
                                              'iter_' + str(iter_num) + '.pth')
                torch.save(model.state_dict(), save_mode_path)
                logging.info("save model to {}".format(save_mode_path))

            if iter_num >= max_iterations:
                break
        if iter_num >= max_iterations:
            iterator.close()
            break
    writer.close()
    return "Training Finished!"