def _gen_layers(self, layers, te_dim, use_bias, no_weights, init_weights, ce_dim, noise_dim): """Generate all layers of this network. This method will create the parameters of each layer. Note, this method should only be called by the constructor. This method will add the attributes "_hidden_dims" and "_out_dims". If "no_weights" is False, it will also create an attribute "_weights" and initialize all parameters. Otherwise, _weights" is set to None. Args: See constructur arguments. """ ### Compute the shapes of all parameters. # Hidden layers. self._hidden_dims = [] prev_dim = te_dim if ce_dim is not None: prev_dim += ce_dim if noise_dim != -1: prev_dim += noise_dim for i, size in enumerate(layers): self._hidden_dims.append([size, prev_dim]) if use_bias: self._hidden_dims.append([size]) prev_dim = size self._last_hidden_size = prev_dim # Output layers. self._out_dims = [] for i, dims in enumerate(self.target_shapes): nouts = np.prod(dims) self._out_dims.append([nouts, self._last_hidden_size]) if use_bias: self._out_dims.append([nouts]) if no_weights: self._theta = None return ### Create parameter tensors. # If "use_bias" is True, then each odd entry of this list will contain # a weight matrix and each even entry a bias vector. Otherwise, # it only contains a weight matrix per layer. self._theta = nn.ParameterList() for i, dims in enumerate(self._hidden_dims + self._out_dims): self._theta.append(nn.Parameter(torch.Tensor(*dims), requires_grad=True)) if init_weights is not None: assert (len(init_weights) == len(self._theta)) for i in range(len(init_weights)): assert (np.all(np.equal(list(init_weights[i].shape), list(self._theta[i].shape)))) self._theta[i].data = init_weights[i] else: for i in range(0, len(self._theta), 2 if use_bias else 1): if use_bias: init_params(self._theta[i], self._theta[i + 1]) else: init_params(self._theta[i])
def __init__(self, in_channels, out_channels, in_height, in_width, kernel_size, stride=1, padding=0, bias=True, no_weights=False): super(LocalConv2dLayer, self).__init__() if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size) if isinstance(stride, int): stride = (stride, stride) if isinstance(padding, int): padding = (padding, padding) self._in_channels = in_channels self._out_channels = out_channels self._in_height = in_height self._in_width = in_width self._kernel_size = kernel_size self._stride = stride self._padding = padding self._has_bias = bias self._no_weights = no_weights self._out_height = (in_height - kernel_size[0] + 2 * padding[0]) // \ stride[0] + 1 self._out_width = (in_width - kernel_size[1] + 2 * padding[1]) // \ stride[1] + 1 # Size of a single receptive field. rf_size = in_channels * kernel_size[0] * kernel_size[1] self._rf_size = rf_size # Number of pixels per output feature map. num_pix = self._out_height * self._out_width self._num_pix = num_pix self._weights = None self._param_shapes = [[out_channels, rf_size, num_pix]] if bias: self._param_shapes.append([out_channels, num_pix]) if not no_weights: self._weights = nn.ParameterList() self.register_parameter('filters', nn.Parameter( \ torch.Tensor(*self._param_shapes[0]), requires_grad=True)) self._weights.append(self.filters) if bias: self.register_parameter('bias', nn.Parameter( \ torch.Tensor(*self._param_shapes[1]), requires_grad=True)) self._weights.append(self.bias) init_params(self.filters, self.bias) else: self.register_parameter('bias', None) init_params(self.filters)
def __init__(self, weight_shapes, activation_fn=torch.nn.ReLU(), use_bias=True, no_weights=False, init_weights=None, dropout_rate=-1, out_fn=None, verbose=True, use_spectral_norm=False, use_batch_norm=False): """Initialize the network. Args: weight_shapes: A list of list of integers, denoting the shape of each parameter tensor in this network. Note, this parameter only has an effect on the construction of this network, if "no_weights" is False. Otherwise, it is just used to check the shapes of the input to the network in the forward method. activation_fn: The nonlinearity used in hidden layers. If None, no nonlinearity will be applied. use_bias: Whether layers may have bias terms. no_weights: If set to True, no trainable parameters will be constructed, i.e., weights are assumed to be produced ad-hoc by a hypernetwork and passed to the forward function. init_weights (optional): This option is for convinience reasons. The option expects a list of parameter values that are used to initialize the network weights. As such, it provides a convinient way of initializing a network with a weight draw produced by the hypernetwork. dropout_rate: If -1, no dropout will be applied. Otherwise a number between 0 and 1 is expected, denoting the dropout rate of hidden layers. out_fn (optional): If provided, this function will be applied to the output neurons of the network. Note, this changes the output of the forward method. verbose: Whether to print the number of weights in the network. use_spectral_norm: Use spectral normalization for training. use_batch_norm: Whether batch normalization should be used. """ # FIXME find a way using super to handle multiple inheritence. #super(MainNetwork, self).__init__() nn.Module.__init__(self) MainNetInterface.__init__(self) warn('Please use class "mnets.mlp.MLP" instead.', DeprecationWarning) if use_spectral_norm: raise NotImplementedError('Spectral normalization not yet ' + 'implemented for this network.') if use_batch_norm: raise NotImplementedError('Batch normalization not yet ' + 'implemented for this network.') assert(len(weight_shapes) > 0) self._all_shapes = weight_shapes self._has_bias = use_bias self._a_fun = activation_fn assert(init_weights is None or no_weights is False) self._no_weights = no_weights self._dropout_rate = dropout_rate self._out_fn = out_fn self._has_fc_out = True if use_spectral_norm and no_weights: raise ValueError('Cannot use spectral norm in a network without ' + 'parameters.') if use_spectral_norm: self._spec_norm = nn.utils.spectral_norm else: self._spec_norm = lambda x : x # identity if verbose: print('Creating an MLP with %d weights' \ % (MnetAPIV2.shapes_to_num_weights(self._all_shapes)) + (', that uses dropout.' if dropout_rate != -1 else '.')) if dropout_rate != -1: assert(dropout_rate >= 0. and dropout_rate <= 1.) self._dropout = nn.Dropout(p=dropout_rate) self._weights = None if no_weights: self._hyper_shapes = self._all_shapes self._is_properly_setup() return ### Define and initialize network weights. # Each odd entry of this list will contain a weight Tensor and each # even entry a bias vector. self._weights = nn.ParameterList() for i, dims in enumerate(self._all_shapes): self._weights.append(nn.Parameter(torch.Tensor(*dims), requires_grad=True)) if init_weights is not None: assert(len(init_weights) == len(self._all_shapes)) for i in range(len(init_weights)): assert(np.all(np.equal(list(init_weights[i].shape), list(self._weights[i].shape)))) self._weights[i].data = init_weights[i] else: for i in range(0, len(self._weights), 2 if use_bias else 1): if use_bias: init_params(self._weights[i], self._weights[i + 1]) else: init_params(self._weights[i]) self._is_properly_setup()
def __init__(self, in_shape=[32, 32, 3], num_classes=10, verbose=True, n=5, no_weights=False, init_weights=None, use_batch_norm=True, bn_track_stats=True, distill_bn_stats=False, use_context_mod=False, context_mod_inputs=False, no_last_layer_context_mod=False, context_mod_no_weights=False, context_mod_post_activation=False, context_mod_gain_offset=False, context_mod_apply_pixel_wise=False): super(ResNet, self).__init__(num_classes, verbose) self._in_shape = in_shape self._n = n assert (init_weights is None or \ (not no_weights or not context_mod_no_weights)) self._no_weights = no_weights assert (not use_batch_norm or (not distill_bn_stats or bn_track_stats)) self._use_batch_norm = use_batch_norm self._bn_track_stats = bn_track_stats self._distill_bn_stats = distill_bn_stats and use_batch_norm self._use_context_mod = use_context_mod self._context_mod_inputs = context_mod_inputs self._no_last_layer_context_mod = no_last_layer_context_mod self._context_mod_no_weights = context_mod_no_weights self._context_mod_post_activation = context_mod_post_activation self._context_mod_gain_offset = context_mod_gain_offset self._context_mod_apply_pixel_wise = context_mod_apply_pixel_wise self._kernel_size = [3, 3] self._filter_sizes = [16, 16, 32, 64] self._has_bias = True self._has_fc_out = True # We need to make sure that the last 2 entries of `weights` correspond # to the weight matrix and bias vector of the last layer. self._mask_fc_out = True # We don't use any output non-linearity. self._has_linear_out = True self._param_shapes = [] self._weights = None if no_weights and context_mod_no_weights \ else nn.ParameterList() self._hyper_shapes_learned = None \ if not no_weights and not context_mod_no_weights else [] ################################################# ### Define and initialize context mod weights ### ################################################# self._context_mod_layers = nn.ModuleList() if use_context_mod else None self._context_mod_shapes = [] if use_context_mod else None if use_context_mod: cm_ind = 0 cm_shapes = [] # Output shape of all layers. if context_mod_inputs: cm_shapes.append([in_shape[2], *in_shape[:2]]) layer_out_shapes = self._compute_layer_out_sizes() if no_last_layer_context_mod: cm_shapes.extend(layer_out_shapes[:-1]) else: cm_shapes.extend(layer_out_shapes) if not context_mod_apply_pixel_wise: # Only scalar gain and shift per feature map! for i, s in enumerate(cm_shapes): if len(s) == 3: cm_shapes[i] = [s[0], 1, 1] for i, s in enumerate(cm_shapes): cmod_layer = ContextModLayer( s, no_weights=context_mod_no_weights, apply_gain_offset=context_mod_gain_offset) self._context_mod_layers.append(cmod_layer) self.param_shapes.extend(cmod_layer.param_shapes) self._context_mod_shapes.extend(cmod_layer.param_shapes) if context_mod_no_weights: self._hyper_shapes_learned.extend(cmod_layer.param_shapes) else: self._weights.extend(cmod_layer.weights) # FIXME ugly code. Move initialization somewhere else. if not context_mod_no_weights and init_weights is not None: assert (len(cmod_layer.weights) == 2) for ii in range(2): assert (np.all(np.equal( \ list(init_weights[cm_ind].shape), list(cm_ind.weights[ii].shape)))) cmod_layer.weights[ii].data = init_weights[cm_ind] cm_ind += 1 if init_weights is not None: init_weights = init_weights[cm_ind:] ########################### ### Print infos to user ### ########################### # Compute the total number of weights in this network and display # them to the user. # Note, this complicated calculation is not necessary as we can simply # count the number of weights afterwards. But it's an additional sanity # check for us. fs = self._filter_sizes num_weights = np.prod(self._kernel_size) * \ (in_shape[2] * fs[0] + np.sum([fs[i] * fs[i + 1] + \ (2 * n - 1) * fs[i + 1] ** 2 for i in range(3)])) + \ (fs[0] + 2 * n * np.sum([fs[i] for i in range(1, 4)])) + \ (fs[-1] * num_classes + num_classes) cm_num_weights = MainNetInterface.shapes_to_num_weights( \ self._context_mod_shapes) if use_context_mod else 0 num_weights += cm_num_weights if use_batch_norm: # The gamma and beta parameters of a batch norm layer are # learned as well. num_weights += 2 * (fs[0] + \ 2 * n * np.sum([fs[i] for i in range(1, 4)])) if verbose: print('A ResNet with %d layers and %d weights is created' \ % (6 * n + 2, num_weights) + (' (including %d context-mod weights).' % cm_num_weights \ if cm_num_weights > 0 else '.')) ################################################ ### Define and initialize batch norm weights ### ################################################ self._batchnorm_layers = nn.ModuleList() if use_batch_norm else None if use_batch_norm: if distill_bn_stats: self._hyper_shapes_distilled = [] bn_ind = 0 for i, s in enumerate(self._filter_sizes): if i == 0: num = 1 else: num = 2 * n for j in range(num): bn_layer = BatchNormLayer( s, affine=not no_weights, track_running_stats=bn_track_stats) self._batchnorm_layers.append(bn_layer) if distill_bn_stats: self._hyper_shapes_distilled.extend( \ [list(p.shape) for p in bn_layer.get_stats(0)]) if not no_weights and init_weights is not None: assert (len(bn_layer.weights) == 2) for ii in range(2): assert (np.all(np.equal( \ list(init_weights[bn_ind].shape), list(bn_layer.weights[ii].shape)))) bn_layer.weights[ii].data = init_weights[bn_ind] bn_ind += 1 if init_weights is not None: init_weights = init_weights[bn_ind:] # Note, method `_compute_hyper_shapes` doesn't take context-mod into # consideration. self._param_shapes.extend(self._compute_hyper_shapes(no_weights=True)) assert (num_weights == \ MainNetInterface.shapes_to_num_weights(self._param_shapes)) self._layer_weight_tensors = nn.ParameterList() self._layer_bias_vectors = nn.ParameterList() if no_weights: if self._hyper_shapes_learned is None: self._hyper_shapes_learned = self._compute_hyper_shapes() else: # Context-mod weights are already included. self._hyper_shapes_learned.extend(self._compute_hyper_shapes()) self._is_properly_setup() return if use_batch_norm: for bn_layer in self._batchnorm_layers: self._weights.extend(bn_layer.weights) ############################################ ### Define and initialize layer weights ### ########################################### ### Does not include context-mod or batchnorm weights. # First layer. self._layer_weight_tensors.append( nn.Parameter(torch.Tensor(self._filter_sizes[0], self._in_shape[2], *self._kernel_size), requires_grad=True)) self._layer_bias_vectors.append( nn.Parameter(torch.Tensor(self._filter_sizes[0]), requires_grad=True)) # Each block consists of 2n layers. for i in range(1, len(self._filter_sizes)): in_filters = self._filter_sizes[i - 1] out_filters = self._filter_sizes[i] for _ in range(2 * n): self._layer_weight_tensors.append( nn.Parameter(torch.Tensor(out_filters, in_filters, *self._kernel_size), requires_grad=True)) self._layer_bias_vectors.append( nn.Parameter(torch.Tensor(out_filters), requires_grad=True)) # Note, that the first layer in this block has potentially a # different number of input filters. in_filters = out_filters # After the average pooling, there is one more dense layer. self._layer_weight_tensors.append( nn.Parameter(torch.Tensor(num_classes, self._filter_sizes[-1]), requires_grad=True)) self._layer_bias_vectors.append( nn.Parameter(torch.Tensor(num_classes), requires_grad=True)) # We add the weights interleaved, such that there are always consecutive # weight tensor and bias vector per layer. This fulfils the requirements # of attribute `mask_fc_out`. for i in range(len(self._layer_weight_tensors)): self._weights.append(self._layer_weight_tensors[i]) self._weights.append(self._layer_bias_vectors[i]) ### Initialize weights. if init_weights is not None: num_layers = 6 * n + 2 assert (len(init_weights) == 2 * num_layers) offset = 0 if use_batch_norm: offset = 2 * (6 * n + 1) assert (len(self._weights) == offset + 2 * num_layers) for i in range(len(init_weights)): j = offset + i assert (np.all( np.equal(list(init_weights[i].shape), list(self._weights[j].shape)))) self._weights[j].data = init_weights[i] else: for i in range(len(self._layer_weight_tensors)): init_params(self._layer_weight_tensors[i], self._layer_bias_vectors[i]) self._is_properly_setup()
def __init__(self, n_in=1, n_out=1, hidden_layers=[2000, 2000], activation_fn=torch.nn.ReLU(), use_bias=True, no_weights=False, init_weights=None, dropout_rate=-1, use_spectral_norm=False, use_batch_norm=False, bn_track_stats=True, distill_bn_stats=False, use_context_mod=False, context_mod_inputs=False, no_last_layer_context_mod=False, context_mod_no_weights=False, context_mod_post_activation=False, context_mod_gain_offset=False, out_fn=None, verbose=True): # FIXME find a way using super to handle multiple inheritence. #super(MainNetwork, self).__init__() nn.Module.__init__(self) MainNetInterface.__init__(self) if use_spectral_norm: raise NotImplementedError('Spectral normalization not yet ' + 'implemented for this network.') if use_batch_norm and use_context_mod: # FIXME Does it make sense to have both enabled? # I.e., should we produce a warning or error? pass self._a_fun = activation_fn assert(init_weights is None or \ (not no_weights or not context_mod_no_weights)) self._no_weights = no_weights self._dropout_rate = dropout_rate #self._use_spectral_norm = use_spectral_norm self._use_batch_norm = use_batch_norm self._bn_track_stats = bn_track_stats self._distill_bn_stats = distill_bn_stats and use_batch_norm self._use_context_mod = use_context_mod self._context_mod_inputs = context_mod_inputs self._no_last_layer_context_mod = no_last_layer_context_mod self._context_mod_no_weights = context_mod_no_weights self._context_mod_post_activation = context_mod_post_activation self._context_mod_gain_offset = context_mod_gain_offset self._out_fn = out_fn self._has_bias = use_bias self._has_fc_out = True # We need to make sure that the last 2 entries of `weights` correspond # to the weight matrix and bias vector of the last layer. self._mask_fc_out = True self._has_linear_out = True if out_fn is None else False if use_spectral_norm and no_weights: raise ValueError('Cannot use spectral norm in a network without ' + 'parameters.') # FIXME make sure that this implementation is correct in all situations # (e.g., what to do if weights are passed to the forward method?). if use_spectral_norm: self._spec_norm = nn.utils.spectral_norm else: self._spec_norm = lambda x: x # identity self._param_shapes = [] self._weights = None if no_weights and context_mod_no_weights \ else nn.ParameterList() self._hyper_shapes_learned = None \ if not no_weights and not context_mod_no_weights else [] if dropout_rate != -1: assert (dropout_rate >= 0. and dropout_rate <= 1.) self._dropout = nn.Dropout(p=dropout_rate) ### Define and initialize context mod weights. self._context_mod_layers = nn.ModuleList() if use_context_mod else None self._context_mod_shapes = [] if use_context_mod else None if use_context_mod: cm_ind = 0 cm_sizes = [] if context_mod_inputs: cm_sizes.append(n_in) cm_sizes.extend(hidden_layers) if not no_last_layer_context_mod: cm_sizes.append(n_out) for i, n in enumerate(cm_sizes): cmod_layer = ContextModLayer( n, no_weights=context_mod_no_weights, apply_gain_offset=context_mod_gain_offset) self._context_mod_layers.append(cmod_layer) self.param_shapes.extend(cmod_layer.param_shapes) self._context_mod_shapes.extend(cmod_layer.param_shapes) if context_mod_no_weights: self._hyper_shapes_learned.extend(cmod_layer.param_shapes) else: self._weights.extend(cmod_layer.weights) # FIXME ugly code. Move initialization somewhere else. if not context_mod_no_weights and init_weights is not None: assert (len(cmod_layer.weights) == 2) for ii in range(2): assert(np.all(np.equal( \ list(init_weights[cm_ind].shape), list(cm_ind.weights[ii].shape)))) cmod_layer.weights[ii].data = init_weights[cm_ind] cm_ind += 1 if init_weights is not None: init_weights = init_weights[cm_ind:] print('hidden_layers', hidden_layers) ### Define and initialize batch norm weights. self._batchnorm_layers = nn.ModuleList() if use_batch_norm else None if use_batch_norm: if distill_bn_stats: self._hyper_shapes_distilled = [] bn_ind = 0 for i, n in enumerate(hidden_layers): bn_layer = BatchNormLayer(n, affine=not no_weights, track_running_stats=bn_track_stats) self._batchnorm_layers.append(bn_layer) self._param_shapes.extend(bn_layer.param_shapes) if no_weights: self._hyper_shapes_learned.extend(bn_layer.param_shapes) else: self._weights.extend(bn_layer.weights) if distill_bn_stats: self._hyper_shapes_distilled.extend( \ [list(p.shape) for p in bn_layer.get_stats(0)]) # FIXME ugly code. Move initialization somewhere else. if not no_weights and init_weights is not None: assert (len(bn_layer.weights) == 2) for ii in range(2): assert(np.all(np.equal( \ list(init_weights[bn_ind].shape), list(bn_layer.weights[ii].shape)))) bn_layer.weights[ii].data = init_weights[bn_ind] bn_ind += 1 if init_weights is not None: init_weights = init_weights[bn_ind:] # Compute shapes of linear layers. linear_shapes = MLP.weight_shapes(n_in=n_in, n_out=n_out, hidden_layers=hidden_layers, use_bias=use_bias) self._param_shapes.extend(linear_shapes) num_weights = MainNetInterface.shapes_to_num_weights( self._param_shapes) if verbose: if use_context_mod: cm_num_weights = 0 for cm_layer in self._context_mod_layers: cm_num_weights += MainNetInterface.shapes_to_num_weights( \ cm_layer.param_shapes) print( 'Creating an MLP with %d weights' % num_weights + (' (including %d weights associated with-' % cm_num_weights + 'context modulation)' if use_context_mod else '') + '.' + (' The network uses dropout.' if dropout_rate != -1 else '') + (' The network uses batchnorm.' if use_batch_norm else '')) self._layer_weight_tensors = nn.ParameterList() self._layer_bias_vectors = nn.ParameterList() if no_weights: self._hyper_shapes_learned.extend(linear_shapes) self._is_properly_setup() return ### Define and initialize linear weights. for i, dims in enumerate(linear_shapes): self._weights.append( nn.Parameter(torch.Tensor(*dims), requires_grad=True)) if len(dims) == 1: self._layer_bias_vectors.append(self._weights[-1]) else: self._layer_weight_tensors.append(self._weights[-1]) if init_weights is not None: assert (len(init_weights) == len(linear_shapes)) for i in range(len(init_weights)): assert (np.all( np.equal(list(init_weights[i].shape), linear_shapes[i]))) if use_bias: if i % 2 == 0: self._layer_weight_tensors[i // 2].data = init_weights[i] else: self._layer_bias_vectors[i // 2].data = init_weights[i] else: self._layer_weight_tensors[i].data = init_weights[i] else: for i in range(len(self._layer_weight_tensors)): if use_bias: init_params(self._layer_weight_tensors[i], self._layer_bias_vectors[i]) else: init_params(self._layer_weight_tensors[i]) self._is_properly_setup()
def __init__(self, in_shape=(32, 32, 3), num_classes=10, no_weights=False, init_weights=None, use_context_mod=False, context_mod_inputs=False, no_last_layer_context_mod=False, context_mod_no_weights=False, context_mod_post_activation=False, context_mod_gain_offset=False, context_mod_gain_softplus=False, context_mod_apply_pixel_wise=False): super(BioConvNet, self).__init__(num_classes, True) assert(len(in_shape) == 3) # FIXME This assertion is not mandatory but a sanity check that the user # uses the Tensorflow layout. assert(in_shape[2] in [1, 3]) assert(init_weights is None or \ (not no_weights or not context_mod_no_weights)) self._in_shape = in_shape self._no_weights = no_weights self._use_context_mod = use_context_mod self._context_mod_inputs = context_mod_inputs self._no_last_layer_context_mod = no_last_layer_context_mod self._context_mod_no_weights = context_mod_no_weights self._context_mod_post_activation = context_mod_post_activation self._context_mod_gain_offset = context_mod_gain_offset self._context_mod_gain_softplus = context_mod_gain_softplus self._context_mod_apply_pixel_wise = context_mod_apply_pixel_wise self._has_bias = True self._has_fc_out = True # We need to make sure that the last 2 entries of `weights` correspond # to the weight matrix and bias vector of the last layer. self._mask_fc_out = True self._has_linear_out = True self._param_shapes = [] self._weights = None if no_weights and context_mod_no_weights \ else nn.ParameterList() self._hyper_shapes_learned = None \ if not no_weights and not context_mod_no_weights else [] self._layer_weight_tensors = nn.ParameterList() self._layer_bias_vectors = nn.ParameterList() # Shapes of output activities for context-modulation, if used. cm_shapes = [] # Output shape of all layers. if context_mod_inputs: cm_shapes.append([in_shape[2], *in_shape[:2]]) ### Define and initialize all conv and linear layers ### Bio-conv layers. H = in_shape[0] W = in_shape[1] C_in = in_shape[2] C = [64, 128, 256] K = [5, 5, 3] S = [2, 2, 1] P = [0, 0, 1] self._conv_layer = [] for i, C_out in enumerate(C): self._conv_layer.append(LocalConv2dLayer(C_in, C_out, H, W, K[i], stride=S[i], padding=P[i], no_weights=no_weights)) H = self._conv_layer[-1].out_height W = self._conv_layer[-1].out_width cm_shapes.append([C_out, H, W]) C_in = C_out self._param_shapes.extend(self._conv_layer[-1].param_shapes) if no_weights: self._hyper_shapes_learned.extend( \ self._conv_layer[-1].param_shapes) else: self._weights.extend(self._conv_layer[-1].weights) assert(len(self._conv_layer[-1].weights) == 2) self._layer_weight_tensors.append( \ self._conv_layer[-1].filters) self._layer_bias_vectors.append( \ self._conv_layer[-1].bias) ### Linear layers n_in = H * W * C_out assert(n_in == 6400) n = [1024, num_classes] for i, n_out in enumerate(n): W_shape = [n_out, n_in] b_shape = [n_out] # Note, that the last layer shape might not be used for context- # modulation. if i < (len(n)-1) or not no_last_layer_context_mod: cm_shapes.append([n_out]) n_in = n_out self._param_shapes.extend([W_shape, b_shape]) if no_weights: self._hyper_shapes_learned.extend([W_shape, b_shape]) else: W = nn.Parameter(torch.Tensor(*W_shape), requires_grad=True) b = nn.Parameter(torch.Tensor(*b_shape), requires_grad=True) init_params(W, b) self._weights.extend([W, b]) self._layer_weight_tensors.append(W) self._layer_bias_vectors.append(b) ### Define and initialize context mod weights. self._context_mod_layers = nn.ModuleList() if use_context_mod else None self._context_mod_shapes = [] if use_context_mod else None self._context_mod_weights = nn.ParameterList() if use_context_mod \ else None if use_context_mod: if not context_mod_apply_pixel_wise: # Only scalar gain and shift per feature map! for i, s in enumerate(cm_shapes): if len(s) == 3: cm_shapes[i] = [s[0], 1, 1] for i, s in enumerate(cm_shapes): cmod_layer = ContextModLayer(s, no_weights=context_mod_no_weights, apply_gain_offset=context_mod_gain_offset, apply_gain_softplus=context_mod_gain_softplus) self._context_mod_layers.append(cmod_layer) self._context_mod_shapes.extend(cmod_layer.param_shapes) if not context_mod_no_weights: self._context_mod_weights.extend(cmod_layer.weights) # We always had the context mod weights/shapes at the beginning of # our list attributes. self._param_shapes = self._context_mod_shapes + self._param_shapes if context_mod_no_weights: self._hyper_shapes_learned = self._context_mod_shapes + \ self._hyper_shapes_learned else: tmp = self._weights self._weights = nn.ParameterList(self._context_mod_weights) for w in tmp: self._weights.append(w) ### Apply custom init if given. if init_weights is not None: assert(len(self.weights) == len(init_weights)) for i in range(len(init_weights)): assert(np.all(np.equal(list(init_weights[i].shape), list(self._weights[i].shape)))) self._weights[i].data = init_weights[i] ### Print user info. num_weights = MainNetInterface.shapes_to_num_weights( \ self._param_shapes) if use_context_mod: cm_num_weights = MainNetInterface.shapes_to_num_weights( \ self._context_mod_shapes) print('Creating bio-plausible convnet with %d weights' % num_weights + (' (including %d weights associated with-' % cm_num_weights + 'context modulation)' if use_context_mod else '') + '.') self._is_properly_setup()
def __init__(self, in_shape=(32, 32, 3), num_classes=10, n=4, k=10, num_feature_maps=(16, 16, 32, 64), use_bias=True, use_fc_bias=None, no_weights=False, use_batch_norm=True, bn_track_stats=True, distill_bn_stats=False, dropout_rate=-1, chw_input_format=False, verbose=True, **kwargs): super(WRN, self).__init__(num_classes, verbose) ### Parse or set context-mod arguments ### rem_kwargs = MainNetInterface._parse_context_mod_args(kwargs) if len(rem_kwargs) > 0: raise ValueError('Keyword arguments %s unknown.' % str(rem_kwargs)) # Since this is a conv-net, we may also want to add the following. if 'context_mod_apply_pixel_wise' not in kwargs.keys(): kwargs['context_mod_apply_pixel_wise'] = False self._use_context_mod = kwargs['use_context_mod'] self._context_mod_inputs = kwargs['context_mod_inputs'] self._no_last_layer_context_mod = kwargs['no_last_layer_context_mod'] self._context_mod_no_weights = kwargs['context_mod_no_weights'] self._context_mod_post_activation = \ kwargs['context_mod_post_activation'] self._context_mod_gain_offset = kwargs['context_mod_gain_offset'] self._context_mod_gain_softplus = kwargs['context_mod_gain_softplus'] self._context_mod_apply_pixel_wise = \ kwargs['context_mod_apply_pixel_wise'] ### Check or parse remaining arguments ### self._in_shape = in_shape self._n = n self._k = k if use_fc_bias is None: use_fc_bias = use_bias # Also, checkout attribute `_has_bias` below. self._use_bias = use_bias self._use_fc_bias = use_fc_bias self._no_weights = no_weights assert not use_batch_norm or (not distill_bn_stats or bn_track_stats) self._use_batch_norm = use_batch_norm self._bn_track_stats = bn_track_stats self._distill_bn_stats = distill_bn_stats and use_batch_norm self._dropout_rate = dropout_rate self._chw_input_format = chw_input_format # The original authors found that the best configuration uses this # kernel in all convolutional layers. self._kernel_size = (3, 3) if len(num_feature_maps) != 4: raise ValueError('Option "num_feature_maps" must be a list of 4 ' + 'integers.') self._filter_sizes = list(num_feature_maps) if k != 1: for i in range(1, 4): self._filter_sizes[i] = k * num_feature_maps[i] # Strides used in the first layer of each convolutional group. self._strides = (1, 1, 2, 2) ### Set required class attributes ### # Note, we did overwrite the getter for attribute `has_bias`, as it is # not applicable if the values of `use_bias` and `use_fc_bias` differ. self._has_bias = use_bias if use_bias == use_fc_bias else False self._has_fc_out = True # We need to make sure that the last 2 entries of `weights` correspond # to the weight matrix and bias vector of the last layer! self._mask_fc_out = True self._has_linear_out = True self._param_shapes = [] self._param_shapes_meta = [] self._internal_params = None if no_weights and \ self._context_mod_no_weights else nn.ParameterList() self._hyper_shapes_learned = None \ if not no_weights and not self._context_mod_no_weights else [] self._hyper_shapes_learned_ref = None if self._hyper_shapes_learned \ is None else [] self._layer_weight_tensors = nn.ParameterList() self._layer_bias_vectors = nn.ParameterList() if dropout_rate != -1: assert dropout_rate >= 0. and dropout_rate <= 1. self._dropout = nn.Dropout(p=dropout_rate) ################################# ### Create context mod layers ### ################################# self._context_mod_layers = nn.ModuleList() if self._use_context_mod \ else None if self._use_context_mod: cm_layer_inds = [] cm_shapes = [] # Output shape of all layers. if self._context_mod_inputs: cm_shapes.append([in_shape[2], *in_shape[:2]]) # We reserve layer zero for input context-mod. Otherwise, there # is no layer zero. cm_layer_inds.append(0) layer_out_shapes = self._compute_layer_out_sizes() cm_shapes.extend(layer_out_shapes) # All layer indices `l` with `l mod 3 == 0` are context-mod layers. cm_layer_inds.extend(range(3, 3*len(layer_out_shapes)+1, 3)) if self._no_last_layer_context_mod: cm_shapes = cm_shapes[:-1] cm_layer_inds = cm_layer_inds[:-1] if not self._context_mod_apply_pixel_wise: # Only scalar gain and shift per feature map! for i, s in enumerate(cm_shapes): if len(s) == 3: cm_shapes[i] = [s[0], 1, 1] self._add_context_mod_layers(cm_shapes, cm_layers=cm_layer_inds) ############################### ### Create batchnorm layers ### ############################### # We just use even numbers starting from 2 as layer indices for # batchnorm layers. if use_batch_norm: bn_sizes = [] for i, s in enumerate(self._filter_sizes): if i == 0: bn_sizes.append(s) else: bn_sizes.extend([s] * (2*n)) # All layer indices `l` with `l mod 3 == 2` are batchnorm layers. self._add_batchnorm_layers(bn_sizes, no_weights, bn_layers=list(range(2, 3*len(bn_sizes)+1, 3)), distill_bn_stats=distill_bn_stats, bn_track_stats=bn_track_stats) ###################################### ### Create skip connection weights ### ###################################### # We use 1x1 convolutional layers for residual blocks in case the # number of input and output feature maps disagrees. We also use 1x1 # convolutions whenever a stride greater than 1 is applied. This is not # necessary in my opinion (as it adds extra weights that do not affect # the downsampling itself), but commonly done; or instance, in the # original PyTorch implementation. # Note, there may be maximally 3 1x1 layers added to the network. # Note, we use 1x1 conv layers without biases. skip_1x1_shapes = [] self._group_has_1x1 = [False] * 3 for i in range(1, 4): if self._filter_sizes[i-1] != self._filter_sizes[i] or \ self._strides[i] != 1: skip_1x1_shapes.append([self._filter_sizes[i], self._filter_sizes[i-1], 1, 1]) self._group_has_1x1[i-1] = True for s in skip_1x1_shapes: if not no_weights: self._internal_params.append(nn.Parameter( \ torch.Tensor(*s), requires_grad=True)) self._layer_weight_tensors.append(self._internal_params[-1]) init_params(self._layer_weight_tensors[-1]) else: self._hyper_shapes_learned.append(s) self._hyper_shapes_learned_ref.append(len(self.param_shapes)) self._param_shapes.append(s) self._param_shapes_meta.append({ 'name': 'weight', 'index': -1 if no_weights else \ len(self._internal_params)-1, 'layer': -1 }) ############################################################ ### Create convolutional layers and final linear weights ### ############################################################ # Convolutional layers will get IDs `l` such that `l mod 3 == 1`. layer_id = 1 for i in range(5): if i == 0: ### Fist layer. num = 1 prev_fs = self._in_shape[2] curr_fs = self._filter_sizes[0] elif i == 4: ### Final fully-connected layer. num = 1 curr_fs = num_classes else: # Group of residual blocks. num = 2 * n curr_fs = self._filter_sizes[i] for _ in range(num): if i == 4: layer_shapes = [[curr_fs, prev_fs]] if use_fc_bias: layer_shapes.append([curr_fs]) else: layer_shapes = [[curr_fs, prev_fs, *self._kernel_size]] if use_bias: layer_shapes.append([curr_fs]) for s in layer_shapes: if not no_weights: self._internal_params.append(nn.Parameter( \ torch.Tensor(*s), requires_grad=True)) if len(s) == 1: self._layer_bias_vectors.append( \ self._internal_params[-1]) else: self._layer_weight_tensors.append( \ self._internal_params[-1]) else: self._hyper_shapes_learned.append(s) self._hyper_shapes_learned_ref.append( \ len(self.param_shapes)) self._param_shapes.append(s) self._param_shapes_meta.append({ 'name': 'weight' if len(s) != 1 else 'bias', 'index': -1 if no_weights else \ len(self._internal_params)-1, 'layer': layer_id }) prev_fs = curr_fs layer_id += 3 # Initialize_weights if not no_weights: init_params(self._layer_weight_tensors[-1], self._layer_bias_vectors[-1] \ if len(layer_shapes) == 2 else None) ########################### ### Print infos to user ### ########################### if verbose: if self._use_context_mod: cm_param_shapes = [] for cm_layer in self.context_mod_layers: cm_param_shapes.extend(cm_layer.param_shapes) cm_num_params = \ MainNetInterface.shapes_to_num_weights(cm_param_shapes) print('Creating a WideResnet "%s" with %d weights' \ % (str(self), self.num_params) + (' (including %d weights associated with-' % cm_num_params + 'context modulation)' if self._use_context_mod else '') + '.' + (' The network uses batchnorm.' if use_batch_norm else '') + (' The network uses dropout.' if dropout_rate != -1 \ else '')) self._is_properly_setup(check_has_bias=False)
def __init__(self, in_shape=(224, 224, 3), num_classes=1000, use_bias=True, use_fc_bias=None, num_feature_maps=(64, 64, 128, 256, 512), blocks_per_group=(2, 2, 2, 2), projection_shortcut=False, bottleneck_blocks=False, cutout_mod=False, no_weights=False, use_batch_norm=True, bn_track_stats=True, distill_bn_stats=False, chw_input_format=False, verbose=True, **kwargs): super(ResNetIN, self).__init__(num_classes, verbose) ### Parse or set context-mod arguments ### rem_kwargs = MainNetInterface._parse_context_mod_args(kwargs) if 'context_mod_apply_pixel_wise' in rem_kwargs: rem_kwargs.remove('context_mod_apply_pixel_wise') if len(rem_kwargs) > 0: raise ValueError('Keyword arguments %s unknown.' % str(rem_kwargs)) # Since this is a conv-net, we may also want to add the following. if 'context_mod_apply_pixel_wise' not in kwargs.keys(): kwargs['context_mod_apply_pixel_wise'] = False self._use_context_mod = kwargs['use_context_mod'] self._context_mod_inputs = kwargs['context_mod_inputs'] self._no_last_layer_context_mod = kwargs['no_last_layer_context_mod'] self._context_mod_no_weights = kwargs['context_mod_no_weights'] self._context_mod_post_activation = \ kwargs['context_mod_post_activation'] self._context_mod_gain_offset = kwargs['context_mod_gain_offset'] self._context_mod_gain_softplus = kwargs['context_mod_gain_softplus'] self._context_mod_apply_pixel_wise = \ kwargs['context_mod_apply_pixel_wise'] ### Check or parse remaining arguments ### self._in_shape = in_shape self._projection_shortcut = projection_shortcut self._bottleneck_blocks = bottleneck_blocks self._cutout_mod = cutout_mod if use_fc_bias is None: use_fc_bias = use_bias # Also, checkout attribute `_has_bias` below. self._use_bias = use_bias self._use_fc_bias = use_fc_bias self._no_weights = no_weights assert not use_batch_norm or (not distill_bn_stats or bn_track_stats) self._use_batch_norm = use_batch_norm self._bn_track_stats = bn_track_stats self._distill_bn_stats = distill_bn_stats and use_batch_norm self._chw_input_format = chw_input_format if len(blocks_per_group) != 4: raise ValueError('Option "blocks_per_group" must be a list of 4 ' + 'integers.') self._num_blocks = blocks_per_group if len(num_feature_maps) != 5: raise ValueError('Option "num_feature_maps" must be a list of 5 ' + 'integers.') self._filter_sizes = list(num_feature_maps) # The first layer of group 3, 4 and 5 uses a strided convolution, so # the shorcut connections need to perform a downsampling operation. In # addition, whenever traversing from one group to the next, the number # of feature maps might change. In all these cases, the network might # benefit from smart shortcut connections, which means using projection # shortcuts, where a 1x1 conv is used for the mentioned skip connection. self._num_non_ident_skips = 3 # Strided convs: 2->3, 3->4 and 4->5 fs1 = self._filter_sizes[1] if self._bottleneck_blocks: fs1 *= 4 if self._filter_sizes[0] != fs1: self._num_non_ident_skips += 1 # Also handle 1->2. self._group_has_1x1 = [False] * 4 if self._projection_shortcut: for i in range(3, 3 - self._num_non_ident_skips, -1): self._group_has_1x1[i] = True # Number of conv layers (excluding skip connections) self._num_main_conv_layers = 1 + int(np.sum([self._num_blocks[i] * \ (3 if self._bottleneck_blocks else 2) for i in range(4)])) # The original architecture uses a 7x7 kernel in the first conv layer # and 3x3 or 1x1 kernels in all remaining layers. self._init_kernel_size = (7, 7) # All 3x3 layers have padding 1 and 1x1 layers have padding 0. self._init_padding = 3 self._init_stride = 2 if self._cutout_mod: self._init_kernel_size = (3, 3) self._init_padding = 1 self._init_stride = 1 ### Set required class attributes ### # Note, we did overwrite the getter for attribute `has_bias`, as it is # not applicable if the values of `use_bias` and `use_fc_bias` differ. self._has_bias = use_bias if use_bias == use_fc_bias else False self._has_fc_out = True # We need to make sure that the last 2 entries of `weights` correspond # to the weight matrix and bias vector of the last layer! self._mask_fc_out = True self._has_linear_out = True self._param_shapes = [] self._param_shapes_meta = [] self._internal_params = None if no_weights and \ self._context_mod_no_weights else nn.ParameterList() self._hyper_shapes_learned = None \ if not no_weights and not self._context_mod_no_weights else [] self._hyper_shapes_learned_ref = None if self._hyper_shapes_learned \ is None else [] self._layer_weight_tensors = nn.ParameterList() self._layer_bias_vectors = nn.ParameterList() ################################# ### Create context mod layers ### ################################# self._context_mod_layers = nn.ModuleList() if self._use_context_mod \ else None if self._use_context_mod: cm_layer_inds = [] cm_shapes = [] # Output shape of all layers. if self._context_mod_inputs: cm_shapes.append([in_shape[2], *in_shape[:2]]) # We reserve layer zero for input context-mod. Otherwise, there # is no layer zero. cm_layer_inds.append(0) layer_out_shapes = self._compute_layer_out_sizes() cm_shapes.extend(layer_out_shapes) # All layer indices `l` with `l mod 3 == 0` are context-mod layers. cm_layer_inds.extend(range(3, 3 * len(layer_out_shapes) + 1, 3)) if self._no_last_layer_context_mod: cm_shapes = cm_shapes[:-1] cm_layer_inds = cm_layer_inds[:-1] if not self._context_mod_apply_pixel_wise: # Only scalar gain and shift per feature map! for i, s in enumerate(cm_shapes): if len(s) == 3: cm_shapes[i] = [s[0], 1, 1] self._add_context_mod_layers(cm_shapes, cm_layers=cm_layer_inds) ############################### ### Create batchnorm layers ### ############################### # We just use even numbers starting from 2 as layer indices for # batchnorm layers. if use_batch_norm: bn_sizes = [] for i, s in enumerate(self._filter_sizes): if i == 0: bn_sizes.append(s) else: for _ in range(self._num_blocks[i - 1]): if self._bottleneck_blocks: bn_sizes.extend([s, s, 4 * s]) else: bn_sizes.extend([s, s]) # All layer indices `l` with `l mod 3 == 2` are batchnorm layers. bn_layers = list(range(2, 3 * len(bn_sizes) + 1, 3)) # We also need a batchnorm layer per skip connection that uses 1x1 # projections. if self._projection_shortcut: bn_layer_ind_skip = 3 * (self._num_main_conv_layers + 1) + 2 factor = 4 if self._bottleneck_blocks else 1 for i in range(4): # For each transition between conv groups. if self._group_has_1x1[i]: bn_sizes.append(self._filter_sizes[i + 1] * factor) bn_layers.append(bn_layer_ind_skip) bn_layer_ind_skip += 3 self._add_batchnorm_layers(bn_sizes, no_weights, bn_layers=bn_layers, distill_bn_stats=distill_bn_stats, bn_track_stats=bn_track_stats) ###################################### ### Create skip connection weights ### ###################################### if self._projection_shortcut: layer_ind_skip = 3 * (self._num_main_conv_layers + 1) + 1 factor = 4 if self._bottleneck_blocks else 1 n_in = self._filter_sizes[0] for i in range(4): # For each transition between conv groups. if not self._group_has_1x1[i]: continue n_out = self._filter_sizes[i + 1] * factor skip_1x1_shape = [n_out, n_in, 1, 1] if not no_weights: self._internal_params.append(nn.Parameter( \ torch.Tensor(*skip_1x1_shape), requires_grad=True)) self._layer_weight_tensors.append( self._internal_params[-1]) self._layer_bias_vectors.append(None) init_params(self._layer_weight_tensors[-1]) else: self._hyper_shapes_learned.append(skip_1x1_shape) self._hyper_shapes_learned_ref.append( \ len(self.param_shapes)) self._param_shapes.append(skip_1x1_shape) self._param_shapes_meta.append({ 'name': 'weight', 'index': -1 if no_weights else \ len(self._internal_params)-1, 'layer': layer_ind_skip }) layer_ind_skip += 3 n_in = n_out ############################################################ ### Create convolutional layers and final linear weights ### ############################################################ # Convolutional layers will get IDs `l` such that `l mod 3 == 1`. layer_id = 1 n_per_block = 3 if self._bottleneck_blocks else 2 for i in range(6): if i == 0: ### Fist layer. num = 1 prev_fs = self._in_shape[2] curr_fs = self._filter_sizes[0] kernel_size = self._init_kernel_size #stride = self._init_stride elif i == 5: ### Final fully-connected layer. num = 1 curr_fs = num_classes kernel_size = None else: # Group of residual blocks. num = self._num_blocks[i - 1] * n_per_block curr_fs = self._filter_sizes[i] kernel_size = (3, 3) # depends on block structure! for n in range(num): if i == 5: layer_shapes = [[curr_fs, prev_fs]] if use_fc_bias: layer_shapes.append([curr_fs]) prev_fs = curr_fs else: if i > 0 and self._bottleneck_blocks: if n % 3 == 0: fs = curr_fs ks = (1, 1) elif n % 3 == 1: fs = curr_fs ks = kernel_size else: fs = 4 * curr_fs ks = (1, 1) elif i > 0 and not self._bottleneck_blocks: fs = curr_fs ks = kernel_size else: fs = curr_fs ks = kernel_size layer_shapes = [[fs, prev_fs, *ks]] if use_bias: layer_shapes.append([fs]) prev_fs = fs for s in layer_shapes: if not no_weights: self._internal_params.append(nn.Parameter( \ torch.Tensor(*s), requires_grad=True)) if len(s) == 1: self._layer_bias_vectors.append( \ self._internal_params[-1]) else: self._layer_weight_tensors.append( \ self._internal_params[-1]) else: self._hyper_shapes_learned.append(s) self._hyper_shapes_learned_ref.append( \ len(self.param_shapes)) self._param_shapes.append(s) self._param_shapes_meta.append({ 'name': 'weight' if len(s) != 1 else 'bias', 'index': -1 if no_weights else \ len(self._internal_params)-1, 'layer': layer_id }) layer_id += 3 # Initialize_weights if not no_weights: init_params(self._layer_weight_tensors[-1], self._layer_bias_vectors[-1] \ if len(layer_shapes) == 2 else None) ########################### ### Print infos to user ### ########################### if verbose: if self._use_context_mod: cm_param_shapes = [] for cm_layer in self.context_mod_layers: cm_param_shapes.extend(cm_layer.param_shapes) cm_num_params = \ MainNetInterface.shapes_to_num_weights(cm_param_shapes) print('Creating a "%s" with %d weights' \ % (str(self), self.num_params) + (' (including %d weights associated with-' % cm_num_params + 'context modulation)' if self._use_context_mod else '') + '.' + (' The network uses batchnorm.' if use_batch_norm else '')) self._is_properly_setup(check_has_bias=False)
def __init__( self, in_shape=(28, 28, 1), num_classes=10, verbose=True, arch='mnist_large', no_weights=False, init_weights=None, dropout_rate=-1, #0.5 **kwargs): super(LeNet, self).__init__(num_classes, verbose) self._in_shape = in_shape assert arch in LeNet._ARCHITECTURES.keys() self._chosen_arch = LeNet._ARCHITECTURES[arch] if num_classes != 10: self._chosen_arch[-2][0] = num_classes self._chosen_arch[-1][0] = num_classes # Sanity check, given current implementation. if arch.startswith('mnist'): if not in_shape[0] == in_shape[1] == 28: raise ValueError('MNIST LeNet architectures expect input ' + 'images of size 28x28.') else: if not in_shape[0] == in_shape[1] == 32: raise ValueError('CIFAR LeNet architectures expect input ' + 'images of size 32x32.') ### Parse or set context-mod arguments ### rem_kwargs = MainNetInterface._parse_context_mod_args(kwargs) if len(rem_kwargs) > 0: raise ValueError('Keyword arguments %s unknown.' % str(rem_kwargs)) # Since this is a conv-net, we may also want to add the following. if 'context_mod_apply_pixel_wise' not in kwargs.keys(): kwargs['context_mod_apply_pixel_wise'] = False self._use_context_mod = kwargs['use_context_mod'] self._context_mod_inputs = kwargs['context_mod_inputs'] self._no_last_layer_context_mod = kwargs['no_last_layer_context_mod'] self._context_mod_no_weights = kwargs['context_mod_no_weights'] self._context_mod_post_activation = \ kwargs['context_mod_post_activation'] self._context_mod_gain_offset = kwargs['context_mod_gain_offset'] self._context_mod_gain_softplus = kwargs['context_mod_gain_softplus'] self._context_mod_apply_pixel_wise = \ kwargs['context_mod_apply_pixel_wise'] ### Setup class attributes ### assert(init_weights is None or \ (not no_weights or not self._context_mod_no_weights)) self._no_weights = no_weights self._dropout_rate = dropout_rate self._has_bias = True self._has_fc_out = True # We need to make sure that the last 2 entries of `weights` correspond # to the weight matrix and bias vector of the last layer! self._mask_fc_out = True self._has_linear_out = True self._param_shapes = [] self._param_shapes_meta = [] self._internal_params = None if no_weights and \ self._context_mod_no_weights else nn.ParameterList() self._hyper_shapes_learned = None \ if not no_weights and not self._context_mod_no_weights else [] self._hyper_shapes_learned_ref = None if self._hyper_shapes_learned \ is None else [] self._layer_weight_tensors = nn.ParameterList() self._layer_bias_vectors = nn.ParameterList() if dropout_rate != -1: assert (dropout_rate >= 0. and dropout_rate <= 1.) # FIXME `nn.Dropout2d` zeroes out whole feature maps. Is that really # desired here? self._drop_conv1 = nn.Dropout2d(p=dropout_rate) self._drop_conv2 = nn.Dropout2d(p=dropout_rate) self._drop_fc1 = nn.Dropout(p=dropout_rate) ### Define and initialize context mod layers/weights ### self._context_mod_layers = nn.ModuleList() if self._use_context_mod \ else None if self._use_context_mod: cm_layer_inds = [] cm_shapes = [] # Output shape of all context-mod layers. if self._context_mod_inputs: cm_shapes.append([in_shape[2], *in_shape[:2]]) # We reserve layer zero for input context-mod. Otherwise, there # is no layer zero. cm_layer_inds.append(0) layer_out_shapes = self._compute_layer_out_sizes() # Context-modulation is applied after the pooling layers. # So we delete the shapes of the conv-layer outputs and keep the # ones of the pooling layer outputs. del layer_out_shapes[2] del layer_out_shapes[0] cm_shapes.extend(layer_out_shapes) cm_layer_inds.extend(range(2, 2 * len(layer_out_shapes) + 1, 2)) if self._no_last_layer_context_mod: cm_shapes = cm_shapes[:-1] cm_layer_inds = cm_layer_inds[:-1] if not self._context_mod_apply_pixel_wise: # Only scalar gain and shift per feature map! for i, s in enumerate(cm_shapes): if len(s) == 3: cm_shapes[i] = [s[0], 1, 1] self._add_context_mod_layers(cm_shapes, cm_layers=cm_layer_inds) ### Define and add conv- and fc-layer weights. for i, s in enumerate(self._chosen_arch): if not no_weights: self._internal_params.append( nn.Parameter(torch.Tensor(*s), requires_grad=True)) if len(s) == 1: self._layer_bias_vectors.append(self._internal_params[-1]) else: self._layer_weight_tensors.append( self._internal_params[-1]) else: self._hyper_shapes_learned.append(s) self._hyper_shapes_learned_ref.append(len(self.param_shapes)) self._param_shapes.append(s) self._param_shapes_meta.append({ 'name': 'weight' if len(s) != 1 else 'bias', 'index': -1 if no_weights else len(self._internal_params) - 1, 'layer': 2 * (i // 2) + 1 }) ### Initialize weights. if init_weights is not None: assert len(init_weights) == len(self.weights) for i in range(len(init_weights)): assert np.all( np.equal(list(init_weights[i].shape), self.weights[i].shape)) self.weights[i].data = init_weights[i] else: for i in range(len(self._layer_weight_tensors)): init_params(self._layer_weight_tensors[i], self._layer_bias_vectors[i]) ### Print user info. if verbose: if self._use_context_mod: cm_param_shapes = [] for cm_layer in self.context_mod_layers: cm_param_shapes.extend(cm_layer.param_shapes) cm_num_weights = \ MainNetInterface.shapes_to_num_weights(cm_param_shapes) print('Creating a LeNet with %d weights' % self.num_params + (' (including %d weights associated with-' % cm_num_weights + 'context modulation)' if self._use_context_mod else '') + '.' + (' The network uses dropout.' if dropout_rate != -1 \ else '')) self._is_properly_setup()
def __init__(self, n_in=1, n_out=1, hidden_layers=(10, 10), activation_fn=torch.nn.ReLU(), use_bias=True, no_weights=False, init_weights=None, dropout_rate=-1, use_spectral_norm=False, use_batch_norm=False, bn_track_stats=True, distill_bn_stats=False, use_context_mod=False, context_mod_inputs=False, no_last_layer_context_mod=False, context_mod_no_weights=False, context_mod_post_activation=False, context_mod_gain_offset=False, context_mod_gain_softplus=False, out_fn=None, verbose=True): # FIXME find a way using super to handle multiple inheritance. nn.Module.__init__(self) MainNetInterface.__init__(self) # FIXME Spectral norm is incorrectly implemented. Function # `nn.utils.spectral_norm` needs to be called in the constructor, such # that sepc norm is wrapped around a module. if use_spectral_norm: raise NotImplementedError('Spectral normalization not yet ' + 'implemented for this network.') if use_batch_norm and use_context_mod: # FIXME Does it make sense to have both enabled? # I.e., should we produce a warning or error? pass self._a_fun = activation_fn assert(init_weights is None or \ (not no_weights or not context_mod_no_weights)) self._no_weights = no_weights self._dropout_rate = dropout_rate #self._use_spectral_norm = use_spectral_norm self._use_batch_norm = use_batch_norm self._bn_track_stats = bn_track_stats self._distill_bn_stats = distill_bn_stats and use_batch_norm self._use_context_mod = use_context_mod self._context_mod_inputs = context_mod_inputs self._no_last_layer_context_mod = no_last_layer_context_mod self._context_mod_no_weights = context_mod_no_weights self._context_mod_post_activation = context_mod_post_activation self._context_mod_gain_offset = context_mod_gain_offset self._context_mod_gain_softplus = context_mod_gain_softplus self._out_fn = out_fn self._has_bias = use_bias self._has_fc_out = True # We need to make sure that the last 2 entries of `weights` correspond # to the weight matrix and bias vector of the last layer. self._mask_fc_out = True self._has_linear_out = True if out_fn is None else False if use_spectral_norm and no_weights: raise ValueError('Cannot use spectral norm in a network without ' + 'parameters.') # FIXME make sure that this implementation is correct in all situations # (e.g., what to do if weights are passed to the forward method?). if use_spectral_norm: self._spec_norm = nn.utils.spectral_norm else: self._spec_norm = lambda x: x # identity self._param_shapes = [] self._param_shapes_meta = [] self._weights = None if no_weights and context_mod_no_weights \ else nn.ParameterList() self._hyper_shapes_learned = None \ if not no_weights and not context_mod_no_weights else [] self._hyper_shapes_learned_ref = None if self._hyper_shapes_learned \ is None else [] if dropout_rate != -1: assert (dropout_rate >= 0. and dropout_rate <= 1.) self._dropout = nn.Dropout(p=dropout_rate) ### Define and initialize context mod weights. self._context_mod_layers = nn.ModuleList() if use_context_mod else None self._context_mod_shapes = [] if use_context_mod else None if use_context_mod: cm_ind = 0 cm_sizes = [] if context_mod_inputs: cm_sizes.append(n_in) cm_sizes.extend(hidden_layers) if not no_last_layer_context_mod: cm_sizes.append(n_out) for i, n in enumerate(cm_sizes): cmod_layer = ContextModLayer( n, no_weights=context_mod_no_weights, apply_gain_offset=context_mod_gain_offset, apply_gain_softplus=context_mod_gain_softplus) self._context_mod_layers.append(cmod_layer) self.param_shapes.extend(cmod_layer.param_shapes) assert len(cmod_layer.param_shapes) == 2 self._param_shapes_meta.extend([ {'name': 'cm_scale', 'index': -1 if context_mod_no_weights else \ len(self._weights), 'layer': -1}, # 'layer' is set later. {'name': 'cm_shift', 'index': -1 if context_mod_no_weights else \ len(self._weights)+1, 'layer': -1}, # 'layer' is set later. ]) self._context_mod_shapes.extend(cmod_layer.param_shapes) if context_mod_no_weights: self._hyper_shapes_learned.extend(cmod_layer.param_shapes) else: self._weights.extend(cmod_layer.weights) # FIXME ugly code. Move initialization somewhere else. if not context_mod_no_weights and init_weights is not None: assert (len(cmod_layer.weights) == 2) for ii in range(2): assert(np.all(np.equal( \ list(init_weights[cm_ind].shape), list(cm_ind.weights[ii].shape)))) cmod_layer.weights[ii].data = init_weights[cm_ind] cm_ind += 1 if init_weights is not None: init_weights = init_weights[cm_ind:] if context_mod_no_weights: self._hyper_shapes_learned_ref = \ list(range(len(self._param_shapes))) ### Define and initialize batch norm weights. self._batchnorm_layers = nn.ModuleList() if use_batch_norm else None if use_batch_norm: if distill_bn_stats: self._hyper_shapes_distilled = [] bn_ind = 0 for i, n in enumerate(hidden_layers): bn_layer = BatchNormLayer(n, affine=not no_weights, track_running_stats=bn_track_stats) self._batchnorm_layers.append(bn_layer) self._param_shapes.extend(bn_layer.param_shapes) assert len(bn_layer.param_shapes) == 2 self._param_shapes_meta.extend([ { 'name': 'bn_scale', 'index': -1 if no_weights else len(self._weights), 'layer': -1 }, # 'layer' is set later. { 'name': 'bn_shift', 'index': -1 if no_weights else len(self._weights) + 1, 'layer': -1 }, # 'layer' is set later. ]) if no_weights: self._hyper_shapes_learned.extend(bn_layer.param_shapes) else: self._weights.extend(bn_layer.weights) if distill_bn_stats: self._hyper_shapes_distilled.extend( \ [list(p.shape) for p in bn_layer.get_stats(0)]) # FIXME ugly code. Move initialization somewhere else. if not no_weights and init_weights is not None: assert (len(bn_layer.weights) == 2) for ii in range(2): assert(np.all(np.equal( \ list(init_weights[bn_ind].shape), list(bn_layer.weights[ii].shape)))) bn_layer.weights[ii].data = init_weights[bn_ind] bn_ind += 1 if init_weights is not None: init_weights = init_weights[bn_ind:] ### Compute shapes of linear layers. linear_shapes = MLP.weight_shapes(n_in=n_in, n_out=n_out, hidden_layers=hidden_layers, use_bias=use_bias) self._param_shapes.extend(linear_shapes) for i, s in enumerate(linear_shapes): self._param_shapes_meta.append({ 'name': 'weight' if len(s) != 1 else 'bias', 'index': -1 if no_weights else len(self._weights) + i, 'layer': -1 # 'layer' is set later. }) num_weights = MainNetInterface.shapes_to_num_weights( self._param_shapes) ### Set missing meta information of param_shapes. offset = 1 if use_context_mod and context_mod_inputs else 0 shift = 1 if use_batch_norm: shift += 1 if use_context_mod: shift += 1 cm_offset = 2 if context_mod_post_activation else 1 bn_offset = 1 if context_mod_post_activation else 2 cm_ind = 0 bn_ind = 0 layer_ind = 0 for i, dd in enumerate(self._param_shapes_meta): if dd['name'].startswith('cm'): if offset == 1 and i in [0, 1]: dd['layer'] = 0 else: if cm_ind < len(hidden_layers): dd['layer'] = offset + cm_ind * shift + cm_offset else: assert cm_ind == len(hidden_layers) and \ not no_last_layer_context_mod # No batchnorm in output layer. dd['layer'] = offset + cm_ind * shift + 1 if dd['name'] == 'cm_shift': cm_ind += 1 elif dd['name'].startswith('bn'): dd['layer'] = offset + bn_ind * shift + bn_offset if dd['name'] == 'bn_shift': bn_ind += 1 else: dd['layer'] = offset + layer_ind * shift if not use_bias or dd['name'] == 'bias': layer_ind += 1 ### Uer information if verbose: if use_context_mod: cm_num_weights = 0 for cm_layer in self._context_mod_layers: cm_num_weights += MainNetInterface.shapes_to_num_weights( \ cm_layer.param_shapes) print( 'Creating an MLP with %d weights' % num_weights + (' (including %d weights associated with-' % cm_num_weights + 'context modulation)' if use_context_mod else '') + '.' + (' The network uses dropout.' if dropout_rate != -1 else '') + (' The network uses batchnorm.' if use_batch_norm else '')) self._layer_weight_tensors = nn.ParameterList() self._layer_bias_vectors = nn.ParameterList() if no_weights: self._hyper_shapes_learned.extend(linear_shapes) if use_context_mod: if context_mod_no_weights: self._hyper_shapes_learned_ref = \ list(range(len(self._param_shapes))) else: ncm = len(self._context_mod_shapes) self._hyper_shapes_learned_ref = \ list(range(ncm, len(self._param_shapes))) self._is_properly_setup() return ### Define and initialize linear weights. for i, dims in enumerate(linear_shapes): self._weights.append( nn.Parameter(torch.Tensor(*dims), requires_grad=True)) if len(dims) == 1: self._layer_bias_vectors.append(self._weights[-1]) else: self._layer_weight_tensors.append(self._weights[-1]) if init_weights is not None: assert (len(init_weights) == len(linear_shapes)) for i in range(len(init_weights)): assert (np.all( np.equal(list(init_weights[i].shape), linear_shapes[i]))) if use_bias: if i % 2 == 0: self._layer_weight_tensors[i // 2].data = init_weights[i] else: self._layer_bias_vectors[i // 2].data = init_weights[i] else: self._layer_weight_tensors[i].data = init_weights[i] else: for i in range(len(self._layer_weight_tensors)): if use_bias: init_params(self._layer_weight_tensors[i], self._layer_bias_vectors[i]) else: init_params(self._layer_weight_tensors[i]) if self._num_context_mod_shapes() == 0: # Note, that might be the case if no hidden layers exist and no # input or output modulation is used. self._use_context_mod = False self._is_properly_setup()
def __init__(self, n_in, n_out=1, inp_chunk_dim=100, out_chunk_dim=10, cemb_size=8, cemb_init_std=1., red_layers=(10, 10), net_layers=(10, 10), activation_fn=torch.nn.ReLU(), use_bias=True, dynamic_biases=None, no_weights=False, init_weights=None, dropout_rate=-1, use_spectral_norm=False, use_batch_norm=False, bn_track_stats=True, distill_bn_stats=False, verbose=True): # FIXME find a way using super to handle multiple inheritance. nn.Module.__init__(self) MainNetInterface.__init__(self) self._n_in = n_in self._n_out = n_out self._inp_chunk_dim = inp_chunk_dim self._out_chunk_dim = out_chunk_dim self._cemb_size = cemb_size self._a_fun = activation_fn self._no_weights = no_weights self._has_bias = use_bias self._has_fc_out = True # We need to make sure that the last 2 entries of `weights` correspond # to the weight matrix and bias vector of the last layer. self._mask_fc_out = True self._has_linear_out = True # Ensure that `out_fn` is `None`! self._param_shapes = [] #self._param_shapes_meta = [] # TODO implement! self._weights = None if no_weights else nn.ParameterList() self._hyper_shapes_learned = None if not no_weights else [] #self._hyper_shapes_learned_ref = None if self._hyper_shapes_learned \ # is None else [] # TODO implement. self._layer_weight_tensors = nn.ParameterList() self._layer_bias_vectors = nn.ParameterList() self._context_mod_layers = None self._batchnorm_layers = nn.ModuleList() if use_batch_norm else None ################################# ### Generate Chunk Embeddings ### ################################# self._num_cembs = int(np.ceil(n_in / inp_chunk_dim)) last_chunk_size = n_in % inp_chunk_dim if last_chunk_size != 0: self._pad = inp_chunk_dim - last_chunk_size else: self._pad = -1 cemb_shape = [self._num_cembs, cemb_size] self._param_shapes.append(cemb_shape) if no_weights: self._cembs = None self._hyper_shapes_learned.append(cemb_shape) else: self._cembs = nn.Parameter(data=torch.Tensor(*cemb_shape), requires_grad=True) nn.init.normal_(self._cembs, mean=0., std=cemb_init_std) self._weights.append(self._cembs) ############################ ### Setup Dynamic Biases ### ############################ self._has_dyn_bias = None if dynamic_biases is not None: assert np.all(np.array(dynamic_biases) >= 0) and \ np.all(np.array(dynamic_biases) < len(red_layers) + 1) dynamic_biases = np.sort(np.unique(dynamic_biases)) # For each layer in the `reducer`, where we want to apply a dynamic # bias, we have to create a weight matrix for a corresponding # linear layer (we just ignore) self._dyn_bias_weights = nn.ModuleList() self._has_dyn_bias = [] for i in range(len(red_layers) + 1): if i in dynamic_biases: self._has_dyn_bias.append(True) trgt_dim = out_chunk_dim if i < len(red_layers): trgt_dim = red_layers[i] trgt_shape = [trgt_dim, cemb_size] self._param_shapes.append(trgt_shape) if not no_weights: self._dyn_bias_weights.append(None) self._hyper_shapes_learned.append(trgt_shape) else: self._dyn_bias_weights.append(nn.Parameter( \ torch.Tensor(*trgt_shape), requires_grad=True)) self._weights.append(self._dyn_bias_weights[-1]) init_params(self._dyn_bias_weights[-1]) self._layer_weight_tensors.append( \ self._dyn_bias_weights[-1]) self._layer_bias_vectors.append(None) else: self._has_dyn_bias.append(False) self._dyn_bias_weights.append(None) ################################ ### Create `Reducer` Network ### ################################ red_inp_dim = inp_chunk_dim + \ (cemb_size if dynamic_biases is None else 0) self._reducer = MLP( n_in=red_inp_dim, n_out=out_chunk_dim, hidden_layers=red_layers, activation_fn=activation_fn, use_bias=use_bias, no_weights=no_weights, init_weights=None, dropout_rate=dropout_rate, use_spectral_norm=use_spectral_norm, use_batch_norm=use_batch_norm, bn_track_stats=bn_track_stats, distill_bn_stats=distill_bn_stats, # We use context modulation to realize dynamic biases, since they # allow a different modulation per sample in the input mini-batch. # Hence, we can process several chunks in parallel with the reducer # network. use_context_mod=not dynamic_biases is None, context_mod_inputs=False, no_last_layer_context_mod=False, context_mod_no_weights=True, context_mod_post_activation=False, context_mod_gain_offset=False, context_mod_gain_softplus=False, out_fn=None, verbose=True) if dynamic_biases is not None: # FIXME We have to extract the param shapes from # `self._reducer.param_shapes`, as well as from # `self._reducer._hyper_shapes_learned` that belong to context-mod # layers. We may not add them to our own `param_shapes` attribute, # as these are not parameters (due to our misuse of the context-mod # layers). # Note, in the `forward` method, we need to supply context-mod # weights for all reducer networks, independent on whether they have # a dynamic bias or not. We can do so, by providing constant ones # for all gains and constance zero-shift for all layers without # dynamic biases (note, we need to ensure the correct batch dim!). raise NotImplementedError( 'Dynamic biases are not yet implemented!') assert self._reducer._context_mod_layers is None ### Overtake all attributes from the underlying MLP. for s in self._reducer.param_shapes: self._param_shapes.append(s) if no_weights: for s in self._reducer._hyper_shapes_learned: self._hyper_shapes_learned.append(s) else: for p in self._reducer._weights: self._weights.append(p) for p in self._reducer._layer_weight_tensors: self._layer_weight_tensors.append(p) for p in self._reducer._layer_bias_vectors: self._layer_bias_vectors.append(p) if use_batch_norm: for p in self._reducer._batchnorm_layers: self._batchnorm_layers.append(p) if self._reducer._hyper_shapes_distilled is not None: self._hyper_shapes_distilled = [] for s in self._reducer._hyper_shapes_distilled: self._hyper_shapes_distilled.append(s) ############################### ### Create Actual `Network` ### ############################### net_inp_dim = out_chunk_dim * self._num_cembs self._network = MLP(n_in=net_inp_dim, n_out=n_out, hidden_layers=net_layers, activation_fn=activation_fn, use_bias=use_bias, no_weights=no_weights, init_weights=None, dropout_rate=dropout_rate, use_spectral_norm=use_spectral_norm, use_batch_norm=use_batch_norm, bn_track_stats=bn_track_stats, distill_bn_stats=distill_bn_stats, use_context_mod=False, out_fn=None, verbose=True) ### Overtake all attributes from the underlying MLP. for s in self._network.param_shapes: self._param_shapes.append(s) if no_weights: for s in self._network._hyper_shapes_learned: self._hyper_shapes_learned.append(s) else: for p in self._network._weights: self._weights.append(p) for p in self._network._layer_weight_tensors: self._layer_weight_tensors.append(p) for p in self._network._layer_bias_vectors: self._layer_bias_vectors.append(p) if use_batch_norm: for p in self._network._batchnorm_layers: self._batchnorm_layers.append(p) if self._hyper_shapes_distilled is not None: assert self._network._hyper_shapes_distilled is not None for s in self._network._hyper_shapes_distilled: self._hyper_shapes_distilled.append(s) ##################################### ### Takeover given Initialization ### ##################################### if init_weights is not None: assert len(init_weights) == len(self._weights) for i in range(len(init_weights)): assert np.all( np.equal(list(init_weights[i].shape), self._param_shapes[i])) self._weights[i].data = init_weights[i] ###################### ### Finalize Setup ### ###################### num_weights = MainNetInterface.shapes_to_num_weights(self.param_shapes) print('Constructed MLP that processes dimensionality reduced inputs ' + 'through chunking. The network has a total of %d weights.' % num_weights) self._is_properly_setup()
def _add_fc_layers(self, in_sizes, out_sizes, no_weights, fc_layers=None): """Add fully-connected layers to the network. This method will set the weight requirements for fully-connected layers correctly. During the :meth:`forward` computation, those weights can be used in combination with :func:`torch.nn.functional.linear`. Note: This method should only be called inside the constructor of any class that implements this interface. Note: Bias weights are handled based on attribute :attr:`has_bias`. Note: This method will assumes attributes :attr:`param_shapes_meta` and :attr:`hyper_shapes_learned_ref` exist already. Note: Generated weights will be automatically added to attributes :attr:`layer_bias_vectors` and :attr:`layer_weight_tensors`. Note: Standard initialization will be applied to created weights. Args: in_sizes (list): List of intergers denoting the input size of each added fc-layer. out_sizes (list): List of intergers denoting the output size of each added fc-layer. no_weights (bool): If ``True``, fc-layers will be generated without internal parameters :attr:`internal_params`. fc_layers (list, optional): See attribute ``cm_layers`` of method :meth:`_add_context_mod_layers`. """ assert len(in_sizes) == len(out_sizes) assert fc_layers is None or len(fc_layers) == len(in_sizes) assert self._param_shapes_meta is not None assert self._hyper_shapes_learned_ref is not None if self._layer_weight_tensors is None: self._layer_weight_tensors = torch.nn.ParameterList() if self._layer_bias_vectors is None: self._layer_bias_vectors = torch.nn.ParameterList() for i, n_in in enumerate(in_sizes): n_out = out_sizes[i] s_w = [n_out, n_in] s_b = [n_out] if self.has_bias else None for j, s in enumerate([s_w, s_b]): if s is None: continue is_bias = True if j % 2 == 0: is_bias = False if not no_weights: self._internal_params.append(torch.nn.Parameter( \ torch.Tensor(*s), requires_grad=True)) if is_bias: self._layer_bias_vectors.append( \ self._internal_params[-1]) else: self._layer_weight_tensors.append( \ self._internal_params[-1]) else: self._hyper_shapes_learned.append(s) self._hyper_shapes_learned_ref.append( \ len(self.param_shapes)) self._param_shapes.append(s) self._param_shapes_meta.append({ 'name': 'bias' if is_bias else 'weight', 'index': -1 if no_weights else len(self._internal_params)-1, 'layer': -1 if fc_layers is None else fc_layers[i] }) if not no_weights: init_params(self._layer_weight_tensors[-1], self._layer_bias_vectors[-1] if self.has_bias else None)