コード例 #1
0
ファイル: main.py プロジェクト: T4vi/LAIF_task
def main():
    para = params_setup()
    logging_config_setup(para)

    logging.info('Creating graph')
    graph, model, data_generator = create_graph(para)

    with tf.Session(config=config_setup(), graph=graph) as sess:
        sess.run(tf.global_variables_initializer())
        logging.info('Loading weights')
        load_weights(para, sess, model)
        print_num_of_trainable_parameters()

        try:
            if para.mode == 'train':
                logging.info('Started training')
                train(para, sess, model, data_generator)
                if para.save_final_model_path != '':
                    save_weights(sess, model, para.save_final_model_path)
            elif para.mode == 'validation':
                logging.info('Started validation')
                test(para, sess, model, data_generator)
            elif para.mode == 'test':
                logging.info('Started testing')
                test(para, sess, model, data_generator)
            elif para.mode == 'predict':
                logging.info('Predicting')
                predict(para, sess, model, data_generator,
                        './data/solar-energy3/solar_predict.txt', para.samples)

        except KeyboardInterrupt:
            print('KeyboardInterrupt')
        finally:
            print('Stop')
コード例 #2
0
def main():
    para = params_setup()
    logging_config_setup(para)

    print("Creating graph...")
    graph, model, data_generator = create_graph(para)
    print("Done creating graph.")

    with tf.Session(config=config_setup(), graph=graph) as sess:
        sess.run(tf.global_variables_initializer())
        print("Loading weights...")
        load_weights(para, sess, model)
        print_num_of_trainable_parameters()

        # PRINT NAMES OF TENSORS THAT ARE ALPHAS
        # example name: "model/rnn/cond/rnn/multi_rnn_cell/cell_0/cell_0/temporal_pattern_attention_cell_wrapper/attention/Sigmoid:0"
        # for item in [n.name for n in tf.get_default_graph().as_graph_def().node
        #              if (n.name.find("temporal_pattern_attention_cell_wrapper/attention")!=-1 and
        #                  n.name.find("Sigmoid")!=-1)]:
        #     print(item)

        # Print names of ops
        # for op in tf.get_default_graph().get_operations():
        #     if(op.name.find("ben_multiply")!=-1):
        #         print(str(op.name))

        # PRINT REG KERNEL AND BIAS
        # reg_weights = [v for v in tf.global_variables() if v.name == "model/dense_2/kernel:0"][0]
        # reg_bias = [v for v in tf.global_variables() if v.name == "model/dense_2/bias:0"][0]
        # print("Reg Weights:", sess.run(reg_weights))
        # print("Reg Bias:", sess.run(reg_bias) * data_generator.scale[0])

        try:
            if para.mode == 'train':
                train(para, sess, model, data_generator)
            elif para.mode == 'test':
                print("Evaluating model...")
                test(para, sess, model, data_generator)

        except KeyboardInterrupt:
            print('KeyboardInterrupt')
        finally:
            print('Stop')
コード例 #3
0
ファイル: main.py プロジェクト: zyf0220/TPA-LSTM
def main():
    para = params_setup()
    logging_config_setup(para)

    graph, model, data_generator = create_graph(para)

    with tf.Session(config=config_setup(), graph=graph) as sess:
        sess.run(tf.global_variables_initializer())
        load_weights(para, sess, model)
        print_num_of_trainable_parameters()

        try:
            if para.mode == 'train':
                train(para, sess, model, data_generator)
            elif para.mode == 'test':
                test(para, sess, model, data_generator)

        except KeyboardInterrupt:
            print('KeyboardInterrupt')
        finally:
            print('Stop')
コード例 #4
0
def main():
    para = sess_params_setup()
    logging_config_setup(para)

    graph, model, data_generator = create_graph(para)

    with tf.Session(config=config_setup(), graph=graph) as sess:
        sess.run(tf.global_variables_initializer())
        load_weights(para, sess, model)
        print_num_of_trainable_parameters()

        try:

            all_inputs, all_labels, all_outputs = train(
                para, sess, model, data_generator)

            return all_inputs, all_labels, all_outputs, model

            # all_inputs = all_inputs.permute(0,2,1)
            #
            # modelSize = 500
            # sampleSize = 500
            #
            # model_input = [tf.convert_to_tensor(x, dtype=tf.float32) for x in all_inputs[:modelSize]]
            # model_output = tf.convert_to_tensor(all_outputs[:modelSize], dtype=tf.float32)
            #
            # data = [x for x in all_inputs[:sampleSize]]
            # X = [x for x in all_inputs[:modelSize]]
            #
            # model2 = (model_input, model_output)
            # explainer = shap.DeepExplainer(model2, data, sess)
            # shap_values = explainer.shap_values(X)
            #
            # return shap_values

        except KeyboardInterrupt:
            print('KeyboardInterrupt')
        finally:
            print('Stop')
def main():
    para = params_setup()
    logging_config_setup(para)

    graph, model, data_generator = create_graph(para)

    with tf.Session(config=config_setup(), graph=graph) as sess:
        sess.run(tf.global_variables_initializer())
        load_weights(para, sess, model)
        print_num_of_trainable_parameters()

        try:
            # EXTRACT WEIGHTS HERE
            for variable in tf.global_variables(): # tf.trainable_variables():
                print(variable)

            # VIEW FETCHABLE OPS
            graph = tf.get_default_graph()
            print([op for op in parent_ops if graph.is_fetchable(op)])

        except KeyboardInterrupt:
            print('KeyboardInterrupt')
        finally:
            print('Weights extracted. Stop.')
コード例 #6
0
ファイル: main.py プロジェクト: frede-stg/TPA-LSTM
def main():
    para = params_setup()
    logging_config_setup(para)

    graph, model, data_generator = create_graph(para)

    with tf.Session(config=config_setup(), graph=graph) as sess:
        sess.run(tf.global_variables_initializer())
        load_weights(para, sess, model)
        print_num_of_trainable_parameters()

        try:
            if para.mode == 'train':
                train(para, sess, model, data_generator)
            elif para.mode == 'test':
                obs, predicted = test(para, sess, model, data_generator)
                obs = obs * data_generator.scale + data_generator.min_value
                predicted = predicted * data_generator.scale + data_generator.min_value
                print("MSE: ", mean_squared_error(obs[:, 0], predicted[:, 0]))
                idx = pd.DatetimeIndex(start='2016-10-16',
                                       end='2018-11-04',
                                       freq='W')
                obs_df = pd.DataFrame(data=obs[:, 0],
                                      columns=['Observed'],
                                      index=idx)
                pred_df = pd.DataFrame(data=predicted[:, 0],
                                       columns=['Predicted'],
                                       index=idx)
                df = pd.concat([obs_df, pred_df], axis=1)
                df.plot()
                plt.show()

        except KeyboardInterrupt:
            print('KeyboardInterrupt')
        finally:
            print('Stop')
コード例 #7
0
#%%
para = parser.parse_args(args=[])
para.logging_level = logging.INFO
logging_config_setup(para)

#%%
create_dir(para.model_dir)
create_dir(para.output_dir)
json_path = para.model_dir + '/parameters.json'
json.dump(vars(para), open(json_path, 'w'), indent=4)

# %%
graph = tf.Graph()
# %%
graph, model, data_generator = create_graph(para)

# %%
with tf.Session(config=config_setup(), graph=graph) as sess:
    sess.run(tf.global_variables_initializer())
    load_weights(para, sess, model)
    print_num_of_trainable_parameters()
    train(para, sess, model, data_generator)
# %%
para.mode = 'test'
graph, model, data_generator = create_graph(para)
with tf.Session(config=config_setup(), graph=graph) as sess:
    sess.run(tf.global_variables_initializer())
    load_weights(para, sess, model)
    print_num_of_trainable_parameters()
    test(para, sess, model, data_generator)
コード例 #8
0
from lib.model_utils import create_model_dir, load_weights, create_graph
from lib.setup import config_setup, logging_config_setup

from lib.pretrain import pretrain
from lib.rl import policy_gradient
from lib.test import test

if __name__ == "__main__":
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

    PARA = params_setup()
    create_model_dir(PARA)
    logging_config_setup(PARA)
    print_parameters(PARA)

    GRAPH, MODEL = create_graph(PARA)

    with tf.Session(config=config_setup(), graph=GRAPH) as sess:
        sess.run(tf.global_variables_initializer())
        load_weights(PARA, sess, MODEL)

        COORD = tf.train.Coordinator()
        THREADS = tf.train.start_queue_runners(sess=sess, coord=COORD)
        try:
            if PARA.mode == 'pretrain':
                pretrain(PARA, sess, MODEL)
            elif PARA.mode == 'rl':
                policy_gradient(PARA, sess, MODEL)
            elif PARA.mode == 'test':
                test(PARA, sess, MODEL)
        except KeyboardInterrupt: