def build_dataset_train(dataset, input_size, batch_size, train_type, random_scale, random_mirror, num_workers): data_dir = os.path.join('./dataset/', dataset) dataset_list = os.path.join(dataset, '_trainval_list.txt') train_data_list = os.path.join(data_dir, dataset + '_' + train_type + '_list.txt') val_data_list = os.path.join(data_dir, dataset + '_val' + '_list.txt') inform_data_file = os.path.join('./dataset/inform/', dataset + '_inform.pkl') # inform_data_file collect the information of mean, std and weigth_class if not os.path.isfile(inform_data_file): print("%s is not found" % (inform_data_file)) if dataset == "cityscapes": dataCollect = CityscapesTrainInform(data_dir, 19, train_set_file=dataset_list, inform_data_file=inform_data_file) elif dataset == 'camvid': dataCollect = CamVidTrainInform(data_dir, 11, train_set_file=dataset_list, inform_data_file=inform_data_file) else: raise NotImplementedError( "This repository now supports two datasets: cityscapes and camvid, %s is not included" % dataset) datas = dataCollect.collectDataAndSave() if datas is None: print("error while pickling data. Please check.") exit(-1) else: print("find file: ", str(inform_data_file)) datas = pickle.load(open(inform_data_file, "rb")) if dataset == "cityscapes": trainLoader = data.DataLoader( CityscapesDataSet(data_dir, train_data_list, crop_size=input_size, scale=random_scale, mirror=random_mirror, mean=datas['mean']), batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True) valLoader = data.DataLoader( CityscapesValDataSet(data_dir, val_data_list, f_scale=1, mean=datas['mean']), batch_size=1, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True) return datas, trainLoader, valLoader elif dataset == "camvid": trainLoader = data.DataLoader( CamVidDataSet(data_dir, train_data_list, crop_size=input_size, scale=random_scale, mirror=random_mirror, mean=datas['mean']), batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True) if train_type == 'trainval': val_data_list = os.path.join(data_dir, dataset + '_test' + '_list.txt') valLoader = data.DataLoader( CamVidValDataSet(data_dir, val_data_list, f_scale=1, mean=datas['mean']), batch_size=1, shuffle=True, num_workers=num_workers, pin_memory=True) return datas, trainLoader, valLoader
def build_dataset_test(dataset, num_workers, none_gt=False): data_dir = os.path.join('./dataset/', dataset) dataset_list = os.path.join(dataset, '_trainval_list.txt') test_data_list = os.path.join(data_dir, dataset + '_test' + '_list.txt') inform_data_file = os.path.join('./dataset/inform/', dataset + '_inform.pkl') # inform_data_file collect the information of mean, std and weigth_class if not os.path.isfile(inform_data_file): print("%s is not found" % (inform_data_file)) if dataset == "cityscapes": dataCollect = CityscapesTrainInform(data_dir, 19, train_set_file=dataset_list, inform_data_file=inform_data_file) elif dataset == 'camvid': dataCollect = CamVidTrainInform(data_dir, 11, train_set_file=dataset_list, inform_data_file=inform_data_file) else: raise NotImplementedError( "This repository now supports two datasets: cityscapes and camvid, %s is not included" % dataset) datas = dataCollect.collectDataAndSave() if datas is None: print("error while pickling data. Please check.") exit(-1) else: print("find file: ", str(inform_data_file)) datas = pickle.load(open(inform_data_file, "rb")) if dataset == "cityscapes": # for cityscapes, if test on validation set, set none_gt to False # if test on the test set, set none_gt to True if none_gt: testLoader = data.DataLoader( CityscapesTestDataSet(data_dir, test_data_list, mean=datas['mean']), batch_size=1, shuffle=False, num_workers=num_workers, pin_memory=True) else: test_data_list = os.path.join(data_dir, dataset + '_val' + '_list.txt') testLoader = data.DataLoader( CityscapesValDataSet(data_dir, test_data_list, mean=datas['mean']), batch_size=1, shuffle=False, num_workers=num_workers, pin_memory=True) return datas, testLoader elif dataset == "camvid": testLoader = data.DataLoader( CamVidValDataSet(data_dir, test_data_list, mean=datas['mean']), batch_size=1, shuffle=False, num_workers=num_workers, pin_memory=True) return datas, testLoader
def build_dataset_predict(dataset_path, dataset, num_workers, none_gt=False): data_dir = os.path.join('./dataset/', dataset) dataset_list = dataset + '_trainval_list.txt' inform_data_file = os.path.join('./dataset/inform/', dataset + '_inform.pkl') # inform_data_file collect the information of mean, std and weigth_class if not os.path.isfile(inform_data_file): print("%s is not found" % (inform_data_file)) if dataset == "cityscapes": dataCollect = CityscapesTrainInform( data_dir, 19, train_set_file=dataset_list, inform_data_file=inform_data_file) elif dataset == 'camvid': dataCollect = CamVidTrainInform(data_dir, 11, train_set_file=dataset_list, inform_data_file=inform_data_file) elif dataset == 'custom_dataset': dataCollect = CustomTrainInform(data_dir, 2, train_set_file=dataset_list, inform_data_file=inform_data_file) else: raise NotImplementedError( "This repository now supports two datasets: cityscapes and camvid, %s is not included" % dataset) datas = dataCollect.collectDataAndSave() if datas is None: print("error while pickling data. Please check.") exit(-1) else: print("find file: ", str(inform_data_file)) datas = pickle.load(open(inform_data_file, "rb")) if dataset == "custom_dataset": testLoader = CustomPredictDataSet(dataset_path, mean=datas['mean']) return datas, testLoader
def test_func(args): """ main function for testing param args: global arguments return: None """ print(args) global network_type if args.cuda: print("=====> use gpu id: '{}'".format(args.gpus)) os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus if not torch.cuda.is_available(): raise Exception( "no GPU found or wrong gpu id, please run without --cuda") args.seed = random.randint(1, 10000) print("Random Seed: ", args.seed) torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) print('=====> checking if processed cached_data_file exists') if not os.path.isfile(args.inform_data_file): dataCollect = CityscapesTrainInform( args.data_dir, args.classes, train_set_file=args.dataset_list, inform_data_file=args.inform_data_file ) #collect mean std, weigth_class information data = dataCollect.collectDataAndSave() if data is None: print("error while pickling data, please check") exit(-1) else: data = pickle.load(open(args.inform_data_file, "rb")) M = args.M N = args.N model = CGNet.Context_Guided_Network(classes=args.classes, M=M, N=N) network_type = "CGNet" print("Arch: CGNet") # define optimization criteria weight = torch.from_numpy( data['classWeights']) # convert the numpy array to torch if args.cuda: weight = weight.cuda() criteria = CrossEntropyLoss2d(weight) #weight if args.cuda: model = model.cuda() # using GPU for inference criteria = criteria.cuda() cudnn.benchmark = True print('Dataset statistics') print('mean and std: ', data['mean'], data['std']) print('classWeights: ', data['classWeights']) if args.save_seg_dir: if not os.path.exists(args.save_seg_dir): os.makedirs(args.save_seg_dir) # validation set testLoader = torch.utils.data.DataLoader(CityscapesTestDataSet( args.data_dir, args.test_data_list, mean=data['mean']), batch_size=1, shuffle=False, num_workers=args.num_workers, pin_memory=True) if args.resume: if os.path.isfile(args.resume): print("=====> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) #model.load_state_dict(checkpoint['model']) model.load_state_dict(convert_state_dict(checkpoint['model'])) else: print("=====> no checkpoint found at '{}'".format(args.resume)) print("=====> beginning testing") print("test set length: ", len(testLoader)) test(args, testLoader, model)
def train_model(args): """ args: args: global arguments """ h, w = map(int, args.input_size.split(',')) input_size = (h, w) print("=====> checking if inform_data_file exists") if not os.path.isfile(args.inform_data_file): print("%s is not found" % (args.inform_data_file)) dataCollect = CityscapesTrainInform( args.data_dir, args.classes, train_set_file=args.dataset_list, inform_data_file=args.inform_data_file ) #collect mean std, weigth_class information datas = dataCollect.collectDataAndSave() if datas is None: print("error while pickling data. Please check.") exit(-1) else: print("find file: ", str(args.inform_data_file)) datas = pickle.load(open(args.inform_data_file, "rb")) print(args) global network_type if args.cuda: print("=====> use gpu id: '{}'".format(args.gpus)) os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus if not torch.cuda.is_available(): raise Exception( "No GPU found or Wrong gpu id, please run without --cuda") #args.seed = random.randint(1, 10000) args.seed = 9830 print("====> Random Seed: ", args.seed) torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) cudnn.enabled = True model = MobileNetV3(model_mode="SMALL", num_classes=args.classes) network_type = "MobileNetV3" print("=====> current architeture: MobileNetV3") print("=====> computing network parameters") total_paramters = netParams(model) print("the number of parameters: " + str(total_paramters)) print("data['classWeights']: ", datas['classWeights']) print('=====> Dataset statistics') print('mean and std: ', datas['mean'], datas['std']) # define optimization criteria weight = torch.from_numpy(datas['classWeights']) criteria = CrossEntropyLoss2d(weight) if args.cuda: criteria = criteria.cuda() if torch.cuda.device_count() > 1: print("torch.cuda.device_count()=", torch.cuda.device_count()) args.gpu_nums = torch.cuda.device_count() model = torch.nn.DataParallel( model).cuda() #multi-card data parallel else: print("single GPU for training") model = model.cuda() #1-card data parallel args.savedir = (args.savedir + args.dataset + '/' + network_type + 'bs' + str(args.batch_size) + 'gpu' + str(args.gpu_nums) + "_" + str(args.train_type) + '/') if not os.path.exists(args.savedir): os.makedirs(args.savedir) train_transform = transforms.Compose([transforms.ToTensor()]) trainLoader = data.DataLoader(CityscapesDataSet(args.data_dir, args.train_data_list, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=datas['mean']), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, drop_last=True) valLoader = data.DataLoader(CityscapesValDataSet(args.data_dir, args.val_data_list, f_scale=1, mean=datas['mean']), batch_size=1, shuffle=True, num_workers=args.num_workers, pin_memory=True, drop_last=True) start_epoch = 0 if args.resume: if os.path.isfile(args.resume): checkpoint = torch.load(args.resume) start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['model']) #model.load_state_dict(convert_state_dict(checkpoint['model'])) print("=====> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=====> no checkpoint found at '{}'".format(args.resume)) model.train() cudnn.benchmark = True logFileLoc = args.savedir + args.logFile if os.path.isfile(logFileLoc): logger = open(logFileLoc, 'a') logger.write("\nGlobal configuration as follows:") for key, value in vars(args).items(): logger.write("\n{:16} {}".format(key, value)) logger.write("\nParameters: %s" % (str(total_paramters))) logger.write( "\n%s\t\t%s\t\t%s\t\t%s\t\t%s\t\t" % ('Epoch', 'Loss(Tr)', 'Loss(val)', 'mIOU (tr)', 'mIOU (val)')) else: logger = open(logFileLoc, 'w') logger.write("Global configuration as follows:") for key, value in vars(args).items(): logger.write("\n{:16} {}".format(key, value)) logger.write("\nParameters: %s" % (str(total_paramters))) logger.write( "\n%s\t\t%s\t\t%s\t\t%s\t\t%s\t\t" % ('Epoch', 'Loss(Tr)', 'Loss(val)', 'mIOU (tr)', 'mIOU (val)')) logger.flush() optimizer = torch.optim.Adam(model.parameters(), args.lr, (0.9, 0.999), eps=1e-08, weight_decay=5e-4) print('=====> beginning training') for epoch in range(start_epoch, args.max_epochs): #training lossTr, per_class_iu_tr, mIOU_tr, lr = train(args, trainLoader, model, criteria, optimizer, epoch) #validation if epoch % 50 == 0: mIOU_val, per_class_iu = val(args, valLoader, model, criteria) # record train information logger.write("\n%d\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.7f" % (epoch, lossTr, mIOU_tr, mIOU_val, lr)) logger.flush() print("Epoch : " + str(epoch) + ' Details') print( "\nEpoch No.: %d\tTrain Loss = %.4f\t mIOU(tr) = %.4f\t mIOU(val) = %.4f\t lr= %.6f" % (epoch, lossTr, mIOU_tr, mIOU_val, lr)) else: # record train information logger.write("\n%d\t\t%.4f\t\t%.4f\t\t%.7f" % (epoch, lossTr, mIOU_tr, lr)) logger.flush() print("Epoch : " + str(epoch) + ' Details') print( "\nEpoch No.: %d\tTrain Loss = %.4f\t mIOU(tr) = %.4f\t lr= %.6f" % (epoch, lossTr, mIOU_tr, lr)) #save the model model_file_name = args.savedir + '/model_' + str(epoch + 1) + '.pth' state = {"epoch": epoch + 1, "model": model.state_dict()} if epoch > args.max_epochs - 10: torch.save(state, model_file_name) elif not epoch % 20: torch.save(state, model_file_name) logger.close()
def test_func(args): """ main function for testing param args: global arguments return: None """ print(args) global network_type if args.cuda: print("=====> use gpu id: '{}'".format(args.gpus)) os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus if not torch.cuda.is_available(): raise Exception( "no GPU found or wrong gpu id, please run without --cuda") device = 'cuda' args.seed = random.randint(1, 10000) print("Random Seed: ", args.seed) torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) print('=====> checking if processed cached_data_file exists') if not os.path.isfile(args.inform_data_file): dataCollect = CityscapesTrainInform( args.data_dir, args.classes, train_set_file=args.dataset_list, inform_data_file=args.inform_data_file ) #collect mean std, weigth_class information data = dataCollect.collectDataAndSave() if data is None: print("error while pickling data, please check") exit(-1) else: data = pickle.load(open(args.inform_data_file, "rb")) M = args.M N = args.N model = MobileNetV3(model_mode="SMALL", num_classes=args.classes) network_type = "MobileNetV3" print("Arch: MobileNetV3") if args.cuda: model = model.to(device) # using GPU for inference cudnn.benchmark = True print('Dataset statistics') print('mean and std: ', data['mean'], data['std']) print('classWeights: ', data['classWeights']) # validation set testLoader = torch.utils.data.DataLoader(CityscapesTestDataSet( args.data_dir, args.test_data_list, mean=data['mean']), batch_size=1, shuffle=False, num_workers=args.num_workers, pin_memory=True) if args.resume: if os.path.isfile(args.resume): print("=====> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) #model.load_state_dict(checkpoint['model']) model.load_state_dict(convert_state_dict(checkpoint['model'])) else: print("=====> no checkpoint found at '{}'".format(args.resume)) print("=====> beginning testing") print("test set length: ", len(testLoader)) test(args, testLoader, model, device, data)
def build_dataset_train(dataset, input_size, batch_size, train_type, random_scale, random_mirror, num_workers, args): data_dir = os.path.join('/media/sdb/datasets/segment/', dataset) dataset_list = dataset + '_trainval_list.txt' train_data_list = os.path.join(data_dir, dataset + '_' + train_type + '_list.txt') val_data_list = os.path.join(data_dir, dataset + '_val' + '_list.txt') inform_data_file = os.path.join('./dataset/inform/', dataset + '_inform.pkl') if dataset == "cityscapes": # inform_data_file collect the information of mean, std and weigth_class if not os.path.isfile(inform_data_file): print("%s is not found" % (inform_data_file)) dataCollect = CityscapesTrainInform(data_dir, 19, train_set_file=dataset_list, inform_data_file=inform_data_file) datas = dataCollect.collectDataAndSave() if datas is None: print("error while pickling data. Please check.") exit(-1) else: print("find file: ", str(inform_data_file)) datas = pickle.load(open(inform_data_file, "rb")) trainLoader = data.DataLoader( CityscapesDataSet(data_dir, train_data_list, crop_size=input_size, scale=random_scale, mirror=random_mirror, mean=datas['mean']), batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True) valLoader = data.DataLoader( CityscapesValDataSet(data_dir, val_data_list, f_scale=1, mean=datas['mean']), batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True, drop_last=False) elif dataset == "camvid": # inform_data_file collect the information of mean, std and weigth_class if not os.path.isfile(inform_data_file): print("%s is not found" % (inform_data_file)) dataCollect = CamVidTrainInform(data_dir, 11, train_set_file=dataset_list, inform_data_file=inform_data_file) datas = dataCollect.collectDataAndSave() if datas is None: print("error while pickling data. Please check.") exit(-1) else: print("find file: ", str(inform_data_file)) datas = pickle.load(open(inform_data_file, "rb")) trainLoader = data.DataLoader( CamVidDataSet(data_dir, train_data_list, crop_size=input_size, scale=random_scale, mirror=random_mirror, mean=datas['mean']), batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True) valLoader = data.DataLoader( CamVidValDataSet(data_dir, val_data_list, f_scale=1, mean=datas['mean']), batch_size=1, shuffle=True, num_workers=num_workers, pin_memory=True) elif dataset == "ade20k": inform_data_file = os.path.join('./dataset/inform/', 'cityscapes_inform.pkl') datas = pickle.load(open(inform_data_file, "rb")) input_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([.485, .456, .406], [.229, .224, .225]), ]) data_kwargs = {'transform': input_transform, 'base_size': args.input_size[0]+40, 'crop_size': args.input_size[0], 'encode': False} train_dataset = ADE20KSegmentation(split='train', mode='train', **data_kwargs) val_dataset = ADE20KSegmentation(split='val', mode='val', **data_kwargs) train_sampler = make_data_sampler(train_dataset, shuffle=True, distributed=False) train_batch_sampler = make_batch_data_sampler(train_sampler, batch_size) val_sampler = make_data_sampler(val_dataset, shuffle=False, distributed=False) val_batch_sampler = make_batch_data_sampler(val_sampler, batch_size) trainLoader = data.DataLoader(dataset=train_dataset, batch_sampler=train_batch_sampler, num_workers=args.num_workers, pin_memory=True) valLoader = data.DataLoader(dataset=val_dataset, batch_sampler=val_batch_sampler, num_workers=args.num_workers, pin_memory=True) else: raise NotImplementedError( "This repository now supports datasets: cityscapes, camvid and ade20k, %s is not included" % dataset) return datas, trainLoader, valLoader