예제 #1
0
    def apply_gradients(self, grads_and_vars, global_step=None, name=None):
        summed_grads_and_vars = []
        for (grad, var) in grads_and_vars:
            if grad is None:
                summed_grads_and_vars.append((grad, var))
            else:
                with ops.colocate_with(grad):
                    # gradient accumulation
                    if self._gradients_to_accumulate > 1 and not self._pipelining:
                        grad = gen_poputil_ops.ipu_stateful_gradient_accumulate(
                            grad / self._gradients_to_accumulate,
                            num_mini_batches=self._gradients_to_accumulate)

                    # replication
                    if self._replicas > 1:
                        grad = gen_poputil_ops.ipu_replication_normalise(
                            cross_replica_ops.cross_replica_sum(grad))

                    grad = math_ops.cast(grad, var.dtype)
                    summed_grads_and_vars.append((grad, var))

        if self._pipelining:
            # can do weight decay here as apply_gradients is only called on last accumulation step
            summed_grads_and_vars = self.add_WD(summed_grads_and_vars)

        ret = self._optimizer.apply_gradients(summed_grads_and_vars,
                                              global_step, name)
        if self._sharded:
            sharding.propagate_sharding(ops.get_default_graph())
        return ret
예제 #2
0
            def grad_fn(grad, param, m, v):
                if self.add_cross_replica_sums:
                    grad = cross_replica_ops.cross_replica_sum(grad)

                cast_grad = tf.cast(grad, dtype=tf.float32)
                cast_grad = cast_grad / self.loss_scaling

                # Standard Adam update.
                next_m = (tf.multiply(self.beta_1, m) +
                          tf.multiply(1.0 - self.beta_1, cast_grad))
                next_v = (tf.multiply(self.beta_2, v) +
                          tf.multiply(1.0 - self.beta_2, tf.square(cast_grad)))

                update = tf.cast(next_m / (tf.sqrt(next_v) + self.epsilon),
                                 param.dtype)

                # Just adding the square of the weights to the loss function is *not*
                # the correct way of using L2 regularization/weight decay with Adam,
                # since that will interact with the m and v parameters in strange ways.
                #
                # Instead we want to decay the weights in a manner that doesn't interact
                # with the m/v parameters. This is equivalent to adding the square
                # of the weights to the loss with plain (non-momentum) SGD.
                if self._do_use_weight_decay(param_name):
                    update += self.weight_decay_rate * param

                update_with_lr = self.learning_rate * update

                next_param = param - update_with_lr

                return next_param, next_v, next_m
예제 #3
0
    def apply_gradients(self, grads_and_vars, global_step=None, name=None):
        """Apply gradients to variables.

    Calls popops_cross_replica_sum.cross_replica_sum() to sum gradient
    contributions across replicas, and then applies the real optimizer.

    Args:
      grads_and_vars: List of (gradient, variable) pairs as returned by
        compute_gradients().
      global_step: Optional Variable to increment by one after the
        variables have been updated.
      name: Optional name for the returned operation.  Default to the
        name passed to the Optimizer constructor.

    Returns:
      An `Operation` that applies the gradients. If `global_step` was not None,
      that operation also increments `global_step`.

    Raises:
      ValueError: If the grads_and_vars is malformed.
    """
        summed_grads_and_vars = []
        for (grad, var) in grads_and_vars:
            if grad is None:
                summed_grads_and_vars.append((grad, var))
            else:
                with ops.colocate_with(grad):
                    summed_grads_and_vars.append(
                        (gen_poputil_ops.ipu_replication_normalise(
                            cross_replica_ops.cross_replica_sum(grad)), var))
        return self._opt.apply_gradients(summed_grads_and_vars, global_step,
                                         name)
예제 #4
0
 def comp_fn():
     def body(total_accuracy, image, label):
         accuracy = validation_graph_builder(model, image, label, opts)
         return total_accuracy + (tf.cast(accuracy, tf.float32) / opts["validation_batches_per_step"])
     accuracy = loops.repeat(int(opts["validation_batches_per_step"]),
                             body, [tf.constant(0, tf.float32)], valid_iterator)
     if opts['replicas'] > 1:
         accuracy = cross_replica_ops.cross_replica_sum(accuracy) / (opts['replicas']*opts['shards'])
     return accuracy
예제 #5
0
    def _reduce_to(self, reduce_op, value, destinations):
        del destinations

        if not _is_current_device_ipu():
            return value

        if reduce_op not in (reduce_util.ReduceOp.SUM,
                             reduce_util.ReduceOp.MEAN):
            raise ValueError("Unsupported reduce op: {}".format(reduce_op))

        result = cross_replica_ops.cross_replica_sum(value)

        if reduce_op == reduce_util.ReduceOp.MEAN:
            result = gen_poputil_ops.ipu_replication_normalise(result)

        return result
예제 #6
0
    def _reduce_to(self, reduce_op, value, destinations):
        del destinations

        if not _is_current_device_ipu():
            # If not on IPU, use Horovod for the reduction.
            return hvd_allreduce(value, op=_to_horovod_op(reduce_op))

        # On IPU we do a compiled reduction with GCL.
        if reduce_op not in (reduce_util.ReduceOp.SUM,
                             reduce_util.ReduceOp.MEAN):
            raise ValueError("Unsupported reduce op: {}".format(reduce_op))

        result = cross_replica_ops.cross_replica_sum(value)

        if reduce_op == reduce_util.ReduceOp.MEAN:
            result = gen_poputil_ops.ipu_replication_normalise(result)

        return result
예제 #7
0
    def apply_gradients(self, grads_and_vars, global_step=None, name=None):
        summed_grads_and_vars = []
        for (grad, var) in grads_and_vars:
            if grad is None:
                summed_grads_and_vars.append((grad, var))
            else:
                with ops.colocate_with(grad):
                    # gradient accumulation
                    if self._gradient_accumulation_count > 1 and not self._pipelining:
                        grad = gen_poputil_ops.ipu_stateful_gradient_accumulate(
                            grad,
                            num_mini_batches=self._gradient_accumulation_count)

                    # replication
                    if self._replicas > 1:
                        grad = gen_poputil_ops.ipu_replication_normalise(
                            cross_replica_ops.cross_replica_sum(grad))

                    # distribution with IPUMultiWorkerStrategy needs additional normalisation by the number of workers
                    if isinstance(
                            distribute.get_strategy(),
                            ipu_multi_worker_strategy.IPUMultiWorkerStrategy):
                        grad /= distribute.get_strategy().num_replicas_in_sync

                    grad = math_ops.cast(grad, var.dtype)
                    summed_grads_and_vars.append((grad, var))

        if self._pipelining:
            # can do weight decay here as apply_gradients is only called on last accumulation step
            summed_grads_and_vars = self.add_WD(summed_grads_and_vars)

        if self._grad_scale != 1.0:
            # don't rescale batch norm moving average statistics as they are not affected by loss scaling
            summed_grads_and_vars = [
                (grad, var) if 'batch_norm/moving_' in var.name else
                (grad / self._grad_scale, var)
                for grad, var in summed_grads_and_vars
            ]
        ret = self._optimizer.apply_gradients(summed_grads_and_vars,
                                              global_step, name)
        if self._sharded:
            sharding.propagate_sharding(ops.get_default_graph())
        return ret
예제 #8
0
            def comp_fn():
                def body(total_accuracy, data_dict):
                    accuracy = validation_graph_builder(model, data_dict, opts)
                    if opts['latency']:
                        timestamp_enqueue = timestamp_queue.enqueue(
                            data_dict['timestamp'])
                        return (total_accuracy +
                                (tf.cast(accuracy, tf.float32) /
                                 opts["validation_batches_per_step"]),
                                timestamp_enqueue)
                    else:
                        return total_accuracy + (
                            tf.cast(accuracy, tf.float32) /
                            opts["validation_batches_per_step"])

                accuracy = loops.repeat(
                    int(opts["validation_batches_per_step"]), body,
                    [tf.constant(0, tf.float32)], valid_iterator)
                if opts['total_replicas'] * opts['shards'] > 1 and not opts.get(
                        'inference', False):
                    accuracy = cross_replica_ops.cross_replica_sum(
                        accuracy) / (opts['total_replicas'] * opts['shards'])
                return accuracy
예제 #9
0
            def grad_fn(grad, param, m, v):
                if self.add_cross_replica_sums:
                    grad = cross_replica_ops.cross_replica_sum(grad)

                # We convert the gradient to fp32 and we rescale it
                cast_grad = tf.cast(grad, dtype=tf.float32)
                cast_grad = cast_grad / self.loss_scaling

                if self.use_nvlamb:
                    # We de normalize the gradients
                    cast_grad = cast_grad * self.clipping_value / global_norm

                # Standard Adam update.
                next_m = (tf.multiply(self.beta_1, m) +
                          tf.multiply(1.0 - self.beta_1, cast_grad))
                next_v = (tf.multiply(self.beta_2, v) +
                          tf.multiply(1.0 - self.beta_2, tf.square(cast_grad)))
                # Beta scaling of momentum and velocity
                if self.debiasing:
                    m_hat = next_m / (1.0 - tf.pow(
                        self.beta_1, tf.cast(self.step, dtype=tf.float32)))  # x10
                    v_hat = next_v / (1.0 - tf.pow(
                        self.beta_2, tf.cast(self.step, dtype=tf.float32))
                                    )  # x1000
                else:
                    m_hat = next_m
                    v_hat = next_v

                # TODO: Check if it is possible to convert to fp16 here.
                # m_hat = tf.cast(m_hat, dtype = tf.float16)
                # v_hat = tf.cast(v_hat, dtype = tf.float16)

                update = m_hat / (tf.sqrt(tf.math.abs(v_hat)) +
                                  tf.cast(self.epsilon, dtype=v_hat.dtype))

                # Just adding the square of the weights to the loss function is *not*
                # the correct way of using L2 regularization/weight decay with Adam,
                # since that will interact with the m and v parameters in strange ways.
                #
                # Instead we want ot decay the weights in a manner that doesn't interact
                # with the m/v parameters. This is equivalent to adding the square
                # of the weights to the loss with plain (non-momentum) SGD.
                if self._do_use_weight_decay(param_name):
                    update += tf.cast(self.weight_decay_rate,
                                      dtype=update.dtype) * tf.cast(
                                          param, dtype=update.dtype)

                if 'qkv' in param_name:
                    # We reshape the parameters
                    reshaped_param = self.forward_transform(param)
                    reshaped_update = self.forward_transform(update)
                else:
                    reshaped_param = tf.reshape(param, [-1])
                    reshaped_update = tf.reshape(update, [-1])

                # Norms are then computed in fp32
                w_norm = linalg_ops.norm(tf.cast(reshaped_param,
                                                 dtype=tf.float32),
                                         ord=2,
                                         axis=-1)
                u_norm = linalg_ops.norm(reshaped_update, ord=2, axis=-1)

                reshaped_update = tf.cast(reshaped_update, dtype=self.target_type)

                if self.weight_clip:
                    w_norm = tf.math.minimum(
                        w_norm, tf.cast(self.weight_clip, dtype=w_norm.dtype))

                # We set the ratio to 1 if either the w norm and the u norms are 0
                ratio = array_ops.where(
                    math_ops.greater(w_norm, 0),
                    array_ops.where(
                        math_ops.greater(u_norm, 0),
                        (tf.cast(w_norm, dtype=tf.float32) / u_norm),
                        tf.constant(1.0, dtype=tf.float32, shape=w_norm.shape)),
                    tf.constant(1.0, dtype=tf.float32, shape=w_norm.shape))

                # We reshape the ration in order to be broadcastable
                ratio = tf.reshape(ratio, shape=ratio.shape.as_list() + [1])
                # We combine the learning rate and the ratio at fp32
                ratio = ratio * tf.cast(self.learning_rate, dtype=tf.float32)
                # We now downcast to do the next operation
                # If the scaledd is present we do not need this operation
                ratio = tf.cast(ratio, dtype=tf.float16)
                update_with_lr = ratio * reshaped_update
                # Backward transform to the same as param
                if 'qkv' in param_name:
                    update_with_lr = self.backward_transform(update_with_lr)
                else:
                    update_with_lr = tf.reshape(update_with_lr, shape=param.shape)
                update_with_lr = tf.cast(update_with_lr, dtype=param.dtype)

                next_param = param - update_with_lr

                return next_param, next_m, next_v
예제 #10
0
 def my_body(acc, x):
     index = math_ops.cast(replication_index(), np.float32) + 1.0
     acc += cross_replica_sum(index)
     acc += x * x
     return acc, outfeed.enqueue(index)