예제 #1
0
def train_network(training_data, val_data, params):
    # SET UP NETWORK
    autoencoder_network = full_network(params)
    loss, losses, loss_refinement = define_loss(autoencoder_network, params)
    learning_rate = tf.placeholder(tf.float32, name='learning_rate')
    train_op = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss)
    train_op_refinement = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss_refinement)
    saver = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))

    validation_dict = create_feed_dictionary(val_data, params, idxs=None)

    x_norm = np.mean(val_data['u']**2)
    if params['model_order'] == 1:
        sindy_predict_norm = np.mean(val_data['du']**2)
    else:
        sindy_predict_norm = np.mean(val_data['ddu']**2)

    print('TRAINING')
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for i in range(params['max_epochs']):
            for j in range(params['epoch_size']//params['batch_size']):
                batch_idxs = np.arange(j*params['batch_size'], (j+1)*params['batch_size'])
                train_dict = create_feed_dictionary(training_data, params, idxs=batch_idxs)
                sess.run(train_op, feed_dict=train_dict)
            
            if params['print_progress'] and (i % params['print_frequency'] == 0):
                print_progress(sess, i, loss, losses, train_dict, validation_dict, x_norm, sindy_predict_norm)

            if params['sequential_thresholding'] and (i % params['threshold_frequency'] == 0) and (i > 0):
                params['coefficient_mask'] = np.abs(sess.run(autoencoder_network['Xi'])) > params['coefficient_threshold']
                validation_dict['coefficient_mask:0'] = params['coefficient_mask']
                print('THRESHOLDING: %d active coefficients' % np.sum(params['coefficient_mask']))

        print('REFINEMENT')
        for i_refinement in range(params['refinement_epochs']):
            for j in range(params['epoch_size']//params['batch_size']):
                batch_idxs = np.arange(j*params['batch_size'], (j+1)*params['batch_size'])
                train_dict = create_feed_dictionary(training_data, params, idxs=batch_idxs)
                sess.run(train_op_refinement, feed_dict=train_dict)
            
            if params['print_progress'] and (i_refinement % params['print_frequency'] == 0):
                print_progress(sess, i_refinement, loss_refinement, losses, train_dict, validation_dict, x_norm, sindy_predict_norm)

        saver.save(sess, params['data_path'] + params['save_name'])
        pickle.dump(params, open(params['data_path'] + params['save_name'] + '_params.pkl', 'wb'))
        decoder_losses = sess.run((losses['decoder'], losses['sindy_x']), feed_dict=validation_dict)
        regularization_loss = sess.run(losses['sindy_regularization'], feed_dict=validation_dict)

        return i, x_norm, sindy_predict_norm, decoder_losses[0], decoder_losses[1], regularization_loss
data_path = os.getcwd() + '/'
example_problem = sys.argv[1]
save_name = data_path + sys.argv[2]

if example_problem == 'lorenz':
    test_data = get_lorenz_data(100)
elif example_problem == 'pendulum':
    test_data = get_pendulum_data(50)
else:
    test_data = get_rd_data()[2]

params = pickle.load(open(save_name + '_params.pkl', 'rb'))
test_dict = create_feed_dictionary(test_data, params)

autoencoder_network = full_network(params)
learning_rate = tf.placeholder(tf.float32, name='learning_rate')
saver = tf.train.Saver(
    var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))

run_tuple = ()
for key in autoencoder_network.keys():
    run_tuple += (autoencoder_network[key], )

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.restore(sess, save_name)

    tf_results = sess.run(run_tuple, feed_dict=test_dict)

results = {}