Exemple #1
0
def run_main(args):

    # Define parameters
    epochs = args.epochs
    dim_au_out = args.bottleneck  #8, 16, 32, 64, 128, 256,512
    dim_dnn_in = dim_au_out
    dim_dnn_out = 1
    select_drug = args.drug
    na = args.missing_value
    data_path = args.data_path
    label_path = args.label_path
    test_size = args.test_size
    valid_size = args.valid_size
    g_disperson = args.var_genes_disp
    model_path = args.source_model_path
    encoder_path = args.encoder_path
    log_path = args.logging_file
    batch_size = args.batch_size
    encoder_hdims = args.encoder_h_dims.split(",")
    preditor_hdims = args.predictor_h_dims.split(",")
    reduce_model = args.dimreduce
    prediction = args.predition
    sampling = args.sampling
    PCA_dim = args.PCA_dim

    encoder_hdims = list(map(int, encoder_hdims))
    preditor_hdims = list(map(int, preditor_hdims))
    load_model = bool(args.load_source_model)

    preditor_path = model_path + reduce_model + args.predictor + prediction + select_drug + '.pkl'

    # Read data
    data_r = pd.read_csv(data_path, index_col=0)
    label_r = pd.read_csv(label_path, index_col=0)
    label_r = label_r.fillna(na)

    now = time.strftime("%Y-%m-%d-%H-%M-%S")

    ut.save_arguments(args, now)

    # Initialize logging and std out
    out_path = log_path + now + ".err"
    log_path = log_path + now + ".log"

    out = open(out_path, "w")
    sys.stderr = out

    logging.basicConfig(
        level=logging.INFO,  #控制台打印的日志级别
        filename=log_path,
        filemode='a',  ##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志
        #a是追加模式,默认如果不写的话,就是追加模式
        format=
        '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'
        #日志格式
    )
    logging.getLogger('matplotlib.font_manager').disabled = True

    logging.info(args)

    # data = data_r

    # Filter out na values
    selected_idx = label_r.loc[:, select_drug] != na

    if (g_disperson != None):
        hvg, adata = ut.highly_variable_genes(data_r, min_disp=g_disperson)
        # Rename columns if duplication exist
        data_r.columns = adata.var_names
        # Extract hvgs
        data = data_r.loc[selected_idx, hvg]
    else:
        data = data_r.loc[selected_idx, :]

    # Do PCA if PCA_dim!=0
    if PCA_dim != 0:
        data = PCA(n_components=PCA_dim).fit_transform(data)
    else:
        data = data

    # Extract labels
    label = label_r.loc[selected_idx, select_drug]

    # Scaling data
    mmscaler = preprocessing.MinMaxScaler()
    lbscaler = preprocessing.MinMaxScaler()

    data = mmscaler.fit_transform(data)
    label = label.values.reshape(-1, 1)

    if prediction == "regression":
        label = lbscaler.fit_transform(label)
        dim_model_out = 1
    else:
        le = LabelEncoder()
        label = le.fit_transform(label)
        dim_model_out = 2

    #label = label.values.reshape(-1,1)

    logging.info(np.std(data))
    logging.info(np.mean(data))

    # Split traning valid test set
    X_train_all, X_test, Y_train_all, Y_test = train_test_split(
        data, label, test_size=test_size, random_state=42)
    X_train, X_valid, Y_train, Y_valid = train_test_split(X_train_all,
                                                          Y_train_all,
                                                          test_size=valid_size,
                                                          random_state=42)
    # sampling method
    if sampling == None:
        X_train, Y_train = sam.nosampling(X_train, Y_train)
        logging.info("nosampling")
    elif sampling == "upsampling":
        X_train, Y_train = sam.upsampling(X_train, Y_train)
        logging.info("upsampling")
    elif sampling == "downsampling":
        X_train, Y_train = sam.downsampling(X_train, Y_train)
        logging.info("downsampling")
    elif sampling == "SMOTE":
        X_train, Y_train = sam.SMOTEsampling(X_train, Y_train)
        logging.info("SMOTE")
    else:
        logging.info("not a legal sampling method")

    logging.info(data.shape)
    logging.info(label.shape)
    #logging.info(X_train.shape, Y_train.shape)
    #logging.info(X_test.shape, Y_test.shape)
    logging.info(X_train.max())
    logging.info(X_train.min())

    # Select the Training device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # Assuming that we are on a CUDA machine, this should print a CUDA device:
    logging.info(device)
    torch.cuda.set_device(device)

    # Construct datasets and data loaders
    X_trainTensor = torch.FloatTensor(X_train).to(device)
    X_validTensor = torch.FloatTensor(X_valid).to(device)
    X_testTensor = torch.FloatTensor(X_test).to(device)
    X_allTensor = torch.FloatTensor(data).to(device)

    if prediction == "regression":
        Y_trainTensor = torch.FloatTensor(Y_train).to(device)
        Y_trainallTensor = torch.FloatTensor(Y_train_all).to(device)
        Y_validTensor = torch.FloatTensor(Y_valid).to(device)
    else:
        Y_trainTensor = torch.LongTensor(Y_train).to(device)
        Y_trainallTensor = torch.LongTensor(Y_train_all).to(device)
        Y_validTensor = torch.LongTensor(Y_valid).to(device)

    train_dataset = TensorDataset(X_trainTensor, X_trainTensor)
    valid_dataset = TensorDataset(X_validTensor, X_validTensor)
    test_dataset = TensorDataset(X_testTensor, X_testTensor)
    all_dataset = TensorDataset(X_allTensor, X_allTensor)

    X_trainDataLoader = DataLoader(dataset=train_dataset,
                                   batch_size=batch_size,
                                   shuffle=True)
    X_validDataLoader = DataLoader(dataset=valid_dataset,
                                   batch_size=batch_size,
                                   shuffle=True)
    X_allDataLoader = DataLoader(dataset=all_dataset,
                                 batch_size=batch_size,
                                 shuffle=True)

    # construct TensorDataset
    trainreducedDataset = TensorDataset(X_trainTensor, Y_trainTensor)
    validreducedDataset = TensorDataset(X_validTensor, Y_validTensor)

    trainDataLoader_p = DataLoader(dataset=trainreducedDataset,
                                   batch_size=batch_size,
                                   shuffle=True)
    validDataLoader_p = DataLoader(dataset=validreducedDataset,
                                   batch_size=batch_size,
                                   shuffle=True)

    dataloaders_train = {'train': trainDataLoader_p, 'val': validDataLoader_p}

    if (bool(args.pretrain) != False):
        dataloaders_pretrain = {
            'train': X_trainDataLoader,
            'val': X_validDataLoader
        }
        if reduce_model == "VAE":
            encoder = VAEBase(input_dim=data.shape[1],
                              latent_dim=dim_au_out,
                              h_dims=encoder_hdims)
        else:
            encoder = AEBase(input_dim=data.shape[1],
                             latent_dim=dim_au_out,
                             h_dims=encoder_hdims)

        if torch.cuda.is_available():
            encoder.cuda()

        logging.info(encoder)
        encoder.to(device)

        optimizer_e = optim.Adam(encoder.parameters(), lr=1e-2)
        loss_function_e = nn.MSELoss()
        exp_lr_scheduler_e = lr_scheduler.ReduceLROnPlateau(optimizer_e)

        if reduce_model == "AE":
            encoder, loss_report_en = t.train_AE_model(
                net=encoder,
                data_loaders=dataloaders_pretrain,
                optimizer=optimizer_e,
                loss_function=loss_function_e,
                n_epochs=epochs,
                scheduler=exp_lr_scheduler_e,
                save_path=encoder_path)
        elif reduce_model == "VAE":
            encoder, loss_report_en = t.train_VAE_model(
                net=encoder,
                data_loaders=dataloaders_pretrain,
                optimizer=optimizer_e,
                n_epochs=epochs,
                scheduler=exp_lr_scheduler_e,
                save_path=encoder_path)

        logging.info("Pretrained finished")

    # Train model of predictor

    if args.predictor == "DNN":
        if reduce_model == "AE":
            model = PretrainedPredictor(input_dim=X_train.shape[1],
                                        latent_dim=dim_au_out,
                                        h_dims=encoder_hdims,
                                        hidden_dims_predictor=preditor_hdims,
                                        output_dim=dim_model_out,
                                        pretrained_weights=encoder_path,
                                        freezed=bool(args.freeze_pretrain))
        elif reduce_model == "VAE":
            model = PretrainedVAEPredictor(
                input_dim=X_train.shape[1],
                latent_dim=dim_au_out,
                h_dims=encoder_hdims,
                hidden_dims_predictor=preditor_hdims,
                output_dim=dim_model_out,
                pretrained_weights=encoder_path,
                freezed=bool(args.freeze_pretrain),
                z_reparam=bool(args.VAErepram))

    elif args.predictor == "GCN":

        if reduce_model == "VAE":
            gcn_encoder = VAEBase(input_dim=data.shape[1],
                                  latent_dim=dim_au_out,
                                  h_dims=encoder_hdims)
        else:
            gcn_encoder = AEBase(input_dim=data.shape[1],
                                 latent_dim=dim_au_out,
                                 h_dims=encoder_hdims)

        gcn_encoder.load_state_dict(torch.load(args.GCNreduce_path))
        gcn_encoder.to(device)

        train_embeddings = gcn_encoder.encode(X_trainTensor)
        zOut_tr = train_embeddings.cpu().detach().numpy()
        valid_embeddings = gcn_encoder.encode(X_validTensor)
        zOut_va = valid_embeddings.cpu().detach().numpy()
        test_embeddings = gcn_encoder.encode(X_testTensor)
        zOut_te = test_embeddings.cpu().detach().numpy()

        adj_tr, edgeList_tr = g.generateAdj(
            zOut_tr,
            graphType='KNNgraphStatsSingleThread',
            para='euclidean' + ':' + str('10'),
            adjTag=True)
        adj_va, edgeList_va = g.generateAdj(
            zOut_va,
            graphType='KNNgraphStatsSingleThread',
            para='euclidean' + ':' + str('10'),
            adjTag=True)
        adj_te, edgeList_te = g.generateAdj(
            zOut_te,
            graphType='KNNgraphStatsSingleThread',
            para='euclidean' + ':' + str('10'),
            adjTag=True)

        Adj_trainTensor = preprocess_graph(adj_tr)
        Adj_validTensor = preprocess_graph(adj_va)
        Adj_testTensor = preprocess_graph(adj_te)

        Z_trainTensor = torch.FloatTensor(zOut_tr).to(device)
        Z_validTensor = torch.FloatTensor(zOut_va).to(device)
        Z_testTensor = torch.FloatTensor(zOut_te).to(device)

        if (args.binarizied == 0):
            zDiscret_tr = zOut_tr > np.mean(zOut_tr, axis=0)
            zDiscret_tr = 1.0 * zDiscret_tr
            zDiscret_va = zOut_va > np.mean(zOut_va, axis=0)
            zDiscret_va = 1.0 * zDiscret_va
            zDiscret_te = zOut_te > np.mean(zOut_te, axis=0)
            zDiscret_te = 1.0 * zDiscret_te

            Z_trainTensor = torch.FloatTensor(zDiscret_tr).to(device)
            Z_validTensor = torch.FloatTensor(zDiscret_va).to(device)
            Z_testTensor = torch.FloatTensor(zDiscret_te).to(device)

        ZTensors_train = {'train': Z_trainTensor, 'val': Z_validTensor}
        XTensors_train = {'train': X_trainTensor, 'val': X_validTensor}

        YTensors_train = {'train': Y_trainTensor, 'val': Y_validTensor}
        AdjTensors_train = {'train': Adj_trainTensor, 'val': Adj_validTensor}

        if (args.GCNfeature == "x"):
            dim_GCNin = X_allTensor.shape[1]
            GCN_trainTensors = XTensors_train
            GCN_testTensor = X_testTensor
        else:
            dim_GCNin = Z_testTensor.shape[1]
            GCN_trainTensors = ZTensors_train
            GCN_testTensor = Z_testTensor

        model = GCNPredictor(input_feat_dim=dim_GCNin,
                             hidden_dim1=encoder_hdims[0],
                             hidden_dim2=dim_au_out,
                             dropout=0.5,
                             hidden_dims_predictor=preditor_hdims,
                             output_dim=dim_model_out,
                             pretrained_weights=encoder_path,
                             freezed=bool(args.freeze_pretrain))

        # model2 = GAEBase(input_dim=X_train_all.shape[1], latent_dim=128,h_dims=[512])
        # model2.to(device)
        # test = model2((X_trainTensor,Adj_trainTensor))

    logging.info(model)
    if torch.cuda.is_available():
        model.cuda()
    model.to(device)

    # Define optimizer
    optimizer = optim.Adam(model.parameters(), lr=1e-2)

    if prediction == "regression":
        loss_function = nn.MSELoss()
    else:
        loss_function = nn.CrossEntropyLoss()

    exp_lr_scheduler = lr_scheduler.ReduceLROnPlateau(optimizer)

    if args.predictor == "GCN":
        model, report = t.train_GCNpreditor_model(model=model,
                                                  z=GCN_trainTensors,
                                                  y=YTensors_train,
                                                  adj=AdjTensors_train,
                                                  optimizer=optimizer,
                                                  loss_function=loss_function,
                                                  n_epochs=epochs,
                                                  scheduler=exp_lr_scheduler,
                                                  save_path=preditor_path)

    else:
        model, report = t.train_predictor_model(model,
                                                dataloaders_train,
                                                optimizer,
                                                loss_function,
                                                epochs,
                                                exp_lr_scheduler,
                                                load=load_model,
                                                save_path=preditor_path)
    if args.predictor != 'GCN':
        dl_result = model(X_testTensor).detach().cpu().numpy()
    else:
        dl_result = model(GCN_testTensor,
                          Adj_testTensor).detach().cpu().numpy()

    #torch.save(model.feature_extractor.state_dict(), preditor_path+"encoder.pkl")

    logging.info('Performances: R/Pearson/Mse/')

    if prediction == "regression":
        logging.info(r2_score(dl_result, Y_test))
        logging.info(pearsonr(dl_result.flatten(), Y_test.flatten()))
        logging.info(mean_squared_error(dl_result, Y_test))
    else:
        lb_results = np.argmax(dl_result, axis=1)
        #pb_results = np.max(dl_result,axis=1)
        pb_results = dl_result[:, 1]

        report_dict = classification_report(Y_test,
                                            lb_results,
                                            output_dict=True)
        report_df = pd.DataFrame(report_dict).T
        ap_score = average_precision_score(Y_test, pb_results)
        auroc_score = roc_auc_score(Y_test, pb_results)

        report_df['auroc_score'] = auroc_score
        report_df['ap_score'] = ap_score

        report_df.to_csv("saved/logs/" + reduce_model + args.predictor +
                         prediction + select_drug + now + '_report.csv')

        logging.info(classification_report(Y_test, lb_results))
        logging.info(average_precision_score(Y_test, pb_results))
        logging.info(roc_auc_score(Y_test, pb_results))

        model = DummyClassifier(strategy='stratified')
        model.fit(X_train, Y_train)
        yhat = model.predict_proba(X_test)
        naive_probs = yhat[:, 1]

        ut.plot_roc_curve(Y_test,
                          naive_probs,
                          pb_results,
                          title=str(roc_auc_score(Y_test, pb_results)),
                          path="saved/figures/" + reduce_model +
                          args.predictor + prediction + select_drug + now +
                          '_roc.pdf')
        ut.plot_pr_curve(Y_test,
                         pb_results,
                         title=average_precision_score(Y_test, pb_results),
                         path="saved/figures/" + reduce_model +
                         args.predictor + prediction + select_drug + now +
                         '_prc.pdf')
def train(args):
    sub_dir = datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S')
    log_dir = os.path.join(args.log_dir, sub_dir)
    model_dir = os.path.join(args.model_dir, sub_dir)
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)

    train_logger = utils.Logger('train',
                                file_name=os.path.join(log_dir, 'train.log'),
                                control_log=False)
    test_logger = utils.Logger('test',
                               file_name=os.path.join(log_dir, 'test.log'))

    utils.save_arguments(args, os.path.join(log_dir, 'arguments.txt'))

    # data
    split_dataset(args)
    base_dir = os.path.dirname(args.dataset_filename)
    train_dataset = load_data(os.path.join(base_dir, 'train.txt'))
    test_dataset = load_data(os.path.join(base_dir, 'test.txt'))

    dataset_size = train_dataset.num_examples
    train_logger.info('dataset size: %s' % dataset_size)

    tf.reset_default_graph()
    with tf.Graph().as_default(), tf.device('/gpu:1'):
        tf.set_random_seed(10)
        x = tf.placeholder(tf.float32,
                           shape=[None, args.timesteps, 1],
                           name='input')
        y = tf.placeholder(tf.int64, shape=[None], name='label')
        one_hot_y = tf.one_hot(y, depth=2, dtype=tf.int64)
        is_training = tf.placeholder(tf.bool, name='training')

        tcn = TemporalConvNet(args.num_channels, args.kernel_size,
                              args.dropout)
        prelogits = tf.layers.dense(
            tcn(x, training=is_training)[:, -1, :],
            args.embedding_size,
            activation=None,
            kernel_initializer=tf.orthogonal_initializer(),
            kernel_regularizer=tf.contrib.layers.l2_regularizer(
                args.weight_decay),
            name='bottleneck')

        logits = tf.layers.dense(
            prelogits,
            args.n_classes,
            activation=None,
            kernel_initializer=tf.contrib.layers.xavier_initializer(),
            kernel_regularizer=tf.contrib.layers.l2_regularizer(
                args.weight_decay),
            name='logits')
        embeddings = tf.nn.l2_normalize(prelogits, axis=1, name='embeddings')

        # accuracy
        tpr_op, fpr_op, g_mean_op, accuracy_op = calc_accuracy(logits, y)

        # loss
        with tf.variable_scope('loss'):
            cross_entropy_op = tf.reduce_mean(
                tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits,
                                                           labels=one_hot_y),
                name='cross_entropy')
            center_loss_op, centers, centers_update_op = utils.center_loss(
                prelogits, y, args.n_classes, args.center_loss_alpha)
            regularization_loss_op = tf.reduce_sum(tf.get_collection(
                tf.GraphKeys.REGULARIZATION_LOSSES),
                                                   name='l2_loss')
            loss_op = center_loss_op * args.center_loss_factor + cross_entropy_op + regularization_loss_op

        # optimizer
        with tf.variable_scope('optimizer'), tf.control_dependencies(
            [centers_update_op]):
            global_step = tf.Variable(0, trainable=False, name='global_step')
            boundaries = [
                int(epoch * dataset_size / args.batch_size)
                for epoch in args.lr_epoch
            ]
            learning_rate_op = tf.train.piecewise_constant(
                global_step, boundaries, args.lr_values, name='learning_rate')
            #optimizer_op = tf.train.MomentumOptimizer(learning_rate_op, 0.9, name='optimizer')
            optimizer = tf.train.AdamOptimizer(args.lr_values[0],
                                               name='optimizer')
            train_op = optimizer.minimize(loss_op, global_step)

        # summary
        tf.summary.scalar('cross_entropy', cross_entropy_op)
        tf.summary.scalar('center_loss', center_loss_op)
        tf.summary.scalar('accuracy', accuracy_op)
        summary_op = tf.summary.merge_all()

        saver = tf.train.Saver(max_to_keep=100)
        config = tf.ConfigProto(allow_soft_placement=True,
                                gpu_options=tf.GPUOptions(allow_growth=True))
        with tf.Session(config=config) as sess:
            writer = tf.summary.FileWriter(log_dir, sess.graph)
            tf.global_variables_initializer().run()

            if args.pretrained_model:
                ckpt = tf.train.get_checkpoint_state(args.pretrained_model)
                saver.restore(sess, ckpt.model_checkpoint_path)

            steps_per_epoch = np.ceil(dataset_size /
                                      args.batch_size).astype(int)
            batch_num_seq = range(steps_per_epoch)
            best_test_accuracy = 0.0
            try:
                for epoch in range(1, args.max_epochs + 1):
                    batch_num_seq = tqdm(batch_num_seq,
                                         desc='Epoch: {:d}'.format(epoch),
                                         ascii=True)
                    for step in batch_num_seq:
                        feature, label = train_dataset.next_batch(
                            args.batch_size)
                        feature = np.reshape(
                            feature, (args.batch_size, args.timesteps, 1))
                        if step % args.display == 0:
                            tensor_list = [
                                train_op, summary_op, global_step, accuracy_op,
                                tpr_op, fpr_op, g_mean_op, cross_entropy_op,
                                center_loss_op, learning_rate_op
                            ]
                            _, summary, train_step, accuracy, tpr, fpr, g_mean, cross_entropy, center_loss, learning_rate = sess.run(
                                tensor_list,
                                feed_dict={
                                    x: feature,
                                    y: label,
                                    is_training: True
                                })
                            train_logger.info(
                                'Train Step: %d, accuracy: %.3f%%, tpr: %.3f%%, fpr: %.3f%%, g_mean: %.3f%%, cross_entropy: %.4f, center_loss: %.4f, learning_rate: %f'
                                % (train_step, accuracy * 100, tpr * 100,
                                   fpr * 100, g_mean * 100, cross_entropy,
                                   center_loss, learning_rate))
                        else:
                            _, summary, train_step = sess.run(
                                [train_op, summary_op, global_step],
                                feed_dict={
                                    x: feature,
                                    y: label,
                                    is_training: True
                                })
                        writer.add_summary(summary, global_step=train_step)
                        writer.flush()

                    # evaluate
                    num_batches = test_dataset.num_examples // args.batch_size
                    accuracy_array = np.zeros((num_batches, ), np.float32)
                    tpr_array = np.zeros((num_batches, ), np.float32)
                    fpr_array = np.zeros((num_batches, ), np.float32)
                    g_mean_array = np.zeros((num_batches, ), np.float32)
                    for i in range(num_batches):
                        feature, label = test_dataset.next_batch(
                            args.batch_size)
                        feature = np.reshape(
                            feature, (args.batch_size, args.timesteps, 1))
                        tensor_list = [accuracy_op, tpr_op, fpr_op, g_mean_op]
                        feed_dict = {x: feature, y: label, is_training: False}
                        accuracy_array[i], tpr_array[i], fpr_array[
                            i], g_mean_array[i] = sess.run(tensor_list,
                                                           feed_dict=feed_dict)
                    test_logger.info(
                        'Validation Epoch: %d, train_step: %d, accuracy: %.3f%%, tpr: %.3f%%, fpr: %.3f%%, g_mean: %.3f%%'
                        % (epoch, train_step, np.mean(accuracy_array) * 100,
                           np.mean(tpr_array) * 100, np.mean(fpr_array) * 100,
                           np.mean(g_mean_array) * 100))

                    test_accuracy = np.mean(accuracy_array)
                    if test_accuracy > best_test_accuracy:
                        best_test_accuracy = test_accuracy
                        saver.save(sess,
                                   os.path.join(model_dir, 'arc_fault'),
                                   global_step=train_step)

            except Exception as e:
                train_logger.error(e)
            writer.close()
Exemple #3
0
def run_main(args):
################################################# START SECTION OF LOADING PARAMETERS #################################################
    # Read parameters
    epochs = args.epochs
    dim_au_out = args.bottleneck #8, 16, 32, 64, 128, 256,512
    freeze = args.freeze_pretrain
    source_model_path = args.source_model_path
    target_model_path = args.target_model_path 
    log_path = args.logging_file
    encoder_hdims = args.source_h_dims.split(",")
    encoder_hdims = list(map(int, encoder_hdims))
    pretrain = args.pretrain
    prediction = args.predition
    reduce_model = args.dimreduce
    predict_hdims = args.p_h_dims.split(",")
    predict_hdims = list(map(int, predict_hdims))
    load_model = True

    
    # Misc
    now=time.strftime("%Y-%m-%d-%H-%M-%S")
    # Initialize logging and std out
    out_path = log_path+now+".err"
    log_path = log_path+now+".log"

    out=open(out_path,"w")
    sys.stderr=out
    
    #Logging infomaion
    logging.basicConfig(level=logging.INFO,
                    filename=log_path,
                    filemode='a',
                    format=
                    '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'
                    )
    logging.getLogger('matplotlib.font_manager').disabled = True


    logging.info(args)
    
    # Save arguments
    args_df = ut.save_arguments(args,now)
################################################# END SECTION OF LOADING PARAMETERS #################################################

    # Select the device of gpu
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # Assuming that we are on a CUDA machine, this should print a CUDA device:
    logging.info(device)
    torch.cuda.set_device(device)
################################################# END SECTION OF LOADING BULK DATA  #################################################

################################################# START SECTION OF MODEL CUNSTRUCTION  #################################################
    # Construct target encoder
    if reduce_model == "AE":
        encoder = AEBase(input_dim=data.shape[1],latent_dim=dim_au_out,h_dims=encoder_hdims)
        loss_function_e = nn.MSELoss()
    elif reduce_model == "VAE":
        encoder = VAEBase(input_dim=data.shape[1],latent_dim=dim_au_out,h_dims=encoder_hdims)
    # elif reduce_model == "CVAE":
    #     # Number of condition is equal to the number of clusters
    #     encoder = CVAEBase(input_dim=data.shape[1],n_conditions=len(set(data_c)),latent_dim=dim_au_out,h_dims=encoder_hdims)

    if torch.cuda.is_available():
        encoder.cuda()

    logging.info("Target encoder structure is: ")
    logging.info(encoder)

    encoder.to(device)
    optimizer_e = optim.Adam(encoder.parameters(), lr=1e-2)
    loss_function_e = nn.MSELoss()
    exp_lr_scheduler_e = lr_scheduler.ReduceLROnPlateau(optimizer_e)


    # Load source model before transfer
    if prediction == "regression":
            dim_model_out = 1
    else:
            dim_model_out = 2
    # Load AE model
    if reduce_model == "AE":

        source_model = PretrainedPredictor(input_dim=Xsource_train.shape[1],latent_dim=dim_au_out,h_dims=encoder_hdims, 
                hidden_dims_predictor=predict_hdims,output_dim=dim_model_out,
                pretrained_weights=None,freezed=freeze)
        source_model.load_state_dict(torch.load(source_model_path))
        source_encoder = source_model
    # Load VAE model
    elif reduce_model in ["VAE","CVAE"]:
        source_model = PretrainedVAEPredictor(input_dim=Xsource_train.shape[1],latent_dim=dim_au_out,h_dims=encoder_hdims, 
                hidden_dims_predictor=predict_hdims,output_dim=dim_model_out,
                pretrained_weights=None,freezed=freeze,z_reparam=bool(args.VAErepram))
        source_model.load_state_dict(torch.load(source_model_path))
        source_encoder = source_model
    logging.info("Load pretrained source model from: "+source_model_path)
           
    source_encoder.to(device)
################################################# END SECTION OF MODEL CUNSTRUCTION  #################################################

################################################# START SECTION OF SC MODEL PRETRAININIG  #################################################
    # Pretrain target encoder
    # Pretain using autoencoder is pretrain is not False
    if(str(pretrain)!='0'):
        # Pretrained target encoder if there are not stored files in the harddisk
        train_flag = True
        pretrain = str(pretrain)
        
        try:
            encoder.load_state_dict(torch.load(pretrain))
            logging.info("Load pretrained target encoder from "+pretrain)
            train_flag = False

        except:
            logging.warning("Loading failed, procceed to re-train model")


            logging.info("Pretrained finished")
################################################# END SECTION OF SC MODEL PRETRAININIG  #################################################

################################################# START SECTION OF TRANSFER LEARNING TRAINING #################################################
    # Using ADDA transfer learning
    if args.transfer =='ADDA':

        # Set discriminator model
        discriminator = Predictor(input_dim=dim_au_out,output_dim=2)
        discriminator.to(device)
        loss_d = nn.CrossEntropyLoss()
        optimizer_d = optim.Adam(encoder.parameters(), lr=1e-2)
        exp_lr_scheduler_d = lr_scheduler.ReduceLROnPlateau(optimizer_d)

        # Adversairal trainning
        discriminator,encoder, report_, report2_ = t.train_ADDA_model(source_encoder,encoder,discriminator,
                            dataloaders_source,dataloaders_pretrain,
                            loss_d,loss_d,
                            # Should here be all optimizer d?
                            optimizer_d,optimizer_d,
                            exp_lr_scheduler_d,exp_lr_scheduler_d,
                            epochs,device,
                            target_model_path)

        logging.info("Transfer ADDA finished")
        

    # DaNN model
    elif args.transfer == 'DaNN':

        # Set predictor loss
        loss_d = nn.CrossEntropyLoss()
        optimizer_d = optim.Adam(encoder.parameters(), lr=1e-2)
        exp_lr_scheduler_d = lr_scheduler.ReduceLROnPlateau(optimizer_d)

        # Set DaNN model
        DaNN_model = DaNN(source_model=source_encoder,target_model=encoder)
        DaNN_model.to(device)

        def loss(x,y,GAMMA=args.GAMMA_mmd):
            result = mmd.mmd_loss(x,y,GAMMA)
            return result

        loss_disrtibution = loss

        # Tran DaNN model
        DaNN_model, report_ = t.train_DaNN_model(DaNN_model,
                            dataloaders_source,dataloaders_pretrain,
                            # Should here be all optimizer d?
                            optimizer_d, loss_d,
                            epochs,exp_lr_scheduler_d,
                            dist_loss=loss_disrtibution,
                            load=load_model,
                            weight = args.mmd_weight,
                            save_path=target_model_path+"_DaNN.pkl")

        encoder = DaNN_model.target_model
        source_model = DaNN_model.source_model
        logging.info("Transfer DaNN finished")
Exemple #4
0
def run_main(args):
    ################################################# START SECTION OF LOADING PARAMETERS #################################################
    # Read parameters
    epochs = args.epochs
    dim_au_out = args.bottleneck  #8, 16, 32, 64, 128, 256,512
    na = args.missing_value
    data_path = DATA_MAP[args.target_data]
    test_size = args.test_size
    select_drug = args.drug
    freeze = args.freeze_pretrain
    valid_size = args.valid_size
    g_disperson = args.var_genes_disp
    min_n_genes = args.min_n_genes
    max_n_genes = args.max_n_genes
    source_model_path = args.source_model_path
    target_model_path = args.target_model_path
    log_path = args.logging_file
    batch_size = args.batch_size
    encoder_hdims = args.source_h_dims.split(",")
    encoder_hdims = list(map(int, encoder_hdims))
    source_data_path = args.source_data
    pretrain = args.pretrain
    prediction = args.predition
    data_name = args.target_data
    label_path = args.label_path
    reduce_model = args.dimreduce
    predict_hdims = args.p_h_dims.split(",")
    predict_hdims = list(map(int, predict_hdims))
    leiden_res = args.cluster_res
    load_model = bool(args.load_target_model)

    # Misc
    now = time.strftime("%Y-%m-%d-%H-%M-%S")
    # Initialize logging and std out
    out_path = log_path + now + ".err"
    log_path = log_path + now + ".log"

    out = open(out_path, "w")
    sys.stderr = out

    #Logging infomaion
    logging.basicConfig(
        level=logging.INFO,
        filename=log_path,
        filemode='a',
        format=
        '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'
    )
    logging.getLogger('matplotlib.font_manager').disabled = True

    logging.info(args)

    # Save arguments
    args_df = ut.save_arguments(args, now)
    ################################################# END SECTION OF LOADING PARAMETERS #################################################

    ################################################# START SECTION OF SINGLE CELL DATA REPROCESSING #################################################
    # Load data and preprocessing
    adata = pp.read_sc_file(data_path)

    if data_name == 'GSE117872':
        adata = ut.specific_process(adata,
                                    dataname=data_name,
                                    select_origin=args.batch_id)
    elif data_name == 'GSE122843':
        adata = ut.specific_process(adata, dataname=data_name)
    elif data_name == 'GSE110894':
        adata = ut.specific_process(adata, dataname=data_name)
    elif data_name == 'GSE112274':
        adata = ut.specific_process(adata, dataname=data_name)
    elif data_name == 'GSE116237':
        adata = ut.specific_process(adata, dataname=data_name)
    elif data_name == 'GSE108383':
        adata = ut.specific_process(adata, dataname=data_name)
    elif data_name == 'GSE140440':
        adata = ut.specific_process(adata, dataname=data_name)
    elif data_name == 'GSE129730':
        adata = ut.specific_process(adata, dataname=data_name)
    elif data_name == 'GSE149383':
        adata = ut.specific_process(adata, dataname=data_name)
    else:
        adata = adata

    sc.pp.filter_cells(adata, min_genes=200)
    sc.pp.filter_genes(adata, min_cells=3)

    adata = pp.cal_ncount_ngenes(adata)

    # Show statisctic after QX
    sc.pl.violin(adata,
                 ['n_genes_by_counts', 'total_counts', 'pct_counts_mt-'],
                 jitter=0.4,
                 multi_panel=True,
                 save=data_name,
                 show=False)
    sc.pl.scatter(adata, x='total_counts', y='pct_counts_mt-', show=False)
    sc.pl.scatter(adata, x='total_counts', y='n_genes_by_counts', show=False)

    if args.remove_genes == 0:
        r_genes = []
    else:
        r_genes = REMOVE_GENES
    #Preprocess data by filtering
    if data_name not in ['GSE112274', 'GSE140440']:
        adata = pp.receipe_my(adata,
                              l_n_genes=min_n_genes,
                              r_n_genes=max_n_genes,
                              filter_mincells=args.min_c,
                              filter_mingenes=args.min_g,
                              normalize=True,
                              log=True,
                              remove_genes=r_genes)
    else:
        adata = pp.receipe_my(adata,
                              l_n_genes=min_n_genes,
                              r_n_genes=max_n_genes,
                              filter_mincells=args.min_c,
                              percent_mito=100,
                              filter_mingenes=args.min_g,
                              normalize=True,
                              log=True,
                              remove_genes=r_genes)

    # Select highly variable genes
    sc.pp.highly_variable_genes(adata,
                                min_disp=g_disperson,
                                max_disp=np.inf,
                                max_mean=6)
    sc.pl.highly_variable_genes(adata, save=data_name, show=False)
    adata.raw = adata
    adata = adata[:, adata.var.highly_variable]

    # Preprocess data if spcific process is required
    data = adata.X
    # PCA
    # Generate neighbor graph
    sc.tl.pca(adata, svd_solver='arpack')
    sc.pp.neighbors(adata, n_neighbors=10)
    # Generate cluster labels
    sc.tl.leiden(adata, resolution=leiden_res)
    sc.tl.umap(adata)
    sc.pl.umap(adata,
               color=['leiden'],
               save=data_name + 'umap' + now,
               show=False)
    adata.obs['leiden_origin'] = adata.obs['leiden']
    adata.obsm['X_umap_origin'] = adata.obsm['X_umap']
    data_c = adata.obs['leiden'].astype("long").to_list()
    ################################################# END SECTION OF SINGLE CELL DATA REPROCESSING #################################################

    ################################################# START SECTION OF LOADING SC DATA TO THE TENSORS #################################################
    #Prepare to normailize and split target data
    mmscaler = preprocessing.MinMaxScaler()

    try:
        data = mmscaler.fit_transform(data)

    except:
        logging.warning("Only one class, no ROC")

        # Process sparse data
        data = data.todense()
        data = mmscaler.fit_transform(data)

    # Split data to train and valid set
    # Along with the leiden conditions for CVAE propose
    Xtarget_train, Xtarget_valid, Ctarget_train, Ctarget_valid = train_test_split(
        data, data_c, test_size=valid_size, random_state=42)

    # Select the device of gpu
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # Assuming that we are on a CUDA machine, this should print a CUDA device:
    logging.info(device)
    try:
        torch.cuda.set_device(device)
    except:
        logging.warning("No GPU detected, will apply cpu to process")

    # Construct datasets and data loaders
    Xtarget_trainTensor = torch.FloatTensor(Xtarget_train).to(device)
    Xtarget_validTensor = torch.FloatTensor(Xtarget_valid).to(device)

    # Use leiden label if CVAE is applied
    Ctarget_trainTensor = torch.LongTensor(Ctarget_train).to(device)
    Ctarget_validTensor = torch.LongTensor(Ctarget_valid).to(device)

    X_allTensor = torch.FloatTensor(data).to(device)
    C_allTensor = torch.LongTensor(data_c).to(device)

    train_dataset = TensorDataset(Xtarget_trainTensor, Ctarget_trainTensor)
    valid_dataset = TensorDataset(Xtarget_validTensor, Ctarget_validTensor)

    Xtarget_trainDataLoader = DataLoader(dataset=train_dataset,
                                         batch_size=batch_size,
                                         shuffle=True)
    Xtarget_validDataLoader = DataLoader(dataset=valid_dataset,
                                         batch_size=batch_size,
                                         shuffle=True)

    dataloaders_pretrain = {
        'train': Xtarget_trainDataLoader,
        'val': Xtarget_validDataLoader
    }
    ################################################# START SECTION OF LOADING SC DATA TO THE TENSORS #################################################

    ################################################# START SECTION OF LOADING BULK DATA  #################################################
    # Read source data
    data_r = pd.read_csv(source_data_path, index_col=0)
    label_r = pd.read_csv(label_path, index_col=0)
    label_r = label_r.fillna(na)

    # Extract labels
    selected_idx = label_r.loc[:, select_drug] != na
    label = label_r.loc[selected_idx, select_drug]

    label = label.values.reshape(-1, 1)

    if prediction == "regression":
        lbscaler = preprocessing.MinMaxScaler()
        label = lbscaler.fit_transform(label)
        dim_model_out = 1
    else:
        le = preprocessing.LabelEncoder()
        label = le.fit_transform(label)
        dim_model_out = 2

    # Process source data
    mmscaler = preprocessing.MinMaxScaler()
    source_data = mmscaler.fit_transform(data_r)

    # Split source data
    Xsource_train_all, Xsource_test, Ysource_train_all, Ysource_test = train_test_split(
        source_data, label, test_size=test_size, random_state=42)
    Xsource_train, Xsource_valid, Ysource_train, Ysource_valid = train_test_split(
        Xsource_train_all,
        Ysource_train_all,
        test_size=valid_size,
        random_state=42)

    # Transform source data
    # Construct datasets and data loaders
    Xsource_trainTensor = torch.FloatTensor(Xsource_train).to(device)
    Xsource_validTensor = torch.FloatTensor(Xsource_valid).to(device)

    if prediction == "regression":
        Ysource_trainTensor = torch.FloatTensor(Ysource_train).to(device)
        Ysource_validTensor = torch.FloatTensor(Ysource_valid).to(device)
    else:
        Ysource_trainTensor = torch.LongTensor(Ysource_train).to(device)
        Ysource_validTensor = torch.LongTensor(Ysource_valid).to(device)

    sourcetrain_dataset = TensorDataset(Xsource_trainTensor,
                                        Ysource_trainTensor)
    sourcevalid_dataset = TensorDataset(Xsource_validTensor,
                                        Ysource_validTensor)

    Xsource_trainDataLoader = DataLoader(dataset=sourcetrain_dataset,
                                         batch_size=batch_size,
                                         shuffle=True)
    Xsource_validDataLoader = DataLoader(dataset=sourcevalid_dataset,
                                         batch_size=batch_size,
                                         shuffle=True)

    dataloaders_source = {
        'train': Xsource_trainDataLoader,
        'val': Xsource_validDataLoader
    }
    ################################################# END SECTION OF LOADING BULK DATA  #################################################

    ################################################# START SECTION OF MODEL CUNSTRUCTION  #################################################
    # Construct target encoder
    if reduce_model == "AE":
        encoder = AEBase(input_dim=data.shape[1],
                         latent_dim=dim_au_out,
                         h_dims=encoder_hdims)
        loss_function_e = nn.MSELoss()
    elif reduce_model == "VAE":
        encoder = VAEBase(input_dim=data.shape[1],
                          latent_dim=dim_au_out,
                          h_dims=encoder_hdims)
    elif reduce_model == "CVAE":
        # Number of condition is equal to the number of clusters
        encoder = CVAEBase(input_dim=data.shape[1],
                           n_conditions=len(set(data_c)),
                           latent_dim=dim_au_out,
                           h_dims=encoder_hdims)

    if torch.cuda.is_available():
        encoder.cuda()

    logging.info("Target encoder structure is: ")
    logging.info(encoder)

    encoder.to(device)
    optimizer_e = optim.Adam(encoder.parameters(), lr=1e-2)
    loss_function_e = nn.MSELoss()
    exp_lr_scheduler_e = lr_scheduler.ReduceLROnPlateau(optimizer_e)

    # Load source model before transfer
    if prediction == "regression":
        dim_model_out = 1
    else:
        dim_model_out = 2
    # Load AE model
    if reduce_model == "AE":

        source_model = PretrainedPredictor(input_dim=Xsource_train.shape[1],
                                           latent_dim=dim_au_out,
                                           h_dims=encoder_hdims,
                                           hidden_dims_predictor=predict_hdims,
                                           output_dim=dim_model_out,
                                           pretrained_weights=None,
                                           freezed=freeze)
        source_model.load_state_dict(torch.load(source_model_path))
        source_encoder = source_model
    # Load VAE model
    elif reduce_model in ["VAE", "CVAE"]:
        source_model = PretrainedVAEPredictor(
            input_dim=Xsource_train.shape[1],
            latent_dim=dim_au_out,
            h_dims=encoder_hdims,
            hidden_dims_predictor=predict_hdims,
            output_dim=dim_model_out,
            pretrained_weights=None,
            freezed=freeze,
            z_reparam=bool(args.VAErepram))
        source_model.load_state_dict(torch.load(source_model_path))
        source_encoder = source_model
    logging.info("Load pretrained source model from: " + source_model_path)

    source_encoder.to(device)
    ################################################# END SECTION OF MODEL CUNSTRUCTION  #################################################

    ################################################# START SECTION OF SC MODEL PRETRAININIG  #################################################
    # Pretrain target encoder
    # Pretain using autoencoder is pretrain is not False
    if (str(pretrain) != '0'):
        # Pretrained target encoder if there are not stored files in the harddisk
        train_flag = True
        pretrain = str(pretrain)
        if (os.path.exists(pretrain) == True):
            try:
                encoder.load_state_dict(torch.load(pretrain))
                logging.info("Load pretrained target encoder from " + pretrain)
                train_flag = False

            except:
                logging.warning("Loading failed, procceed to re-train model")

        if train_flag == True:

            if reduce_model == "AE":
                encoder, loss_report_en = t.train_AE_model(
                    net=encoder,
                    data_loaders=dataloaders_pretrain,
                    optimizer=optimizer_e,
                    loss_function=loss_function_e,
                    n_epochs=epochs,
                    scheduler=exp_lr_scheduler_e,
                    save_path=pretrain)
            elif reduce_model == "VAE":
                encoder, loss_report_en = t.train_VAE_model(
                    net=encoder,
                    data_loaders=dataloaders_pretrain,
                    optimizer=optimizer_e,
                    n_epochs=epochs,
                    scheduler=exp_lr_scheduler_e,
                    save_path=pretrain)

            elif reduce_model == "CVAE":
                encoder, loss_report_en = t.train_CVAE_model(
                    net=encoder,
                    data_loaders=dataloaders_pretrain,
                    optimizer=optimizer_e,
                    n_epochs=epochs,
                    scheduler=exp_lr_scheduler_e,
                    save_path=pretrain)
            logging.info("Pretrained finished")

        # Before Transfer learning, we test the performance of using no transfer performance:
        # Use vae result to predict
        if (args.dimreduce != "CVAE"):
            embeddings_pretrain = encoder.encode(X_allTensor)
        else:
            embeddings_pretrain = encoder.encode(X_allTensor, C_allTensor)

        pretrain_prob_prediction = source_model.predict(
            embeddings_pretrain).detach().cpu().numpy()
        adata.obs["sens_preds_pret"] = pretrain_prob_prediction[:, 1]
        adata.obs["sens_label_pret"] = pretrain_prob_prediction.argmax(axis=1)

        # # Use umap result to predict
        ## This section is removed because the dim problem and the performance problem

        # sc.tl.pca(adata,  n_comps=max(50,2*dim_au_out),svd_solver='arpack')
        # sc.tl.umap(adata, n_components=dim_au_out)
        # embeddings_umap = torch.FloatTensor(adata.obsm["X_umap"]).to(device)
        # umap_prob_prediction = source_model.predict(embeddings_umap).detach().cpu().numpy()
        # adata.obs["sens_preds_umap"] = umap_prob_prediction[:,1]
        # adata.obs["sens_label_umap"] = umap_prob_prediction.argmax(axis=1)

        # # Use tsne result to predict
        # #sc.tl.tsne(adata, n_pcs=dim_au_out)

        # X_pca = adata.obsm["X_pca"]

        # # Replace tsne by pac beacause TSNE is very slow
        # X_tsne =  adata.obsm["X_umap"]
        # #X_tsne = TSNE(n_components=dim_au_out,method='exact').fit_transform(X_pca)
        # embeddings_tsne = torch.FloatTensor(X_tsne).to(device)
        # tsne_prob_prediction = source_model.predict(embeddings_tsne).detach().cpu().numpy()
        # adata.obs["sens_preds_tsne"] = tsne_prob_prediction[:,1]
        # adata.obs["sens_label_tsne"] = tsne_prob_prediction.argmax(axis=1)
        # adata.obsm["X_tsne_pret"] = X_tsne

        # Add embeddings to the adata object
        embeddings_pretrain = embeddings_pretrain.detach().cpu().numpy()
        adata.obsm["X_pre"] = embeddings_pretrain
################################################# END SECTION OF SC MODEL PRETRAININIG  #################################################

################################################# START SECTION OF TRANSFER LEARNING TRAINING #################################################
# Using ADDA transfer learning
    if args.transfer == 'ADDA':

        # Set discriminator model
        discriminator = Predictor(input_dim=dim_au_out, output_dim=2)
        discriminator.to(device)
        loss_d = nn.CrossEntropyLoss()
        optimizer_d = optim.Adam(encoder.parameters(), lr=1e-2)
        exp_lr_scheduler_d = lr_scheduler.ReduceLROnPlateau(optimizer_d)

        # Adversairal trainning
        discriminator, encoder, report_, report2_ = t.train_ADDA_model(
            source_encoder,
            encoder,
            discriminator,
            dataloaders_source,
            dataloaders_pretrain,
            loss_d,
            loss_d,
            # Should here be all optimizer d?
            optimizer_d,
            optimizer_d,
            exp_lr_scheduler_d,
            exp_lr_scheduler_d,
            epochs,
            device,
            target_model_path)

        logging.info("Transfer ADDA finished")

    # DaNN model
    elif args.transfer == 'DaNN':

        # Set predictor loss
        loss_d = nn.CrossEntropyLoss()
        optimizer_d = optim.Adam(encoder.parameters(), lr=1e-2)
        exp_lr_scheduler_d = lr_scheduler.ReduceLROnPlateau(optimizer_d)

        # Set DaNN model
        DaNN_model = DaNN(source_model=source_encoder, target_model=encoder)
        DaNN_model.to(device)

        def loss(x, y, GAMMA=args.GAMMA_mmd):
            result = mmd.mmd_loss(x, y, GAMMA)
            return result

        loss_disrtibution = loss

        # Tran DaNN model
        DaNN_model, report_ = t.train_DaNN_model(
            DaNN_model,
            dataloaders_source,
            dataloaders_pretrain,
            # Should here be all optimizer d?
            optimizer_d,
            loss_d,
            epochs,
            exp_lr_scheduler_d,
            dist_loss=loss_disrtibution,
            load=load_model,
            weight=args.mmd_weight,
            save_path=target_model_path + "_DaNN.pkl")

        encoder = DaNN_model.target_model
        source_model = DaNN_model.source_model
        logging.info("Transfer DaNN finished")
        if (load_model == False):
            ut.plot_loss(report_[0], path="figures/train_loss_" + now + ".pdf")
            ut.plot_loss(report_[1],
                         path="figures/mmd_loss_" + now + ".pdf",
                         set_ylim=False)

        if (args.dimreduce != 'CVAE'):
            # Attribute test using integrated gradient

            # Generate a target model including encoder and predictor
            target_model = TargetModel(source_model, encoder)

            # Allow require gradients and process label
            X_allTensor.requires_grad_()
            # Run integrated gradient check
            # Return adata and feature integrated gradient

            ytarget_allPred = target_model(X_allTensor).detach().cpu().numpy()
            ytarget_allPred = ytarget_allPred.argmax(axis=1)

            adata, attrp1, senNeu_c0_genes, senNeu_c1_genes = ut.integrated_gradient_differential(
                net=target_model,
                input=X_allTensor,
                clip="positive",
                target=ytarget_allPred,
                adata=adata,
                ig_fc=1,
                save_name=reduce_model + args.predictor + prediction +
                select_drug + "sensNeuron" + now)

            adata, attrn1, resNeu_c0_genes, resNeu_c1_genes = ut.integrated_gradient_differential(
                net=target_model,
                input=X_allTensor,
                clip="negative",
                target=ytarget_allPred,
                adata=adata,
                ig_fc=1,
                save_name=reduce_model + args.predictor + prediction +
                select_drug + "restNeuron" + now)

            sc.pl.heatmap(attrp1,
                          senNeu_c0_genes,
                          groupby='sensitive',
                          cmap='RdBu_r',
                          save=data_name + args.transfer + args.dimreduce +
                          "_seNc0_" + now,
                          show=False)
            sc.pl.heatmap(attrp1,
                          senNeu_c1_genes,
                          groupby='sensitive',
                          cmap='RdBu_r',
                          save=data_name + args.transfer + args.dimreduce +
                          "_seNc1_" + now,
                          show=False)
            sc.pl.heatmap(attrn1,
                          resNeu_c0_genes,
                          groupby='sensitive',
                          cmap='RdBu_r',
                          save=data_name + args.transfer + args.dimreduce +
                          "_reNc0_" + now,
                          show=False)
            sc.pl.heatmap(attrn1,
                          resNeu_c1_genes,
                          groupby='sensitive',
                          cmap='RdBu_r',
                          save=data_name + args.transfer + args.dimreduce +
                          "_reNc1_" + now,
                          show=False)

            # CHI2 Test on predictive features
            SFD = SelectFdr(chi2)
            SFD.fit(adata.raw.X, ytarget_allPred)
            adata.raw.var['chi2_pval'] = SFD.pvalues_
            adata.raw.var['chi2_score'] = SFD.scores_
            df_chi2_genes = adata.raw.var[
                (SFD.pvalues_ < 0.05) & (adata.raw.var.highly_variable == True)
                & (adata.raw.var.n_cells > args.min_c)]
            df_chi2_genes.sort_values(by="chi2_pval",
                                      ascending=True,
                                      inplace=True)
            df_chi2_genes.to_csv("saved/results/chi2_pval_genes" +
                                 args.predictor + prediction + select_drug +
                                 now + '.csv')

        else:
            print()
################################################# END SECTION OF TRANSER LEARNING TRAINING #################################################

################################################# START SECTION OF PREPROCESSING FEATURES #################################################
# Extract feature embeddings
# Extract prediction probabilities

    if (args.dimreduce != "CVAE"):
        embedding_tensors = encoder.encode(X_allTensor)
    else:
        embedding_tensors = encoder.encode(X_allTensor, C_allTensor)

    prediction_tensors = source_model.predictor(embedding_tensors)
    embeddings = embedding_tensors.detach().cpu().numpy()
    predictions = prediction_tensors.detach().cpu().numpy()

    # Transform predict8ion probabilities to 0-1 labels
    if (prediction == "regression"):
        adata.obs["sens_preds"] = predictions
    else:
        adata.obs["sens_preds"] = predictions[:, 1]
        adata.obs["sens_label"] = predictions.argmax(axis=1)
        adata.obs["sens_label"] = adata.obs["sens_label"].astype('category')
        adata.obs["rest_preds"] = predictions[:, 0]

    adata.write("saved/adata/before_ann" + data_name + now + ".h5ad")

    ################################################# END SECTION OF PREPROCESSING FEATURES #################################################

    ################################################# START SECTION OF ANALYSIS AND POST PROCESSING #################################################
    # Pipeline of scanpy
    # Add embeddings to the adata package
    adata.obsm["X_Trans"] = embeddings
    #sc.tl.umap(adata)
    sc.pp.neighbors(adata, n_neighbors=10, use_rep="X_Trans")
    # Use t-sne on transfer learning features
    sc.tl.tsne(adata, use_rep="X_Trans")
    # Leiden on the data
    # sc.tl.leiden(adata)
    # Plot tsne
    sc.pl.tsne(adata, save=data_name + now, color=["leiden"], show=False)

    # Differenrial expression genes
    sc.tl.rank_genes_groups(adata, 'leiden', method='wilcoxon')
    sc.pl.rank_genes_groups(adata,
                            n_genes=args.n_DE_genes,
                            sharey=False,
                            save=data_name + now,
                            show=False)

    # Differenrial expression genes across 0-1 classes
    sc.tl.rank_genes_groups(adata, 'sens_label', method='wilcoxon')
    adata = ut.de_score(adata, clustername='sens_label')
    # save DE genes between 0-1 class
    for label in [0, 1]:

        try:
            df_degs = get_de_dataframe(adata, label)
            df_degs.to_csv("saved/results/DEGs_class_" + str(label) +
                           args.predictor + prediction + select_drug + now +
                           '.csv')
        except:
            logging.warning("Only one class, no two calsses critical genes")

    # Generate reports of scores
    report_df = args_df

    # Data specific benchmarking

    sens_pb_pret = adata.obs['sens_preds_pret']
    lb_pret = adata.obs['sens_label_pret']

    # sens_pb_umap = adata.obs['sens_preds_umap']
    # lb_umap = adata.obs['sens_label_umap']

    # sens_pb_tsne = adata.obs['sens_preds_tsne']
    # lb_tsne = adata.obs['sens_label_tsne']

    if ('sensitive' in adata.obs.keys()):

        report_df = report_df.T
        Y_test = adata.obs['sensitive']
        sens_pb_results = adata.obs['sens_preds']
        lb_results = adata.obs['sens_label']

        le_sc = LabelEncoder()
        le_sc.fit(['Resistant', 'Sensitive'])
        label_descrbie = le_sc.inverse_transform(Y_test)
        adata.obs['sens_truth'] = label_descrbie

        color_list = ["sens_truth", "sens_label", 'sens_preds']
        color_score_list = [
            "Sensitive_score", "Resistant_score", "1_score", "0_score"
        ]

        sens_score = pearsonr(adata.obs["sens_preds"],
                              adata.obs["Sensitive_score"])[0]
        resistant_score = pearsonr(adata.obs["rest_preds"],
                                   adata.obs["Resistant_score"])[0]

        report_df['prob_sens_pearson'] = sens_score
        report_df['prob_rest_pearson'] = resistant_score

        try:
            cluster_score_sens = pearsonr(adata.obs["1_score"],
                                          adata.obs["Sensitive_score"])[0]
            report_df['sens_pearson'] = cluster_score_sens
        except:
            logging.warning(
                "Prediction score 1 not exist, fill adata with 0 values")
            adata.obs["1_score"] = np.zeros(len(adata))

        try:
            cluster_score_resist = pearsonr(adata.obs["0_score"],
                                            adata.obs["Resistant_score"])[0]
            report_df['rest_pearson'] = cluster_score_resist

        except:
            logging.warning(
                "Prediction score 0 not exist, fill adata with 0 values")
            adata.obs["0_score"] = np.zeros(len(adata))

    #if (data_name in ['GSE110894','GSE117872']):
        ap_score = average_precision_score(Y_test, sens_pb_results)
        ap_pret = average_precision_score(Y_test, sens_pb_pret)
        # ap_umap = average_precision_score(Y_test, sens_pb_umap)
        # ap_tsne = average_precision_score(Y_test, sens_pb_tsne)

        report_dict = classification_report(Y_test,
                                            lb_results,
                                            output_dict=True)
        f1score = report_dict['weighted avg']['f1-score']
        report_df['f1_score'] = f1score
        classification_report_df = pd.DataFrame(report_dict).T
        classification_report_df.to_csv("saved/results/clf_report_" +
                                        reduce_model + args.predictor +
                                        prediction + select_drug + now +
                                        '.csv')

        # report_dict_umap = classification_report(Y_test, lb_umap, output_dict=True)
        # classification_report_umap_df = pd.DataFrame(report_dict_umap).T
        # classification_report_umap_df.to_csv("saved/results/clf_umap_report_" + reduce_model + args.predictor+ prediction + select_drug+now + '.csv')

        report_dict_pret = classification_report(Y_test,
                                                 lb_pret,
                                                 output_dict=True)
        classification_report_pret_df = pd.DataFrame(report_dict_pret).T
        classification_report_pret_df.to_csv("saved/results/clf_pret_report_" +
                                             reduce_model + args.predictor +
                                             prediction + select_drug + now +
                                             '.csv')

        # report_dict_tsne = classification_report(Y_test, lb_tsne, output_dict=True)
        # classification_report_tsne_df = pd.DataFrame(report_dict_tsne).T
        # classification_report_tsne_df.to_csv("saved/results/clf_tsne_report_" + reduce_model + args.predictor+ prediction + select_drug+now + '.csv')

        try:
            auroc_score = roc_auc_score(Y_test, sens_pb_results)

            auroc_pret = average_precision_score(Y_test, sens_pb_pret)
            # auroc_umap = average_precision_score(Y_test, sens_pb_umap)
            # auroc_tsne = average_precision_score(Y_test, sens_pb_tsne)
        except:
            logging.warning("Only one class, no ROC")
            auroc_pret = auroc_umap = auroc_tsne = auroc_score = 0

        report_df['auroc_score'] = auroc_score
        report_df['ap_score'] = ap_score

        report_df['auroc_pret'] = auroc_pret
        report_df['ap_pret'] = ap_pret

        # report_df['auroc_umap'] = auroc_umap
        # report_df['ap_umap'] = ap_umap

        # report_df['auroc_tsne'] = auroc_tsne
        # report_df['ap_tsne'] = ap_tsne

        ap_title = "ap: " + str(Decimal(ap_score).quantize(Decimal('0.0000')))
        auroc_title = "roc: " + str(
            Decimal(auroc_score).quantize(Decimal('0.0000')))
        title_list = ["Ground truth", "Prediction", "Probability"]

    else:

        color_list = ["leiden", "sens_label", 'sens_preds']
        title_list = ['Cluster', "Prediction", "Probability"]
        color_score_list = color_list

    # Simple analysis do neighbors in adata using PCA embeddings
    #sc.pp.neighbors(adata)

    # Run UMAP dimension reduction
    sc.pp.neighbors(adata)
    sc.tl.umap(adata)
    # Run leiden clustering
    # sc.tl.leiden(adata,resolution=leiden_res)
    # Plot uamp
    # sc.pl.umap(adata,color=[color_list[0],'sens_label_umap','sens_preds_umap'],save=data_name+args.transfer+args.dimreduce+now,show=False,title=title_list)
    # Plot transfer learning on umap
    sc.pl.umap(adata,
               color=color_list + color_score_list,
               save=data_name + args.transfer + args.dimreduce + "umap_all" +
               now,
               show=False)
    sc.settings.set_figure_params(dpi=100,
                                  frameon=False,
                                  figsize=(4, 3),
                                  facecolor='white')
    sc.pl.umap(adata,
               color=['sensitivity', 'leiden', 'sens_label', 'sens_preds'],
               title=[
                   'Cell sensitivity', 'Cell clusters',
                   'Transfer learning prediction', 'Prediction probability'
               ],
               save=data_name + args.transfer + args.dimreduce + "umap_pred" +
               now,
               show=False,
               ncols=4)

    sc.pl.umap(adata,
               color=color_score_list,
               title=[
                   'Sensitive gene score', 'Resistant gene score',
                   'Sensitive gene score (prediction)',
                   'Resistant gene score (prediction)'
               ],
               save=data_name + args.transfer + args.dimreduce +
               "umap_scores" + now,
               show=False,
               ncols=2)

    # sc.pl.umap(adata,color=['Sample name'],

    #     save=data_name+args.transfer+args.dimreduce+"umap_sm"+now,show=False,ncols=4)
    try:
        sc.pl.umap(adata,
                   color=adata.var.sort_values(
                       "integrated_gradient_sens_class0").head().index,
                   save=data_name + args.transfer + args.dimreduce +
                   "_cgenes0_" + now,
                   show=False)
        sc.pl.umap(adata,
                   color=adata.var.sort_values(
                       "integrated_gradient_sens_class1").head().index,
                   save=data_name + args.transfer + args.dimreduce +
                   "_cgenes1_" + now,
                   show=False)

        # c0_genes = df_11_genes.loc[df_11_genes.pval<0.05].head().index
        # c1_genes = df_00_genes.loc[df_00_genes.pval<0.05].head().index

        # sc.pl.umap(adata,color=c0_genes,neighbors_key="Trans",save=data_name+args.transfer+args.dimreduce+"_cgenes0_TL"+now,show=False)
        # sc.pl.umap(adata,color=c1_genes,neighbors_key="Trans",save=data_name+args.transfer+args.dimreduce+"_cgenes1_TL"+now,show=False)
    except:
        logging.warning("IG results not avaliable")

    # Run embeddings using transfered embeddings
    sc.pp.neighbors(adata, use_rep='X_Trans', key_added="Trans")
    sc.tl.umap(adata, neighbors_key="Trans")
    sc.tl.leiden(adata,
                 neighbors_key="Trans",
                 key_added="leiden_trans",
                 resolution=leiden_res)
    sc.pl.umap(adata,
               color=color_list,
               neighbors_key="Trans",
               save=data_name + args.transfer + args.dimreduce + "_TL" + now,
               show=False,
               title=title_list)
    # Plot cell score on umap
    sc.pl.umap(adata,
               color=color_score_list,
               neighbors_key="Trans",
               save=data_name + args.transfer + args.dimreduce + "_score_TL" +
               now,
               show=False,
               title=color_score_list)

    # This tsne is based on transfer learning feature
    sc.pl.tsne(adata,
               color=color_list,
               neighbors_key="Trans",
               save=data_name + args.transfer + args.dimreduce + "_TL" + now,
               show=False,
               title=title_list)
    # Use tsne origianl version to visualize

    sc.tl.tsne(adata)
    # This tsne is based on transfer learning feature
    # sc.pl.tsne(adata,color=[color_list[0],'sens_label_tsne','sens_preds_tsne'],save=data_name+args.transfer+args.dimreduce+"_original_tsne"+now,show=False,title=title_list)

    # Plot tsne of the pretrained (autoencoder) embeddings
    sc.pp.neighbors(adata, use_rep='X_pre', key_added="Pret")
    sc.tl.umap(adata, neighbors_key="Pret")
    sc.tl.leiden(adata,
                 neighbors_key="Pret",
                 key_added="leiden_Pret",
                 resolution=leiden_res)
    sc.pl.umap(adata,
               color=[color_list[0], 'sens_label_pret', 'sens_preds_pret'],
               neighbors_key="Pret",
               save=data_name + args.transfer + args.dimreduce +
               "_umap_Pretrain_" + now,
               show=False)
    # Ari between two transfer learning embedding and sensitivity label
    ari_score_trans = adjusted_rand_score(adata.obs['leiden_trans'],
                                          adata.obs['sens_label'])
    ari_score = adjusted_rand_score(adata.obs['leiden'],
                                    adata.obs['sens_label'])

    pret_ari_score = adjusted_rand_score(adata.obs['leiden_origin'],
                                         adata.obs['leiden_Pret'])
    transfer_ari_score = adjusted_rand_score(adata.obs['leiden_origin'],
                                             adata.obs['leiden_trans'])

    sc.pl.umap(adata,
               color=['leiden_origin', 'leiden_trans', 'leiden_Pret'],
               save=data_name + args.transfer + args.dimreduce +
               "_comp_Pretrain_" + now,
               show=False)
    #report_df = args_df
    report_df['ari_score'] = ari_score
    report_df['ari_trans_score'] = ari_score_trans

    report_df['ari_pre_umap'] = pret_ari_score
    report_df['ari_trans_umap'] = transfer_ari_score

    # Trajectory of adata
    adata, corelations = trajectory(adata,
                                    root_key='sensitive',
                                    genes_vis=senNeu_c0_genes[:5],
                                    root=1,
                                    now=now,
                                    plot=True)

    gene_cor = {}
    # Trajectory
    for g in np.array(senNeu_c0_genes):
        gene = g
        express_vec = adata[:, gene].X
        corr = pearsonr(
            np.array(express_vec).ravel(),
            np.array(adata.obs["dpt_pseudotime"]))[0]
        gene_cor[gene] = corr

    try:
        for k in corelations.keys():
            report_df['cor_dpt_' + k] = corelations[k][0]
            report_df['cor_pvl_' + k] = corelations[k][1]
    except:
        logging.warning(
            "Some of the coorelation cannot be reterived from the dictional")
################################################# END SECTION OF ANALYSIS AND POST PROCESSING #################################################

################################################# START SECTION OF ANALYSIS FOR BULK DATA #################################################
# bdata = sc.AnnData(data_r)
# bdata.obs = label_r
# bulk_degs={}
# sc.tl.rank_genes_groups(bdata, select_drug, method='wilcoxon')
# bdata = ut.de_score(bdata,select_drug)
# for label in set(label_r.loc[:,select_drug]):
#     try:
#         df_degs = get_de_dataframe(bdata,label)
#         bulk_degs[label] = df_degs.iloc[:50,:].names
#         df_degs.to_csv("saved/results/DEGs_bulk_" +str(label)+ args.predictor+ prediction + select_drug+now + '.csv')
#     except:
#         logging.warning("Only one class, no two calsses critical genes")

# Xsource_allTensor = torch.FloatTensor(data_r.values).to(device)
# Ysource_preTensor = source_model(Xsource_allTensor)
# Ysource_prediction = Ysource_preTensor.detach().cpu().numpy()
# bdata.obs["sens_preds"] = Ysource_prediction[:,1]
# bdata.obs["sens_label"] = Ysource_prediction.argmax(axis=1)
# bdata.obs["sens_label"] = bdata.obs["sens_label"].astype('category')
# bdata.obs["rest_preds"] = Ysource_prediction[:,0]
# sc.tl.score_genes(adata, bulk_degs['sensitive'],score_name="bulk_sens_score" )
# sc.tl.score_genes(adata, bulk_degs['resistant'],score_name="bulk_rest_score" )
# sc.pl.umap(adata,color=['bulk_sens_score','bulk_rest_score'],save=data_name+args.transfer+args.dimreduce+"umap_bg_all"+now,show=False)

# try:
#     bulk_score_sens = pearsonr(adata.obs["1_score"],adata.obs["bulk_sens_score"])[0]
#     report_df['bulk_sens_pearson'] = bulk_score_sens
#     cluster_score_resist = pearsonr(adata.obs["0_score"],adata.obs["bulk_rest_score"])[0]
#     report_df['bulk_rest_pearson'] = cluster_score_resist

# except:
#     logging.warning("Bulk level gene score not exist")

# Save adata
    adata.write("saved/adata/" + data_name + now + ".h5ad")

    # Save report
    report_df = report_df.T
    report_df.to_csv("saved/results/report" + reduce_model + args.predictor +
                     prediction + select_drug + now + '.csv')
def run_main(args):
    ################################################# START SECTION OF LOADING PARAMETERS #################################################
    # Read parameters
    epochs = args.epochs
    dim_au_out = args.bottleneck  #8, 16, 32, 64, 128, 256,512
    na = args.missing_value
    data_path = DATA_MAP[args.target_data]
    test_size = args.test_size
    select_drug = args.drug
    freeze = args.freeze_pretrain
    valid_size = args.valid_size
    g_disperson = args.var_genes_disp
    min_n_genes = args.min_n_genes
    max_n_genes = args.max_n_genes
    source_model_path = args.source_model_path
    target_model_path = args.target_model_path
    log_path = args.logging_file
    batch_size = args.batch_size
    encoder_hdims = args.source_h_dims.split(",")
    encoder_hdims = list(map(int, encoder_hdims))
    source_data_path = args.source_data
    pretrain = args.pretrain
    prediction = args.predition
    data_name = args.target_data
    label_path = args.label_path
    reduce_model = args.dimreduce
    predict_hdims = args.p_h_dims.split(",")
    predict_hdims = list(map(int, predict_hdims))
    leiden_res = args.cluster_res
    load_model = bool(args.load_target_model)

    # Misc
    now = time.strftime("%Y-%m-%d-%H-%M-%S")
    # Initialize logging and std out
    out_path = log_path + now + ".err"
    log_path = log_path + now + ".log"

    out = open(out_path, "w")
    sys.stderr = out

    #Logging infomaion
    logging.basicConfig(
        level=logging.INFO,
        filename=log_path,
        filemode='a',
        format=
        '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'
    )
    logging.getLogger('matplotlib.font_manager').disabled = True

    logging.info(args)

    # Save arguments
    args_df = ut.save_arguments(args, now)
    AUROC = list()
    AP = list()
    Fone = list()
    ################################################# END SECTION OF LOADING PARAMETERS #################################################
    cycletime = 50
    while cycletime < 101:
        adata = pp.read_sc_file(data_path)
        adata = resample(adata,
                         n_samples=350,
                         random_state=cycletime,
                         replace=False)
        #replace=True: 可以从a 中反复选取同一个元素。
        #replace=False: a 中同一个元素只能被选取一次。
        ################################################# START SECTION OF SINGLE CELL DATA REPROCESSING #################################################
        # Load data and preprocessing

        sc.pp.filter_cells(adata, min_genes=200)
        sc.pp.filter_genes(adata, min_cells=3)

        adata = pp.cal_ncount_ngenes(adata)

        # Show statisctic after QX
        sc.pl.violin(adata,
                     ['n_genes_by_counts', 'total_counts', 'pct_counts_mt'],
                     jitter=0.4,
                     multi_panel=True,
                     save=data_name,
                     show=False)
        sc.pl.scatter(adata, x='total_counts', y='pct_counts_mt', show=False)
        sc.pl.scatter(adata,
                      x='total_counts',
                      y='n_genes_by_counts',
                      show=False)

        if args.remove_genes == 0:
            r_genes = []
        else:
            r_genes = REMOVE_GENES

        #Preprocess data by filtering
        adata = pp.receipe_my(adata,
                              l_n_genes=min_n_genes,
                              r_n_genes=max_n_genes,
                              filter_mincells=args.min_c,
                              filter_mingenes=args.min_g,
                              normalize=True,
                              log=True,
                              remove_genes=r_genes)

        # Select highly variable genes
        sc.pp.highly_variable_genes(adata,
                                    min_disp=g_disperson,
                                    max_disp=np.inf,
                                    max_mean=6)
        sc.pl.highly_variable_genes(adata, save=data_name, show=False)
        adata.raw = adata
        adata = adata[:, adata.var.highly_variable]

        # Preprocess data if spcific process is required
        if data_name == 'GSE117872':
            adata = ut.specific_process(adata,
                                        dataname=data_name,
                                        select_origin=args.batch_id)
            data = adata.X
        elif data_name == 'GSE122843':
            adata = ut.specific_process(adata, dataname=data_name)
            data = adata.X
        elif data_name == 'GSE110894':
            adata = ut.specific_process(adata, dataname=data_name)
            data = adata.X
        elif data_name == 'GSE112274':
            adata = ut.specific_process(adata, dataname=data_name)
            data = adata.X
        elif data_name == 'GSE116237':
            adata = ut.specific_process(adata, dataname=data_name)
            data = adata.X
        elif data_name == 'GSE108383':
            adata = ut.specific_process(adata, dataname=data_name)
            data = adata.X
        else:
            data = adata.X

        # PCA
        # Generate neighbor graph
        sc.tl.pca(adata, svd_solver='arpack')
        sc.pp.neighbors(adata, n_neighbors=10)
        # Generate cluster labels
        sc.tl.leiden(adata, resolution=leiden_res)
        sc.tl.umap(adata)
        # sc.pl.umap(adata,color=['leiden'],save=data_name+'umap'+now,show=False)
        adata.obs['leiden_origin'] = adata.obs['leiden']
        adata.obsm['X_umap_origin'] = adata.obsm['X_umap']
        data_c = adata.obs['leiden'].astype("long").to_list()
        ################################################# END SECTION OF SINGLE CELL DATA REPROCESSING #################################################

        ################################################# START SECTION OF LOADING SC DATA TO THE TENSORS #################################################
        #Prepare to normailize and split target data
        mmscaler = preprocessing.MinMaxScaler()

        try:
            data = mmscaler.fit_transform(data)

        except:
            logging.warning("Only one class, no ROC")

            # Process sparse data
            data = data.todense()
            data = mmscaler.fit_transform(data)

        # Split data to train and valid set
        # Along with the leiden conditions for CVAE propose
        Xtarget_train, Xtarget_valid, Ctarget_train, Ctarget_valid = train_test_split(
            data, data_c, test_size=valid_size, random_state=42)

        # Select the device of gpu
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        # Assuming that we are on a CUDA machine, this should print a CUDA device:
        logging.info(device)
        torch.cuda.set_device(device)

        # Construct datasets and data loaders
        Xtarget_trainTensor = torch.FloatTensor(Xtarget_train).to(device)
        Xtarget_validTensor = torch.FloatTensor(Xtarget_valid).to(device)

        # Use leiden label if CVAE is applied
        Ctarget_trainTensor = torch.LongTensor(Ctarget_train).to(device)
        Ctarget_validTensor = torch.LongTensor(Ctarget_valid).to(device)

        X_allTensor = torch.FloatTensor(data).to(device)
        C_allTensor = torch.LongTensor(data_c).to(device)

        train_dataset = TensorDataset(Xtarget_trainTensor, Ctarget_trainTensor)
        valid_dataset = TensorDataset(Xtarget_validTensor, Ctarget_validTensor)

        Xtarget_trainDataLoader = DataLoader(dataset=train_dataset,
                                             batch_size=batch_size,
                                             shuffle=True)
        Xtarget_validDataLoader = DataLoader(dataset=valid_dataset,
                                             batch_size=batch_size,
                                             shuffle=True)

        dataloaders_pretrain = {
            'train': Xtarget_trainDataLoader,
            'val': Xtarget_validDataLoader
        }
        ################################################# START SECTION OF LOADING SC DATA TO THE TENSORS #################################################

        ################################################# START SECTION OF LOADING BULK DATA  #################################################
        # Read source data
        data_r = pd.read_csv(source_data_path, index_col=0)
        label_r = pd.read_csv(label_path, index_col=0)
        label_r = label_r.fillna(na)

        # Extract labels
        selected_idx = label_r.loc[:, select_drug] != na
        label = label_r.loc[selected_idx, select_drug]

        label = label.values.reshape(-1, 1)

        if prediction == "regression":
            lbscaler = preprocessing.MinMaxScaler()
            label = lbscaler.fit_transform(label)
            dim_model_out = 1
        else:
            le = preprocessing.LabelEncoder()
            label = le.fit_transform(label)
            dim_model_out = 2

        # Process source data
        mmscaler = preprocessing.MinMaxScaler()
        source_data = mmscaler.fit_transform(data_r)

        # Split source data
        Xsource_train_all, Xsource_test, Ysource_train_all, Ysource_test = train_test_split(
            source_data, label, test_size=test_size, random_state=42)
        Xsource_train, Xsource_valid, Ysource_train, Ysource_valid = train_test_split(
            Xsource_train_all,
            Ysource_train_all,
            test_size=valid_size,
            random_state=42)

        # Transform source data
        # Construct datasets and data loaders
        Xsource_trainTensor = torch.FloatTensor(Xsource_train).to(device)
        Xsource_validTensor = torch.FloatTensor(Xsource_valid).to(device)

        if prediction == "regression":
            Ysource_trainTensor = torch.FloatTensor(Ysource_train).to(device)
            Ysource_validTensor = torch.FloatTensor(Ysource_valid).to(device)
        else:
            Ysource_trainTensor = torch.LongTensor(Ysource_train).to(device)
            Ysource_validTensor = torch.LongTensor(Ysource_valid).to(device)

        sourcetrain_dataset = TensorDataset(Xsource_trainTensor,
                                            Ysource_trainTensor)
        sourcevalid_dataset = TensorDataset(Xsource_validTensor,
                                            Ysource_validTensor)

        Xsource_trainDataLoader = DataLoader(dataset=sourcetrain_dataset,
                                             batch_size=batch_size,
                                             shuffle=True)
        Xsource_validDataLoader = DataLoader(dataset=sourcevalid_dataset,
                                             batch_size=batch_size,
                                             shuffle=True)

        dataloaders_source = {
            'train': Xsource_trainDataLoader,
            'val': Xsource_validDataLoader
        }
        ################################################# END SECTION OF LOADING BULK DATA  #################################################

        ################################################# START SECTION OF MODEL CUNSTRUCTION  #################################################
        # Construct target encoder
        if reduce_model == "AE":
            encoder = AEBase(input_dim=data.shape[1],
                             latent_dim=dim_au_out,
                             h_dims=encoder_hdims)
            loss_function_e = nn.MSELoss()
        elif reduce_model == "VAE":
            encoder = VAEBase(input_dim=data.shape[1],
                              latent_dim=dim_au_out,
                              h_dims=encoder_hdims)
        elif reduce_model == "CVAE":
            # Number of condition is equal to the number of clusters
            encoder = CVAEBase(input_dim=data.shape[1],
                               n_conditions=len(set(data_c)),
                               latent_dim=dim_au_out,
                               h_dims=encoder_hdims)

        if torch.cuda.is_available():
            encoder.cuda()

        logging.info("Target encoder structure is: ")
        logging.info(encoder)

        encoder.to(device)
        optimizer_e = optim.Adam(encoder.parameters(), lr=1e-2)
        loss_function_e = nn.MSELoss()
        exp_lr_scheduler_e = lr_scheduler.ReduceLROnPlateau(optimizer_e)

        # Load source model before transfer
        if prediction == "regression":
            dim_model_out = 1
        else:
            dim_model_out = 2
        # Load AE model
        if reduce_model == "AE":

            source_model = PretrainedPredictor(
                input_dim=Xsource_train.shape[1],
                latent_dim=dim_au_out,
                h_dims=encoder_hdims,
                hidden_dims_predictor=predict_hdims,
                output_dim=dim_model_out,
                pretrained_weights=None,
                freezed=freeze)
            source_model.load_state_dict(torch.load(source_model_path))
            source_encoder = source_model
        # Load VAE model
        elif reduce_model in ["VAE", "CVAE"]:
            source_model = PretrainedVAEPredictor(
                input_dim=Xsource_train.shape[1],
                latent_dim=dim_au_out,
                h_dims=encoder_hdims,
                hidden_dims_predictor=predict_hdims,
                output_dim=dim_model_out,
                pretrained_weights=None,
                freezed=freeze,
                z_reparam=bool(args.VAErepram))
            source_model.load_state_dict(torch.load(source_model_path))
            source_encoder = source_model
        logging.info("Load pretrained source model from: " + source_model_path)

        source_encoder.to(device)
        ################################################# END SECTION OF MODEL CUNSTRUCTION  #################################################

        ################################################# START SECTION OF SC MODEL PRETRAININIG  #################################################
        # Pretrain target encoder
        # Pretain using autoencoder is pretrain is not False
        if (str(pretrain) != '0'):
            # Pretrained target encoder if there are not stored files in the harddisk
            train_flag = True
            pretrain = str(pretrain)
            if (os.path.exists(pretrain) == True):
                try:
                    encoder.load_state_dict(torch.load(pretrain))
                    logging.info("Load pretrained target encoder from " +
                                 pretrain)
                    train_flag = False

                except:
                    logging.warning(
                        "Loading failed, procceed to re-train model")

            if train_flag == True:

                if reduce_model == "AE":
                    encoder, loss_report_en = t.train_AE_model(
                        net=encoder,
                        data_loaders=dataloaders_pretrain,
                        optimizer=optimizer_e,
                        loss_function=loss_function_e,
                        n_epochs=epochs,
                        scheduler=exp_lr_scheduler_e,
                        save_path=pretrain)
                elif reduce_model == "VAE":
                    encoder, loss_report_en = t.train_VAE_model(
                        net=encoder,
                        data_loaders=dataloaders_pretrain,
                        optimizer=optimizer_e,
                        n_epochs=epochs,
                        scheduler=exp_lr_scheduler_e,
                        save_path=pretrain)

                elif reduce_model == "CVAE":
                    encoder, loss_report_en = t.train_CVAE_model(
                        net=encoder,
                        data_loaders=dataloaders_pretrain,
                        optimizer=optimizer_e,
                        n_epochs=epochs,
                        scheduler=exp_lr_scheduler_e,
                        save_path=pretrain)
                logging.info("Pretrained finished")

            # Before Transfer learning, we test the performance of using no transfer performance:
            # Use vae result to predict
            if (args.dimreduce != "CVAE"):
                embeddings_pretrain = encoder.encode(X_allTensor)
            else:
                embeddings_pretrain = encoder.encode(X_allTensor, C_allTensor)

            pretrain_prob_prediction = source_model.predict(
                embeddings_pretrain).detach().cpu().numpy()
            adata.obs["sens_preds_pret"] = pretrain_prob_prediction[:, 1]
            adata.obs["sens_label_pret"] = pretrain_prob_prediction.argmax(
                axis=1)

            # Add embeddings to the adata object
            embeddings_pretrain = embeddings_pretrain.detach().cpu().numpy()
            adata.obsm["X_pre"] = embeddings_pretrain
    ################################################# END SECTION OF SC MODEL PRETRAININIG  #################################################

    ################################################# START SECTION OF TRANSFER LEARNING TRAINING #################################################
    # Using ADDA transfer learning
        if args.transfer == 'ADDA':

            # Set discriminator model
            discriminator = Predictor(input_dim=dim_au_out, output_dim=2)
            discriminator.to(device)
            loss_d = nn.CrossEntropyLoss()
            optimizer_d = optim.Adam(encoder.parameters(), lr=1e-2)
            exp_lr_scheduler_d = lr_scheduler.ReduceLROnPlateau(optimizer_d)

            # Adversairal trainning
            discriminator, encoder, report_, report2_ = t.train_ADDA_model(
                source_encoder,
                encoder,
                discriminator,
                dataloaders_source,
                dataloaders_pretrain,
                loss_d,
                loss_d,
                # Should here be all optimizer d?
                optimizer_d,
                optimizer_d,
                exp_lr_scheduler_d,
                exp_lr_scheduler_d,
                epochs,
                device,
                target_model_path)

            logging.info("Transfer ADDA finished")

        # DaNN model
        elif args.transfer == 'DaNN':

            # Set predictor loss
            loss_d = nn.CrossEntropyLoss()
            optimizer_d = optim.Adam(encoder.parameters(), lr=1e-2)
            exp_lr_scheduler_d = lr_scheduler.ReduceLROnPlateau(optimizer_d)

            # Set DaNN model
            DaNN_model = DaNN(source_model=source_encoder,
                              target_model=encoder)
            DaNN_model.to(device)

            def loss(x, y, GAMMA=args.GAMMA_mmd):
                result = mmd.mmd_loss(x, y, GAMMA)
                return result

            loss_disrtibution = loss

            # Tran DaNN model
            DaNN_model, report_ = t.train_DaNN_model(
                DaNN_model,
                dataloaders_source,
                dataloaders_pretrain,
                # Should here be all optimizer d?
                optimizer_d,
                loss_d,
                epochs,
                exp_lr_scheduler_d,
                dist_loss=loss_disrtibution,
                load=load_model,
                weight=args.mmd_weight,
                save_path=target_model_path + "_DaNN.pkl")

            encoder = DaNN_model.target_model
            source_model = DaNN_model.source_model
            logging.info("Transfer DaNN finished")
            if (load_model == False):
                ut.plot_loss(report_[0],
                             path="figures/train_loss_" + now + ".pdf")
                ut.plot_loss(report_[1],
                             path="figures/mmd_loss_" + now + ".pdf")

            if (args.dimreduce != 'CVAE'):
                # Attribute test using integrated gradient

                # Generate a target model including encoder and predictor
                target_model = TargetModel(source_model, encoder)

                # Allow require gradients and process label
                Xtarget_validTensor.requires_grad_()

                # Run integrated gradient check
                # Return adata and feature integrated gradient

                ytarget_validPred = target_model(
                    Xtarget_validTensor).detach().cpu().numpy()
                ytarget_validPred = ytarget_validPred.argmax(axis=1)

                adata, attr = ut.integrated_gradient_check(
                    net=target_model,
                    input=Xtarget_validTensor,
                    target=ytarget_validPred,
                    adata=adata,
                    n_genes=args.n_DL_genes,
                    save_name=reduce_model + args.predictor + prediction +
                    select_drug + now)

                adata, attr0 = ut.integrated_gradient_check(
                    net=target_model,
                    input=Xtarget_validTensor,
                    target=ytarget_validPred,
                    target_class=0,
                    adata=adata,
                    n_genes=args.n_DL_genes,
                    save_name=reduce_model + args.predictor + prediction +
                    select_drug + now)

            else:
                print()
    ################################################# END SECTION OF TRANSER LEARNING TRAINING #################################################

    ################################################# START SECTION OF PREPROCESSING FEATURES #################################################
    # Extract feature embeddings
    # Extract prediction probabilities

        if (args.dimreduce != "CVAE"):
            embedding_tensors = encoder.encode(X_allTensor)
        else:
            embedding_tensors = encoder.encode(X_allTensor, C_allTensor)

        prediction_tensors = source_model.predictor(embedding_tensors)
        embeddings = embedding_tensors.detach().cpu().numpy()
        predictions = prediction_tensors.detach().cpu().numpy()

        # Transform predict8ion probabilities to 0-1 labels
        if (prediction == "regression"):
            adata.obs["sens_preds"] = predictions
        else:
            adata.obs["sens_preds"] = predictions[:, 1]
            adata.obs["sens_label"] = predictions.argmax(axis=1)
            adata.obs["sens_label"] = adata.obs["sens_label"].astype(
                'category')
            adata.obs["rest_preds"] = predictions[:, 0]
    ################################################# END SECTION OF PREPROCESSING FEATURES #################################################

    ################################################# START SECTION OF ANALYSIS AND POST PROCESSING #################################################
    # Pipeline of scanpy
    # Add embeddings to the adata package
        adata.obsm["X_Trans"] = embeddings
        #sc.tl.umap(adata)
        sc.pp.neighbors(adata, n_neighbors=10, use_rep="X_Trans")
        # Use t-sne on transfer learning features
        sc.tl.tsne(adata, use_rep="X_Trans")
        # Leiden on the data
        # sc.tl.leiden(adata)
        # Plot tsne
        # sc.pl.tsne(adata,save=data_name+now,color=["leiden"],show=False)

        # Differenrial expression genes
        sc.tl.rank_genes_groups(adata, 'leiden', method='wilcoxon')
        # sc.pl.rank_genes_groups(adata, n_genes=args.n_DE_genes, sharey=False,save=data_name+now,show=False)

        # Differenrial expression genes across 0-1 classes
        sc.tl.rank_genes_groups(adata, 'sens_label', method='wilcoxon')
        adata = ut.de_score(adata, clustername='sens_label')
        # save DE genes between 0-1 class
        # for label in [0,1]:

        #     try:
        #         df_degs = get_de_dataframe(adata,label)
        #         #df_degs.to_csv("saved/results/DEGs_class_" +str(label)+ args.predictor+ prediction + select_drug+now + '.csv')
        #     except:
        #         logging.warning("Only one class, no two calsses critical genes")

        # Generate reports of scores
        report_df = args_df

        # Data specific benchmarking

        # sens_pb_pret = adata.obs['sens_preds_pret']
        # lb_pret = adata.obs['sens_label_pret']

        # sens_pb_umap = adata.obs['sens_preds_umap']
        # lb_umap = adata.obs['sens_label_umap']

        # sens_pb_tsne = adata.obs['sens_preds_tsne']
        # lb_tsne = adata.obs['sens_label_tsne']

        if (data_name == 'GSE117872'):

            label = adata.obs['cluster']
            label_no_ho = label

            if len(label[label != "Sensitive"]) > 0:
                # label[label != "Sensitive"] = 'Resistant'
                label_no_ho[label_no_ho != "Resistant"] = 'Sensitive'
                adata.obs['sens_truth'] = label_no_ho

            le_sc = LabelEncoder()
            le_sc.fit(['Resistant', 'Sensitive'])
            sens_pb_results = adata.obs['sens_preds']
            Y_test = le_sc.transform(label_no_ho)
            lb_results = adata.obs['sens_label']
            color_list = ["sens_truth", "sens_label", 'sens_preds']
            color_score_list = [
                "Sensitive_score", "Resistant_score", "1_score", "0_score"
            ]

            sens_score = pearsonr(adata.obs["sens_preds"],
                                  adata.obs["Sensitive_score"])[0]
            resistant_score = pearsonr(adata.obs["sens_preds"],
                                       adata.obs["Resistant_score"])[0]

            cluster_score_sens = pearsonr(adata.obs["1_score"],
                                          adata.obs["Sensitive_score"])[0]
            cluster_score_resist = pearsonr(adata.obs["0_score"],
                                            adata.obs["Resistant_score"])[0]

            report_df['sens_pearson'] = sens_score
            report_df['resist_pearson'] = resistant_score
            report_df['1_pearson'] = cluster_score_sens
            report_df['0_pearson'] = cluster_score_resist

        elif (data_name == 'GSE110894'):

            report_df = report_df.T
            Y_test = adata.obs['sensitive']
            sens_pb_results = adata.obs['sens_preds']
            lb_results = adata.obs['sens_label']

            le_sc = LabelEncoder()
            le_sc.fit(['Resistant', 'Sensitive'])
            label_descrbie = le_sc.inverse_transform(Y_test)
            adata.obs['sens_truth'] = label_descrbie

            color_list = ["sens_truth", "sens_label", 'sens_preds']
            color_score_list = [
                "sensitive_score", "resistant_score", "1_score", "0_score"
            ]

            sens_score = pearsonr(adata.obs["sens_preds"],
                                  adata.obs["sensitive_score"])[0]
            resistant_score = pearsonr(adata.obs["sens_preds"],
                                       adata.obs["resistant_score"])[0]

            report_df['sens_pearson'] = sens_score
            report_df['resist_pearson'] = resistant_score

            cluster_score_sens = pearsonr(adata.obs["1_score"],
                                          adata.obs["sensitive_score"])[0]
            cluster_score_resist = pearsonr(adata.obs["0_score"],
                                            adata.obs["resistant_score"])[0]

            report_df['1_pearson'] = cluster_score_sens
            report_df['0_pearson'] = cluster_score_resist

        elif (data_name == 'GSE108383'):

            report_df = report_df.T
            Y_test = adata.obs['sensitive']
            sens_pb_results = adata.obs['sens_preds']
            lb_results = adata.obs['sens_label']

            le_sc = LabelEncoder()
            le_sc.fit(['Resistant', 'Sensitive'])
            label_descrbie = le_sc.inverse_transform(Y_test)
            adata.obs['sens_truth'] = label_descrbie

            color_list = ["sens_truth", "sens_label", 'sens_preds']
            color_score_list = [
                "sensitive_score", "resistant_score", "1_score", "0_score"
            ]

            sens_score = pearsonr(adata.obs["sens_preds"],
                                  adata.obs["sensitive_score"])[0]
            resistant_score = pearsonr(adata.obs["sens_preds"],
                                       adata.obs["resistant_score"])[0]

            report_df['sens_pearson'] = sens_score
            report_df['resist_pearson'] = resistant_score

            cluster_score_sens = pearsonr(adata.obs["1_score"],
                                          adata.obs["sensitive_score"])[0]
            cluster_score_resist = pearsonr(adata.obs["0_score"],
                                            adata.obs["resistant_score"])[0]

            report_df['1_pearson'] = cluster_score_sens
            report_df['0_pearson'] = cluster_score_resist

        if (data_name in ['GSE110894', 'GSE117872', 'GSE108383']):
            ap_score = average_precision_score(Y_test, sens_pb_results)

            report_dict = classification_report(Y_test,
                                                lb_results,
                                                output_dict=True)
            classification_report_df = pd.DataFrame(report_dict).T
            #classification_report_df.to_csv("saved/results/clf_report_" + reduce_model + args.predictor+ prediction + select_drug+now + '.csv')
            sum_classification_report_df = pd.DataFrame()
            sum_classification_report_df = sum_classification_report_df.append(
                classification_report_df, ignore_index=False)
            try:
                auroc_score = roc_auc_score(Y_test, sens_pb_results)

            except:
                logging.warning("Only one class, no ROC")
                auroc_pret = auroc_umap = auroc_tsne = auroc_score = 0

            report_df['auroc_score'] = auroc_score
            report_df['ap_score'] = ap_score

            ap_title = "ap: " + str(
                Decimal(ap_score).quantize(Decimal('0.0000')))
            auroc_title = "roc: " + str(
                Decimal(auroc_score).quantize(Decimal('0.0000')))
            title_list = ["Ground truth", "Prediction", "Probability"]

        else:

            color_list = ["leiden", "sens_label", 'sens_preds']
            title_list = ['Cluster', "Prediction", "Probability"]
            color_score_list = color_list

        # Simple analysis do neighbors in adata using PCA embeddings
        #sc.pp.neighbors(adata)

        # Run UMAP dimension reduction
        sc.pp.neighbors(adata)
        sc.tl.umap(adata)
        # Run leiden clustering
        # sc.tl.leiden(adata,resolution=leiden_res)
        # # Plot uamp
        # sc.pl.umap(adata,color=[color_list[0],'sens_label_umap','sens_preds_umap'],save=data_name+args.transfer+args.dimreduce+now,show=False,title=title_list)

        # Run embeddings using transfered embeddings
        sc.pp.neighbors(adata, use_rep='X_Trans', key_added="Trans")
        sc.tl.umap(adata, neighbors_key="Trans")
        sc.tl.leiden(adata,
                     neighbors_key="Trans",
                     key_added="leiden_trans",
                     resolution=leiden_res)
        # sc.pl.umap(adata,color=color_list,neighbors_key="Trans",save=data_name+args.transfer+args.dimreduce+"_TL"+now,show=False,title=title_list)
        # Plot tsne
        # sc.pl.umap(adata,color=color_score_list,neighbors_key="Trans",save=data_name+args.transfer+args.dimreduce+"_score_TL"+now,show=False,title=color_score_list)

        # This tsne is based on transfer learning feature
        # sc.pl.tsne(adata,color=color_list,neighbors_key="Trans",save=data_name+args.transfer+args.dimreduce+"_TL"+now,show=False,title=title_list)
        # Use tsne origianl version to visualize

        sc.tl.tsne(adata)
        # This tsne is based on transfer learning feature
        # sc.pl.tsne(adata,color=[color_list[0],'sens_label_tsne','sens_preds_tsne'],save=data_name+args.transfer+args.dimreduce+"_original_tsne"+now,show=False,title=title_list)

        # Plot tsne of the pretrained (autoencoder) embeddings
        sc.pp.neighbors(adata, use_rep='X_pre', key_added="Pret")
        sc.tl.umap(adata, neighbors_key="Pret")
        sc.tl.leiden(adata,
                     neighbors_key="Pret",
                     key_added="leiden_Pret",
                     resolution=leiden_res)
        # sc.pl.umap(adata,color=[color_list[0],'sens_label_pret','sens_preds_pret'],neighbors_key="Pret",save=data_name+args.transfer+args.dimreduce+"_umap_Pretrain_"+now,show=False)

        # Ari between two transfer learning embedding and sensitivity label
        ari_score_trans = adjusted_rand_score(adata.obs['leiden_trans'],
                                              adata.obs['sens_label'])
        ari_score = adjusted_rand_score(adata.obs['leiden'],
                                        adata.obs['sens_label'])

        pret_ari_score = adjusted_rand_score(adata.obs['leiden_origin'],
                                             adata.obs['leiden_Pret'])
        transfer_ari_score = adjusted_rand_score(adata.obs['leiden_origin'],
                                                 adata.obs['leiden_trans'])

        # sc.pl.umap(adata,color=['leiden_origin','leiden_trans','leiden_Pret'],save=data_name+args.transfer+args.dimreduce+"_comp_Pretrain_"+now,show=False)
        #report_df = args_df
        report_df['ari_score'] = ari_score
        report_df['ari_trans_score'] = ari_score_trans

        report_df['ari_pre_umap'] = pret_ari_score
        report_df['ari_trans_umap'] = transfer_ari_score

        cluster_ids = set(adata.obs['leiden'])

        # Two class: sens and resistant between clustering label

        # Trajectory of adata
        #adata = trajectory(adata,now=now)

        # Draw PDF
        # sc.pl.draw_graph(adata, color=['leiden', 'dpt_pseudotime'],save=data_name+args.dimreduce+"leiden+trajectory")
        # sc.pl.draw_graph(adata, color=['sens_preds', 'dpt_pseudotime_leiden_trans','leiden_trans'],save=data_name+args.dimreduce+"sens_preds+trajectory")

        # Save adata
        #adata.write("saved/adata/"+data_name+now+".h5ad")

        # Save report
        report_df = report_df.T
        AP.append(ap_score)
        AUROC.append(auroc_score)
        Fone.append(report_dict['weighted avg']['f1-score'])
        cycletime = cycletime + 1

    AUROC2 = pd.DataFrame(AUROC)
    AP2 = pd.DataFrame(AP)
    Fonefinal = pd.DataFrame(Fone)
    AUROC2.to_csv("saved/results/AUROC2report" + reduce_model +
                  args.predictor + prediction + select_drug + now + '.csv')
    ################################################# END SECTION OF ANALYSIS AND POST PROCESSING #################################################
    Fonefinal.to_csv("saved/results/clf_report_Fonefinal" + reduce_model +
                     args.predictor + prediction + select_drug + now + '.csv')
    AP2.to_csv("saved/results/clf_umap_report_AP2" + reduce_model +
               args.predictor + prediction + select_drug + now + '.csv')
Exemple #6
0
def main():
    # Parse the arguments
    args, device, model_name, paths, kwargs = parser.parserCIFAR()
    print(args)
    print('Name of the model:', model_name)
    # Find the dataset
    dataset = args.dataset
    if Path('/dataset/' + dataset).is_dir():
        print('Using folder /dataset')
        path_to_dataset = './'
    elif Path('/datasets2/' + dataset).is_dir():
        path_to_dataset = '/datasets2/' + dataset
    elif Path('../data').is_dir():
        print('Looking for the dataset locally')
        path_to_dataset = '../data'
    else:
        raise ValueError("No path to the dataset found")
    save_path = './saved_models/{}/{}.pt'.format(dataset, model_name)
    load_path = './saved_models/{}/{}.pt'.format(dataset, model_name)
    results_path = './results/{}/{}.txt'.format(dataset, model_name)

    if dataset == "cifar10":
        train_loader, test_loader = utils.load_cifar(path_to_dataset, args, kwargs)
    elif dataset == "mnist":
        train_loader, test_loader = utils.load_mnist(path_to_dataset, args, kwargs)
    else:
        raise ValueError("Dataset {} not implemented".format(dataset))

    # Case where there is no training
    if args.visualize:
        model = StandardNet(args) if args.standard else ProductNet(args)
        model.load_state_dict(torch.load(load_path))
        print('Model loaded')
        model = model.to('cpu')
        (visualize_standard_conv if args.standard else visualize)(model, model_name, args)
        return

    n_expe = args.experiments
    results = np.zeros((n_expe, args.epochs))
    for expe in range(n_expe):
        model = StandardNet(args) if args.standard else ProductNet(args)
        model = model.to(device)
        optimizer = optim.Adam(model.parameters(), args.lr, weight_decay=1e-5)
        if args.early_stopping:
            earlyS = utils.EarlyStopping(mode='max', patience=args.early_stopping)

        for epoch in range(1, args.epochs + 1):
            train(args, model, device, train_loader, optimizer, epoch)
            acc = test(model, device, test_loader)
            results[expe, epoch - 1] = acc
            stop = False if not args.early_stopping else earlyS.step(acc)
            if stop:
                print("Early stopping triggered at iteration", epoch)
                print("Stopping the training.")
                for i in range(epoch + 1, args.epochs + 1):
                    results[expe, i - 1] = acc
                break
        print('Experiment', expe, 'finished.')

    if args.save_results or args.save_model:
        utils.save_arguments(args.dataset, model_name)
    if args.save_model:
        torch.save(model.state_dict(), save_path)
    if args.save_results:
        np.savetxt(results_path, results)
    average_acc = np.mean(results[:, -1])
    print('All experiments done. Average accuracy:', average_acc)
    return {'loss': -average_acc}