def update_op(self, optimizer, g):
    # pylint: disable=protected-access
    # for better convergence:

    with ops.colocate_with(None, ignore_existing=True):
      _slots = [
          optimizer.get_slot(self._v, _s) for _s in optimizer.get_slot_names()
      ]
      with ops.control_dependencies([g]):
        _before = [self._v.read_value()] + [_s.read_value() for _s in _slots]
      if isinstance(g, ops.IndexedSlices):
        if self._v.constraint is not None:
          raise RuntimeError(
              "Cannot use a constraint function on a sparse variable.")

        with ops.control_dependencies(_before):
          _apply_op = optimizer._resource_apply_sparse_duplicate_indices(
              g.values, self._v, g.indices)
        with ops.control_dependencies([_apply_op]):
          _after = control_flow_ops.group([self._v.update_op()] +
                                          [_s.update_op() for _s in _slots])
          return _after
      with ops.control_dependencies(_before):
        _apply_op = optimizer._resource_apply_dense(g, self._v)
      if self._v.constraint is not None:
        with ops.control_dependencies([_apply_op]):
          return self._v.assign(self._v.constraint(self._v))
      else:
        with ops.control_dependencies([_apply_op]):
          _after = control_flow_ops.group([self._v.update_op()] +
                                          [_s.update_op() for _s in _slots])
        return _after
Пример #2
0
 def update_op(self, optimizer, g):
     # pylint: disable=protected-access
     if isinstance(g, ops.IndexedSlices):
         if self._v.constraint is not None:
             raise RuntimeError(
                 "Cannot use a constraint function on a sparse variable.")
         return optimizer._resource_apply_sparse_duplicate_indices(
             g.values, self._v, g.indices)
     update_op = optimizer._resource_apply_dense(g, self._v)
     if self._v.constraint is not None:
         with ops.control_dependencies([update_op]):
             return self._v.assign(self._v.constraint(self._v))
     else:
         return update_op
Пример #3
0
  def update_op(self, optimizer, g):
    # pylint: disable=protected-access
    # for better convergence:

    if not self._v.params.trainable:
      return control_flow_ops.no_op()

    with ops.colocate_with(None, ignore_existing=True):
      _slots = [
          optimizer.get_slot(self._v, _s) for _s in optimizer.get_slot_names()
      ]
      self._v._track_optimizer_slots(_slots)

      with ops.control_dependencies([g]):
        v0 = self._v.read_value(do_prefetch=not self._v.params.bp_v2)
        s0 = [_s.read_value() for _s in _slots]
        _before = [v0] + s0

      if isinstance(g, ops.IndexedSlices):
        if self._v.constraint is not None:
          raise RuntimeError(
              "Cannot use a constraint function on a sparse variable.")

        with ops.control_dependencies(_before):
          _apply_op = optimizer._resource_apply_sparse_duplicate_indices(
              g.values, self._v, g.indices)
        with ops.control_dependencies([_apply_op]):
          _after = control_flow_ops.group(
              [self._v.update_op(v0=v0)] +
              [_s.update_op(v0=s0[si]) for si, _s in enumerate(_slots)])
          return _after

      with ops.control_dependencies(_before):
        _apply_op = optimizer._resource_apply_dense(g, self._v)
      if self._v.constraint is not None:
        with ops.control_dependencies([_apply_op]):
          return self._v.assign(self._v.constraint(self._v))
      else:
        with ops.control_dependencies([_apply_op]):
          _after = control_flow_ops.group(
              [self._v.update_op(v0=v0)] +
              [_s.update_op(v0=s0[si]) for si, _s in enumerate(_slots)])
        return _after
 def update_op_asynchronous(self, optimizer, g, index):
     # pylint: disable=protected-access
     if isinstance(g, ops.IndexedSlices):
         return optimizer._resource_apply_sparse_duplicate_indices(
             g.values, self._v, g.indices, index)
     return optimizer._resource_apply_dense(g, self._v, index)
 def update_op_asynchronous(self, optimizer, g, index):
   # pylint: disable=protected-access
   if isinstance(g, ops.IndexedSlices):
     return optimizer._resource_apply_sparse_duplicate_indices(
       g.values, self._v, g.indices, index)
   return optimizer._resource_apply_dense(g, self._v, index)