def check_class(line): """Check if class name is valid in Data Base""" if line == "": print("** class name missing **") return False line = line.split() if line[0] not in models_dict.keys(): print("** class doesn't exist **") return False return True
def main(args): ############ init config ################ model_name = args.model assert model_name in models_dict.keys(),"The Usage model is not exist !" print('Usage model :{}'.format(model_name)) #################### init logger ################################### log_dir = './logs/'+ args.model+'_'+args.note + '/{}'.format(time.strftime('%Y%m%d-%H%M%S')) logger = get_logger(log_dir) print('RUNDIR: {}'.format(log_dir)) logger.info('{}-Train'.format(args.model)) # setting setting={k: v for k, v in args._get_kwargs()} logger.info(setting) args.save_path = log_dir args.save_tbx_log = args.save_path + '/tbx_log' writer = SummaryWriter(args.save_tbx_log) ##################### init device ################################# if args.manualSeed is None: args.manualSeed = random.randint(1, 10000) np.random.seed(args.manualSeed) torch.manual_seed(args.manualSeed) args.use_cuda= args.gpus>0 and torch.cuda.is_available() args.device = torch.device('cuda' if args.use_cuda else 'cpu') if args.use_cuda: torch.cuda.manual_seed(args.manualSeed) cudnn.benchmark = True ####################### init dataset ########################################### train_loader=get_dataloder(args,split_flag="train") val_loader=get_dataloder(args,split_flag="valid") ######################## init model ############################################ # model logger.info("Model Dict has keys: \n {}".format(models_dict.keys())) model=get_models(args) if torch.cuda.device_count() > 1 and args.use_cuda: logger.info('use: %d gpus', torch.cuda.device_count()) model = nn.DataParallel(model) logger.info('param size = %fMB', calc_parameters_count(model)) # init loss if args.loss=='bce': criterion=nn.BCELoss() elif args.loss=='bcelog': criterion=nn.BCEWithLogitsLoss() elif args.loss=="dice": criterion=DiceLoss() elif args.loss=="softdice": criterion=SoftDiceLoss() elif args.loss=='bcedice': criterion=BCEDiceLoss() else: criterion=nn.CrossEntropyLoss() if args.use_cuda: logger.info("load model and criterion to gpu !") model=model.to(args.device) criterion=criterion.to(args.device) # init optimizer if args.model_optimizer=="sgd": #torch.optim.SGD(parametetrs,lr=args.lr,weight_decay=args.weight_decay,momentum=args.momentum) optimizer=torch.optim.SGD(model.parameters(),lr=args.lr,weight_decay=args.weight_decay,momentum=args.momentum) else: optimizer=torch.optim.Adam(model.parameters(),args.lr,[args.beta1, args.beta2], weight_decay=args.weight_decay) # init schedulers Steplr scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,args.epoch) # scheduler=torch.optim.lr_scheduler.StepLR(optimizer=optimizer,step_size=30,gamma=0.1,last_epoch=-1) ############################### check resume ######################### start_epoch=0 if args.resume is not None: if os.path.isfile(args.resume): logger.info("Loading model and optimizer from checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume, map_location=args.device) start_epoch = checkpoint['epoch'] optimizer.load_state_dict(checkpoint['optimizer']) model.load_state_dict(checkpoint['state_dict']) scheduler.load_state_dict(checkpoint['scheduler']) else: raise FileNotFoundError("No checkpoint found at '{}'".format(args.resume)) #################################### train and val ######################## max_value=0 for epoch in range(start_epoch,args.epoch): # lr=adjust_learning_rate(args,optimizer,epoch) scheduler.step() logger.info('Epoch: %d lr %e', epoch, scheduler.get_lr()[0]) # train mr, ms, mp, mf, mjc, md, macc, mean_loss=train(args, model, criterion, train_loader, optimizer, epoch, logger) # write writer.add_scalar('Train/Loss', mean_loss, epoch) writer.add_scalar('Train/mAcc', macc, epoch) writer.add_scalar('Train/Recall', mr, epoch) writer.add_scalar('Train/Specifi', ms, epoch) writer.add_scalar('Train/Precision', mp, epoch) writer.add_scalar('Train/F1', mf, epoch) writer.add_scalar('Train/Jc', mjc, epoch) writer.add_scalar('Train/Dice', md, epoch) # val vmr, vms, vmp, vmf, vmjc, vmd, vmacc, vmean_loss=val(args, model, criterion, val_loader, epoch, logger) writer.add_scalar('Val/Loss', vmean_loss, epoch) writer.add_scalar('Val/mAcc', vmacc, epoch) writer.add_scalar('Val/Recall', vmr, epoch) writer.add_scalar('Val/Specifi', vms, epoch) writer.add_scalar('Val/Precision', vmp, epoch) writer.add_scalar('Val/F1', vmf, epoch) writer.add_scalar('Val/Jc', vmjc, epoch) writer.add_scalar('Val/Dice', vmd, epoch) is_best=True if (vmjc>=max_value) else False max_value=max(max_value,vmjc) state={ 'epoch': epoch, 'optimizer': optimizer.state_dict(), 'state_dict': model.state_dict(), 'scheduler': model.state_dict(), } logger.info("epoch:{} best:{} max_value:{}".format(epoch,is_best,max_value)) if not is_best: torch.save(state,os.path.join(args.save_path,"checkpoint.pth.tar")) else: torch.save(state,os.path.join(args.save_path,"checkpoint.pth.tar")) torch.save(state,os.path.join(args.save_path,"model_best.pth.tar")) writer.close()
def adjust_learning_rate(args,optimizer, epoch): """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" lr = args.lr assert len(args.gammas) == len(args.schedule), "length of gammas and schedule should be equal" for (gamma, step) in zip(args.gammas, args.schedule): if (epoch >= step): lr = lr * gamma else: break for param_group in optimizer.param_groups: param_group['lr'] = lr return lr if __name__ == '__main__': models_name=models_dict.keys() datasets_name=datasets_dict.keys() parser = argparse.ArgumentParser(description='Unet serieas baseline') # Add default argument parser.add_argument('--model', type=str, default='unet',choices=models_name, help='Model to train and evaluation') parser.add_argument('--note' ,type=str, default='_', help='model note ') parser.add_argument('--dataset',type=str, default='cvc',choices=datasets_name, help='Model to train and evaluation') parser.add_argument('--base_size', type=int, default=256, help="resize base size") parser.add_argument('--crop_size', type=int, default=256, help="crop size") parser.add_argument('--im_channel', type=int, default=3, help="input image channel ") parser.add_argument('--class_num', type=int, default=1, help="output feature channel") parser.add_argument('--epoch', type=int, default=1600, help="epochs") parser.add_argument('--train_batch', type=int, default=8, help="train_batch")