Ejemplo n.º 1
0
Archivo: nns.py Proyecto: tricky61/nns
def reset_keras(model):
    """Reset Keras Session.
        https://forums.fast.ai/t/how-could-i-release-gpu-memory-of-keras/2023/18
        https://github.com/keras-team/keras/issues/12625
    """
    sess = backend.get_session()
    backend.clear_session()
    sess.close()
    try:
        del model
    except:
        pass
    gc.collect()
    backend.set_session(Session(config=ConfigProto()))
Ejemplo n.º 2
0
def GenerateModelV1(tf_saved_model_dir, tftrt_saved_model_dir):
    """Generate and convert a model using TFv1 API."""
    def SimpleModel():
        """Define model with a TF graph."""
        def GraphFn():
            input1 = array_ops.placeholder(dtype=dtypes.float32,
                                           shape=[None, 1, 1],
                                           name="input1")
            input2 = array_ops.placeholder(dtype=dtypes.float32,
                                           shape=[None, 1, 1],
                                           name="input2")
            var = variables.Variable([[[1.0]]],
                                     dtype=dtypes.float32,
                                     name="v1")
            out = GetGraph(input1, input2, var)
            return g, var, input1, input2, out

        g = ops.Graph()
        with g.as_default():
            return GraphFn()

    g, var, input1, input2, out = SimpleModel()
    signature_def = signature_def_utils.build_signature_def(
        inputs={
            "input1": utils.build_tensor_info(input1),
            "input2": utils.build_tensor_info(input2)
        },
        outputs={"output": utils.build_tensor_info(out)},
        method_name=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY)
    saved_model_builder = builder.SavedModelBuilder(tf_saved_model_dir)
    with Session(graph=g) as sess:
        sess.run(var.initializer)
        saved_model_builder.add_meta_graph_and_variables(
            sess, [tag_constants.SERVING],
            signature_def_map={
                signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
                signature_def
            })
    saved_model_builder.save()

    # Convert TF model to TensorRT
    converter = trt_convert.TrtGraphConverter(
        input_saved_model_dir=tf_saved_model_dir, is_dynamic_op=True)
    converter.convert()
    converter.save(tftrt_saved_model_dir)
Ejemplo n.º 3
0
    def run():
        # 创建一个变量, 初始化为标量 0.
        state = Variable(0, name="counter")

        # 创建一个 op, 其作用是使 state 增加 1

        one = constant(1)
        new_value = add(state, one)
        update = assign(state, new_value)

        # 启动图后, 变量必须先经过`初始化` (init) op 初始化,
        # 首先必须增加一个`初始化` op 到图中.
        init_op = initialize_all_variables()

        # 启动图, 运行 op
        with Session() as sess:
            # 运行 'init' op
            sess.run(init_op)
            # 打印 'state' 的初始值
            print(sess.run(state))
            # 运行 op, 更新 'state', 并打印 'state'
            for _ in range(3):
                sess.run(update)
                print(sess.run(state))
Ejemplo n.º 4
0
import random
import numpy as np
from statistics import median, mean
from collections import Counter

from tensorflow.core.protobuf.config_pb2 import ConfigProto, GPUOptions
from tensorflow.python import Session
from tensorflow.python.keras import Sequential
from tensorflow.python.keras.backend import set_session
from tensorflow.python.keras.layers import Dense, Dropout
from tensorflow.python.keras.optimizer_v2.adam import Adam

config = ConfigProto(gpu_options=GPUOptions(
    per_process_gpu_memory_fraction=0.8))
config.gpu_options.allow_growth = True
session = Session(config=config)
set_session(session)

LR = 1e-3
env = gym.make("CartPole-v0")
env.reset()
goal_steps = 500
score_requirement = 50
initial_games = 10000


def initial_population():
    # [OBS, MOVES]
    training_data = []
    # all scores:
    scores = []
Ejemplo n.º 5
0
 def config_tensorflow():
     config = ConfigProto(gpu_options=GPUOptions(
         per_process_gpu_memory_fraction=0.8))
     config.gpu_options.allow_growth = True
     session = Session(config=config)
     set_session(session)