Esempio n. 1
0
    def __init__(self,
                 nstate: int,
                 nin: int,
                 nout: int,
                 activation: Callable = jn.tanh,
                 w_init: Callable = kaiming_normal):
        """Creates an RNN instance.

        Args:
            nstate: number of hidden units.
            nin: number of input units.
            nout: number of output units.
            activation: actication function for hidden layer.
            w_init: weight initializer for RNN model weights.
        """
        self.num_inputs = nin
        self.num_outputs = nout
        self.nstate = nstate
        self.activation = activation

        # Hidden layer parameters
        self.w_xh = TrainVar(w_init((self.num_inputs, self.nstate)))
        self.w_hh = TrainVar(w_init((self.nstate, self.nstate)))
        self.b_h = TrainVar(jn.zeros(self.nstate))

        self.output_layer = Linear(self.nstate, self.num_outputs)
Esempio n. 2
0
    def __init__(self,
                 nin: int,
                 nout: int,
                 k: Union[Tuple[int, int], int],
                 strides: Union[Tuple[int, int], int] = 1,
                 dilations: Union[Tuple[int, int], int] = 1,
                 groups: int = 1,
                 padding: Union[ConvPadding, str, ConvPaddingInt] = ConvPadding.SAME,
                 use_bias: bool = True,
                 w_init: Callable = kaiming_normal):
        """Creates a Conv2D module instance.

        Args:
            nin: number of channels of the input tensor.
            nout: number of channels of the output tensor.
            k: size of the convolution kernel, either tuple (height, width) or single number if they're the same.
            strides: convolution strides, either tuple (stride_y, stride_x) or single number if they're the same.
            dilations: spacing between kernel points (also known as astrous convolution),
                       either tuple (dilation_y, dilation_x) or single number if they're the same.
            groups: number of input and output channels group. When groups > 1 convolution operation is applied
                    individually for each group. nin and nout must both be divisible by groups.
            padding: padding of the input tensor, either Padding.SAME, Padding.VALID or numerical values.
            use_bias: if True then convolution will have bias term.
            w_init: initializer for convolution kernel (a function that takes in a HWIO shape and returns a 4D matrix).
        """
        super().__init__()
        assert nin % groups == 0, 'nin should be divisible by groups'
        assert nout % groups == 0, 'nout should be divisible by groups'
        self.b = TrainVar(jn.zeros((nout, 1, 1))) if use_bias else None
        self.w = TrainVar(w_init((*util.to_tuple(k, 2), nin // groups, nout)))  # HWIO
        self.padding = util.to_padding(padding, 2)
        self.strides = util.to_tuple(strides, 2)
        self.dilations = util.to_tuple(dilations, 2)
        self.groups = groups
Esempio n. 3
0
    def _get_loss(self, loss_name: str):
        if loss_name == 'logistic':
            x = TrainVar(jn.zeros(2))
            model_vars = VarCollection({'x': x})

            def loss():
                return jn.log(jn.exp(-jn.sum(x.value)) + 1)

            return model_vars, loss
        if loss_name == 'square':
            # loss = x*x + y*y.
            x = TrainVar(jn.ones(2))
            y = TrainVar(jn.ones(3))
            model_vars = VarCollection({'x': x, 'y': y})

            def loss():
                return jn.dot(x.value, x.value) + jn.dot(y.value, y.value)

            return model_vars, loss
        if loss_name == 'rastrigin':
            d = 2
            x = TrainVar(jn.ones(d))
            model_vars = VarCollection({'x': x})

            def loss():
                return 10 * d + jn.dot(x.value, x.value) - 10 * jn.sum(
                    jn.cos(2 * math.pi * x.value))

            return model_vars, loss
        raise ValueError
Esempio n. 4
0
    def __init__(self, hparams):
        self.hparams = hparams

        self._wpe = TrainVar(np.zeros([hparams.n_ctx, hparams.n_embd]))
        self._wte = TrainVar(np.zeros([hparams.n_vocab, hparams.n_embd]))

        self.blocks = objax.module.ModuleList(
            [Block(hparams.n_embd, hparams) for _ in range(hparams.n_layer)])
        self.norm = Norm(hparams.n_embd)
Esempio n. 5
0
 def __init__(self, repin, repout):
     nin,nout = repin.size(),repout.size()
     super().__init__(nin,nout)
     self.b = TrainVar(objax.random.uniform((nout,))/jnp.sqrt(nout))
     self.w = TrainVar(orthogonal((nout, nin)))
     self.rep_W = rep_W = repout*repin.T
     
     rep_bias = repout
     self.Pw = rep_W.equivariant_projector()
     self.Pb = rep_bias.equivariant_projector()
     logging.info(f"Linear W components:{rep_W.size()} rep:{rep_W}")
Esempio n. 6
0
    def __init__(self, nin: int, nout: int, use_bias: bool = True, w_init: Callable = xavier_normal):
        """Creates a Linear module instance.

        Args:
            nin: number of channels of the input tensor.
            nout: number of channels of the output tensor.
            use_bias: if True then linear layer will have bias term.
            w_init: weight initializer for linear layer (a function that takes in a IO shape and returns a 2D matrix).
        """
        super().__init__()
        self.b = TrainVar(jn.zeros(nout)) if use_bias else None
        self.w = TrainVar(w_init((nin, nout)))
    def __init__(self,
                 num_models,
                 dense_kernel_size=32,
                 embedding_dim=32,
                 seed=0,
                 logit_temp=1.0,
                 orthogonal_init=True):

        if num_models <= 1:
            raise Exception("requires at least two models")

        self.num_models = num_models
        self.logit_temp = logit_temp

        key = random.PRNGKey(seed)
        subkeys = random.split(key, 8)

        # conv stack kernels and biases
        if orthogonal_init:
            initialiser = orthogonal
        else:
            initialiser = he_normal
        self.conv_kernels = objax.ModuleList()
        self.conv_biases = objax.ModuleList()
        input_channels = 3
        for i, output_channels in enumerate([32, 64, 64, 64, 64, 64]):
            self.conv_kernels.append(
                TrainVar(initialiser()(
                    subkeys[i],
                    (num_models, 3, 3, input_channels, output_channels))))
            self.conv_biases.append(
                TrainVar(jnp.zeros((num_models, output_channels))))
            input_channels = output_channels

        # dense kernels and biases
        self.dense_kernels = TrainVar(initialiser()(
            subkeys[6],
            (num_models, 1, 1, output_channels, dense_kernel_size)))
        self.dense_biases = TrainVar(jnp.zeros(
            (num_models, dense_kernel_size)))

        # embeddings kernel; no bias or non linearity.
        if orthogonal_init:
            initialiser = orthogonal
        else:
            initialiser = glorot_normal
        self.embedding_kernels = TrainVar(initialiser()(
            subkeys[7], (num_models, 1, 1, dense_kernel_size, embedding_dim)))
Esempio n. 8
0
    def __init__(self,
                 nin: int,
                 nout: int,
                 k: Union[Tuple[int, int], int],
                 strides: Union[Tuple[int, int], int] = 1,
                 dilations: Union[Tuple[int, int], int] = 1,
                 padding: Union[ConvPadding, str, ConvPaddingInt] = ConvPadding.SAME,
                 use_bias: bool = True,
                 w_init: Callable = kaiming_normal):
        """Creates a ConvTranspose2D module instance.

        Args:
            nin: number of channels of the input tensor.
            nout: number of channels of the output tensor.
            k: size of the convolution kernel, either tuple (height, width) or single number if they're the same.
            strides: convolution strides, either tuple (stride_y, stride_x) or single number if they're the same.
            dilations: spacing between kernel points (also known as astrous convolution),
                       either tuple (dilation_y, dilation_x) or single number if they're the same.
            padding: padding of the input tensor, either Padding.SAME or Padding.VALID.
            use_bias: if True then convolution will have bias term.
            w_init: initializer for convolution kernel (a function that takes in a HWIO shape and returns a 4D matrix).
        """
        super().__init__(nin=nout, nout=nin, k=k, strides=strides, padding=padding, use_bias=False, w_init=w_init)
        self.b = TrainVar(jn.zeros((nout, 1, 1))) if use_bias else None
        self.dilations = util.to_tuple(dilations, 2)
Esempio n. 9
0
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel_size: Union[Tuple[int, int], int],
                 stride: Union[Tuple[int, int], int] = 1,
                 padding: Union[str, Tuple[int, int], int] = 0,
                 dilation: Union[Tuple[int, int], int] = 1,
                 groups: int = 1,
                 bias: bool = False,
                 kernel_init: Callable = kaiming_normal,
                 bias_init: Callable = jnp.zeros,
                 ):
        """Creates a Conv2D module instance.

        Args:
            in_channels: number of channels of the input tensor.
            out_channels: number of channels of the output tensor.
            kernel_size: size of the convolution kernel, either tuple (height, width) or single number if they're the same.
            stride: convolution strides, either tuple (stride_y, stride_x) or single number if they're the same.
            dilation: spacing between kernel points (also known as astrous convolution),
                       either tuple (dilation_y, dilation_x) or single number if they're the same.
            groups: number of input and output channels group. When groups > 1 convolution operation is applied
                    individually for each group. nin and nout must both be divisible by groups.
            padding: padding of the input tensor, either Padding.SAME or Padding.VALID.
            bias: if True then convolution will have bias term.
            kernel_init: initializer for convolution kernel (a function that takes in a HWIO shape and returns a 4D matrix).
        """
        super().__init__()
        assert in_channels % groups == 0, 'in_chs should be divisible by groups'
        assert out_channels % groups == 0, 'out_chs should be divisible by groups'
        kernel_size = util.to_tuple(kernel_size, 2)
        self.weight = TrainVar(kernel_init((out_channels, in_channels // groups, *kernel_size)))  # OIHW
        self.bias = TrainVar(bias_init((out_channels,))) if bias else None
        self.strides = util.to_tuple(stride, 2)
        self.dilations = util.to_tuple(dilation, 2)
        if isinstance(padding, str):
            if padding == 'LIKE':
                padding = (
                    get_like_padding(kernel_size[0], self.strides[0], self.dilations[0]),
                    get_like_padding(kernel_size[1], self.strides[1], self.dilations[1]))
                padding = [padding, padding]
        else:
            padding = util.to_tuple(padding, 2)
            padding = [padding, padding]
        self.padding = padding
        self.groups = groups
Esempio n. 10
0
    def __init__(self, dims: Iterable[int], redux: Iterable[int], momentum: float = 0.999, eps: float = 1e-6):
        """Creates a BatchNorm module instance.

        Args:
            dims: shape of the batch normalization state variables.
            redux: list of indices of reduction axes. Batch norm statistics are computed by averaging over these axes.
            momentum: value used to compute exponential moving average of batch statistics.
            eps: small value which is used for numerical stability.
        """
        super().__init__()
        dims = tuple(dims)
        self.momentum = momentum
        self.eps = eps
        self.redux = tuple(redux)
        self.running_mean = StateVar(jn.zeros(dims))
        self.running_var = StateVar(jn.ones(dims))
        self.beta = TrainVar(jn.zeros(dims))
        self.gamma = TrainVar(jn.ones(dims))
Esempio n. 11
0
    def __init__(
            self,
            in_features: int,
            out_features: int,
            bias: bool = True,
            weight_init: Callable = xavier_normal,
            bias_init: Callable = jnp.zeros,
    ):
        """Creates a Linear module instance.

        Args:
            in_features: number of channels of the input tensor.
            out_features: number of channels of the output tensor.
            bias: if True then linear layer will have bias term.
            weight_init: weight initializer for linear layer (a function that takes in a IO shape and returns a 2D matrix).
        """
        super().__init__()
        self.weight = TrainVar(weight_init((out_features, in_features)))
        self.bias = TrainVar(bias_init(out_features)) if bias else None
Esempio n. 12
0
    def __init__(self, nin: int, rank: int, groups: int = 32, eps: float = 1e-5):
        """Creates a GroupNorm module instance.

        Args:
            nin: number of input channels.
            rank: rank of the input tensor.
            groups: number of normalization groups.
            eps: small value which is used for numerical stability.
        """
        groups = min(groups, nin)
        assert nin % groups == 0, 'nin should be divisible by groups'

        super(GroupNorm, self).__init__()
        self.nin = nin
        self.groups = groups
        self.eps = eps
        self.redux = tuple(range(2, rank + 1))
        var_shape = (1, nin) + (1,) * (rank - 2)
        self.gamma = TrainVar(jn.ones(var_shape))
        self.beta = TrainVar(jn.zeros(var_shape))
Esempio n. 13
0
    def __init__(self,
                 nstate: int,
                 nin: int,
                 nout: int,
                 w_init: Callable = kaiming_normal):
        """Creates a GRU instance.

        Args:
            nstate: number of hidden units.
            nin: number of input units.
            nout: number of output units.
            w_init: weight initializer for GRU model weights.
        """
        self.num_inputs = nin
        self.num_outputs = nout
        self.nstate = nstate

        # Update gate parameters
        self.w_xz = TrainVar(w_init((self.num_inputs, self.nstate)))
        self.w_hz = TrainVar(w_init((self.nstate, self.nstate)))
        self.b_z = TrainVar(jn.zeros(self.nstate))

        # Reset gate parameters
        self.w_xr = TrainVar(w_init((self.num_inputs, self.nstate)))
        self.w_hr = TrainVar(w_init((self.nstate, self.nstate)))
        self.b_r = TrainVar(jn.zeros(self.nstate))

        # Candidate hidden state parameters
        self.w_xh = TrainVar(w_init((self.num_inputs, self.nstate)))
        self.w_hh = TrainVar(w_init((self.nstate, self.nstate)))
        self.b_h = TrainVar(jn.zeros(self.nstate))

        # Output layer parameters
        self.w_hq = TrainVar(w_init((self.nstate, self.num_outputs)))
        self.b_q = TrainVar(jn.zeros(self.num_outputs))
Esempio n. 14
0
 def __init__(self, repin, repout):
     super().__init__()
     Wdim, weight_proj = bilinear_weights(repout,repin)
     self.weight_proj = jit(weight_proj)
     self.w = TrainVar(objax.random.normal((Wdim,)))#xavier_normal((Wdim,))) #TODO: revert to xavier
     logging.info(f"BiW components: dim:{Wdim}")
Esempio n. 15
0
 def __init__(self, nx, nf):
     super().__init__()
     self.nx = nx
     self.nf = nf
     self.w = TrainVar(np.zeros([1, nx, nf]))
     self.b = TrainVar(np.zeros([nf]))
Esempio n. 16
0
 def __init__(self, n_state, axis=-1, epsilon=1e-5):
     super().__init__()
     self.g = TrainVar(np.zeros(n_state))
     self.b = TrainVar(np.ones(n_state))
     self.axis = axis
     self.epsilon = epsilon