Beispiel #1
0
    def build(self, input_shape):
        input_dim = input_shape[-1]
        tf_version = _tf_utils.tf_version()

        if tf_version >= '2.1.0':
            default_caching_device = _caching_device(self)

        if tf_version < '2.1.0':
            kernel_args = {
                'shape': (input_dim, self.units),
                'initializer': self.kernel_initializer,
                'regularizer': self.kernel_regularizer,
                'constraint': self.kernel_constraint,
            }
            recurrent_args = {
                'shape': (self.units, self.units),
                'initializer': self.recurrent_initializer,
                'regularizer': self.recurrent_regularizer,
                'constraint': self.recurrent_constraint,
            }
        else:
            # There is an addtional 'default_caching_device' argument after tf 2.1
            kernel_args = {
                'shape': (input_dim, self.units),
                'initializer': self.kernel_initializer,
                'regularizer': self.kernel_regularizer,
                'constraint': self.kernel_constraint,
                'caching_device': default_caching_device
            }
            recurrent_args = {
                'shape': (self.units, self.units),
                'initializer': self.recurrent_initializer,
                'regularizer': self.recurrent_regularizer,
                'constraint': self.recurrent_constraint,
                'caching_device': default_caching_device
            }

        # Split kernel/recurrent_kernel/bias to 4 parts as RNN compiler
        # requires this.
        self.kernel_i = self.add_weight(name='kernel_i', **kernel_args)
        self.kernel_f = self.add_weight(name='kernel_f', **kernel_args)
        self.kernel_c = self.add_weight(name='kernel_c', **kernel_args)
        self.kernel_o = self.add_weight(name='kernel_o', **kernel_args)

        self.recurrent_kernel_i = self.add_weight(name='recurrent_kernel_i',
                                                  **recurrent_args)
        self.recurrent_kernel_f = self.add_weight(name='recurrent_kernel_f',
                                                  **recurrent_args)
        self.recurrent_kernel_c = self.add_weight(name='recurrent_kernel_c',
                                                  **recurrent_args)
        self.recurrent_kernel_o = self.add_weight(name='recurrent_kernel_o',
                                                  **recurrent_args)

        if self.use_bias:
            bias_initializer = self.bias_initializer
            if self.unit_forget_bias:
                forget_bias_initializer = initializers.get('ones')
            else:
                forget_bias_initializer = bias_initializer

            if tf_version < '2.1.0':
                bias_args = {
                    'shape': (self.units, ),
                    'regularizer': self.bias_regularizer,
                    'constraint': self.bias_constraint,
                }
            else:
                bias_args = {
                    'shape': (self.units, ),
                    'regularizer': self.bias_regularizer,
                    'constraint': self.bias_constraint,
                    'caching_device': default_caching_device
                }
            self.bias_i = self.add_weight(name='bias_i',
                                          initializer=bias_initializer,
                                          **bias_args)
            self.bias_f = self.add_weight(name='bias_f',
                                          initializer=forget_bias_initializer,
                                          **bias_args)
            self.bias_c = self.add_weight(name='bias_c',
                                          initializer=bias_initializer,
                                          **bias_args)
            self.bias_o = self.add_weight(name='bias_o',
                                          initializer=bias_initializer,
                                          **bias_args)
        else:
            self.bias_i = None
            self.bias_f = None
            self.bias_c = None
            self.bias_o = None
        self.built = True
Beispiel #2
0
from __future__ import print_function

from distutils.version import LooseVersion

from tensorflow.keras import activations
from tensorflow.keras import backend as K
from tensorflow.keras import constraints
from tensorflow.keras import initializers
from tensorflow.keras import layers as keras_layers
from tensorflow.keras import regularizers
from tensorflow.python.ops import array_ops
from tensorflow.python.training.tracking import data_structures

from tf_nndct.utils import tf_utils as _tf_utils

if _tf_utils.tf_version() >= LooseVersion('2.6'):
    from keras.utils import tf_utils
    from keras.layers import recurrent
else:
    from tensorflow.python.keras.utils import tf_utils
    from tensorflow.python.keras.layers import recurrent


class LSTMCell(recurrent.DropoutRNNCellMixin, keras_layers.Layer):
    """Cell class for the LSTM layer.

  Arguments:
    units: Positive integer, dimensionality of the output space.
    activation: Activation function to use.
      Default: hyperbolic tangent (`tanh`).
      If you pass `None`, no activation is applied
Beispiel #3
0
from tensorflow.keras import activations
from tensorflow.keras import layers
from tensorflow.python.util import nest

from tf_nndct.graph import OpTypes
from tf_nndct.graph import dtypes
from tf_nndct.graph import op_def
from tf_nndct.graph import ops
from tf_nndct.graph import utils
from tf_nndct.utils import generic_utils
from tf_nndct.utils import keras_utils
from tf_nndct.utils import registry
from tf_nndct.utils import tensor_utils
from tf_nndct.utils import tf_utils

if tf_utils.tf_version() >= LooseVersion('2.6'):
    from keras.layers import recurrent
    from keras.layers import recurrent_v2
else:
    from tensorflow.python.keras.layers import recurrent
    from tensorflow.python.keras.layers import recurrent_v2

keras = tf.keras

_NO_LAYER_NAME = '_NO_LAYER_NAME'


class OpBuilder(object):

    _OP_COUNT = {}
Beispiel #4
0
import tensorflow as tf

from distutils.version import LooseVersion
from typing import Any, Callable, Dict, List, Optional, Union

from tensorflow.keras import layers
from tensorflow.python.eager import def_function
from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.training.tracking import data_structures
from tensorflow.python.util import nest

from tf_nndct.utils import logging
from tf_nndct.utils import tf_utils

_is_tf_later_than_220 = tf_utils.tf_version() >= LooseVersion('2.2')
if _is_tf_later_than_220:
  from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2_as_graph

keras = tf.keras

def data_format():
  return keras.backend.image_data_format()

def _keras_weight_name(name):
  # Given 'dense/kernel:0', return 'kernel'.
  return name.split('/')[-1].rsplit(':', 1)[0]

def get_named_weights(layer):
  params = collections.OrderedDict()
  weights = layer.get_weights()
Beispiel #5
0
 def build(self, input_shape):
     if _tf_utils.tf_version() < LooseVersion('2.1.0'):
         return self._build_v200(input_shape)
     else:
         return self._build_v210(input_shape)