예제 #1
0
def test_RPS_MLP_training():

    train_set = TensorDataset(
        torch.zeros([50, 204, 501]),
        torch.zeros([50, 2]),
        torch.zeros([50, 204, 6]),
    )

    valid_set = TensorDataset(
        torch.zeros([10, 204, 501]),
        torch.zeros([10, 2]),
        torch.zeros([10, 204, 6]),
    )

    print(len(train_set))

    device = "cpu"

    trainloader = DataLoader(
        train_set, batch_size=10, shuffle=False, num_workers=1
    )

    validloader = DataLoader(
        valid_set, batch_size=2, shuffle=False, num_workers=1
    )

    epochs = 1

    # change between different network
    net = models.RPS_MLP()
    optimizer = Adam(net.parameters(), lr=0.00001)
    loss_function = torch.nn.MSELoss()

    print("begin training...")
    model, _, _ = train_bp_MLP(
        net,
        trainloader,
        validloader,
        optimizer,
        loss_function,
        device,
        epochs,
        10,
        0,
        "",
    )

    print("Training do not rise error")
예제 #2
0
def main(args):

    data_dir = args.data_dir
    figure_path = args.figure_dir
    model_path = args.model_dir

    # Set skip_training to False if the model has to be trained, to True if the model has to be loaded.
    skip_training = False

    # Set the torch device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device = {}".format(device))

    # Initialize parameters
    parameters = Params_cross(
        subject_n=args.sub,
        hand=args.hand,
        batch_size=args.batch_size,
        valid_batch_size=args.batch_size_valid,
        test_batch_size=args.batch_size_test,
        epochs=args.epochs,
        lr=args.learning_rate,
        wd=args.weight_decay,
        patience=args.patience,
        device=device,
        desc=args.desc,
    )
    # Import data and generate train-, valid- and test-set
    # Set if generate with RPS values or not (check network architecture used later)

    print("Testing: {} ".format(parameters.desc))

    mlp = False

    train_dataset = MEG_Cross_Dataset(data_dir,
                                      parameters.subject_n,
                                      parameters.hand,
                                      mode="train")

    valid_dataset = MEG_Cross_Dataset(data_dir,
                                      parameters.subject_n,
                                      parameters.hand,
                                      mode="val")

    test_dataset = MEG_Cross_Dataset(data_dir,
                                     parameters.subject_n,
                                     parameters.hand,
                                     mode="test")

    transfer_dataset = MEG_Cross_Dataset(data_dir,
                                         parameters.subject_n,
                                         parameters.hand,
                                         mode="transf")

    print("Train dataset len {}, valid dataset len {}, test dataset len {}, "
          "transfer dataset len {}".format(
              len(train_dataset),
              len(valid_dataset),
              len(test_dataset),
              len(transfer_dataset),
          ))

    # Initialize the dataloaders
    trainloader = DataLoader(train_dataset,
                             batch_size=parameters.batch_size,
                             shuffle=True,
                             num_workers=4)

    validloader = DataLoader(valid_dataset,
                             batch_size=parameters.valid_batch_size,
                             shuffle=True,
                             num_workers=4)

    testloader = DataLoader(
        test_dataset,
        batch_size=parameters.test_batch_size,
        shuffle=False,
        num_workers=4,
    )

    transferloader = DataLoader(transfer_dataset,
                                batch_size=parameters.valid_batch_size,
                                shuffle=True,
                                num_workers=4)

    # Initialize network
    if mlp:
        net = RPS_MLP()
    else:
        # Get the n_times dimension
        with torch.no_grad():
            sample, y, _ = iter(trainloader).next()

        n_times = sample.shape[-1]
        net = RPS_MNet_ivan(n_times)

    print(net)

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
        net = nn.DataParallel(net)

    # Training loop
    if not skip_training:
        print("Begin training....")

        # Check the optimizer before running (different from model to model)
        optimizer = Adam(net.parameters(),
                         lr=parameters.lr,
                         weight_decay=parameters.wd)
        # optimizer = SGD(net.parameters(), lr=parameters.lr, momentum=0.9, weight_decay=parameters.wd)

        scheduler = ReduceLROnPlateau(optimizer,
                                      mode="min",
                                      factor=0.5,
                                      patience=15)

        print("scheduler : ", scheduler)

        loss_function = torch.nn.MSELoss()
        # loss_function = torch.nn.L1Loss()
        start_time = timer.time()

        if mlp:
            net, train_loss, valid_loss = train_bp_MLP(
                net,
                trainloader,
                validloader,
                optimizer,
                scheduler,
                loss_function,
                parameters.device,
                parameters.epochs,
                parameters.patience,
                parameters.hand,
                model_path,
            )
        else:
            net, train_loss, valid_loss = train_bp(
                net,
                trainloader,
                validloader,
                optimizer,
                scheduler,
                loss_function,
                parameters.device,
                parameters.epochs,
                parameters.patience,
                parameters.hand,
                model_path,
            )

        train_time = timer.time() - start_time
        print("Training done in {:.4f}".format(train_time))

        # visualize the loss as the network trained
        fig = plt.figure(figsize=(10, 4))
        plt.plot(range(1,
                       len(train_loss) + 1),
                 train_loss,
                 label="Training Loss")
        plt.plot(range(1,
                       len(valid_loss) + 1),
                 valid_loss,
                 label="Validation Loss")

        # find position of lowest validation loss
        minposs = valid_loss.index(min(valid_loss)) + 1
        plt.axvline(
            minposs,
            linestyle="--",
            color="r",
            label="Early Stopping Checkpoint",
        )

        plt.xlabel("epochs")
        plt.ylabel("loss")
        # plt.ylim(0, 0.5) # consistent scale
        # plt.xlim(0, len(train_loss)+1) # consistent scale
        plt.grid(True)
        plt.legend()
        plt.tight_layout()
        plt.show()
        image1 = fig
        plt.savefig(os.path.join(figure_path, "loss_plot.pdf"))

    if not skip_training:
        # Save the trained model
        save_pytorch_model(net, model_path, "model.pth")
    else:
        # Load the model (properly select the model architecture)
        net = RPS_MNet()
        net = load_pytorch_model(net, os.path.join(model_path, "model.pth"),
                                 parameters.device)

    # Evaluation
    print("Evaluation...")
    net.eval()
    y_pred = []
    y = []
    y_pred_valid = []
    y_valid = []

    # if RPS integration
    with torch.no_grad():
        if mlp:
            for _, labels, bp in testloader:
                labels, bp = (
                    labels.to(parameters.device),
                    bp.to(parameters.device),
                )
                y.extend(list(labels[:, parameters.hand]))
                y_pred.extend((list(net(bp))))

            for _, labels, bp in validloader:
                labels, bp = (
                    labels.to(parameters.device),
                    bp.to(parameters.device),
                )
                y_valid.extend(list(labels[:, parameters.hand]))
                y_pred_valid.extend((list(net(bp))))
        else:
            for data, labels, bp in testloader:
                data, labels, bp = (
                    data.to(parameters.device),
                    labels.to(parameters.device),
                    bp.to(parameters.device),
                )
                y.extend(list(labels[:, parameters.hand]))
                y_pred.extend((list(net(data, bp))))

            for data, labels, bp in validloader:
                data, labels, bp = (
                    data.to(parameters.device),
                    labels.to(parameters.device),
                    bp.to(parameters.device),
                )
                y_valid.extend(list(labels[:, parameters.hand]))
                y_pred_valid.extend((list(net(data, bp))))

    # Calculate Evaluation measures
    print("Evaluation measures")
    mse = mean_squared_error(y, y_pred)
    rmse = mean_squared_error(y, y_pred, squared=False)
    mae = mean_absolute_error(y, y_pred)
    r2 = r2_score(y, y_pred)

    rmse_valid = mean_squared_error(y_valid, y_pred_valid, squared=False)
    r2_valid = r2_score(y_valid, y_pred_valid)
    valid_loss_last = min(valid_loss)

    print("Test set ")
    print("mean squared error {}".format(mse))
    print("root mean squared error {}".format(rmse))
    print("mean absolute error {}".format(mae))
    print("r2 score {}".format(r2))

    print("Validation set")
    print("root mean squared error valid {}".format(rmse_valid))
    print("r2 score valid {}".format(r2_valid))
    print("last value of the validation loss: {}".format(valid_loss_last))

    # plot y_new against the true value focus on 100 timepoints
    fig, ax = plt.subplots(1, 1, figsize=[10, 4])
    times = np.arange(200)
    ax.plot(times, y_pred[0:200], color="b", label="Predicted")
    ax.plot(times, y[0:200], color="r", label="True")
    ax.set_xlabel("Times")
    ax.set_ylabel("Target")
    ax.set_title("Sub {}, hand {}, Target prediction".format(
        str(parameters.subject_n), "sx" if parameters.hand == 0 else "dx"))
    plt.legend()
    plt.savefig(os.path.join(figure_path, "Times_prediction_focus.pdf"))
    plt.show()

    # plot y_new against the true value
    fig, ax = plt.subplots(1, 1, figsize=[10, 4])
    times = np.arange(len(y_pred))
    ax.plot(times, y_pred, color="b", label="Predicted")
    ax.plot(times, y, color="r", label="True")
    ax.set_xlabel("Times")
    ax.set_ylabel("Target")
    ax.set_title("Sub {}, hand {}, target prediction".format(
        str(parameters.subject_n), "sx" if parameters.hand == 0 else "dx"))
    plt.legend()
    plt.savefig(os.path.join(figure_path, "Times_prediction.pdf"))
    plt.show()

    # scatterplot y predicted against the true value
    fig, ax = plt.subplots(1, 1, figsize=[10, 4])
    ax.scatter(np.array(y), np.array(y_pred), color="b", label="Predicted")
    ax.set_xlabel("True")
    ax.set_ylabel("Predicted")
    # plt.legend()
    plt.savefig(os.path.join(figure_path, "Scatter.pdf"))
    plt.show()

    # scatterplot y predicted against the true value
    fig, ax = plt.subplots(1, 1, figsize=[10, 4])
    ax.scatter(np.array(y_valid),
               np.array(y_pred_valid),
               color="b",
               label="Predicted")
    ax.set_xlabel("True")
    ax.set_ylabel("Predicted")
    # plt.legend()
    plt.savefig(os.path.join(figure_path, "Scatter_valid.pdf"))
    plt.show()

    # Transfer learning, feature extraction.

    optimizer_trans = SGD(net.parameters(), lr=3e-4)

    loss_function_trans = torch.nn.MSELoss()
    # loss_function_trans = torch.nn.L1Loss()

    if mlp:
        net, train_loss = train_mlp_transfer(
            net,
            transferloader,
            optimizer_trans,
            loss_function_trans,
            parameters.device,
            50,
            parameters.patience,
            parameters.hand,
            model_path,
        )
    else:
        # net, train_loss = train_bp_transfer(
        #     net,
        #     transferloader,
        #     optimizer_trans,
        #     loss_function_trans,
        #     parameters.device,
        #     50,
        #     parameters.patience,
        #     parameters.hand,
        #     model_path,
        # )
        net, train_loss = train_bp_fine_tuning(net, transferloader,
                                               optimizer_trans,
                                               loss_function_trans,
                                               parameters.device, 50, 10,
                                               parameters.hand, model_path)

    # Evaluation
    print("Evaluation after transfer...")
    net.eval()
    y_pred = []
    y = []

    # if RPS integration
    with torch.no_grad():
        if mlp:
            for _, labels, bp in testloader:
                labels, bp = (
                    labels.to(parameters.device),
                    bp.to(parameters.device),
                )
                y.extend(list(labels[:, parameters.hand]))
                y_pred.extend((list(net(bp))))
        else:
            for data, labels, bp in testloader:
                data, labels, bp = (
                    data.to(parameters.device),
                    labels.to(parameters.device),
                    bp.to(parameters.device),
                )
                y.extend(list(labels[:, parameters.hand]))
                y_pred.extend((list(net(data, bp))))

    print("Evaluation measures")
    rmse_trans = mean_squared_error(y, y_pred, squared=False)
    r2_trans = r2_score(y, y_pred)

    print("root mean squared error after transfer learning {}".format(
        rmse_trans))
    print("r2 score after transfer learning  {}".format(r2_trans))

    # scatterplot y predicted against the true value
    fig, ax = plt.subplots(1, 1, figsize=[10, 4])
    ax.scatter(np.array(y), np.array(y_pred), color="b", label="Predicted")
    ax.set_xlabel("True")
    ax.set_ylabel("Predicted")
    # plt.legend()
    plt.savefig(os.path.join(figure_path, "Scatter_after_trans.pdf"))
    plt.show()

    # log the model and parameters using mlflow tracker
    with mlflow.start_run(experiment_id=args.experiment) as run:
        for key, value in vars(parameters).items():
            mlflow.log_param(key, value)

        mlflow.log_param("Time", train_time)

        mlflow.log_metric("MSE", mse)
        mlflow.log_metric("RMSE", rmse)
        mlflow.log_metric("MAE", mae)
        mlflow.log_metric("R2", r2)

        mlflow.log_metric("RMSE_Valid", rmse_valid)
        mlflow.log_metric("R2_Valid", r2_valid)
        mlflow.log_metric("Valid_loss", valid_loss_last)

        mlflow.log_metric("RMSE_T", rmse_trans)
        mlflow.log_metric("R2_T", r2_trans)

        mlflow.log_artifact(os.path.join(figure_path, "Times_prediction.pdf"))
        mlflow.log_artifact(
            os.path.join(figure_path, "Times_prediction_focus.pdf"))
        mlflow.log_artifact(os.path.join(figure_path, "loss_plot.pdf"))
        mlflow.log_artifact(os.path.join(figure_path, "Scatter.pdf"))
        mlflow.log_artifact(os.path.join(figure_path, "Scatter_valid.pdf"))
        mlflow.log_artifact(
            os.path.join(figure_path, "Scatter_after_trans.pdf"))
        mlflow.pytorch.log_model(net, "models")
예제 #3
0
def main(args):

    data_dir = args.data_dir
    figure_path = args.figure_dir
    model_path = args.model_dir

    file_name = "ball_left_mean.npz"

    # Set skip_training to False if the model has to be trained, to True if the model has to be loaded.
    skip_training = False

    # Set the torch device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device = {}".format(device))

    # Initialize parameters
    parameters = Params_cross(subject_n=args.sub,
                              hand=args.hand,
                              batch_size=args.batch_size,
                              valid_batch_size=args.batch_size_valid,
                              test_batch_size=args.batch_size_test,
                              epochs=args.epochs,
                              lr=args.learning_rate,
                              wd=args.weight_decay,
                              patience=args.patience,
                              device=device,
                              desc=args.desc)

    # Set if generate with RPS values or not (check network architecture used later)
    # if mlp = rps-mlp, elif rps = rps-mnet, else mnet
    mlp = False
    rps = True
    print("Creating dataset")

    # Generate the custom dataset
    train_dataset = MEG_Within_Dataset_ivan(data_dir,
                                            parameters.subject_n,
                                            parameters.hand,
                                            mode="train")

    test_dataset = MEG_Within_Dataset_ivan(data_dir,
                                           parameters.subject_n,
                                           parameters.hand,
                                           mode="test")

    valid_dataset = MEG_Within_Dataset_ivan(data_dir,
                                            parameters.subject_n,
                                            parameters.hand,
                                            mode="val")

    # split the dataset in train, test and valid sets.

    print("train set {}, val set {}, test set {}".format(
        len(train_dataset), len(valid_dataset), len(test_dataset)))

    # train_dataset, valid_test, test_dataset = random_split(dataset, [train_len, valid_len, test_len],
    #                                                        generator=torch.Generator().manual_seed(42))
    # train_dataset, valid_test, test_dataset = random_split(dataset, [train_len, valid_len, test_len])
    # Better vizualization
    # train_valid_dataset = Subset(dataset, list(range(train_len+valid_len)))
    # test_dataset = Subset(dataset, list(range(train_len+valid_len, len(dataset))))
    #
    # train_dataset, valid_dataset = random_split(train_valid_dataset, [train_len, valid_len])

    # Initialize the dataloaders
    trainloader = DataLoader(train_dataset,
                             batch_size=parameters.batch_size,
                             shuffle=True,
                             num_workers=1)
    validloader = DataLoader(valid_dataset,
                             batch_size=parameters.valid_batch_size,
                             shuffle=True,
                             num_workers=1)
    testloader = DataLoader(test_dataset,
                            batch_size=parameters.test_batch_size,
                            shuffle=False,
                            num_workers=1)

    # Get the n_times dimension

    if mlp:
        net = RPS_MLP()
        # net = RPS_CNN()
    else:
        # Get the n_times dimension
        with torch.no_grad():
            sample, y, _ = iter(trainloader).next()

        n_times = sample.shape[-1]
        if rps:
            net = RPS_MNet_ivan(n_times)
        else:
            net = MNet_ivan(n_times)

    print(net)
    total_params = 0
    for name, parameter in net.named_parameters():
        param = parameter.numel()
        print("param {} : {}".format(name,
                                     param if parameter.requires_grad else 0))
        total_params += param
    print(f"Total Trainable Params: {total_params}")

    # Training loop or model loading
    if not skip_training:
        print("Begin training....")

        # Check the optimizer before running (different from model to model)
        # optimizer = Adam(net.parameters(), lr=parameters.lr)
        optimizer = Adam(net.parameters(),
                         lr=parameters.lr,
                         weight_decay=parameters.wd)
        # optimizer = SGD(net.parameters(), lr=parameters.lr, momentum=0.9, weight_decay=parameters.wd)
        # optimizer = SGD(net.parameters(), lr=parameters.lr, momentum=0.9)

        print("optimizer : ", optimizer)

        scheduler = ReduceLROnPlateau(optimizer,
                                      mode="min",
                                      factor=0.5,
                                      patience=15)

        print("scheduler : ", scheduler)

        loss_function = torch.nn.MSELoss()
        # loss_function = torch.nn.L1Loss()
        print("loss :", loss_function)
        start_time = timer.time()

        if mlp:
            net, train_loss, valid_loss = train_bp_MLP(
                net,
                trainloader,
                validloader,
                optimizer,
                scheduler,
                loss_function,
                parameters.device,
                parameters.epochs,
                parameters.patience,
                parameters.hand,
                model_path,
            )
        else:
            if rps:
                net, train_loss, valid_loss = train_bp(
                    net,
                    trainloader,
                    validloader,
                    optimizer,
                    scheduler,
                    loss_function,
                    parameters.device,
                    parameters.epochs,
                    parameters.patience,
                    parameters.hand,
                    model_path,
                )
            else:
                net, train_loss, valid_loss = train(
                    net,
                    trainloader,
                    validloader,
                    optimizer,
                    scheduler,
                    loss_function,
                    parameters.device,
                    parameters.epochs,
                    parameters.patience,
                    parameters.hand,
                    model_path,
                )

        train_time = timer.time() - start_time
        print("Training done in {:.4f}".format(train_time))

        # visualize the loss as the network trained
        fig = plt.figure(figsize=(10, 4))
        plt.plot(range(1,
                       len(train_loss) + 1),
                 train_loss,
                 label='Training Loss')
        plt.plot(range(1,
                       len(valid_loss) + 1),
                 valid_loss,
                 label='Validation Loss')

        # find position of lowest validation loss
        minposs = valid_loss.index(min(valid_loss)) + 1
        plt.axvline(minposs,
                    linestyle='--',
                    color='r',
                    label='Early Stopping Checkpoint')

        plt.xlabel("epochs")
        plt.ylabel("loss")
        # plt.ylim(0, 0.5) # consistent scale
        # plt.xlim(0, len(train_loss)+1) # consistent scale
        plt.grid(True)
        plt.legend()
        plt.tight_layout()
        plt.show()
        image1 = fig
        plt.savefig(os.path.join(figure_path, "loss_plot.pdf"))

    if not skip_training:
        # Save the trained model
        save_pytorch_model(net, model_path, "model.pth")
    else:
        # Load the model (properly select the model architecture)
        net = RPS_MNet()
        net = load_pytorch_model(net, os.path.join(model_path, "model.pth"),
                                 parameters.device)

    # Evaluation
    print("Evaluation...")
    net.eval()
    y_pred = []
    y = []
    y_pred_valid = []
    y_valid = []

    # if RPS integration
    with torch.no_grad():
        if mlp:
            for _, labels, bp in testloader:
                labels, bp = labels.to(parameters.device), \
                             bp.to(parameters.device)
                y.extend(list(labels[:, parameters.hand]))
                y_pred.extend((list(net(bp))))

            for _, labels, bp in validloader:
                labels, bp = (
                    labels.to(parameters.device),
                    bp.to(parameters.device),
                )
                y_valid.extend(list(labels[:, parameters.hand]))
                y_pred_valid.extend((list(net(bp))))
        else:
            if rps:
                for data, labels, bp in testloader:
                    data, labels, bp = (
                        data.to(parameters.device),
                        labels.to(parameters.device),
                        bp.to(parameters.device),
                    )
                    y.extend(list(labels[:, parameters.hand]))
                    y_pred.extend((list(net(data, bp))))

                for data, labels, bp in validloader:
                    data, labels, bp = (
                        data.to(parameters.device),
                        labels.to(parameters.device),
                        bp.to(parameters.device),
                    )
                    y_valid.extend(list(labels[:, parameters.hand]))
                    y_pred_valid.extend((list(net(data, bp))))

            else:
                for data, labels, _ in testloader:
                    data, labels = (
                        data.to(parameters.device),
                        labels.to(parameters.device),
                    )
                    y.extend(list(labels[:, parameters.hand]))
                    y_pred.extend((list(net(data))))

                for data, labels, _ in validloader:
                    data, labels = (
                        data.to(parameters.device),
                        labels.to(parameters.device),
                    )
                    y_valid.extend(list(labels[:, parameters.hand]))
                    y_pred_valid.extend((list(net(data))))

    # Calculate Evaluation measures
    print("Evaluation measures")
    mse = mean_squared_error(y, y_pred)
    rmse = mean_squared_error(y, y_pred, squared=False)
    mae = mean_absolute_error(y, y_pred)
    r2 = r2_score(y, y_pred)

    rmse_valid = mean_squared_error(y_valid, y_pred_valid, squared=False)
    r2_valid = r2_score(y_valid, y_pred_valid)
    valid_loss_last = min(valid_loss)

    print("Test set ")
    print("mean squared error {}".format(mse))
    print("root mean squared error {}".format(rmse))
    print("mean absolute error {}".format(mae))
    print("r2 score {}".format(r2))

    print("Validation set")
    print("root mean squared error valid {}".format(rmse_valid))
    print("r2 score valid {}".format(r2_valid))
    print("last value of the validation loss: {}".format(valid_loss_last))

    # plot y_new against the true value focus on 200 timepoints
    fig, ax = plt.subplots(1, 1, figsize=[14, 6])
    times = np.arange(1000)
    ax.plot(times, y_pred[:1000], color="b", label="Predicted")
    ax.plot(times, y[:1000], color="r", label="True")
    ax.set_xlabel("Times")
    ax.set_ylabel("Target")
    ax.set_title("Sub {}, hand {} prediction".format(
        str(parameters.subject_n), "sx" if parameters.hand == 0 else "dx"))
    plt.legend()
    plt.savefig(os.path.join(figure_path, "Times_prediction_focus.pdf"))
    plt.show()

    # plot y_new against the true value
    fig, ax = plt.subplots(1, 1, figsize=[10, 4])
    times = np.arange(len(y_pred))
    ax.plot(times, y_pred, color="b", label="Predicted")
    ax.plot(times, y, color="r", label="True")
    ax.set_xlabel("Times")
    ax.set_ylabel("Target")
    ax.set_title("Sub {}, hand {}, prediction".format(
        str(parameters.subject_n), "sx" if parameters.hand == 0 else "dx"))
    plt.legend()
    plt.savefig(os.path.join(figure_path, "Times_prediction.pdf"))
    plt.show()

    # scatterplot y predicted against the true value
    fig, ax = plt.subplots(1, 1, figsize=[10, 4])
    ax.scatter(np.array(y), np.array(y_pred), color="b", label="Predicted")
    ax.set_xlabel("True")
    ax.set_ylabel("Predicted")
    # plt.legend()
    plt.savefig(os.path.join(figure_path, "Scatter.pdf"))
    plt.show()

    # scatterplot y predicted against the true value
    fig, ax = plt.subplots(1, 1, figsize=[10, 4])
    ax.scatter(np.array(y_valid),
               np.array(y_pred_valid),
               color="b",
               label="Predicted")
    ax.set_xlabel("True")
    ax.set_ylabel("Predicted")
    # plt.legend()
    plt.savefig(os.path.join(figure_path, "Scatter_valid.pdf"))
    plt.show()

    # Save prediction for post analysis

    out_file = "prediction_sub{}_hand_{}.npz".format(
        str(parameters.subject_n), "left" if parameters.hand == 0 else "right")
    np.savez(os.path.join(data_dir, out_file), y_pred=y_pred, y=y)

    # log the model and parameters using mlflow tracker
    with mlflow.start_run(experiment_id=args.experiment) as run:
        for key, value in vars(parameters).items():
            mlflow.log_param(key, value)

        mlflow.log_param("Time", train_time)

        mlflow.log_metric("MSE", mse)
        mlflow.log_metric("RMSE", rmse)
        mlflow.log_metric("MAE", mae)
        mlflow.log_metric("R2", r2)
        mlflow.log_metric("RMSE_Valid", rmse_valid)
        mlflow.log_metric("R2_Valid", r2_valid)
        mlflow.log_metric("Valid_loss", valid_loss_last)

        mlflow.log_artifact(os.path.join(figure_path, "Times_prediction.pdf"))
        mlflow.log_artifact(
            os.path.join(figure_path, "Times_prediction_focus.pdf"))
        mlflow.log_artifact(os.path.join(figure_path, "loss_plot.pdf"))
        mlflow.log_artifact(os.path.join(figure_path, "Scatter.pdf"))
        mlflow.log_artifact(os.path.join(figure_path, "Scatter_valid.pdf"))
        mlflow.pytorch.log_model(net, "models")
예제 #4
0
def main(args):

    data_dir = args.data_dir
    figure_path = args.figure_dir
    model_path = args.model_dir

    # Generate the data input path list. Each subject has 3 runs stored in 3 different files.
    subj_id = "/sub" + str(args.sub) + "/ball0"
    raw_fnames = [
        "".join([data_dir, subj_id, str(i), "_sss_trans.fif"])
        for i in range(1 if args.sub != 3 else 2, 4)
    ]

    # local
    # subj_id = "/sub"+str(args.sub)+"/ball"
    # raw_fnames = ["".join([data_dir, subj_id, str(i), "_sss.fif"]) for i in range(1, 2)]

    # Set skip_training to False if the model has to be trained, to True if the model has to be loaded.
    skip_training = False

    # Set the torch device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device = {}".format(device))

    # Initialize parameters
    parameters = Params_tunable(
        subject_n=args.sub,
        hand=args.hand,
        batch_size=args.batch_size,
        valid_batch_size=args.batch_size_valid,
        test_batch_size=args.batch_size_test,
        epochs=args.epochs,
        lr=args.learning_rate,
        duration=args.duration,
        overlap=args.overlap,
        patience=args.patience,
        device=device,
        y_measure=args.y_measure,
        s_n_layer=args.s_n_layer,
        # s_kernel_size=args.s_kernel_size,  # Local
        s_kernel_size=json.loads(" ".join(args.s_kernel_size)),
        t_n_layer=args.t_n_layer,
        # t_kernel_size=args.t_kernel_size,  # Local
        t_kernel_size=json.loads(" ".join(args.t_kernel_size)),
        max_pooling=args.max_pooling,
        ff_n_layer=args.ff_n_layer,
        ff_hidden_channels=args.ff_hidden_channels,
        dropout=args.dropout,
        activation=args.activation,
    )

    # Set if generate with RPS values or not (check network architecture used later)
    rps = True

    # Generate the custom dataset
    if rps:
        dataset = MEG_Dataset(
            raw_fnames,
            parameters.duration,
            parameters.overlap,
            parameters.y_measure,
            normalize_input=True,
        )
    else:
        dataset = MEG_Dataset_no_bp(
            raw_fnames,
            parameters.duration,
            parameters.overlap,
            parameters.y_measure,
            normalize_input=True,
        )

    # split the dataset in train, test and valid sets.
    train_len, valid_len, test_len = len_split(len(dataset))
    print(
        "{} + {} + {} = {}?".format(
            train_len, valid_len, test_len, len(dataset)
        )
    )

    # train_dataset, valid_test, test_dataset = random_split(dataset, [train_len, valid_len, test_len],
    #                                                        generator=torch.Generator().manual_seed(42))
    train_dataset, valid_test, test_dataset = random_split(
        dataset, [train_len, valid_len, test_len]
    )

    # Better vizualization
    # train_valid_dataset = Subset(dataset, list(range(train_len+valid_len)))
    # test_dataset = Subset(dataset, list(range(train_len+valid_len, len(dataset))))
    #
    # train_dataset, valid_dataset = random_split(train_valid_dataset, [train_len, valid_len])

    # Initialize the dataloaders
    trainloader = DataLoader(
        train_dataset,
        batch_size=parameters.batch_size,
        shuffle=True,
        num_workers=1,
    )
    validloader = DataLoader(
        valid_test,
        batch_size=parameters.valid_batch_size,
        shuffle=True,
        num_workers=1,
    )
    testloader = DataLoader(
        test_dataset,
        batch_size=parameters.test_batch_size,
        shuffle=False,
        num_workers=1,
    )

    # Get the n_times dimension
    with torch.no_grad():
        # Changes if RPS integration or not
        if rps:
            x, _, _ = iter(trainloader).next()
        else:
            x, _ = iter(trainloader).next()

    n_times = x.shape[-1]

    # Initialize network
    # net = LeNet5(n_times)
    # net = ResNet([2, 2, 2], 64, n_times)
    # net = SCNN(parameters.s_n_layer,
    #                    parameters.s_kernel_size,
    #                    parameters.t_n_layer,
    #                    parameters.t_kernel_size,
    #                    n_times,
    #                    parameters.ff_n_layer,
    #                    parameters.ff_hidden_channels,
    #                    parameters.dropout,
    #                    parameters.max_pooling,
    #                    parameters.activation)
    # net = MNet(n_times)
    # net = RPS_SCNN(parameters.s_n_layer,
    #                    parameters.s_kernel_size,
    #                    parameters.t_n_layer,
    #                    parameters.t_kernel_size,
    #                    n_times,
    #                    parameters.ff_n_layer,
    #                    parameters.ff_hidden_channels,
    #                    parameters.dropout,
    #                    parameters.max_pooling,
    #                    parameters.activation)

    net = RPS_MNet(n_times)
    # net = RPS_MLP()
    mlp = False

    print(net)
    # Training loop or model loading
    if not skip_training:
        print("Begin training....")

        # Check the optimizer before running (different from model to model)
        optimizer = Adam(net.parameters(), lr=parameters.lr, weight_decay=5e-4)
        # optimizer = SGD(net.parameters(), lr=parameters.lr, weight_decay=5e-4)

        scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.5,
                                      patience=15)

        print("scheduler : ", scheduler)

        loss_function = torch.nn.MSELoss()
        start_time = timer.time()
        if rps:
            if mlp:
                net, train_loss, valid_loss = train_bp_MLP(
                    net,
                    trainloader,
                    validloader,
                    optimizer,
                    scheduler,
                    loss_function,
                    parameters.device,
                    parameters.epochs,
                    parameters.patience,
                    parameters.hand,
                    model_path,
                )
            else:
                net, train_loss, valid_loss = train_bp(
                    net,
                    trainloader,
                    validloader,
                    optimizer,
                    scheduler,
                    loss_function,
                    parameters.device,
                    parameters.epochs,
                    parameters.patience,
                    parameters.hand,
                    model_path,
                )
        else:
            net, train_loss, valid_loss = train(
                net,
                trainloader,
                validloader,
                optimizer,
                scheduler,
                loss_function,
                parameters.device,
                parameters.epochs,
                parameters.patience,
                parameters.hand,
                model_path,
            )

        train_time = timer.time() - start_time
        print("Training done in {:.4f}".format(train_time))

        # visualize the loss as the network trained
        fig = plt.figure(figsize=(10, 4))
        plt.plot(
            range(1, len(train_loss) + 1), train_loss, label="Training Loss"
        )
        plt.plot(
            range(1, len(valid_loss) + 1), valid_loss, label="Validation Loss"
        )

        # find position of lowest validation loss
        minposs = valid_loss.index(min(valid_loss)) + 1
        plt.axvline(
            minposs,
            linestyle="--",
            color="r",
            label="Early Stopping Checkpoint",
        )

        plt.xlabel("epochs")
        plt.ylabel("loss")
        # plt.ylim(0, 0.5) # consistent scale
        # plt.xlim(0, len(train_loss)+1) # consistent scale
        plt.grid(True)
        plt.legend()
        plt.tight_layout()
        plt.show()
        image1 = fig
        plt.savefig(os.path.join(figure_path, "loss_plot.pdf"))

    if not skip_training:
        # Save the trained model
        save_pytorch_model(net, model_path, "Baselinemodel_SCNN_swap.pth")
    else:
        # Load the model (properly select the model architecture)
        net = RPS_MNet()
        net = load_pytorch_model(
            net, os.path.join(model_path, "model.pth"), parameters.device
        )

    # Evaluation
    print("Evaluation...")
    net.eval()
    y_pred = []
    y = []

    # if RPS integration
    with torch.no_grad():
        if rps:
            if mlp:
                for _, labels, bp in testloader:
                    labels, bp = labels.to(parameters.device), bp.to(device)
                    y.extend(list(labels[:, parameters.hand]))
                    y_pred.extend((list(net(bp))))
            else:
                for data, labels, bp in testloader:
                    data, labels, bp = (
                        data.to(parameters.device),
                        labels.to(parameters.device),
                        bp.to(device),
                    )
                    y.extend(list(labels[:, parameters.hand]))
                    y_pred.extend((list(net(data, bp))))
        else:
            for data, labels in testloader:
                data, labels = (
                    data.to(parameters.device),
                    labels.to(parameters.device),
                )
                y.extend(list(labels[:, parameters.hand]))
                y_pred.extend((list(net(data))))

    print("SCNN_swap...")
    # Calculate Evaluation measures
    mse = mean_squared_error(y, y_pred)
    rmse = mean_squared_error(y, y_pred, squared=False)
    mae = mean_absolute_error(y, y_pred)
    r2 = r2_score(y, y_pred)
    print("mean squared error {}".format(mse))
    print("root mean squared error {}".format(rmse))
    print("mean absolute error {}".format(mae))
    print("r2 score {}".format(r2))

    # plot y_new against the true value focus on 100 timepoints
    fig, ax = plt.subplots(1, 1, figsize=[10, 4])
    times = np.arange(100)
    ax.plot(times, y_pred[0:100], color="b", label="Predicted")
    ax.plot(times, y[0:100], color="r", label="True")
    ax.set_xlabel("Times")
    ax.set_ylabel("{}".format(parameters.y_measure))
    ax.set_title(
        "Sub {}, hand {}, {} prediction".format(
            str(parameters.subject_n),
            "sx" if parameters.hand == 0 else "dx",
            parameters.y_measure,
        )
    )
    plt.legend()
    plt.savefig(os.path.join(figure_path, "Times_prediction_focus.pdf"))
    plt.show()

    # plot y_new against the true value
    fig, ax = plt.subplots(1, 1, figsize=[10, 4])
    times = np.arange(len(y_pred))
    ax.plot(times, y_pred, color="b", label="Predicted")
    ax.plot(times, y, color="r", label="True")
    ax.set_xlabel("Times")
    ax.set_ylabel("{}".format(parameters.y_measure))
    ax.set_title(
        "Sub {}, hand {}, {} prediction".format(
            str(parameters.subject_n),
            "sx" if parameters.hand == 0 else "dx",
            parameters.y_measure,
        )
    )
    plt.legend()
    plt.savefig(os.path.join(figure_path, "Times_prediction.pdf"))
    plt.show()

    # scatterplot y predicted against the true value
    fig, ax = plt.subplots(1, 1, figsize=[10, 4])
    ax.scatter(np.array(y), np.array(y_pred), color="b", label="Predicted")
    ax.set_xlabel("True")
    ax.set_ylabel("Predicted")
    # plt.legend()
    plt.savefig(os.path.join(figure_path, "Scatter.pdf"))
    plt.show()

    # log the model and parameters using mlflow tracker
    with mlflow.start_run(experiment_id=args.experiment) as run:
        for key, value in vars(parameters).items():
            mlflow.log_param(key, value)

        mlflow.log_param("Time", train_time)

        mlflow.log_metric("MSE", mse)
        mlflow.log_metric("RMSE", rmse)
        mlflow.log_metric("MAE", mae)
        mlflow.log_metric("R2", r2)

        mlflow.log_artifact(os.path.join(figure_path, "Times_prediction.pdf"))
        mlflow.log_artifact(
            os.path.join(figure_path, "Times_prediction_focus.pdf")
        )
        mlflow.log_artifact(os.path.join(figure_path, "loss_plot.pdf"))
        mlflow.log_artifact(os.path.join(figure_path, "Scatter.pdf"))
        mlflow.pytorch.log_model(net, "models")