def run_check_net(train_dl, val_dl, multi_gpu=[0, 1]): set_logger(LOG_PATH) logging.info('\n\n') #--- if MODEL == 'UNetResNet34': net = UNetResNet34(debug=False).cuda(device=device) #elif MODEL == 'RESNET18': # net = AtlasResNet18(debug=False).cuda(device=device) # for param in net.named_parameters(): # if param[0][:8] in ['decoder5']:#'decoder5', 'decoder4', 'decoder3', 'decoder2' # param[1].requires_grad = False # dummy sgd to see if it can converge ... #optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), # lr=LearningRate, momentum=0.9, weight_decay=0.0001) #optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=0.045)#LearningRate #scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', # factor=0.5, patience=4,#4 resnet34 # verbose=False, threshold=0.0001, # threshold_mode='rel', cooldown=0, # min_lr=0, eps=1e-08) #scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size, gamma=0.9, last_epoch=-1) train_params = filter(lambda p: p.requires_grad, net.parameters()) optimizer = torch.optim.SGD(train_params, momentum=0.9, weight_decay=0.0001, lr=LearningRate) scheduler = LR_Scheduler( 'poly', LearningRate, NUM_EPOCHS, len(train_dl)) #lr_scheduler=['poly', 'step', 'cos'] if warm_start: logging.info('warm_start: ' + last_checkpoint_path) net, _ = load_checkpoint(last_checkpoint_path, net) # using multi GPU if multi_gpu is not None: net = nn.DataParallel(net, device_ids=multi_gpu) #use sync_batchnorm #net = convert_model(net) diff = 0 best_val_metric = -0.1 optimizer.zero_grad() #seed = get_seed() #seed = SEED #logging.info('aug seed: '+str(seed)) #ia.imgaug.seed(seed) #np.random.seed(seed) for i_epoch in range(NUM_EPOCHS): t0 = time.time() # iterate through trainset if multi_gpu is not None: net.module.set_mode('train') else: net.set_mode('train') train_loss_list = [] #train_metric_list #logit_list, truth_list = [], [] for i, (images, masks) in enumerate(train_dl): ## adjust learning rate scheduler(optimizer, i, i_epoch, best_val_metric) input_data = images.to(device=device, dtype=torch.float) #1 for non-zero-mask truth = (torch.sum(masks.reshape(masks.size()[0], masks.size()[1], -1), dim=2, keepdim=False) != 0).to(device=device, dtype=torch.float) logit = net(input_data) #logit_list.append(logit) #truth_list.append(truth) if multi_gpu is not None: _train_loss = net.module.criterion(logit, truth) #_train_metric = net.module.metric(logit, truth)#device='gpu' else: _train_loss = net.criterion(logit, truth) #_train_metric = net.metric(logit, truth)#device='gpu' train_loss_list.append(_train_loss.item()) #train_metric_list.append(_train_metric.item())#.detach() #grandient accumulation step=2 acc_step = 1 _train_loss = _train_loss / acc_step _train_loss.backward() if (i + 1) % acc_step == 0: optimizer.step() optimizer.zero_grad() train_loss = np.mean(train_loss_list) #train_metric = np.mean(train_metric_list) # if multi_gpu is not None: # train_metric, train_tn, train_fp, train_fn, train_tp, train_auc, train_pos_percent = net.module.metric(torch.cat(logit_list, dim=0), torch.cat(truth_list, dim=0)) # else: # train_metric, train_tn, train_fp, train_fn, train_tp, train_auc, train_pos_percent = net.metric(torch.cat(logit_list, dim=0), torch.cat(truth_list, dim=0)) # compute valid loss & metrics (concatenate valid set in cpu, then compute loss, metrics on full valid set) net.module.set_mode('valid') with torch.no_grad(): # val_loss_list, val_metric_list = [], [] # for i, (image, masks) in enumerate(val_dl): # input_data = image.to(device=device, dtype=torch.float) # truth = masks.to(device=device, dtype=torch.float) # logit = net(input_data) # if multi_gpu is not None: # _val_loss = net.module.criterion(logit, truth) # _val_metric = net.module.metric(logit, truth)#device='gpu' # else: # _val_loss = net.criterion(logit, truth) # _val_metric = net.metric(logit, truth)#device='gpu' # val_loss_list.append(_val_loss.item()) # val_metric_list.append(_val_metric.item())#.detach() # val_loss = np.mean(val_loss_list) # val_metric = np.mean(val_metric_list) logit_valid, truth_valid = None, None for j, (images, masks) in enumerate(val_dl): input_data = images.to(device=device, dtype=torch.float) #1 for non-zero-mask truth = (torch.sum(masks.reshape(masks.size()[0], masks.size()[1], -1), dim=2, keepdim=False) != 0).to(device=device, dtype=torch.float) logit = net(input_data) if logit_valid is None: logit_valid = logit truth_valid = truth else: logit_valid = torch.cat((logit_valid, logit), dim=0) truth_valid = torch.cat((truth_valid, truth), dim=0) if multi_gpu is not None: val_loss = net.module.criterion(logit_valid, truth_valid) _, val_metric, val_tn, val_fp, val_fn, val_tp, val_pos_percent = net.module.metric( logit_valid, truth_valid) else: val_loss = net.criterion(logit_valid, truth_valid) _, val_metric, val_tn, val_fp, val_fn, val_tp, val_pos_percent = net.metric( logit_valid, truth_valid) # Adjust learning_rate #scheduler.step(val_metric) # if i_epoch >= 30: if val_metric > best_val_metric: best_val_metric = val_metric is_best = True diff = 0 else: is_best = False diff += 1 if diff > early_stopping_round: logging.info( 'Early Stopping: val_metric does not increase %d rounds' % early_stopping_round) #print('Early Stopping: val_iou does not increase %d rounds'%early_stopping_round) break else: is_best = False #save checkpoint checkpoint_dict = \ { 'epoch': i_epoch, 'state_dict': net.module.state_dict() if multi_gpu is not None else net.state_dict(), 'optim_dict' : optimizer.state_dict(), 'metrics': {'train_loss': train_loss, 'val_loss': val_loss, 'val_metric': val_metric} } save_checkpoint(checkpoint_dict, is_best=is_best, checkpoint=checkpoint_path) #if i_epoch%20==0: if i_epoch > -1: logging.info( '[EPOCH %05d]train_loss: %0.5f; val_loss, val_metric: %0.5f, %0.5f' % (i_epoch, train_loss.item(), val_loss.item(), val_metric)) logging.info('val_pos_percent: %.3f' % (val_pos_percent)) logging.info('val (tn, fp, fn, tp): %d, %d, %d, %d' % (val_tn, val_fp, val_fn, val_tp)) logging.info('time elapsed: %0.1f min' % ((time.time() - t0) / 60))
def run_check_net(train_dl, val_dl, multi_gpu=[0, 1], nonempty_only_loss=False): set_logger(LOG_PATH) logging.info('\n\n') #--- enc, dec = MODEL.split('_')[0], MODEL.split('_')[1] net = SegmentationModule(net_enc=enc, net_dec=dec).cuda(device=device) # for param in net.named_parameters(): # if param[0][:8] in ['decoder5']:#'decoder5', 'decoder4', 'decoder3', 'decoder2' # param[1].requires_grad = False # train_params = [{'params': net.get_1x_lr_params(), 'lr': LearningRate}, # {'params': net.get_10x_lr_params(), 'lr': LearningRate * 10}]#for resnet backbone train_params = filter(lambda p: p.requires_grad, net.parameters()) # dummy sgd to see if it can converge ... #optimizer = torch.optim.SGD(train_params, # lr=LearningRate, momentum=0.9, weight_decay=0.0001) #optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=0.045)#LearningRate #scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', # factor=0.5, patience=4,#4 resnet34 # verbose=False, threshold=0.0001, # threshold_mode='rel', cooldown=0, # min_lr=0, eps=1e-08) #scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size, gamma=0.9, last_epoch=-1) #for deeplabv3plus customized optimizer = torch.optim.SGD(train_params, momentum=0.9, weight_decay=0.0001, lr=LearningRate) scheduler = LR_Scheduler( 'poly', LearningRate, NUM_EPOCHS, len(train_dl)) #lr_scheduler=['poly', 'step', 'cos'] if warm_start: logging.info('warm_start: ' + last_checkpoint_path) net, _ = load_checkpoint(last_checkpoint_path, net) # using multi GPU if multi_gpu is not None: net = nn.DataParallel(net, device_ids=multi_gpu) #use sync_batchnorm #net = convert_model(net) diff = 0 best_val_metric = -0.1 optimizer.zero_grad() #seed = get_seed() #seed = SEED #logging.info('aug seed: '+str(seed)) #ia.imgaug.seed(seed) #np.random.seed(seed) for i_epoch in range(NUM_EPOCHS): ### adjust learning rate #scheduler.step(epoch=i_epoch) #print('lr: %f'%scheduler.get_lr()[0]) t0 = time.time() # iterate through trainset if multi_gpu is not None: net.module.set_mode('train') else: net.set_mode('train') train_loss_list, train_metric_list = [], [] #for seed in [1]:#[1, SEED]:#augment raw data with a duplicate one (augmented) #seed = get_seed() #np.random.seed(seed) #ia.imgaug.seed(i//10) for i, (image, masks) in enumerate(train_dl): ## adjust learning rate scheduler(optimizer, i, i_epoch, best_val_metric) input_data = image.to(device=device, dtype=torch.float) truth = masks.to(device=device, dtype=torch.float) #set_trace() logit, logit_clf = net(input_data) #[:, :3, :, :] if multi_gpu is not None: _train_loss = net.module.criterion(logit, truth, nonempty_only_loss, logit_clf) _train_metric = net.module.metric(logit, truth, nonempty_only_loss, logit_clf) #device='gpu' else: _train_loss = net.criterion(logit, truth, nonempty_only_loss, logit_clf) _train_metric = net.metric(logit, truth, nonempty_only_loss, logit_clf) #device='gpu' train_loss_list.append(_train_loss.item()) train_metric_list.append(_train_metric.item()) #.detach() #grandient accumulation step=2 acc_step = 1 _train_loss = _train_loss / acc_step _train_loss.backward() if (i + 1) % acc_step == 0: optimizer.step() optimizer.zero_grad() train_loss = np.mean(train_loss_list) train_metric = np.mean(train_metric_list) # compute valid loss & metrics (concatenate valid set in cpu, then compute loss, metrics on full valid set) net.module.set_mode('valid') with torch.no_grad(): val_loss_list, val_metric_list = [], [] for i, (image, masks) in enumerate(val_dl): input_data = image.to(device=device, dtype=torch.float) truth = masks.to(device=device, dtype=torch.float) logit, logit_clf = net(input_data) if multi_gpu is not None: _val_loss = net.module.criterion(logit, truth, nonempty_only_loss, logit_clf) _val_metric = net.module.metric(logit, truth, nonempty_only_loss, logit_clf) #device='gpu' else: _val_loss = net.criterion(logit, truth, nonempty_only_loss, logit_clf) _val_metric = net.metric(logit, truth, nonempty_only_loss, logit_clf) #device='gpu' val_loss_list.append(_val_loss.item()) val_metric_list.append(_val_metric.item()) #.detach() val_loss = np.mean(val_loss_list) val_metric = np.mean(val_metric_list) # logit_valid, truth_valid = None, None # for j, (image, masks) in enumerate(val_dl): # input_data = image.to(device=device, dtype=torch.float) # logit = net(input_data).cpu().float() # truth = masks.cpu().float() # if logit_valid is None: # logit_valid = logit # truth_valid = truth # else: # logit_valid = torch.cat((logit_valid, logit), dim=0) # truth_valid = torch.cat((truth_valid, truth), dim=0) # if multi_gpu is not None: # val_loss = net.module.criterion(logit_valid, truth_valid) # val_metric = net.module.metric(logit_valid, truth_valid) # else: # val_loss = net.criterion(logit_valid, truth_valid) # val_metric = net.metric(logit_valid, truth_valid) # Adjust learning_rate #scheduler.step(val_metric) #for 1024 trainging is harder, sometimes too early stop, force to at least train 40 epochs if i_epoch >= 10: #-1 if val_metric > best_val_metric: best_val_metric = val_metric is_best = True diff = 0 else: is_best = False diff += 1 if diff > early_stopping_round: logging.info( 'Early Stopping: val_metric does not increase %d rounds' % early_stopping_round) #print('Early Stopping: val_iou does not increase %d rounds'%early_stopping_round) break else: is_best = False #save checkpoint checkpoint_dict = \ { 'epoch': i, 'state_dict': net.module.state_dict() if multi_gpu is not None else net.state_dict(), 'optim_dict' : optimizer.state_dict(), 'metrics': {'train_loss': train_loss, 'val_loss': val_loss, 'train_metric': train_metric, 'val_metric': val_metric} } save_checkpoint(checkpoint_dict, is_best=is_best, checkpoint=checkpoint_path) #if i_epoch%20==0: if i_epoch > -1: logging.info( '[EPOCH %05d]train_loss, train_metric: %0.5f, %0.5f; val_loss, val_metric: %0.5f, %0.5f; time elapsed: %0.1f min' % (i_epoch, train_loss.item(), train_metric.item(), val_loss.item(), val_metric.item(), (time.time() - t0) / 60))