def test_RPS_MLP_training(): train_set = TensorDataset( torch.zeros([50, 204, 501]), torch.zeros([50, 2]), torch.zeros([50, 204, 6]), ) valid_set = TensorDataset( torch.zeros([10, 204, 501]), torch.zeros([10, 2]), torch.zeros([10, 204, 6]), ) print(len(train_set)) device = "cpu" trainloader = DataLoader( train_set, batch_size=10, shuffle=False, num_workers=1 ) validloader = DataLoader( valid_set, batch_size=2, shuffle=False, num_workers=1 ) epochs = 1 # change between different network net = models.RPS_MLP() optimizer = Adam(net.parameters(), lr=0.00001) loss_function = torch.nn.MSELoss() print("begin training...") model, _, _ = train_bp_MLP( net, trainloader, validloader, optimizer, loss_function, device, epochs, 10, 0, "", ) print("Training do not rise error")
def main(args): data_dir = args.data_dir figure_path = args.figure_dir model_path = args.model_dir # Set skip_training to False if the model has to be trained, to True if the model has to be loaded. skip_training = False # Set the torch device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Device = {}".format(device)) # Initialize parameters parameters = Params_cross( subject_n=args.sub, hand=args.hand, batch_size=args.batch_size, valid_batch_size=args.batch_size_valid, test_batch_size=args.batch_size_test, epochs=args.epochs, lr=args.learning_rate, wd=args.weight_decay, patience=args.patience, device=device, desc=args.desc, ) # Import data and generate train-, valid- and test-set # Set if generate with RPS values or not (check network architecture used later) print("Testing: {} ".format(parameters.desc)) mlp = False train_dataset = MEG_Cross_Dataset(data_dir, parameters.subject_n, parameters.hand, mode="train") valid_dataset = MEG_Cross_Dataset(data_dir, parameters.subject_n, parameters.hand, mode="val") test_dataset = MEG_Cross_Dataset(data_dir, parameters.subject_n, parameters.hand, mode="test") transfer_dataset = MEG_Cross_Dataset(data_dir, parameters.subject_n, parameters.hand, mode="transf") print("Train dataset len {}, valid dataset len {}, test dataset len {}, " "transfer dataset len {}".format( len(train_dataset), len(valid_dataset), len(test_dataset), len(transfer_dataset), )) # Initialize the dataloaders trainloader = DataLoader(train_dataset, batch_size=parameters.batch_size, shuffle=True, num_workers=4) validloader = DataLoader(valid_dataset, batch_size=parameters.valid_batch_size, shuffle=True, num_workers=4) testloader = DataLoader( test_dataset, batch_size=parameters.test_batch_size, shuffle=False, num_workers=4, ) transferloader = DataLoader(transfer_dataset, batch_size=parameters.valid_batch_size, shuffle=True, num_workers=4) # Initialize network if mlp: net = RPS_MLP() else: # Get the n_times dimension with torch.no_grad(): sample, y, _ = iter(trainloader).next() n_times = sample.shape[-1] net = RPS_MNet_ivan(n_times) print(net) if torch.cuda.device_count() > 1: print("Let's use", torch.cuda.device_count(), "GPUs!") # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs net = nn.DataParallel(net) # Training loop if not skip_training: print("Begin training....") # Check the optimizer before running (different from model to model) optimizer = Adam(net.parameters(), lr=parameters.lr, weight_decay=parameters.wd) # optimizer = SGD(net.parameters(), lr=parameters.lr, momentum=0.9, weight_decay=parameters.wd) scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=15) print("scheduler : ", scheduler) loss_function = torch.nn.MSELoss() # loss_function = torch.nn.L1Loss() start_time = timer.time() if mlp: net, train_loss, valid_loss = train_bp_MLP( net, trainloader, validloader, optimizer, scheduler, loss_function, parameters.device, parameters.epochs, parameters.patience, parameters.hand, model_path, ) else: net, train_loss, valid_loss = train_bp( net, trainloader, validloader, optimizer, scheduler, loss_function, parameters.device, parameters.epochs, parameters.patience, parameters.hand, model_path, ) train_time = timer.time() - start_time print("Training done in {:.4f}".format(train_time)) # visualize the loss as the network trained fig = plt.figure(figsize=(10, 4)) plt.plot(range(1, len(train_loss) + 1), train_loss, label="Training Loss") plt.plot(range(1, len(valid_loss) + 1), valid_loss, label="Validation Loss") # find position of lowest validation loss minposs = valid_loss.index(min(valid_loss)) + 1 plt.axvline( minposs, linestyle="--", color="r", label="Early Stopping Checkpoint", ) plt.xlabel("epochs") plt.ylabel("loss") # plt.ylim(0, 0.5) # consistent scale # plt.xlim(0, len(train_loss)+1) # consistent scale plt.grid(True) plt.legend() plt.tight_layout() plt.show() image1 = fig plt.savefig(os.path.join(figure_path, "loss_plot.pdf")) if not skip_training: # Save the trained model save_pytorch_model(net, model_path, "model.pth") else: # Load the model (properly select the model architecture) net = RPS_MNet() net = load_pytorch_model(net, os.path.join(model_path, "model.pth"), parameters.device) # Evaluation print("Evaluation...") net.eval() y_pred = [] y = [] y_pred_valid = [] y_valid = [] # if RPS integration with torch.no_grad(): if mlp: for _, labels, bp in testloader: labels, bp = ( labels.to(parameters.device), bp.to(parameters.device), ) y.extend(list(labels[:, parameters.hand])) y_pred.extend((list(net(bp)))) for _, labels, bp in validloader: labels, bp = ( labels.to(parameters.device), bp.to(parameters.device), ) y_valid.extend(list(labels[:, parameters.hand])) y_pred_valid.extend((list(net(bp)))) else: for data, labels, bp in testloader: data, labels, bp = ( data.to(parameters.device), labels.to(parameters.device), bp.to(parameters.device), ) y.extend(list(labels[:, parameters.hand])) y_pred.extend((list(net(data, bp)))) for data, labels, bp in validloader: data, labels, bp = ( data.to(parameters.device), labels.to(parameters.device), bp.to(parameters.device), ) y_valid.extend(list(labels[:, parameters.hand])) y_pred_valid.extend((list(net(data, bp)))) # Calculate Evaluation measures print("Evaluation measures") mse = mean_squared_error(y, y_pred) rmse = mean_squared_error(y, y_pred, squared=False) mae = mean_absolute_error(y, y_pred) r2 = r2_score(y, y_pred) rmse_valid = mean_squared_error(y_valid, y_pred_valid, squared=False) r2_valid = r2_score(y_valid, y_pred_valid) valid_loss_last = min(valid_loss) print("Test set ") print("mean squared error {}".format(mse)) print("root mean squared error {}".format(rmse)) print("mean absolute error {}".format(mae)) print("r2 score {}".format(r2)) print("Validation set") print("root mean squared error valid {}".format(rmse_valid)) print("r2 score valid {}".format(r2_valid)) print("last value of the validation loss: {}".format(valid_loss_last)) # plot y_new against the true value focus on 100 timepoints fig, ax = plt.subplots(1, 1, figsize=[10, 4]) times = np.arange(200) ax.plot(times, y_pred[0:200], color="b", label="Predicted") ax.plot(times, y[0:200], color="r", label="True") ax.set_xlabel("Times") ax.set_ylabel("Target") ax.set_title("Sub {}, hand {}, Target prediction".format( str(parameters.subject_n), "sx" if parameters.hand == 0 else "dx")) plt.legend() plt.savefig(os.path.join(figure_path, "Times_prediction_focus.pdf")) plt.show() # plot y_new against the true value fig, ax = plt.subplots(1, 1, figsize=[10, 4]) times = np.arange(len(y_pred)) ax.plot(times, y_pred, color="b", label="Predicted") ax.plot(times, y, color="r", label="True") ax.set_xlabel("Times") ax.set_ylabel("Target") ax.set_title("Sub {}, hand {}, target prediction".format( str(parameters.subject_n), "sx" if parameters.hand == 0 else "dx")) plt.legend() plt.savefig(os.path.join(figure_path, "Times_prediction.pdf")) plt.show() # scatterplot y predicted against the true value fig, ax = plt.subplots(1, 1, figsize=[10, 4]) ax.scatter(np.array(y), np.array(y_pred), color="b", label="Predicted") ax.set_xlabel("True") ax.set_ylabel("Predicted") # plt.legend() plt.savefig(os.path.join(figure_path, "Scatter.pdf")) plt.show() # scatterplot y predicted against the true value fig, ax = plt.subplots(1, 1, figsize=[10, 4]) ax.scatter(np.array(y_valid), np.array(y_pred_valid), color="b", label="Predicted") ax.set_xlabel("True") ax.set_ylabel("Predicted") # plt.legend() plt.savefig(os.path.join(figure_path, "Scatter_valid.pdf")) plt.show() # Transfer learning, feature extraction. optimizer_trans = SGD(net.parameters(), lr=3e-4) loss_function_trans = torch.nn.MSELoss() # loss_function_trans = torch.nn.L1Loss() if mlp: net, train_loss = train_mlp_transfer( net, transferloader, optimizer_trans, loss_function_trans, parameters.device, 50, parameters.patience, parameters.hand, model_path, ) else: # net, train_loss = train_bp_transfer( # net, # transferloader, # optimizer_trans, # loss_function_trans, # parameters.device, # 50, # parameters.patience, # parameters.hand, # model_path, # ) net, train_loss = train_bp_fine_tuning(net, transferloader, optimizer_trans, loss_function_trans, parameters.device, 50, 10, parameters.hand, model_path) # Evaluation print("Evaluation after transfer...") net.eval() y_pred = [] y = [] # if RPS integration with torch.no_grad(): if mlp: for _, labels, bp in testloader: labels, bp = ( labels.to(parameters.device), bp.to(parameters.device), ) y.extend(list(labels[:, parameters.hand])) y_pred.extend((list(net(bp)))) else: for data, labels, bp in testloader: data, labels, bp = ( data.to(parameters.device), labels.to(parameters.device), bp.to(parameters.device), ) y.extend(list(labels[:, parameters.hand])) y_pred.extend((list(net(data, bp)))) print("Evaluation measures") rmse_trans = mean_squared_error(y, y_pred, squared=False) r2_trans = r2_score(y, y_pred) print("root mean squared error after transfer learning {}".format( rmse_trans)) print("r2 score after transfer learning {}".format(r2_trans)) # scatterplot y predicted against the true value fig, ax = plt.subplots(1, 1, figsize=[10, 4]) ax.scatter(np.array(y), np.array(y_pred), color="b", label="Predicted") ax.set_xlabel("True") ax.set_ylabel("Predicted") # plt.legend() plt.savefig(os.path.join(figure_path, "Scatter_after_trans.pdf")) plt.show() # log the model and parameters using mlflow tracker with mlflow.start_run(experiment_id=args.experiment) as run: for key, value in vars(parameters).items(): mlflow.log_param(key, value) mlflow.log_param("Time", train_time) mlflow.log_metric("MSE", mse) mlflow.log_metric("RMSE", rmse) mlflow.log_metric("MAE", mae) mlflow.log_metric("R2", r2) mlflow.log_metric("RMSE_Valid", rmse_valid) mlflow.log_metric("R2_Valid", r2_valid) mlflow.log_metric("Valid_loss", valid_loss_last) mlflow.log_metric("RMSE_T", rmse_trans) mlflow.log_metric("R2_T", r2_trans) mlflow.log_artifact(os.path.join(figure_path, "Times_prediction.pdf")) mlflow.log_artifact( os.path.join(figure_path, "Times_prediction_focus.pdf")) mlflow.log_artifact(os.path.join(figure_path, "loss_plot.pdf")) mlflow.log_artifact(os.path.join(figure_path, "Scatter.pdf")) mlflow.log_artifact(os.path.join(figure_path, "Scatter_valid.pdf")) mlflow.log_artifact( os.path.join(figure_path, "Scatter_after_trans.pdf")) mlflow.pytorch.log_model(net, "models")
def main(args): data_dir = args.data_dir figure_path = args.figure_dir model_path = args.model_dir file_name = "ball_left_mean.npz" # Set skip_training to False if the model has to be trained, to True if the model has to be loaded. skip_training = False # Set the torch device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Device = {}".format(device)) # Initialize parameters parameters = Params_cross(subject_n=args.sub, hand=args.hand, batch_size=args.batch_size, valid_batch_size=args.batch_size_valid, test_batch_size=args.batch_size_test, epochs=args.epochs, lr=args.learning_rate, wd=args.weight_decay, patience=args.patience, device=device, desc=args.desc) # Set if generate with RPS values or not (check network architecture used later) # if mlp = rps-mlp, elif rps = rps-mnet, else mnet mlp = False rps = True print("Creating dataset") # Generate the custom dataset train_dataset = MEG_Within_Dataset_ivan(data_dir, parameters.subject_n, parameters.hand, mode="train") test_dataset = MEG_Within_Dataset_ivan(data_dir, parameters.subject_n, parameters.hand, mode="test") valid_dataset = MEG_Within_Dataset_ivan(data_dir, parameters.subject_n, parameters.hand, mode="val") # split the dataset in train, test and valid sets. print("train set {}, val set {}, test set {}".format( len(train_dataset), len(valid_dataset), len(test_dataset))) # train_dataset, valid_test, test_dataset = random_split(dataset, [train_len, valid_len, test_len], # generator=torch.Generator().manual_seed(42)) # train_dataset, valid_test, test_dataset = random_split(dataset, [train_len, valid_len, test_len]) # Better vizualization # train_valid_dataset = Subset(dataset, list(range(train_len+valid_len))) # test_dataset = Subset(dataset, list(range(train_len+valid_len, len(dataset)))) # # train_dataset, valid_dataset = random_split(train_valid_dataset, [train_len, valid_len]) # Initialize the dataloaders trainloader = DataLoader(train_dataset, batch_size=parameters.batch_size, shuffle=True, num_workers=1) validloader = DataLoader(valid_dataset, batch_size=parameters.valid_batch_size, shuffle=True, num_workers=1) testloader = DataLoader(test_dataset, batch_size=parameters.test_batch_size, shuffle=False, num_workers=1) # Get the n_times dimension if mlp: net = RPS_MLP() # net = RPS_CNN() else: # Get the n_times dimension with torch.no_grad(): sample, y, _ = iter(trainloader).next() n_times = sample.shape[-1] if rps: net = RPS_MNet_ivan(n_times) else: net = MNet_ivan(n_times) print(net) total_params = 0 for name, parameter in net.named_parameters(): param = parameter.numel() print("param {} : {}".format(name, param if parameter.requires_grad else 0)) total_params += param print(f"Total Trainable Params: {total_params}") # Training loop or model loading if not skip_training: print("Begin training....") # Check the optimizer before running (different from model to model) # optimizer = Adam(net.parameters(), lr=parameters.lr) optimizer = Adam(net.parameters(), lr=parameters.lr, weight_decay=parameters.wd) # optimizer = SGD(net.parameters(), lr=parameters.lr, momentum=0.9, weight_decay=parameters.wd) # optimizer = SGD(net.parameters(), lr=parameters.lr, momentum=0.9) print("optimizer : ", optimizer) scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=15) print("scheduler : ", scheduler) loss_function = torch.nn.MSELoss() # loss_function = torch.nn.L1Loss() print("loss :", loss_function) start_time = timer.time() if mlp: net, train_loss, valid_loss = train_bp_MLP( net, trainloader, validloader, optimizer, scheduler, loss_function, parameters.device, parameters.epochs, parameters.patience, parameters.hand, model_path, ) else: if rps: net, train_loss, valid_loss = train_bp( net, trainloader, validloader, optimizer, scheduler, loss_function, parameters.device, parameters.epochs, parameters.patience, parameters.hand, model_path, ) else: net, train_loss, valid_loss = train( net, trainloader, validloader, optimizer, scheduler, loss_function, parameters.device, parameters.epochs, parameters.patience, parameters.hand, model_path, ) train_time = timer.time() - start_time print("Training done in {:.4f}".format(train_time)) # visualize the loss as the network trained fig = plt.figure(figsize=(10, 4)) plt.plot(range(1, len(train_loss) + 1), train_loss, label='Training Loss') plt.plot(range(1, len(valid_loss) + 1), valid_loss, label='Validation Loss') # find position of lowest validation loss minposs = valid_loss.index(min(valid_loss)) + 1 plt.axvline(minposs, linestyle='--', color='r', label='Early Stopping Checkpoint') plt.xlabel("epochs") plt.ylabel("loss") # plt.ylim(0, 0.5) # consistent scale # plt.xlim(0, len(train_loss)+1) # consistent scale plt.grid(True) plt.legend() plt.tight_layout() plt.show() image1 = fig plt.savefig(os.path.join(figure_path, "loss_plot.pdf")) if not skip_training: # Save the trained model save_pytorch_model(net, model_path, "model.pth") else: # Load the model (properly select the model architecture) net = RPS_MNet() net = load_pytorch_model(net, os.path.join(model_path, "model.pth"), parameters.device) # Evaluation print("Evaluation...") net.eval() y_pred = [] y = [] y_pred_valid = [] y_valid = [] # if RPS integration with torch.no_grad(): if mlp: for _, labels, bp in testloader: labels, bp = labels.to(parameters.device), \ bp.to(parameters.device) y.extend(list(labels[:, parameters.hand])) y_pred.extend((list(net(bp)))) for _, labels, bp in validloader: labels, bp = ( labels.to(parameters.device), bp.to(parameters.device), ) y_valid.extend(list(labels[:, parameters.hand])) y_pred_valid.extend((list(net(bp)))) else: if rps: for data, labels, bp in testloader: data, labels, bp = ( data.to(parameters.device), labels.to(parameters.device), bp.to(parameters.device), ) y.extend(list(labels[:, parameters.hand])) y_pred.extend((list(net(data, bp)))) for data, labels, bp in validloader: data, labels, bp = ( data.to(parameters.device), labels.to(parameters.device), bp.to(parameters.device), ) y_valid.extend(list(labels[:, parameters.hand])) y_pred_valid.extend((list(net(data, bp)))) else: for data, labels, _ in testloader: data, labels = ( data.to(parameters.device), labels.to(parameters.device), ) y.extend(list(labels[:, parameters.hand])) y_pred.extend((list(net(data)))) for data, labels, _ in validloader: data, labels = ( data.to(parameters.device), labels.to(parameters.device), ) y_valid.extend(list(labels[:, parameters.hand])) y_pred_valid.extend((list(net(data)))) # Calculate Evaluation measures print("Evaluation measures") mse = mean_squared_error(y, y_pred) rmse = mean_squared_error(y, y_pred, squared=False) mae = mean_absolute_error(y, y_pred) r2 = r2_score(y, y_pred) rmse_valid = mean_squared_error(y_valid, y_pred_valid, squared=False) r2_valid = r2_score(y_valid, y_pred_valid) valid_loss_last = min(valid_loss) print("Test set ") print("mean squared error {}".format(mse)) print("root mean squared error {}".format(rmse)) print("mean absolute error {}".format(mae)) print("r2 score {}".format(r2)) print("Validation set") print("root mean squared error valid {}".format(rmse_valid)) print("r2 score valid {}".format(r2_valid)) print("last value of the validation loss: {}".format(valid_loss_last)) # plot y_new against the true value focus on 200 timepoints fig, ax = plt.subplots(1, 1, figsize=[14, 6]) times = np.arange(1000) ax.plot(times, y_pred[:1000], color="b", label="Predicted") ax.plot(times, y[:1000], color="r", label="True") ax.set_xlabel("Times") ax.set_ylabel("Target") ax.set_title("Sub {}, hand {} prediction".format( str(parameters.subject_n), "sx" if parameters.hand == 0 else "dx")) plt.legend() plt.savefig(os.path.join(figure_path, "Times_prediction_focus.pdf")) plt.show() # plot y_new against the true value fig, ax = plt.subplots(1, 1, figsize=[10, 4]) times = np.arange(len(y_pred)) ax.plot(times, y_pred, color="b", label="Predicted") ax.plot(times, y, color="r", label="True") ax.set_xlabel("Times") ax.set_ylabel("Target") ax.set_title("Sub {}, hand {}, prediction".format( str(parameters.subject_n), "sx" if parameters.hand == 0 else "dx")) plt.legend() plt.savefig(os.path.join(figure_path, "Times_prediction.pdf")) plt.show() # scatterplot y predicted against the true value fig, ax = plt.subplots(1, 1, figsize=[10, 4]) ax.scatter(np.array(y), np.array(y_pred), color="b", label="Predicted") ax.set_xlabel("True") ax.set_ylabel("Predicted") # plt.legend() plt.savefig(os.path.join(figure_path, "Scatter.pdf")) plt.show() # scatterplot y predicted against the true value fig, ax = plt.subplots(1, 1, figsize=[10, 4]) ax.scatter(np.array(y_valid), np.array(y_pred_valid), color="b", label="Predicted") ax.set_xlabel("True") ax.set_ylabel("Predicted") # plt.legend() plt.savefig(os.path.join(figure_path, "Scatter_valid.pdf")) plt.show() # Save prediction for post analysis out_file = "prediction_sub{}_hand_{}.npz".format( str(parameters.subject_n), "left" if parameters.hand == 0 else "right") np.savez(os.path.join(data_dir, out_file), y_pred=y_pred, y=y) # log the model and parameters using mlflow tracker with mlflow.start_run(experiment_id=args.experiment) as run: for key, value in vars(parameters).items(): mlflow.log_param(key, value) mlflow.log_param("Time", train_time) mlflow.log_metric("MSE", mse) mlflow.log_metric("RMSE", rmse) mlflow.log_metric("MAE", mae) mlflow.log_metric("R2", r2) mlflow.log_metric("RMSE_Valid", rmse_valid) mlflow.log_metric("R2_Valid", r2_valid) mlflow.log_metric("Valid_loss", valid_loss_last) mlflow.log_artifact(os.path.join(figure_path, "Times_prediction.pdf")) mlflow.log_artifact( os.path.join(figure_path, "Times_prediction_focus.pdf")) mlflow.log_artifact(os.path.join(figure_path, "loss_plot.pdf")) mlflow.log_artifact(os.path.join(figure_path, "Scatter.pdf")) mlflow.log_artifact(os.path.join(figure_path, "Scatter_valid.pdf")) mlflow.pytorch.log_model(net, "models")
def main(args): data_dir = args.data_dir figure_path = args.figure_dir model_path = args.model_dir # Generate the data input path list. Each subject has 3 runs stored in 3 different files. subj_id = "/sub" + str(args.sub) + "/ball0" raw_fnames = [ "".join([data_dir, subj_id, str(i), "_sss_trans.fif"]) for i in range(1 if args.sub != 3 else 2, 4) ] # local # subj_id = "/sub"+str(args.sub)+"/ball" # raw_fnames = ["".join([data_dir, subj_id, str(i), "_sss.fif"]) for i in range(1, 2)] # Set skip_training to False if the model has to be trained, to True if the model has to be loaded. skip_training = False # Set the torch device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Device = {}".format(device)) # Initialize parameters parameters = Params_tunable( subject_n=args.sub, hand=args.hand, batch_size=args.batch_size, valid_batch_size=args.batch_size_valid, test_batch_size=args.batch_size_test, epochs=args.epochs, lr=args.learning_rate, duration=args.duration, overlap=args.overlap, patience=args.patience, device=device, y_measure=args.y_measure, s_n_layer=args.s_n_layer, # s_kernel_size=args.s_kernel_size, # Local s_kernel_size=json.loads(" ".join(args.s_kernel_size)), t_n_layer=args.t_n_layer, # t_kernel_size=args.t_kernel_size, # Local t_kernel_size=json.loads(" ".join(args.t_kernel_size)), max_pooling=args.max_pooling, ff_n_layer=args.ff_n_layer, ff_hidden_channels=args.ff_hidden_channels, dropout=args.dropout, activation=args.activation, ) # Set if generate with RPS values or not (check network architecture used later) rps = True # Generate the custom dataset if rps: dataset = MEG_Dataset( raw_fnames, parameters.duration, parameters.overlap, parameters.y_measure, normalize_input=True, ) else: dataset = MEG_Dataset_no_bp( raw_fnames, parameters.duration, parameters.overlap, parameters.y_measure, normalize_input=True, ) # split the dataset in train, test and valid sets. train_len, valid_len, test_len = len_split(len(dataset)) print( "{} + {} + {} = {}?".format( train_len, valid_len, test_len, len(dataset) ) ) # train_dataset, valid_test, test_dataset = random_split(dataset, [train_len, valid_len, test_len], # generator=torch.Generator().manual_seed(42)) train_dataset, valid_test, test_dataset = random_split( dataset, [train_len, valid_len, test_len] ) # Better vizualization # train_valid_dataset = Subset(dataset, list(range(train_len+valid_len))) # test_dataset = Subset(dataset, list(range(train_len+valid_len, len(dataset)))) # # train_dataset, valid_dataset = random_split(train_valid_dataset, [train_len, valid_len]) # Initialize the dataloaders trainloader = DataLoader( train_dataset, batch_size=parameters.batch_size, shuffle=True, num_workers=1, ) validloader = DataLoader( valid_test, batch_size=parameters.valid_batch_size, shuffle=True, num_workers=1, ) testloader = DataLoader( test_dataset, batch_size=parameters.test_batch_size, shuffle=False, num_workers=1, ) # Get the n_times dimension with torch.no_grad(): # Changes if RPS integration or not if rps: x, _, _ = iter(trainloader).next() else: x, _ = iter(trainloader).next() n_times = x.shape[-1] # Initialize network # net = LeNet5(n_times) # net = ResNet([2, 2, 2], 64, n_times) # net = SCNN(parameters.s_n_layer, # parameters.s_kernel_size, # parameters.t_n_layer, # parameters.t_kernel_size, # n_times, # parameters.ff_n_layer, # parameters.ff_hidden_channels, # parameters.dropout, # parameters.max_pooling, # parameters.activation) # net = MNet(n_times) # net = RPS_SCNN(parameters.s_n_layer, # parameters.s_kernel_size, # parameters.t_n_layer, # parameters.t_kernel_size, # n_times, # parameters.ff_n_layer, # parameters.ff_hidden_channels, # parameters.dropout, # parameters.max_pooling, # parameters.activation) net = RPS_MNet(n_times) # net = RPS_MLP() mlp = False print(net) # Training loop or model loading if not skip_training: print("Begin training....") # Check the optimizer before running (different from model to model) optimizer = Adam(net.parameters(), lr=parameters.lr, weight_decay=5e-4) # optimizer = SGD(net.parameters(), lr=parameters.lr, weight_decay=5e-4) scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=15) print("scheduler : ", scheduler) loss_function = torch.nn.MSELoss() start_time = timer.time() if rps: if mlp: net, train_loss, valid_loss = train_bp_MLP( net, trainloader, validloader, optimizer, scheduler, loss_function, parameters.device, parameters.epochs, parameters.patience, parameters.hand, model_path, ) else: net, train_loss, valid_loss = train_bp( net, trainloader, validloader, optimizer, scheduler, loss_function, parameters.device, parameters.epochs, parameters.patience, parameters.hand, model_path, ) else: net, train_loss, valid_loss = train( net, trainloader, validloader, optimizer, scheduler, loss_function, parameters.device, parameters.epochs, parameters.patience, parameters.hand, model_path, ) train_time = timer.time() - start_time print("Training done in {:.4f}".format(train_time)) # visualize the loss as the network trained fig = plt.figure(figsize=(10, 4)) plt.plot( range(1, len(train_loss) + 1), train_loss, label="Training Loss" ) plt.plot( range(1, len(valid_loss) + 1), valid_loss, label="Validation Loss" ) # find position of lowest validation loss minposs = valid_loss.index(min(valid_loss)) + 1 plt.axvline( minposs, linestyle="--", color="r", label="Early Stopping Checkpoint", ) plt.xlabel("epochs") plt.ylabel("loss") # plt.ylim(0, 0.5) # consistent scale # plt.xlim(0, len(train_loss)+1) # consistent scale plt.grid(True) plt.legend() plt.tight_layout() plt.show() image1 = fig plt.savefig(os.path.join(figure_path, "loss_plot.pdf")) if not skip_training: # Save the trained model save_pytorch_model(net, model_path, "Baselinemodel_SCNN_swap.pth") else: # Load the model (properly select the model architecture) net = RPS_MNet() net = load_pytorch_model( net, os.path.join(model_path, "model.pth"), parameters.device ) # Evaluation print("Evaluation...") net.eval() y_pred = [] y = [] # if RPS integration with torch.no_grad(): if rps: if mlp: for _, labels, bp in testloader: labels, bp = labels.to(parameters.device), bp.to(device) y.extend(list(labels[:, parameters.hand])) y_pred.extend((list(net(bp)))) else: for data, labels, bp in testloader: data, labels, bp = ( data.to(parameters.device), labels.to(parameters.device), bp.to(device), ) y.extend(list(labels[:, parameters.hand])) y_pred.extend((list(net(data, bp)))) else: for data, labels in testloader: data, labels = ( data.to(parameters.device), labels.to(parameters.device), ) y.extend(list(labels[:, parameters.hand])) y_pred.extend((list(net(data)))) print("SCNN_swap...") # Calculate Evaluation measures mse = mean_squared_error(y, y_pred) rmse = mean_squared_error(y, y_pred, squared=False) mae = mean_absolute_error(y, y_pred) r2 = r2_score(y, y_pred) print("mean squared error {}".format(mse)) print("root mean squared error {}".format(rmse)) print("mean absolute error {}".format(mae)) print("r2 score {}".format(r2)) # plot y_new against the true value focus on 100 timepoints fig, ax = plt.subplots(1, 1, figsize=[10, 4]) times = np.arange(100) ax.plot(times, y_pred[0:100], color="b", label="Predicted") ax.plot(times, y[0:100], color="r", label="True") ax.set_xlabel("Times") ax.set_ylabel("{}".format(parameters.y_measure)) ax.set_title( "Sub {}, hand {}, {} prediction".format( str(parameters.subject_n), "sx" if parameters.hand == 0 else "dx", parameters.y_measure, ) ) plt.legend() plt.savefig(os.path.join(figure_path, "Times_prediction_focus.pdf")) plt.show() # plot y_new against the true value fig, ax = plt.subplots(1, 1, figsize=[10, 4]) times = np.arange(len(y_pred)) ax.plot(times, y_pred, color="b", label="Predicted") ax.plot(times, y, color="r", label="True") ax.set_xlabel("Times") ax.set_ylabel("{}".format(parameters.y_measure)) ax.set_title( "Sub {}, hand {}, {} prediction".format( str(parameters.subject_n), "sx" if parameters.hand == 0 else "dx", parameters.y_measure, ) ) plt.legend() plt.savefig(os.path.join(figure_path, "Times_prediction.pdf")) plt.show() # scatterplot y predicted against the true value fig, ax = plt.subplots(1, 1, figsize=[10, 4]) ax.scatter(np.array(y), np.array(y_pred), color="b", label="Predicted") ax.set_xlabel("True") ax.set_ylabel("Predicted") # plt.legend() plt.savefig(os.path.join(figure_path, "Scatter.pdf")) plt.show() # log the model and parameters using mlflow tracker with mlflow.start_run(experiment_id=args.experiment) as run: for key, value in vars(parameters).items(): mlflow.log_param(key, value) mlflow.log_param("Time", train_time) mlflow.log_metric("MSE", mse) mlflow.log_metric("RMSE", rmse) mlflow.log_metric("MAE", mae) mlflow.log_metric("R2", r2) mlflow.log_artifact(os.path.join(figure_path, "Times_prediction.pdf")) mlflow.log_artifact( os.path.join(figure_path, "Times_prediction_focus.pdf") ) mlflow.log_artifact(os.path.join(figure_path, "loss_plot.pdf")) mlflow.log_artifact(os.path.join(figure_path, "Scatter.pdf")) mlflow.pytorch.log_model(net, "models")