def train_model(args): """ args: args: global arguments """ h, w = map(int, args.input_size.split(',')) input_size = (h, w) print("=====> input size:{}".format(input_size)) print(args) if args.cuda: print("=====> use gpu id: '{}'".format(args.gpus)) os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus if not torch.cuda.is_available(): raise Exception( "No GPU found or Wrong gpu id, please run without --cuda") # set the seed setup_seed(GLOBAL_SEED) print("=====> set Global Seed: ", GLOBAL_SEED) cudnn.enabled = True print("=====> building network") # build the model and initialization model = build_model(args.model, num_classes=args.classes) init_weight(model, nn.init.kaiming_normal_, nn.BatchNorm2d, 1e-3, 0.1, mode='fan_in') print("=====> computing network parameters and FLOPs") total_parameters = netParams(model) print("the number of parameters: %d ==> %.2f M" % (total_parameters, (total_parameters / 1e6))) # load data and data augmentation datas, trainLoader, valLoader = build_dataset_train( args.dataset, args.classes, input_size, args.batch_size, args.train_type, False, False, args.num_workers) args.per_iter = len(trainLoader) args.max_iter = args.max_epochs * args.per_iter print('=====> Dataset statistics') print("data['classWeights']: ", datas['classWeights']) print('mean and std: ', datas['mean'], datas['std']) # datas['classWeights'] = np.array([4.044603, 2.0614128, 4.2246304, 6.0238333, # 10.107266, 8.601249, 8.808282], dtype=np.float32) # datas['mean'] = [0.5, 0.5, 0.5] # datas['std'] = [0.2, 0.2, 0.2] # define loss function, respectively weight = torch.from_numpy(datas['classWeights']) if args.dataset == 'pollen': weight = torch.tensor([1., 1.]) if args.dataset == 'camvid': criteria = CrossEntropyLoss2d(weight=weight, ignore_label=args.ignore_label) elif args.dataset == 'camvid' and args.use_label_smoothing: criteria = CrossEntropyLoss2dLabelSmooth( weight=weight, ignore_label=args.ignore_label) elif args.dataset == 'cityscapes' and args.use_ohem: min_kept = int(args.batch_size // len(args.gpus) * h * w // 16) criteria = ProbOhemCrossEntropy2d(use_weight=True, ignore_label=args.ignore_label, thresh=0.7, min_kept=min_kept) elif args.dataset == 'cityscapes' and args.use_label_smoothing: criteria = CrossEntropyLoss2dLabelSmooth( weight=weight, ignore_label=args.ignore_label) elif args.dataset == 'cityscape' and args.use_lovaszsoftmax: criteria = LovaszSoftmax(ignore_index=args.ignore_label) elif args.dataset == 'cityscapes' and args.use_focal: criteria = FocalLoss2d(weight=weight, ignore_index=args.ignore_label) elif args.dataset == 'seed': criteria = CrossEntropyLoss2d(weight=weight, ignore_label=args.ignore_label) elif args.dataset == 'remote': criteria = CrossEntropyLoss2d(weight=weight, ignore_label=args.ignore_label) elif args.dataset == 'remote' and args.use_ohem: min_kept = int(args.batch_size // len(args.gpus) * h * w // 16) criteria = ProbOhemCrossEntropy2d(use_weight=True, ignore_label=args.ignore_label, thresh=0.7, min_kept=min_kept) elif args.dataset == 'remote' and args.use_label_smoothing: criteria = CrossEntropyLoss2dLabelSmooth( weight=weight, ignore_label=args.ignore_label) elif args.dataset == 'remote' and args.use_lovaszsoftmax: criteria = LovaszSoftmax(ignore_index=args.ignore_label) elif args.dataset == 'remote' and args.use_focal: criteria = FocalLoss2d(weight=weight, ignore_index=args.ignore_label) else: criteria = CrossEntropyLoss2d(weight=weight, ignore_label=args.ignore_label) if args.cuda: criteria = criteria.cuda() if torch.cuda.device_count() > 1: print("torch.cuda.device_count()=", torch.cuda.device_count()) args.gpu_nums = torch.cuda.device_count() model = nn.DataParallel(model).cuda() # multi-card data parallel else: args.gpu_nums = 1 print("single GPU for training") model = model.cuda() # 1-card data parallel args.savedir = (args.savedir + args.dataset + '/' + args.model + 'bs' + str(args.batch_size) + 'gpu' + str(args.gpu_nums) + "_" + str(args.train_type) + '/') if not os.path.exists(args.savedir): os.makedirs(args.savedir) start_epoch = 0 # continue training if args.resume: if os.path.isfile(args.resume): checkpoint = torch.load(args.resume) start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['model']) # model.load_state_dict(convert_state_dict(checkpoint['model'])) print("=====> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=====> no checkpoint found at '{}'".format(args.resume)) model.train() cudnn.benchmark = True # cudnn.deterministic = True ## my add logFileLoc = args.savedir + args.logFile if os.path.isfile(logFileLoc): logger = open(logFileLoc, 'a') else: logger = open(logFileLoc, 'w') logger.write("Parameters: %s Seed: %s" % (str(total_parameters), GLOBAL_SEED)) logger.write("\n%s\t\t%s\t%s\t%s" % ('Epoch', 'Loss(Tr)', 'mIOU (val)', 'lr')) logger.flush() # define optimization strategy if args.optim == 'sgd': optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, momentum=0.9, weight_decay=1e-4) elif args.optim == 'adam': optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-4) elif args.optim == 'radam': optimizer = RAdam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, betas=(0.90, 0.999), eps=1e-08, weight_decay=1e-4) elif args.optim == 'ranger': optimizer = Ranger(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, betas=(0.95, 0.999), eps=1e-08, weight_decay=1e-4) elif args.optim == 'adamw': optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-4) lossTr_list = [] epoches = [] mIOU_val_list = [] print('=====> beginning training') for epoch in range(start_epoch, args.max_epochs): # training lossTr, lr = train(args, trainLoader, model, criteria, optimizer, epoch) lossTr_list.append(lossTr) # validation if epoch % 2 == 0 or epoch == (args.max_epochs - 1): epoches.append(epoch) mIOU_val, per_class_iu = val(args, valLoader, model) mIOU_val_list.append(mIOU_val) # record train information logger.write("\n%d\t\t%.4f\t\t%.4f\t\t%.7f" % (epoch, lossTr, mIOU_val, lr)) logger.flush() print("Epoch : " + str(epoch) + ' Details') print( "Epoch No.: %d\tTrain Loss = %.4f\t mIOU(val) = %.4f\t lr= %.6f\n" % (epoch, lossTr, mIOU_val, lr)) else: # record train information logger.write("\n%d\t\t%.4f\t\t\t\t%.7f" % (epoch, lossTr, lr)) logger.flush() print("Epoch : " + str(epoch) + ' Details') print("Epoch No.: %d\tTrain Loss = %.4f\t lr= %.6f\n" % (epoch, lossTr, lr)) # save the model model_file_name = args.savedir + '/model_' + str(epoch + 1) + '.pth' state = {"epoch": epoch + 1, "model": model.state_dict()} # Individual Setting for save model !!! if args.dataset == 'camvid': torch.save(state, model_file_name) elif args.dataset == 'cityscapes': if epoch >= args.max_epochs - 10: torch.save(state, model_file_name) elif not epoch % 50: torch.save(state, model_file_name) elif args.dataset == 'seed': torch.save(state, model_file_name) else: torch.save(state, model_file_name) # draw plots for visualization if epoch % 5 == 0 or epoch == (args.max_epochs - 1): # Plot the figures per 50 epochs fig1, ax1 = plt.subplots(figsize=(11, 8)) ax1.plot(range(start_epoch, epoch + 1), lossTr_list) ax1.set_title("Average training loss vs epochs") ax1.set_xlabel("Epochs") ax1.set_ylabel("Current loss") plt.savefig(args.savedir + "loss_vs_epochs.png") plt.clf() fig2, ax2 = plt.subplots(figsize=(11, 8)) ax2.plot(epoches, mIOU_val_list, label="Val IoU") ax2.set_title("Average IoU vs epochs") ax2.set_xlabel("Epochs") ax2.set_ylabel("Current IoU") plt.legend(loc='lower right') plt.savefig(args.savedir + "iou_vs_epochs.png") plt.close('all') logger.close()
def train_model(args): """ args: args: global arguments """ h, w = map(int, args.input_size.split(',')) input_size = (h, w) print("input size:{}".format(input_size)) print(args) if args.cuda: print("use gpu id: '{}'".format(args.gpus)) os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus if not torch.cuda.is_available(): raise Exception( "No GPU found or Wrong gpu id, please run without --cuda") # set the seed setup_seed(GLOBAL_SEED) print("set Global Seed: ", GLOBAL_SEED) cudnn.enabled = True print("building network") # build the model and initialization model = build_model(args.model, num_classes=args.classes) init_weight(model, nn.init.kaiming_normal_, nn.BatchNorm2d, 1e-3, 0.1, mode='fan_in') print("computing network parameters and FLOPs") total_paramters = netParams(model) print("the number of parameters: %d ==> %.2f M" % (total_paramters, (total_paramters / 1e6))) # load data and data augmentation datas, trainLoader, valLoader = build_dataset_train( args.dataset, input_size, args.batch_size, args.train_type, args.random_scale, args.random_mirror, args.num_workers) args.per_iter = len(trainLoader) args.max_iter = args.max_epochs * args.per_iter print('Dataset statistics') print("data['classWeights']: ", datas['classWeights']) print('mean and std: ', datas['mean'], datas['std']) # define loss function, respectively weight = torch.from_numpy(datas['classWeights']) if args.dataset == 'camvid': criteria = CrossEntropyLoss2d(weight=weight, ignore_label=ignore_label) elif args.dataset == 'camvid' and args.use_label_smoothing: criteria = CrossEntropyLoss2dLabelSmooth(weight=weight, ignore_label=ignore_label) elif args.dataset == 'cityscapes' and args.use_ohem: min_kept = int(args.batch_size // len(args.gpus) * h * w // 16) criteria = ProbOhemCrossEntropy2d(use_weight=True, ignore_label=ignore_label, thresh=0.7, min_kept=min_kept) elif args.dataset == 'cityscapes' and args.use_label_smoothing: criteria = CrossEntropyLoss2dLabelSmooth(weight=weight, ignore_label=ignore_label) elif args.dataset == 'cityscapes' and args.use_lovaszsoftmax: criteria = LovaszSoftmax(ignore_index=ignore_label) elif args.dataset == 'cityscapes' and args.use_focal: criteria = FocalLoss2d(weight=weight, ignore_index=ignore_label) elif args.dataset == 'paris': criteria = CrossEntropyLoss2d(weight=weight, ignore_label=ignore_label) else: raise NotImplementedError( "This repository now supports two datasets: cityscapes and camvid, %s is not included" % args.dataset) if args.cuda: criteria = criteria.cuda() if torch.cuda.device_count() > 1: print("torch.cuda.device_count()=", torch.cuda.device_count()) args.gpu_nums = torch.cuda.device_count() model = nn.DataParallel(model).cuda() # multi-card data parallel else: args.gpu_nums = 1 print("single GPU for training") model = model.cuda() # 1-card data parallel args.savedir = (args.savedir + args.dataset + '/' + args.model + 'bs' + str(args.batch_size) + 'gpu' + str(args.gpu_nums) + "_" + str(args.train_type) + '/') if not os.path.exists(args.savedir): os.makedirs(args.savedir) with open(args.savedir + 'args.txt', 'w') as f: f.write('mean:{}\nstd:{}\n'.format(datas['mean'], datas['std'])) f.write("Parameters: {} Seed: {}\n".format(str(total_paramters), GLOBAL_SEED)) f.write(str(args)) start_epoch = 0 # continue training if args.resume: if os.path.isfile(args.resume): checkpoint = torch.load(args.resume) start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['model']) # model.load_state_dict(convert_state_dict(checkpoint['model'])) print("loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("no checkpoint found at '{}'".format(args.resume)) model.train() cudnn.benchmark = True # cudnn.deterministic = True ## my add # initialize the early_stopping object early_stopping = EarlyStopping(patience=50) logFileLoc = args.savedir + args.logFile if os.path.isfile(logFileLoc): logger = open(logFileLoc, 'a') else: logger = open(logFileLoc, 'w') logger.write("%s\t%s\t\t%s\t%s\t%s" % ('Epoch', ' lr', 'Loss(Tr)', 'Loss(Val)', 'mIOU(Val)')) logger.flush() # define optimization strategy if args.optim == 'sgd': optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, momentum=0.9, weight_decay=1e-4) elif args.optim == 'adam': optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-4) elif args.optim == 'radam': optimizer = RAdam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, betas=(0.90, 0.999), eps=1e-08, weight_decay=1e-4) elif args.optim == 'ranger': optimizer = Ranger(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, betas=(0.95, 0.999), eps=1e-08, weight_decay=1e-4) elif args.optim == 'adamw': optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-4) lossTr_list = [] epoches = [] mIOU_val_list = [] lossVal_list = [] print('>>>>>>>>>>>beginning training>>>>>>>>>>>') for epoch in range(start_epoch, args.max_epochs): # training lossTr, lr = train(args, trainLoader, model, criteria, optimizer, epoch) lossTr_list.append(lossTr) # validation if epoch % args.val_miou_epochs == 0: epoches.append(epoch) val_loss, mIOU_val, per_class_iu = val(args, valLoader, criteria, model, epoch) mIOU_val_list.append(mIOU_val) lossVal_list.append(val_loss.item()) # record train information logger.write( "\n%d\t%.6f\t%.4f\t\t%.4f\t%0.4f\t %s" % (epoch, lr, lossTr, val_loss, mIOU_val, str(per_class_iu))) logger.flush() print( "Epoch %d\tlr= %.6f\tTrain Loss = %.4f\tVal Loss = %.4f\tmIOU(val) = %.4f\tper_class_iu= %s\n" % (epoch, lr, lossTr, val_loss, mIOU_val, str(per_class_iu))) else: # record train information val_loss = val(args, valLoader, criteria, model, epoch) lossVal_list.append(val_loss.item()) logger.write("\n%d\t%.6f\t%.4f\t\t%.4f" % (epoch, lr, lossTr, val_loss)) logger.flush() print("Epoch %d\tlr= %.6f\tTrain Loss = %.4f\tVal Loss = %.4f\n" % (epoch, lr, lossTr, val_loss)) # save the model model_file_name = args.savedir + '/model_' + str(epoch) + '.pth' state = {"epoch": epoch, "model": model.state_dict()} # Individual Setting for save model if epoch >= args.max_epochs - 10: torch.save(state, model_file_name) elif epoch % 10 == 0: torch.save(state, model_file_name) # draw plots for visualization if os.path.isfile(args.savedir + "loss.png"): f = open(args.savedir + 'log.txt', 'r') next(f) epoch_list = [] lossTr_list = [] lossVal_list = [] for line in f.readlines(): epoch_list.append(line.strip().split()[0]) lossTr_list.append(line.strip().split()[2]) lossVal_list.append(line.strip().split()[3]) assert len(epoch_list) == len(lossTr_list) == len(lossVal_list) fig1, ax1 = plt.subplots(figsize=(11, 8)) ax1.plot(range(0, epoch + 1), lossTr_list, label='Train_loss') ax1.plot(range(0, epoch + 1), lossVal_list, label='Val_loss') ax1.set_title("Average training loss vs epochs") ax1.set_xlabel("Epochs") ax1.set_ylabel("Current loss") ax1.legend() plt.savefig(args.savedir + "loss.png") plt.clf() else: fig1, ax1 = plt.subplots(figsize=(11, 8)) ax1.plot(range(0, epoch + 1), lossTr_list, label='Train_loss') ax1.plot(range(0, epoch + 1), lossVal_list, label='Val_loss') ax1.set_title("Average training loss vs epochs") ax1.set_xlabel("Epochs") ax1.set_ylabel("Current loss") ax1.legend() plt.savefig(args.savedir + "loss.png") plt.clf() fig2, ax2 = plt.subplots(figsize=(11, 8)) ax2.plot(epoches, mIOU_val_list, label="Val IoU") ax2.set_title("Average IoU vs epochs") ax2.set_xlabel("Epochs") ax2.set_ylabel("Current IoU") ax2.legend() plt.savefig(args.savedir + "mIou.png") plt.close('all') early_stopping.monitor(monitor=val_loss) if early_stopping.early_stop: print("Early stopping and Save checkpoint") if not os.path.exists(model_file_name): torch.save(state, model_file_name) break logger.close()