Ejemplo n.º 1
0
    def _build(self, laplacian, inputs):
        input_shape = tuple(inputs.get_shape().as_list())
        if len(input_shape) != 3:
            raise snt.IncompatibleShapeError(
                "{}: rank of shape must be 3 not: {}".format(
                    self.scope_name, len(input_shape)))

        if input_shape[2] is None:
            raise snt.IncompatibleShapeError(
                "{}: Input size must be specified at module build time".format(
                    self.scope_name))

        if self._input_shape is not None and input_shape[
                2] != self._input_shape[2]:
            raise snt.IncompatibleShapeError(
                "{}: Input shape must be [batch_size, {}, {}] not: [batch_size, {}, {}]"
                .format(self.scope_name, input_shape[2], self._input_shape[2],
                        input_shape[1], input_shape[2]))

        self._input_shape = input_shape
        dtype = inputs.dtype

        if "w" not in self._initializers:
            self._initializers["w"] = tfutils.create_linear_initializer(
                self._input_shape[2], self._output_size, dtype)

        if "b" not in self._initializers and self._use_bias:
            self._initializers["b"] = tfutils.create_bias_initializer(
                self._input_shape[2], self._output_size, dtype)

        weight_shape = (self._input_shape[2], self.output_size)
        self._w = tf.get_variable(
            "w",
            shape=weight_shape,
            dtype=dtype,
            initializer=self._initializers["w"],
            partitioner=self._partitioners.get("w", None),
            regularizer=self._regularizers.get("w", None))
        if self._w not in tf.get_collection('weights'):
            tf.add_to_collection('weights', self._w)
        outputs = tfutils.matmul(inputs, self._w)

        if self._use_bias:
            bias_shape = (self.output_size, )
            self._b = tf.get_variable(
                "b",
                shape=bias_shape,
                dtype=dtype,
                initializer=self._initializers["b"],
                partitioner=self._partitioners.get("b", None),
                regularizer=self._regularizers.get("b", None))
            if self._b not in tf.get_collection('biases'):
                tf.add_to_collection('biases', self._b)
            outputs += self._b

        return outputs
Ejemplo n.º 2
0
    def _build(self, inputs):
        # Based on:
        # https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/contrib/layers/python/layers/normalization.py
        input_shape = tuple(inputs.get_shape().as_list())
        if len(input_shape) != 3:
            raise snt.IncompatibleShapeError(
                "{}: rank of shape must be 3 not: {}".format(
                    self.scope_name, len(input_shape)))

        if input_shape[2] is None:
            raise snt.IncompatibleShapeError(
                "{}: Input size must be specified at module build time".format(
                    self.scope_name))
        self._input_shape = input_shape
        dtype = inputs.dtype
        group_sizes = [
            self.group_size, self._input_shape[2] // self.group_size
        ]
        broadcast_shape = [1, 1] + group_sizes
        self._gamma = tf.get_variable("gamma",
                                      shape=(self._input_shape[2]),
                                      dtype=dtype,
                                      initializer=self._initializers["gamma"])
        if self._gamma not in tf.get_collection('weights'):
            tf.add_to_collection('weights', self._gamma)
        self._gamma = tf.reshape(self._gamma, broadcast_shape)

        self._beta = tf.get_variable("beta",
                                     shape=(self._input_shape[2], ),
                                     dtype=dtype,
                                     initializer=self._initializers["beta"])
        if self._beta not in tf.get_collection('biases'):
            tf.add_to_collection('biases', self._beta)
        self._beta = tf.reshape(self._beta, broadcast_shape)

        ##### Actually perform operations
        # Reshape input
        original_shape = [-1, self._input_shape[1], self._input_shape[2]]
        inputs_shape = [-1, self._input_shape[1]] + group_sizes

        inputs = tf.reshape(inputs, inputs_shape)

        # Normalize
        mean, variance = tf.nn.moments(inputs, [1, 3], keep_dims=True)
        gain = tf.rsqrt(variance + 1e-7) * self._gamma
        offset = -mean * gain + self._beta
        outputs = inputs * gain + offset

        # Reshape back to output
        outputs = tf.reshape(outputs, original_shape)

        return outputs
Ejemplo n.º 3
0
    def _build(self, laplacian, inputs):
        input_shape = tuple(inputs.get_shape().as_list())
        if len(input_shape) != 3:
            raise snt.IncompatibleShapeError(
                "{}: rank of shape must be 3 not: {}".format(
                    self.scope_name, len(input_shape)))

        if input_shape[2] is None:
            raise snt.IncompatibleShapeError(
                "{}: Input size must be specified at module build time".format(
                    self.scope_name))

        if input_shape[1] is None:
            raise snt.IncompatibleShapeError(
                "{}: Number of nodes must be specified at module build time".
                format(self.scope_name))

        if self._input_shape is not None and \
            (input_shape[2] != self._input_shape[2] or \
             input_shape[1] != self._input_shape[1]):
            raise snt.IncompatibleShapeError(
                "{}: Input shape must be [batch_size, {}, {}] not: [batch_size, {}, {}]"
                .format(self.scope_name, self._input_shape[1],
                        self._input_shape[2], input_shape[1], input_shape[2]))

        self._input_shape = input_shape
        dtype = inputs.dtype

        for k, s in self.weight_keys:
            if k not in self._initializers:
                self._initializers[k] = tfutils.create_linear_initializer(
                    self._input_shape[2], s, dtype)

        if self._use_bias:
            for k, s in self.bias_keys:
                if k not in self._initializers:
                    self._initializers[k] = tfutils.create_bias_initializer(
                        self._input_shape[2], s, dtype)

        for k, s in self.weight_keys:
            weight_shape = (self._input_shape[2], s)
            self.weights[k] = tf.get_variable(
                k,
                shape=weight_shape,
                dtype=dtype,
                initializer=self._initializers[k],
                partitioner=self._partitioners.get(k, None),
                regularizer=self._regularizers.get(k, None))
            if self.weights[k] not in tf.get_collection('weights'):
                tf.add_to_collection('weights', self.weights[k])

        if self._use_bias:
            for k, s in self.bias_keys:
                bias_shape = (s, )
                self.weights[k] = tf.get_variable(
                    k,
                    shape=bias_shape,
                    dtype=dtype,
                    initializer=self._initializers[k],
                    partitioner=self._partitioners.get(k, None),
                    regularizer=self._regularizers.get(k, None))
            if self.weights[k] not in tf.get_collection('biases'):
                tf.add_to_collection('biases', self.weights[k])

        preactiv_ = tfutils.matmul(inputs, self.weights["w"])
        f1_ = tfutils.matmul(inputs, self.weights["f1"])
        f2_ = tfutils.matmul(inputs, self.weights["f2"])
        if self._use_bias:
            f1_ += self.weights["d1"]
            f2_ += self.weights["d2"]
        preattn_mat_ = f1_ + tf.transpose(f2_, [0, 2, 1])
        if self._sparse:
            preattn_mat = self._attn_activ(preattn_mat_) * laplacian
        else:
            preattn_mat = self._attn_activ(preattn_mat_) + laplacian
        attn_mat = tf.nn.softmax(preattn_mat, axis=-1)
        preactiv = tfutils.batch_matmul(attn_mat, preactiv_)
        skip = tfutils.matmul(inputs, self.weights["u"])

        if self._use_bias:
            preactiv += self.weights["b"]
            skip += self.weights["c"]

        activ = self._activ(preactiv) + skip

        return activ
Ejemplo n.º 4
0
    def _build(self, laplacian, inputs):
        input_shape = tuple(inputs.get_shape().as_list())
        if len(input_shape) != 3:
            raise snt.IncompatibleShapeError(
                "{}: rank of shape must be 3 not: {}".format(
                    self.scope_name, len(input_shape)))

        if input_shape[2] is None:
            raise snt.IncompatibleShapeError(
                "{}: Input size must be specified at module build time".format(
                    self.scope_name))

        if input_shape[1] is None:
            raise snt.IncompatibleShapeError(
                "{}: Number of nodes must be specified at module build time".
                format(self.scope_name))

        if self._input_shape is not None and \
            (input_shape[2] != self._input_shape[2] or \
             input_shape[1] != self._input_shape[1]):
            raise snt.IncompatibleShapeError(
                "{}: Input shape must be [batch_size, {}, {}] not: [batch_size, {}, {}]"
                .format(self.scope_name, self._input_shape[1],
                        self._input_shape[2], input_shape[1], input_shape[2]))

        self._input_shape = input_shape
        dtype = inputs.dtype

        if "w" not in self._initializers:
            self._initializers["w"] = tfutils.create_linear_initializer(
                self._input_shape[2], self._output_size, dtype)
        if "u" not in self._initializers:
            self._initializers["u"] = tfutils.create_linear_initializer(
                self._input_shape[2], self._output_size, dtype)

        if "b" not in self._initializers and self._use_bias:
            self._initializers["b"] = tfutils.create_bias_initializer(
                self._input_shape[2], self._output_size, dtype)
        if "c" not in self._initializers and self._use_bias:
            self._initializers["c"] = tfutils.create_bias_initializer(
                self._input_shape[2], self._output_size, dtype)

        weight_shape = (self._input_shape[2], self.output_size)
        self._w = tf.get_variable(
            "w",
            shape=weight_shape,
            dtype=dtype,
            initializer=self._initializers["w"],
            partitioner=self._partitioners.get("w", None),
            regularizer=self._regularizers.get("w", None))
        if self._w not in tf.get_collection('weights'):
            tf.add_to_collection('weights', self._w)
        self._u = tf.get_variable(
            "u",
            shape=weight_shape,
            dtype=dtype,
            initializer=self._initializers["u"],
            partitioner=self._partitioners.get("u", None),
            regularizer=self._regularizers.get("u", None))
        if self._u not in tf.get_collection('weights'):
            tf.add_to_collection('weights', self._u)
        preactiv_ = tfutils.matmul(inputs, self._w)
        preactiv = tfutils.batch_matmul(laplacian, preactiv_)
        skip = tfutils.matmul(inputs, self._u)

        if self._use_bias:
            bias_shape = (self.output_size, )
            self._b = tf.get_variable(
                "b",
                shape=bias_shape,
                dtype=dtype,
                initializer=self._initializers["b"],
                partitioner=self._partitioners.get("b", None),
                regularizer=self._regularizers.get("b", None))
            if self._b not in tf.get_collection('biases'):
                tf.add_to_collection('biases', self._b)
            self._c = tf.get_variable(
                "c",
                shape=bias_shape,
                dtype=dtype,
                initializer=self._initializers["c"],
                partitioner=self._partitioners.get("c", None),
                regularizer=self._regularizers.get("c", None))
            if self._c not in tf.get_collection('biases'):
                tf.add_to_collection('biases', self._c)
            preactiv += self._b
            skip += self._c

        activ = self._activ(preactiv) + skip

        return activ