def _build_call_outputs(self, func_outputs, result): """Maps the fdef output list to actual output structure. Args: func_outputs: The outputs originally defined by the graph function. It could potentially be a nested structure. result: Output lists defined by FunctionDef. Returns: The actual call output. """ if self._func_outputs is None: return None if isinstance(ag_core.getval(self._func_outputs), ops.Tensor): return result[0] outputs = [] for o in func_outputs: vo = ag_core.getval(o) if isinstance(vo, ops.Tensor): outputs.append(result[self._returns_to_fedf_outputs[id(vo)]]) elif type(vo) in (tuple, list): outputs.append(self._build_call_outputs(o, result)) else: outputs.append(o) return tuple(outputs) if type(func_outputs) is tuple else outputs
def variational_objective(params, t): # stochastic estimate of the variational lower bound w_mu, w_log_s2, s_pi, log_s2_w, pi_w, log_s2 = unpack_params(params) # compute the expectation of the "data fit" term and the entropy # by Monte Carlo sampling datafit = 0. entropy = 0. for _ in range(num_samples): # acquire M Bernoulli samples s = Bernoulli(pi = np.column_stack( [ 1-s_pi, s_pi ] ), T=0.5)[:,1] # acquire M Gaussian distributed samples mean = s*w_mu var = s*np.exp(w_log_s2) + (1-s)*np.exp(log_s2_w) w = mean + np.sqrt(var) * np.random.randn(M) # compute the log of the joint probability datafit = datafit \ + logprob(s, w, log_s2_w, pi_w, log_s2, X, y, batch_size, t) # compute the entropy q(w,s) mean = getval(mean) var = getval(var) s_pi = getval(s_pi) entropy = entropy \ + np.sum( 0.5*np.log(2*np.pi*var) + 0.5/var*np.power(w-mean, 2) ) \ - np.sum( s*np.log(s_pi) + (1-s)*np.log(1-s_pi) ) datafit = datafit / num_samples entropy = entropy / num_samples # the lower bound to maximize lower_bound = datafit + entropy return -lower_bound
def _objfunc(self, params, t): samps = self.sample_var_approx(getval(params), n_samples=self.N_SAMPLES) return np.mean( self.log_var_approx(samps, params) * (self.log_prob(samps) - self.log_var_approx(samps, getval(params))))
def flatten(value): """value can be any nesting of tuples, arrays, dicts. returns 1D numpy array and an unflatten function.""" if isinstance(getval(value), np.ndarray): def unflatten(vector): return np.reshape(vector, value.shape) return np.ravel(value), unflatten elif isinstance(getval(value), float): return np.array([value]), lambda x : x[0] elif isinstance(getval(value), tuple): if not value: return np.array([]), lambda x : () flattened_first, unflatten_first = flatten(value[0]) flattened_rest, unflatten_rest = flatten(value[1:]) def unflatten(vector): N = len(flattened_first) return (unflatten_first(vector[:N]),) + unflatten_rest(vector[N:]) return np.concatenate((flattened_first, flattened_rest)), unflatten elif isinstance(getval(value), list): if not value: return np.array([]), lambda x : [] flattened_first, unflatten_first = flatten(value[0]) flattened_rest, unflatten_rest = flatten(value[1:]) def unflatten(vector): N = len(flattened_first) return [unflatten_first(vector[:N])] + unflatten_rest(vector[N:]) return np.concatenate((flattened_first, flattened_rest)), unflatten elif isinstance(getval(value), dict): flattened = [] unflatteners = [] lengths = [] keys = [] for k, v in sorted(iteritems(value), key=itemgetter(0)): cur_flattened, cur_unflatten = flatten(v) flattened.append(cur_flattened) unflatteners.append(cur_unflatten) lengths.append(len(cur_flattened)) keys.append(k) def unflatten(vector): split_ixs = np.cumsum(lengths) pieces = np.split(vector, split_ixs) return {key: unflattener(piece) for piece, unflattener, key in zip(pieces, unflatteners, keys)} return np.concatenate(flattened), unflatten else: raise Exception("Don't know how to flatten type {}".format(type(value)))
def mut_add(self, x, y): """Add wrapper safe for IndexedSlices and LazyZero.""" if isinstance(ag_core.getval(x), tensor.LazyZero): return y if isinstance(ag_core.getval(y), tensor.LazyZero): return x if isinstance(x, ops.IndexedSlices): x = _indexed_slices_to_tensor(x) if isinstance(y, ops.IndexedSlices): y = _indexed_slices_to_tensor(y) return math_ops.add(x, y)
def jac_fun(*args, **kwargs): arg_in = args[argnum] output = fun(*args, **kwargs) assert isinstance(getval(arg_in), np.ndarray), "Must have array input" assert isinstance(getval(output), np.ndarray), "Must have array output" jac = np.zeros(output.shape + arg_in.shape) input_slice = (slice(None),) * len(arg_in.shape) for idxs in it.product(*map(range, output.shape)): scalar_fun = lambda *args, **kwargs : fun(*args, **kwargs)[idxs] jac[idxs + input_slice] = grad(scalar_fun, argnum=argnum)(*args, **kwargs) return jac
def jac_fun(*args, **kwargs): arg_in = args[argnum] output = fun(*args, **kwargs) assert isinstance(getval(arg_in), np.ndarray), "Must have array input" assert isinstance(getval(output), np.ndarray), "Must have array output" jac = np.zeros(output.shape + arg_in.shape) input_slice = (slice(None),) * len(arg_in.shape) for idxs in it.product(*list(map(range, output.shape))): scalar_fun = lambda *args, **kwargs : fun(*args, **kwargs)[idxs] jac[idxs + input_slice] = grad(scalar_fun, argnum=argnum)(*args, **kwargs) return jac
def as_scalar(x): vs = vspace(getval(x)) if vs.iscomplex: x = np.real(x) if vs.shape == (): return x elif vs.size == 1: return x.reshape(()) else: raise TypeError("Output {} can't be cast to float. " "Function grad requires a scalar-valued function. " "Try jacobian or elementwise_grad.".format(getval(x)))
def as_scalar(x): vs = vspace(getval(x)) if vs.iscomplex: x = np.real(x) if vs.shape == (): return x elif vs.size == 1: return x.reshape(()) else: raise TypeError( "Output {} can't be cast to float. " "Function grad requires a scalar-valued function. " "Try jacobian or elementwise_grad.".format(getval(x)))
def flatten(value): """Flattens any nesting of tuples, arrays, or dicts. Returns 1D numpy array and an unflatten function. Doesn't preserve mixed numeric types (e.g. floats and ints). Assumes dict keys are sortable.""" if isinstance(getval(value), np.ndarray): shape = value.shape def unflatten(vector): return np.reshape(vector, shape) return np.ravel(value), unflatten elif isinstance(getval(value), (float, int)): return np.array([value]), lambda x: x[0] elif isinstance(getval(value), (tuple, list)): constructor = type(getval(value)) if not value: return np.array([]), lambda x: constructor() flat_pieces, unflatteners = zip(*map(flatten, value)) split_indices = np.cumsum([len(vec) for vec in flat_pieces[:-1]]) def unflatten(vector): pieces = np.split(vector, split_indices) return constructor( unflatten(v) for unflatten, v in zip(unflatteners, pieces)) return np.concatenate(flat_pieces), unflatten elif isinstance(getval(value), dict): items = sorted(iteritems(value), key=itemgetter(0)) keys, flat_pieces, unflatteners = zip(*[(k, ) + flatten(v) for k, v in items]) split_indices = np.cumsum([len(vec) for vec in flat_pieces[:-1]]) def unflatten(vector): pieces = np.split(vector, split_indices) return { key: unflattener(piece) for piece, unflattener, key in zip(pieces, unflatteners, keys) } return np.concatenate(flat_pieces), unflatten else: raise Exception("Don't know how to flatten type {}".format( type(value)))
def convert_to_eager_tensor(t, dtype=None): """Converts the given `value` to an `EagerTensor`.""" if isinstance(ag_core.getval(t), ops.EagerTensor): if dtype is not None and t.dtype != dtype: raise TypeError("Expected tensor with type %r not %r" % (dtype, t.dtype)) return t # Handle converting ResourceVariable to Tensor. # TODO(josh11b): get rid of this explicit ugly conversion once we have a more # general scheme in place. try: return t._dense_var_to_tensor(dtype=dtype, as_ref=False) # pylint: disable=protected-access except AttributeError: pass if isinstance(t, (int, float)): # Use a scalar cache. This will put each scalar of each type only once on # each device. Scalars don't use much device memory but copying scalars can # trigger memcpys which are slow. device = context.context().device_name cache_key = device, t, dtype, type(t) tensor = _scalar_cache.get(cache_key, None) if tensor is not None: return tensor value = ops.EagerTensor(t, dtype=dtype) _scalar_cache[cache_key] = value return value return ops.EagerTensor(t, dtype=dtype)
def _estimate_ELBO_noscore(self, params, t): samps = self.sample_var_approx(params, n_samples=self.N_SAMPLES) #eliminates the score function return -np.mean( self.log_prob(samps) - self.log_var_approx(samps, getval(params)), axis=0) #this one appears to be correct
def repeat_to_match_shape(x, axis, keepdims): """Returns a function that repeats an array along axis to get a given shape. Also returns the number of repetitions of the array.""" assert isinstance(axis, (type(None), int, tuple)) if not isarray(x): return I, 1 shape = x.shape if axis is None: dtype=None if anp.iscomplexobj(x): dtype = getval(anp.array(x)).dtype # np.full() has a bug for complex numbers if keepdims: return lambda g : anp.full(shape, anp.sum(g), dtype=dtype), anp.prod(shape) else: return lambda g : anp.full(shape, g, dtype=dtype), anp.prod(shape) elif isinstance(axis, int): if keepdims: return lambda g : anp.repeat(g, shape[axis], axis), shape[axis] else: return lambda g : anp.repeat(anp.expand_dims(g, axis), shape[axis], axis), shape[axis] else: repeats = [shape[i] if i in axis else 1 for i in range(len(shape))] expanded = [shape[i] if i not in axis else 1 for i in range(len(shape))] num_reps = anp.prod(anp.array(shape)[list(axis)]) if keepdims: return lambda g: anp.tile(g, repeats), num_reps else: return lambda g: anp.tile(anp.reshape(g, expanded), repeats), num_reps
def constant_value(tensor, partial=False): # pylint: disable=invalid-name """Returns the constant value of the given tensor, if efficiently calculable. This function attempts to partially evaluate the given tensor, and returns its value as a numpy ndarray if this succeeds. TODO(mrry): Consider whether this function should use a registration mechanism like gradients and ShapeFunctions, so that it is easily extensible. NOTE: If `constant_value(tensor)` returns a non-`None` result, it will no longer be possible to feed a different value for `tensor`. This allows the result of this function to influence the graph that is constructed, and permits static shape optimizations. Args: tensor: The Tensor to be evaluated. partial: If True, the returned numpy array is allowed to have partially evaluated values. Values that can't be evaluated will be None. Returns: A numpy ndarray containing the constant value of the given `tensor`, or None if it cannot be calculated. Raises: TypeError: if tensor is not an ops.Tensor. """ if isinstance(ag_core.getval(tensor), ops.EagerTensor): return tensor.numpy() ret = _ConstantValue(tensor, partial) if ret is not None: # The caller may now depend on the constant value of `tensor`, so we # conservatively prevent it from being fed. tensor.graph.prevent_feeding(tensor) return ret
def _compute_backprop(self): """Computes the backprop function object for this function.""" self._has_backprop = True with self._graph.as_default(), context.graph_mode(): c = _CapturingContext() with c: filtered_outputs = [ ag_core.getval(x) for x in self._returns if x is not None ] self._out_grad_placeholders = [ graph_placeholder(x.dtype, x.shape) for x in filtered_outputs ] in_gradients = gradients_impl.gradients( filtered_outputs, self._input_placeholders, grad_ys=self._out_grad_placeholders) shapes = [x.shape for x in in_gradients if x is not None] captures = list(sorted(c.captured_tensors, key=lambda x: x.name)) forward_function_def = graph_to_function_def.graph_to_function_def( self._graph, self._ops, self._input_placeholders, filtered_outputs + captures) self._forward_fdef = _DefinedFunction(forward_function_def) _register_with_name(_forward_name(self._func_name), forward_function_def) backward_outputs = [x for x in in_gradients if x is not None] all_inputs = self._out_grad_placeholders + captures backward_function_def = graph_to_function_def.graph_to_function_def( self._graph, [x.op for x in self._out_grad_placeholders ] + list(sorted(c.known_ops, key=lambda x: x.name)), all_inputs, backward_outputs) _register_with_name(_backward_name(self._func_name), backward_function_def) self._backward_function = _GraphModeFunction( all_inputs, [], backward_function_def, self._graph, c.known_ops, in_gradients, _map_sequence_obj_to_idx(backward_outputs), shapes)
def convert_to_mixed_eager_tensors(values): v = [ t if isinstance(ag_core.getval(t), tensor.Tensor) else tensor.Tensor(t) for t in values ] types = [t.dtype for t in v] return types, v
def args_to_matching_eager(l, default_dtype=None): """Convert sequence `l` to eager same-type Tensors.""" # TODO(josh11b): Could we do a better job if we also passed in the # allowed dtypes when that was known? # Is some input already a Tensor with a dtype? dtype = None for t in l: if isinstance(ag_core.getval(t), tensor.Tensor): dtype = t.dtype break if dtype is None: # TODO(josh11b): At the moment, I don't think this can fail, but at some # point we likely should have some logic to prevent bad conversions. dtype = default_dtype if dtype is None: # Infer a dtype based on the first value, and use that dtype for the # remaining values. ret = [] for t in l: ret.append(ops.convert_to_tensor(t, dtype)) if dtype is None: dtype = ret[-1].dtype else: ret = [ops.convert_to_tensor(t, dtype) for t in l] return dtype, ret
def args_to_matching_eager(l, default_dtype=None): """Convert sequence `l` to eager same-type Tensors.""" # TODO(josh11b): Could we do a better job if we also passed in the # allowed dtypes when that was known? # Is some input already a Tensor with a dtype? dtype = None for t in l: if isinstance(ag_core.getval(t), tensor.Tensor): dtype = t.dtype break if dtype is None: # TODO(josh11b): At the moment, I don't think this can fail, but at some # point we likely should have some logic to prevent bad conversions. dtype = default_dtype if dtype is None: # Infer a dtype based on the first value, and use that dtype for the # remaining values. ret = [] for t in l: ret.append(tensor.convert_to_eager_tensor(t, dtype)) if dtype is None: dtype = ret[-1].dtype else: ret = [tensor.convert_to_eager_tensor(t, dtype) for t in l] return dtype, ret
def convert_to_eager_tensor(t, dtype=None): if isinstance(ag_core.getval(t), Tensor): if dtype is not None and t.dtype != dtype: raise TypeError("Expected tensor with type %r not %r" % (dtype, t.dtype)) return t return Tensor(t, dtype=dtype)
def execute(op_name, num_outputs, inputs, attrs=None, name=None): """Execute a TensorFlow operation. Args: op_name: Name of the TensorFlow operation (see REGISTER_OP in C++ code) to execute. num_outputs: The number of outputs of the operation to fetch. (Explicitly provided instead of being inferred for performance reasons). inputs: A list of inputs to the operation. Each entry should be a Tensor, or a value which can be passed to the Tensor constructor to create one. attrs: A tuple with alternating string attr names and attr values for this operation. name: Customized name for the operation. Returns: None if there are no outputs, a single Tensor object if there is one output and a list of Tensor objects if there are multiple outputs. Raises: An exception on error. """ ctx = context.get_default_context() # TODO(apassos) move this to convert_to_tensor inputs = [ag_core.getval(x) for x in inputs] # pylint: disable=protected-access input_handles = [c._handle for c in inputs] device_name = ctx.device_name try: outh = pywrap_tensorflow.TFE_Py_Execute(ctx._handle, device_name, str(op_name), input_handles, attrs, num_outputs) # pylint: enable=protected-access except core._NotOkStatusException as e: # pylint: disable=protected-access if name is not None: message = e.message + " name: " + name else: message = e.message raise core._status_to_exception(e.code, message) # pylint: disable=protected-access # pylint: enable=protected-access tensors = [tensor._tensor_from_handle(x) for x in outh] # pylint: disable=protected-access # TODO(alive, cais): Use the execution callback mechanism. if core.active_trace() is not None: trace_name = name if name else op_name for t in tensors: # pylint: disable=protected-access core.active_trace().record_tensor(trace_name, ops.tensor_id(t), t._device_name(), t.shape.num_elements()) # pylint: enable=protected-access # TODO(cais): Optimize this, perhaps by replacing this execute function with # a different one when there are execution callback(s). for callback in ctx.post_execution_callbacks: callback(op_name, name, attrs, inputs, tensors) return tensors
def match_complex(vs, x): x_iscomplex = vspace(getval(x)).iscomplex if x_iscomplex and not vs.iscomplex: return anp.real(x) elif not x_iscomplex and vs.iscomplex: return x + 0j else: return x
def fwd_grad_concatenate_args(argnum, g, ans, gvs, vs, axis_args, kwargs): result = [] for i in range(1, len(axis_args)): if i == argnum: result.append(g) else: result.append(anp.zeros_like(getval(axis_args[i]))) return anp.concatenate_args(axis_args[0], *result)
def ggnvp_maker(*args, **kwargs): f_vjp, f_x = make_vjp(f, f_argnum)(*args, **kwargs) g_hvp, grad_g_x = make_vjp(grad(g))(f_x) f_vjp_vjp, _ = make_vjp(f_vjp)(vspace(getval(grad_g_x)).zeros()) def ggnvp(v): return f_vjp(g_hvp(f_vjp_vjp(v))) return ggnvp
def _record_operation_vjp(g, ans, vs, gvs, output_tensors, input_tensors, side_outputs, backward_function): """Gradient for _record_operation.""" del ans, vs, gvs, output_tensors, input_tensors backward_args = tuple(g) + tuple(side_outputs) if ag_core.isnode(backward_args): backward_args = list(backward_args) tensors = nest.flatten(backward_function(*backward_args)) return _EagerList([ag_core.getval(t) for t in tensors])
def grad_fn(*outputs): """Generated gradient function.""" tensors = inputs + result_copies + list(outputs) tensors = [ag_core.getval(x) for x in tensors] result = _magic_gradient_function(op_name, attrs, len(inputs), num_outputs, *(tensors)) if _tracing: print("Gradient for", (name if name else op_name), "inputs", inputs, "output_grads", outputs) return result
def variational_objective(var_param): samples = approx.sample(var_param, self.num_mc_samples) if self._use_path_deriv: var_param_stopped = getval(var_param) lower_bound = np.mean(self.model(samples) - approx.log_density(var_param_stopped, samples)) elif approx.supports_entropy: lower_bound = np.mean(self.model(samples)) + approx.entropy(var_param) else: lower_bound = np.mean(self.model(samples) - approx.log_density(samples)) return -lower_bound
def flatten(value): """Flattens any nesting of tuples, arrays, or dicts. Returns 1D numpy array and an unflatten function. Doesn't preserve mixed numeric types (e.g. floats and ints). Assumes dict keys are sortable.""" try: vs = vspace(getval(value)) except TypeError: raise Exception("Don't know how to flatten type {}".format( type(value))) return vs.flatten(value), vs.unflatten
def flatten(value): """Flattens any nesting of tuples, arrays, or dicts. Returns 1D numpy array and an unflatten function. Doesn't preserve mixed numeric types (e.g. floats and ints). Assumes dict keys are sortable.""" if isinstance(getval(value), np.ndarray): shape = value.shape def unflatten(vector): return np.reshape(vector, shape) return np.ravel(value), unflatten elif isinstance(getval(value), (float, int)): return np.array([value]), lambda x : x[0] elif isinstance(getval(value), (tuple, list)): constructor = type(getval(value)) if not value: return np.array([]), lambda x : constructor() flat_pieces, unflatteners = zip(*map(flatten, value)) split_indices = np.cumsum([len(vec) for vec in flat_pieces[:-1]]) def unflatten(vector): pieces = np.split(vector, split_indices) return constructor(unflatten(v) for unflatten, v in zip(unflatteners, pieces)) return np.concatenate(flat_pieces), unflatten elif isinstance(getval(value), dict): items = sorted(iteritems(value), key=itemgetter(0)) keys, flat_pieces, unflatteners = zip(*[(k,) + flatten(v) for k, v in items]) split_indices = np.cumsum([len(vec) for vec in flat_pieces[:-1]]) def unflatten(vector): pieces = np.split(vector, split_indices) return {key: unflattener(piece) for piece, unflattener, key in zip(pieces, unflatteners, keys)} return np.concatenate(flat_pieces), unflatten else: raise Exception("Don't know how to flatten type {}".format(type(value)))
def flatten(value): # value can be any nested thing ((), array, [] ) etc # returns numpy array if isinstance(getval(value), np.ndarray): def unflatten(vector): return np.reshape(vector, value.shape) return np.ravel(value), unflatten elif isinstance(getval(value), float): return np.array([value]), lambda x: x[0] elif isinstance(getval(value), tuple): if not value: return np.array([]), lambda x: () flattened_first, unflatten_first = flatten(value[0]) flattened_rest, unflatten_rest = flatten(value[1:]) def unflatten(vector): N = len(flattened_first) return (unflatten_first(vector[:N]), ) + unflatten_rest(vector[N:]) return np.concatenate((flattened_first, flattened_rest)), unflatten elif isinstance(getval(value), list): if not value: return np.array([]), lambda x: [] flattened_first, unflatten_first = flatten(value[0]) flattened_rest, unflatten_rest = flatten(value[1:]) def unflatten(vector): N = len(flattened_first) return [unflatten_first(vector[:N])] + unflatten_rest(vector[N:]) return np.concatenate((flattened_first, flattened_rest)), unflatten else: raise Exception("Don't know how to flatten type {}".format( type(value)))
def _get_defun_inputs(args): """Maps the inputs args to graph inputs.""" ret = [] for a in args: a = ag_core.getval(a) if isinstance(a, (tensor.LazyZero, ops.Tensor, tensor.Tensor)): ret.append(graph_placeholder(a.dtype, a.shape)) elif type(a) in (tuple, list): ret.append(_get_defun_inputs(a)) else: ret.append(a) return tuple(ret) if type(args) is tuple else ret
def _cache_key(x): """Cache key for tfe functions.""" x = ag_core.getval(x) if isinstance(x, tensor.Tensor): return _TensorDtype(x.dtype, x._shape_tuple()) # pylint: disable=protected-access if isinstance(x, tensor.LazyZero): return _TensorDtype(x.dtype, tuple(x.shape.as_list())) # pylint: disable=protected-access if isinstance(x, np.ndarray): return ("array", x.shape, tuple(x.reshape(-1))) if type(x) in (list, tuple): return tuple([_cache_key(a) for a in x]) return x
def convert_to_eager_tensor(t, dtype=None): """Converts the given `value` to an `EagerTensor`.""" if isinstance(ag_core.getval(t), ops.EagerTensor): if dtype is not None and t.dtype != dtype: raise TypeError("Expected tensor with type %r not %r" % (dtype, t.dtype)) return t # Handle converting ResourceVariable to Tensor. # TODO(josh11b): get rid of this explicit ugly conversion once we have a more # general scheme in place. try: return t._dense_var_to_tensor(dtype=dtype, as_ref=False) # pylint: disable=protected-access except AttributeError: pass return ops.EagerTensor(t, dtype=dtype)
def _defun_internal(name, func, args, kwds): """Defines and returns graph-mode version of func.""" with context.graph_mode(): tmp_graph = ops.Graph() # Copy the graph collections to ensure summaries and other things work. This # lets the function access (but not mutate) collections of the containing # graph, such as the global step and the summary writer collections. curr_graph = ops.get_default_graph() for collection in curr_graph.collections: tmp_graph.get_collection_ref( collection)[:] = curr_graph.get_collection(collection) with tmp_graph.as_default(): func_inputs = _get_defun_inputs(args) captures = {} with capture_tensors(captures): func_outputs = func(*func_inputs, **kwds) ids = list(sorted(captures.keys())) if ids: extra_inputs, extra_placeholders = zip( *[captures[x] for x in ids]) else: extra_inputs = [] extra_placeholders = [] outputs_list = nest.flatten(func_outputs) output_shapes = [x.shape for x in outputs_list if x is not None] flat_inputs = [ x for x in nest.flatten(func_inputs) if isinstance(x, ops.Tensor) ] all_inputs = flat_inputs + list(extra_placeholders) func_def_outputs = [ ag_core.getval(x) for x in outputs_list if x is not None ] inference_function_def = graph_to_function_def.graph_to_function_def( tmp_graph, tmp_graph.get_operations(), all_inputs, func_def_outputs) # Register any other functions defined in the graph # TODO(ashankar): Oh lord, forgive me for this lint travesty. for f in tmp_graph._functions.values(): # pylint: disable=protected-access # TODO(ashankar): What about the gradient registry? _register_with_name(f.name, f.definition) _register_with_name(_inference_name(name), inference_function_def) return _GraphModeFunction(all_inputs, extra_inputs, inference_function_def, tmp_graph, tmp_graph.get_operations(), func_outputs, _map_sequence_obj_to_idx(func_def_outputs), output_shapes)
def flatten(value): # value can be any nested thing ((), array, [] ) etc # returns numpy array if isinstance(getval(value), np.ndarray): def unflatten(vector): return np.reshape(vector, value.shape) return np.ravel(value), unflatten elif isinstance(getval(value), float): return np.array([value]), lambda x : x[0] elif isinstance(getval(value), tuple): if not value: return np.array([]), lambda x : () flattened_first, unflatten_first = flatten(value[0]) flattened_rest, unflatten_rest = flatten(value[1:]) def unflatten(vector): N = len(flattened_first) return (unflatten_first(vector[:N]),) + unflatten_rest(vector[N:]) return np.concatenate((flattened_first, flattened_rest)), unflatten elif isinstance(getval(value), list): if not value: return np.array([]), lambda x : [] flattened_first, unflatten_first = flatten(value[0]) flattened_rest, unflatten_rest = flatten(value[1:]) def unflatten(vector): N = len(flattened_first) return [unflatten_first(vector[:N])] + unflatten_rest(vector[N:]) return np.concatenate((flattened_first, flattened_rest)), unflatten else: raise Exception("Don't know how to flatten type {}".format(type(value)))
def _backprop_call(self, args): """Calls the wrapped function and records the result on a tape.""" all_args = args + self._extra_inputs signature = self._forward_fdef.definition.signature if context.in_graph_mode(): g = ops.get_default_graph() g._add_function(self._forward_fdef) # pylint: disable=protected-access unwrapped_args = [ag_core.getval(x) for x in all_args] op = g.create_op( signature.name, [ops.convert_to_tensor(x) for x in unwrapped_args], [dtypes.DType(x.type) for x in signature.output_arg], op_def=signature, name="FunctionCall", compute_shapes=False) outputs = op.outputs outputs = [outputs] if isinstance(outputs, (tensor.Tensor, ops.Tensor, type(None))) else list(outputs) for i, s in enumerate(self._output_shapes): outputs[i].set_shape(s) else: outputs = execute.execute(str(signature.name), num_outputs=len(signature.output_arg), inputs=all_args) real_outputs = outputs[:len(self._returns)] side_outputs = outputs[len(self._returns):] watched_extra_inputs = [] for t in self._extra_inputs: tid = ops.tensor_id(t) for t in tape._tape_stack.stack: # pylint: disable=protected-access w = t.value.tensors.get(tid, None) if w is not None: watched_extra_inputs.append(w) break else: # Note: for-else here done on purpose watched_extra_inputs.append(t) def backward_function_wrapper(*outputs): outputs = outputs[len(real_outputs):] return self._backward_function(*outputs) real_outputs = tape.record_operation(real_outputs, (args + watched_extra_inputs), side_outputs, backward_function_wrapper) return self._build_call_outputs(self._returns, real_outputs)
def decorated(*args, **kwargs): """Decorated function with custom gradient.""" input_tensors = [_watch_value_from_tape(x) for x in args if isinstance(x, (_tensor.Tensor, tf_ops.Tensor)) or ag_core.isnode(x)] result, grad_fn = f(*args, **kwargs) flat_result = nest.flatten(result) flat_result = [ag_core.getval(x) for x in flat_result] flat_result = tape.record_operation( flat_result, input_tensors, [], grad_fn) flat_result = list(flat_result) return nest.pack_sequence_as(structure=result, flat_sequence=flat_result)
def _backprop_call(self, args): """Calls the wrapped function and records the result on a tape.""" all_args = args + self._extra_inputs signature = self._forward_fdef.definition.signature if context.in_graph_mode(): g = ops.get_default_graph() g._add_function(self._forward_fdef) # pylint: disable=protected-access unwrapped_args = [ag_core.getval(x) for x in all_args] op = g.create_op( signature.name, [ops.convert_to_tensor(x) for x in unwrapped_args], [dtypes.DType(x.type) for x in signature.output_arg], op_def=signature, name="FunctionCall", compute_shapes=False) outputs = op.outputs outputs = [outputs] if isinstance( outputs, (tensor.Tensor, ops.Tensor, type(None))) else list(outputs) for i, s in enumerate(self._output_shapes): outputs[i].set_shape(s) else: outputs = execute.execute( signature.name, num_outputs=len(signature.output_arg), inputs=all_args) real_outputs = outputs[:len(self._returns)] side_outputs = outputs[len(self._returns):] watched_extra_inputs = [] for t in self._extra_inputs: tid = ops.tensor_id(t) for t in tape._tape_stack.stack: # pylint: disable=protected-access w = t.value.tensors.get(tid, None) if w is not None: watched_extra_inputs.append(w) break else: # Note: for-else here done on purpose watched_extra_inputs.append(t) def backward_function_wrapper(*outputs): outputs = outputs[len(real_outputs):] return self._backward_function(*outputs) real_outputs = tape.record_operation( real_outputs, (args + watched_extra_inputs), side_outputs, backward_function_wrapper) return self._build_call_outputs(self._returns, real_outputs)
def _defun_internal(name, func, args, kwds): """Defines and returns graph-mode version of func.""" with context.graph_mode(): tmp_graph = ops.Graph() # Copy the graph collections to ensure summaries and other things work. This # lets the function access (but not mutate) collections of the containing # graph, such as the global step and the summary writer collections. curr_graph = ops.get_default_graph() for collection in curr_graph.collections: tmp_graph.get_collection_ref(collection)[:] = curr_graph.get_collection( collection) with tmp_graph.as_default(): func_inputs = _get_defun_inputs(args) captures = {} with capture_tensors(captures): func_outputs = func(*func_inputs, **kwds) ids = list(sorted(captures.keys())) if ids: extra_inputs, extra_placeholders = zip(* [captures[x] for x in ids]) else: extra_inputs = [] extra_placeholders = [] outputs_list = nest.flatten(func_outputs) output_shapes = [x.shape for x in outputs_list if x is not None] flat_inputs = [ x for x in nest.flatten(func_inputs) if isinstance(x, ops.Tensor) ] all_inputs = flat_inputs + list(extra_placeholders) func_def_outputs = [ag_core.getval(x) for x in outputs_list if x is not None] inference_function_def = graph_to_function_def.graph_to_function_def( tmp_graph, tmp_graph.get_operations(), all_inputs, func_def_outputs) # Register any other functions defined in the graph # TODO(ashankar): Oh lord, forgive me for this lint travesty. for f in tmp_graph._functions.values(): # pylint: disable=protected-access # TODO(ashankar): What about the gradient registry? _register_with_name(f.name, f.definition) _register_with_name(_inference_name(name), inference_function_def) return _GraphModeFunction( all_inputs, extra_inputs, inference_function_def, tmp_graph, tmp_graph.get_operations(), func_outputs, _map_sequence_obj_to_idx(func_def_outputs), output_shapes)
def repeat_to_match_shape(x, axis, keepdims): """Returns a function that repeats an array along axis to get a given shape. Also returns the number of repetitions of the array.""" if not isarray(x): return I, 1 shape = x.shape if axis is None: dtype=None if anp.iscomplexobj(x): dtype = getval(anp.array(x)).dtype # np.full() has a bug for complex numbers if keepdims: return lambda g : anp.full(shape, anp.sum(g), dtype=dtype), anp.prod(shape) else: return lambda g : anp.full(shape, g, dtype=dtype), anp.prod(shape) else: if keepdims: return lambda g : anp.repeat(g, shape[axis], axis), shape[axis] else: return lambda g : anp.repeat(anp.expand_dims(g, axis), shape[axis], axis), shape[axis]
def args_to_matching_eager(l, default_dtype=None): """Convert sequence `l` to eager same-type Tensors.""" # TODO(josh11b): Could we do a better job if we also passed in the # allowed dtypes when that was known? # Is some input already a Tensor with a dtype? dtype = None for t in l: if isinstance(ag_core.getval(t), tensor.Tensor): dtype = t.dtype break if dtype is None: # Infer a dtype based on the first value, and use that dtype for the # remaining values. ret = [] for t in l: ret.append(ops.convert_to_tensor(t, dtype, preferred_dtype=default_dtype)) if dtype is None: dtype = ret[-1].dtype else: ret = [ops.convert_to_tensor(t, dtype) for t in l] return dtype, ret
def decorated(*args, **kwargs): """Decorated function with custom gradient.""" input_tensors = [_watch_value_from_tape(x) for x in args if isinstance(x, (_tensor.Tensor, tf_ops.Tensor)) or ag_core.isnode(x)] result, grad_fn = f(*args, **kwargs) result_size = len(result) if isinstance(result, (list, tuple)) else 1 # TODO(apassos): naive uses of custom_gradient will not get the correct # second derivative this way if they capture any output tensors. Change the # signature of custom_gradient. def actual_grad_fn(*outputs): outputs = outputs[result_size:] return grad_fn(*outputs) flat_result = nest.flatten(result) flat_result = [ag_core.getval(x) for x in flat_result] flat_result = tape.record_operation( flat_result, input_tensors, [], actual_grad_fn) flat_result = list(flat_result) return nest.pack_sequence_as(structure=result, flat_sequence=flat_result)
def args_to_mixed_eager_tensors(lists): """Converts a list of same-length lists of values to eager tensors.""" assert len(lists) > 1 # Generate an error if len(lists[i]) is not the same for all i. lists_ret = [] for l in lists[1:]: if len(l) != len(lists[0]): raise ValueError( "Expected list arguments to be the same length: %d != %d (%r vs. %r)" % (len(lists[0]), len(l), lists[0], l)) lists_ret.append([]) # Convert the first element of each list first, then the second element, etc. types = [] for i in range(len(lists[0])): dtype = None # If any list has a Tensor, use that dtype for l in lists: if isinstance(ag_core.getval(l[i]), tensor.Tensor): dtype = l[i].dtype break if dtype is None: # Convert the first one and use its dtype. lists_ret[0].append(tensor.convert_to_eager_tensor(lists[0][i])) dtype = lists_ret[0][i].dtype for j in range(1, len(lists)): lists_ret[j].append( tensor.convert_to_eager_tensor(lists[j][i], dtype=dtype)) else: # Convert everything to the found dtype. for j in range(len(lists)): lists_ret[j].append( tensor.convert_to_eager_tensor(lists[j][i], dtype=dtype)) types.append(dtype) return types, lists_ret
def zeros_like(self): return {k : zeros_like(v) for k, v in six.iteritems(getval(self))}
def zeros_like(value): return tuple([zeros_like(item) for item in getval(value)])
def double_val_fun(*args, **kwargs): val = fun(*args, **kwargs) return val, getval(val)
def _ConstantValue(tensor, partial): # TODO(touts): Support Variables? if not isinstance(ag_core.getval(tensor), ops.Tensor): raise TypeError("tensor is not a Tensor") if tensor.op.type == "Const": return MakeNdarray(tensor.op.get_attr("value")) elif tensor.op.type == "Shape": input_shape = tensor.op.inputs[0].get_shape() if input_shape.is_fully_defined(): return np.array([dim.value for dim in input_shape.dims], dtype=tensor.dtype.as_numpy_dtype) else: return None elif tensor.op.type == "Size": input_shape = tensor.op.inputs[0].get_shape() if input_shape.is_fully_defined(): return np.prod([dim.value for dim in input_shape.dims], dtype=np.int32) else: return None elif tensor.op.type == "Rank": input_shape = tensor.op.inputs[0].get_shape() if input_shape.ndims is not None: return np.ndarray(shape=(), buffer=np.array([input_shape.ndims]), dtype=np.int32) else: return None elif tensor.op.type == "Range": start = constant_value(tensor.op.inputs[0]) if start is None: return None limit = constant_value(tensor.op.inputs[1]) if limit is None: return None delta = constant_value(tensor.op.inputs[2]) if delta is None: return None return np.arange(start, limit, delta, dtype=tensor.dtype.as_numpy_dtype) elif tensor.op.type == "Cast": pre_cast = constant_value(tensor.op.inputs[0]) if pre_cast is None: return None cast_dtype = dtypes.as_dtype(tensor.op.get_attr("DstT")) return pre_cast.astype(cast_dtype.as_numpy_dtype) elif tensor.op.type == "Concat": dim = constant_value(tensor.op.inputs[0]) if dim is None: return None values = [] for x in tensor.op.inputs[1:]: value = constant_value(x) if value is None: return None values.append(value) return np.concatenate(values, axis=dim) elif tensor.op.type == "ConcatV2": dim = constant_value(tensor.op.inputs[-1]) if dim is None: return None values = [] for x in tensor.op.inputs[:-1]: value = constant_value(x) if value is None: return None values.append(value) return np.concatenate(values, axis=dim) elif tensor.op.type == "Pack": values = [] # Some imported GraphDefs have Pack ops with zero inputs. Those are invalid # and shouldn't be produced, but to deal sensibly with them here we check # and return None. if not tensor.op.inputs: return None for x in tensor.op.inputs: value = constant_value(x, partial) if value is None and not partial: return None values.append(value) return np.array(values) elif tensor.op.type == "Fill": fill_shape = tensor.shape fill_value = constant_value(tensor.op.inputs[1]) if fill_shape.is_fully_defined() and fill_value is not None: return np.full(fill_shape.as_list(), fill_value, dtype=fill_value.dtype) else: return None elif tensor.op.type == "Equal": value1 = constant_value(tensor.op.inputs[0]) if value1 is None: return None value2 = constant_value(tensor.op.inputs[1]) if value2 is None: return None return np.equal(value1, value2) elif tensor.op.type == "NotEqual": value1 = constant_value(tensor.op.inputs[0]) if value1 is None: return None value2 = constant_value(tensor.op.inputs[1]) if value2 is None: return None return np.not_equal(value1, value2) else: return None
def make_grad_tile(ans, x, reps): reps = [reps] if anp.isscalar(reps) else reps def tile_grad(g): for axis, rep in enumerate(reps): g = sum(anp.split(g, rep, axis)) return anp.reshape(g, x.shape) return tile_grad anp.tile.defgrad(make_grad_tile) def make_grad_transpose(ans, x, axes=None): if axes is not None: axes = anp.argsort(axes) return lambda g : anp.transpose(g, axes) anp.transpose.defgrad(make_grad_transpose) isarray = lambda x : isinstance(getval(x), anp.ndarray) def repeat_to_match_shape(x, axis, keepdims): """Returns a function that repeats an array along axis to get a given shape. Also returns the number of repetitions of the array.""" assert isinstance(axis, (type(None), int, tuple)) if not isarray(x): return I, 1 shape = x.shape if axis is None: dtype=None if anp.iscomplexobj(x): dtype = getval(anp.array(x)).dtype # np.full() has a bug for complex numbers if keepdims: return lambda g : anp.full(shape, anp.sum(g), dtype=dtype), anp.prod(shape)