def restore_ae(data, graph_path, grid, frame=-1): tf.reset_default_graph() graph = tf.get_default_graph() graph_handle = util.find_graph(graph_path) init = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init) graph_handle.restore(sess, tf.train.latest_checkpoint(config.path_e)) inputs = fetch_data.get_volume(config.benchmark_data, time_idx=frame, batch_size=1, scaling_factor=1) v = sess.run(graph.get_tensor_by_name('Decoder/decoder:0'), feed_dict=inputs) util.contour(v, inputs['velocity:0'])
def train_integrator(): tf.reset_default_graph() graph = tf.get_default_graph() graph_handle = 0 int_net = 0 if config.train_integrator_network: if not config.conv: int_net = nn.IntegratorNetwork( param_state_size=config.param_state_size, sdf_state_size=config.sdf_state) else: int_net = nn.Convo_IntegratorNetwork( config.data_size, param_state_size=config.field_state, sdf_state_size=config.sdf_state) init = tf.global_variables_initializer() SCF = 1 #fetch_data.get_scaling_factor(config.data_path) for file in os.listdir(config.path_e): print(file) if file.endswith('.ckpt.meta'): try: graph_handle = tf.train.import_meta_graph( os.path.join(config.path_e, file)) print(graph) break except IOError: print('Cant import graph') exit() with tf.Session() as sess: sess.run(init) sub_dir = os.path.join(config.tensor_board, 'integrator_' + time.strftime("%Y%m%d-%H%M%S")) os.mkdir(sub_dir) writer = tf.summary.FileWriter(sub_dir) writer.add_graph(sess.graph) saver = tf.train.Saver(tf.global_variables()) store_integrator_loss_tb = 0 store_integrator_loss = -1 graph_handle.restore(sess, tf.train.latest_checkpoint(config.path_e)) print('Model restored') full_encoding = graph.get_tensor_by_name( "Latent_State/full_encoding:0") encoded_sdf = graph.get_tensor_by_name( "Boundary_conditions/encoded_sdf:0") for i in range(config.training_runs): input_sequence, input_0 = fetch_data.get_volume( config.data_path, 1, sequential=True, sequence_length=config.sequence_length) start_encoding = sess.run(full_encoding, input_0) sdf_encodings, label_encodings = sess.run( [encoded_sdf, full_encoding], input_sequence) integrator_feed_dict = { 'label_encodings:0': label_encodings, 'sdf_encodings:0': sdf_encodings, 'start_encoding:0': start_encoding, "sequence_length:0": config.sequence_length } if i % config.f_tensorboard == 0 and config.f_tensorboard != 0 and os.path.isdir( config.tensor_board): _, int_loss, merged_int = sess.run( [int_net.train_int, int_net.loss_int, int_net.merged_int], integrator_feed_dict) writer.add_summary(merged_int, i) else: _, int_loss = sess.run([int_net.train_int, int_net.loss_int], integrator_feed_dict) if not i % 10: print('Training Run', i, 'Learning Rate', config.lr_max, '// Encoder Loss:', -1, '// Integrator Loss', int_loss) if config.save_freq and os.path.isdir(config.path_i): if config.meta_graphs and i % config.save_freq == 0 and config.save_freq > 2: saver.save( sess, os.path.join(config.path_i, "trained_integrator_model.ckpt")) print('Saving graph')
def train_network(): tf.reset_default_graph() net = nn.NetWork(config.data_size, param_state_size=config.param_state_size) init = tf.global_variables_initializer() SCF = 1 #fetch_data.get_scaling_factor(config.data_path) with tf.Session() as sess: sess.run(init) sub_dir = os.path.join( config.tensor_board, 'TS' + util.get_name_ext() + time.strftime("%Y%m%d-%H%M%S")) sub_dir_test = os.path.join( config.tensor_board, 'TS_Test' + util.get_name_ext() + time.strftime("%Y%m%d-%H%M%S")) os.mkdir(sub_dir) writer = tf.summary.FileWriter(sub_dir) writer_test = tf.summary.FileWriter(sub_dir_test) writer.add_graph(sess.graph) saver = tf.train.Saver(tf.global_variables()) graph_handle = util.find_graph(config.path_e) store_integrator_loss = -1 if graph_handle: graph_handle.restore(sess, tf.train.latest_checkpoint(config.path_e)) for i in range(config.training_runs): inputs = fetch_data.get_volume(config.data_path, batch_size=config.batch_size, scaling_factor=SCF) inputs['Train/step:0'] = i if i % config.f_tensorboard == 0 and config.f_tensorboard != 0 and os.path.isdir( config.tensor_board): loss, lr, merged, _ = sess.run( [net.loss, net.lr, net.merged, net.train], inputs) writer.add_summary(merged, i) else: loss, lr, _ = sess.run([net.loss, net.lr, net.train], inputs) if os.path.isdir(config.path_e) and i % config.save_freq == 0: saver.save( sess, os.path.join(config.path_e, util.get_name_ext() + "_trained_model.ckpt")) print('Saving graph') if i % 500 == 0 and os.path.isdir(config.tensor_board): inputs_ci = fetch_data.get_volume(config.benchmark_data, 1, scaling_factor=SCF) if inputs_ci: util.create_gif_encoder(config.benchmark_data, sess, net, i=i, save_frequency=5000, SCF=SCF) if not inputs_ci: inputs_ci = fetch_data.get_volume(config.data_path, 1, scaling_factor=SCF) util.create_gif_encoder(config.data_path, sess, net, i=i, save_frequency=5000, SCF=SCF) inputs_ci['Train/step:0'] = i loss, merged = sess.run([net.loss, net.merged], inputs_ci) writer_test.add_summary(merged, i) test_field = sess.run(net.y, inputs_ci) np.save( os.path.join( config.test_field_path, 'train_field' + util.get_name_ext() + time.strftime("%Y%m%d-%H%M")), test_field) # Record execution stats run_options = tf.RunOptions( trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() _ = sess.run([net.train], feed_dict=inputs, options=run_options, run_metadata=run_metadata) writer.add_run_metadata(run_metadata, 'step%d' % i) if not i % 10: print('Training Run', i, 'Learning Rate', lr, '// Encoder Loss:', loss, '// Integrator Loss', store_integrator_loss)
def create_gif_integrator(sess, net, autoencoder_graph, roll_out, i=0, gif_length=500, save_frequency=5000, SCF=1, restore=False): if not i % save_frequency and i > 1: search_dir = config.data_path if os.path.isdir(config.alt_dir): search_dir = config.alt_dir try: video_length = gif_length MOVIE = [] sdf = autoencoder_graph.get_tensor_by_name( "Boundary_conditions/encoded_sdf:0") full_encoding = autoencoder_graph.get_tensor_by_name( "Latent_State/full_encoding:0") reconstructed_v = autoencoder_graph.get_tensor_by_name( "Decoder/decoder:0") v_next = '' next_encoding = 0 for F in range(gif_length): try: input_i = fetch_data.get_volume(search_dir, batch_size=1, time_idx=F, scaling_factor=SCF) except IndexError: print( 'index out of range -> creating gif with stashed frames' ) break if F == 0: full_enc = sess.run(full_encoding, feed_dict=input_i) else: full_enc = next_encoding try: input_next = fetch_data.get_volume(search_dir, batch_size=1, time_idx=(1 + F), scaling_factor=SCF) except IndexError: print( 'index out of range -> creating gif with stashed frames' ) break next_encoded_sdf = sess.run(sdf, feed_dict=input_next) integrator_feed_dict = { 'parameter_encodings:0': next_encoded_sdf, 'start_encoding:0': [full_enc], "sequence_length:0": 1 } next_encoding = sess.run(net.full_encoding, feed_dict=integrator_feed_dict) v_next = sess.run(reconstructed_v, feed_dict={ 'sdf:0': input_next['sdf:0'], 'Latent_State/full_encoding:0': np.squeeze(next_encoding, axis=1) }) image = np.linalg.norm(np.squeeze(v_next[0, :, 16, :, :]), axis=2) MOVIE.append(image) path = os.path.join(config.gif_path, 'integrator_vel_field_' + str(i) + '.gif') MOVIE = np.uint8(255 * (MOVIE - np.amin(MOVIE)) / (np.amax(MOVIE) - np.amin(MOVIE))) imageio.mimwrite(path, MOVIE) print( 'gif saved at:', os.path.join(config.gif_path, 'integrator_vel_field_' + str(i) + '.gif')) except OSError: print( 'No valid .npy test file in output dir / alternative dir not found' )
def create_gif_encoder(path, sess, net, i=0, gif_length=2000, save_frequency=5000, SCF=1, restore=False): viridis = cm.get_cmap('inferno', 12) if not i % save_frequency: search_dir = path try: MOVIE = [] diff_MOVIE = [] print('Generating gif') for F in range(gif_length): print(F) try: test_input = fetch_data.get_volume(search_dir, batch_size=1, time_idx=F, scaling_factor=SCF) if not test_input: break except IndexError: print( 'index out of range -> creating gif with stashed frames' ) break try: network = net.y diff = net.d_labels except: network = net.get_tensor_by_name('Decoder/decoder:0') diff = net.get_tensor_by_name('Differentiate/diff:0') reconstructed_vel = sess.run(network, feed_dict=test_input) image = np.linalg.norm(np.squeeze( reconstructed_vel[0, :, int(config.data_size / 2), :, :]), axis=2) image_gt = np.linalg.norm(np.squeeze( test_input['velocity:0'][0, :, int(config.data_size / 2), :, :]), axis=2) image_diffs = sess.run(diff, feed_dict=test_input) test_input['velocity:0'] *= 0 reconstructed_vel0 = sess.run(network, feed_dict=test_input) image0 = np.linalg.norm(np.squeeze( reconstructed_vel0[0, :, int(config.data_size / 2), :, :]), axis=2) full_image = np.concatenate((image_gt, image, image0), axis=1) diff_image = [] for diffs_num in range(np.shape(image_diffs)[4]): diff = np.squeeze(image_diffs[0, :, 16, :, diffs_num]) if np.amax(diff) > 0: diff = (np.uint8(255 * ((diff - np.amin(diff)) / (np.amax(diff) - np.amin(diff))))) else: diff = (np.uint8(diff)) if diffs_num > 0: diff_image = np.abs( np.concatenate((diff_image, diff), axis=1)) else: diff_image = np.abs(diff) MOVIE.append(full_image) diff_MOVIE.append(diff_image) path = os.path.join( config.gif_path, 'test_vel_field_' + get_name_ext() + str(i) + '.gif') MOVIE = np.uint8(255 * viridis( (MOVIE - np.amin(MOVIE)) / (np.amax(MOVIE) - np.amin(MOVIE)))) imageio.mimwrite(path, MOVIE) diff_path = os.path.join( config.gif_path, 'diff_vel_field_' + get_name_ext() + str(i) + '.gif') imageio.mimwrite(diff_path, diff_MOVIE) print( 'gif saved at:', os.path.join( config.gif_path, 'test_vel_field_' + get_name_ext() + str(i) + '.gif')) except OSError: print( 'No valid .npy test file in output dir / alternative dir not found' )