def _get_weights(self, name, shape, dtype=None, initializer=None): if initializer is None: initializer = self.initializer else: initializer = initializers.get(initializer) # Set default dtype if not specified if dtype is None: dtype = hub.dtype # Get regularizer if necessary regularizer = None if hub.use_global_regularizer: regularizer = hub.get_global_regularizer() # Get constraint if necessary constraint = hub.get_global_constraint() # Get weights weights = tf.get_variable(name, shape, dtype=dtype, initializer=initializer, regularizer=regularizer, constraint=constraint) # If weight dropout is positive, dropout and return if self.weight_dropout > 0: return linker.dropout(weights, self.weight_dropout, rescale=True) # If no mask is needed to be created, return weight variable directly if not any( [self.prune_is_on, self.being_etched, hub.force_to_use_pruner]): return weights # Register, context.pruner should be created in early model.build assert context.pruner is not None # Merged lottery logic into etch logic if self.prune_is_on: assert not self.being_etched self.etch = 'lottery:prune_frac={}'.format(self.prune_frac) # Register etch kernel to pruner masked_weights = context.pruner.register_to_dense(weights, self.etch) # if self.prune_is_on: # masked_weights = context.pruner.register_to_dense( # weights, self.prune_frac) # else: # # TODO # assert self.being_etched # mask = self._get_etched_surface(weights) # masked_weights = context.pruner.register_with_mask(weights, mask) # Return assert isinstance(masked_weights, tf.Tensor) return masked_weights
def dropout(input_, dropout_rate): return linker.dropout(input_, dropout_rate)