def train_net(): dataset = generate_dataset(cfg.DATA_NAME, cfg, 'train', cfg.DATA_AUG) # dataset = generate_dataset(cfg.DATA_NAME, cfg, 'train_cross'+cfg.DATA_CROSS, cfg.DATA_AUG) dataloader = DataLoader(dataset, batch_size=cfg.TRAIN_BATCHES, shuffle=cfg.TRAIN_SHUFFLE, num_workers=cfg.DATA_WORKERS, drop_last=True) test_dataset = generate_dataset(cfg.DATA_NAME, cfg, 'test') # test_dataset = generate_dataset(cfg.DATA_NAME, cfg, 'test_cross'+cfg.DATA_CROSS) test_dataloader = DataLoader(test_dataset, batch_size=cfg.TEST_BATCHES, shuffle=False, num_workers=cfg.DATA_WORKERS) net = generate_net(cfg) print('Use %d GPU'%cfg.TRAIN_GPUS) device = torch.device(0) if cfg.TRAIN_GPUS > 1: net = nn.DataParallel(net) patch_replication_callback(net) net.to(device) if cfg.TRAIN_CKPT: pretrained_dict = torch.load(cfg.TRAIN_CKPT) net_dict = net.state_dict() pretrained_dict = {k: v for k, v in pretrained_dict.items() if (k in net_dict) and (v.shape==net_dict[k].shape)} net_dict.update(pretrained_dict) net.load_state_dict(net_dict) # net.load_state_dict(torch.load(cfg.TRAIN_CKPT),False) # criterion = nn.CrossEntropyLoss(ignore_index=255) criterion = nn.BCEWithLogitsLoss() # optimizer = optim.SGD(net.parameters(), lr = cfg.TRAIN_LR, momentum=cfg.TRAIN_MOMENTUM) optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=cfg.TRAIN_LR) scheduler = lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5) # decay LR by a factor of 0.5 every 30 epochs itr = cfg.TRAIN_MINEPOCH * len(dataloader) max_itr = cfg.TRAIN_EPOCHS*len(dataloader) best_jacc = 0. best_epoch = 0 for epoch in range(cfg.TRAIN_MINEPOCH, cfg.TRAIN_EPOCHS): running_loss = 0.0 running_dice_loss = 0.0 net.train() scheduler.step() #now_lr = scheduler.get_lr() for i_batch, sample_batched in enumerate(dataloader): now_lr = adjust_lr(optimizer, itr, max_itr) inputs_batched, labels_batched = sample_batched['image'], sample_batched['segmentation'] optimizer.zero_grad() inputs_batched = inputs_batched.cuda() labels_batched = labels_batched.long().cuda() outputs = net(inputs_batched) loss = 0 dice_loss = 0 for output in outputs: output = output.cuda() loss += criterion(output, make_one_hot(labels_batched)) soft_predicts_batched = nn.Softmax(dim=1)(output) dice_loss += Jaccard_loss_cal(labels_batched, soft_predicts_batched, eps=1e-7) loss /= len(outputs) dice_loss /= len(outputs) (loss+dice_loss).backward() optimizer.step() running_loss += loss.item() running_dice_loss += dice_loss.item() itr += 1 i_batch = i_batch + 1 print('epoch:%d/%d\tmean loss:%g\tmean dice loss:%g \n' % (epoch, cfg.TRAIN_EPOCHS, running_loss/i_batch, running_dice_loss/i_batch)) #### start testing now if epoch % 10 == 0: IoUP = test_one_epoch(test_dataset, test_dataloader, net, epoch) if IoUP > best_jacc: model_snapshot(net.state_dict(), new_file=os.path.join(cfg.MODEL_SAVE_DIR,'model-best-%s_%s_%s_epoch%d_jac%.3f.pth'%(cfg.MODEL_NAME,cfg.MODEL_BACKBONE,cfg.DATA_NAME,epoch,IoUP)), old_file=os.path.join(cfg.MODEL_SAVE_DIR,'model-best-%s_%s_%s_epoch%d_jac%.3f.pth'%(cfg.MODEL_NAME,cfg.MODEL_BACKBONE,cfg.DATA_NAME,best_epoch,best_jacc))) best_jacc = IoUP best_epoch = epoch
def train_net(): # dataset = generate_dataset(cfg.DATA_NAME, cfg, 'train_cross'+cfg.DATA_CROSS, cfg.DATA_AUG) dataset = generate_dataset(cfg.DATA_NAME, cfg, 'train' + cfg.DATA_CROSS, cfg.DATA_AUG) dataloader = DataLoader(dataset, batch_size=cfg.TRAIN_BATCHES, shuffle=cfg.TRAIN_SHUFFLE, num_workers=cfg.DATA_WORKERS, drop_last=True) # dataset_mixup = generate_dataset(cfg.DATA_NAME, cfg, 'train_cross'+cfg.DATA_CROSS, cfg.DATA_AUG) dataset_mixup = generate_dataset(cfg.DATA_NAME, cfg, 'train' + cfg.DATA_CROSS, cfg.DATA_AUG) dataloader_mixup = DataLoader(dataset_mixup, batch_size=cfg.TRAIN_BATCHES, shuffle=cfg.TRAIN_SHUFFLE, num_workers=cfg.DATA_WORKERS, drop_last=True) # test_dataset = generate_dataset(cfg.DATA_NAME, cfg, 'test_cross'+cfg.DATA_CROSS) test_dataset = generate_dataset(cfg.DATA_NAME, cfg, 'test' + cfg.DATA_CROSS) test_dataloader = DataLoader(test_dataset, batch_size=cfg.TEST_BATCHES, shuffle=False, num_workers=cfg.DATA_WORKERS) net = generate_net(cfg) #if cfg.TRAIN_TBLOG: # from tensorboardX import SummaryWriter # Set the Tensorboard logger #tblogger = SummaryWriter(cfg.LOG_DIR) print('Use %d GPU' % cfg.TRAIN_GPUS) device = torch.device(0) if cfg.TRAIN_GPUS > 1: net = nn.DataParallel(net) patch_replication_callback(net) net.to(device) if cfg.TRAIN_CKPT: pretrained_dict = torch.load(cfg.TRAIN_CKPT) net_dict = net.state_dict() # for i, p in enumerate(net_dict): # print(i, p) pretrained_dict = { k: v for k, v in pretrained_dict.items() if (k in net_dict) and (v.shape == net_dict[k].shape) } net_dict.update(pretrained_dict) net.load_state_dict(net_dict) # net.load_state_dict(torch.load(cfg.TRAIN_CKPT),False) # for i, para in enumerate(net.named_parameters()): # (name, param) = para # print(i, name) threshold_dict = [] segment_dict = [] backbone_dict = [] for i, para in enumerate(net.parameters()): if i <= 47 and i >= 38: threshold_dict.append(para) elif i < 38: segment_dict.append(para) else: backbone_dict.append(para) # print(i) thr_optimizer = optim.SGD(threshold_dict, lr=10 * cfg.TRAIN_LR, momentum=cfg.TRAIN_MOMENTUM) seg_optimizer = optim.SGD(params=[{ 'params': backbone_dict, 'lr': cfg.TRAIN_LR }, { 'params': segment_dict, 'lr': 10 * cfg.TRAIN_LR }], momentum=cfg.TRAIN_MOMENTUM) '''optimizer = optim.SGD( params = [ {'params': get_params(net.module,key='1x'), 'lr': cfg.TRAIN_LR}, {'params': get_params(net.module,key='10x'), 'lr': 10*cfg.TRAIN_LR} ], momentum=cfg.TRAIN_MOMENTUM )''' #scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=cfg.TRAIN_LR_MST, gamma=cfg.TRAIN_LR_GAMMA, last_epoch=-1) itr = cfg.TRAIN_MINEPOCH * len(dataloader) max_itr = cfg.TRAIN_EPOCHS_lr * len(dataloader) #tblogger = SummaryWriter(cfg.LOG_DIR) #net.train() best_jacc = 0. best_epoch = 0 for epoch in range(cfg.TRAIN_MINEPOCH, cfg.TRAIN_EPOCHS): running_loss = 0.0 seg_jac_running_loss = 0.0 dice_running_loss = 0.0 grad_running_loss = 0.0 average_running_loss = 0.0 mixup_running_loss = 0.0 mixup_seg_jac_running_loss = 0.0 mixup_dice_running_loss = 0.0 mixup_grad_running_loss = 0.0 mixup_average_running_loss = 0.0 dataset_list = [] net.train() #scheduler.step() #now_lr = scheduler.get_lr() for i_batch, (sample_batched, mixup_batched) in enumerate( zip(dataloader, dataloader_mixup)): now_lr = adjust_lr(seg_optimizer, itr, max_itr) name_batched = sample_batched['name'] inputs_batched1, labels_batched_cpu1 = sample_batched[ 'image'], sample_batched['segmentation'] inputs_batched2, labels_batched_cpu2 = mixup_batched[ 'image'], mixup_batched['segmentation'] labels_batched1 = labels_batched_cpu1.long().to(1) labels_batched2 = labels_batched_cpu2.long().to(1) loss, seg_jac_loss, cosSim1, deep_feature1 = train_one_batch( inputs_batched1, labels_batched1, net, seg_optimizer, thr_optimizer) running_loss += loss.item() seg_jac_running_loss += seg_jac_loss.item() ############################## #### obtain mixup samples #### ############################## if epoch >= cfg.Mixup_start_epoch: margin = cfg.Threshold_margin imgresize = torch.nn.UpsamplingBilinear2d( size=(int(cfg.DATA_RESCALE / 4), int(cfg.DATA_RESCALE / 4))) deep_feature2, predicts_batched2 = net(inputs_batched2) predicts_batched2 = predicts_batched2.to(1) Softmax_predicts_batched = nn.Softmax(dim=1)(predicts_batched2) cosSim2 = Softmax_predicts_batched.mul( make_one_hot(labels_batched2)).sum(dim=1) cosSim2 = torch.from_numpy(cosSim2.cpu().data.numpy()).to(1) cosSim1 = torch.from_numpy(cosSim1).to(1) feature_label1 = imgresize( torch.unsqueeze(labels_batched_cpu1.to(1), 1)) feature_label2 = imgresize( torch.unsqueeze(labels_batched_cpu2.to(1), 1)) alpha = cfg.Alpha random_lambda = np.random.beta(alpha, alpha) mixup_input = input_mixup(inputs_batched1, inputs_batched2, random_lambda) mixup_label = label_mixup(labels_batched1, labels_batched2, cosSim1, cosSim2, random_lambda) feature_mixup_label = imgresize(torch.unsqueeze( mixup_label, 1)) mixup_label = mixup_label.long().to(1) cosSimilarity = input_mixup(cosSim1, cosSim2, random_lambda) cosSim1 = imgresize(torch.unsqueeze(cosSim1, 1)) cosSim2 = imgresize(torch.unsqueeze(cosSim2, 1)) ### here cosSim all means the confidence map of segmentation branch mixup_list = { 'deep_feature1': deep_feature1, 'deep_feature2': deep_feature2, 'feature_label1': feature_label1, 'feature_label2': feature_label1, 'feature_mixup_label': feature_mixup_label, 'cosSim1': cosSim1, 'cosSim2': cosSim2, 'random_lambda': random_lambda, 'cosSimilarity': cosSimilarity } loss, seg_jac_loss, MM_consistent_loss, Simi_consistent_loss, mixup_feature_seg_loss = train_one_batch( mixup_input, mixup_label, net, seg_optimizer, thr_optimizer, list=mixup_list, phase='mixup') mixup_running_loss += loss.item() mixup_seg_jac_running_loss += seg_jac_loss.item() mixup_dice_running_loss += MM_consistent_loss.item() mixup_grad_running_loss += 10 * Simi_consistent_loss.item() mixup_average_running_loss += 10 * mixup_feature_seg_loss.item( ) itr += 1 '''if (epoch) % 50 == 0: [batch, channel, height, width] = mixup_input.size() for i in range(batch): mixup1 = inputs_batched[i,:,:,:].cpu().numpy() mixup2 = mixup_inputs_batched[i,:,:,:].cpu().numpy() mixup_ = mixup_input[i,:,:,:].cpu().numpy() mixup1_l = labels_batched[i,:,:].cpu().numpy() mixup2_l = mixup_labels_batched[i,:,:].cpu().numpy() mixup_l = mixup_label[i,1,:,:].cpu().numpy() cosSimi1 = Sample_cos[i,0,:,:].cpu().numpy() cosSimi2 = Mixup_cos[i,0,:,:].cpu().numpy() dataset_list.append({'inputs_batched':np.uint8(mixup1*255), 'mixup_inputs_batched':np.uint8(mixup2*255), 'mixup_input':np.uint8(mixup_*255), 'name':name_batched[i], 'mixup_label':np.uint8(mixup_l*255), 'inputs_labels':np.uint8(mixup1_l*255), 'mixup_labels_batched':np.uint8(mixup2_l*255), 'cosSimi1':np.uint8(cosSimi1*255), 'cosSimi2':np.uint8(cosSimi2*255)}) if (epoch) % 50 == 0: dataset.save_result_train_mixup(dataset_list, cfg.MODEL_NAME)''' i_batch = i_batch + 1 print('epoch:%d/%d\tSegCE loss:%g \tSegJaccard loss:%g' % (epoch, cfg.TRAIN_EPOCHS, running_loss / i_batch, seg_jac_running_loss / i_batch)) if epoch >= cfg.Mixup_start_epoch: print( 'Mixup:\tSegCE loss:%g \tSegJaccard loss:%g \tMFMC loss:%g \tMCMC loss:%g \tMMfeature loss:%g \n' % (mixup_running_loss / i_batch, mixup_seg_jac_running_loss / i_batch, mixup_dice_running_loss / i_batch, mixup_grad_running_loss / i_batch, mixup_average_running_loss / i_batch)) #### start testing now if (epoch) % 2 == 0: Dice_score, IoUP = test_one_epoch(test_dataset, test_dataloader, net, epoch) if Dice_score > best_jacc: model_snapshot(net.state_dict(), new_file=os.path.join( cfg.MODEL_SAVE_DIR, 'model-best-%s_%s_%s_epoch%d_dice%.3f.pth' % (cfg.MODEL_NAME, cfg.MODEL_BACKBONE, cfg.DATA_NAME, epoch, Dice_score)), old_file=os.path.join( cfg.MODEL_SAVE_DIR, 'model-best-%s_%s_%s_epoch%d_dice%.3f.pth' % (cfg.MODEL_NAME, cfg.MODEL_BACKBONE, cfg.DATA_NAME, best_epoch, best_jacc))) best_jacc = Dice_score best_epoch = epoch
def test_net(): dataset = generate_dataset(cfg.DATA_NAME, cfg, 'test') dataloader = DataLoader(dataset, batch_size=cfg.TEST_BATCHES, shuffle=False, num_workers=cfg.DATA_WORKERS) net = generate_net(cfg) print('net initialize') if cfg.TEST_CKPT is None: raise ValueError( 'test.py: cfg.MODEL_CKPT can not be empty in test period') print('Use %d GPU' % cfg.TEST_GPUS) device = torch.device('cuda') if cfg.TEST_GPUS > 1: net = nn.DataParallel(net) patch_replication_callback(net) net.to(device) print('start loading model %s' % cfg.TEST_CKPT) model_dict = torch.load(cfg.TEST_CKPT, map_location=device) net.load_state_dict(model_dict) net.eval() result_list = [] with torch.no_grad(): for i_batch, sample_batched in enumerate(dataloader): name_batched = sample_batched['name'] row_batched = sample_batched['row'] col_batched = sample_batched['col'] [batch, channel, height, width] = sample_batched['image'].size() multi_avg = torch.zeros( (batch, cfg.MODEL_NUM_CLASSES, height, width), dtype=torch.float32).to(1) for rate in cfg.TEST_MULTISCALE: inputs_batched = sample_batched['image_%f' % rate] predicts = net(inputs_batched).to(1) predicts_batched = predicts.clone() del predicts if cfg.TEST_FLIP: inputs_batched_flip = torch.flip(inputs_batched, [3]) predicts_flip = torch.flip(net(inputs_batched_flip), [3]).to(1) predicts_batched_flip = predicts_flip.clone() del predicts_flip predicts_batched = (predicts_batched + predicts_batched_flip) / 2.0 predicts_batched = F.interpolate(predicts_batched, size=None, scale_factor=1 / rate, mode='bilinear', align_corners=True) multi_avg = multi_avg + predicts_batched del predicts_batched multi_avg = multi_avg / len(cfg.TEST_MULTISCALE) result = torch.argmax(multi_avg, dim=1).cpu().numpy().astype(np.uint8) for i in range(batch): row = row_batched[i] col = col_batched[i] # max_edge = max(row,col) # rate = cfg.DATA_RESCALE / max_edge # new_row = row*rate # new_col = col*rate # s_row = (cfg.DATA_RESCALE-new_row)//2 # s_col = (cfg.DATA_RESCALE-new_col)//2 # p = predicts_batched[i, s_row:s_row+new_row, s_col:s_col+new_col] p = result[i, :, :] * 255 p = cv2.resize(p, dsize=(col, row), interpolation=cv2.INTER_NEAREST) result_list.append({'predict': p, 'name': name_batched[i]}) print('%d/%d' % (i_batch, len(dataloader))) dataset.save_result(result_list, cfg.MODEL_NAME) dataset.do_python_eval(cfg.MODEL_NAME) print('Test finished')
def test_net(): dataset = RemoDataset.RemoDataset(cfg, 'val') dataloader = DataLoader(dataset, batch_size=cfg.bs, shuffle=False, num_workers=cfg.num_workers, # collate_fn=collate_fn, # drop_last=True ) net = generate_net(cfg) print('net initialize') if cfg.TEST_CKPT is None: raise ValueError('test.py: cfg.MODEL_CKPT can not be empty in test period') print('Use %d GPU' % cfg.TEST_GPUS) device = torch.device('cuda') if cfg.TEST_GPUS > 1: net = nn.DataParallel(net) patch_replication_callback(net) net.to(device) print('start loading model %s' % cfg.TEST_CKPT) model_dict = torch.load(cfg.TEST_CKPT, map_location=device) from collections import OrderedDict new_model_dict = OrderedDict() mod = net.state_dict() for k, v in model_dict.items(): if k[7:] in mod.keys(): name = k[7:] # remove module. new_model_dict[name] = v net.load_state_dict(new_model_dict) net.eval() result_list = [] with torch.no_grad(): hist = np.zeros((4, 4)) # for i_batch, sample_batched in tqdm(enumerate(dataloader)): for sample_batched in tqdm(dataloader): name_batched = sample_batched['name'] row_batched = sample_batched['row'] col_batched = sample_batched['col'] labels_batched = sample_batched['segmentation'] [batch, channel, height, width] = sample_batched['image'].size() multi_avg = torch.zeros((batch, cfg.MODEL_NUM_CLASSES, height, width), dtype=torch.float32).to(0) for rate in cfg.TEST_MULTISCALE: inputs_batched = sample_batched['image_%f' % rate] # inputs_batched = sample_batched['image'] inputs_batched = inputs_batched.cuda(0) predicts = net(inputs_batched).to(0) predicts_batched = predicts.clone() del predicts if cfg.TEST_FLIP: inputs_batched_flip = torch.flip(inputs_batched, [3]) predicts_flip = torch.flip(net(inputs_batched_flip), [3]).to(0) predicts_batched_flip = predicts_flip.clone() del predicts_flip predicts_batched = (predicts_batched + predicts_batched_flip) / 2.0 predicts_batched = F.interpolate(predicts_batched, size=None, scale_factor=1 / rate, mode='bilinear', align_corners=True) multi_avg = multi_avg + predicts_batched del predicts_batched multi_avg = multi_avg / len(cfg.TEST_MULTISCALE) result = torch.argmax(multi_avg, dim=1).cpu().numpy().astype(np.uint8) for i in range(batch): row = row_batched[i] col = col_batched[i] p = result[i, :, :] p = cv2.resize(p, dsize=(col, row), interpolation=cv2.INTER_NEAREST) # labels = labels_batched[i].cpu().numpy() result_list.append({'predict': p, 'gt': name_batched[i]}) # result_list.append({'predict':p, 'gt':labels}) dataset.do_python_eval(result_list) print('Test finished')
def train_net(): dataset = generate_dataset(cfg.DATA_NAME, cfg, 'train', cfg.DATA_AUG) dataloader = DataLoader(dataset, batch_size=cfg.TRAIN_BATCHES, shuffle=cfg.TRAIN_SHUFFLE, num_workers=cfg.DATA_WORKERS, drop_last=True) net = generate_net(cfg) if cfg.TRAIN_TBLOG: from tensorboardX import SummaryWriter # Set the Tensorboard logger tblogger = SummaryWriter(cfg.LOG_DIR) print('Use %d GPU' % cfg.TRAIN_GPUS) device = torch.device(0) if cfg.TRAIN_GPUS > 1: net = nn.DataParallel(net) patch_replication_callback(net) net.to(device) if cfg.TRAIN_CKPT: pretrained_dict = torch.load(cfg.TRAIN_CKPT) net_dict = net.state_dict() pretrained_dict = { k: v for k, v in pretrained_dict.items() if (k in net_dict) and (v.shape == net_dict[k].shape) } net_dict.update(pretrained_dict) net.load_state_dict(net_dict) # net.load_state_dict(torch.load(cfg.TRAIN_CKPT),False) criterion = nn.CrossEntropyLoss(ignore_index=255) optimizer = optim.SGD( params=[{ 'params': get_params(net.module, key='1x'), 'lr': cfg.TRAIN_LR }, { 'params': get_params(net.module, key='10x'), 'lr': 10 * cfg.TRAIN_LR }], momentum=cfg.TRAIN_MOMENTUM, weight_decay=cfg.TRAIN_WEIGHT_DECAY, ) #scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=cfg.TRAIN_LR_MST, gamma=cfg.TRAIN_LR_GAMMA, last_epoch=-1) itr = cfg.TRAIN_MINEPOCH * len(dataloader) max_itr = cfg.TRAIN_EPOCHS * len(dataloader) running_loss = 0.0 tblogger = SummaryWriter(cfg.LOG_DIR) for epoch in range(cfg.TRAIN_MINEPOCH, cfg.TRAIN_EPOCHS): #scheduler.step() #now_lr = scheduler.get_lr() for i_batch, sample_batched in enumerate(dataloader): now_lr = adjust_lr(optimizer, itr, max_itr) inputs_batched, labels_batched = sample_batched[ 'image'], sample_batched['segmentation'] optimizer.zero_grad() labels_batched = labels_batched.long().to(1) #0foreground_pix = (torch.sum(labels_batched!=0).float()+1)/(cfg.DATA_RESCALE**2*cfg.TRAIN_BATCHES) predicts_batched = net(inputs_batched) predicts_batched = predicts_batched.to(1) loss = criterion(predicts_batched, labels_batched) loss.backward() optimizer.step() running_loss += loss.item() print('epoch:%d/%d\tbatch:%d/%d\titr:%d\tlr:%g\tloss:%g ' % (epoch, cfg.TRAIN_EPOCHS, i_batch, dataset.__len__() // cfg.TRAIN_BATCHES, itr + 1, now_lr, running_loss)) if cfg.TRAIN_TBLOG and itr % 100 == 0: #inputs = np.array((inputs_batched[0]*128+128).numpy().transpose((1,2,0)),dtype=np.uint8) inputs = inputs_batched.numpy()[0] #inputs = inputs_batched.numpy()[0]/2.0 + 0.5 labels = labels_batched[0].cpu().numpy() labels_color = dataset.label2colormap(labels).transpose( (2, 0, 1)) predicts = torch.argmax(predicts_batched[0], dim=0).cpu().numpy() predicts_color = dataset.label2colormap(predicts).transpose( (2, 0, 1)) pix_acc = np.sum(labels == predicts) / (cfg.DATA_RESCALE**2) tblogger.add_scalar('loss', running_loss, itr) tblogger.add_scalar('lr', now_lr, itr) tblogger.add_scalar('pixel acc', pix_acc, itr) tblogger.add_image('Input', inputs, itr) tblogger.add_image('Label', labels_color, itr) tblogger.add_image('Output', predicts_color, itr) running_loss = 0.0 if itr % 5000 == 0: save_path = os.path.join( cfg.MODEL_SAVE_DIR, '%s_%s_%s_itr%d.pth' % (cfg.MODEL_NAME, cfg.MODEL_BACKBONE, cfg.DATA_NAME, itr)) torch.save(net.state_dict(), save_path) print('%s has been saved' % save_path) itr += 1 save_path = os.path.join( cfg.MODEL_SAVE_DIR, '%s_%s_%s_epoch%d_all.pth' % (cfg.MODEL_NAME, cfg.MODEL_BACKBONE, cfg.DATA_NAME, cfg.TRAIN_EPOCHS)) torch.save(net.state_dict(), save_path) if cfg.TRAIN_TBLOG: tblogger.close() print('%s has been saved' % save_path)
def train_net(): print('start') dataset = generate_dataset(cfg.DATA_NAME, cfg, 'train') dataloader = DataLoader(dataset, batch_size=cfg.TRAIN_BATCHES, shuffle=cfg.TRAIN_SHUFFLE, num_workers=cfg.DATA_WORKERS, drop_last=True) net = generate_net(cfg) if cfg.TRAIN_TBLOG: from tensorboardX import SummaryWriter # Set the Tensorboard logger tblogger = SummaryWriter(cfg.LOG_DIR) # os.environ["CUDA_VISIBLE_DEVICES"]='0,1' # print('zzz',torch.cuda.current_device()) print('Use %d GPU' % cfg.TRAIN_GPUS) # print('zrj',torch.cuda.device_count()) device = torch.device('cuda:4') if cfg.TRAIN_GPUS > 1: net = nn.DataParallel(net, device_ids=[4, 7]) patch_replication_callback(net) net.to(device) if cfg.TRAIN_CKPT: pretrained_dict = torch.load(cfg.TRAIN_CKPT) net_dict = net.state_dict() # print(net_dict.keys()) # input() pretrained_dict = { k: v for k, v in pretrained_dict.items() if (k in net_dict) and (v.shape == net_dict[k].shape) } net_dict.update(pretrained_dict) net.load_state_dict(net_dict) # net.load_state_dict(torch.load(cfg.TRAIN_CKPT),False) # print('zzz1') criterion = nn.CrossEntropyLoss(ignore_index=255) hgo_loss = HgoLoss() optimizer = optim.SGD(params=[{ 'params': get_params(net.module, key='1x'), 'lr': cfg.TRAIN_LR }, { 'params': get_params(net.module, key='10x'), 'lr': 10 * cfg.TRAIN_LR }], momentum=cfg.TRAIN_MOMENTUM) #scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=cfg.TRAIN_LR_MST, gamma=cfg.TRAIN_LR_GAMMA, last_epoch=-1) itr = cfg.TRAIN_MINEPOCH * len(dataloader) max_itr = cfg.TRAIN_EPOCHS * len(dataloader) running_loss = 0.0 tblogger = SummaryWriter(cfg.LOG_DIR) #net.eval() for epoch in range(cfg.TRAIN_MINEPOCH, cfg.TRAIN_EPOCHS): #scheduler.step() #now_lr = scheduler.get_lr() # running_loss = 0.0 for i_batch, sample_batched in enumerate(dataloader): now_lr = adjust_lr(optimizer, itr, max_itr) inputs_img1_batched, labels_img1_batched = sample_batched[ 'image1'], sample_batched['segmentation1'] inputs_img2_batched, labels_img2_batched = sample_batched[ 'image2'], sample_batched['segmentation2'] inputs_label_batched = sample_batched['label'] # print('input',inputs_img1_batched.shape,inputs_img2_batched.shape) optimizer.zero_grad() labels_img1_batched = labels_img1_batched.long().to(4) labels_img2_batched = labels_img2_batched.long().to(4) inputs_label_batched = inputs_label_batched.float().to(4) #0foreground_pix = (torch.sum(labels_batched!=0).float()+1)/(cfg.DATA_RESCALE**2*cfg.TRAIN_BATCHES) # print('input',inputs_batched.shape) predicts_img1_batched, predicts_img2_batched, predicts_label = net( inputs_img1_batched, inputs_img2_batched) # print('out',predicts_img1_batched.shape,predicts_label.shape) predicts_img1_batched = predicts_img1_batched.to(4) predicts_img2_batched = predicts_img2_batched.to(4) predicts_label = predicts_label.float().to(4) loss_img1 = criterion( predicts_img1_batched, labels_img1_batched, ) loss_img2 = criterion( predicts_img2_batched, labels_img2_batched, ) # print('inputs_label_batched',inputs_label_batched.shape) loss_2c = hgo_loss(predicts_label, inputs_label_batched) # print('zzzhrj',predicts_batched.shape,labels_batched.shape) loss_1 = 0.5 * (loss_img1 + loss_img2) loss = 0.7 * loss_2c + 0.3 * loss_1 loss.backward() optimizer.step() # print('zrj',loss.item(),running_loss) running_loss += loss.item() # print('zrj',type(loss),running_loss) print( 'epoch:%d/%d\tbatch:%d/%d\titr:%d\tlr:%g\tloss:%g\tloss1:%g\tloss2:%g ' % (epoch, cfg.TRAIN_EPOCHS, i_batch, dataset.__len__() // cfg.TRAIN_BATCHES, itr + 1, now_lr, running_loss, loss_1, loss_2c)) # if cfg.TRAIN_TBLOG and itr%100 == 0: # #inputs = np.array((inputs_batched[0]*128+128).numpy().transpose((1,2,0)),dtype=np.uint8) # #inputs = inputs_batched.numpy()[0] # inputs = inputs_batched.numpy()[0]/2.0 + 0.5 # labels = labels_batched[0].cpu().numpy() # labels_color = dataset.label2colormap(labels).transpose((2,0,1)) # predicts = torch.argmax(predicts_batched[0],dim=0).cpu().numpy() # predicts_color = dataset.label2colormap(predicts).transpose((2,0,1)) # pix_acc = np.sum(labels==predicts)/(cfg.DATA_RESCALE**2) # tblogger.add_scalar('loss', running_loss, itr) # tblogger.add_scalar('loss_1', loss_1, itr) # tblogger.add_scalar('loss_2', loss_2c, itr) # tblogger.add_scalar('lr', now_lr, itr) # tblogger.add_scalar('pixel acc', pix_acc, itr) # tblogger.add_image('Input', inputs, itr) # tblogger.add_image('Label', labels_color, itr) # tblogger.add_image('Output', predicts_color, itr) running_loss = 0.0 if itr % 60 == 0: save_path = os.path.join( cfg.MODEL_SAVE_DIR, '%s_%s_%s_itr%d.pth' % (cfg.MODEL_NAME, cfg.MODEL_BACKBONE, cfg.DATA_NAME, itr)) torch.save(net.state_dict(), save_path) print('%s has been saved' % save_path) itr += 1 save_path = os.path.join( cfg.MODEL_SAVE_DIR, '%s_%s_%s_epoch%d_all.pth' % (cfg.MODEL_NAME, cfg.MODEL_BACKBONE, cfg.DATA_NAME, cfg.TRAIN_EPOCHS)) torch.save(net.state_dict(), save_path) if cfg.TRAIN_TBLOG: tblogger.close() print('%s has been saved' % save_path)
def train_net(): # dataset = generate_dataset(cfg.DATA_NAME, cfg, 'train_cross'+cfg.DATA_CROSS, cfg.DATA_AUG) dataset = generate_dataset(cfg.DATA_NAME, cfg, 'train', cfg.DATA_AUG) dataloader = DataLoader(dataset, batch_size=cfg.TRAIN_BATCHES, shuffle=cfg.TRAIN_SHUFFLE, num_workers=cfg.DATA_WORKERS, drop_last=True) dataset_mixup = generate_dataset(cfg.DATA_NAME, cfg, 'train', cfg.DATA_AUG) dataloader_mixup = DataLoader(dataset_mixup, batch_size=cfg.TRAIN_BATCHES, shuffle=cfg.TRAIN_SHUFFLE, num_workers=cfg.DATA_WORKERS, drop_last=True) # test_dataset = generate_dataset(cfg.DATA_NAME, cfg, 'test_cross'+cfg.DATA_CROSS) test_dataset = generate_dataset(cfg.DATA_NAME, cfg, 'test') test_dataloader = DataLoader(test_dataset, batch_size=cfg.TEST_BATCHES, shuffle=False, num_workers=cfg.DATA_WORKERS) net = generate_net(cfg) print('Use %d GPU' % cfg.TRAIN_GPUS) device = torch.device(0) if cfg.TRAIN_GPUS > 1: net = nn.DataParallel(net) patch_replication_callback(net) net.to(device) if cfg.TRAIN_CKPT: pretrained_dict = torch.load(cfg.TRAIN_CKPT) net_dict = net.state_dict() pretrained_dict = { k: v for k, v in pretrained_dict.items() if (k in net_dict) and (v.shape == net_dict[k].shape) } net_dict.update(pretrained_dict) net.load_state_dict(net_dict) # net.load_state_dict(torch.load(cfg.TRAIN_CKPT),False) optimizer = optim.SGD(params=[{ 'params': get_params(net.module, key='1x'), 'lr': cfg.TRAIN_LR }, { 'params': get_params(net.module, key='10x'), 'lr': 10 * cfg.TRAIN_LR }], momentum=cfg.TRAIN_MOMENTUM) itr = cfg.TRAIN_MINEPOCH * len(dataloader) max_itr = cfg.TRAIN_EPOCHS * len(dataloader) best_jacc = 0. best_epoch = 0 for epoch in range(cfg.TRAIN_MINEPOCH, cfg.TRAIN_EPOCHS): running_loss = 0.0 seg_jac_running_loss = 0.0 SPL_loss1 = 0.0 SPL_loss2 = 0.0 epoch_samples1 = 0.0 epoch_samples2 = 0.0 Lambda1 = 0.0 Lambda0 = 0.0 mixup_running_loss = 0.0 mixup_seg_jac_running_loss = 0.0 mixup_SPL_loss1 = 0.0 mixup_SPL_loss2 = 0.0 mixup_epoch_samples1 = 0.0 mixup_epoch_samples2 = 0.0 mixup_Lambda1 = 0.0 mixup_Lambda0 = 0.0 dataset_list = [] net.train() ######################################################### ########### give lambda && decay coefficient ############ ######################################################### for i_batch, (sample_batched, mixup_batched) in enumerate( zip(dataloader, dataloader_mixup)): now_lr = adjust_lr(optimizer, itr, max_itr) inputs_batched, labels_batched = sample_batched[ 'image'], sample_batched['segmentation'] mixup_inputs_batched, mixup_labels_batched = mixup_batched[ 'image'], mixup_batched['segmentation'] labels_batched = labels_batched.long().to(1) mixup_labels_batched = mixup_labels_batched.long().to(1) alpha = cfg.Mixup_Alpha random_lambda = np.random.beta(alpha, alpha) mixuped_input = input_mixup(inputs_batched, mixup_inputs_batched, random_lambda) mixuped_label = label_mixup(labels_batched, mixup_labels_batched, random_lambda) CE, JA = train_one_epoch(inputs_batched, labels_batched, net, optimizer, epoch) running_loss += CE.item() seg_jac_running_loss += JA.item() itr += 1 CE, JA = train_one_epoch(mixuped_input, mixuped_label, net, optimizer, epoch) mixup_running_loss += CE.item() mixup_seg_jac_running_loss += JA.item() i_batch = i_batch + 1 print('\nepoch:%d/%d\tCE loss:%g\tJA loss:%g' % (epoch, cfg.TRAIN_EPOCHS, running_loss / i_batch, seg_jac_running_loss / i_batch)) print('mixup: \tCE loss:%g\tJA loss:%g' % (mixup_running_loss / i_batch, mixup_seg_jac_running_loss / i_batch)) #if (epoch) % 50 == 0: # dataset.save_result_train_weight(dataset_list, cfg.MODEL_NAME) #### start testing now Acc_array = 0. Prec_array = 0. Spe_array = 0. Rec_array = 0. IoU_array = 0. Dice_array = 0. HD_array = 0. sample_num = 0. result_list = [] net.eval() with torch.no_grad(): if epoch % 4 == 0: for i_batch, sample_batched in enumerate(test_dataloader): name_batched = sample_batched['name'] row_batched = sample_batched['row'] col_batched = sample_batched['col'] [batch, channel, height, width] = sample_batched['image'].size() multi_avg = torch.zeros( (batch, cfg.MODEL_NUM_CLASSES, height, width), dtype=torch.float32).to(1) for rate in cfg.TEST_MULTISCALE: inputs_batched = sample_batched['image_%f' % rate] _, predicts = net(inputs_batched) predicts = predicts.to(1) predicts_batched = predicts.clone() del predicts predicts_batched = F.interpolate(predicts_batched, size=None, scale_factor=1 / rate, mode='bilinear', align_corners=True) multi_avg = multi_avg + predicts_batched del predicts_batched multi_avg = multi_avg / len(cfg.TEST_MULTISCALE) result = torch.argmax(multi_avg, dim=1).cpu().numpy().astype(np.uint8) labels_batched = sample_batched['segmentation'].cpu( ).numpy() for i in range(batch): row = row_batched[i] col = col_batched[i] p = result[i, :, :] l = labels_batched[i, :, :] #p = cv2.resize(p, dsize=(col,row), interpolation=cv2.INTER_NEAREST) #l = cv2.resize(l, dsize=(col,row), interpolation=cv2.INTER_NEAREST) predict = np.int32(p) gt = np.int32(l) cal = gt < 255 mask = (predict == gt) * cal TP = np.zeros((cfg.MODEL_NUM_CLASSES), np.uint64) TN = np.zeros((cfg.MODEL_NUM_CLASSES), np.uint64) P = np.zeros((cfg.MODEL_NUM_CLASSES), np.uint64) T = np.zeros((cfg.MODEL_NUM_CLASSES), np.uint64) P = np.sum((predict == 1)).astype(np.float64) T = np.sum((gt == 1)).astype(np.float64) TP = np.sum( (gt == 1) * (predict == 1)).astype(np.float64) TN = np.sum( (gt == 0) * (predict == 0)).astype(np.float64) Acc = (TP + TN) / (T + P - TP + TN) Prec = TP / (P + 10e-6) Spe = TN / (P - TP + TN) Rec = TP / T DICE = 2 * TP / (T + P) IoU = TP / (T + P - TP) HD = max( directed_hausdorff(predict, gt)[0], directed_hausdorff(predict, gt)[0]) beta = 2 HD = Rec * Prec * (1 + beta**2) / ( Rec + beta**2 * Prec + 1e-10) Acc_array += Acc Prec_array += Prec Spe_array += Spe Rec_array += Rec Dice_array += DICE IoU_array += IoU HD_array += HD sample_num += 1 #p = cv2.resize(p, dsize=(col,row), interpolation=cv2.INTER_NEAREST) result_list.append({ 'predict': np.uint8(p * 255), 'label': np.uint8(l * 255), 'IoU': IoU, 'name': name_batched[i] }) Acc_score = Acc_array * 100 / sample_num Prec_score = Prec_array * 100 / sample_num Spe_score = Spe_array * 100 / sample_num Rec_score = Rec_array * 100 / sample_num Dice_score = Dice_array * 100 / sample_num IoUP = IoU_array * 100 / sample_num HD_score = HD_array * 100 / sample_num print( '%10s:%7.3f%% %10s:%7.3f%% %10s:%7.3f%% %10s:%7.3f%% %10s:%7.3f%% %10s:%7.3f%% %10s:%7.3f%%\n' % ('Acc', Acc_score, 'Sen', Rec_score, 'Spe', Spe_score, 'Prec', Prec_score, 'Dice', Dice_score, 'Jac', IoUP, 'F2', HD_score)) if Dice_score > best_jacc: model_snapshot( net.state_dict(), new_file=os.path.join( cfg.MODEL_SAVE_DIR, 'model-best-%s_%s_%s_epoch%d_dice%.3f.pth' % (cfg.MODEL_NAME, cfg.MODEL_BACKBONE, cfg.DATA_NAME, epoch, Dice_score)), old_file=os.path.join( cfg.MODEL_SAVE_DIR, 'model-best-%s_%s_%s_epoch%d_dice%.3f.pth' % (cfg.MODEL_NAME, cfg.MODEL_BACKBONE, cfg.DATA_NAME, best_epoch, best_jacc))) best_jacc = Dice_score best_epoch = epoch if epoch % 50 == 0: dataset.save_result_train(result_list, cfg.MODEL_NAME)
def train_net(): period = 'train' transform = 'weak' dataset = generate_dataset(cfg, period=period, transform=transform) def worker_init_fn(worker_id): np.random.seed(1 + worker_id) dataloader = DataLoader(dataset, batch_size=cfg.TRAIN_BATCHES, shuffle=cfg.TRAIN_SHUFFLE, num_workers=cfg.DATA_WORKERS, pin_memory=True, drop_last=True, worker_init_fn=worker_init_fn) if cfg.GPUS > 1: net = generate_net(cfg, batchnorm=SynchronizedBatchNorm2d, dilated=cfg.MODEL_BACKBONE_DILATED, multi_grid=cfg.MODEL_BACKBONE_MULTIGRID, deep_base=cfg.MODEL_BACKBONE_DEEPBASE) else: net = generate_net(cfg, batchnorm=nn.BatchNorm2d, dilated=cfg.MODEL_BACKBONE_DILATED, multi_grid=cfg.MODEL_BACKBONE_MULTIGRID, deep_base=cfg.MODEL_BACKBONE_DEEPBASE) if cfg.TRAIN_CKPT: net.load_state_dict(torch.load(cfg.TRAIN_CKPT),strict=True) print('load pretrained model') if cfg.TRAIN_TBLOG: from tensorboardX import SummaryWriter # Set the Tensorboard logger tblogger = SummaryWriter(cfg.LOG_DIR) print('Use %d GPU'%cfg.GPUS) device = torch.device(0) if cfg.GPUS > 1: net = nn.DataParallel(net) patch_replication_callback(net) parameter_source = net.module else: parameter_source = net net.to(device) criterion = nn.CrossEntropyLoss(ignore_index=255) optimizer = optim.SGD( params = [ {'params': get_params(parameter_source,key='backbone'), 'lr': cfg.TRAIN_LR}, {'params': get_params(parameter_source,key='cls'), 'lr': 10*cfg.TRAIN_LR}, {'params': get_params(parameter_source,key='others'), 'lr': cfg.TRAIN_LR} ], momentum=cfg.TRAIN_MOMENTUM, weight_decay=cfg.TRAIN_WEIGHT_DECAY ) itr = cfg.TRAIN_MINEPOCH * len(dataset)//(cfg.TRAIN_BATCHES) max_itr = cfg.TRAIN_ITERATION max_epoch = max_itr*(cfg.TRAIN_BATCHES)//len(dataset)+1 tblogger = SummaryWriter(cfg.LOG_DIR) criterion = nn.CrossEntropyLoss(ignore_index=255) scaler = torch.cuda.amp.GradScaler() with tqdm(total=max_itr) as pbar: for epoch in range(cfg.TRAIN_MINEPOCH, max_epoch): for i_batch, sample in enumerate(dataloader): now_lr = adjust_lr(optimizer, itr, max_itr, cfg.TRAIN_LR, cfg.TRAIN_POWER) optimizer.zero_grad() inputs, seg_label = sample['image'], sample['segmentation'] n,c,h,w = inputs.size() with torch.cuda.amp.autocast(): pred1 = net(inputs.to(0)) loss = criterion(pred1, seg_label.to(0)) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() pbar.set_description("loss=%g " % (loss.item())) pbar.update(1) time.sleep(0.001) #print('epoch:%d/%d\tbatch:%d/%d\titr:%d\tlr:%g\tloss:%g' % # (epoch, max_epoch, i_batch, len(dataset)//(cfg.TRAIN_BATCHES), # itr+1, now_lr, loss.item())) if cfg.TRAIN_TBLOG and itr%100 == 0: inputs1 = img_denorm(inputs[-1].cpu().numpy()).astype(np.uint8) label1 = sample['segmentation'][-1].cpu().numpy() label_color1 = dataset.label2colormap(label1).transpose((2,0,1)) n,c,h,w = inputs.size() seg_vis1 = torch.argmax(pred1[-1], dim=0).detach().cpu().numpy() seg_color1 = dataset.label2colormap(seg_vis1).transpose((2,0,1)) tblogger.add_scalar('loss', loss.item(), itr) tblogger.add_scalar('lr', now_lr, itr) tblogger.add_image('Input', inputs1, itr) tblogger.add_image('Label', label_color1, itr) tblogger.add_image('SEG1', seg_color1, itr) itr += 1 if itr>=max_itr: break save_path = os.path.join(cfg.MODEL_SAVE_DIR,'%s_%s_%s_epoch%d.pth'%(cfg.MODEL_NAME,cfg.MODEL_BACKBONE,cfg.DATA_NAME,epoch)) torch.save(parameter_source.state_dict(), save_path) print('%s has been saved'%save_path) remove_path = os.path.join(cfg.MODEL_SAVE_DIR,'%s_%s_%s_epoch%d.pth'%(cfg.MODEL_NAME,cfg.MODEL_BACKBONE,cfg.DATA_NAME,epoch-1)) if os.path.exists(remove_path): os.remove(remove_path) save_path = os.path.join(cfg.MODEL_SAVE_DIR,'%s_%s_%s_itr%d_all.pth'%(cfg.MODEL_NAME,cfg.MODEL_BACKBONE,cfg.DATA_NAME,cfg.TRAIN_ITERATION)) torch.save(parameter_source.state_dict(),save_path) if cfg.TRAIN_TBLOG: tblogger.close() print('%s has been saved'%save_path) writelog(cfg, period)
def test_net(): dataset = generate_dataset(cfg.DATA_NAME, cfg, 'val') dataloader = DataLoader(dataset, batch_size=cfg.TEST_BATCHES, shuffle=False, num_workers=cfg.DATA_WORKERS) net = generate_net(cfg) print('net initialize') if cfg.TEST_CKPT is None: raise ValueError('test.py: cfg.MODEL_CKPT can not be empty in test period') print('Use %d GPU'%cfg.TEST_GPUS) device = torch.device('cuda:6') if cfg.TEST_GPUS > 1: net = nn.DataParallel(net,device_ids=[6,7]) patch_replication_callback(net) net.to(device) print('start loading model %s'%cfg.TEST_CKPT) model_dict = torch.load(cfg.TEST_CKPT,map_location=device) net.load_state_dict(model_dict) net.eval() result_list = [] changed_right=0 changed_wrong=0 unchanged_right=0 unchanged_wrong=0 num=0 c_num=0 with torch.no_grad(): for i_batch, sample_batched in enumerate(dataloader): # name_batched = sample_batched['name'] # row_batched = sample_batched['row'] # col_batched = sample_batched['col'] inputs_img1_batched = sample_batched['image1'] inputs_img2_batched = sample_batched['image2'] gt_label=sample_batched['label'] # print('debug',len(gt_label)) [batch, channel, height, width] = sample_batched['image1'].size() multi_avg = torch.zeros((batch, cfg.MODEL_NUM_CLASSES, height, width), dtype=torch.float32).to(6) # for rate in cfg.TEST_MULTISCALE: # inputs_img1_batched = sample_batched['image1_%f'%rate] # inputs_img2_batched = sample_batched['image2_%f'%rate] # print(inputs_img1_batched.shape) pre1,pre2,label = net(inputs_img1_batched,inputs_img2_batched) gt_label=gt_label.unsqueeze(1) for i in range(len(label)): print(gt_label[i],label[i]) if gt_label[i].item()<0: c_num+=1 if label[i].item()<0: changed_right+=1 else: changed_wrong+=1 else: if label[i].item()>0: unchanged_right+=1 else: unchanged_wrong+=1 print('call back',changed_right,unchanged_right,changed_right/c_num) # if gt_label[i].item() # print('predicts',label,gt_label) # predicts_batched = predicts.clone() # del predicts # if cfg.TEST_FLIP: # inputs_batched_flip = torch.flip(inputs_batched,[3]) # predicts_flip = torch.flip(net(inputs_batched_flip),[3]).to(5) # predicts_batched_flip = predicts_flip.clone() # del predicts_flip # predicts_batched = (predicts_batched + predicts_batched_flip) / 2.0 # print('predicts_batched',predicts_batched.shape) # predicts_batched = F.interpolate(predicts_batched, size=None, scale_factor=1/rate, mode='bilinear', align_corners=True) # print('predicts_batched',predicts_batched.shape) # multi_avg = multi_avg + predicts_batched # del predicts_batched # print('multi_avg',multi_avg.shape) # print('t1 multi_avg',multi_avg.shape) # multi_avg = multi_avg / len(cfg.TEST_MULTISCALE) # print('tt multi_avg',multi_avg.shape) # result = torch.argmax(multi_avg, dim=1).cpu().numpy().astype(np.uint8) # print('result',result.shape) # for i in range(batch): # row = row_batched[i] # col = col_batched[i] # # max_edge = max(row,col) # # rate = cfg.DATA_RESCALE / max_edge # # new_row = row*rate # # new_col = col*rate # # s_row = (cfg.DATA_RESCALE-new_row)//2 # # s_col = (cfg.DATA_RESCALE-new_col)//2 # # p = predicts_batched[i, s_row:s_row+new_row, s_col:s_col+new_col] # p = result[i,:,:] # p = cv2.resize(p, dsize=(col,row), interpolation=cv2.INTER_NEAREST) # result_list.append({'predict':p, 'name':name_batched[i]}) # print('%d/%d'%(i_batch,len(dataloader))) # dataset.save_result(result_list, cfg.MODEL_NAME) # dataset.do_python_eval(cfg.MODEL_NAME) print('Test finished')
def train_net(): # dataset = generate_dataset(cfg.DATA_NAME, cfg, 'train_cross'+cfg.DATA_CROSS, cfg.DATA_AUG) dataset = generate_dataset(cfg.DATA_NAME, cfg, 'train'+cfg.DATA_CROSS, cfg.DATA_AUG) dataloader = DataLoader(dataset, batch_size=cfg.TRAIN_BATCHES, shuffle=cfg.TRAIN_SHUFFLE, num_workers=cfg.DATA_WORKERS, drop_last=True) # test_dataset = generate_dataset(cfg.DATA_NAME, cfg, 'test_cross'+cfg.DATA_CROSS) test_dataset = generate_dataset(cfg.DATA_NAME, cfg, 'test'+cfg.DATA_CROSS) test_dataloader = DataLoader(test_dataset, batch_size=cfg.TEST_BATCHES, shuffle=False, num_workers=cfg.DATA_WORKERS) net = generate_net(cfg) #if cfg.TRAIN_TBLOG: # from tensorboardX import SummaryWriter # Set the Tensorboard logger #tblogger = SummaryWriter(cfg.LOG_DIR) print('Use %d GPU'%cfg.TRAIN_GPUS) device = torch.device(0) if cfg.TRAIN_GPUS > 1: net = nn.DataParallel(net) patch_replication_callback(net) net.to(device) if cfg.TRAIN_CKPT: pretrained_dict = torch.load(cfg.TRAIN_CKPT) net_dict = net.state_dict() # for i, p in enumerate(net_dict): # print(i, p) pretrained_dict = {k: v for k, v in pretrained_dict.items() if (k in net_dict) and (v.shape==net_dict[k].shape)} net_dict.update(pretrained_dict) net.load_state_dict(net_dict) # net.load_state_dict(torch.load(cfg.TRAIN_CKPT),False) # for i, para in enumerate(net.named_parameters()): # (name, param) = para # print(i, name) threshold_dict = [] segment_dict = [] backbone_dict = [] for i, para in enumerate(net.parameters()): if i <= 47 and i >= 38: threshold_dict.append(para) elif i < 38: segment_dict.append(para) else: backbone_dict.append(para) # print(i) thr_optimizer = optim.SGD(threshold_dict, lr=10*cfg.TRAIN_LR, momentum=cfg.TRAIN_MOMENTUM) seg_optimizer = optim.SGD( params = [ {'params': backbone_dict, 'lr': cfg.TRAIN_LR}, {'params': segment_dict, 'lr': 10*cfg.TRAIN_LR} ], momentum=cfg.TRAIN_MOMENTUM) '''optimizer = optim.SGD( params = [ {'params': get_params(net.module,key='1x'), 'lr': cfg.TRAIN_LR}, {'params': get_params(net.module,key='10x'), 'lr': 10*cfg.TRAIN_LR} ], momentum=cfg.TRAIN_MOMENTUM )''' #scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=cfg.TRAIN_LR_MST, gamma=cfg.TRAIN_LR_GAMMA, last_epoch=-1) itr = cfg.TRAIN_MINEPOCH * len(dataloader) max_itr = cfg.TRAIN_EPOCHS_lr * len(dataloader) #tblogger = SummaryWriter(cfg.LOG_DIR) #net.train() best_jacc = 0. best_epoch = 0 for epoch in range(cfg.TRAIN_MINEPOCH, cfg.TRAIN_EPOCHS): running_loss = 0.0 seg_jac_running_loss = 0.0 dice_running_loss = 0.0 grad_running_loss = 0.0 average_running_loss = 0.0 mixup_running_loss = 0.0 mixup_seg_jac_running_loss = 0.0 mixup_dice_running_loss = 0.0 mixup_grad_running_loss = 0.0 mixup_average_running_loss = 0.0 dataset_list = [] net.train() #scheduler.step() #now_lr = scheduler.get_lr() for i_batch, sample_batched in enumerate(dataloader): now_lr = adjust_lr(seg_optimizer, itr, max_itr) name_batched = sample_batched['name'] inputs_batched1, labels_batched_cpu1 = sample_batched['image'], sample_batched['segmentation'] labels_batched1 = labels_batched_cpu1.long().to(1) loss, seg_jac_loss, dice_loss, grad_loss, sup_loss = train_one_batch(inputs_batched1, labels_batched1, net, seg_optimizer, thr_optimizer) running_loss += loss.item() seg_jac_running_loss += seg_jac_loss.item() dice_running_loss += dice_loss.item() grad_running_loss += grad_loss.item() average_running_loss += sup_loss.item() i_batch = i_batch + 1 print('epoch:%d/%d\tSegCE loss:%g \tSegJaccard loss:%g \tThrJaccard loss:%g \tThrGrad loss:%g \tThrSup loss:%g' % (epoch, cfg.TRAIN_EPOCHS, running_loss/i_batch, seg_jac_running_loss/i_batch, dice_running_loss/i_batch, grad_running_loss/i_batch, average_running_loss/i_batch)) #### start testing now if (epoch) % 2 == 0: Dice_score, IoUP = test_one_epoch(test_dataset, test_dataloader, net, epoch) if Dice_score > best_jacc: model_snapshot(net.state_dict(), new_file=os.path.join(cfg.MODEL_SAVE_DIR,'model-best-%s_%s_%s_epoch%d_dice%.3f.pth'%(cfg.MODEL_NAME,cfg.MODEL_BACKBONE,cfg.DATA_NAME,epoch,Dice_score)), old_file=os.path.join(cfg.MODEL_SAVE_DIR,'model-best-%s_%s_%s_epoch%d_dice%.3f.pth'%(cfg.MODEL_NAME,cfg.MODEL_BACKBONE,cfg.DATA_NAME,best_epoch,best_jacc))) best_jacc = Dice_score best_epoch = epoch