def check_tensor_shape(tensor1, tensor2, name1=None, name2=None): if name1 is None: name1 = misc.retrieve_name(tensor1) if name2 is None: name2 = misc.retrieve_name(tensor2) assert isinstance(name1, str) and isinstance(name2, str) assert isinstance(tensor1, tf.Tensor) and isinstance(tensor2, tf.Tensor) shape1, shape2 = tensor1.shape.as_list(), tensor2.shape.as_list() if shape1 != shape2: raise ValueError( '!! {}.shape({}) should be equal with {}.shape({}) '.format( name1, shape1, name2, shape2))
def check_callable(f, name=None, allow_none=True): if name is None: name = misc.retrieve_name(f) flag = True if not allow_none and f is None: flag = False if f is not None and not callable(f): flag = False if flag: return f else: raise TypeError('!! {} must be callable'.format(name))
def eval_show(tensor, name=None, feed_dict=None): if name is None: name = misc.retrieve_name(tensor) sess = tf.get_default_session() val = sess.run(tensor, feed_dict=feed_dict) if len(val.shape) > 1: show_status('{} = '.format(name)) pprint(val) else: show_status('{} = {}'.format(name, val)) return val
def check_positive_integer(x, allow_zero=False, name=None): if not isinstance(x, int) or x < 0 or not allow_zero and x == 0: if name is None: name = misc.retrieve_name(x) raise ValueError('!! {} must be a positive integer'.format(name)) return x