def create_model(ema=False): # Network definition net = unet_3D(n_classes=2, in_channels=1) model = net.cuda() if ema: for param in model.parameters(): param.detach_() return model
def net_factory(net_type="unet_3D", num_classes=3, in_channels=1): if net_type == "unet_3D": net = unet_3D(n_classes=num_classes, in_channels=in_channels).cuda() elif net_type == "unet_3D_dv_semi": net = unet_3D_dv_semi(n_classes=num_classes, in_channels=in_channels).cuda() else: net = None return net
def net_factory_3d(net_type="unet_3D", in_chns=1, class_num=2): if net_type == "unet_3D": net = unet_3D(n_classes=class_num, in_channels=in_chns).cuda() elif net_type == "attention_unet": net = Attention_UNet(n_classes=class_num, in_channels=in_chns).cuda() elif net_type == "voxresnet": net = VoxResNet(in_chns=in_chns, feature_chns=64, class_num=class_num).cuda() elif net_type == "vnet": net = VNet(n_channels=in_chns, n_classes=class_num, normalization='batchnorm', has_dropout=True).cuda() else: net = None return net
def Inference(FLAGS): snapshot_path = "../model/{}/{}".format(FLAGS.exp, FLAGS.model) num_classes = 2 test_save_path = "../model/BraTs2019_Mean_Teacher_25/{}_Prediction".format( FLAGS.model) if os.path.exists(test_save_path): shutil.rmtree(test_save_path) os.makedirs(test_save_path) net = unet_3D(n_classes=num_classes, in_channels=1).cuda() save_mode_path = os.path.join( snapshot_path, '{}_best_model.pth'.format(FLAGS.model)) net.load_state_dict(torch.load(save_mode_path)) print("init weight from {}".format(save_mode_path)) net.eval() avg_metric = test_all_case(net, base_dir=FLAGS.root_path, method=FLAGS.model, test_list="test.txt", num_classes=num_classes, patch_size=(96, 96, 96), stride_xy=64, stride_z=64, test_save_path=test_save_path) return avg_metric
def train(args, snapshot_path): num_classes = 2 base_lr = args.base_lr train_data_path = args.root_path batch_size = args.batch_size max_iterations = args.max_iterations net = unet_3D(n_classes=num_classes, in_channels=1) model = net.cuda() DAN = FC3DDiscriminator(num_classes=num_classes) DAN = DAN.cuda() db_train = BraTS2019(base_dir=train_data_path, split='train', num=None, transform=transforms.Compose([ RandomRotFlip(), RandomCrop(args.patch_size), ToTensor(), ])) def worker_init_fn(worker_id): random.seed(args.seed + worker_id) labeled_idxs = list(range(0, args.labeled_num)) unlabeled_idxs = list(range(args.labeled_num, 250)) batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs, batch_size, batch_size - args.labeled_bs) trainloader = DataLoader(db_train, batch_sampler=batch_sampler, num_workers=4, pin_memory=True, worker_init_fn=worker_init_fn) model.train() optimizer = optim.SGD(model.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001) DAN_optimizer = optim.Adam(DAN.parameters(), lr=args.DAN_lr, betas=(0.9, 0.99)) ce_loss = CrossEntropyLoss() dice_loss = losses.DiceLoss(2) writer = SummaryWriter(snapshot_path + '/log') logging.info("{} iterations per epoch".format(len(trainloader))) iter_num = 0 max_epoch = max_iterations // len(trainloader) + 1 best_performance = 0.0 iterator = tqdm(range(max_epoch), ncols=70) for epoch_num in iterator: for i_batch, sampled_batch in enumerate(trainloader): volume_batch, label_batch = sampled_batch['image'], sampled_batch[ 'label'] volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda() DAN_target = torch.tensor([1, 1, 0, 0]).cuda() model.train() DAN.eval() outputs = model(volume_batch) outputs_soft = torch.softmax(outputs, dim=1) loss_ce = ce_loss(outputs, label_batch[:]) loss_dice = dice_loss(outputs_soft, label_batch.unsqueeze(1)) supervised_loss = 0.5 * (loss_dice + loss_ce) consistency_weight = get_current_consistency_weight(iter_num // 150) DAN_outputs = DAN(outputs_soft[args.labeled_bs:], volume_batch[args.labeled_bs:]) consistency_loss = F.cross_entropy( DAN_outputs, (DAN_target[:args.labeled_bs]).long()) loss = supervised_loss + consistency_weight * consistency_loss optimizer.zero_grad() loss.backward() optimizer.step() model.eval() DAN.train() with torch.no_grad(): outputs = model(volume_batch) outputs_soft = torch.softmax(outputs, dim=1) DAN_outputs = DAN(outputs_soft, volume_batch) DAN_loss = F.cross_entropy(DAN_outputs, DAN_target.long()) DAN_optimizer.zero_grad() DAN_loss.backward() DAN_optimizer.step() lr_ = base_lr * (1.0 - iter_num / max_iterations)**0.9 for param_group in optimizer.param_groups: param_group['lr'] = lr_ iter_num = iter_num + 1 writer.add_scalar('info/lr', lr_, iter_num) writer.add_scalar('info/total_loss', loss, iter_num) writer.add_scalar('info/loss_ce', loss_ce, iter_num) writer.add_scalar('info/loss_dice', loss_dice, iter_num) writer.add_scalar('info/consistency_loss', consistency_loss, iter_num) writer.add_scalar('info/consistency_weight', consistency_weight, iter_num) logging.info( 'iteration %d : loss : %f, loss_ce: %f, loss_dice: %f' % (iter_num, loss.item(), loss_ce.item(), loss_dice.item())) if iter_num % 20 == 0: image = volume_batch[0, 0:1, :, :, 20:61:10].permute(3, 0, 1, 2).repeat(1, 3, 1, 1) grid_image = make_grid(image, 5, normalize=True) writer.add_image('train/Image', grid_image, iter_num) image = outputs_soft[0, 1:2, :, :, 20:61:10].permute(3, 0, 1, 2).repeat(1, 3, 1, 1) grid_image = make_grid(image, 5, normalize=False) writer.add_image('train/Predicted_label', grid_image, iter_num) image = label_batch[0, :, :, 20:61:10].unsqueeze(0).permute( 3, 0, 1, 2).repeat(1, 3, 1, 1) grid_image = make_grid(image, 5, normalize=False) writer.add_image('train/Groundtruth_label', grid_image, iter_num) if iter_num > 0 and iter_num % 200 == 0: model.eval() avg_metric = test_all_case(model, args.root_path, test_list="val.txt", num_classes=2, patch_size=args.patch_size, stride_xy=64, stride_z=64) if avg_metric[:, 0].mean() > best_performance: best_performance = avg_metric[:, 0].mean() save_mode_path = os.path.join( snapshot_path, 'iter_{}_dice_{}.pth'.format( iter_num, round(best_performance, 4))) save_best = os.path.join( snapshot_path, '{}_best_model.pth'.format(args.model)) torch.save(model.state_dict(), save_mode_path) torch.save(model.state_dict(), save_best) writer.add_scalar('info/val_dice_score', avg_metric[0, 0], iter_num) writer.add_scalar('info/val_hd95', avg_metric[0, 1], iter_num) logging.info('iteration %d : dice_score : %f hd95 : %f' % (iter_num, avg_metric[0, 0].mean(), avg_metric[0, 1].mean())) model.train() if iter_num % 3000 == 0: save_mode_path = os.path.join(snapshot_path, 'iter_' + str(iter_num) + '.pth') torch.save(model.state_dict(), save_mode_path) logging.info("save model to {}".format(save_mode_path)) if iter_num >= max_iterations: break if iter_num >= max_iterations: iterator.close() break writer.close() return "Training Finished!"