def __named_apply_prep(self, ctx, name, fn, attrs, do_apply): if attrs and '_ellipsis' in attrs: raise Exception("Can't use attribute ellipsis in function apply") fn = unwrap_bag(fn) base_scope_name = None if hasattr(fn, '_name'): base_scope_name = fn._name() if not base_scope_name: base_scope_name = 'fnc' g = tf.get_default_graph() scope_name = "q___%s" % g.unique_name(base_scope_name, False).split("/")[-1] if not hasattr(fn, 'apply') and hasattr(fn, 'apply_attrs'): fn = fn.apply_attrs(self, attrs) fn = unwrap_bag(fn) attrs = None result = do_apply(fn, scope_name) ctx.possible_leaf(result) if name != None: result = ctx.define_local(name, result) # HACK(adamb) The tf.identity call below just demands that the result is a Tensor. if isinstance(result, RetvalBag) and result.len() == 1: result = result.get(None) if isinstance(result, tf.Tensor): tf.identity(result, name=name) return result
def keyword_apply(unwrapped_fn, scope_name): unwrapped_kwargs = {} for key, value in kwargs.items(): value = unwrap_bag(value) ctx.eliminate_leaf(value) unwrapped_kwargs[key] = value return unwrapped_fn.apply_kw(self, ctx, name, attrs, unwrapped_kwargs)
def _maybe_export_function(self, package_name, subctx, name, value): if not name[0].isupper(): eprint("not capitalized", name) return value = unwrap_bag(value) eprint("considering", name) if not isinstance(value, graph_function.DeclaredFunction): eprint("isn't a declared function", type(value)) return fn = value if fn.has_attrs(): eprint("has attributes, skipping.") return g = tf.get_default_graph() var_collection_name = "%s:variable_names" % (g.unique_name(name, False)) var_set = set() def on_var(var): var_set.add(var.name) self.add_variable_listener(on_var) eprint("exporting", name, fn) with tf.variable_scope(name): with tf.variable_scope("inputs"): args = [tf.placeholder(arg_dtype, arg_shape, arg_name) for (arg_name, arg_shape, arg_dtype) in fn._arg_specs()] subctx2 = subctx.subcontext() fn.apply(self, subctx2, "_", None, args) with tf.variable_scope("outputs"): g = tf.get_default_graph() for (retval_name, retval_inner_name) in fn._retval_specs(): tensor_prefix = "%s/%s" % (package_name, name) try: returned_tensor = g.get_tensor_by_name("%s/_/%s:0" % (tensor_prefix, retval_inner_name)) except KeyError as ke: eprint("repeating lookup of %s in prefix %s for retval %s" % (retval_inner_name, tensor_prefix, retval_name)) # If we fail to find the tensor above, perhaps it was just an input. try: returned_tensor = g.get_tensor_by_name("%s/inputs/%s:0" % (tensor_prefix, retval_inner_name)) except KeyError: nodes = [n.name for n in tf.get_default_graph().as_graph_def().node] nodes.sort() eprint('error, but got nodes', nodes) raise ke tf.identity(returned_tensor, name=retval_name) for var_name in var_set: g.add_to_collection(var_collection_name, var_name) self.remove_variable_listener(on_var)
def positonal_apply(unwrapped_fn, scope_name): unwrapped_args = [] for arg in args: arg = unwrap_bag(arg) ctx.eliminate_leaf(arg) unwrapped_args.append(arg) if hasattr(unwrapped_fn, 'apply'): return unwrapped_fn.apply(self, ctx, scope_name, attrs, unwrapped_args) else: raise Exception("Can't apply non-function %s with unwrapped args %s" % (unwrapped_fn, unwrapped_args))
def _sf_while_inner(use_device, visitor_class, ctx, exprs): with tf.Graph().as_default() as g: with tf.device(use_device): visitor = visitor_class() final_tensor = visitor._visit_exprs(ctx, exprs) # We do not want to include shapes, since inferred shapes will cause problems # for shape inference upon import and re-export. # HACK(adamb) Since TensorFlow uses __del__ to clean up py_funcs, we need to copy them. cleanup_py_funcs_used_in_graph = [] if hasattr(g, "_cleanup_py_funcs_used_in_graph"): cleanup_py_funcs_used_in_graph = g._cleanup_py_funcs_used_in_graph[:] return (tf.train.export_meta_graph(), cleanup_py_funcs_used_in_graph, unwrap_bag(final_tensor).name)
def _sf_while_loop(visitor, ctx, cond_expr, body_exprs, body_retvals, init_exprs): # Need to evaluate body_exprs first, looking for all variables that will be created # internally. Roll up into nested variable contexts. Unroll these contexts to be # passed as part of var_list. Within def body(*a), repackage these variables into # variable contexts. Use these contexts *instead of* creating variables directly. # So we want to be able to indicate whether or not contexts should be allowed to # create variables on the fly. Within a while cond/body, they should not. Otherwise, # the can (e.g. when compiling a graph { ... } expression) # track these so we can eventually remove them. proxy_cruft = set() proxied_placeholder_names = OrderedDict() proxied_placeholders = OrderedDict() # track replacements. focus on external entity to internal name. def proxy(v): # eprint("proxy", v) if isinstance(v, RetvalBag): if v.graph is None: return v if v.graph == tf.get_default_graph(): return v return v.wrap(proxy) if not isinstance(v, (tf.Operation, tf.Tensor, tf.Variable)): return v if v.graph == tf.get_default_graph(): return v if v.name in proxied_placeholders: return proxied_placeholders[v.name] if v.name in proxied_placeholder_names: placeholder_name = proxied_placeholder_names[v.name] else: placeholder_name = "Proxy_%d" % len(proxied_placeholder_names) p = None with tf.name_scope(None): with tf.control_dependencies(None): p_name = None if isinstance(v, tf.Tensor) and v.dtype._is_ref_dtype: p = tf.Variable(initial_value=zero_value_for_dtype( v.dtype), trainable=False, collections=[], name=placeholder_name, dtype=v.dtype.base_dtype, validate_shape=False) p.set_shape(v.get_shape()) p_name = "%s" % p.op.name proxy_cruft.add(p_name) proxy_cruft.add("%s/read" % p.op.name) proxy_cruft.add("%s/Assign" % p.op.name) proxy_cruft.add("%s/initial_value" % p.op.name) elif isinstance(v, tf.Variable): p = tf.Variable(initial_value=zero_value_for_dtype( v.dtype), trainable=False, collections=[], name=placeholder_name, dtype=v.dtype.base_dtype, validate_shape=False) p.set_shape(v.get_shape()) p_name = "%s:0" % p.op.name p = tf.get_default_graph().get_tensor_by_name(p_name) v = v.graph.get_tensor_by_name("%s:0" % v.op.name) proxy_cruft.add(p_name) proxy_cruft.add("%s/read" % p.op.name) proxy_cruft.add("%s/Assign" % p.op.name) proxy_cruft.add("%s/initial_value" % p.op.name) else: p = tf.placeholder(v.dtype, shape=v.get_shape(), name=placeholder_name) p_name = p.op.name proxy_cruft.add(p_name) proxied_placeholders[v.name] = p proxied_placeholder_names[v.name] = p_name if placeholder_name and placeholder_name != p.op.name: raise Exception( "Created placeholder with unexpected name: %s vs %s" % (placeholder_name, p.op.name)) return p g = tf.get_default_graph() while_loop_name = g.unique_name("while", False) # eprint('init_exprs', init_exprs) initial_value_ctx = ctx.subcontext() initial_value_ctx._proxy = proxy initial_tensor_list = None initial_local_names = None with tf.variable_scope('%s_init' % while_loop_name): initial_tensor_list = [ unwrap_bag(visitor.visit(initial_value_ctx, expr)) for expr in init_exprs ] initial_local_names = [define[1] for define in init_exprs] local_name_by_tensor_name = dict( zip([t.name for t in initial_tensor_list], initial_local_names)) device_stack = g._device_function_stack use_device = None if len(device_stack) > 0: use_device = device_stack[-1] eprint("Will use device", use_device) # Ensure we have a placeholder for every initial value. with tf.Graph().as_default(): with tf.device(use_device): for local_name in initial_local_names: initial_value_ctx.get_local(local_name) # Don't let cached placeholders from init_exprs infect our graph. proxied_placeholders = OrderedDict() cond_ctx = initial_value_ctx.subcontext() cond_meta_graph_def, cond_cleanup_funcs, cond_retval_name = _sf_while_inner( use_device, type(visitor), cond_ctx, [cond_expr]) # Don't let cached placeholders from cond_exprs infect our graph. proxied_placeholders = OrderedDict() body_ctx = initial_value_ctx.subcontext() body_meta_graph_def, body_cleanup_funcs, _ = _sf_while_inner( use_device, type(visitor), body_ctx, body_exprs) # HACK(adamb) Don't actually import any nodes that are only proxies. # This should probably be done automatically by the TF import # logic, but empirically this is not the case. _while_prune(cond_meta_graph_def, proxy_cruft) _while_fix_colocations(cond_meta_graph_def, proxy_cruft) _while_prune(body_meta_graph_def, proxy_cruft) _while_fix_colocations(body_meta_graph_def, proxy_cruft) body_retval_dict = dict(body_retvals) body_retval_names = [] next_value_ixs = [] loop_vars = [ g.get_tensor_by_name(v_name) for v_name in proxied_placeholder_names.keys() ] ix = -1 for t in loop_vars: ix += 1 # if it's in initial_tensor_list, then look up its init_local_name # if we have a retval for this init_local_name, then use the inner_retval # otherwise pass through. if t.name in local_name_by_tensor_name: local_name = local_name_by_tensor_name[t.name] if local_name in body_retval_dict: # eprint("while next vals", ix, t.get_shape(), t.name, local_name, body_retval_dict[local_name]) body_retval_names.append("%s:0" % body_retval_dict[local_name]) next_value_ixs.append(ix) else: # eprint("while next vals skipped", ix, local_name) pass else: # eprint("while next vals t.name", ix, t.name) pass # eprint("while initial_local_names", initial_local_names) # eprint("while initial_tensor_list", initial_tensor_list) # eprint("while proxied_placeholder_names", proxied_placeholder_names) # eprint("while local_name_by_tensor_name", local_name_by_tensor_name) def cond(*a): # We use a variable_scope because name_scope has a strange # only-sometimes-present trailing / that messes with everything. cond_import_scope = '%s_cond' % while_loop_name _while_fix_context_scope(cond_meta_graph_def, cond_import_scope) return _sf_while_embed( cond_import_scope, dict(zip(proxied_placeholder_names.values(), a)), [cond_retval_name], cond_meta_graph_def, cond_cleanup_funcs)[0] def body(*a): body_input_map = dict(zip(proxied_placeholder_names.values(), a)) # eprint("while body", body_input_map) # We use a variable_scope because name_scope has a strange # only-sometimes-present trailing / that messes with everything. body_import_scope = '%s_body' % while_loop_name _while_fix_context_scope(body_meta_graph_def, body_import_scope) next_values = _sf_while_embed(body_import_scope, body_input_map, body_retval_names, body_meta_graph_def, body_cleanup_funcs) body_results = list(a) for ix, val in zip(next_value_ixs, next_values): val.set_shape(a[ix].get_shape()) # eprint('while shape', ix, a[ix], a[ix].get_shape(), val, val.get_shape()) # val.set_shape(val.get_shape()) body_results[ix] = val # eprint('while body_results', body_results) return body_results # If we're referencing variables, we need to alert listeners. for v in loop_vars: visitor._visit_result(v) results = None results = tf.while_loop( cond=cond, body=body, loop_vars=loop_vars, parallel_iterations=1, back_prop=False, name=while_loop_name.split("/")[-1], ) if type(results) != list: results = [results] r = {} for k_name, v in zip(proxied_placeholder_names.keys(), results): if k_name in local_name_by_tensor_name: r[local_name_by_tensor_name[k_name]] = v return RetvalBag(r)
def apply_attrs(self, ctx, function, attrs): return unwrap_bag(function).apply_attrs(self, attrs)
def list(self, ctx, *entries): return [unwrap_bag(e) for e in entries]