def build_dataset_train(dataset, input_size, batch_size, train_type, random_scale, random_mirror, num_workers): data_dir = os.path.join('./dataset/', dataset) dataset_list = os.path.join(dataset, '_trainval_list.txt') train_data_list = os.path.join(data_dir, dataset + '_' + train_type + '_list.txt') val_data_list = os.path.join(data_dir, dataset + '_val' + '_list.txt') inform_data_file = os.path.join('./dataset/inform/', dataset + '_inform.pkl') # inform_data_file collect the information of mean, std and weigth_class if not os.path.isfile(inform_data_file): print("%s is not found" % (inform_data_file)) if dataset == "cityscapes": dataCollect = CityscapesTrainInform(data_dir, 19, train_set_file=dataset_list, inform_data_file=inform_data_file) elif dataset == 'camvid': dataCollect = CamVidTrainInform(data_dir, 11, train_set_file=dataset_list, inform_data_file=inform_data_file) else: raise NotImplementedError( "This repository now supports two datasets: cityscapes and camvid, %s is not included" % dataset) datas = dataCollect.collectDataAndSave() if datas is None: print("error while pickling data. Please check.") exit(-1) else: print("find file: ", str(inform_data_file)) datas = pickle.load(open(inform_data_file, "rb")) if dataset == "cityscapes": trainLoader = data.DataLoader( CityscapesDataSet(data_dir, train_data_list, crop_size=input_size, scale=random_scale, mirror=random_mirror, mean=datas['mean']), batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True) valLoader = data.DataLoader( CityscapesValDataSet(data_dir, val_data_list, f_scale=1, mean=datas['mean']), batch_size=1, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True) return datas, trainLoader, valLoader elif dataset == "camvid": trainLoader = data.DataLoader( CamVidDataSet(data_dir, train_data_list, crop_size=input_size, scale=random_scale, mirror=random_mirror, mean=datas['mean']), batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True) if train_type == 'trainval': val_data_list = os.path.join(data_dir, dataset + '_test' + '_list.txt') valLoader = data.DataLoader( CamVidValDataSet(data_dir, val_data_list, f_scale=1, mean=datas['mean']), batch_size=1, shuffle=True, num_workers=num_workers, pin_memory=True) return datas, trainLoader, valLoader
def build_dataset_test(dataset, num_workers, none_gt=False): data_dir = os.path.join('./dataset/', dataset) dataset_list = os.path.join(dataset, '_trainval_list.txt') test_data_list = os.path.join(data_dir, dataset + '_test' + '_list.txt') inform_data_file = os.path.join('./dataset/inform/', dataset + '_inform.pkl') # inform_data_file collect the information of mean, std and weigth_class if not os.path.isfile(inform_data_file): print("%s is not found" % (inform_data_file)) if dataset == "cityscapes": dataCollect = CityscapesTrainInform(data_dir, 19, train_set_file=dataset_list, inform_data_file=inform_data_file) elif dataset == 'camvid': dataCollect = CamVidTrainInform(data_dir, 11, train_set_file=dataset_list, inform_data_file=inform_data_file) else: raise NotImplementedError( "This repository now supports two datasets: cityscapes and camvid, %s is not included" % dataset) datas = dataCollect.collectDataAndSave() if datas is None: print("error while pickling data. Please check.") exit(-1) else: print("find file: ", str(inform_data_file)) datas = pickle.load(open(inform_data_file, "rb")) if dataset == "cityscapes": # for cityscapes, if test on validation set, set none_gt to False # if test on the test set, set none_gt to True if none_gt: testLoader = data.DataLoader( CityscapesTestDataSet(data_dir, test_data_list, mean=datas['mean']), batch_size=1, shuffle=False, num_workers=num_workers, pin_memory=True) else: test_data_list = os.path.join(data_dir, dataset + '_val' + '_list.txt') testLoader = data.DataLoader( CityscapesValDataSet(data_dir, test_data_list, mean=datas['mean']), batch_size=1, shuffle=False, num_workers=num_workers, pin_memory=True) return datas, testLoader elif dataset == "camvid": testLoader = data.DataLoader( CamVidValDataSet(data_dir, test_data_list, mean=datas['mean']), batch_size=1, shuffle=False, num_workers=num_workers, pin_memory=True) return datas, testLoader
def build_dataset_predict(dataset_path, dataset, num_workers, none_gt=False): data_dir = os.path.join('./dataset/', dataset) dataset_list = dataset + '_trainval_list.txt' inform_data_file = os.path.join('./dataset/inform/', dataset + '_inform.pkl') # inform_data_file collect the information of mean, std and weigth_class if not os.path.isfile(inform_data_file): print("%s is not found" % (inform_data_file)) if dataset == "cityscapes": dataCollect = CityscapesTrainInform( data_dir, 19, train_set_file=dataset_list, inform_data_file=inform_data_file) elif dataset == 'camvid': dataCollect = CamVidTrainInform(data_dir, 11, train_set_file=dataset_list, inform_data_file=inform_data_file) elif dataset == 'custom_dataset': dataCollect = CustomTrainInform(data_dir, 2, train_set_file=dataset_list, inform_data_file=inform_data_file) else: raise NotImplementedError( "This repository now supports two datasets: cityscapes and camvid, %s is not included" % dataset) datas = dataCollect.collectDataAndSave() if datas is None: print("error while pickling data. Please check.") exit(-1) else: print("find file: ", str(inform_data_file)) datas = pickle.load(open(inform_data_file, "rb")) if dataset == "custom_dataset": testLoader = CustomPredictDataSet(dataset_path, mean=datas['mean']) return datas, testLoader
def test_model(args): """ main function for testing args: args: global arguments """ print("=====> Check if the cached file exists ") if not os.path.isfile(args.inform_data_file): print("%s is not found" % (args.inform_data_file)) dataCollect = CamVidTrainInform( args.data_dir, args.classes, train_set_file=args.dataset_list, inform_data_file=args.inform_data_file ) #collect mean std, weigth_class information datas = dataCollect.collectDataAndSave() if datas is None: print('Error while pickling data. Please check.') exit(-1) else: print("%s exists" % (args.inform_data_file)) datas = pickle.load(open(args.inform_data_file, "rb")) print(args) global network_type if args.cuda: print("=====> Use gpu id: '{}'".format(args.gpus)) os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus if not torch.cuda.is_available(): raise Exception( "No GPU found or Wrong gpu id, please run without --cuda") args.seed = random.randint(1, 10000) print("Random Seed: ", args.seed) torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) cudnn.enabled = True M = args.M N = args.N model = CGNet.Context_Guided_Network(classes=args.classes, M=M, N=N) network_type = "CGNet" print("=====> current architeture: CGNet_M%sN%s" % (M, N)) total_paramters = netParams(model) print("the number of parameters: " + str(total_paramters)) print("data['classWeights']: ", datas['classWeights']) weight = torch.from_numpy(datas['classWeights']) print("=====> Dataset statistics") print("mean and std: ", datas['mean'], datas['std']) # define optimization criteria criteria = CrossEntropyLoss2d(weight, args.ignore_label) if args.cuda: model = model.cuda() criteria = criteria.cuda() #load test set train_transform = transforms.Compose([transforms.ToTensor()]) testLoader = data.DataLoader(CamVidValDataSet(args.data_dir, args.test_data_list, f_scale=1, mean=datas['mean']), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, drop_last=True) if args.resume: if os.path.isfile(args.resume): print("=====> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) #model.load_state_dict(convert_state_dict(checkpoint['model'])) model.load_state_dict(checkpoint['model']) else: print("=====> no checkpoint found at '{}'".format(args.resume)) cudnn.benchmark = True print("=====> beginning test") print("length of test set:", len(testLoader)) mIOU_val, per_class_iu = test(args, testLoader, model, criteria) print(mIOU_val) print(per_class_iu)
test_data_list = os.path.join(data_dir, dataset + '_val' + '_list.txt') testLoader = data.DataLoader(CityscapesValDataSet( data_dir, test_data_list, mean=datas['mean']), batch_size=1, shuffle=False, num_workers=num_workers, pin_memory=True) return datas, testLoader elif dataset == "camvid": testLoader = data.DataLoader(CamVidValDataSet(data_dir, test_data_list, mean=datas['mean']), batch_size=1, shuffle=False, num_workers=num_workers, pin_memory=True) return datas, testLoader if __name__ == "__main__": dataCollection = CamVidTrainInform( "/home/mohamed/RINet/dataset/camvid", classes=11, train_set_file="camvid_trainval_list.txt", inform_data_file="inform/camvid_inform.pkl") data = dataCollection.collectDataAndSa
def build_dataset_train(dataset, input_size, batch_size, train_type, random_scale, random_mirror, num_workers, args): data_dir = os.path.join('/media/sdb/datasets/segment/', dataset) dataset_list = dataset + '_trainval_list.txt' train_data_list = os.path.join(data_dir, dataset + '_' + train_type + '_list.txt') val_data_list = os.path.join(data_dir, dataset + '_val' + '_list.txt') inform_data_file = os.path.join('./dataset/inform/', dataset + '_inform.pkl') if dataset == "cityscapes": # inform_data_file collect the information of mean, std and weigth_class if not os.path.isfile(inform_data_file): print("%s is not found" % (inform_data_file)) dataCollect = CityscapesTrainInform(data_dir, 19, train_set_file=dataset_list, inform_data_file=inform_data_file) datas = dataCollect.collectDataAndSave() if datas is None: print("error while pickling data. Please check.") exit(-1) else: print("find file: ", str(inform_data_file)) datas = pickle.load(open(inform_data_file, "rb")) trainLoader = data.DataLoader( CityscapesDataSet(data_dir, train_data_list, crop_size=input_size, scale=random_scale, mirror=random_mirror, mean=datas['mean']), batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True) valLoader = data.DataLoader( CityscapesValDataSet(data_dir, val_data_list, f_scale=1, mean=datas['mean']), batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True, drop_last=False) elif dataset == "camvid": # inform_data_file collect the information of mean, std and weigth_class if not os.path.isfile(inform_data_file): print("%s is not found" % (inform_data_file)) dataCollect = CamVidTrainInform(data_dir, 11, train_set_file=dataset_list, inform_data_file=inform_data_file) datas = dataCollect.collectDataAndSave() if datas is None: print("error while pickling data. Please check.") exit(-1) else: print("find file: ", str(inform_data_file)) datas = pickle.load(open(inform_data_file, "rb")) trainLoader = data.DataLoader( CamVidDataSet(data_dir, train_data_list, crop_size=input_size, scale=random_scale, mirror=random_mirror, mean=datas['mean']), batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True) valLoader = data.DataLoader( CamVidValDataSet(data_dir, val_data_list, f_scale=1, mean=datas['mean']), batch_size=1, shuffle=True, num_workers=num_workers, pin_memory=True) elif dataset == "ade20k": inform_data_file = os.path.join('./dataset/inform/', 'cityscapes_inform.pkl') datas = pickle.load(open(inform_data_file, "rb")) input_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([.485, .456, .406], [.229, .224, .225]), ]) data_kwargs = {'transform': input_transform, 'base_size': args.input_size[0]+40, 'crop_size': args.input_size[0], 'encode': False} train_dataset = ADE20KSegmentation(split='train', mode='train', **data_kwargs) val_dataset = ADE20KSegmentation(split='val', mode='val', **data_kwargs) train_sampler = make_data_sampler(train_dataset, shuffle=True, distributed=False) train_batch_sampler = make_batch_data_sampler(train_sampler, batch_size) val_sampler = make_data_sampler(val_dataset, shuffle=False, distributed=False) val_batch_sampler = make_batch_data_sampler(val_sampler, batch_size) trainLoader = data.DataLoader(dataset=train_dataset, batch_sampler=train_batch_sampler, num_workers=args.num_workers, pin_memory=True) valLoader = data.DataLoader(dataset=val_dataset, batch_sampler=val_batch_sampler, num_workers=args.num_workers, pin_memory=True) else: raise NotImplementedError( "This repository now supports datasets: cityscapes, camvid and ade20k, %s is not included" % dataset) return datas, trainLoader, valLoader
def train_model(args): """ args: args: global arguments """ h, w = map(int, args.input_size.split(',')) input_size = (h, w) print("=====> checking if inform_data_file exists") if not os.path.isfile(args.inform_data_file): print("%s is not found" % (args.inform_data_file)) dataCollect = CamVidTrainInform( args.data_dir, args.classes, train_set_file=args.dataset_list, inform_data_file=args.inform_data_file ) #collect mean std, weigth_class information datas = dataCollect.collectDataAndSave() if datas is None: print("error while pickling data. Please check.") exit(-1) else: print("find file: ", str(args.inform_data_file)) datas = pickle.load(open(args.inform_data_file, "rb")) print(args) global network_type if args.cuda: print("=====> use gpu id: '{}'".format(args.gpus)) os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus if not torch.cuda.is_available(): raise Exception( "No GPU found or Wrong gpu id, please run without --cuda") args.seed = random.randint(1, 10000) print("====> Random Seed: ", args.seed) torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) cudnn.enabled = True M = args.M N = args.N print("=====> building network") model = CGNet.Context_Guided_Network(classes=args.classes, M=M, N=N) network_type = "CGNet" print("=====> current architeture: CGNet") print("=====> computing network parameters") total_paramters = netParams(model) print("the number of parameters: " + str(total_paramters)) print("data['classWeights']: ", datas['classWeights']) print('=====> Dataset statistics') print('mean and std: ', datas['mean'], datas['std']) # define optimization criteria weight = torch.from_numpy(datas['classWeights']) criteria = CrossEntropyLoss2d(weight, args.ignore_label) if args.cuda: criteria = criteria.cuda() args.gpu_nums = 1 if torch.cuda.device_count() > 1: print("torch.cuda.device_count()=", torch.cuda.device_count()) args.gpu_nums = torch.cuda.device_count() model = torch.nn.DataParallel(model).cuda() else: print("single GPU for training") model = model.cuda() args.savedir = (args.savedir + args.dataset + '/' + network_type + "_M" + str(M) + 'N' + str(N) + 'bs' + str(args.batch_size) + 'gpu' + str(args.gpu_nums) + "_" + str(args.train_type) + '/') if not os.path.exists(args.savedir): os.makedirs(args.savedir) #Data augmentation, compose the data with transforms train_transform = transforms.Compose([transforms.ToTensor()]) trainLoader = data.DataLoader(CamVidDataSet(args.data_dir, args.train_data_list, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=datas['mean']), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, drop_last=True) valLoader = data.DataLoader(CamVidValDataSet(args.data_dir, args.val_data_list, f_scale=1, mean=datas['mean']), batch_size=1, shuffle=True, num_workers=args.num_workers, pin_memory=True, drop_last=True) start_epoch = 0 if args.resume: if os.path.isfile(args.resume): checkpoint = torch.load(args.resume) start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['model']) #model.load_state_dict(convert_state_dict(checkpoint['model'])) print("=====> loading checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=====> no checkpoint found at '{}'".format(args.resume)) model.train() cudnn.benchmark = True logFileLoc = args.savedir + args.logFile if os.path.isfile(logFileLoc): logger = open(logFileLoc, 'a') else: logger = open(logFileLoc, 'w') logger.write("Parameters: %s" % (str(total_paramters))) logger.write( "\n%s\t\t%s\t\t%s\t\t%s\t\t%s\t\t" % ('Epoch', 'Loss(Tr)', 'Loss(val)', 'mIOU (tr)', 'mIOU (val)')) logger.flush() optimizer = torch.optim.Adam(model.parameters(), args.lr, (0.9, 0.999), eps=1e-08, weight_decay=5e-4) print('=====> beginning training') for epoch in range(start_epoch, args.max_epochs): #training lossTr, per_class_iu_tr, mIOU_tr, lr = train(args, trainLoader, model, criteria, optimizer, epoch) #validation if epoch % 50 == 0: mIOU_val, per_class_iu = val(args, valLoader, model, criteria) logger.write("\n%d\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.7f" % (epoch, lossTr, mIOU_tr, mIOU_val, lr)) logger.flush() print("epoch: " + str(epoch) + ' Details') print( "\nEpoch No.: %d\tTrain Loss = %.4f\t mIOU(tr) = %.4f\t mIOU(val) = %.4f\t lr= %.6f" % (epoch, lossTr, mIOU_tr, mIOU_val, lr)) else: logger.write("\n%d\t\t%.4f\t\t%.4f\t\t%.7f" % (epoch, lossTr, mIOU_tr, lr)) logger.flush() print("Epoch : " + str(epoch) + ' Details') print( "\nEpoch No.: %d\tTrain Loss = %.4f\t mIOU(tr) = %.4f\t lr= %.6f" % (epoch, lossTr, mIOU_tr, lr)) #save the model model_file_name = args.savedir + '/model_' + str(epoch + 1) + '.pth' state = {"epoch": epoch + 1, "model": model.state_dict()} if epoch > args.max_epochs - 10: torch.save(state, model_file_name) elif not epoch % 20: torch.save(state, model_file_name) logger.close()