Beispiel #1
0
  def __named_apply_prep(self, ctx, name, fn, attrs, do_apply):
    if attrs and '_ellipsis' in attrs:
      raise Exception("Can't use attribute ellipsis in function apply")

    fn = unwrap_bag(fn)

    base_scope_name = None
    if hasattr(fn, '_name'):
      base_scope_name = fn._name()

    if not base_scope_name:
      base_scope_name = 'fnc'

    g = tf.get_default_graph()
    scope_name = "q___%s" % g.unique_name(base_scope_name, False).split("/")[-1]

    if not hasattr(fn, 'apply') and hasattr(fn, 'apply_attrs'):
      fn = fn.apply_attrs(self, attrs)
      fn = unwrap_bag(fn)
      attrs = None

    result = do_apply(fn, scope_name)

    ctx.possible_leaf(result)

    if name != None:
      result = ctx.define_local(name, result)

      # HACK(adamb) The tf.identity call below just demands that the result is a Tensor.
      if isinstance(result, RetvalBag) and result.len() == 1:
        result = result.get(None)
      if isinstance(result, tf.Tensor):
        tf.identity(result, name=name)

    return result
Beispiel #2
0
    def keyword_apply(unwrapped_fn, scope_name):
      unwrapped_kwargs = {}
      for key, value in kwargs.items():
        value = unwrap_bag(value)
        ctx.eliminate_leaf(value)
        unwrapped_kwargs[key] = value

      return unwrapped_fn.apply_kw(self, ctx, name, attrs, unwrapped_kwargs)
Beispiel #3
0
  def _maybe_export_function(self, package_name, subctx, name, value):
    if not name[0].isupper():
      eprint("not capitalized", name)
      return

    value = unwrap_bag(value)
    eprint("considering", name)

    if not isinstance(value, graph_function.DeclaredFunction):
      eprint("isn't a declared function", type(value))
      return

    fn = value

    if fn.has_attrs():
      eprint("has attributes, skipping.")
      return

    g = tf.get_default_graph()
    var_collection_name = "%s:variable_names" % (g.unique_name(name, False))
    var_set = set()
    def on_var(var):
      var_set.add(var.name)
    self.add_variable_listener(on_var)

    eprint("exporting", name, fn)
    with tf.variable_scope(name):
      with tf.variable_scope("inputs"):
        args = [tf.placeholder(arg_dtype, arg_shape, arg_name) for (arg_name, arg_shape, arg_dtype) in fn._arg_specs()]

      subctx2 = subctx.subcontext()
      fn.apply(self, subctx2, "_", None, args)

      with tf.variable_scope("outputs"):
        g = tf.get_default_graph()
        for (retval_name, retval_inner_name) in fn._retval_specs():
          tensor_prefix = "%s/%s" % (package_name, name)
          try:
            returned_tensor = g.get_tensor_by_name("%s/_/%s:0" % (tensor_prefix, retval_inner_name))
          except KeyError as ke:
            eprint("repeating lookup of %s in prefix %s for retval %s" % (retval_inner_name, tensor_prefix, retval_name))
            # If we fail to find the tensor above, perhaps it was just an input.
            try:
              returned_tensor = g.get_tensor_by_name("%s/inputs/%s:0" % (tensor_prefix, retval_inner_name))
            except KeyError:
              nodes = [n.name for n in tf.get_default_graph().as_graph_def().node]
              nodes.sort()
              eprint('error, but got nodes', nodes)
              raise ke

          tf.identity(returned_tensor, name=retval_name)

    for var_name in var_set:
      g.add_to_collection(var_collection_name, var_name)

    self.remove_variable_listener(on_var)
Beispiel #4
0
 def positonal_apply(unwrapped_fn, scope_name):
   unwrapped_args = []
   for arg in args:
     arg = unwrap_bag(arg)
     ctx.eliminate_leaf(arg)
     unwrapped_args.append(arg)
   if hasattr(unwrapped_fn, 'apply'):
     return unwrapped_fn.apply(self, ctx, scope_name, attrs, unwrapped_args)
   else:
     raise Exception("Can't apply non-function %s with unwrapped args %s" % (unwrapped_fn, unwrapped_args))
Beispiel #5
0
def _sf_while_inner(use_device, visitor_class, ctx, exprs):
    with tf.Graph().as_default() as g:
        with tf.device(use_device):
            visitor = visitor_class()
            final_tensor = visitor._visit_exprs(ctx, exprs)

            # We do not want to include shapes, since inferred shapes will cause problems
            # for shape inference upon import and re-export.

            # HACK(adamb) Since TensorFlow uses __del__ to clean up py_funcs, we need to copy them.
            cleanup_py_funcs_used_in_graph = []
            if hasattr(g, "_cleanup_py_funcs_used_in_graph"):
                cleanup_py_funcs_used_in_graph = g._cleanup_py_funcs_used_in_graph[:]

            return (tf.train.export_meta_graph(),
                    cleanup_py_funcs_used_in_graph,
                    unwrap_bag(final_tensor).name)
Beispiel #6
0
def _sf_while_loop(visitor, ctx, cond_expr, body_exprs, body_retvals,
                   init_exprs):
    # Need to evaluate body_exprs first, looking for all variables that will be created
    # internally. Roll up into nested variable contexts. Unroll these contexts to be
    # passed as part of var_list. Within def body(*a), repackage these variables into
    # variable contexts. Use these contexts *instead of* creating variables directly.
    # So we want to be able to indicate whether or not contexts should be allowed to
    # create variables on the fly. Within a while cond/body, they should not. Otherwise,
    # the can (e.g. when compiling a graph { ... } expression)

    # track these so we can eventually remove them.
    proxy_cruft = set()
    proxied_placeholder_names = OrderedDict()
    proxied_placeholders = OrderedDict()

    # track replacements. focus on external entity to internal name.
    def proxy(v):
        # eprint("proxy", v)

        if isinstance(v, RetvalBag):
            if v.graph is None:
                return v

            if v.graph == tf.get_default_graph():
                return v

            return v.wrap(proxy)

        if not isinstance(v, (tf.Operation, tf.Tensor, tf.Variable)):
            return v

        if v.graph == tf.get_default_graph():
            return v

        if v.name in proxied_placeholders:
            return proxied_placeholders[v.name]

        if v.name in proxied_placeholder_names:
            placeholder_name = proxied_placeholder_names[v.name]
        else:
            placeholder_name = "Proxy_%d" % len(proxied_placeholder_names)

        p = None
        with tf.name_scope(None):
            with tf.control_dependencies(None):
                p_name = None
                if isinstance(v, tf.Tensor) and v.dtype._is_ref_dtype:
                    p = tf.Variable(initial_value=zero_value_for_dtype(
                        v.dtype),
                                    trainable=False,
                                    collections=[],
                                    name=placeholder_name,
                                    dtype=v.dtype.base_dtype,
                                    validate_shape=False)
                    p.set_shape(v.get_shape())
                    p_name = "%s" % p.op.name
                    proxy_cruft.add(p_name)
                    proxy_cruft.add("%s/read" % p.op.name)
                    proxy_cruft.add("%s/Assign" % p.op.name)
                    proxy_cruft.add("%s/initial_value" % p.op.name)
                elif isinstance(v, tf.Variable):
                    p = tf.Variable(initial_value=zero_value_for_dtype(
                        v.dtype),
                                    trainable=False,
                                    collections=[],
                                    name=placeholder_name,
                                    dtype=v.dtype.base_dtype,
                                    validate_shape=False)
                    p.set_shape(v.get_shape())
                    p_name = "%s:0" % p.op.name
                    p = tf.get_default_graph().get_tensor_by_name(p_name)
                    v = v.graph.get_tensor_by_name("%s:0" % v.op.name)
                    proxy_cruft.add(p_name)
                    proxy_cruft.add("%s/read" % p.op.name)
                    proxy_cruft.add("%s/Assign" % p.op.name)
                    proxy_cruft.add("%s/initial_value" % p.op.name)
                else:
                    p = tf.placeholder(v.dtype,
                                       shape=v.get_shape(),
                                       name=placeholder_name)
                    p_name = p.op.name
                    proxy_cruft.add(p_name)

        proxied_placeholders[v.name] = p
        proxied_placeholder_names[v.name] = p_name

        if placeholder_name and placeholder_name != p.op.name:
            raise Exception(
                "Created placeholder with unexpected name: %s vs %s" %
                (placeholder_name, p.op.name))

        return p

    g = tf.get_default_graph()
    while_loop_name = g.unique_name("while", False)

    # eprint('init_exprs', init_exprs)
    initial_value_ctx = ctx.subcontext()
    initial_value_ctx._proxy = proxy
    initial_tensor_list = None
    initial_local_names = None
    with tf.variable_scope('%s_init' % while_loop_name):
        initial_tensor_list = [
            unwrap_bag(visitor.visit(initial_value_ctx, expr))
            for expr in init_exprs
        ]
        initial_local_names = [define[1] for define in init_exprs]

    local_name_by_tensor_name = dict(
        zip([t.name for t in initial_tensor_list], initial_local_names))

    device_stack = g._device_function_stack
    use_device = None
    if len(device_stack) > 0:
        use_device = device_stack[-1]

    eprint("Will use device", use_device)

    # Ensure we have a placeholder for every initial value.
    with tf.Graph().as_default():
        with tf.device(use_device):
            for local_name in initial_local_names:
                initial_value_ctx.get_local(local_name)

    # Don't let cached placeholders from init_exprs infect our graph.
    proxied_placeholders = OrderedDict()
    cond_ctx = initial_value_ctx.subcontext()
    cond_meta_graph_def, cond_cleanup_funcs, cond_retval_name = _sf_while_inner(
        use_device, type(visitor), cond_ctx, [cond_expr])

    # Don't let cached placeholders from cond_exprs infect our graph.
    proxied_placeholders = OrderedDict()
    body_ctx = initial_value_ctx.subcontext()
    body_meta_graph_def, body_cleanup_funcs, _ = _sf_while_inner(
        use_device, type(visitor), body_ctx, body_exprs)

    # HACK(adamb) Don't actually import any nodes that are only proxies.
    #     This should probably be done automatically by the TF import
    #     logic, but empirically this is not the case.
    _while_prune(cond_meta_graph_def, proxy_cruft)
    _while_fix_colocations(cond_meta_graph_def, proxy_cruft)

    _while_prune(body_meta_graph_def, proxy_cruft)
    _while_fix_colocations(body_meta_graph_def, proxy_cruft)

    body_retval_dict = dict(body_retvals)
    body_retval_names = []
    next_value_ixs = []

    loop_vars = [
        g.get_tensor_by_name(v_name)
        for v_name in proxied_placeholder_names.keys()
    ]

    ix = -1
    for t in loop_vars:
        ix += 1
        # if it's in initial_tensor_list, then look up its init_local_name
        # if we have a retval for this init_local_name, then use the inner_retval
        # otherwise pass through.
        if t.name in local_name_by_tensor_name:
            local_name = local_name_by_tensor_name[t.name]
            if local_name in body_retval_dict:
                # eprint("while next vals", ix, t.get_shape(), t.name, local_name, body_retval_dict[local_name])
                body_retval_names.append("%s:0" % body_retval_dict[local_name])
                next_value_ixs.append(ix)
            else:
                # eprint("while next vals skipped", ix, local_name)
                pass
        else:
            # eprint("while next vals t.name", ix, t.name)
            pass

    # eprint("while initial_local_names", initial_local_names)
    # eprint("while initial_tensor_list", initial_tensor_list)
    # eprint("while proxied_placeholder_names", proxied_placeholder_names)
    # eprint("while local_name_by_tensor_name", local_name_by_tensor_name)

    def cond(*a):
        # We use a variable_scope because name_scope has a strange
        # only-sometimes-present trailing / that messes with everything.
        cond_import_scope = '%s_cond' % while_loop_name

        _while_fix_context_scope(cond_meta_graph_def, cond_import_scope)

        return _sf_while_embed(
            cond_import_scope, dict(zip(proxied_placeholder_names.values(),
                                        a)), [cond_retval_name],
            cond_meta_graph_def, cond_cleanup_funcs)[0]

    def body(*a):
        body_input_map = dict(zip(proxied_placeholder_names.values(), a))
        # eprint("while body", body_input_map)

        # We use a variable_scope because name_scope has a strange
        # only-sometimes-present trailing / that messes with everything.
        body_import_scope = '%s_body' % while_loop_name

        _while_fix_context_scope(body_meta_graph_def, body_import_scope)

        next_values = _sf_while_embed(body_import_scope, body_input_map,
                                      body_retval_names, body_meta_graph_def,
                                      body_cleanup_funcs)

        body_results = list(a)
        for ix, val in zip(next_value_ixs, next_values):
            val.set_shape(a[ix].get_shape())
            # eprint('while shape', ix, a[ix], a[ix].get_shape(), val, val.get_shape())
            # val.set_shape(val.get_shape())
            body_results[ix] = val

        # eprint('while body_results', body_results)
        return body_results

    # If we're referencing variables, we need to alert listeners.
    for v in loop_vars:
        visitor._visit_result(v)

    results = None
    results = tf.while_loop(
        cond=cond,
        body=body,
        loop_vars=loop_vars,
        parallel_iterations=1,
        back_prop=False,
        name=while_loop_name.split("/")[-1],
    )

    if type(results) != list:
        results = [results]

    r = {}
    for k_name, v in zip(proxied_placeholder_names.keys(), results):
        if k_name in local_name_by_tensor_name:
            r[local_name_by_tensor_name[k_name]] = v

    return RetvalBag(r)
Beispiel #7
0
 def apply_attrs(self, ctx, function, attrs):
   return unwrap_bag(function).apply_attrs(self, attrs)
Beispiel #8
0
 def list(self, ctx, *entries):
   return [unwrap_bag(e) for e in entries]