Example #1
0
    def _build_call_outputs(self, func_outputs, result):
        """Maps the fdef output list to actual output structure.

    Args:
      func_outputs: The outputs originally defined by the graph function. It
        could potentially be a nested structure.
      result: Output lists defined by FunctionDef.
    Returns:
      The actual call output.
    """
        if self._func_outputs is None:
            return None
        if isinstance(ag_core.getval(self._func_outputs), ops.Tensor):
            return result[0]

        outputs = []
        for o in func_outputs:
            vo = ag_core.getval(o)
            if isinstance(vo, ops.Tensor):
                outputs.append(result[self._returns_to_fedf_outputs[id(vo)]])
            elif type(vo) in (tuple, list):
                outputs.append(self._build_call_outputs(o, result))
            else:
                outputs.append(o)

        return tuple(outputs) if type(func_outputs) is tuple else outputs
Example #2
0
 def variational_objective(params, t):
     # stochastic estimate of the variational lower bound
     
     w_mu, w_log_s2, s_pi, log_s2_w, pi_w, log_s2 = unpack_params(params)
     
     # compute the expectation of the "data fit" term and the entropy
     # by Monte Carlo sampling
     datafit = 0.
     entropy = 0.
     for _ in range(num_samples):
         # acquire M Bernoulli samples
         s = Bernoulli(pi = np.column_stack( [ 1-s_pi, s_pi ] ), T=0.5)[:,1]
         # acquire M Gaussian distributed samples
         mean = s*w_mu
         var  = s*np.exp(w_log_s2) + (1-s)*np.exp(log_s2_w)
         w = mean + np.sqrt(var) * np.random.randn(M)
         # compute the log of the joint probability
         datafit = datafit \
                   + logprob(s, w, log_s2_w, pi_w, log_s2, X, y, batch_size, t)
         # compute the entropy q(w,s)
         mean = getval(mean)
         var  = getval(var)
         s_pi = getval(s_pi)
         entropy = entropy \
                   + np.sum( 0.5*np.log(2*np.pi*var) + 0.5/var*np.power(w-mean, 2) ) \
                   - np.sum( s*np.log(s_pi) + (1-s)*np.log(1-s_pi) )
     datafit = datafit / num_samples
     entropy = entropy / num_samples
     # the lower bound to maximize
     lower_bound = datafit + entropy
     return -lower_bound
Example #3
0
  def _build_call_outputs(self, func_outputs, result):
    """Maps the fdef output list to actual output structure.

    Args:
      func_outputs: The outputs originally defined by the graph function. It
        could potentially be a nested structure.
      result: Output lists defined by FunctionDef.
    Returns:
      The actual call output.
    """
    if self._func_outputs is None:
      return None
    if isinstance(ag_core.getval(self._func_outputs), ops.Tensor):
      return result[0]

    outputs = []
    for o in func_outputs:
      vo = ag_core.getval(o)
      if isinstance(vo, ops.Tensor):
        outputs.append(result[self._returns_to_fedf_outputs[id(vo)]])
      elif type(vo) in (tuple, list):
        outputs.append(self._build_call_outputs(o, result))
      else:
        outputs.append(o)

    return tuple(outputs) if type(func_outputs) is tuple else outputs
Example #4
0
    def _objfunc(self, params, t):
        samps = self.sample_var_approx(getval(params),
                                       n_samples=self.N_SAMPLES)

        return np.mean(
            self.log_var_approx(samps, params) *
            (self.log_prob(samps) -
             self.log_var_approx(samps, getval(params))))
Example #5
0
def flatten(value):
    """value can be any nesting of tuples, arrays, dicts.
       returns 1D numpy array and an unflatten function."""
    if isinstance(getval(value), np.ndarray):
        def unflatten(vector):
            return np.reshape(vector, value.shape)
        return np.ravel(value), unflatten

    elif isinstance(getval(value), float):
        return np.array([value]), lambda x : x[0]

    elif isinstance(getval(value), tuple):
        if not value:
            return np.array([]), lambda x : ()
        flattened_first, unflatten_first = flatten(value[0])
        flattened_rest, unflatten_rest = flatten(value[1:])
        def unflatten(vector):
            N = len(flattened_first)
            return (unflatten_first(vector[:N]),) + unflatten_rest(vector[N:])

        return np.concatenate((flattened_first, flattened_rest)), unflatten

    elif isinstance(getval(value), list):
        if not value:
            return np.array([]), lambda x : []
        flattened_first, unflatten_first = flatten(value[0])
        flattened_rest, unflatten_rest = flatten(value[1:])
        def unflatten(vector):
            N = len(flattened_first)
            return [unflatten_first(vector[:N])] + unflatten_rest(vector[N:])

        return np.concatenate((flattened_first, flattened_rest)), unflatten

    elif isinstance(getval(value), dict):
        flattened = []
        unflatteners = []
        lengths = []
        keys = []
        for k, v in sorted(iteritems(value), key=itemgetter(0)):
            cur_flattened, cur_unflatten = flatten(v)
            flattened.append(cur_flattened)
            unflatteners.append(cur_unflatten)
            lengths.append(len(cur_flattened))
            keys.append(k)

        def unflatten(vector):
            split_ixs = np.cumsum(lengths)
            pieces = np.split(vector, split_ixs)
            return {key: unflattener(piece)
                    for piece, unflattener, key in zip(pieces, unflatteners, keys)}

        return np.concatenate(flattened), unflatten

    else:
        raise Exception("Don't know how to flatten type {}".format(type(value)))
Example #6
0
def flatten(value):
    """value can be any nesting of tuples, arrays, dicts.
       returns 1D numpy array and an unflatten function."""
    if isinstance(getval(value), np.ndarray):
        def unflatten(vector):
            return np.reshape(vector, value.shape)
        return np.ravel(value), unflatten

    elif isinstance(getval(value), float):
        return np.array([value]), lambda x : x[0]

    elif isinstance(getval(value), tuple):
        if not value:
            return np.array([]), lambda x : ()
        flattened_first, unflatten_first = flatten(value[0])
        flattened_rest, unflatten_rest = flatten(value[1:])
        def unflatten(vector):
            N = len(flattened_first)
            return (unflatten_first(vector[:N]),) + unflatten_rest(vector[N:])

        return np.concatenate((flattened_first, flattened_rest)), unflatten

    elif isinstance(getval(value), list):
        if not value:
            return np.array([]), lambda x : []
        flattened_first, unflatten_first = flatten(value[0])
        flattened_rest, unflatten_rest = flatten(value[1:])
        def unflatten(vector):
            N = len(flattened_first)
            return [unflatten_first(vector[:N])] + unflatten_rest(vector[N:])

        return np.concatenate((flattened_first, flattened_rest)), unflatten

    elif isinstance(getval(value), dict):
        flattened = []
        unflatteners = []
        lengths = []
        keys = []
        for k, v in sorted(iteritems(value), key=itemgetter(0)):
            cur_flattened, cur_unflatten = flatten(v)
            flattened.append(cur_flattened)
            unflatteners.append(cur_unflatten)
            lengths.append(len(cur_flattened))
            keys.append(k)

        def unflatten(vector):
            split_ixs = np.cumsum(lengths)
            pieces = np.split(vector, split_ixs)
            return {key: unflattener(piece)
                    for piece, unflattener, key in zip(pieces, unflatteners, keys)}

        return np.concatenate(flattened), unflatten

    else:
        raise Exception("Don't know how to flatten type {}".format(type(value)))
 def mut_add(self, x, y):
   """Add wrapper safe for IndexedSlices and LazyZero."""
   if isinstance(ag_core.getval(x), tensor.LazyZero):
     return y
   if isinstance(ag_core.getval(y), tensor.LazyZero):
     return x
   if isinstance(x, ops.IndexedSlices):
     x = _indexed_slices_to_tensor(x)
   if isinstance(y, ops.IndexedSlices):
     y = _indexed_slices_to_tensor(y)
   return math_ops.add(x, y)
 def jac_fun(*args, **kwargs):
     arg_in = args[argnum]
     output = fun(*args, **kwargs)
     assert isinstance(getval(arg_in), np.ndarray), "Must have array input"
     assert isinstance(getval(output), np.ndarray), "Must have array output"
     jac = np.zeros(output.shape + arg_in.shape)
     input_slice = (slice(None),) * len(arg_in.shape)
     for idxs in it.product(*map(range, output.shape)):
         scalar_fun = lambda *args, **kwargs : fun(*args, **kwargs)[idxs]
         jac[idxs + input_slice] = grad(scalar_fun, argnum=argnum)(*args, **kwargs)
     return jac
 def jac_fun(*args, **kwargs):
     arg_in = args[argnum]
     output = fun(*args, **kwargs)
     assert isinstance(getval(arg_in), np.ndarray), "Must have array input"
     assert isinstance(getval(output), np.ndarray), "Must have array output"
     jac = np.zeros(output.shape + arg_in.shape)
     input_slice = (slice(None),) * len(arg_in.shape)
     for idxs in it.product(*list(map(range, output.shape))):
         scalar_fun = lambda *args, **kwargs : fun(*args, **kwargs)[idxs]
         jac[idxs + input_slice] = grad(scalar_fun, argnum=argnum)(*args, **kwargs)
     return jac
Example #10
0
def as_scalar(x):
    vs = vspace(getval(x))
    if vs.iscomplex:
        x = np.real(x)
    if vs.shape == ():
        return x
    elif vs.size == 1:
        return x.reshape(())
    else:
        raise TypeError("Output {} can't be cast to float. "
                        "Function grad requires a scalar-valued function. "
                        "Try jacobian or elementwise_grad.".format(getval(x)))
def as_scalar(x):
    vs = vspace(getval(x))
    if vs.iscomplex:
        x = np.real(x)
    if vs.shape == ():
        return x
    elif vs.size == 1:
        return x.reshape(())
    else:
        raise TypeError(
            "Output {} can't be cast to float. "
            "Function grad requires a scalar-valued function. "
            "Try jacobian or elementwise_grad.".format(getval(x)))
Example #12
0
def flatten(value):
    """Flattens any nesting of tuples, arrays, or dicts.
       Returns 1D numpy array and an unflatten function.
       Doesn't preserve mixed numeric types (e.g. floats and ints).
       Assumes dict keys are sortable."""
    if isinstance(getval(value), np.ndarray):
        shape = value.shape

        def unflatten(vector):
            return np.reshape(vector, shape)

        return np.ravel(value), unflatten

    elif isinstance(getval(value), (float, int)):
        return np.array([value]), lambda x: x[0]

    elif isinstance(getval(value), (tuple, list)):
        constructor = type(getval(value))
        if not value:
            return np.array([]), lambda x: constructor()
        flat_pieces, unflatteners = zip(*map(flatten, value))
        split_indices = np.cumsum([len(vec) for vec in flat_pieces[:-1]])

        def unflatten(vector):
            pieces = np.split(vector, split_indices)
            return constructor(
                unflatten(v) for unflatten, v in zip(unflatteners, pieces))

        return np.concatenate(flat_pieces), unflatten

    elif isinstance(getval(value), dict):
        items = sorted(iteritems(value), key=itemgetter(0))
        keys, flat_pieces, unflatteners = zip(*[(k, ) + flatten(v)
                                                for k, v in items])
        split_indices = np.cumsum([len(vec) for vec in flat_pieces[:-1]])

        def unflatten(vector):
            pieces = np.split(vector, split_indices)
            return {
                key: unflattener(piece)
                for piece, unflattener, key in zip(pieces, unflatteners, keys)
            }

        return np.concatenate(flat_pieces), unflatten

    else:
        raise Exception("Don't know how to flatten type {}".format(
            type(value)))
Example #13
0
def convert_to_eager_tensor(t, dtype=None):
    """Converts the given `value` to an `EagerTensor`."""
    if isinstance(ag_core.getval(t), ops.EagerTensor):
        if dtype is not None and t.dtype != dtype:
            raise TypeError("Expected tensor with type %r not %r" %
                            (dtype, t.dtype))
        return t
    # Handle converting ResourceVariable to Tensor.
    # TODO(josh11b): get rid of this explicit ugly conversion once we have a more
    # general scheme in place.
    try:
        return t._dense_var_to_tensor(dtype=dtype, as_ref=False)  # pylint: disable=protected-access
    except AttributeError:
        pass
    if isinstance(t, (int, float)):
        # Use a scalar cache. This will put each scalar of each type only once on
        # each device. Scalars don't use much device memory but copying scalars can
        # trigger memcpys which are slow.
        device = context.context().device_name
        cache_key = device, t, dtype, type(t)
        tensor = _scalar_cache.get(cache_key, None)
        if tensor is not None:
            return tensor
        value = ops.EagerTensor(t, dtype=dtype)
        _scalar_cache[cache_key] = value
        return value
    return ops.EagerTensor(t, dtype=dtype)
Example #14
0
    def _estimate_ELBO_noscore(self, params, t):
        samps = self.sample_var_approx(params, n_samples=self.N_SAMPLES)

        #eliminates the score function
        return -np.mean(
            self.log_prob(samps) - self.log_var_approx(samps, getval(params)),
            axis=0)  #this one appears to be correct
Example #15
0
def repeat_to_match_shape(x, axis, keepdims):
    """Returns a function that repeats an array along axis to get a given shape.
       Also returns the number of repetitions of the array."""
    assert isinstance(axis, (type(None), int, tuple))

    if not isarray(x):
        return I, 1
    shape = x.shape
    if axis is None:
        dtype=None
        if anp.iscomplexobj(x):
            dtype = getval(anp.array(x)).dtype   # np.full() has a bug for complex numbers
        if keepdims:
            return lambda g : anp.full(shape, anp.sum(g), dtype=dtype), anp.prod(shape)
        else:
            return lambda g : anp.full(shape, g, dtype=dtype), anp.prod(shape)
    elif isinstance(axis, int):
        if keepdims:
            return lambda g : anp.repeat(g, shape[axis], axis), shape[axis]
        else:
            return lambda g : anp.repeat(anp.expand_dims(g, axis),
                                         shape[axis], axis), shape[axis]
    else:
        repeats  = [shape[i] if i in axis else 1 for i in range(len(shape))]
        expanded = [shape[i] if i not in axis else 1 for i in range(len(shape))]
        num_reps = anp.prod(anp.array(shape)[list(axis)])

        if keepdims:
            return lambda g: anp.tile(g, repeats), num_reps
        else:
            return lambda g: anp.tile(anp.reshape(g, expanded), repeats), num_reps
Example #16
0
def constant_value(tensor, partial=False):  # pylint: disable=invalid-name
  """Returns the constant value of the given tensor, if efficiently calculable.

  This function attempts to partially evaluate the given tensor, and
  returns its value as a numpy ndarray if this succeeds.

  TODO(mrry): Consider whether this function should use a registration
  mechanism like gradients and ShapeFunctions, so that it is easily
  extensible.

  NOTE: If `constant_value(tensor)` returns a non-`None` result, it will no
  longer be possible to feed a different value for `tensor`. This allows the
  result of this function to influence the graph that is constructed, and
  permits static shape optimizations.

  Args:
    tensor: The Tensor to be evaluated.
    partial: If True, the returned numpy array is allowed to have partially
      evaluated values. Values that can't be evaluated will be None.

  Returns:
    A numpy ndarray containing the constant value of the given `tensor`,
    or None if it cannot be calculated.

  Raises:
    TypeError: if tensor is not an ops.Tensor.
  """
  if isinstance(ag_core.getval(tensor), ops.EagerTensor):
    return tensor.numpy()
  ret = _ConstantValue(tensor, partial)
  if ret is not None:
    # The caller may now depend on the constant value of `tensor`, so we
    # conservatively prevent it from being fed.
    tensor.graph.prevent_feeding(tensor)
  return ret
Example #17
0
def repeat_to_match_shape(x, axis, keepdims):
    """Returns a function that repeats an array along axis to get a given shape.
       Also returns the number of repetitions of the array."""
    assert isinstance(axis, (type(None), int, tuple))

    if not isarray(x):
        return I, 1
    shape = x.shape
    if axis is None:
        dtype=None
        if anp.iscomplexobj(x):
            dtype = getval(anp.array(x)).dtype   # np.full() has a bug for complex numbers
        if keepdims:
            return lambda g : anp.full(shape, anp.sum(g), dtype=dtype), anp.prod(shape)
        else:
            return lambda g : anp.full(shape, g, dtype=dtype), anp.prod(shape)
    elif isinstance(axis, int):
        if keepdims:
            return lambda g : anp.repeat(g, shape[axis], axis), shape[axis]
        else:
            return lambda g : anp.repeat(anp.expand_dims(g, axis),
                                         shape[axis], axis), shape[axis]
    else:
        repeats  = [shape[i] if i in axis else 1 for i in range(len(shape))]
        expanded = [shape[i] if i not in axis else 1 for i in range(len(shape))]
        num_reps = anp.prod(anp.array(shape)[list(axis)])

        if keepdims:
            return lambda g: anp.tile(g, repeats), num_reps
        else:
            return lambda g: anp.tile(anp.reshape(g, expanded), repeats), num_reps
Example #18
0
 def _compute_backprop(self):
   """Computes the backprop function object for this function."""
   self._has_backprop = True
   with self._graph.as_default(), context.graph_mode():
     c = _CapturingContext()
     with c:
       filtered_outputs = [
           ag_core.getval(x) for x in self._returns if x is not None
       ]
       self._out_grad_placeholders = [
           graph_placeholder(x.dtype, x.shape) for x in filtered_outputs
       ]
       in_gradients = gradients_impl.gradients(
           filtered_outputs,
           self._input_placeholders,
           grad_ys=self._out_grad_placeholders)
       shapes = [x.shape for x in in_gradients if x is not None]
   captures = list(sorted(c.captured_tensors, key=lambda x: x.name))
   forward_function_def = graph_to_function_def.graph_to_function_def(
       self._graph, self._ops, self._input_placeholders,
       filtered_outputs + captures)
   self._forward_fdef = _DefinedFunction(forward_function_def)
   _register_with_name(_forward_name(self._func_name), forward_function_def)
   backward_outputs = [x for x in in_gradients if x is not None]
   all_inputs = self._out_grad_placeholders + captures
   backward_function_def = graph_to_function_def.graph_to_function_def(
       self._graph, [x.op for x in self._out_grad_placeholders
                    ] + list(sorted(c.known_ops, key=lambda x: x.name)),
       all_inputs, backward_outputs)
   _register_with_name(_backward_name(self._func_name), backward_function_def)
   self._backward_function = _GraphModeFunction(
       all_inputs, [], backward_function_def, self._graph, c.known_ops,
       in_gradients, _map_sequence_obj_to_idx(backward_outputs), shapes)
Example #19
0
def convert_to_mixed_eager_tensors(values):
    v = [
        t if isinstance(ag_core.getval(t), tensor.Tensor) else tensor.Tensor(t)
        for t in values
    ]
    types = [t.dtype for t in v]
    return types, v
Example #20
0
def args_to_matching_eager(l, default_dtype=None):
    """Convert sequence `l` to eager same-type Tensors."""
    # TODO(josh11b): Could we do a better job if we also passed in the
    # allowed dtypes when that was known?

    # Is some input already a Tensor with a dtype?
    dtype = None
    for t in l:
        if isinstance(ag_core.getval(t), tensor.Tensor):
            dtype = t.dtype
            break

    if dtype is None:
        # TODO(josh11b): At the moment, I don't think this can fail, but at some
        # point we likely should have some logic to prevent bad conversions.
        dtype = default_dtype

    if dtype is None:
        # Infer a dtype based on the first value, and use that dtype for the
        # remaining values.
        ret = []
        for t in l:
            ret.append(ops.convert_to_tensor(t, dtype))
            if dtype is None:
                dtype = ret[-1].dtype
    else:
        ret = [ops.convert_to_tensor(t, dtype) for t in l]

    return dtype, ret
Example #21
0
 def _compute_backprop(self):
   """Computes the backprop function object for this function."""
   self._has_backprop = True
   with self._graph.as_default(), context.graph_mode():
     c = _CapturingContext()
     with c:
       filtered_outputs = [
           ag_core.getval(x) for x in self._returns if x is not None
       ]
       self._out_grad_placeholders = [
           graph_placeholder(x.dtype, x.shape) for x in filtered_outputs
       ]
       in_gradients = gradients_impl.gradients(
           filtered_outputs,
           self._input_placeholders,
           grad_ys=self._out_grad_placeholders)
       shapes = [x.shape for x in in_gradients if x is not None]
   captures = list(sorted(c.captured_tensors, key=lambda x: x.name))
   forward_function_def = graph_to_function_def.graph_to_function_def(
       self._graph, self._ops, self._input_placeholders,
       filtered_outputs + captures)
   self._forward_fdef = _DefinedFunction(forward_function_def)
   _register_with_name(_forward_name(self._func_name), forward_function_def)
   backward_outputs = [x for x in in_gradients if x is not None]
   all_inputs = self._out_grad_placeholders + captures
   backward_function_def = graph_to_function_def.graph_to_function_def(
       self._graph, [x.op for x in self._out_grad_placeholders
                    ] + list(sorted(c.known_ops, key=lambda x: x.name)),
       all_inputs, backward_outputs)
   _register_with_name(_backward_name(self._func_name), backward_function_def)
   self._backward_function = _GraphModeFunction(
       all_inputs, [], backward_function_def, self._graph, c.known_ops,
       in_gradients, _map_sequence_obj_to_idx(backward_outputs), shapes)
Example #22
0
def args_to_matching_eager(l, default_dtype=None):
  """Convert sequence `l` to eager same-type Tensors."""
  # TODO(josh11b): Could we do a better job if we also passed in the
  # allowed dtypes when that was known?

  # Is some input already a Tensor with a dtype?
  dtype = None
  for t in l:
    if isinstance(ag_core.getval(t), tensor.Tensor):
      dtype = t.dtype
      break

  if dtype is None:
    # TODO(josh11b): At the moment, I don't think this can fail, but at some
    # point we likely should have some logic to prevent bad conversions.
    dtype = default_dtype

  if dtype is None:
    # Infer a dtype based on the first value, and use that dtype for the
    # remaining values.
    ret = []
    for t in l:
      ret.append(tensor.convert_to_eager_tensor(t, dtype))
      if dtype is None:
        dtype = ret[-1].dtype
  else:
    ret = [tensor.convert_to_eager_tensor(t, dtype) for t in l]

  return dtype, ret
Example #23
0
def convert_to_eager_tensor(t, dtype=None):
    if isinstance(ag_core.getval(t), Tensor):
        if dtype is not None and t.dtype != dtype:
            raise TypeError("Expected tensor with type %r not %r" %
                            (dtype, t.dtype))
        return t
    return Tensor(t, dtype=dtype)
Example #24
0
def execute(op_name, num_outputs, inputs, attrs=None, name=None):
  """Execute a TensorFlow operation.

  Args:
    op_name: Name of the TensorFlow operation (see REGISTER_OP in C++ code) to
      execute.
    num_outputs: The number of outputs of the operation to fetch.
                 (Explicitly provided instead of being inferred for performance
                 reasons).
    inputs: A list of inputs to the operation. Each entry should be a Tensor, or
      a value which can be passed to the Tensor constructor to create one.
    attrs: A tuple with alternating string attr names and attr values for this
      operation.
    name: Customized name for the operation.

  Returns:
    None if there are no outputs, a single Tensor object if there is one output
    and a list of Tensor objects if there are multiple outputs.

  Raises:
    An exception on error.
  """
  ctx = context.get_default_context()
  # TODO(apassos) move this to convert_to_tensor
  inputs = [ag_core.getval(x) for x in inputs]
  # pylint: disable=protected-access
  input_handles = [c._handle for c in inputs]
  device_name = ctx.device_name
  try:
    outh = pywrap_tensorflow.TFE_Py_Execute(ctx._handle, device_name,
                                            str(op_name), input_handles, attrs,
                                            num_outputs)
    # pylint: enable=protected-access
  except core._NotOkStatusException as e:  # pylint: disable=protected-access
    if name is not None:
      message = e.message + " name: " + name
    else:
      message = e.message
    raise core._status_to_exception(e.code, message)  # pylint: disable=protected-access
  # pylint: enable=protected-access

  tensors = [tensor._tensor_from_handle(x) for x in outh]  # pylint: disable=protected-access
  # TODO(alive, cais): Use the execution callback mechanism.
  if core.active_trace() is not None:
    trace_name = name if name else op_name
    for t in tensors:
      # pylint: disable=protected-access
      core.active_trace().record_tensor(trace_name,
                                        ops.tensor_id(t),
                                        t._device_name(),
                                        t.shape.num_elements())
      # pylint: enable=protected-access

  # TODO(cais): Optimize this, perhaps by replacing this execute function with
  # a different one when there are execution callback(s).
  for callback in ctx.post_execution_callbacks:
    callback(op_name, name, attrs, inputs, tensors)

  return tensors
Example #25
0
def match_complex(vs, x):
    x_iscomplex = vspace(getval(x)).iscomplex
    if x_iscomplex and not vs.iscomplex:
        return anp.real(x)
    elif not x_iscomplex and vs.iscomplex:
        return x + 0j
    else:
        return x
Example #26
0
def match_complex(vs, x):
    x_iscomplex = vspace(getval(x)).iscomplex
    if x_iscomplex and not vs.iscomplex:
        return anp.real(x)
    elif not x_iscomplex and vs.iscomplex:
        return x + 0j
    else:
        return x
Example #27
0
def fwd_grad_concatenate_args(argnum, g, ans, gvs, vs, axis_args, kwargs):
    result = []
    for i in range(1, len(axis_args)):
        if i == argnum:
            result.append(g)
        else:
            result.append(anp.zeros_like(getval(axis_args[i])))
    return anp.concatenate_args(axis_args[0], *result)
Example #28
0
    def ggnvp_maker(*args, **kwargs):
        f_vjp, f_x = make_vjp(f, f_argnum)(*args, **kwargs)
        g_hvp, grad_g_x = make_vjp(grad(g))(f_x)
        f_vjp_vjp, _ = make_vjp(f_vjp)(vspace(getval(grad_g_x)).zeros())

        def ggnvp(v):
            return f_vjp(g_hvp(f_vjp_vjp(v)))

        return ggnvp
Example #29
0
File: tape.py Project: lengjia/RRL
def _record_operation_vjp(g, ans, vs, gvs, output_tensors, input_tensors,
                          side_outputs, backward_function):
  """Gradient for _record_operation."""
  del ans, vs, gvs, output_tensors, input_tensors
  backward_args = tuple(g) + tuple(side_outputs)
  if ag_core.isnode(backward_args):
    backward_args = list(backward_args)
  tensors = nest.flatten(backward_function(*backward_args))
  return _EagerList([ag_core.getval(t) for t in tensors])
Example #30
0
 def grad_fn(*outputs):
     """Generated gradient function."""
     tensors = inputs + result_copies + list(outputs)
     tensors = [ag_core.getval(x) for x in tensors]
     result = _magic_gradient_function(op_name, attrs, len(inputs),
                                       num_outputs, *(tensors))
     if _tracing:
         print("Gradient for", (name if name else op_name), "inputs",
               inputs, "output_grads", outputs)
     return result
Example #31
0
 def grad_fn(*outputs):
   """Generated gradient function."""
   tensors = inputs + result_copies + list(outputs)
   tensors = [ag_core.getval(x) for x in tensors]
   result = _magic_gradient_function(op_name, attrs, len(inputs),
                                     num_outputs, *(tensors))
   if _tracing:
     print("Gradient for", (name if name else op_name), "inputs", inputs,
           "output_grads", outputs)
   return result
Example #32
0
 def variational_objective(var_param):
     samples = approx.sample(var_param, self.num_mc_samples)
     if self._use_path_deriv:
         var_param_stopped = getval(var_param)
         lower_bound = np.mean(self.model(samples) - approx.log_density(var_param_stopped, samples))
     elif approx.supports_entropy:
         lower_bound = np.mean(self.model(samples)) + approx.entropy(var_param)
     else:
         lower_bound = np.mean(self.model(samples) - approx.log_density(samples))
     return -lower_bound
Example #33
0
def flatten(value):
    """Flattens any nesting of tuples, arrays, or dicts.
       Returns 1D numpy array and an unflatten function.
       Doesn't preserve mixed numeric types (e.g. floats and ints).
       Assumes dict keys are sortable."""
    try:
        vs = vspace(getval(value))
    except TypeError:
        raise Exception("Don't know how to flatten type {}".format(
            type(value)))
    return vs.flatten(value), vs.unflatten
Example #34
0
def flatten(value):
    """Flattens any nesting of tuples, arrays, or dicts.
       Returns 1D numpy array and an unflatten function.
       Doesn't preserve mixed numeric types (e.g. floats and ints).
       Assumes dict keys are sortable."""
    if isinstance(getval(value), np.ndarray):
        shape = value.shape
        def unflatten(vector):
            return np.reshape(vector, shape)
        return np.ravel(value), unflatten

    elif isinstance(getval(value), (float, int)):
        return np.array([value]), lambda x : x[0]

    elif isinstance(getval(value), (tuple, list)):
        constructor = type(getval(value))
        if not value:
            return np.array([]), lambda x : constructor()
        flat_pieces, unflatteners = zip(*map(flatten, value))
        split_indices = np.cumsum([len(vec) for vec in flat_pieces[:-1]])

        def unflatten(vector):
            pieces = np.split(vector, split_indices)
            return constructor(unflatten(v) for unflatten, v in zip(unflatteners, pieces))

        return np.concatenate(flat_pieces), unflatten

    elif isinstance(getval(value), dict):
        items = sorted(iteritems(value), key=itemgetter(0))
        keys, flat_pieces, unflatteners = zip(*[(k,) + flatten(v) for k, v in items])
        split_indices = np.cumsum([len(vec) for vec in flat_pieces[:-1]])

        def unflatten(vector):
            pieces = np.split(vector, split_indices)
            return {key: unflattener(piece)
                    for piece, unflattener, key in zip(pieces, unflatteners, keys)}

        return np.concatenate(flat_pieces), unflatten

    else:
        raise Exception("Don't know how to flatten type {}".format(type(value)))
Example #35
0
def flatten(value):
    # value can be any nested thing ((), array, [] ) etc
    # returns numpy array
    if isinstance(getval(value), np.ndarray):

        def unflatten(vector):
            return np.reshape(vector, value.shape)

        return np.ravel(value), unflatten

    elif isinstance(getval(value), float):
        return np.array([value]), lambda x: x[0]

    elif isinstance(getval(value), tuple):
        if not value:
            return np.array([]), lambda x: ()
        flattened_first, unflatten_first = flatten(value[0])
        flattened_rest, unflatten_rest = flatten(value[1:])

        def unflatten(vector):
            N = len(flattened_first)
            return (unflatten_first(vector[:N]), ) + unflatten_rest(vector[N:])

        return np.concatenate((flattened_first, flattened_rest)), unflatten

    elif isinstance(getval(value), list):
        if not value:
            return np.array([]), lambda x: []

        flattened_first, unflatten_first = flatten(value[0])
        flattened_rest, unflatten_rest = flatten(value[1:])

        def unflatten(vector):
            N = len(flattened_first)
            return [unflatten_first(vector[:N])] + unflatten_rest(vector[N:])

        return np.concatenate((flattened_first, flattened_rest)), unflatten

    else:
        raise Exception("Don't know how to flatten type {}".format(
            type(value)))
Example #36
0
def _get_defun_inputs(args):
  """Maps the inputs args to graph inputs."""
  ret = []
  for a in args:
    a = ag_core.getval(a)
    if isinstance(a, (tensor.LazyZero, ops.Tensor, tensor.Tensor)):
      ret.append(graph_placeholder(a.dtype, a.shape))
    elif type(a) in (tuple, list):
      ret.append(_get_defun_inputs(a))
    else:
      ret.append(a)
  return tuple(ret) if type(args) is tuple else ret
Example #37
0
def _cache_key(x):
  """Cache key for tfe functions."""
  x = ag_core.getval(x)
  if isinstance(x, tensor.Tensor):
    return _TensorDtype(x.dtype, x._shape_tuple())  # pylint: disable=protected-access
  if isinstance(x, tensor.LazyZero):
    return _TensorDtype(x.dtype, tuple(x.shape.as_list()))  # pylint: disable=protected-access
  if isinstance(x, np.ndarray):
    return ("array", x.shape, tuple(x.reshape(-1)))
  if type(x) in (list, tuple):
    return tuple([_cache_key(a) for a in x])
  return x
Example #38
0
def _cache_key(x):
    """Cache key for tfe functions."""
    x = ag_core.getval(x)
    if isinstance(x, tensor.Tensor):
        return _TensorDtype(x.dtype, x._shape_tuple())  # pylint: disable=protected-access
    if isinstance(x, tensor.LazyZero):
        return _TensorDtype(x.dtype, tuple(x.shape.as_list()))  # pylint: disable=protected-access
    if isinstance(x, np.ndarray):
        return ("array", x.shape, tuple(x.reshape(-1)))
    if type(x) in (list, tuple):
        return tuple([_cache_key(a) for a in x])
    return x
Example #39
0
def _get_defun_inputs(args):
    """Maps the inputs args to graph inputs."""
    ret = []
    for a in args:
        a = ag_core.getval(a)
        if isinstance(a, (tensor.LazyZero, ops.Tensor, tensor.Tensor)):
            ret.append(graph_placeholder(a.dtype, a.shape))
        elif type(a) in (tuple, list):
            ret.append(_get_defun_inputs(a))
        else:
            ret.append(a)
    return tuple(ret) if type(args) is tuple else ret
Example #40
0
def convert_to_eager_tensor(t, dtype=None):
  """Converts the given `value` to an `EagerTensor`."""
  if isinstance(ag_core.getval(t), ops.EagerTensor):
    if dtype is not None and t.dtype != dtype:
      raise TypeError("Expected tensor with type %r not %r" % (dtype, t.dtype))
    return t
  # Handle converting ResourceVariable to Tensor.
  # TODO(josh11b): get rid of this explicit ugly conversion once we have a more
  # general scheme in place.
  try:
    return t._dense_var_to_tensor(dtype=dtype, as_ref=False)  # pylint: disable=protected-access
  except AttributeError:
    pass
  return ops.EagerTensor(t, dtype=dtype)
Example #41
0
def _defun_internal(name, func, args, kwds):
    """Defines and returns graph-mode version of func."""
    with context.graph_mode():
        tmp_graph = ops.Graph()
        # Copy the graph collections to ensure summaries and other things work. This
        # lets the function access (but not mutate) collections of the containing
        # graph, such as the global step and the summary writer collections.
        curr_graph = ops.get_default_graph()
        for collection in curr_graph.collections:
            tmp_graph.get_collection_ref(
                collection)[:] = curr_graph.get_collection(collection)
        with tmp_graph.as_default():
            func_inputs = _get_defun_inputs(args)

            captures = {}
            with capture_tensors(captures):
                func_outputs = func(*func_inputs, **kwds)
            ids = list(sorted(captures.keys()))
            if ids:
                extra_inputs, extra_placeholders = zip(
                    *[captures[x] for x in ids])
            else:
                extra_inputs = []
                extra_placeholders = []
            outputs_list = nest.flatten(func_outputs)
            output_shapes = [x.shape for x in outputs_list if x is not None]

    flat_inputs = [
        x for x in nest.flatten(func_inputs) if isinstance(x, ops.Tensor)
    ]
    all_inputs = flat_inputs + list(extra_placeholders)

    func_def_outputs = [
        ag_core.getval(x) for x in outputs_list if x is not None
    ]
    inference_function_def = graph_to_function_def.graph_to_function_def(
        tmp_graph, tmp_graph.get_operations(), all_inputs, func_def_outputs)
    # Register any other functions defined in the graph
    # TODO(ashankar): Oh lord, forgive me for this lint travesty.
    for f in tmp_graph._functions.values():  # pylint: disable=protected-access
        # TODO(ashankar): What about the gradient registry?
        _register_with_name(f.name, f.definition)
    _register_with_name(_inference_name(name), inference_function_def)

    return _GraphModeFunction(all_inputs, extra_inputs,
                              inference_function_def, tmp_graph,
                              tmp_graph.get_operations(), func_outputs,
                              _map_sequence_obj_to_idx(func_def_outputs),
                              output_shapes)
Example #42
0
def flatten(value):
    # value can be any nested thing ((), array, [] ) etc
    # returns numpy array
    if isinstance(getval(value), np.ndarray):
        def unflatten(vector):
            return np.reshape(vector, value.shape)
        return np.ravel(value), unflatten

    elif isinstance(getval(value), float):
        return np.array([value]), lambda x : x[0]

    elif isinstance(getval(value), tuple):
        if not value:
            return np.array([]), lambda x : ()
        flattened_first, unflatten_first = flatten(value[0])
        flattened_rest, unflatten_rest = flatten(value[1:])
        def unflatten(vector):
            N = len(flattened_first)
            return (unflatten_first(vector[:N]),) + unflatten_rest(vector[N:])

        return np.concatenate((flattened_first, flattened_rest)), unflatten

    elif isinstance(getval(value), list):
        if not value:
            return np.array([]), lambda x : []

        flattened_first, unflatten_first = flatten(value[0])
        flattened_rest, unflatten_rest = flatten(value[1:])
        def unflatten(vector):
            N = len(flattened_first)
            return [unflatten_first(vector[:N])] + unflatten_rest(vector[N:])

        return np.concatenate((flattened_first, flattened_rest)), unflatten

    else:
        raise Exception("Don't know how to flatten type {}".format(type(value)))
Example #43
0
    def _backprop_call(self, args):
        """Calls the wrapped function and records the result on a tape."""
        all_args = args + self._extra_inputs
        signature = self._forward_fdef.definition.signature
        if context.in_graph_mode():
            g = ops.get_default_graph()
            g._add_function(self._forward_fdef)  # pylint: disable=protected-access
            unwrapped_args = [ag_core.getval(x) for x in all_args]
            op = g.create_op(
                signature.name,
                [ops.convert_to_tensor(x) for x in unwrapped_args],
                [dtypes.DType(x.type) for x in signature.output_arg],
                op_def=signature,
                name="FunctionCall",
                compute_shapes=False)
            outputs = op.outputs
            outputs = [outputs] if isinstance(outputs,
                                              (tensor.Tensor, ops.Tensor,
                                               type(None))) else list(outputs)
            for i, s in enumerate(self._output_shapes):
                outputs[i].set_shape(s)
        else:
            outputs = execute.execute(str(signature.name),
                                      num_outputs=len(signature.output_arg),
                                      inputs=all_args)
        real_outputs = outputs[:len(self._returns)]
        side_outputs = outputs[len(self._returns):]
        watched_extra_inputs = []
        for t in self._extra_inputs:
            tid = ops.tensor_id(t)
            for t in tape._tape_stack.stack:  # pylint: disable=protected-access
                w = t.value.tensors.get(tid, None)
                if w is not None:
                    watched_extra_inputs.append(w)
                    break
            else:  # Note: for-else here done on purpose
                watched_extra_inputs.append(t)

        def backward_function_wrapper(*outputs):
            outputs = outputs[len(real_outputs):]
            return self._backward_function(*outputs)

        real_outputs = tape.record_operation(real_outputs,
                                             (args + watched_extra_inputs),
                                             side_outputs,
                                             backward_function_wrapper)

        return self._build_call_outputs(self._returns, real_outputs)
Example #44
0
  def decorated(*args, **kwargs):
    """Decorated function with custom gradient."""
    input_tensors = [_watch_value_from_tape(x) for x in args
                     if isinstance(x, (_tensor.Tensor, tf_ops.Tensor))
                     or ag_core.isnode(x)]
    result, grad_fn = f(*args, **kwargs)

    flat_result = nest.flatten(result)
    flat_result = [ag_core.getval(x) for x in flat_result]
    flat_result = tape.record_operation(
        flat_result,
        input_tensors,
        [],
        grad_fn)
    flat_result = list(flat_result)
    return nest.pack_sequence_as(structure=result, flat_sequence=flat_result)
Example #45
0
  def _backprop_call(self, args):
    """Calls the wrapped function and records the result on a tape."""
    all_args = args + self._extra_inputs
    signature = self._forward_fdef.definition.signature
    if context.in_graph_mode():
      g = ops.get_default_graph()
      g._add_function(self._forward_fdef)  # pylint: disable=protected-access
      unwrapped_args = [ag_core.getval(x) for x in all_args]
      op = g.create_op(
          signature.name, [ops.convert_to_tensor(x) for x in unwrapped_args],
          [dtypes.DType(x.type) for x in signature.output_arg],
          op_def=signature,
          name="FunctionCall",
          compute_shapes=False)
      outputs = op.outputs
      outputs = [outputs] if isinstance(
          outputs, (tensor.Tensor, ops.Tensor, type(None))) else list(outputs)
      for i, s in enumerate(self._output_shapes):
        outputs[i].set_shape(s)
    else:
      outputs = execute.execute(
          signature.name,
          num_outputs=len(signature.output_arg),
          inputs=all_args)
    real_outputs = outputs[:len(self._returns)]
    side_outputs = outputs[len(self._returns):]
    watched_extra_inputs = []
    for t in self._extra_inputs:
      tid = ops.tensor_id(t)
      for t in tape._tape_stack.stack:  # pylint: disable=protected-access
        w = t.value.tensors.get(tid, None)
        if w is not None:
          watched_extra_inputs.append(w)
          break
      else:  # Note: for-else here done on purpose
        watched_extra_inputs.append(t)

    def backward_function_wrapper(*outputs):
      outputs = outputs[len(real_outputs):]
      return self._backward_function(*outputs)
    real_outputs = tape.record_operation(
        real_outputs,
        (args + watched_extra_inputs),
        side_outputs,
        backward_function_wrapper)

    return self._build_call_outputs(self._returns, real_outputs)
Example #46
0
def _defun_internal(name, func, args, kwds):
  """Defines and returns graph-mode version of func."""
  with context.graph_mode():
    tmp_graph = ops.Graph()
    # Copy the graph collections to ensure summaries and other things work. This
    # lets the function access (but not mutate) collections of the containing
    # graph, such as the global step and the summary writer collections.
    curr_graph = ops.get_default_graph()
    for collection in curr_graph.collections:
      tmp_graph.get_collection_ref(collection)[:] = curr_graph.get_collection(
          collection)
    with tmp_graph.as_default():
      func_inputs = _get_defun_inputs(args)

      captures = {}
      with capture_tensors(captures):
        func_outputs = func(*func_inputs, **kwds)
      ids = list(sorted(captures.keys()))
      if ids:
        extra_inputs, extra_placeholders = zip(* [captures[x] for x in ids])
      else:
        extra_inputs = []
        extra_placeholders = []
      outputs_list = nest.flatten(func_outputs)
      output_shapes = [x.shape for x in outputs_list if x is not None]

  flat_inputs = [
      x for x in nest.flatten(func_inputs) if isinstance(x, ops.Tensor)
  ]
  all_inputs = flat_inputs + list(extra_placeholders)

  func_def_outputs = [ag_core.getval(x) for x in outputs_list if x is not None]
  inference_function_def = graph_to_function_def.graph_to_function_def(
      tmp_graph, tmp_graph.get_operations(), all_inputs, func_def_outputs)
  # Register any other functions defined in the graph
  # TODO(ashankar): Oh lord, forgive me for this lint travesty.
  for f in tmp_graph._functions.values():  # pylint: disable=protected-access
    # TODO(ashankar): What about the gradient registry?
    _register_with_name(f.name, f.definition)
  _register_with_name(_inference_name(name), inference_function_def)

  return _GraphModeFunction(
      all_inputs, extra_inputs, inference_function_def, tmp_graph,
      tmp_graph.get_operations(), func_outputs,
      _map_sequence_obj_to_idx(func_def_outputs), output_shapes)
Example #47
0
def repeat_to_match_shape(x, axis, keepdims):
    """Returns a function that repeats an array along axis to get a given shape.
       Also returns the number of repetitions of the array."""
    if not isarray(x):
        return I, 1
    shape = x.shape
    if axis is None:
        dtype=None
        if anp.iscomplexobj(x):
            dtype = getval(anp.array(x)).dtype   # np.full() has a bug for complex numbers
        if keepdims:
            return lambda g : anp.full(shape, anp.sum(g), dtype=dtype), anp.prod(shape)
        else:
            return lambda g : anp.full(shape, g, dtype=dtype), anp.prod(shape)
    else:
        if keepdims:
            return lambda g : anp.repeat(g, shape[axis], axis), shape[axis]
        else:
            return lambda g : anp.repeat(anp.expand_dims(g, axis),
                                         shape[axis], axis), shape[axis]
Example #48
0
def args_to_matching_eager(l, default_dtype=None):
  """Convert sequence `l` to eager same-type Tensors."""
  # TODO(josh11b): Could we do a better job if we also passed in the
  # allowed dtypes when that was known?

  # Is some input already a Tensor with a dtype?
  dtype = None
  for t in l:
    if isinstance(ag_core.getval(t), tensor.Tensor):
      dtype = t.dtype
      break

  if dtype is None:
    # Infer a dtype based on the first value, and use that dtype for the
    # remaining values.
    ret = []
    for t in l:
      ret.append(ops.convert_to_tensor(t, dtype, preferred_dtype=default_dtype))
      if dtype is None:
        dtype = ret[-1].dtype
  else:
    ret = [ops.convert_to_tensor(t, dtype) for t in l]

  return dtype, ret
Example #49
0
  def decorated(*args, **kwargs):
    """Decorated function with custom gradient."""
    input_tensors = [_watch_value_from_tape(x) for x in args
                     if isinstance(x, (_tensor.Tensor, tf_ops.Tensor))
                     or ag_core.isnode(x)]
    result, grad_fn = f(*args, **kwargs)
    result_size = len(result) if isinstance(result, (list, tuple)) else 1

    # TODO(apassos): naive uses of custom_gradient will not get the correct
    # second derivative this way if they capture any output tensors. Change the
    # signature of custom_gradient.
    def actual_grad_fn(*outputs):
      outputs = outputs[result_size:]
      return grad_fn(*outputs)

    flat_result = nest.flatten(result)
    flat_result = [ag_core.getval(x) for x in flat_result]
    flat_result = tape.record_operation(
        flat_result,
        input_tensors,
        [],
        actual_grad_fn)
    flat_result = list(flat_result)
    return nest.pack_sequence_as(structure=result, flat_sequence=flat_result)
Example #50
0
def args_to_mixed_eager_tensors(lists):
  """Converts a list of same-length lists of values to eager tensors."""
  assert len(lists) > 1

  # Generate an error if len(lists[i]) is not the same for all i.
  lists_ret = []
  for l in lists[1:]:
    if len(l) != len(lists[0]):
      raise ValueError(
          "Expected list arguments to be the same length: %d != %d (%r vs. %r)"
          % (len(lists[0]), len(l), lists[0], l))
    lists_ret.append([])

  # Convert the first element of each list first, then the second element, etc.
  types = []
  for i in range(len(lists[0])):
    dtype = None
    # If any list has a Tensor, use that dtype
    for l in lists:
      if isinstance(ag_core.getval(l[i]), tensor.Tensor):
        dtype = l[i].dtype
        break
    if dtype is None:
      # Convert the first one and use its dtype.
      lists_ret[0].append(tensor.convert_to_eager_tensor(lists[0][i]))
      dtype = lists_ret[0][i].dtype
      for j in range(1, len(lists)):
        lists_ret[j].append(
            tensor.convert_to_eager_tensor(lists[j][i], dtype=dtype))
    else:
      # Convert everything to the found dtype.
      for j in range(len(lists)):
        lists_ret[j].append(
            tensor.convert_to_eager_tensor(lists[j][i], dtype=dtype))
    types.append(dtype)
  return types, lists_ret
Example #51
0
 def zeros_like(self):
     return {k : zeros_like(v) for k, v in six.iteritems(getval(self))}
Example #52
0
 def zeros_like(value):
     return tuple([zeros_like(item) for item in getval(value)])
Example #53
0
 def double_val_fun(*args, **kwargs):
     val = fun(*args, **kwargs)
     return val, getval(val)
Example #54
0
def _ConstantValue(tensor, partial):
  # TODO(touts): Support Variables?
  if not isinstance(ag_core.getval(tensor), ops.Tensor):
    raise TypeError("tensor is not a Tensor")
  if tensor.op.type == "Const":
    return MakeNdarray(tensor.op.get_attr("value"))
  elif tensor.op.type == "Shape":
    input_shape = tensor.op.inputs[0].get_shape()
    if input_shape.is_fully_defined():
      return np.array([dim.value for dim in input_shape.dims],
                      dtype=tensor.dtype.as_numpy_dtype)
    else:
      return None
  elif tensor.op.type == "Size":
    input_shape = tensor.op.inputs[0].get_shape()
    if input_shape.is_fully_defined():
      return np.prod([dim.value for dim in input_shape.dims], dtype=np.int32)
    else:
      return None
  elif tensor.op.type == "Rank":
    input_shape = tensor.op.inputs[0].get_shape()
    if input_shape.ndims is not None:
      return np.ndarray(shape=(), buffer=np.array([input_shape.ndims]),
                        dtype=np.int32)
    else:
      return None
  elif tensor.op.type == "Range":
    start = constant_value(tensor.op.inputs[0])
    if start is None:
      return None
    limit = constant_value(tensor.op.inputs[1])
    if limit is None:
      return None
    delta = constant_value(tensor.op.inputs[2])
    if delta is None:
      return None
    return np.arange(start, limit, delta, dtype=tensor.dtype.as_numpy_dtype)
  elif tensor.op.type == "Cast":
    pre_cast = constant_value(tensor.op.inputs[0])
    if pre_cast is None:
      return None
    cast_dtype = dtypes.as_dtype(tensor.op.get_attr("DstT"))
    return pre_cast.astype(cast_dtype.as_numpy_dtype)
  elif tensor.op.type == "Concat":
    dim = constant_value(tensor.op.inputs[0])
    if dim is None:
      return None
    values = []
    for x in tensor.op.inputs[1:]:
      value = constant_value(x)
      if value is None:
        return None
      values.append(value)
    return np.concatenate(values, axis=dim)
  elif tensor.op.type == "ConcatV2":
    dim = constant_value(tensor.op.inputs[-1])
    if dim is None:
      return None
    values = []
    for x in tensor.op.inputs[:-1]:
      value = constant_value(x)
      if value is None:
        return None
      values.append(value)
    return np.concatenate(values, axis=dim)
  elif tensor.op.type == "Pack":
    values = []
    # Some imported GraphDefs have Pack ops with zero inputs. Those are invalid
    # and shouldn't be produced, but to deal sensibly with them here we check
    # and return None.
    if not tensor.op.inputs:
      return None
    for x in tensor.op.inputs:
      value = constant_value(x, partial)
      if value is None and not partial:
        return None
      values.append(value)
    return np.array(values)
  elif tensor.op.type == "Fill":
    fill_shape = tensor.shape
    fill_value = constant_value(tensor.op.inputs[1])
    if fill_shape.is_fully_defined() and fill_value is not None:
      return np.full(fill_shape.as_list(), fill_value, dtype=fill_value.dtype)
    else:
      return None
  elif tensor.op.type == "Equal":
    value1 = constant_value(tensor.op.inputs[0])
    if value1 is None:
      return None
    value2 = constant_value(tensor.op.inputs[1])
    if value2 is None:
      return None
    return np.equal(value1, value2)
  elif tensor.op.type == "NotEqual":
    value1 = constant_value(tensor.op.inputs[0])
    if value1 is None:
      return None
    value2 = constant_value(tensor.op.inputs[1])
    if value2 is None:
      return None
    return np.not_equal(value1, value2)
  else:
    return None
Example #55
0
def make_grad_tile(ans, x, reps):
    reps = [reps] if anp.isscalar(reps) else reps
    def tile_grad(g):
        for axis, rep in enumerate(reps):
            g = sum(anp.split(g, rep, axis))
        return anp.reshape(g, x.shape)
    return tile_grad
anp.tile.defgrad(make_grad_tile)

def make_grad_transpose(ans, x, axes=None):
    if axes is not None:
        axes = anp.argsort(axes)
    return lambda g : anp.transpose(g, axes)
anp.transpose.defgrad(make_grad_transpose)

isarray = lambda x : isinstance(getval(x), anp.ndarray)

def repeat_to_match_shape(x, axis, keepdims):
    """Returns a function that repeats an array along axis to get a given shape.
       Also returns the number of repetitions of the array."""
    assert isinstance(axis, (type(None), int, tuple))

    if not isarray(x):
        return I, 1
    shape = x.shape
    if axis is None:
        dtype=None
        if anp.iscomplexobj(x):
            dtype = getval(anp.array(x)).dtype   # np.full() has a bug for complex numbers
        if keepdims:
            return lambda g : anp.full(shape, anp.sum(g), dtype=dtype), anp.prod(shape)