Esempio n. 1
0
def check_tensor_shape(tensor1, tensor2, name1=None, name2=None):
    if name1 is None: name1 = misc.retrieve_name(tensor1)
    if name2 is None: name2 = misc.retrieve_name(tensor2)
    assert isinstance(name1, str) and isinstance(name2, str)
    assert isinstance(tensor1, tf.Tensor) and isinstance(tensor2, tf.Tensor)
    shape1, shape2 = tensor1.shape.as_list(), tensor2.shape.as_list()
    if shape1 != shape2:
        raise ValueError(
            '!! {}.shape({}) should be equal with {}.shape({}) '.format(
                name1, shape1, name2, shape2))
Esempio n. 2
0
def check_callable(f, name=None, allow_none=True):
    if name is None: name = misc.retrieve_name(f)
    flag = True
    if not allow_none and f is None: flag = False
    if f is not None and not callable(f): flag = False
    if flag: return f
    else: raise TypeError('!! {} must be callable'.format(name))
Esempio n. 3
0
def eval_show(tensor, name=None, feed_dict=None):
    if name is None: name = misc.retrieve_name(tensor)
    sess = tf.get_default_session()
    val = sess.run(tensor, feed_dict=feed_dict)
    if len(val.shape) > 1:
        show_status('{} = '.format(name))
        pprint(val)
    else:
        show_status('{} = {}'.format(name, val))
    return val
Esempio n. 4
0
def check_positive_integer(x, allow_zero=False, name=None):
    if not isinstance(x, int) or x < 0 or not allow_zero and x == 0:
        if name is None: name = misc.retrieve_name(x)
        raise ValueError('!! {} must be a positive integer'.format(name))
    return x