예제 #1
0
 def wrap(f):
     try:
         if is_tf():
             gdef = tf.Graph().as_default()
             gdef.__enter__()
             tf_reset_session()
         executor = f()
         ts = time.time()
         for _ in range(n):
             executor()
         te = time.time()
         print 'func:%r %d times took: %2.4f sec' % (f.__name__, n, te-ts)
         if is_tf():
             gdef.__exit__(None, None, None)
     except Exception as e:
         print e
예제 #2
0
 def wrap(f):
     try:
         if is_tf():
             gdef = tf.Graph().as_default()
             gdef.__enter__()
             tf_reset_session()
         executor = f()
         ts = time.time()
         for _ in range(n):
             executor()
         te = time.time()
         print 'func:%r %d times took: %2.4f sec' % (f.__name__, n, te - ts)
         if is_tf():
             gdef.__exit__(None, None, None)
     except Exception as e:
         print e
예제 #3
0
def is_variable(x):
    if is_theano():
        return isinstance(x, theano.gof.Variable)
    elif is_cgt():
        return isinstance(x, cgt.core.Node)
    elif is_tf():
        return isinstance(x, (tf.Tensor, tf.Variable))
    else:
        import ipdb; ipdb.set_trace()
예제 #4
0
def is_variable(x):
    if is_theano():
        return isinstance(x, theano.gof.Variable)
    elif is_cgt():
        return isinstance(x, cgt.core.Node)
    elif is_tf():
        return isinstance(x, (tf.Tensor, tf.Variable))
    else:
        import ipdb
        ipdb.set_trace()
예제 #5
0
def get_inputs(outputs):
    if is_theano():
        return theano.gof.graph.inputs(outputs)
    elif is_cgt():
        outputs = list(outputs)
        return [node for node in cgt.core.topsorted(outputs) if node.is_input()]
    elif is_tf():
        outputs = list(outputs)
        return [node for node in _tf_topsorted(outputs) if _tf_is_input(node)]
    else:
        import ipdb; ipdb.set_trace()
예제 #6
0
def set_value(x, val):
    """
    Get parameter value from a shared variable.
    """
    if is_theano():
        x.set_value(val)
    elif is_cgt():
        x.op.set_value(val)
    elif is_tf():
        tf.get_session().run(tf.assign(x, val))
    else:
        import ipdb; ipdb.set_trace()
예제 #7
0
def shape(x):
    if is_theano():
        return x.shape
    elif is_cgt():
        return x.shape
    elif is_tf():
        if isinstance(x, (tf.Tensor, tf.Variable)):
            return x.shape
        else:
            import ipdb; ipdb.set_trace()
    else:
        import ipdb; ipdb.set_trace()
예제 #8
0
def set_value(x, val):
    """
    Get parameter value from a shared variable.
    """
    if is_theano():
        x.set_value(val)
    elif is_cgt():
        x.op.set_value(val)
    elif is_tf():
        tf.get_session().run(tf.assign(x, val))
    else:
        import ipdb
        ipdb.set_trace()
예제 #9
0
def shape(x):
    if is_theano():
        return x.shape
    elif is_cgt():
        return x.shape
    elif is_tf():
        if isinstance(x, (tf.Tensor, tf.Variable)):
            return x.shape
        else:
            import ipdb
            ipdb.set_trace()
    else:
        import ipdb
        ipdb.set_trace()
예제 #10
0
def get_inputs(outputs):
    if is_theano():
        return theano.gof.graph.inputs(outputs)
    elif is_cgt():
        outputs = list(outputs)
        return [
            node for node in cgt.core.topsorted(outputs) if node.is_input()
        ]
    elif is_tf():
        outputs = list(outputs)
        return [node for node in _tf_topsorted(outputs) if _tf_is_input(node)]
    else:
        import ipdb
        ipdb.set_trace()
예제 #11
0
def shape(x):
    if is_theano():
        return x.shape
    elif is_cgt():
        return x.shape
    elif is_tf():
        if isinstance(x, (tf.Tensor, tf.Variable)):
            return x.shape
        else:
            import ipdb; ipdb.set_trace()
    else:
        import ipdb; ipdb.set_trace()


if is_tf():
    _tf_session = None
    _tf_blank_vars = []

    def tf_get_session():
        global _tf_session
        if _tf_session is None:
            _tf_session = tf.Session()
        return _tf_session

    def tf_reset_session():
        global _tf_session
        if _tf_session:
            _tf_session.close()
        _tf_session = tf.Session()
    
예제 #12
0
from tensorfuse.config import is_theano, is_cgt, is_tf, is_mxnet
if is_theano():
    from tensorfuse.backend.theano.tensor.signal import *
elif is_cgt():
    from tensorfuse.backend.cgt.tensor.signal import *
elif is_tf():
    from tensorfuse.backend.tensorflow.tensor.signal import *
elif is_mxnet():
    from tensorfuse.backend.mxnet.tensor.signal import *
else:
    raise ValueError('Unknown backend')