def training(data_path, train_datasets, model_dir, batch_size, learning_rate, num_steps): """ Perform training of three a scaden model ensemble consisting of three different models :param model_dir: :param batch_size: :param learning_rate: :param num_steps: :return: """ # Convert training datasets if train_datasets == '': train_datasets = [] else: train_datasets = train_datasets.split() print("Training on: " + str(train_datasets)) # M256 model training print("Training M256 Model ...") tf.reset_default_graph() with tf.Session() as sess: cdn256 = Scaden(sess=sess, model_dir=model_dir + "/m256", model_name='m256', batch_size=batch_size, learning_rate=learning_rate, num_steps=num_steps) cdn256.hidden_units = M256_HIDDEN_UNITS cdn256.do_rates = M256_DO_RATES cdn256.train(input_path=data_path, train_datasets=train_datasets) # Training of mid model print("Training M512 Model ...") tf.reset_default_graph() with tf.Session() as sess: cdn512 = Scaden(sess=sess, model_dir=model_dir + "/m512", model_name='m512', batch_size=batch_size, learning_rate=learning_rate, num_steps=num_steps) cdn512.hidden_units = M512_HIDDEN_UNITS cdn512.do_rates = M512_DO_RATES cdn512.train(input_path=data_path, train_datasets=train_datasets) # Training of large model print("Training M1024 Model ...") tf.reset_default_graph() with tf.Session() as sess: cdn1024 = Scaden(sess=sess, model_dir=model_dir + "/m1024", model_name='m1024', batch_size=batch_size, learning_rate=learning_rate, num_steps=num_steps) cdn1024.hidden_units = M1024_HIDDEN_UNITS cdn1024.do_rates = M1024_DO_RATES cdn1024.train(input_path=data_path, train_datasets=train_datasets) print("Training finished.")
def prediction(model_dir, data_path, out_name): """ Perform prediction using a trained scaden ensemble :param model_dir: the directory containing the models :param data_path: the path to the gene expression file :param out_name: name of the output prediction file :return: """ # Small model predictions tf.reset_default_graph() with tf.Session() as sess: cdn256 = Scaden(sess=sess, model_dir=model_dir + "/m256", model_name='m256') cdn256.hidden_units = M256_HIDDEN_UNITS cdn256.do_rates = M256_DO_RATES # Predict ratios preds_256 = cdn256.predict(input_path=data_path, out_name='cdn_predictions_m256.txt') # Mid model predictions tf.reset_default_graph() with tf.Session() as sess: cdn512 = Scaden(sess=sess, model_dir=model_dir + "/m512", model_name='m512') cdn512.hidden_units = M512_HIDDEN_UNITS cdn512.do_rates = M512_DO_RATES # Predict ratios preds_512 = cdn512.predict(input_path=data_path, out_name='cdn_predictions_m512.txt') # Large model predictions tf.reset_default_graph() with tf.Session() as sess: cdn1024 = Scaden(sess=sess, model_dir=model_dir + "/m1024", model_name='m1024') cdn1024.hidden_units = M1024_HIDDEN_UNITS cdn1024.do_rates = M1024_DO_RATES # Predict ratios preds_1024 = cdn1024.predict(input_path=data_path, out_name='cdn_predictions_m1024.txt') # Average predictions preds = (preds_256 + preds_512 + preds_1024) / 3 preds.to_csv(out_name, sep="\t")