def train_network(training_data, val_data, params): # SET UP NETWORK autoencoder_network = full_network(params) loss, losses, loss_refinement = define_loss(autoencoder_network, params) learning_rate = tf.placeholder(tf.float32, name='learning_rate') train_op = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss) train_op_refinement = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss_refinement) saver = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)) validation_dict = create_feed_dictionary(val_data, params, idxs=None) x_norm = np.mean(val_data['u']**2) if params['model_order'] == 1: sindy_predict_norm = np.mean(val_data['du']**2) else: sindy_predict_norm = np.mean(val_data['ddu']**2) print('TRAINING') with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for i in range(params['max_epochs']): for j in range(params['epoch_size']//params['batch_size']): batch_idxs = np.arange(j*params['batch_size'], (j+1)*params['batch_size']) train_dict = create_feed_dictionary(training_data, params, idxs=batch_idxs) sess.run(train_op, feed_dict=train_dict) if params['print_progress'] and (i % params['print_frequency'] == 0): print_progress(sess, i, loss, losses, train_dict, validation_dict, x_norm, sindy_predict_norm) if params['sequential_thresholding'] and (i % params['threshold_frequency'] == 0) and (i > 0): params['coefficient_mask'] = np.abs(sess.run(autoencoder_network['Xi'])) > params['coefficient_threshold'] validation_dict['coefficient_mask:0'] = params['coefficient_mask'] print('THRESHOLDING: %d active coefficients' % np.sum(params['coefficient_mask'])) print('REFINEMENT') for i_refinement in range(params['refinement_epochs']): for j in range(params['epoch_size']//params['batch_size']): batch_idxs = np.arange(j*params['batch_size'], (j+1)*params['batch_size']) train_dict = create_feed_dictionary(training_data, params, idxs=batch_idxs) sess.run(train_op_refinement, feed_dict=train_dict) if params['print_progress'] and (i_refinement % params['print_frequency'] == 0): print_progress(sess, i_refinement, loss_refinement, losses, train_dict, validation_dict, x_norm, sindy_predict_norm) saver.save(sess, params['data_path'] + params['save_name']) pickle.dump(params, open(params['data_path'] + params['save_name'] + '_params.pkl', 'wb')) decoder_losses = sess.run((losses['decoder'], losses['sindy_x']), feed_dict=validation_dict) regularization_loss = sess.run(losses['sindy_regularization'], feed_dict=validation_dict) return i, x_norm, sindy_predict_norm, decoder_losses[0], decoder_losses[1], regularization_loss
data_path = os.getcwd() + '/' example_problem = sys.argv[1] save_name = data_path + sys.argv[2] if example_problem == 'lorenz': test_data = get_lorenz_data(100) elif example_problem == 'pendulum': test_data = get_pendulum_data(50) else: test_data = get_rd_data()[2] params = pickle.load(open(save_name + '_params.pkl', 'rb')) test_dict = create_feed_dictionary(test_data, params) autoencoder_network = full_network(params) learning_rate = tf.placeholder(tf.float32, name='learning_rate') saver = tf.train.Saver( var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)) run_tuple = () for key in autoencoder_network.keys(): run_tuple += (autoencoder_network[key], ) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) saver.restore(sess, save_name) tf_results = sess.run(run_tuple, feed_dict=test_dict) results = {}