def main(): #Get args for running all_mesh_terms_path = '' abstracts_path = '' model_name = '' n_hidden_units = 300 sample_size = 500000 try: opts, args = getopt.getopt(sys.argv[1:], 'm:d:n:e:h:s:b:mb:p:', ['mesh_terms=', 'abstract_data','model_name', 'hidden_units', 'sample_size',]) except getopt.GetoptError: sys.exit(2) for opt, arg in opts: if opt in ('-m', '--mesh_terms'): all_mesh_terms_path = arg elif opt in ('-d', '--abstract_data'): abstracts_path = arg elif opt in ('-n', '--model_name'): model_name = arg elif opt in ('-h', 'hidden_units'): n_hidden_units = int(arg) elif opt in ('-s', 'sample_size'): sample_size = int(arg) else: sys.exit(2) model = ANNMesh.create_model()
def run_train_model(all_mesh_terms_path, abstracts_path, model_name, epochs, n_hidden_units, sample_size, batch_size, mini_batch_size, full_path): print('..loading data') target_dict = parse_mesh(all_mesh_terms_path, False) X_train, Y_train = process_save_data(sample_size, target_dict, full_path, abstracts_path) model = ANNMesh.create_model(X_train['data'].shape[1], n_hidden_units, len(target_dict)) if os.path.exists(model_name): model.load_weights(model_name) print('..training') checkpointer = ModelCheckpoint(filepath=model_name, verbose=1, save_best_only=True) loss_history = LossHistory() model, train_loss_history, valid_loss_history = train(epochs, model, sample_size, abstracts_path, checkpointer, loss_history, batch_size, target_dict, mini_batch_size, X_train, Y_train) print(len(train_loss_history), len(valid_loss_history)) print('..saving model') model.save_weights("final_" + model_name) #Plot the loss curves epochs_axis = numpy.arange(1, epochs + 1) train_loss, = plt.plot(epochs_axis, train_loss_history, label='Train') val_loss, = plt.plot(epochs_axis, valid_loss_history, label='Validation') plt.legend(handles=[train_loss, val_loss]) plt.ylabel('Loss') plt.xlabel('Epoch') plt.savefig(model_name + '_loss_plot.png')