def impl_selecting_fn(**kwargs):
    """The wrapper function to be returned."""
    if _is_xla():  # JAX, XLA breakout.
      return default_fn(**kwargs)
    if NUMPY_MODE:  # Numpy breakout.
      return cpu_fn(**kwargs)

    # Import locally to avoid TF dependency for TFP-on-JAX.
    from tensorflow.python.eager import function  # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top

    # Each time a `tf.function` is called, we will give it a unique
    # identifiable API name, so that Grappler won't get confused when it
    # sees multiple samplers in same graph, and it will be able
    # to pair up the different implementations across them.
    api_name = '{}_{}'.format(fn_name, str(uuid.uuid4()))
    defun_default_fn = _generate_defun_backend(
        default_fn, api_name)
    defun_cpu_fn = _generate_defun_backend(
        cpu_fn, api_name, preferred_device=_CPU_DEVICE_NAME)

    # Call the default sampling impl and register the CPU-specialized impl.
    # Grappler will kick in during session execution to optimize the graph.
    samples, runtime = defun_default_fn(**kwargs)
    function.register(defun_cpu_fn, **kwargs)
    return samples, runtime
Пример #2
0
    def _defun_gru_call(self, inputs, initial_state, training):
        # Use the new defun approach for backend implementation swap.
        # Note that different implementations need to have same function
        # signature, eg, the tensor parameters need to have same shape and dtypes.
        if self.go_backwards:
            # Reverse time axis.
            inputs = K.reverse(inputs, 0 if self.time_major else 1)

        self.reset_dropout_mask()
        dropout_mask = self.get_dropout_mask_for_cell(inputs,
                                                      training,
                                                      count=3)
        if dropout_mask is not None:
            inputs *= dropout_mask[0]
        if ops.executing_eagerly_outside_functions():
            # Under eager context, the device placement is already known. Prefer the
            # GPU implementation when GPU is available.
            if context.num_gpus() > 0:
                last_output, outputs, new_h, runtime = cudnn_gru(
                    inputs=inputs,
                    init_h=initial_state[0],
                    kernel=self.cell.kernel,
                    recurrent_kernel=self.cell.recurrent_kernel,
                    bias=self.cell.bias,
                    time_major=self.time_major)
            else:
                last_output, outputs, new_h, runtime = standard_gru(
                    inputs=inputs,
                    init_h=initial_state[0],
                    kernel=self.cell.kernel,
                    recurrent_kernel=self.cell.recurrent_kernel,
                    bias=self.cell.bias,
                    activation=self.activation,
                    recurrent_activation=self.recurrent_activation,
                    time_major=self.time_major)
        else:
            api_name = 'gru_' + str(uuid.uuid4())
            defun_standard_gru = _generate_defun_backend(
                api_name, _CPU_DEVICE_NAME, standard_gru)
            defun_cudnn_gru = _generate_defun_backend(api_name,
                                                      _GPU_DEVICE_NAME,
                                                      cudnn_gru)
            # Call the normal GRU impl and register the CuDNN impl function. The
            # grappler will kick in during session execution to optimize the graph.
            last_output, outputs, new_h, runtime = defun_standard_gru(
                inputs=inputs,
                init_h=initial_state[0],
                kernel=self.cell.kernel,
                recurrent_kernel=self.cell.recurrent_kernel,
                bias=self.cell.bias,
                activation=self.activation,
                recurrent_activation=self.recurrent_activation,
                time_major=self.time_major)

            function.register(defun_cudnn_gru, inputs, initial_state[0],
                              self.cell.kernel, self.cell.recurrent_kernel,
                              self.cell.bias, self.time_major)
        states = [new_h]
        return last_output, outputs, runtime, states
Пример #3
0
  def call(self, inputs, mask=None, training=None, initial_state=None):
    if isinstance(inputs, list):
      initial_state = inputs[1:]
      inputs = inputs[0]
    elif initial_state is not None:
      pass
    elif self.stateful:
      initial_state = self.states
    else:
      initial_state = self.get_initial_state(inputs)

    if len(initial_state) != len(self.states):
      raise ValueError('Layer has ' + str(len(self.states)) +
                       ' states but was passed ' + str(len(initial_state)) +
                       ' initial states.')

    if self.go_backwards:
      # Reverse time axis.
      inputs = K.reverse(inputs, 1)

    if ops.executing_eagerly_outside_functions():
      if context.num_gpus() > 0:
        outputs, [new_h, new_c], runtime = cudnn_lstm(
            inputs, initial_state[0], initial_state[1], self.kernel,
            self.recurrent_kernel, self.bias, self.units)
      else:
        outputs, [new_h, new_c], runtime = normal_lstm(
            inputs, initial_state[0], initial_state[1], self.kernel,
            self.recurrent_kernel, self.bias, self.units, self.activation,
            self.recurrent_activation)
    else:
      outputs, [new_h, new_c], runtime = normal_lstm(
          inputs, initial_state[0], initial_state[1], self.kernel,
          self.recurrent_kernel, self.bias, self.units, self.activation,
          self.recurrent_activation)

      function.register(cudnn_lstm, inputs, initial_state[0], initial_state[1],
                        self.kernel, self.recurrent_kernel, self.bias,
                        self.units)

    states = [new_h, new_c]

    if self.stateful:
      updates = []
      for i in range(len(states)):
        updates.append(state_ops.assign(self.states[i], states[i]))
      self.add_update(updates, inputs)

    if self.return_sequences:
      output = outputs
    else:
      output = outputs[:, -1, :]

    if self.return_state:
      return [output] + states
    else:
      return output, runtime
Пример #4
0
def lstm_with_backend_selection(inputs, init_h, init_c, kernel,
                                recurrent_kernel, bias, mask, time_major,
                                go_backwards, activation, recurrent_activation,
                                sequence_lengths, zero_output_for_mask):
  """Call the LSTM with optimized backend kernel selection.

  Under the hood, this function will create two TF function, one with the most
  generic kernel and can run on all device condition, and the second one with
  CuDNN specific kernel, which can only run on GPU.

  The first function will be called with normal_lstm_params, while the second
  function is not called, but only registered in the graph. The Grappler will
  do the proper graph rewrite and swap the optimized TF function based on the
  device placement.

  Args:
    inputs: Input tensor of LSTM layer.
    init_h: Initial state tensor for the cell output.
    init_c: Initial state tensor for the cell hidden state.
    kernel: Weights for cell kernel.
    recurrent_kernel: Weights for cell recurrent kernel.
    bias: Weights for cell kernel bias and recurrent bias. Only recurrent bias
      is used in this case.
    mask: Boolean tensor for mask out the steps within sequence.
    time_major: Boolean, whether the inputs are in the format of
      [time, batch, feature] or [batch, time, feature].
    go_backwards: Boolean (default False). If True, process the input sequence
      backwards and return the reversed sequence.
    activation: Activation function to use for output.
    recurrent_activation: Activation function to use for hidden recurrent state.
    sequence_lengths: The lengths of all sequences coming from a variable length
      input, such as ragged tensors. If the input has a fixed timestep size,
      this should be None.
    zero_output_for_mask: Boolean, whether to output zero for masked timestep.

  Returns:
    List of output tensors, same as standard_lstm.
  """
  params = {
      'inputs': inputs,
      'init_h': init_h,
      'init_c': init_c,
      'kernel': kernel,
      'recurrent_kernel': recurrent_kernel,
      'bias': bias,
      'mask': mask,
      'time_major': time_major,
      'go_backwards': go_backwards,
      'activation': activation,
      'recurrent_activation': recurrent_activation,
      'sequence_lengths': sequence_lengths,
      'zero_output_for_mask': zero_output_for_mask,
  }

  def gpu_lstm_with_fallback(inputs, init_h, init_c, kernel, recurrent_kernel,
                             bias, mask, time_major, go_backwards, activation,
                             recurrent_activation, sequence_lengths,
                             zero_output_for_mask):
    """Use CuDNN kernel when mask is none or strictly right padded."""
    if mask is None:
      return gpu_lstm(
          inputs=inputs,
          init_h=init_h,
          init_c=init_c,
          kernel=kernel,
          recurrent_kernel=recurrent_kernel,
          bias=bias,
          mask=mask,
          time_major=time_major,
          go_backwards=go_backwards,
          sequence_lengths=sequence_lengths)

    def input_right_padded():
      return gpu_lstm(
          inputs=inputs,
          init_h=init_h,
          init_c=init_c,
          kernel=kernel,
          recurrent_kernel=recurrent_kernel,
          bias=bias,
          mask=mask,
          time_major=time_major,
          go_backwards=go_backwards,
          sequence_lengths=sequence_lengths)

    def input_not_right_padded():
      return standard_lstm(
          inputs=inputs,
          init_h=init_h,
          init_c=init_c,
          kernel=kernel,
          recurrent_kernel=recurrent_kernel,
          bias=bias,
          mask=mask,
          time_major=time_major,
          go_backwards=go_backwards,
          activation=activation,
          recurrent_activation=recurrent_activation,
          sequence_lengths=sequence_lengths,
          zero_output_for_mask=zero_output_for_mask)

    return control_flow_ops.cond(
        is_sequence_right_padded(mask, time_major),
        true_fn=input_right_padded,
        false_fn=input_not_right_padded)

  # Each time a `tf.function` is called, we will give it a unique
  # identifiable API name, so that Grappler won't get confused when it
  # sees multiple LSTM layers added into same graph, and it will be able
  # to pair up the different implementations across them.
  api_name = 'lstm_' + str(uuid.uuid4())
  defun_standard_lstm = _generate_defun_backend(
      api_name, _CPU_DEVICE_NAME, standard_lstm)
  defun_gpu_lstm = _generate_defun_backend(api_name, _GPU_DEVICE_NAME,
                                           gpu_lstm_with_fallback)

  # Call the normal LSTM impl and register the CuDNN impl function. The
  # grappler will kick in during session execution to optimize the graph.
  last_output, outputs, new_h, new_c, runtime = defun_standard_lstm(
      **params)
  function.register(defun_gpu_lstm, **params)

  return last_output, outputs, new_h, new_c, runtime
Пример #5
0
 def register_cudnn_defun():
   function.register(defun_cudnn_lstm, **cudnn_lstm_kwargs)
   # return some dummy value since the tf.cond require some return value.
   return 0
Пример #6
0
 def register_cudnn_defun():
   function.register(defun_cudnn_lstm, **cudnn_lstm_kwargs)
   # return some dummy value since the tf.cond require some return value.
   return 0
    def call(self, inputs, mask=None, training=None, initial_state=None):
        # LSTM does not support constants. Ignore it during process.
        inputs, initial_state, _ = self._process_inputs(
            inputs, initial_state, None)

        if isinstance(mask, list):
            mask = mask[0]

        input_shape = K.int_shape(inputs)
        timesteps = input_shape[0] if self.time_major else input_shape[1]

        if mask is not None or not self.could_use_cudnn:
            # CuDNN does not support masking, fall back to use the normal LSTM.
            kwargs = {'training': training}

            def step(inputs, states):
                return self.cell.call(inputs, states, **kwargs)

            last_output, outputs, states = K.rnn(
                step,
                inputs,
                initial_state,
                constants=None,
                go_backwards=self.go_backwards,
                mask=mask,
                unroll=self.unroll,
                input_length=timesteps,
                time_major=self.time_major,
                zero_output_for_mask=self.zero_output_for_mask)
            runtime = _runtime('unknown')
        else:
            # Use the new defun approach for backend implementation swap.
            # Note that different implementations need to have same function
            # signature, eg, the tensor parameters need to have same shape and dtypes.
            # Since the CuDNN has an extra set of bias, those bias will be passed to
            # both normal and CuDNN implementations.
            if self.go_backwards:
                # Reverse time axis.
                inputs = K.reverse(inputs, 0 if self.time_major else 1)

            self.reset_dropout_mask()
            dropout_mask = self.get_dropout_mask_for_cell(inputs,
                                                          training,
                                                          count=4)
            if dropout_mask is not None:
                inputs *= dropout_mask[0]

            if context.executing_eagerly():
                device_type = _get_context_device_type()
                if device_type == _GPU_DEVICE_NAME or (device_type is None and
                                                       context.num_gpus() > 0):
                    # Under eager context, check the device placement and prefer the
                    # GPU implementation when GPU is available.
                    last_output, outputs, new_h, new_c, runtime = cudnn_lstm(
                        inputs, initial_state[0], initial_state[1],
                        self.cell.kernel, self.cell.recurrent_kernel,
                        self.cell.bias, self.time_major)
                else:
                    last_output, outputs, new_h, new_c, runtime = standard_lstm(
                        inputs, initial_state[0], initial_state[1],
                        self.cell.kernel, self.cell.recurrent_kernel,
                        self.cell.bias, self.activation,
                        self.recurrent_activation, self.time_major)
            else:
                # Each time a `tf.function` is called, we will give it a unique
                # identifiable API name, so that Grappler won't get confused when it
                # sees multiple LSTM layers added into same graph, and it will be able
                # to pair up the different implementations across them.
                api_name = 'lstm_' + str(uuid.uuid4())
                defun_standard_lstm = _generate_defun_backend(
                    api_name, _CPU_DEVICE_NAME, standard_lstm)
                defun_cudnn_lstm = _generate_defun_backend(
                    api_name, _GPU_DEVICE_NAME, cudnn_lstm)

                # Call the normal LSTM impl and register the CuDNN impl function. The
                # grappler will kick in during session execution to optimize the graph.
                last_output, outputs, new_h, new_c, runtime = defun_standard_lstm(
                    inputs, initial_state[0], initial_state[1],
                    self.cell.kernel, self.cell.recurrent_kernel,
                    self.cell.bias, self.activation, self.recurrent_activation,
                    self.time_major)

                function.register(defun_cudnn_lstm, inputs, initial_state[0],
                                  initial_state[1], self.cell.kernel,
                                  self.cell.recurrent_kernel, self.cell.bias,
                                  self.time_major)
            states = [new_h, new_c]

        if self.stateful:
            updates = []
            for i in range(len(states)):
                updates.append(state_ops.assign(self.states[i], states[i]))
            self.add_update(updates, inputs)

        if self.return_sequences:
            output = outputs
        else:
            output = last_output

        if self.return_state:
            return [output] + list(states)
        elif self.return_runtime:
            return output, runtime
        else:
            return output