def evaluate(val_loader, model, criterion, epoch, gt_labels): #progress bar val_progressor = ProgressBar(mode='Val', epoch=epoch, total_epoch=param_set['epoch'], model_name=param_set['model'], total=len(val_loader)) #switch to evaluate model and confirm model has been transferred to cuda model.cuda() model.eval() with torch.no_grad(): iter_loss = 0 iter_dice = 0 val_iter_loss = 0 val_iter_dice = 0 pred_labels = [] for step, (images_pv, images_art, labels) in enumerate(val_loader): val_progressor.current = step if USE_GPU: images_pv = images_pv.cuda() images_art = images_art.cuda() labels = labels.cuda() #compute the output outputs = model(images_pv, images_art) val_loss = criterion(outputs, labels) prob = F.softmax(outputs, dim=1).data.cpu() val_out_soft = F.softmax(outputs, dim=1).data.cpu() val_dice = soft_dice(val_out_soft[:, 1], labels.cpu().float()) val_iter_loss += float(val_loss.item()) val_iter_dice += val_dice.cpu() pred_labels.append(prob[:, 1]) val_progressor.current_loss = val_loss.item() val_progressor.current_dice = val_dice.cpu() val_progressor() # clear cache del images_pv, images_art, labels, outputs, val_loss, val_dice torch.cuda.empty_cache() val_progressor.done() # val_epoch_loss = sum(val_iter_loss) / len(loader) # val_epoch_dice = sum(val_iter_dice) / len(loader) # arr_pred_labels = np.asarray(pred_labels) # np.array arr_pred_labels = torch.cat(pred_labels, dim=0) arr_pred_labels[arr_pred_labels < 0.6] = 0 arr_pred_labels[arr_pred_labels >= 0.6] = 1 global_dice = soft_dice(arr_pred_labels, gt_labels) val_epoch_loss = val_iter_loss / len(val_loader) val_epoch_dice = val_iter_dice / len(val_loader) del pred_labels, arr_pred_labels return [global_dice, val_epoch_loss, val_epoch_dice]
def train(param_set, model): folder = datetime.datetime.now().strftime('%Y-%m-%d-%H') save_dir = param_set['result_dir'] + folder ckpt_dir = save_dir + '/checkpoint' log_dir = save_dir + '/log' test_result_dir = save_dir + '/testResult' if not os.path.exists(log_dir): os.makedirs(log_dir) os.mkdir(ckpt_dir) os.mkdir(test_result_dir) for file in os.listdir(log_dir): print('removing ' + os.path.join(log_dir, file)) os.remove(os.path.join(log_dir, file)) test_batch = ['batch5'] val_loader = DataLoader(Multi_PNGDataset_val(param_set['imgdir'], test_batch), num_workers=num_workers, batch_size=param_set['batch_size'], shuffle=False) gt_labels = get_gt_labels( os.path.join(param_set['imgdir'], test_batch[0], 'mask/')) # when evaluate '''c_weight = torch.ones(NUM_CLASSES) if USE_GPU: criterion = CrossEntropyLoss2d(c_weight.cuda()) # define the criterion''' criterion = nn.CrossEntropyLoss() optimizer = Adam(model.parameters(), lr=5e-4) # define the optimizer # lr_schedule = lr_scheduler.MultiStepLR(optimizer, milestones=[50, 200], gamma=0.1) # lr decay # lr_schedule = lr_scheduler.MultiStepLR(optimizer, milestones=[5, 8], gamma=0.1) # for debug epoch_save = param_set['epoch_save'] num_epochs = param_set['epoch'] cvBatch = ['batch4', 'batch3', 'batch2', 'batch1'] writer = SummaryWriter(log_dir) iter_count = 0 #writer.add_graph(model, ) # loader = DataLoader(PNGDataset(param_set['imgdir'],cvBatch),num_workers=num_workers, batch_size=param_set['batch_size'], shuffle=True) loader = DataLoader(Multi_PNGDataset(param_set['imgdir'], cvBatch), num_workers=num_workers, batch_size=param_set['batch_size'], shuffle=True) print('steps per epoch:', len(loader)) best_val_dice = 0.5 # 0 for debug model.train() for epoch in range(num_epochs + 1): # lr_schedule.step() # lr decay train_progressor = ProgressBar(mode='Train', epoch=epoch, total_epoch=num_epochs, model_name=param_set['model'], total=len(loader)) # train dataloader iter_loss = 0 iter_dice = 0 for step, (images_pv, images_art, labels) in enumerate(loader): train_progressor.current = step model.train() if USE_GPU: images_pv = images_pv.cuda(non_blocking=True) # inputs to GPU images_art = images_art.cuda(non_blocking=True) labels = labels.cuda(non_blocking=True) outputs = model(images_pv, images_art) # forward out_soft = F.softmax(outputs, dim=1).data.cpu() out_dice = soft_dice(out_soft[:, 1].cpu(), labels.cpu().float()) iter_dice += out_dice.cpu() loss = criterion(outputs, labels) # loss iter_loss += float(loss.item()) writer.add_scalar('train/loss', loss.item(), epoch * len(loader) + step) # writer.add_image('train/input_pv',images_pv[0],global_step=10) # writer.add_image('train/input_art',images_art[0],global_step=10) # writer.add_image('train/gt',labels[0],global_step = 10) # writer.add_image('train/output',out_soft[0,1,:,:],global_step=10) train_progressor.current_loss = loss.item() train_progressor.current_dice = out_dice.cpu() optimizer.zero_grad() # zero the parameter gradients loss.backward() # backward optimizer.step() # optimize train_progressor() iter_count += 1 # clear cache del images_pv, images_art, labels, loss, outputs, out_soft, out_dice # import gc # gc.collect() torch.cuda.empty_cache() train_progressor.done() # save best model in terms of validation dice--------- #evaluate valid_result = evaluate(val_loader, model, criterion, epoch, gt_labels) # print("epoch {epoch}, validation dice {val_dice}".format(epoch=epoch, val_dice=valid_result[0])) #for debug with open(save_dir + "/validation_log.txt", "a") as f: print("epoch {epoch}, validation dice {val_dice}".format( epoch=epoch, val_dice=valid_result[0]), file=f) if valid_result[0] >= 0.5: is_best = valid_result[0] > best_val_dice best_val_dice = max(valid_result[0], best_val_dice) filename = "{model}-{epoch:03}-{step:04}-{dice}.pth".format( model=param_set['model'], epoch=epoch, step=step, dice=valid_result[0]) #save_checkpoint(model.state_dict(), is_best, ckpt_dir, filename) save_checkpoint_new(model.state_dict(), is_best, ckpt_dir, filename, SAVE=(epoch > 0 and epoch % epoch_save == 0 or epoch == num_epochs)) epoch_loss = iter_loss / len(loader) epoch_dice = iter_dice / len(loader) writer.add_scalar('train/epoch_loss', epoch_loss, epoch) writer.add_scalar('train/epoch_dice', epoch_dice, epoch) writer.close()