コード例 #1
0
def restore_ae(data, graph_path, grid, frame=-1):
    tf.reset_default_graph()
    graph = tf.get_default_graph()
    graph_handle = util.find_graph(graph_path)

    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        graph_handle.restore(sess, tf.train.latest_checkpoint(config.path_e))
        inputs = fetch_data.get_volume(config.benchmark_data,
                                       time_idx=frame,
                                       batch_size=1,
                                       scaling_factor=1)

        v = sess.run(graph.get_tensor_by_name('Decoder/decoder:0'),
                     feed_dict=inputs)
        util.contour(v, inputs['velocity:0'])
コード例 #2
0
ファイル: train.py プロジェクト: asgerMe/Deep-Earth-
def train_integrator():
    tf.reset_default_graph()
    graph = tf.get_default_graph()
    graph_handle = 0
    int_net = 0

    if config.train_integrator_network:
        if not config.conv:
            int_net = nn.IntegratorNetwork(
                param_state_size=config.param_state_size,
                sdf_state_size=config.sdf_state)
        else:
            int_net = nn.Convo_IntegratorNetwork(
                config.data_size,
                param_state_size=config.field_state,
                sdf_state_size=config.sdf_state)

    init = tf.global_variables_initializer()

    SCF = 1  #fetch_data.get_scaling_factor(config.data_path)

    for file in os.listdir(config.path_e):
        print(file)
        if file.endswith('.ckpt.meta'):
            try:
                graph_handle = tf.train.import_meta_graph(
                    os.path.join(config.path_e, file))
                print(graph)
                break
            except IOError:
                print('Cant import graph')
                exit()

    with tf.Session() as sess:
        sess.run(init)

        sub_dir = os.path.join(config.tensor_board,
                               'integrator_' + time.strftime("%Y%m%d-%H%M%S"))
        os.mkdir(sub_dir)
        writer = tf.summary.FileWriter(sub_dir)
        writer.add_graph(sess.graph)

        saver = tf.train.Saver(tf.global_variables())
        store_integrator_loss_tb = 0
        store_integrator_loss = -1

        graph_handle.restore(sess, tf.train.latest_checkpoint(config.path_e))
        print('Model restored')

        full_encoding = graph.get_tensor_by_name(
            "Latent_State/full_encoding:0")
        encoded_sdf = graph.get_tensor_by_name(
            "Boundary_conditions/encoded_sdf:0")

        for i in range(config.training_runs):
            input_sequence, input_0 = fetch_data.get_volume(
                config.data_path,
                1,
                sequential=True,
                sequence_length=config.sequence_length)

            start_encoding = sess.run(full_encoding, input_0)
            sdf_encodings, label_encodings = sess.run(
                [encoded_sdf, full_encoding], input_sequence)

            integrator_feed_dict = {
                'label_encodings:0': label_encodings,
                'sdf_encodings:0': sdf_encodings,
                'start_encoding:0': start_encoding,
                "sequence_length:0": config.sequence_length
            }

            if i % config.f_tensorboard == 0 and config.f_tensorboard != 0 and os.path.isdir(
                    config.tensor_board):
                _, int_loss, merged_int = sess.run(
                    [int_net.train_int, int_net.loss_int, int_net.merged_int],
                    integrator_feed_dict)
                writer.add_summary(merged_int, i)
            else:
                _, int_loss = sess.run([int_net.train_int, int_net.loss_int],
                                       integrator_feed_dict)

            if not i % 10:
                print('Training Run', i, 'Learning Rate', config.lr_max,
                      '//  Encoder Loss:', -1, '//  Integrator Loss', int_loss)

            if config.save_freq and os.path.isdir(config.path_i):
                if config.meta_graphs and i % config.save_freq == 0 and config.save_freq > 2:
                    saver.save(
                        sess,
                        os.path.join(config.path_i,
                                     "trained_integrator_model.ckpt"))
                    print('Saving graph')
コード例 #3
0
ファイル: train.py プロジェクト: asgerMe/Deep-Earth-
def train_network():
    tf.reset_default_graph()
    net = nn.NetWork(config.data_size,
                     param_state_size=config.param_state_size)
    init = tf.global_variables_initializer()

    SCF = 1  #fetch_data.get_scaling_factor(config.data_path)

    with tf.Session() as sess:
        sess.run(init)

        sub_dir = os.path.join(
            config.tensor_board,
            'TS' + util.get_name_ext() + time.strftime("%Y%m%d-%H%M%S"))
        sub_dir_test = os.path.join(
            config.tensor_board,
            'TS_Test' + util.get_name_ext() + time.strftime("%Y%m%d-%H%M%S"))
        os.mkdir(sub_dir)
        writer = tf.summary.FileWriter(sub_dir)
        writer_test = tf.summary.FileWriter(sub_dir_test)
        writer.add_graph(sess.graph)

        saver = tf.train.Saver(tf.global_variables())
        graph_handle = util.find_graph(config.path_e)

        store_integrator_loss = -1

        if graph_handle:
            graph_handle.restore(sess,
                                 tf.train.latest_checkpoint(config.path_e))

        for i in range(config.training_runs):

            inputs = fetch_data.get_volume(config.data_path,
                                           batch_size=config.batch_size,
                                           scaling_factor=SCF)
            inputs['Train/step:0'] = i

            if i % config.f_tensorboard == 0 and config.f_tensorboard != 0 and os.path.isdir(
                    config.tensor_board):
                loss, lr, merged, _ = sess.run(
                    [net.loss, net.lr, net.merged, net.train], inputs)
                writer.add_summary(merged, i)
            else:
                loss, lr, _ = sess.run([net.loss, net.lr, net.train], inputs)

            if os.path.isdir(config.path_e) and i % config.save_freq == 0:
                saver.save(
                    sess,
                    os.path.join(config.path_e,
                                 util.get_name_ext() + "_trained_model.ckpt"))
                print('Saving graph')

            if i % 500 == 0 and os.path.isdir(config.tensor_board):
                inputs_ci = fetch_data.get_volume(config.benchmark_data,
                                                  1,
                                                  scaling_factor=SCF)
                if inputs_ci:
                    util.create_gif_encoder(config.benchmark_data,
                                            sess,
                                            net,
                                            i=i,
                                            save_frequency=5000,
                                            SCF=SCF)
                if not inputs_ci:
                    inputs_ci = fetch_data.get_volume(config.data_path,
                                                      1,
                                                      scaling_factor=SCF)
                    util.create_gif_encoder(config.data_path,
                                            sess,
                                            net,
                                            i=i,
                                            save_frequency=5000,
                                            SCF=SCF)

                inputs_ci['Train/step:0'] = i

                loss, merged = sess.run([net.loss, net.merged], inputs_ci)
                writer_test.add_summary(merged, i)
                test_field = sess.run(net.y, inputs_ci)
                np.save(
                    os.path.join(
                        config.test_field_path, 'train_field' +
                        util.get_name_ext() + time.strftime("%Y%m%d-%H%M")),
                    test_field)

                # Record execution stats
                run_options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
                run_metadata = tf.RunMetadata()
                _ = sess.run([net.train],
                             feed_dict=inputs,
                             options=run_options,
                             run_metadata=run_metadata)
                writer.add_run_metadata(run_metadata, 'step%d' % i)

            if not i % 10:
                print('Training Run', i, 'Learning Rate', lr,
                      '//  Encoder Loss:', loss, '//  Integrator Loss',
                      store_integrator_loss)
コード例 #4
0
ファイル: util.py プロジェクト: asgerMe/Deep-Earth-
def create_gif_integrator(sess,
                          net,
                          autoencoder_graph,
                          roll_out,
                          i=0,
                          gif_length=500,
                          save_frequency=5000,
                          SCF=1,
                          restore=False):

    if not i % save_frequency and i > 1:
        search_dir = config.data_path
        if os.path.isdir(config.alt_dir):
            search_dir = config.alt_dir
        try:
            video_length = gif_length
            MOVIE = []

            sdf = autoencoder_graph.get_tensor_by_name(
                "Boundary_conditions/encoded_sdf:0")
            full_encoding = autoencoder_graph.get_tensor_by_name(
                "Latent_State/full_encoding:0")
            reconstructed_v = autoencoder_graph.get_tensor_by_name(
                "Decoder/decoder:0")
            v_next = ''
            next_encoding = 0
            for F in range(gif_length):

                try:
                    input_i = fetch_data.get_volume(search_dir,
                                                    batch_size=1,
                                                    time_idx=F,
                                                    scaling_factor=SCF)
                except IndexError:
                    print(
                        'index out of range -> creating gif with stashed frames'
                    )
                    break

                if F == 0:
                    full_enc = sess.run(full_encoding, feed_dict=input_i)
                else:
                    full_enc = next_encoding
                try:
                    input_next = fetch_data.get_volume(search_dir,
                                                       batch_size=1,
                                                       time_idx=(1 + F),
                                                       scaling_factor=SCF)
                except IndexError:
                    print(
                        'index out of range -> creating gif with stashed frames'
                    )
                    break

                next_encoded_sdf = sess.run(sdf, feed_dict=input_next)

                integrator_feed_dict = {
                    'parameter_encodings:0': next_encoded_sdf,
                    'start_encoding:0': [full_enc],
                    "sequence_length:0": 1
                }

                next_encoding = sess.run(net.full_encoding,
                                         feed_dict=integrator_feed_dict)

                v_next = sess.run(reconstructed_v,
                                  feed_dict={
                                      'sdf:0':
                                      input_next['sdf:0'],
                                      'Latent_State/full_encoding:0':
                                      np.squeeze(next_encoding, axis=1)
                                  })

                image = np.linalg.norm(np.squeeze(v_next[0, :, 16, :, :]),
                                       axis=2)
                MOVIE.append(image)

            path = os.path.join(config.gif_path,
                                'integrator_vel_field_' + str(i) + '.gif')
            MOVIE = np.uint8(255 * (MOVIE - np.amin(MOVIE)) /
                             (np.amax(MOVIE) - np.amin(MOVIE)))
            imageio.mimwrite(path, MOVIE)

            print(
                'gif saved at:',
                os.path.join(config.gif_path,
                             'integrator_vel_field_' + str(i) + '.gif'))
        except OSError:
            print(
                'No valid .npy test file in output dir / alternative dir not found'
            )
コード例 #5
0
ファイル: util.py プロジェクト: asgerMe/Deep-Earth-
def create_gif_encoder(path,
                       sess,
                       net,
                       i=0,
                       gif_length=2000,
                       save_frequency=5000,
                       SCF=1,
                       restore=False):
    viridis = cm.get_cmap('inferno', 12)
    if not i % save_frequency:
        search_dir = path
        try:
            MOVIE = []
            diff_MOVIE = []
            print('Generating gif')
            for F in range(gif_length):
                print(F)
                try:
                    test_input = fetch_data.get_volume(search_dir,
                                                       batch_size=1,
                                                       time_idx=F,
                                                       scaling_factor=SCF)
                    if not test_input:
                        break
                except IndexError:
                    print(
                        'index out of range -> creating gif with stashed frames'
                    )
                    break
                try:
                    network = net.y
                    diff = net.d_labels
                except:
                    network = net.get_tensor_by_name('Decoder/decoder:0')
                    diff = net.get_tensor_by_name('Differentiate/diff:0')

                reconstructed_vel = sess.run(network, feed_dict=test_input)
                image = np.linalg.norm(np.squeeze(
                    reconstructed_vel[0, :,
                                      int(config.data_size / 2), :, :]),
                                       axis=2)

                image_gt = np.linalg.norm(np.squeeze(
                    test_input['velocity:0'][0, :,
                                             int(config.data_size / 2), :, :]),
                                          axis=2)
                image_diffs = sess.run(diff, feed_dict=test_input)

                test_input['velocity:0'] *= 0

                reconstructed_vel0 = sess.run(network, feed_dict=test_input)
                image0 = np.linalg.norm(np.squeeze(
                    reconstructed_vel0[0, :,
                                       int(config.data_size / 2), :, :]),
                                        axis=2)

                full_image = np.concatenate((image_gt, image, image0), axis=1)
                diff_image = []
                for diffs_num in range(np.shape(image_diffs)[4]):
                    diff = np.squeeze(image_diffs[0, :, 16, :, diffs_num])

                    if np.amax(diff) > 0:
                        diff = (np.uint8(255 *
                                         ((diff - np.amin(diff)) /
                                          (np.amax(diff) - np.amin(diff)))))
                    else:
                        diff = (np.uint8(diff))

                    if diffs_num > 0:
                        diff_image = np.abs(
                            np.concatenate((diff_image, diff), axis=1))
                    else:
                        diff_image = np.abs(diff)

                MOVIE.append(full_image)
                diff_MOVIE.append(diff_image)

            path = os.path.join(
                config.gif_path,
                'test_vel_field_' + get_name_ext() + str(i) + '.gif')
            MOVIE = np.uint8(255 * viridis(
                (MOVIE - np.amin(MOVIE)) / (np.amax(MOVIE) - np.amin(MOVIE))))
            imageio.mimwrite(path, MOVIE)

            diff_path = os.path.join(
                config.gif_path,
                'diff_vel_field_' + get_name_ext() + str(i) + '.gif')

            imageio.mimwrite(diff_path, diff_MOVIE)

            print(
                'gif saved at:',
                os.path.join(
                    config.gif_path,
                    'test_vel_field_' + get_name_ext() + str(i) + '.gif'))

        except OSError:
            print(
                'No valid .npy test file in output dir / alternative dir not found'
            )