Ejemplo n.º 1
0
    def _get_queue_ops(self, var_update_op: ops.Operation,
                       source_op: ops.Operation, is_chief: bool,
                       is_trainable: bool) -> List[ops.Operation]:
        """
        Get queue operations for synchronous parameter update.

        Maintain a list of queues of size 1. The chief machine pushes a token to each queue at the beginning
        of each update. The other workers then dequeue a token from their corresponding queue if their gradient
        is sent to the accumulator. The enqueue and dequeue operations are grouped and have to be completed
        before the model moves on to the next step, thus resulting in synchronous parameter update.

        Args:
            var_update_op: The op

        Returns:
            A list of queue operations.
        """
        var_op = var_update_op.inputs[UPDATE_OP_VAR_POS].op

        var_update_sync_queues = \
            [data_flow_ops.FIFOQueue(1, [dtypes.bool], shapes=[[]],
                                     name='%s_update_sync_queue_%d' % (var_op.name, i),
                                     shared_name='%s_update_sync_queue_%d' % (var_op.name, i))
             for i in range(self.num_workers)]

        queue_ops = []
        if is_chief:
            if is_trainable:
                var_update_deps = [
                    self._var_op_to_accum_apply_op[var_op], source_op
                ]
            else:
                var_update_deps = [var_update_op]
            # Chief enqueues tokens to all other workers after executing variable update
            token = constant_op.constant(False)
            with ops.control_dependencies(var_update_deps):
                for i, q in enumerate(var_update_sync_queues):
                    if i != self.worker_id:
                        queue_ops.append(q.enqueue(token))
                    else:
                        queue_ops.append(gen_control_flow_ops.no_op())
        else:
            # wait for execution of var_update_op
            if is_trainable:
                with ops.control_dependencies(
                    [self._var_op_to_accum_apply_op[var_op]]):
                    dequeue = var_update_sync_queues[self.worker_id].dequeue()
            else:
                dequeue = var_update_sync_queues[self.worker_id].dequeue()
            queue_ops.append(dequeue)

        return queue_ops
Ejemplo n.º 2
0
    def _resource_apply_sparse(self, grad, var, indices):
        var_dtype = var.dtype.base_dtype
        lr_t = self._decayed_lr(var_dtype)
        beta_1_t = self._get_hyper('beta_1', var_dtype)
        beta_2_t = self._get_hyper('beta_2', var_dtype)
        epsilon_t = ops.convert_to_tensor(self.epsilon, var_dtype)
        local_step = math_ops.cast(self.iterations + 1, var_dtype)
        beta_1_power = math_ops.pow(beta_1_t, local_step)
        beta_2_power = math_ops.pow(beta_2_t, local_step)

        decay_steps = self._get_hyper('decay_steps', var_dtype)
        warmup_steps = self._get_hyper('warmup_steps', var_dtype)
        min_lr = self._get_hyper('min_lr', var_dtype)
        lr_t = tf.where(
            local_step <= warmup_steps,
            lr_t * (local_step / warmup_steps),
            min_lr + (lr_t - min_lr) *
            (1.0 - tf.minimum(local_step, decay_steps) / decay_steps),
        )
        lr_t = (lr_t * math_ops.sqrt(1 - beta_2_power) / (1 - beta_1_power))

        m = self.get_slot(var, 'm')
        m_scaled_g_values = grad * (1 - beta_1_t)
        m_t = state_ops.assign(m, m * beta_1_t, use_locking=self._use_locking)
        with ops.control_dependencies([m_t]):
            m_t = self._resource_scatter_add(m, indices, m_scaled_g_values)

        v = self.get_slot(var, 'v')
        v_scaled_g_values = (grad * grad) * (1 - beta_2_t)
        v_t = state_ops.assign(v, v * beta_2_t, use_locking=self._use_locking)
        with ops.control_dependencies([v_t]):
            v_t = self._resource_scatter_add(v, indices, v_scaled_g_values)

        if self.amsgrad:
            v_hat = self.get_slot(var, 'vhat')
            v_hat_t = math_ops.maximum(v_hat, v_t)
            var_update = m_t / (math_ops.sqrt(v_hat_t) + epsilon_t)
        else:
            var_update = m_t / (math_ops.sqrt(v_t) + epsilon_t)

        if self._initial_weight_decay > 0.0:
            weight_decay = self._get_hyper('weight_decay', var_dtype)
            var_update += weight_decay * var
        var_update = state_ops.assign_sub(var,
                                          lr_t * var_update,
                                          use_locking=self._use_locking)

        updates = [var_update, m_t, v_t]
        if self.amsgrad:
            updates.append(v_hat_t)
        return control_flow_ops.group(*updates)
Ejemplo n.º 3
0
    def _apply_sparse_shared(self, grad, var, indices, scatter_add):
        var_dtype = var.dtype.base_dtype
        beta1_power, beta2_power = self._get_beta_accumulators()
        beta1_power = math_ops.cast(beta1_power, var_dtype)
        beta2_power = math_ops.cast(beta2_power, var_dtype)
        niter = self._get_niter()
        niter = math_ops.cast(niter, var_dtype)
        lr_t = math_ops.cast(self._lr_t, var_dtype)
        beta1_t = math_ops.cast(self._beta1_t, var_dtype)
        beta2_t = math_ops.cast(self._beta2_t, var_dtype)
        epsilon_t = math_ops.cast(self._epsilon_t, var_dtype)

        sma_inf = 2.0 / (1.0 - beta2_t) - 1.0
        sma_t = sma_inf - 2.0 * niter * beta2_power / (1.0 - beta2_power)

        m = self.get_slot(var, 'm')
        m_t = state_ops.assign(m, beta1_t * m, use_locking=self._use_locking)
        with ops.control_dependencies([m_t]):
            m_t = scatter_add(m, indices, grad * (1 - beta1_t))
        m_corr_t = m_t / (1.0 - beta1_power)

        v = self.get_slot(var, 'v')
        v_t = state_ops.assign(v, beta2_t * v, use_locking=self._use_locking)
        with ops.control_dependencies([v_t]):
            v_t = scatter_add(v, indices,
                              (1.0 - beta2_t) * math_ops.square(grad))

        if self._amsgrad:
            vhat = self.get_slot(var, 'vhat')
            vhat_t = state_ops.assign(vhat,
                                      math_ops.maximum(vhat, v_t),
                                      use_locking=self._use_locking)
            v_corr_t = math_ops.sqrt(vhat_t / (1.0 - beta2_power) + epsilon_t)
        else:
            v_corr_t = math_ops.sqrt(v_t / (1.0 - beta2_power) + epsilon_t)

        r_t = math_ops.sqrt((sma_t - 4.0) / (sma_inf - 4.0) * (sma_t - 2.0) /
                            (sma_inf - 2.0) * sma_inf / sma_t)

        var_t = tf.where(sma_t > 5.0, r_t * m_corr_t / v_corr_t, m_corr_t)

        var_update = state_ops.assign_sub(var,
                                          lr_t * var_t,
                                          use_locking=self._use_locking)

        updates = [var_update, m_t, v_t]
        if self._amsgrad:
            updates.append(vhat_t)
        return control_flow_ops.group(*updates)
Ejemplo n.º 4
0
        def build_and_run_model():
            dataset = dataset_ops.Dataset.from_tensor_slices(
                np.ones(10, dtype=np.float32))
            infeed_queue = ipu.ipu_infeed_queue.IPUInfeedQueue(
                dataset, "infeed")
            outfeed_queue = ipu.ipu_outfeed_queue.IPUOutfeedQueue("outfeed")

            def body(v, x):
                v = v + x
                outfed = outfeed_queue.enqueue(v)
                return v, outfed

            def my_net(v):
                return ipu.loops.repeat(10, body, v, infeed_queue)

            v = array_ops.placeholder(np.float32, shape=())
            with ipu.scopes.ipu_scope("/device:IPU:0"):
                [result] = ipu.ipu_compiler.compile(my_net, inputs=[v])
            with ops.control_dependencies([result]):
                dequeued = outfeed_queue.dequeue()

            with session.Session() as sess:
                report = ReportJSON(
                    self, sess, set_opts_fn=_use_offline_compilation_if_needed)
                sess.run(infeed_queue.initializer)
                try:
                    res, deq = sess.run([result, dequeued], {v: 0.0})
                except errors.InvalidArgumentError as e:
                    if offline_compilation_needed and "compilation only" in e.message:
                        res = []
                        deq = []
                    else:
                        raise
                events = report.get_event_trace(sess)
                return res, deq, events
Ejemplo n.º 5
0
def _update_t_cur_eta_t(self):  # keras
    self.updates.append(_update_t_cur(self))
    # Cosine annealing
    if self.use_cosine_annealing:
        # ensure eta_t is updated AFTER t_cur
        with ops.control_dependencies([self.updates[-1]]):
            self.updates.append(state_ops.assign(self.eta_t,
                                                 _compute_eta_t(self)))
Ejemplo n.º 6
0
        def my_model_fn(features, mode):
            logging_op = hook.log({"features": features})
            with ops.control_dependencies([logging_op]):
                predictions = math_ops.reduce_max(features)

            return model_fn_lib.EstimatorSpec(
                mode,
                predictions=predictions,
            )
Ejemplo n.º 7
0
    def get_training_loss_and_op(self, compiled_training_loop):
        with ops.device(_HOST_DEVICE):
            with ops.control_dependencies([compiled_training_loop]):
                loss = self._outfeed_queue.dequeue()

            # Reduce loss over all dimensions (i.e. batch_size, gradient_accumulation_count)
            loss = math_ops.reduce_mean(math_ops.cast(loss, dtypes.float32))

        train_op = compiled_training_loop

        return loss, train_op
Ejemplo n.º 8
0
    def _apply_sparse_shared(self, grad, var, indices, scatter_add):
        learning_rate_t = math_ops.cast(self.learning_rate_t,
                                        var.dtype.base_dtype)
        beta_1_t = math_ops.cast(self.beta_1_t, var.dtype.base_dtype)
        beta_2_t = math_ops.cast(self.beta_2_t, var.dtype.base_dtype)
        epsilon_t = math_ops.cast(self.epsilon_t, var.dtype.base_dtype)
        weight_decay_rate_t = math_ops.cast(self.weight_decay_rate_t,
                                            var.dtype.base_dtype)

        m = self.get_slot(var, 'm')
        v = self.get_slot(var, 'v')
        beta1_power, beta2_power = self._get_beta_accumulators()
        beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
        beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
        learning_rate_t = math_ops.cast(self.learning_rate_t,
                                        var.dtype.base_dtype)
        learning_rate_t = (learning_rate_t * math_ops.sqrt(1 - beta2_power) /
                           (1 - beta1_power))

        m_t = state_ops.assign(m, m * beta_1_t, use_locking=self._use_locking)

        m_scaled_g_values = grad * (1 - beta_1_t)
        with ops.control_dependencies([m_t]):
            m_t = scatter_add(m, indices, m_scaled_g_values)

        v_scaled_g_values = (grad * grad) * (1 - beta_2_t)
        v_t = state_ops.assign(v, v * beta_2_t, use_locking=self._use_locking)
        with ops.control_dependencies([v_t]):
            v_t = scatter_add(v, indices, v_scaled_g_values)

        update = m_t / (math_ops.sqrt(v_t) + epsilon_t)

        if self._do_use_weight_decay(var.name):
            update += weight_decay_rate_t * var

        update_with_lr = learning_rate_t * update

        var_update = state_ops.assign_sub(var,
                                          update_with_lr,
                                          use_locking=self._use_locking)
        return control_flow_ops.group(*[var_update, m_t, v_t])
Ejemplo n.º 9
0
 def _finish(self, update_ops, name_scope):
     # Update the power accumulators.
     with ops.control_dependencies(update_ops):
         beta1_power, beta2_power = self._get_beta_accumulators()
         with ops.colocate_with(beta1_power):
             update_beta1 = beta1_power.assign(
                 beta1_power * self.beta_1_t, use_locking=self._use_locking)
             update_beta2 = beta2_power.assign(
                 beta2_power * self.beta_2_t, use_locking=self._use_locking)
         return control_flow_ops.group(*update_ops +
                                       [update_beta1, update_beta2],
                                       name=name_scope)
    def _resource_apply_dense(self, grad, var):
        var_dtype = var.dtype.base_dtype
        lr_t = self._decayed_lr(var_dtype)
        beta_1_t = self._get_hyper('beta_1', var_dtype)
        beta_2_t = self._get_hyper('beta_2', var_dtype)
        accumulation_steps = self._get_hyper('accumulation_steps', 'int64')
        update_cond = tf.equal((self.iterations + 1) % accumulation_steps, 0)
        sub_step = self.iterations % accumulation_steps + 1
        local_step = math_ops.cast(self.iterations // accumulation_steps + 1,
                                   var_dtype)
        beta_1_power = math_ops.pow(beta_1_t, local_step)
        beta_2_power = math_ops.pow(beta_2_t, local_step)
        epsilon_t = ops.convert_to_tensor(self.epsilon, var_dtype)
        lr = (lr_t * math_ops.sqrt(1 - beta_2_power) / (1 - beta_1_power))
        lr = tf.where(update_cond, lr, 0.0)

        g = self.get_slot(var, 'g')
        g_a = grad / math_ops.cast(accumulation_steps, var_dtype)
        g_t = tf.where(tf.equal(sub_step, 1), g_a,
                       g + (g_a - g) / math_ops.cast(sub_step, var_dtype))
        g_t = state_ops.assign(g, g_t, use_locking=self._use_locking)

        m = self.get_slot(var, 'm')
        m_t = tf.where(update_cond, m * beta_1_t + g_t * (1 - beta_1_t), m)
        m_t = state_ops.assign(m, m_t, use_locking=self._use_locking)

        v = self.get_slot(var, 'v')
        v_t = tf.where(update_cond,
                       v * beta_2_t + (g_t * g_t) * (1 - beta_2_t), v)
        v_t = state_ops.assign(v, v_t, use_locking=self._use_locking)

        if not self.amsgrad:
            v_sqrt = math_ops.sqrt(v_t)
            var_update = state_ops.assign_sub(var,
                                              lr * m_t / (v_sqrt + epsilon_t),
                                              use_locking=self._use_locking)
            return control_flow_ops.group(*[var_update, m_t, v_t])
        else:
            v_hat = self.get_slot(var, 'vhat')
            v_hat_t = tf.where(update_cond, math_ops.maximum(v_hat, v_t),
                               v_hat)
            with ops.control_dependencies([v_hat_t]):
                v_hat_t = state_ops.assign(v_hat,
                                           v_hat_t,
                                           use_locking=self._use_locking)
            v_hat_sqrt = math_ops.sqrt(v_hat_t)
            var_update = state_ops.assign_sub(var,
                                              lr * m_t /
                                              (v_hat_sqrt + epsilon_t),
                                              use_locking=self._use_locking)
            return control_flow_ops.group(*[var_update, m_t, v_t, v_hat_t])
Ejemplo n.º 11
0
def test_parse_name_scope():
    with ops.Graph().as_default():
        name_scope = 'name_scope/child_name_scope'
        a = constant_op.constant(5)
        new_name = ops.prepend_name_scope(a.name, name_scope)
        assert new_name == 'name_scope/child_name_scope/Const:0'
        assert name_scope == utils.parse_name_scope(new_name)
        assert '' == utils.parse_name_scope(a.name)

        with ops.control_dependencies([no_op(name='my_op')]):
            b = constant_op.constant(6)
        name_scope = 'name_scope'
        new_name = ops.prepend_name_scope(b.op.node_def.input[0], name_scope)
        assert new_name == '^name_scope/my_op'
        assert name_scope == utils.parse_name_scope(new_name)
Ejemplo n.º 12
0
    def get_predictions(self, compiled_prediction_loop):
        with ops.device(_HOST_DEVICE):
            with ops.control_dependencies([compiled_prediction_loop]):
                predictions = self._outfeed_queue.dequeue()

        if isinstance(predictions, dict):
            return predictions

        assert isinstance(predictions, list)
        if len(predictions) != 1:
            raise ValueError((
                "The last computational stage must return exactly one prediction "
                "tensor, but got {}").format(len(predictions)))

        return predictions[0]
Ejemplo n.º 13
0
  def restore(self, restored_tensors, unused_restored_shapes):
    """Restores the associated tree from 'restored_tensors'.

    Args:
      restored_tensors: the tensors that were loaded from a checkpoint.
      unused_restored_shapes: the shapes this object should conform to after
        restore. Not meaningful for trees.

    Returns:
      The operation that restores the state of the tree variable.
    """
    with ops.control_dependencies([self._create_op]):
      return self._deserialize_op_func(
          self._resource_handle,
          restored_tensors[0],
      )
Ejemplo n.º 14
0
  def restore(self, restored_tensors, unused_restored_shapes):
    """Restores the associated tree from 'restored_tensors'.

    Args:
      restored_tensors: the tensors that were loaded from a checkpoint.
      unused_restored_shapes: the shapes this object should conform to after
        restore. Not meaningful for trees.

    Returns:
      The operation that restores the state of the tree variable.
    """
    with ops.control_dependencies([self._create_op]):
      return self._deserialize_op_func(
          self._resource_handle,
          restored_tensors[0],
      )
Ejemplo n.º 15
0
    def _aggregate_sparse_gradients(self, var_op, reduce_to_device,
                                    indexed_slices_grads, values_op_name):
        with ops.device(reduce_to_device):
            grad_accum_op_name = ops.prepend_name_scope(
                values_op_name, u"%sAccum" % AUTODIST_PREFIX)
            grad_accum = data_flow_ops.SparseConditionalAccumulator(
                dtype=indexed_slices_grads[0].values.dtype,
                shape=var_op.outputs[0].shape,
                shared_name=grad_accum_op_name,
                name=grad_accum_op_name)
            accum_apply_ops = [
                grad_accum.apply_indexed_slices_grad(
                    indexed_slices_grads[i],
                    MAX_INT64,
                    name=ops.prepend_name_scope(
                        values_op_name, u"%s-Accum-Apply" % replica_prefix(i)))
                for i in range(self.num_replicas)
            ]
            take_grad_op_name = ops.prepend_name_scope(
                values_op_name, u"%sTake-Grad" % AUTODIST_PREFIX)
            with ops.control_dependencies(accum_apply_ops):
                take_grad = grad_accum.take_indexed_slices_grad(
                    self.num_replicas, name=take_grad_op_name)

            new_indices = take_grad.indices
            new_values = take_grad.values
            new_dense_shape = take_grad.dense_shape
            if indexed_slices_grads[0].indices.dtype != new_indices.dtype:
                new_indices = math_ops.cast(
                    new_indices,
                    indexed_slices_grads[0].indices.dtype,
                    name=ops.prepend_name_scope(
                        values_op_name,
                        u"%sTake-Grad-Cast-Indices" % AUTODIST_PREFIX))
            if indexed_slices_grads[
                    0].dense_shape.dtype != new_dense_shape.dtype:
                new_dense_shape = math_ops.cast(
                    new_dense_shape,
                    indexed_slices_grads[0].dense_shape.dtype,
                    name=ops.prepend_name_scope(
                        values_op_name,
                        u"%sTake-Grad-Cast-Shape" % AUTODIST_PREFIX))
        return ops.IndexedSlices(new_values, new_indices, new_dense_shape)
Ejemplo n.º 16
0
    def get_all_update_ops(self, grad_apply_finished, worker_device=None):
        """
        Create and return new update ops for proxy vars.

        Args:
            grad_apply_finished (List[Operation]): ops with which to colocate the new ops.
            worker_device (DeviceSpecV2): the device on which to create the ops.

        Returns:
            List[Operation]: the list of update ops for each proxy variable.
        """
        with ops.device(worker_device):
            with ops.control_dependencies(grad_apply_finished):
                updated_value = gen_read_var_op(
                    self._this_op, self._dtype)  # create new read var op
        update_ops = []
        for proxy_var in self._proxy_vars:
            with ops.device(proxy_var.device):
                update_ops.append(proxy_var.assign(updated_value))
        return update_ops
Ejemplo n.º 17
0
    def get_evaluation_loss_and_metrics(self, compiled_evaluation_loop):
        with ops.device(_HOST_DEVICE):
            with ops.control_dependencies([compiled_evaluation_loop]):
                inputs = self._outfeed_queue.dequeue()

            args, kwargs = loops._body_arguments(inputs)  # pylint: disable=protected-access
            metrics = self._captured_eval_metrics_fn(*args, **kwargs)

        if not isinstance(metrics, dict):
            raise TypeError(("The `eval_metrics_fn` must return a dict, "
                             "but got {}.").format(type(metrics)))

        if model_fn_lib.LOSS_METRIC_KEY not in metrics:
            raise KeyError(
                ("The dict returned from `eval_metrics_fn` "
                 "must contain '{}'.").format(model_fn_lib.LOSS_METRIC_KEY))

        loss = metrics.pop(model_fn_lib.LOSS_METRIC_KEY)

        return loss, metrics
Ejemplo n.º 18
0
def _update_t_cur_eta_t_v2(self, lr_t=None, var=None):  # tf.keras
    t_cur_update, eta_t_update = None, None  # in case not assigned

    # update `t_cur` if iterating last `(grad, var)`
    iteration_done = (self._updates_processed == (self._updates_per_iter - 1))
    if iteration_done:
        t_cur_update = _update_t_cur(self)
        self._updates_processed = 0  # reset
    else:
        self._updates_processed += 1

    # Cosine annealing
    if self.use_cosine_annealing and iteration_done:
        # ensure eta_t is updated AFTER t_cur
        with ops.control_dependencies([t_cur_update]):
            eta_t_update = state_ops.assign(self.eta_t, _compute_eta_t(self),
                                            use_locking=self._use_locking)
        self.lr_t = lr_t * self.eta_t  # for external tracking

    return iteration_done, t_cur_update, eta_t_update
Ejemplo n.º 19
0
    def _resource_apply_sparse(self, grad, var, indices):
        var_dtype = var.dtype.base_dtype
        lr_t = self._decayed_lr(var_dtype)
        beta_1_t = self._get_hyper('beta_1', var_dtype)
        beta_2_t = self._get_hyper('beta_2', var_dtype)
        epsilon_t = ops.convert_to_tensor(self.epsilon, var_dtype)
        local_step = math_ops.cast(self.iterations + 1, var_dtype)
        beta_1_power = math_ops.pow(beta_1_t, local_step)
        beta_2_power = math_ops.pow(beta_2_t, local_step)

        if self._initial_total_steps > 0:
            total_steps = self._get_hyper('total_steps', var_dtype)
            warmup_steps = total_steps * self._get_hyper('warmup_proportion', var_dtype)
            min_lr = self._get_hyper('min_lr', var_dtype)
            decay_steps = K.maximum(total_steps - warmup_steps, 1)
            decay_rate = (min_lr - lr_t) / decay_steps
            lr_t = tf.where(
                local_step <= warmup_steps,
                lr_t * (local_step / warmup_steps),
                lr_t + decay_rate * K.minimum(local_step - warmup_steps, decay_steps),
            )

        sma_inf = 2.0 / (1.0 - beta_2_t) - 1.0
        sma_t = sma_inf - 2.0 * local_step * beta_2_power / (1.0 - beta_2_power)

        m = self.get_slot(var, 'm')
        m_scaled_g_values = grad * (1 - beta_1_t)
        m_t = state_ops.assign(m, m * beta_1_t, use_locking=self._use_locking)
        with ops.control_dependencies([m_t]):
            m_t = self._resource_scatter_add(m, indices, m_scaled_g_values)
        m_corr_t = m_t / (1.0 - beta_1_power)

        v = self.get_slot(var, 'v')
        v_scaled_g_values = (grad * grad) * (1 - beta_2_t)
        v_t = state_ops.assign(v, v * beta_2_t, use_locking=self._use_locking)
        with ops.control_dependencies([v_t]):
            v_t = self._resource_scatter_add(v, indices, v_scaled_g_values)

        if self.amsgrad:
            vhat = self.get_slot(var, 'vhat')
            vhat_t = state_ops.assign(vhat,
                                      math_ops.maximum(vhat, v_t),
                                      use_locking=self._use_locking)
            v_corr_t = math_ops.sqrt(vhat_t / (1.0 - beta_2_power))
        else:
            vhat_t = None
            v_corr_t = math_ops.sqrt(v_t / (1.0 - beta_2_power))

        r_t = math_ops.sqrt((sma_t - 4.0) / (sma_inf - 4.0) *
                            (sma_t - 2.0) / (sma_inf - 2.0) *
                            sma_inf / sma_t)

        var_t = tf.where(sma_t >= 5.0, r_t * m_corr_t / (v_corr_t + epsilon_t), m_corr_t)

        if self._initial_weight_decay > 0.0:
            var_t += self._get_hyper('weight_decay', var_dtype) * var

        var_update = self._resource_scatter_add(var, indices, tf.gather(-lr_t * var_t, indices))

        updates = [var_update, m_t, v_t]
        if self.amsgrad:
            updates.append(vhat_t)
        return control_flow_ops.group(*updates)
Ejemplo n.º 20
0
 def _resource_scatter_add(self, x, i, v):
     with ops.control_dependencies(
         [resource_variable_ops.resource_scatter_add(x.handle, i, v)]):
         return x.value()
Ejemplo n.º 21
0
    def _get_queue_ops_stale(self, var_update_op: ops.Operation,
                             source_op: ops.Operation, is_chief: bool,
                             is_trainable: bool) -> List[ops.Operation]:
        """
        Get queue operations for staleness synchronous parameter update.

        Maintain a list of queues of size equal to <staleness>. At the beginning of each call of this function
        (either by the chief worker or other workers), it checks whether each queue is not full. If yes, it pushes
        a token to each queue. If not, it does nothing (a no_op).
        Then, for the current worker that calls this function, it dequeues a token from its corresponding queue
        (indexed by its worker id).
        The potential enqueue operations and definite dequeue operation are grouped together, and have to be
        finished before the model moves on to the next step.
        As at each invocation of this function, a row of empty space in the list of queues will be filled. Thus
        <staleness> number of consecutive dequeue operations can be done by a worker without blocking, achieving
        stale synchronous parameter update with maximum <staleness> steps difference.

        Args:
            var_update_op: The op

        Returns:
            A list of queue operations.
        """
        var_op = var_update_op.inputs[UPDATE_OP_VAR_POS].op

        var_update_sync_queues = \
            [data_flow_ops.FIFOQueue(self._staleness, [dtypes.bool], shapes=None,
                                     name='%s_update_sync_queue_%d' % (var_op.name, i),
                                     shared_name='%s_update_sync_queue_%d' % (var_op.name, i))
             for i in range(self.num_workers)]

        # Enqueue one token to every queue if all queues are not full.
        def _enqueue_row_op():
            enqueue_ops = []
            for q in var_update_sync_queues:
                enqueue_ops.append(q.enqueue(False))
            enqueue_a_row_ops = control_flow_ops.group(*enqueue_ops)
            return enqueue_a_row_ops

        def _no_op():
            return gen_control_flow_ops.no_op()

        switch_cond = gen_array_ops.identity(True)
        for q in var_update_sync_queues:
            switch_cond = gen_math_ops.logical_and(
                switch_cond,
                gen_math_ops.less(q.size(),
                                  gen_array_ops.identity(self._staleness)))

        enqueue_a_row_ops = control_flow_ops.cond(switch_cond, _enqueue_row_op,
                                                  _no_op)

        queue_ops = [enqueue_a_row_ops]

        if is_chief:
            if is_trainable:
                var_update_deps = [
                    self._var_op_to_accum_apply_op[var_op], source_op
                ]
            else:
                var_update_deps = [var_update_op]
            with ops.control_dependencies(var_update_deps):
                dequeue = var_update_sync_queues[self.worker_id].dequeue()
        else:
            # wait for execution of var_update_op
            if is_trainable:
                with ops.control_dependencies(
                    [self._var_op_to_accum_apply_op[var_op]]):
                    dequeue = var_update_sync_queues[self.worker_id].dequeue()
            else:
                dequeue = var_update_sync_queues[self.worker_id].dequeue()
        queue_ops.append(dequeue)

        return queue_ops
Ejemplo n.º 22
0
 def body(v):
     logging_op = hook.log({"foo": v})
     with ops.control_dependencies([logging_op]):
         return v + 1