Ejemplo n.º 1
0
 def __call__(self, x):
     for width in self.widths[:-1]:
         x = nn.relu(nn.Dense(width)(x))
     return nn.Dense(self.widths[-1])(x)
Ejemplo n.º 2
0
  def __call__(
      self,
      inputs,
  ):
    """Applies ResNet model. Number of residual blocks inferred from hparams."""
    num_classes = self.num_classes
    hparams = self.hparams
    num_filters = self.num_filters
    dtype = self.dtype
    assert hparams.act_function in act_function_zoo.keys(
    ), 'Activation function type is not supported.'

    x = aqt_flax_layers.ConvAqt(
        features=num_filters,
        kernel_size=(7, 7),
        strides=(2, 2),
        padding=[(3, 3), (3, 3)],
        use_bias=False,
        dtype=dtype,
        name='init_conv',
        train=self.train,
        quant_context=self.quant_context,
        paxis_name='batch',
        hparams=hparams.conv_init,
    )(
        inputs)
    x = nn.BatchNorm(
        use_running_average=not self.train,
        momentum=0.9,
        epsilon=1e-5,
        dtype=dtype,
        name='init_bn')(
            x)
    if hparams.act_function == 'relu':
      x = nn.relu(x)
      x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
    else:
      # TODO(yichi): try adding other activation functions here
      # Use avg pool so that for binary nets, the distribution is symmetric.
      x = nn.avg_pool(x, (3, 3), strides=(2, 2), padding='SAME')
    filter_multiplier = hparams.filter_multiplier
    for i, block_hparams in enumerate(hparams.residual_blocks):
      proj = block_hparams.conv_proj
      # For projection layers (unless it is the first layer), strides = (2, 2)
      if i > 0 and proj is not None:
        filter_multiplier *= 2
        strides = (2, 2)
      else:
        strides = (1, 1)
      x = ResidualBlock(
          filters=int(num_filters * filter_multiplier),
          hparams=block_hparams,
          quant_context=self.quant_context,
          strides=strides,
          train=self.train,
          dtype=dtype)(
              x)
    if hparams.act_function == 'none':
      # The DenseAQT below is not binarized.
      # If removing the activation functions, there will be no act function
      # between the last residual block and the dense layer.
      # So add a ReLU in that case.
      # TODO(yichi): try BPReLU
      x = nn.relu(x)
    else:
      pass
    x = jnp.mean(x, axis=(1, 2))

    x = aqt_flax_layers.DenseAqt(
        features=num_classes,
        dtype=dtype,
        train=self.train,
        quant_context=self.quant_context,
        paxis_name='batch',
        hparams=hparams.dense_layer,
    )(x, padding_mask=None)

    x = jnp.asarray(x, dtype)
    output = nn.log_softmax(x)
    return output
 def __call__(self, x):
     x = nn.relu(nn.Dense(784)(x))
     x = nn.relu(nn.Dense(200)(x))
     x = nn.relu(nn.Dense(200)(x))
     x = nn.Dense(10)(x)
     return nn.log_softmax(x)
Ejemplo n.º 4
0
 def __call__(self, x):
     for layer in self.layers:
         x = layer(x)
         x = nn.relu(x)
     return x
Ejemplo n.º 5
0
    def __call__(
        self,
        inputs,
    ):
        """Applies ResNet model. Number of residual blocks inferred from hparams."""
        num_classes = self.num_classes
        hparams = self.hparams
        num_filters = self.num_filters
        dtype = self.dtype

        x = aqt_flax_layers.ConvAqt(
            features=num_filters,
            kernel_size=(7, 7),
            strides=(2, 2),
            padding=[(3, 3), (3, 3)],
            use_bias=False,
            dtype=dtype,
            name='init_conv',
            train=self.train,
            quant_context=self.quant_context,
            paxis_name=self.paxis_name,
            hparams=hparams.conv_init,
        )(inputs)
        x = nn.BatchNorm(use_running_average=not self.train,
                         momentum=0.9,
                         epsilon=1e-5,
                         dtype=dtype,
                         name='init_bn')(x)
        x = nn.relu(x)
        x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
        filter_multiplier = hparams.filter_multiplier
        for i, block_hparams in enumerate(hparams.residual_blocks):
            proj = block_hparams.conv_proj
            # For projection layers (unless it is the first layer), strides = (2, 2)
            if i > 0 and proj is not None:
                filter_multiplier *= 2
                strides = (2, 2)
            else:
                strides = (1, 1)
            x = ResidualBlock(filters=int(num_filters * filter_multiplier),
                              hparams=block_hparams,
                              quant_context=self.quant_context,
                              strides=strides,
                              train=self.train,
                              dtype=dtype)(x)

        x = jnp.mean(x, axis=(1, 2))

        x = aqt_flax_layers.DenseAqt(
            features=num_classes,
            dtype=dtype,
            train=self.train,
            quant_context=self.quant_context,
            paxis_name=self.paxis_name,
            hparams=hparams.dense_layer,
        )(x, padding_mask=None)

        x = jnp.asarray(x, dtype)
        # The output of ViT does not have log_softmax.
        # To make resnet50 teacher has the same type of outputs as ViT,
        # comment out the following line
        # output = nn.log_softmax(x)
        return x
Ejemplo n.º 6
0
 def __call__(self, z):
     z = nn.Dense(500, name='fc1')(z)
     z = nn.relu(z)
     z = nn.Dense(784, name='fc2')(z)
     return z
Ejemplo n.º 7
0
    def __call__(self, graph, feat):
        r"""

        Description
        -----------
        Compute GraphSAGE layer.

        Parameters
        ----------
        graph : DGLGraph
            The graph.
        feat : torch.Tensor or pair of torch.Tensor
            If a torch.Tensor is given, it represents the input feature of shape
            :math:`(N, D_{in})`
            where :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
            If a pair of torch.Tensor is given, the pair must contain two tensors of shape
            :math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.

        Returns
        -------
        torch.Tensor
            The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
            is size of output feature.
        """
        with graph.local_scope():
            if isinstance(feat, tuple):
                feat_src = nn.Dropout(self.feat_drop)(feat[0])
                feat_dst = nn.Dropout(self.feat_drop)(feat[1])
            else:
                feat_src = feat_dst = nn.Dropout(self.feat_drop)(feat)
                if graph.is_block:
                    feat_dst = feat_src[:graph.number_of_dst_nodes()]

            h_self = feat_dst

            # Handle the case of graphs without edges
            if graph.number_of_edges() == 0:
                graph.dstdata['neigh'] = jnp.zeros(
                    (feat_dst.shape[0], self.in_src_feats),
                )

            if self.aggregator_type == 'mean':
                graph.srcdata['h'] = feat_src
                graph.update_all(fn.copy_src('h', 'm'), fn.mean('m', 'neigh'))
                h_neigh = graph.dstdata['neigh']
            elif self.aggregator_type == 'gcn':
                check_eq_shape(feat)
                graph.srcdata['h'] = feat_src
                graph.dstdata['h'] = feat_dst     # same as above if homogeneous
                graph.update_all(fn.copy_src('h', 'm'), fn.sum('m', 'neigh'))
                # divide in_degrees
                degs = graph.in_degrees()
                h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (jnp.expand_dims(degs, -1) + 1)
            elif self.aggregator_type == 'pool':
                graph.srcdata['h'] = nn.relu(nn.Dense(self.in_src_feats)(feat_src))
                graph.update_all(fn.copy_src('h', 'm'), fn.max('m', 'neigh'))
                h_neigh = graph.dstdata['neigh']
            elif self.aggregator_type == 'lstm':
                raise NotImplementedError("Not Implemented in JAX.")
            else:
                raise KeyError('Aggregator type {} not recognized.'.format(self.aggregator_type))

            # GraphSAGE GCN does not require fc_self.
            if self.aggregator_type == 'gcn':
                rst = nn.Dense(self.out_feats, use_bias=self.bias)(h_neigh)
            else:
                rst = nn.Dense(self.out_feats)(h_self) + nn.Dense(self.out_feats)(h_neigh)
            # activation
            if self.activation is not None:
                rst = self.activation(rst)
            # normalization
            if self.norm is not None:
                rst = self.norm(rst)
            return rst
Ejemplo n.º 8
0
 def __call__(self, x):
     x = nn.Dense(500, name='fc1')(x)
     x = nn.relu(x)
     mean_x = nn.Dense(self.latents, name='fc2_mean')(x)
     logvar_x = nn.Dense(self.latents, name='fc2_logvar')(x)
     return mean_x, logvar_x
Ejemplo n.º 9
0
 def __call__(self, x):
     for size in self.sizes[:-1]:
         x = Dense(size)(x)
         x = nn.relu(x)
     return Dense(self.sizes[-1])(x)
Ejemplo n.º 10
0
    def __call__(self, x, mem, lengths_x, lengths_mem):
        """
        Compute multi-head self-attention.

        Parameters
        ----------
        x : torch.Tensor
            The input tensor used to compute queries.
        mem : torch.Tensor
            The memory tensor used to compute keys and values.
        lengths_x : list
            The array of node numbers, used to segment x.
        lengths_mem : list
            The array of node numbers, used to segment mem.
        """
        batch_size = len(lengths_x)
        max_len_x = max(lengths_x)
        max_len_mem = max(lengths_mem)

        queries = self.proj_q(x).reshape((-1, self.num_heads, self.d_head))
        keys = self.proj_k(mem).reshape((-1, self.num_heads, self.d_head))
        values = self.proj_v(mem).reshape((-1, self.num_heads, self.d_head))

        # padding to (B, max_len_x/mem, num_heads, d_head)
        queries = F.pad_packed_tensor(queries, lengths_x, 0)
        keys = F.pad_packed_tensor(keys, lengths_mem, 0)
        values = F.pad_packed_tensor(values, lengths_mem, 0)

        # attention score with shape (B, num_heads, max_len_x, max_len_mem)
        e = jnp.einsum('bxhd,byhd->bhxy', queries, keys)
        # normalize
        e = e / np.sqrt(self.d_head)

        # generate mask
        mask = jnp.zeros((batch_size, max_len_x, max_len_mem))
        for i in range(batch_size):
            mask = jax.ops.index_update(
                mask,
                jax.ops.index[i, :lengths_x[i], :lengths_mem[i]],
                1,
            )
        mask = jnp.expand_dims(mask, 1)

        e = jnp.where(
            mask == 0,
            -float('inf') * jnp.ones_like(e),
            e
        )

        # apply softmax
        alpha = jax.nn.softmax(e, axis=-1)
        # sum of value weighted by alpha
        out = jnp.einsum('bhxy,byhd->bxhd', alpha, values)
        # project to output
        out = self.proj_o(
            out.reshape((batch_size, max_len_x, self.num_heads * self.d_head)))
        # pack tensor
        out = F.pack_padded_tensor(out, lengths_x)

        # intra norm
        x = self.norm_in(x + out)

        # inter norm
        ffn_x = nn.Dense(self.d_model)(
            nn.relu(
                nn.Dropout(self.dropouth)(
                    nn.Dense(self.d_ff)(x)
                )
            )
        )

        x = self.norm_inter(x + ffn_x)

        return x
Ejemplo n.º 11
0
  def __call__(self, x, train):
    def dense_layers(y, block, num_blocks, growth_rate):
      for _ in range(num_blocks):
        y = block(growth_rate)(y, train=train)
      return y

    def update_num_features(num_features, num_blocks, growth_rate, reduction):
      num_features += num_blocks * growth_rate
      if reduction is not None:
        num_features = int(math.floor(num_features * reduction))
      return num_features

    # Initial convolutional layer
    num_features = 2 * self.growth_rate
    conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype)
    y = conv(
        features=num_features,
        kernel_size=(3, 3),
        padding=((1, 1), (1, 1)),
        name='conv1')(x)

    # Internal dense and transtion blocks
    num_blocks = _block_size_options[self.num_layers]
    block = functools.partial(
        BottleneckBlock,
        dtype=self.dtype,
        normalizer=self.normalizer)
    for i in range(3):
      y = dense_layers(y, block, num_blocks[i], self.growth_rate)
      num_features = update_num_features(num_features, num_blocks[i],
                                         self.growth_rate, self.reduction)
      y = TransitionBlock(
          num_features,
          dtype=self.dtype,
          normalizer=self.normalizer,
          use_kernel_size_as_stride_in_pooling=self
          .use_kernel_size_as_stride_in_pooling)(
              y, train=train)

    # Final dense block
    y = dense_layers(y, block, num_blocks[3], self.growth_rate)

    # Final pooling
    maybe_normalize = model_utils.get_normalizer(self.normalizer, train)
    y = maybe_normalize()(y)
    y = nn.relu(y)
    y = nn.avg_pool(
        y,
        window_shape=(4, 4),
        strides=(4, 4) if self.use_kernel_size_as_stride_in_pooling else (1, 1))

    # Classification layer
    y = jnp.reshape(y, (y.shape[0], -1))
    if self.normalize_classifier_input:
      maybe_normalize = model_utils.get_normalizer(
          self.normalize_classifier_input, train)
      y = maybe_normalize()(y)
    y = y * self.classification_scale_factor

    y = nn.Dense(self.num_outputs)(y)
    return y
Ejemplo n.º 12
0
 def __call__(self, input, apply_relu: bool = False):
     return nn.relu(input) if apply_relu else input