def test_model(args): """ main function for testing args: args: global arguments """ print("=====> Check if the cached file exists ") if not os.path.isfile(args.inform_data_file): print("%s is not found" % (args.inform_data_file)) dataCollect = CamVidTrainInform( args.data_dir, args.classes, train_set_file=args.dataset_list, inform_data_file=args.inform_data_file ) #collect mean std, weigth_class information datas = dataCollect.collectDataAndSave() if datas is None: print('Error while pickling data. Please check.') exit(-1) else: print("%s exists" % (args.inform_data_file)) datas = pickle.load(open(args.inform_data_file, "rb")) print(args) global network_type if args.cuda: print("=====> Use gpu id: '{}'".format(args.gpus)) os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus if not torch.cuda.is_available(): raise Exception( "No GPU found or Wrong gpu id, please run without --cuda") args.seed = random.randint(1, 10000) print("Random Seed: ", args.seed) torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) cudnn.enabled = True M = args.M N = args.N model = CGNet.Context_Guided_Network(classes=args.classes, M=M, N=N) network_type = "CGNet" print("=====> current architeture: CGNet_M%sN%s" % (M, N)) total_paramters = netParams(model) print("the number of parameters: " + str(total_paramters)) print("data['classWeights']: ", datas['classWeights']) weight = torch.from_numpy(datas['classWeights']) print("=====> Dataset statistics") print("mean and std: ", datas['mean'], datas['std']) # define optimization criteria criteria = CrossEntropyLoss2d(weight, args.ignore_label) if args.cuda: model = model.cuda() criteria = criteria.cuda() #load test set train_transform = transforms.Compose([transforms.ToTensor()]) testLoader = data.DataLoader(CamVidValDataSet(args.data_dir, args.test_data_list, f_scale=1, mean=datas['mean']), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, drop_last=True) if args.resume: if os.path.isfile(args.resume): print("=====> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) #model.load_state_dict(convert_state_dict(checkpoint['model'])) model.load_state_dict(checkpoint['model']) else: print("=====> no checkpoint found at '{}'".format(args.resume)) cudnn.benchmark = True print("=====> beginning test") print("length of test set:", len(testLoader)) mIOU_val, per_class_iu = test(args, testLoader, model, criteria) print(mIOU_val) print(per_class_iu)
def train_model(args): """ args: args: global arguments """ h, w = map(int, args.input_size.split(',')) input_size = (h, w) print("=====> checking if inform_data_file exists") if not os.path.isfile(args.inform_data_file): print("%s is not found" % (args.inform_data_file)) dataCollect = CamVidTrainInform( args.data_dir, args.classes, train_set_file=args.dataset_list, inform_data_file=args.inform_data_file ) #collect mean std, weigth_class information datas = dataCollect.collectDataAndSave() if datas is None: print("error while pickling data. Please check.") exit(-1) else: print("find file: ", str(args.inform_data_file)) datas = pickle.load(open(args.inform_data_file, "rb")) print(args) global network_type if args.cuda: print("=====> use gpu id: '{}'".format(args.gpus)) os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus if not torch.cuda.is_available(): raise Exception( "No GPU found or Wrong gpu id, please run without --cuda") args.seed = random.randint(1, 10000) print("====> Random Seed: ", args.seed) torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) cudnn.enabled = True M = args.M N = args.N print("=====> building network") model = CGNet.Context_Guided_Network(classes=args.classes, M=M, N=N) network_type = "CGNet" print("=====> current architeture: CGNet") print("=====> computing network parameters") total_paramters = netParams(model) print("the number of parameters: " + str(total_paramters)) print("data['classWeights']: ", datas['classWeights']) print('=====> Dataset statistics') print('mean and std: ', datas['mean'], datas['std']) # define optimization criteria weight = torch.from_numpy(datas['classWeights']) criteria = CrossEntropyLoss2d(weight, args.ignore_label) if args.cuda: criteria = criteria.cuda() args.gpu_nums = 1 if torch.cuda.device_count() > 1: print("torch.cuda.device_count()=", torch.cuda.device_count()) args.gpu_nums = torch.cuda.device_count() model = torch.nn.DataParallel(model).cuda() else: print("single GPU for training") model = model.cuda() args.savedir = (args.savedir + args.dataset + '/' + network_type + "_M" + str(M) + 'N' + str(N) + 'bs' + str(args.batch_size) + 'gpu' + str(args.gpu_nums) + "_" + str(args.train_type) + '/') if not os.path.exists(args.savedir): os.makedirs(args.savedir) #Data augmentation, compose the data with transforms train_transform = transforms.Compose([transforms.ToTensor()]) trainLoader = data.DataLoader(CamVidDataSet(args.data_dir, args.train_data_list, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=datas['mean']), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, drop_last=True) valLoader = data.DataLoader(CamVidValDataSet(args.data_dir, args.val_data_list, f_scale=1, mean=datas['mean']), batch_size=1, shuffle=True, num_workers=args.num_workers, pin_memory=True, drop_last=True) start_epoch = 0 if args.resume: if os.path.isfile(args.resume): checkpoint = torch.load(args.resume) start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['model']) #model.load_state_dict(convert_state_dict(checkpoint['model'])) print("=====> loading checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=====> no checkpoint found at '{}'".format(args.resume)) model.train() cudnn.benchmark = True logFileLoc = args.savedir + args.logFile if os.path.isfile(logFileLoc): logger = open(logFileLoc, 'a') else: logger = open(logFileLoc, 'w') logger.write("Parameters: %s" % (str(total_paramters))) logger.write( "\n%s\t\t%s\t\t%s\t\t%s\t\t%s\t\t" % ('Epoch', 'Loss(Tr)', 'Loss(val)', 'mIOU (tr)', 'mIOU (val)')) logger.flush() optimizer = torch.optim.Adam(model.parameters(), args.lr, (0.9, 0.999), eps=1e-08, weight_decay=5e-4) print('=====> beginning training') for epoch in range(start_epoch, args.max_epochs): #training lossTr, per_class_iu_tr, mIOU_tr, lr = train(args, trainLoader, model, criteria, optimizer, epoch) #validation if epoch % 50 == 0: mIOU_val, per_class_iu = val(args, valLoader, model, criteria) logger.write("\n%d\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.7f" % (epoch, lossTr, mIOU_tr, mIOU_val, lr)) logger.flush() print("epoch: " + str(epoch) + ' Details') print( "\nEpoch No.: %d\tTrain Loss = %.4f\t mIOU(tr) = %.4f\t mIOU(val) = %.4f\t lr= %.6f" % (epoch, lossTr, mIOU_tr, mIOU_val, lr)) else: logger.write("\n%d\t\t%.4f\t\t%.4f\t\t%.7f" % (epoch, lossTr, mIOU_tr, lr)) logger.flush() print("Epoch : " + str(epoch) + ' Details') print( "\nEpoch No.: %d\tTrain Loss = %.4f\t mIOU(tr) = %.4f\t lr= %.6f" % (epoch, lossTr, mIOU_tr, lr)) #save the model model_file_name = args.savedir + '/model_' + str(epoch + 1) + '.pth' state = {"epoch": epoch + 1, "model": model.state_dict()} if epoch > args.max_epochs - 10: torch.save(state, model_file_name) elif not epoch % 20: torch.save(state, model_file_name) logger.close()