Ejemplo n.º 1
0
    def _get_weights(self, name, shape, dtype=None, initializer=None):
        if initializer is None: initializer = self.initializer
        else: initializer = initializers.get(initializer)
        # Set default dtype if not specified
        if dtype is None: dtype = hub.dtype
        # Get regularizer if necessary
        regularizer = None
        if hub.use_global_regularizer:
            regularizer = hub.get_global_regularizer()
        # Get constraint if necessary
        constraint = hub.get_global_constraint()
        # Get weights
        weights = tf.get_variable(name,
                                  shape,
                                  dtype=dtype,
                                  initializer=initializer,
                                  regularizer=regularizer,
                                  constraint=constraint)
        # If weight dropout is positive, dropout and return
        if self.weight_dropout > 0:
            return linker.dropout(weights, self.weight_dropout, rescale=True)
        # If no mask is needed to be created, return weight variable directly
        if not any(
            [self.prune_is_on, self.being_etched, hub.force_to_use_pruner]):
            return weights
        # Register, context.pruner should be created in early model.build
        assert context.pruner is not None
        # Merged lottery logic into etch logic
        if self.prune_is_on:
            assert not self.being_etched
            self.etch = 'lottery:prune_frac={}'.format(self.prune_frac)

        # Register etch kernel to pruner
        masked_weights = context.pruner.register_to_dense(weights, self.etch)

        # if self.prune_is_on:
        #   masked_weights = context.pruner.register_to_dense(
        #     weights, self.prune_frac)
        # else:
        #   # TODO
        #   assert self.being_etched
        #   mask = self._get_etched_surface(weights)
        #   masked_weights = context.pruner.register_with_mask(weights, mask)

        # Return
        assert isinstance(masked_weights, tf.Tensor)
        return masked_weights
Ejemplo n.º 2
0
 def dropout(input_, dropout_rate):
     return linker.dropout(input_, dropout_rate)