def main(): """ Main function wrapper for demo script """ random.seed(args["SEED"]) np.random.seed(args["SEED"]) torch.manual_seed(args["SEED"]) if torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") if args["TRAINED_WEIGHTS_FILE"] is not None: print("Trained Weights File: %s" % (args["TRAINED_WEIGHTS_FILE"])) print("Demo Directory: %s" % (args["DEMO_DIRECTORY"])) model = MyNet() model.load_state_dict( torch.load( args["CODE_DIRECTORY"] + args["TRAINED_WEIGHTS_FILE"], map_location=device, )) model.to(device) print("Running Demo ....") for root, dirs, files in os.walk(args["DEMO_DIRECTORY"]): for file in files: sampleFile = os.path.join(root, file) preprocess_sample(sampleFile) inp, _ = prepare_input(sampleFile) inputBatch = torch.unsqueeze(inp, dim=0) inputBatch = (inputBatch.float()).to(device) model.eval() with torch.no_grad(): outputBatch = model(inputBatch) predictionBatch = decode(outputBatch) pred = predictionBatch[0][:] print("File: %s" % (file)) print("Prediction: %s" % (pred)) print("\n") print("Demo Completed.") else: print("Path to trained weights file not specified.") return
def main(): """ Main function wrapper for testing script. """ random.seed(args["SEED"]) np.random.seed(args["SEED"]) torch.manual_seed(args["SEED"]) if torch.cuda.is_available(): device = torch.device("cuda") kwargs = {"num_workers": args["NUM_WORKERS"], "pin_memory": True} else: device = torch.device("cpu") kwargs = {} if args["TRAINED_WEIGHTS_FILE"] is not None: testData = MyDataset("test", datadir=args["DATA_DIRECTORY"]) testLoader = DataLoader(testData, batch_size=args["BATCH_SIZE"], shuffle=True, **kwargs) print("Trained Weights File: %s" % (args["TRAINED_WEIGHTS_FILE"])) model = MyNet() model.load_state_dict( torch.load( args["CODE_DIRECTORY"] + args["TRAINED_WEIGHTS_FILE"], map_location=device, )) model.to(device) criterion = MyLoss() regularizer = L2Regularizer(lambd=args["LAMBDA"]) print("Testing the trained model ....") testLoss, testMetric = evaluate(model, testLoader, criterion, regularizer, device) print("| Test Loss: %.6f || Test Metric: %.3f |" % (testLoss, testMetric)) print("Testing Done.") else: print("Path to the trained weights file not specified.") return
def load_model(self): """ :return: """ # TODO 1 加载模型 use_cuda = self.use_cuda if self.o_net_path is not None: print('=======> loading') net = MyNet(use_cuda=False) net.load_state_dict(torch.load(self.o_net_path)) if (use_cuda): net.to('cpu') net.eval() # TODO 2 准备好数据 img_list = os.listdir(self.image_dir) for idx, item in enumerate(img_list): _img = Image.open(os.path.join(self.image_dir, item)) parse_result = self.parse_image_name(item) landmark_and_format = parse_result['landmark_and_format'] name = parse_result['name'] img = self.transforms(_img) img = img.unsqueeze(0) pred = net(img) pred = pred * 192 # pred = pred.detach().numpy() print('the pred landmark is :', pred) print("=" * 20) # # print(pred.shape) # # print(landmark) # try: self.save_pred(_img, name, landmark_and_format, pred.detach().numpy()) # self.visualize(_img, np.array(landmark)) # self.visualize(_img, pred.detach().numpy()) # # print(pred) except: print('Error:', item)
def main(): """ Main function wrapper for training script. """ matplotlib.use("Agg") random.seed(args["SEED"]) np.random.seed(args["SEED"]) torch.manual_seed(args["SEED"]) if torch.cuda.is_available(): device = torch.device("cuda") kwargs = {"num_workers": args["NUM_WORKERS"], "pin_memory": True} else: device = torch.device("cpu") kwargs = {} trainData = MyDataset("train", datadir=args["DATA_DIRECTORY"]) valSize = int(args["VALIDATION_SPLIT"] * len(trainData)) trainSize = len(trainData) - valSize trainData, valData = random_split(trainData, [trainSize, valSize]) trainLoader = DataLoader( trainData, batch_size=args["BATCH_SIZE"], shuffle=True, **kwargs ) valLoader = DataLoader( valData, batch_size=args["BATCH_SIZE"], shuffle=True, **kwargs ) model = MyNet() model.to(device) optimizer = optim.Adam( model.parameters(), lr=args["LEARNING_RATE"], betas=(args["MOMENTUM1"], args["MOMENTUM2"]), ) scheduler = optim.lr_scheduler.ExponentialLR( optimizer, gamma=args["LR_DECAY"] ) criterion = MyLoss() regularizer = L2Regularizer(lambd=args["LAMBDA"]) if os.path.exists(args["CODE_DIRECTORY"] + "/checkpoints"): while True: char = input( "Continue and remove the 'checkpoints' directory? y/n: " ) if char == "y": break if char == "n": sys.exit() else: print("Invalid input") shutil.rmtree(args["CODE_DIRECTORY"] + "/checkpoints") os.mkdir(args["CODE_DIRECTORY"] + "/checkpoints") os.mkdir(args["CODE_DIRECTORY"] + "/checkpoints/plots") os.mkdir(args["CODE_DIRECTORY"] + "/checkpoints/weights") if args["PRETRAINED_WEIGHTS_FILE"] is not None: print( "Pretrained Weights File: %s" % (args["PRETRAINED_WEIGHTS_FILE"]) ) print("Loading the pretrained weights ....") model.load_state_dict( torch.load( args["CODE_DIRECTORY"] + args["PRETRAINED_WEIGHTS_FILE"], map_location=device, ) ) model.to(device) print("Loading Done.") trainingLossCurve = list() validationLossCurve = list() trainingMetricCurve = list() validationMetricCurve = list() numTotalParams, numTrainableParams = num_params(model) print("Number of total parameters in the model = %d" % (numTotalParams)) print( "Number of trainable parameters in the model = %d" % (numTrainableParams) ) print("Training the model ....") for epoch in range(1, args["NUM_EPOCHS"] + 1): trainingLoss, trainingMetric = train( model, trainLoader, optimizer, criterion, regularizer, device ) trainingLossCurve.append(trainingLoss) trainingMetricCurve.append(trainingMetric) validationLoss, validationMetric = evaluate( model, valLoader, criterion, regularizer, device ) validationLossCurve.append(validationLoss) validationMetricCurve.append(validationMetric) print( ( "| Epoch: %03d |" "| Tr.Loss: %.6f Val.Loss: %.6f |" "| Tr.Metric: %.3f Val.Metric: %.3f |" ) % ( epoch, trainingLoss, validationLoss, trainingMetric, validationMetric, ) ) scheduler.step() if epoch % args["SAVE_FREQUENCY"] == 0: savePath = ( args["CODE_DIRECTORY"] + "/checkpoints/weights/epoch_{:04d}-metric_{:.3f}.pt" ).format(epoch, validationMetric) torch.save(model.state_dict(), savePath) plt.figure() plt.title("Loss Curves") plt.xlabel("Epoch No.") plt.ylabel("Loss value") plt.plot( list(range(1, len(trainingLossCurve) + 1)), trainingLossCurve, "blue", label="Train", ) plt.plot( list(range(1, len(validationLossCurve) + 1)), validationLossCurve, "red", label="Validation", ) plt.legend() plt.savefig( ( args["CODE_DIRECTORY"] + "/checkpoints/plots/epoch_{:04d}_loss.png" ).format(epoch) ) plt.close() plt.figure() plt.title("Metric Curves") plt.xlabel("Epoch No.") plt.ylabel("Metric") plt.plot( list(range(1, len(trainingMetricCurve) + 1)), trainingMetricCurve, "blue", label="Train", ) plt.plot( list(range(1, len(validationMetricCurve) + 1)), validationMetricCurve, "red", label="Validation", ) plt.legend() plt.savefig( ( args["CODE_DIRECTORY"] + "/checkpoints/plots/epoch_{:04d}_metric.png" ).format(epoch) ) plt.close() print("Training Done.") return