def main(): # os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu assert torch.cuda.is_available(), "Currently, we only support CUDA version" torch.manual_seed(args.seed) # torch.cuda.manual_seed(args.seed) random.seed(args.seed) np.random.seed(args.seed) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") Network = getattr(models, args.net) # model = Network(**args.net_params) model = torch.nn.DataParallel(model).to(device) optimizer = getattr(torch.optim, args.opt)(model.parameters(), **args.opt_params) criterion = getattr(criterions, args.criterion) msg = '' if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_iter = checkpoint['iter'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optim_dict']) msg = ("=> loaded checkpoint '{}' (iter {})".format( args.resume, checkpoint['iter'])) else: msg = "=> no checkpoint found at '{}'".format(args.resume) else: msg = '-------------- New training session ----------------' msg += '\n' + str(args) logging.info(msg) Dataset = getattr(datasets, args.dataset) # if args.prefix_path: args.train_data_dir = os.path.join(args.prefix_path, args.train_data_dir) train_list = os.path.join(args.train_data_dir, args.train_list) train_set = Dataset(train_list, root=args.train_data_dir, for_train=True, transforms=args.train_transforms) num_iters = args.num_iters or (len(train_set) * args.num_epochs) // args.batch_size num_iters -= args.start_iter train_sampler = CycleSampler(len(train_set), num_iters * args.batch_size) train_loader = DataLoader(train_set, batch_size=args.batch_size, collate_fn=train_set.collate, sampler=train_sampler, num_workers=args.workers, pin_memory=True, worker_init_fn=init_fn) if args.valid_list: valid_list = os.path.join(args.train_data_dir, args.valid_list) valid_set = Dataset(valid_list, root=args.train_data_dir, for_train=False, transforms=args.test_transforms) valid_loader = DataLoader(valid_set, batch_size=1, shuffle=False, collate_fn=valid_set.collate, num_workers=args.workers, pin_memory=True) start = time.time() enum_batches = len(train_set) / float( args.batch_size) # nums_batch per epoch args.schedule = { int(k * enum_batches): v for k, v in args.schedule.items() } # 17100 # args.save_freq = int(enum_batches * args.save_freq) # args.valid_freq = int(enum_batches * args.valid_freq) losses = AverageMeter() torch.set_grad_enabled(True) for i, data in enumerate(train_loader, args.start_iter): elapsed_bsize = int(i / enum_batches) + 1 epoch = int((i + 1) / enum_batches) setproctitle.setproctitle("Epoch:{}/{}".format(elapsed_bsize, args.num_epochs)) adjust_learning_rate(optimizer, epoch, args.num_epochs, args.opt_params.lr) # data = [t.cuda(non_blocking=True) for t in data] data = [t.to(device) for t in data] x, target = data[:2] output = model(x) if not args.weight_type: # compatible for the old version args.weight_type = 'square' # loss = criterion(output, target, args.eps,args.weight_type) # loss = criterion(output, target,args.alpha,args.gamma) # for focal loss loss = criterion(output, target, *args.kwargs) # measure accuracy and record loss losses.update(loss.item(), target.numel()) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() if (i + 1) % int(enum_batches * args.save_freq) == 0 \ or (i + 1) % int(enum_batches * (args.num_epochs - 1)) == 0 \ or (i + 1) % int(enum_batches * (args.num_epochs - 2)) == 0 \ or (i + 1) % int(enum_batches * (args.num_epochs - 3)) == 0 \ or (i + 1) % int(enum_batches * (args.num_epochs - 4)) == 0: file_name = os.path.join(ckpts, 'model_epoch_{}.pth'.format(epoch)) torch.save( { 'iter': i + 1, 'state_dict': model.state_dict(), 'optim_dict': optimizer.state_dict(), }, file_name) # validation if (i + 1) % int(enum_batches * args.valid_freq) == 0 \ or (i + 1) % int(enum_batches * (args.num_epochs - 1)) == 0 \ or (i + 1) % int(enum_batches * (args.num_epochs - 2)) == 0 \ or (i + 1) % int(enum_batches * (args.num_epochs - 3)) == 0 \ or (i + 1) % int(enum_batches * (args.num_epochs - 4)) == 0: logging.info('-' * 50) msg = 'Iter {}, Epoch {:.4f}, {}'.format(i, i / enum_batches, 'validation') logging.info(msg) with torch.no_grad(): validate_softmax(valid_loader, model, cfg=args.cfg, savepath='', names=valid_set.names, scoring=True, verbose=False, use_TTA=False, snapshot=False, postprocess=False, cpu_only=False) msg = 'Iter {0:}, Epoch {1:.4f}, Loss {2:.7f}'.format( i + 1, (i + 1) / enum_batches, losses.avg) logging.info(msg) losses.reset() i = num_iters + args.start_iter file_name = os.path.join(ckpts, 'model_last.pth') torch.save( { 'iter': i, 'state_dict': model.state_dict(), 'optim_dict': optimizer.state_dict(), }, file_name) msg = 'total time: {:.4f} minutes'.format((time.time() - start) / 60) logging.info(msg)
def main(): os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu assert torch.cuda.is_available(), "Currently, we only support CUDA version" torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) random.seed(args.seed) np.random.seed(args.seed) Network = getattr(models, args.net) # model = Network(**args.net_params) model = torch.nn.DataParallel(model).cuda() optimizer = getattr(torch.optim, args.opt)(model.parameters(), **args.opt_params) criterion = getattr(criterions, args.criterion) msg = '' if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_iter = checkpoint['iter'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optim_dict']) msg = ("=> loaded checkpoint '{}' (iter {})".format( args.resume, checkpoint['iter'])) else: msg = "=> no checkpoint found at '{}'".format(args.resume) else: msg = '-------------- New training session ----------------' msg += '\n' + str(args) logging.info(msg) # Data loading code Dataset = getattr(datasets, args.dataset) # train_list = os.path.join(args.train_data_dir, args.train_list) train_set = Dataset(train_list, root=args.train_data_dir, for_train=True, transforms=args.train_transforms) num_iters = args.num_iters or (len(train_set) * args.num_epochs) // args.batch_size num_iters -= args.start_iter train_sampler = CycleSampler(len(train_set), num_iters * args.batch_size) train_loader = DataLoader(dataset=train_set, batch_size=args.batch_size, collate_fn=train_set.collate, sampler=train_sampler, num_workers=args.workers, pin_memory=True, worker_init_fn=init_fn) start = time.time() enum_batches = len(train_set) / float( args.batch_size) # nums_batch per epoch losses = AverageMeter() torch.set_grad_enabled(True) for i, data in enumerate(train_loader, args.start_iter): elapsed_bsize = int(i / enum_batches) + 1 epoch = int((i + 1) / enum_batches) setproctitle.setproctitle("Epoch:{}/{}".format(elapsed_bsize, args.num_epochs)) # actual training adjust_learning_rate(optimizer, epoch, args.num_epochs, args.opt_params.lr) data = [t.cuda(non_blocking=True) for t in data] x, target = data[:2] output = model(x) if not args.weight_type: # compatible for the old version args.weight_type = 'square' if args.criterion_kwargs is not None: loss = criterion(output, target, **args.criterion_kwargs) else: loss = criterion(output, target) # measure accuracy and record loss losses.update(loss.item(), target.numel()) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() if (i+1) % int(enum_batches * args.save_freq) == 0 \ or (i+1) % int(enum_batches * (args.num_epochs -1))==0\ or (i+1) % int(enum_batches * (args.num_epochs -2))==0\ or (i+1) % int(enum_batches * (args.num_epochs -3))==0\ or (i+1) % int(enum_batches * (args.num_epochs -4))==0: file_name = os.path.join(ckpts, 'model_epoch_{}.pth'.format(epoch)) torch.save( { 'iter': i, 'state_dict': model.state_dict(), 'optim_dict': optimizer.state_dict(), }, file_name) msg = 'Iter {0:}, Epoch {1:.4f}, Loss {2:.7f}'.format( i + 1, (i + 1) / enum_batches, losses.avg) logging.info(msg) losses.reset() i = num_iters + args.start_iter file_name = os.path.join(ckpts, 'model_last.pth') torch.save( { 'iter': i, 'state_dict': model.state_dict(), 'optim_dict': optimizer.state_dict(), }, file_name) msg = 'total time: {:.4f} minutes'.format((time.time() - start) / 60) logging.info(msg)
def main(): # setup environments and seeds os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) random.seed(args.seed) np.random.seed(args.seed) # setup networks Network = getattr(models, args.net) model = Network(**args.net_params) model = model.cuda() optimizer = getattr(torch.optim, args.opt)(model.parameters(), **args.opt_params) criterion = getattr(criterions, args.criterion) msg = '' # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_iter = checkpoint['iter'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optim_dict']) msg = ("=> loaded checkpoint '{}' (iter {})".format( args.resume, checkpoint['iter'])) else: msg = "=> no checkpoint found at '{}'".format(args.resume) else: msg = '-------------- New training session ----------------' msg += '\n' + str(args) logging.info(msg) # Data loading code Dataset = getattr(datasets, args.dataset) # The loader will get 1000 patches from 50 subjects for each sub epoch # each subject sample 20 patches train_list = os.path.join(args.data_dir, args.train_list) train_set = Dataset(train_list, root=args.data_dir, for_train=True, num_patches=args.num_patches, transforms=args.train_transforms, sample_size=args.sample_size, sub_sample_size=args.sub_sample_size, target_size=args.target_size) num_iters = args.num_iters or (len(train_set) * args.num_epochs) // args.batch_size num_iters -= args.start_iter train_sampler = CycleSampler(len(train_set), num_iters * args.batch_size) train_loader = DataLoader(train_set, batch_size=args.batch_size, collate_fn=train_set.collate, sampler=train_sampler, num_workers=args.workers, pin_memory=True, worker_init_fn=init_fn) if args.valid_list: valid_list = os.path.join(args.data_dir, args.valid_list) valid_set = Dataset(valid_list, root=args.data_dir, for_train=False, crop=False, transforms=args.test_transforms, sample_size=args.sample_size, sub_sample_size=args.sub_sample_size, target_size=args.target_size) valid_loader = DataLoader(valid_set, batch_size=1, shuffle=False, collate_fn=valid_set.collate, num_workers=4, pin_memory=True) train_valid_set = Dataset(train_list, root=args.data_dir, for_train=False, crop=False, transforms=args.test_transforms, sample_size=args.sample_size, sub_sample_size=args.sub_sample_size, target_size=args.target_size) train_valid_loader = DataLoader(train_valid_set, batch_size=1, shuffle=False, collate_fn=train_valid_set.collate, num_workers=4, pin_memory=True) start = time.time() enum_batches = len(train_set) / float(args.batch_size) args.schedule = { int(k * enum_batches): v for k, v in args.schedule.items() } args.save_freq = int(enum_batches * args.save_freq) args.valid_freq = int(enum_batches * args.valid_freq) losses = AverageMeter() torch.set_grad_enabled(True) for i, (data, label) in enumerate(train_loader, args.start_iter): ## validation #if args.valid_list and (i % args.valid_freq) == 0: # logging.info('-'*50) # msg = 'Iter {}, Epoch {:.4f}, {}'.format(i, i/enum_batches, 'validation') # logging.info(msg) # with torch.no_grad(): # validate(valid_loader, model, batch_size=args.mini_batch_size, names=valid_set.names) # actual training adjust_learning_rate(optimizer, i) for data in zip(*[d.split(args.mini_batch_size) for d in data]): data = [t.cuda(non_blocking=True) for t in data] x1, x2, target = data[:3] if len(data) > 3: # has mask m1, m2 = data[3:] x1 = add_mask(x1, m1, 1) x2 = add_mask(x2, m2, 1) # compute output output = model((x1, x2)) # output nx5x9x9x9, target nx9x9x9 loss = criterion(output, target, args.alpha) # measure accuracy and record loss losses.update(loss.item(), target.numel()) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() if (i + 1) % args.save_freq == 0: epoch = int((i + 1) // enum_batches) file_name = os.path.join(ckpts, 'model_epoch_{}.tar'.format(epoch)) torch.save( { 'iter': i + 1, 'state_dict': model.state_dict(), 'optim_dict': optimizer.state_dict(), }, file_name) msg = 'Iter {0:}, Epoch {1:.4f}, Loss {2:.4f}'.format( i + 1, (i + 1) / enum_batches, losses.avg) logging.info(msg) losses.reset() i = num_iters + args.start_iter file_name = os.path.join(ckpts, 'model_last.tar') torch.save( { 'iter': i, 'state_dict': model.state_dict(), 'optim_dict': optimizer.state_dict(), }, file_name) if args.valid_list: logging.info('-' * 50) msg = 'Iter {}, Epoch {:.4f}, {}'.format(i, i / enum_batches, 'validate validation data') logging.info(msg) with torch.no_grad(): validate(valid_loader, model, batch_size=args.mini_batch_size, names=valid_set.names, out_dir=args.out) #logging.info('-'*50) #msg = 'Iter {}, Epoch {:.4f}, {}'.format(i, i/enum_batches, 'validate training data') #logging.info(msg) #with torch.no_grad(): # validate(train_valid_loader, model, batch_size=args.mini_batch_size, names=train_valid_set.names, verbose=False) msg = 'total time: {:.4f} minutes'.format((time.time() - start) / 60) logging.info(msg)
def main(): h, w, z = map(int, args.input_size.split(',')) input_size = (h, w, z) cudnn.enabled = True # make logger file if not os.path.exists(snapshot_path): os.makedirs(snapshot_path) # if os.path.exists(snapshot_path + '/code'): # shutil.rmtree(snapshot_path + '/code') # shutil.copytree('.', snapshot_path + '/code', shutil.ignore_patterns(['.git','__pycache__'])) logging.basicConfig(filename=snapshot_path+"/log.txt", level=logging.INFO, format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging.info(str(args)) # setup environments and seeds os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) random.seed(args.seed) np.random.seed(args.seed) # setup networks Network = getattr(models, args.net) # model = Network(**args.net_params) # model = model.cuda() # # Load Plan and Read Params # load_plans_file() # Generic_Unet setting # model = Generic_UNet(num_input_channels, base_num_features, num_classes, net_numpool, # 2, 2, conv_op, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, # net_nonlin, net_nonlin_kwargs, False, False, lambda x: x, InitWeights_He(1e-2), # net_num_pool_op_kernel_sizes, net_conv_kernel_sizes, False, True, True) # model.train() # model = model.cuda() # model.inference_apply_nonlin = softmax_helper # cudnn.benchmark = True def create_model(ema=False): # Network definition Network = getattr(models, args.net) net = Network(**args.net_params) model = net.cuda() if ema: for param in model.parameters(): param.detach_() return model model = create_model() ema_model = create_model(ema=True) # Network = getattr(models, args.net) # model = Network(**args.net_params) # model = model.cuda() # optimizer for segmentation network optimizer = getattr(torch.optim, args.opt)( model.parameters(), **args.opt_params) optimizer.zero_grad() # optimizer = getattr(torch.optim, args.opt)( # model.parameters(), **args.opt_params) # optimizer.zero_grad() ce_loss = getattr(loss_fun, args.criterion) model_C = NetC(ngpu = 1) model_C.train() model_C.cuda() # model_C.apply(inplace_relu) # optimizer_D = optim.Adam(model_D.parameters(), lr=0.0002, betas=(0.5, 0.999)) optimizer_D = getattr(torch.optim, args.opt_D)( model_C.parameters(), **args.opt_params_D) # Dopt = optim.Adam(D.parameters(), lr=args.D_lr, betas=(0.9,0.99)) # init D # model_D = s4GAN_discriminator(num_classes=args.num_classes) # model_D.train() # model_D.cuda() # init D # model_D = s4GAN_discriminator(num_classes=args.num_classes) # optimizer = getattr(torch.optim, args.opt)( # model.parameters(), **args.opt_params) # lr_scheduler_eps = 1e-3 # lr_scheduler_patience = 30 # initial_lr = 3e-4 # weight_decay = 3e-5 # oversample_foreground_percent = 0.33 # optimizer = torch.optim.Adam(model.parameters(), initial_lr, weight_decay= weight_decay, # amsgrad=True) # optimizer = getattr(torch.optim, args.opt)( # model.parameters(), **args.opt_params) # criterion = getattr(criterions, args.criterion) msg = '' # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_iter = checkpoint['iter'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optim_dict']) msg = ("=> loaded checkpoint '{}' (iter {})" .format(args.resume, checkpoint['iter'])) else: msg = "=> no checkpoint found at '{}'".format(args.resume) else: msg = '-------------- New training session ----------------' msg += '\n' + str(args) logging.info(msg) # Data loading code Dataset = getattr(datasets, args.dataset) all_train_list = os.path.join(args.data_dir, args.all_train_list) lab_train_list = os.path.join(args.data_dir, args.lab_train_list) unlab_train_list = os.path.join(args.data_dir, args.unlab_train_list) all_train_set = Dataset(all_train_list, root=args.data_dir, for_train=True, transforms=args.train_transforms) lab_train_set = Dataset(lab_train_list, root=args.data_dir, for_train=True, transforms=args.train_transforms) unlab_train_set = Dataset(unlab_train_list, root=args.data_dir, for_train=True, transforms=args.train_transforms) num_iters = args.num_iters or (len(all_train_set) * args.num_epochs) // args.batch_size num_iters -= args.start_iter train_sampler = CycleSampler(len(all_train_set), num_iters*args.batch_size) all_train_dataset_size = len(all_train_set) print ('train dataset size: ', all_train_dataset_size) # train_dataset = train_set lab_train_dataset = lab_train_set unlab_train_dataset = unlab_train_set lab_train_dataset_size = len(lab_train_set) unlab_train_dataset_size = len(unlab_train_set) # partial_size = int(args.labeled_ratio * train_dataset_size) # args.split_id = None # if args.split_id is not None: # train_ids = pickle.load(open(args.split_id, 'rb')) # print('loading train ids from {}'.format(args.split_id)) # else: # train_ids = np.arange(train_dataset_size) # np.random.shuffle(train_ids) # pickle.dump(train_ids, open(os.path.join(ckpts, 'train_split.pkl'), 'wb')) # train_sampler = sampler.SubsetRandomSampler(train_ids[:partial_size]) # pickle.dump(train_ids[:partial_size], open(os.path.join(ckpts, 'train_lab.pkl'), 'wb')) # train_remain_sampler = sampler.SubsetRandomSampler(train_ids[partial_size:]) # pickle.dump(train_ids[partial_size:], open(os.path.join(ckpts, 'train_unlab.pkl'), 'wb')) # train_gt_sampler = sampler.SubsetRandomSampler(train_ids[:partial_size]) train_sampler = CycleSampler(len(lab_train_set), num_iters*args.batch_size) train_remain_sampler = CycleSampler(len(unlab_train_set), num_iters*args.batch_size) train_gt_sampler = CycleSampler(len(lab_train_set), num_iters*args.batch_size) # trainloader = DataLoader(lab_train_dataset, # batch_size=args.batch_size, sampler=lab_train_sampler, num_workers=4, pin_memory=True) # trainloader_remain = DataLoader(unlab_train_dataset, # batch_size=args.batch_size, sampler=unlab_train_sampler, num_workers=4, pin_memory=True) # trainloader_gt = DataLoader(lab_train_dataset, # batch_size=args.batch_size, sampler=lab_train_sampler, num_workers=4, pin_memory=True) # lab_train_ids = np.arange(lab_train_dataset_size) # np.random.shuffle(lab_train_ids) # unlab_train_ids = np.arange(unlab_train_dataset_size) # np.random.shuffle(unlab_train_ids) # train_sampler = sampler.SubsetRandomSampler(lab_train_ids[:2]) # pickle.dump(lab_train_ids[:2], open(os.path.join(ckpts, 'train_lab.pkl'), 'wb')) # train_remain_sampler = sampler.SubsetRandomSampler(unlab_train_ids[:13]) # pickle.dump(unlab_train_ids[:13], open(os.path.join(ckpts, 'train_unlab.pkl'), 'wb')) # train_gt_sampler = sampler.SubsetRandomSampler(lab_train_ids[:2]) # trainloader = DataLoader(lab_train_dataset, # batch_size=args.batch_size, sampler=train_sampler, num_workers=4, pin_memory=True) # trainloader_remain = DataLoader(unlab_train_dataset, # batch_size=args.batch_size, sampler=train_remain_sampler, num_workers=4, pin_memory=True) # trainloader_gt = DataLoader(lab_train_dataset, # batch_size=args.batch_size, sampler=train_gt_sampler, num_workers=4, pin_memory=True) trainloader = DataLoader(lab_train_dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=4, pin_memory=True) trainloader_remain = DataLoader(unlab_train_dataset, batch_size=args.batch_size, sampler=train_remain_sampler, num_workers=4, pin_memory=True) # trainloader_remain = DataLoader( # unlab_train_dataset, batch_size=2, shuffle=False, # collate_fn=valid_set.collate, # num_workers=4, pin_memory=True) trainloader_gt = DataLoader(lab_train_dataset, batch_size=args.batch_size, sampler=train_gt_sampler, num_workers=4, pin_memory=True) trainloader_remain_iter = iter(trainloader_remain) trainloader_iter = iter(trainloader) trainloader_gt_iter = iter(trainloader_gt) if args.valid_list: valid_list = os.path.join(args.data_dir, args.valid_list) valid_set = Dataset(valid_list, root=args.data_dir, for_train=False, transforms=args.test_transforms) valid_loader = DataLoader( valid_set, batch_size=1, shuffle=False, collate_fn=valid_set.collate, num_workers=4, pin_memory=True) # optimizer for segmentation network # # optimizer = optim.SGD(model.parameters(), # # lr=args.learning_rate, momentum=args.momentum,weight_decay=args.weight_decay) # # optimizer.zero_grad() # # optimizer for discriminator network # optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9,0.99)) # optimizer_D.zero_grad() # loss/ bilinear upsampling interp = nn.Upsample(size=(input_size[0], input_size[1], input_size[2]), mode='trilinear', align_corners=True) # labels for adversarial training pred_label = 0 gt_label = 1 y_real_, y_fake_ = Variable(torch.ones(args.batch_size, 1).cuda()), Variable(torch.zeros(args.batch_size, 1).cuda()) Tanh = nn.Tanh() writer = SummaryWriter(snapshot_path+'/log') logging.info("{} itertations per epoch".format(len(trainloader))) start = time.time() enum_batches = len(all_train_set)/float(args.batch_size) #enum_batches = len(all_train_set)/float(args.batch_size) + 1 args.schedule = {int(k*enum_batches): v for k, v in args.schedule.items()} args.save_freq = int(enum_batches * args.save_freq) args.valid_freq = int(enum_batches * args.valid_freq) args.max_iterations = num_iters + args.start_iter print ('number of max iterations: ', args.max_iterations) iter_num = 0 max_epoch = args.max_iterations//len(trainloader)+1 print ('number of max epoch: ', max_epoch) losses = AverageMeter() torch.set_grad_enabled(True) batch_dice = False # batch_dice = batch_dice initial_lr = args.learning_rate lr_ = initial_lr # labels for adversarial training pred_label = 0 gt_label = 1 max_par=0.0 max_score_ema=0.0 max_score=0.0 for i_iter in range(args.max_iterations): loss_ce_value = 0 loss_D_value = 0 loss_fm_value = 0 loss_S_value = 0 loss_G_value = 0 loss_seg_value = 0 loss_adv_pred_value = 0 loss_D_value = 0 loss_semi_value = 0 loss_semi_adv_value = 0 loss_consistency_value = 0 ema_decay = 0.99 consistency = 0.1 consistency_rampup = 40.0 loss_adv_value = 0 loss_sdf_value = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) optimizer_D.zero_grad() adjust_learning_rate_D(optimizer_D, i_iter) # #train C # model_C.zero_grad() # train with gt # get gt labels # training loss for labeled data only try: batch = next(trainloader_iter) except: trainloader_iter = iter(trainloader) batch = next(trainloader_iter) images, labels = batch images = Variable(images).cuda(non_blocking=True) labels = Variable(labels).cuda(non_blocking=True) images = torch.squeeze(images, 1) labels = torch.squeeze(labels, 1) # print ('images size: ', images.shape) # print ('labels dataset size: ', labels.shape) ignore_mask = (labels.cpu().numpy() == 255) ignore_mask_D = ignore_mask ignore_mask_gt = (labels.cpu().numpy() == 255) # pred = interp(model(images)) # with torch.no_grad(): # pred = model(images) # pred_interp = interp(pred) # noise = torch.clamp(torch.randn_like(images) * 0.1, -0.2, 0.2) # ema_inputs = images + noise # with torch.no_grad(): # ema_output = ema_model(ema_inputs) # ema_pred_interp = interp(ema_output) # ema_output = F.softmax(ema_output) # ema_output = ema_output.detach() # ema_pred_fore_1 = ema_output[:, 1] # flair = images # pred_gt = F.softmax(pred) # pred_gt = pred_gt.detach() # output_masked = flair.clone() # input_mask = flair.clone() # target_masked = torch.zeros([args.batch_size, 1, 128, 128, 128]).cuda() # target_masked[:, 0, :] = input_mask[:,0,:,:,:] * ema_pred_fore_1 # target_masked = target_masked.cuda() # target_masked_lab = target_masked # pred_fore_1_gt = pred_gt[:, 1] # labels_n_gt = labels.cpu().numpy() # labels_n_gt_1 = labels_n_gt # labels_n_gt_1[(labels_n_gt_1==2)]=0 # labels_n_gt_1[(labels_n_gt_1==3)]=0 # labels_fore_gt_1 = torch.tensor(labels_n_gt_1) # labels_fore_gt_1 = labels_fore_gt_1.cuda().float() # #detach G from the network # output_masked = torch.zeros([args.batch_size, 1, 128, 128, 128]).cuda() # output_masked[:, 0, :] = input_mask[:,0,:,:,:] * pred_fore_1_gt # output_masked = output_masked.cuda() # # print('output_masked:', output_masked.shape) # output_masked_lab = output_masked # target_masked = flair.clone() # target_masked = torch.zeros([args.batch_size, 1, 128, 128, 128]).cuda() # target_masked[:, 0, :] = input_mask[:,0,:,:,:] * labels_fore_gt_1 # target_masked = target_masked.cuda() # pred_gt_cat = torch.cat((F.softmax(pred_interp,dim=1), output_masked), dim=1) # _, result_gt = model_C(pred_gt_cat) # D_gt_v_gt = Variable(one_hot(labels)).cuda(non_blocking=True) # # print('D_gt_v_gt',D_gt_v_gt.shape) # D_gt_v_cat_gt = torch.cat((D_gt_v_gt, target_masked), dim=1) # _, target_D = model_C(D_gt_v_cat_gt) # loss_DD = - torch.mean(torch.abs(result_gt - target_D)) # loss_DD.backward() # optimizer_D.step() # #clip parameters in D # for p in model_C.parameters(): # p.data.clamp_(-0.05, 0.05) # model.zero_grad() outputs_tanh_1, pred = model(images) with torch.no_grad(): gt_dis = compute_sdf(labels.cpu().numpy(), pred.shape) gt_dis = torch.from_numpy(gt_dis).float().cuda() loss_sdf = F.mse_loss(outputs_tanh_1, gt_dis) loss_seg = ce_loss(pred, labels) loss = loss_seg losses.update(loss.item(), labels.numel()) pred_gt = F.softmax(pred, dim=1) output_masked = images.clone() input_mask = images.clone() pred_fore_1_gt = pred_gt[:, 1] # pred_fore_2_gt = pred_gt[:, 2] # pred_fore_3_gt = pred_gt[:, 3] labels_n_gt = labels.cpu().numpy() labels_n_gt_1 = labels_n_gt labels_n_gt_2 = labels_n_gt labels_n_gt_3 = labels_n_gt labels_n_gt_1[(labels_n_gt_1==2)]=0 labels_n_gt_1[(labels_n_gt_1==3)]=0 # labels_n_gt_2[(labels_n_gt_2==1)]=0 # labels_n_gt_2[(labels_n_gt_2==3)]=0 # labels_n_gt_2[(labels_n_gt_2==2)]=1 # labels_n_gt_3[(labels_n_gt_3==1)]=0 # labels_n_gt_3[(labels_n_gt_3==2)]=0 # labels_n_gt_3[(labels_n_gt_3==3)]=1 labels_fore_gt_1 = torch.tensor(labels_n_gt_1) labels_fore_gt_1 = labels_fore_gt_1.cuda().float() # labels_fore_gt_2 = torch.tensor(labels_n_gt_2) # labels_fore_gt_2 = labels_fore_gt_2.cuda().float() # labels_fore_gt_3 = torch.tensor(labels_n_gt_3) # labels_fore_gt_3 = labels_fore_gt_3.cuda().float() output_masked = torch.zeros([args.batch_size, 1, 128, 128, 128]).cuda() output_masked[:, 0, :] = input_mask[:,0,:,:,:] * pred_fore_1_gt # output_masked[:, 1, :] = input_mask[:,0,:,:,:] * pred_fore_2_gt # output_masked[:, 2, :] = input_mask[:,0,:,:,:] * pred_fore_3_gt output_masked = output_masked.cuda() output_masked_lab = output_masked # print('outputs_tanh_1',outputs_tanh_1) result,_ = model_C(outputs_tanh_1, output_masked) target_masked = images.clone() target_masked = torch.zeros([args.batch_size, 1, 128, 128, 128]).cuda() target_masked[:, 0, :] = input_mask[:,0,:,:,:] * labels_fore_gt_1 # target_masked[:, 1, :] = input_mask[:,0,:,:,:] * labels_fore_gt_2 # target_masked[:, 2, :] = input_mask[:,0,:,:,:] * labels_fore_gt_3 target_masked = target_masked.cuda() # print('labels',labels.shape) # print('labels_tanh',labels_tanh) labels_tanh = Variable(one_hot(labels.cpu())).cuda(non_blocking=True) target_G,_ = model_C(labels_tanh, target_masked) loss_G = torch.mean(torch.abs(result - target_G)) # print('trainloader_remain_iter', trainloader_remain_iter) try: batch_remain = next(trainloader_remain_iter) except: trainloader_remain_iter = iter(trainloader_remain) batch_remain = next(trainloader_remain_iter) images_remain, _ = batch_remain # print ('images remain size: ', images_remain.shape) images_remain = Variable(images_remain).cuda(non_blocking=True) images_remain = torch.squeeze(images_remain, 1) # Generate Discriminator target based on sampler Dtarget = torch.tensor([1, 0]).cuda() model.train() model_C.eval() # noise and ema_output noise = torch.clamp(torch.randn_like(images_remain) * 0.1, -0.2, 0.2) ema_inputs = images_remain + noise with torch.no_grad(): outputs_tanh, pred_output = model(images_remain) with torch.no_grad(): ema_outputs_tanh, ema_output = ema_model(ema_inputs) pred_output = F.softmax(pred_output) ema_output = F.softmax(ema_output) pred_fore_1 = pred_output[:, 1] ema_pred_fore_1 = ema_output[:, 1] # pred_fore_2 = pred_output[:, 2] # ema_pred_fore_2 = ema_output[:, 2] # pred_fore_3 = pred_output[:, 3] # ema_pred_fore_3 = ema_output[:, 3] # # uncertainty # T = 8 # volume_batch_r = images_remain.repeat(2, 1, 1, 1, 1) # stride = volume_batch_r.shape[0] // 2 # preds = torch.zeros([stride * T, 4, 128, 128, 128]).cuda() # for i in range(T//2): # ema_inputs = volume_batch_r + torch.clamp(torch.randn_like(volume_batch_r) * 0.1, -0.2, 0.2) # with torch.no_grad(): # ema_out = ema_model(ema_inputs) # preds[2 * stride * i:2 * stride * (i + 1)] = ema_out # preds = F.softmax(preds, dim=1) # preds = preds.reshape(T, stride, 4, 128, 128,128) # preds = torch.mean(preds, dim=0) #(batch, 2, 112,112,80) # uncertainty = -1.0*torch.sum(preds*torch.log(preds + 1e-6), dim=1, keepdim=True) # calculate the loss epo = int((i_iter+1) // enum_batches) con_weight = get_current_consistency_con_weight(epo) adv_weight = get_current_consistency_adv_weight(i_iter//150) # consistency_dist = softmax_mse_WT_loss(pred_output, ema_output) #(batch, 2, 112,112,80) image_flair = images_remain[:, 0:1] output_masked = image_flair.clone() input_mask = image_flair.clone() output_masked = torch.zeros([args.batch_size, 1, 128, 128, 128]).cuda() output_masked[:, 0, :] = input_mask[:,0,:,:,:] * pred_fore_1 # output_masked[:, 1, :] = input_mask[:,0,:,:,:] * pred_fore_2 # output_masked[:, 2, :] = input_mask[:,0,:,:,:] * pred_fore_3 output_masked = output_masked.cuda() output_masked_unlab = output_masked # Doutputs = D(outputs_tanh, images_remain) result_remain, Doutputs = model_C(outputs_tanh, output_masked) target_masked = image_flair.clone() target_masked = torch.zeros([args.batch_size, 1, 128, 128, 128]).cuda() target_masked[:, 0, :] = input_mask[:,0,:,:,:] * ema_pred_fore_1 # target_masked[:, 1, :] = input_mask[:,0,:,:,:] * ema_pred_fore_2 # target_masked[:, 2, :] = input_mask[:,0,:,:,:] * ema_pred_fore_3 target_masked = target_masked.cuda() target_C, ema_Doutputs = model_C(ema_outputs_tanh, target_masked) consistency_dist = torch.mean(torch.abs(result_remain - target_C)) # threshold = (0.75+0.25*ramps.sigmoid_rampup(i_iter, args.max_iterations))*np.log(2) # mask = (uncertainty<threshold).float() # consistency_dist = torch.sum(mask*consistency_dist)/(2*torch.sum(mask)+1e-16) consistency_loss = con_weight * consistency_dist print('consistency_loss', consistency_loss.requires_grad) print('loss_G', loss_G.requires_grad) print('loss_seg', loss_seg.requires_grad) loss_adv = F.cross_entropy(Doutputs, Dtarget[:1].long()) loss = loss_seg + loss_G + con_weight * consistency_dist + args.sdm_weight * loss_sdf + adv_weight * loss_adv loss.backward() loss_seg_value += loss_seg.data.cpu().numpy() loss_G_value += loss_G.data.cpu().numpy() loss_consistency_value += consistency_loss.cpu().detach().numpy() loss_adv_value += loss_adv.cpu().detach().numpy() loss_sdf_value += loss_sdf.cpu().detach().numpy() optimizer.step() update_ema_variables(model, ema_model, ema_decay, i_iter) # Train D model.eval() model_C.train() with torch.no_grad(): outputs_tanh_3, outputs_3 = model(images) outputs_tanh_4, outputs_4 = model(images_remain) output_masked_lab = output_masked_lab.detach() output_masked_unlab = output_masked_unlab.detach() outputs_tanh = torch.cat((outputs_tanh_3, outputs_tanh_4), 0) volume_batch = torch.cat((output_masked_lab, output_masked_unlab), 0) result_D, Doutputs = model_C(outputs_tanh, volume_batch) # D want to classify unlabel data and label data rightly. # print('Doutputs',Doutputs.shape) loss_D_2 = F.cross_entropy(Doutputs, Dtarget.long()) # pred_lab = outputs_3 # pred_lab_soft = F.softmax(pred_lab, dim=1) input_mask = images.clone() labels_gt = labels.clone() # pred_fore_1_gt = pred_lab_soft[:, 1,:,:,:] labels_n_gt_1 = labels_gt.cpu().numpy() labels_n_gt_1[(labels_n_gt_1==2)]=0 labels_n_gt_1[(labels_n_gt_1==3)]=0 labels_fore_gt_1 = torch.tensor(labels_n_gt_1) labels_fore_gt_1 = labels_fore_gt_1.cuda().float() # output_masked = input_mask * pred_fore_1_gt # output_masked_lab = output_masked target_masked = torch.zeros([args.batch_size, 1, 128, 128, 128]).cuda() target_masked[:, 0, :] = input_mask[:,0,:,:,:] * labels_fore_gt_1 target_masked = target_masked.cuda() # output_masked_lab = output_masked target_masked_lab = target_masked D_gt_v_gt = Variable(one_hot(labels_gt.cpu())).cuda(non_blocking=True) # D_gt_v_cat_gt = torch.cat((D_gt_v_gt, target_masked), dim=1) # pred_interp = interp(pred_lab) # pred_remain = interp(pred_unlab) # pred_interp = torch.cat((pred_interp, pred_remain), 0) # output_masked_2 = torch.cat((output_masked_lab, output_masked_unlab), 0) # ignore_mask = np.concatenate((ignore_mask_lab,ignore_mask_remain), axis = 0) # D_gt_v_cat = torch.cat((F.softmax(pred_interp, dim=1), output_msaked_2), dim=1) # result_gt, _ = model_C(outputs_tanh_5, output_masked_2) target_D, _ = model_C(D_gt_v_gt, target_masked) loss_D_dis = - torch.mean(torch.abs(result_D[:1] - target_D)) loss_D = loss_D_2 + loss_D_dis # Dtp and Dfn is unreliable because of the num of samples is small(4) # Dacc = torch.mean((torch.argmax(Doutputs, dim=1).float()==Dtarget.float()).float()) # Dtp = torch.mean((torch.argmax(Doutputs, dim=1).float()==Dtarget.float()).float()) # Dfn = torch.mean((torch.argmax(Doutputs, dim=1).float()==Dtarget.float()).float()) # Dopt.zero_grad() optimizer_D.zero_grad() loss_D.backward() optimizer_D.step() lr_S = optimizer.param_groups[0]['lr'] lr_D = optimizer_D.param_groups[0]['lr'] writer.add_scalar('lr/lr_S', lr_S, i_iter) writer.add_scalar('loss/loss_seg', losses.avg, i_iter) writer.add_scalar('loss/consistency_loss', consistency_loss, i_iter) writer.add_scalar('loss/loss_sdf', loss_sdf, i_iter) writer.add_scalar('loss/loss_adv', loss_adv, i_iter) writer.add_scalar('loss/loss_G', loss_adv, i_iter) writer.add_scalar('train/consistency_weight', con_weight, i_iter) writer.add_scalar('train/adv_weight', adv_weight, i_iter) writer.add_scalar('train/consistency_dist', consistency_dist, i_iter) # print('iter = {0:8d}/{1:8d}, loss_dc_ce = {2:.3f}, loss_fm = {3:.3f}, loss_S = {4:.3f}, loss_D = {5:.3f}'.format(i_iter, args.max_iterations, loss_ce_value, loss_fm_value, loss_S_value, loss_D_value)) msg = 'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_G = {3:.3f}, consistency_loss = {4:.6f}, con_weight = {5:.6f}, consistency_dist = {6:.6f}, loss_sdf = {7:.6f}, loss_adv = {8:.6f}, adv_weight = {9:.6f}, sdm_weight = {10:.6f}'.format(i_iter, args.max_iterations, losses.avg, loss_G_value, loss_consistency_value, con_weight, consistency_dist, loss_sdf_value, loss_adv_value, adv_weight, args.sdm_weight) # msg = 'iter = {0:8d}/{1:8d}, Loss {2:.4f}'.format( # i_iter, args.max_iterations, losses.avg) # msg = 'iter = {0:8d}/{1:8d}, loss_dc_ce = {2:.3f}, loss_fm = {3:.3f}, loss_S = {4:.3f}, loss_D = {5:.3f}'.format(i_iter, args.max_iterations, loss_ce_value, loss_fm_value, loss_S_value, loss_D_value) logging.info(msg) if (i_iter+1) % args.save_freq == 0: epoch = int((i_iter+1) // enum_batches) file_name = os.path.join(ckpts, 'model_epoch_{}.tar'.format(epoch)) file_name_D = os.path.join(ckpts, 'model_D_epoch_{}.tar'.format(epoch)) torch.save({ 'iter': i_iter+1, 'state_dict': model.state_dict(), 'ema_state_dict': ema_model.state_dict(), 'optim_dict': optimizer.state_dict(), }, file_name) i = num_iters + args.start_iter file_name = os.path.join(ckpts, 'model_last.tar') file_name_D = os.path.join(ckpts, 'model_D_last.tar') torch.save({ 'iter': i, 'state_dict': model.state_dict(), 'ema_state_dict': ema_model.state_dict(), 'optim_dict': optimizer.state_dict(), }, file_name) if args.valid_list: logging.info('-'*50) # msg = 'Best_Epoch {:.4f}, {}'.format(best_epoch, 'validate validation data') logging.info(msg) msg = '{}'.format('validate validation data with the model') logging.info(msg) with torch.no_grad(): ckpt = 'model_epoch_{}.tar'.format(50) ckpts_dir = args.getdir() dice_avg_score = validate(valid_loader, model, ckpt,ckpts_dir, names=valid_set.names, out_dir=args.out) if (dice_avg_score>=max_score): best_epoch = 50 max_score = dice_avg_score with torch.no_grad(): ckpt = 'model_epoch_{}.tar'.format(100) ckpts_dir = args.getdir() dice_avg_score = validate(valid_loader, model, ckpt,ckpts_dir, names=valid_set.names, out_dir=args.out) if (dice_avg_score>=max_score): best_epoch = 100 max_score = dice_avg_score with torch.no_grad(): ckpt = 'model_epoch_{}.tar'.format(150) ckpts_dir = args.getdir() dice_avg_score = validate(valid_loader, model, ckpt,ckpts_dir, names=valid_set.names, out_dir=args.out) if (dice_avg_score>=max_score): best_epoch = 150 max_score = dice_avg_score with torch.no_grad(): ckpt = 'model_epoch_{}.tar'.format(200) ckpts_dir = args.getdir() dice_avg_score = validate(valid_loader, model, ckpt,ckpts_dir, names=valid_set.names, out_dir=args.out) if (dice_avg_score>=max_score): best_epoch = 200 max_score = dice_avg_score with torch.no_grad(): ckpt = 'model_epoch_{}.tar'.format(250) ckpts_dir = args.getdir() dice_avg_score = validate(valid_loader, model, ckpt,ckpts_dir, names=valid_set.names, out_dir=args.out) if (dice_avg_score>=max_score): best_epoch = 250 max_score = dice_avg_score with torch.no_grad(): ckpt = 'model_epoch_{}.tar'.format(300) ckpts_dir = args.getdir() dice_avg_score = validate(valid_loader, model, ckpt,ckpts_dir, names=valid_set.names, out_dir=args.out) if (dice_avg_score>=max_score): best_epoch = 300 max_score = dice_avg_score with torch.no_grad(): ckpt = 'model_epoch_{}.tar'.format(350) ckpts_dir = args.getdir() dice_avg_score = validate(valid_loader, model, ckpt,ckpts_dir, names=valid_set.names, out_dir=args.out) if (dice_avg_score>=max_score): best_epoch = 350 max_score = dice_avg_score with torch.no_grad(): ckpt = 'model_epoch_{}.tar'.format(400) ckpts_dir = args.getdir() dice_avg_score = validate(valid_loader, model, ckpt,ckpts_dir, names=valid_set.names, out_dir=args.out) if (dice_avg_score>=max_score): best_epoch = 400 max_score = dice_avg_score with torch.no_grad(): ckpt = 'model_epoch_{}.tar'.format(450) ckpts_dir = args.getdir() dice_avg_score = validate(valid_loader, model, ckpt,ckpts_dir, names=valid_set.names, out_dir=args.out) if (dice_avg_score>=max_score): best_epoch = 450 max_score = dice_avg_score with torch.no_grad(): ckpt = 'model_epoch_{}.tar'.format(500) ckpts_dir = args.getdir() dice_avg_score = validate(valid_loader, model, ckpt,ckpts_dir, names=valid_set.names, out_dir=args.out) if (dice_avg_score>=max_score): best_epoch = 500 max_score = dice_avg_score with torch.no_grad(): ckpt = 'model_epoch_{}.tar'.format(550) ckpts_dir = args.getdir() dice_avg_score = validate(valid_loader, model, ckpt,ckpts_dir, names=valid_set.names, out_dir=args.out) if (dice_avg_score>=max_score): best_epoch = 550 max_score = dice_avg_score with torch.no_grad(): ckpt = 'model_epoch_{}.tar'.format(600) ckpts_dir = args.getdir() dice_avg_score = validate(valid_loader, model, ckpt,ckpts_dir, names=valid_set.names, out_dir=args.out) if (dice_avg_score>=max_score): best_epoch = 600 max_score = dice_avg_score msg = 'MAX-------------------' logging.info(msg) # print('best_epoch',best_epoch) # print('max_score',max_score) # msg = 'Epoch_MAX = {0:8d}, Score_Max = {1:.4f}'.format(best_epoch, max_score) msg = 'Epoch_MAX = {0}, WT_MAX = {1}'.format(best_epoch, max_score) logging.info(msg) msg = '{}'.format('validate validation data with the EMA model') logging.info(msg) with torch.no_grad(): ckpt = 'model_epoch_{}.tar'.format(50) ckpts_dir = args.getdir() dice_avg = validate_ema(valid_loader, ema_model, ckpt,ckpts_dir, names=valid_set.names, out_dir=args.out) if (dice_avg>=max_score_ema): best_epoch_ema = 50 max_score_ema = dice_avg with torch.no_grad(): ckpt = 'model_epoch_{}.tar'.format(100) ckpts_dir = args.getdir() dice_avg = validate_ema(valid_loader, ema_model, ckpt,ckpts_dir, names=valid_set.names, out_dir=args.out) if (dice_avg>=max_score_ema): best_epoch_ema = 100 max_score_ema = dice_avg with torch.no_grad(): ckpt = 'model_epoch_{}.tar'.format(150) ckpts_dir = args.getdir() dice_avg = validate_ema(valid_loader, ema_model, ckpt,ckpts_dir, names=valid_set.names, out_dir=args.out) if (dice_avg>=max_score_ema): best_epoch_ema = 150 max_score_ema = dice_avg with torch.no_grad(): ckpt = 'model_epoch_{}.tar'.format(200) ckpts_dir = args.getdir() dice_avg = validate_ema(valid_loader, ema_model, ckpt,ckpts_dir, names=valid_set.names, out_dir=args.out) if (dice_avg>=max_score_ema): best_epoch_ema = 200 max_score_ema = dice_avg with torch.no_grad(): ckpt = 'model_epoch_{}.tar'.format(250) ckpts_dir = args.getdir() dice_avg = validate_ema(valid_loader, ema_model, ckpt,ckpts_dir, names=valid_set.names, out_dir=args.out) if (dice_avg>=max_score_ema): best_epoch_ema = 250 max_score_ema = dice_avg with torch.no_grad(): ckpt = 'model_epoch_{}.tar'.format(300) ckpts_dir = args.getdir() dice_avg = validate_ema(valid_loader, ema_model, ckpt,ckpts_dir, names=valid_set.names, out_dir=args.out) if (dice_avg>=max_score_ema): best_epoch_ema = 300 max_score_ema = dice_avg with torch.no_grad(): ckpt = 'model_epoch_{}.tar'.format(350) ckpts_dir = args.getdir() dice_avg = validate_ema(valid_loader, ema_model, ckpt,ckpts_dir, names=valid_set.names, out_dir=args.out) if (dice_avg>=max_score_ema): best_epoch_ema = 350 max_score_ema = dice_avg with torch.no_grad(): ckpt = 'model_epoch_{}.tar'.format(400) ckpts_dir = args.getdir() dice_avg = validate_ema(valid_loader, ema_model, ckpt,ckpts_dir, names=valid_set.names, out_dir=args.out) if (dice_avg>=max_score_ema): best_epoch_ema = 400 max_score_ema = dice_avg with torch.no_grad(): ckpt = 'model_epoch_{}.tar'.format(450) ckpts_dir = args.getdir() dice_avg = validate_ema(valid_loader, ema_model, ckpt,ckpts_dir, names=valid_set.names, out_dir=args.out) if (dice_avg>=max_score_ema): best_epoch_ema = 450 max_score_ema = dice_avg with torch.no_grad(): ckpt = 'model_epoch_{}.tar'.format(500) ckpts_dir = args.getdir() dice_avg = validate_ema(valid_loader, ema_model, ckpt,ckpts_dir, names=valid_set.names, out_dir=args.out) if (dice_avg>=max_score_ema): best_epoch_ema = 500 max_score_ema = dice_avg with torch.no_grad(): ckpt = 'model_epoch_{}.tar'.format(550) ckpts_dir = args.getdir() dice_avg = validate_ema(valid_loader, ema_model, ckpt,ckpts_dir, names=valid_set.names, out_dir=args.out) if (dice_avg>=max_score_ema): best_epoch_ema = 550 max_score_ema = dice_avg with torch.no_grad(): ckpt = 'model_epoch_{}.tar'.format(600) ckpts_dir = args.getdir() dice_avg = validate_ema(valid_loader, ema_model, ckpt,ckpts_dir, names=valid_set.names, out_dir=args.out) if (dice_avg>=max_score_ema): best_epoch_ema = 600 max_score_ema = dice_avg msg = 'MAX------------------------ ' logging.info(msg) msg = 'Epoch_MAX_ema = {0}, Score_Max_ema = {1}'.format(best_epoch_ema, max_score_ema) logging.info(msg) if (max_score>=max_score_ema): best_epoch_all = best_epoch best_score_all = max_score ema_true = 0 else: best_epoch_all = best_epoch_ema best_score_all = max_score_ema ema_true = 1 msg = 'Best_Epoch_ALL= {0}, Best_Score_ALL= {1}'.format(best_epoch_all, best_score_all) logging.info(msg) msg = 'total time: {:.4f} minutes'.format((time.time() - start)/60) logging.info(msg) if ema_true==0: if args.valid_list: logging.info('-'*50) logging.info(msg) msg = '{}'.format('validate validation data with the model') logging.info(msg) with torch.no_grad(): ckpt = 'model_epoch_{}.tar'.format(best_epoch_all) ckpts_dir = args.getdir() validate_all(valid_loader, model, ckpt,ckpts_dir, names=valid_set.names, out_dir=args.out_dir) msg = 'total time: {:.4f} minutes'.format((time.time() - start)/60) logging.info(msg) if ema_true==1: if args.valid_list: logging.info('-'*50) logging.info(msg) msg = '{}'.format('validate validation data with the ema model') logging.info(msg) with torch.no_grad(): ckpt = 'model_epoch_{}.tar'.format(best_epoch_all) ckpts_dir = args.getdir() validate_ema_all(valid_loader, ema_model, ckpt,ckpts_dir, names=valid_set.names, out_dir=args.out_dir)