예제 #1
0
 def apply(self, inputs, mask=None):
     inputs = masked.MaskedModule(
         inputs,
         features=self.NUM_FEATURES[0],
         kernel_size=(5, 5),
         wrapped_module=flax.nn.Conv,
         mask=mask['MaskedModule_0'] if mask is not None else None)
     return masked.MaskedModule(
         inputs,
         features=self.NUM_FEATURES[1],
         wrapped_module=flax.nn.Dense,
         mask=mask['MaskedModule_1'] if mask is not None else None)
예제 #2
0
    def apply(self, inputs, mask=None):
        inputs = inputs.reshape(inputs.shape[0], -1)

        inputs = masked.MaskedModule(
            inputs,
            features=self.NUM_FEATURES[0],
            wrapped_module=flax.nn.Dense,
            mask=mask['MaskedModule_0'] if mask else None)
        return masked.MaskedModule(
            inputs,
            features=self.NUM_FEATURES[1],
            wrapped_module=flax.nn.Dense,
            mask=mask['MaskedModule_1'] if mask else None)
예제 #3
0
 def apply(self, inputs, mask=None):
     return masked.MaskedModule(
         inputs,
         features=self.NUM_FEATURES,
         kernel_size=(3, 3),
         wrapped_module=flax.nn.Conv,
         mask=mask['MaskedModule_0'] if mask is not None else None)
예제 #4
0
 def apply(self,
           inputs,
           mask = None):
   inputs = inputs.reshape(inputs.shape[0], -1)
   return masked.MaskedModule(
       inputs,
       features=self.NUM_FEATURES,
       wrapped_module=flax.deprecated.nn.Dense,
       mask=mask['MaskedModule_0'] if mask is not None else None)
예제 #5
0
    def apply(self, inputs, mask=None):

        layer_mask = mask['MaskedModule_0'] if mask else None
        return masked.MaskedModule(
            inputs,
            features=self.NUM_FEATURES,
            wrapped_module=flax.nn.Conv,
            kernel_size=(3, 3),
            mask=layer_mask,
            kernel_init=flax.nn.initializers.kaiming_normal())
예제 #6
0
    def apply(self, inputs, mask=None):
        inputs = inputs.reshape(inputs.shape[0], -1)

        layer_mask = mask['MaskedModule_0'] if mask else None
        return masked.MaskedModule(
            inputs,
            features=self.NUM_FEATURES,
            wrapped_module=flax.nn.Dense,
            mask=layer_mask,
            kernel_init=flax.nn.initializers.kaiming_normal())
예제 #7
0
 def apply(self, inputs, mask=None):
     inputs = masked.MaskedModule(
         inputs,
         features=self.NUM_FEATURES[0],
         kernel_size=(5, 5),
         wrapped_module=flax.deprecated.nn.Conv,
         mask=mask['MaskedModule_0'] if mask is not None else None)
     inputs = masked.MaskedModule(
         inputs,
         features=self.NUM_FEATURES[1],
         kernel_size=(3, 3),
         wrapped_module=flax.deprecated.nn.Conv,
         mask=mask['MaskedModule_1'] if mask is not None else None)
     return masked.MaskedModule(
         inputs,
         features=self.NUM_FEATURES[2],
         kernel_size=inputs.shape[1:-1],
         wrapped_module=flax.deprecated.nn.Conv,
         mask=mask['MaskedModule_2'] if mask is not None else None)
예제 #8
0
    def apply(self, inputs, *args, mask=None, **kwargs):

        layer_mask = mask['MaskedModule_0'] if mask else None
        return masked.MaskedModule(
            inputs,
            features=self.NUM_FEATURES,
            wrapped_module=flax.nn.Conv,
            kernel_size=(3, 3),
            mask=layer_mask,
            kernel_init=init.kaiming_sparse_normal(
                layer_mask['kernel'] if layer_mask is not None else None),
            **kwargs)
예제 #9
0
    def apply(self, inputs, *args, mask=None, **kwargs):
        inputs = inputs.reshape(inputs.shape[0], -1)

        layer_mask = mask['MaskedModule_0'] if mask else None
        return masked.MaskedModule(
            inputs,
            features=self.NUM_FEATURES,
            wrapped_module=flax.nn.Dense,
            mask=layer_mask,
            kernel_init=init.kaiming_sparse_normal(
                layer_mask['kernel'] if layer_mask is not None else None),
            **kwargs)