def __init__(self, weight_shape: List[int], bias_shape: List[int], use_bias: bool = True, weight_norm: WeightNormArgType = False, weight_init: TensorInitArgType = DEFAULT_WEIGHT_INIT, bias_init: TensorInitArgType = DEFAULT_BIAS_INIT, data_init: Optional[DataInitArgType] = None, device: Optional[str] = None): device = device or current_device() weight_store = get_weight_store(weight_shape, initializer=weight_init, weight_norm=weight_norm, device=device) bias_store = get_bias_store(bias_shape, initializer=bias_init, use_bias=use_bias, device=device) if data_init is not None: if not isinstance(data_init, initializer.DataDependentInitializer) and \ (isinstance(data_init, type) or callable(data_init)): data_init = data_init() if not isinstance(data_init, initializer.DataDependentInitializer): raise TypeError(f'Unsupported data dependent initializer: ' f'{data_init!r}') super().__init__() self.weight_store = weight_store self.bias_store = bias_store if data_init is not None: data_init.register(self)
def get_weight_store(shape: List[int], initializer: TensorInitArgType = DEFAULT_WEIGHT_INIT, norm_axis: int = 1, weight_norm: WeightNormArgType = False, device: Optional[str] = None) -> BaseParamStore: """ Create a module which carries the `weight` parameter. Args: shape: The shape of the weight. initializer: The initializer for the weight. norm_axis: The axis, along with to normalize the weight. weight_norm: The mode of weight norm. Use `NormedAndScaledWeightStore` if `True` or `WeightNormMode.FULL`. Use `NormedWeightStore` if `WeightNormMode.NO_SCALE`. Use `WeightStore` if `False` or `WeightNormMode.NONE`. Returns: The weight object. """ device = device or current_device() if weight_norm is True or weight_norm == WeightNormMode.FULL: return NormedAndScaledWeightStore(shape, initializer, norm_axis, device) elif weight_norm == WeightNormMode.NO_SCALE: return NormedWeightStore(shape, initializer, norm_axis, device) elif weight_norm is False or weight_norm == WeightNormMode.NONE: return SimpleParamStore(shape, initializer, device) else: raise ValueError(f'Invalid value for argument `weight_norm`: ' f'{weight_norm!r}.')
def __init__(self, shape: List[int], initializer: TensorInitArgType, device: Optional[str] = None): super().__init__(shape) device = device or current_device() add_parameter(self, 'value', variable(shape, initializer=initializer, device=device))
def __init__(self, num_features: int, momentum: float = 0.1, epsilon: float = EPSILON, device: Optional[str] = None): super().__init__(num_features, eps=epsilon, momentum=momentum) device = device or current_device() if device != CPU_DEVICE: self.to(device=device)
def __init__(self, shape: List[int], initializer: TensorInitArgType, norm_axis: int = 1, device: Optional[str] = None, epsilon: float = EPSILON): super().__init__(shape) self.norm_axis = norm_axis device = device or current_device() self.epsilon = epsilon weight = variable(shape, initializer=initializer, device=device) with torch.no_grad(): v, _ = weight_norm_decompose(weight, norm_axis, epsilon) add_parameter(self, 'v', v)
def layer_to_device(layer: Module, device: Optional[str] = None) -> Module: """ Move the specified module or layer to the given device. The module or layer may be changed in-place. Args: layer: The module or layer to be moved. device: The device, to where move the module or layer. If not specified, will move to ``T.current_device()``. Returns: The layer instance. """ if device is None: device = current_device() layer = layer.to(device=torch.device(device)) return layer
def get_bias_store(shape: List[int], initializer: TensorInitArgType = DEFAULT_BIAS_INIT, use_bias: bool = True, device: Optional[str] = None) -> Optional[BaseParamStore]: """ Create a module that carries the `bias` parameter. Args: shape: The shape of the bias. initializer: The initializer for the bias. use_bias: Whether or not to use the bias? If `False`, will return :obj:`None`. Returns: The bias object, or :obj:`None` if `use_bias` is False. """ device = device or current_device() if use_bias: return SimpleParamStore(shape, initializer, device)