def main(): logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", filename="log.txt", filemode='w') console = logging.StreamHandler() console.setLevel(logging.INFO) formatter = logging.Formatter('%(message)s') console.setFormatter(formatter) logging.getLogger('').addHandler(console) DATA_PATH = "/Users/usi/PycharmProjects/data/160x90/" # Get baseline results picklename = "160x90HimaxMixedTest_12_03_20.pickle" [x_test, y_test] = DataProcessor.ProcessTestData(DATA_PATH + picklename) test_set = Dataset(x_test, y_test) params = {'batch_size': 1, 'shuffle': False, 'num_workers': 1} test_generator = data.DataLoader(test_set, **params) model = Dronet(PreActBlock, [1, 1, 1], True) ModelManager.Read('../PyTorch/Models/DronetHimax160x90AugCrop.pt', model) trainer = ModelTrainer(model) MSE2, MAE2, r2_score2, outputs2, gt_labels2 = trainer.Test(test_generator) # Get pitch values picklename = "160x90HimaxMixedTest_12_03_20Cropped70.pickle" p_test = DataProcessor.GetPitchFromTestData(DATA_PATH + picklename) if picklename.find(".pickle"): picklename = picklename.replace(".pickle", '') Plot2Models(p_test, picklename, r2_score2)