コード例 #1
0
    def __init__(self,
                 in_shape=(224, 224, 3),
                 num_classes=1000,
                 use_bias=True,
                 use_fc_bias=None,
                 num_feature_maps=(64, 64, 128, 256, 512),
                 blocks_per_group=(2, 2, 2, 2),
                 projection_shortcut=False,
                 bottleneck_blocks=False,
                 cutout_mod=False,
                 no_weights=False,
                 use_batch_norm=True,
                 bn_track_stats=True,
                 distill_bn_stats=False,
                 chw_input_format=False,
                 verbose=True,
                 **kwargs):
        super(ResNetIN, self).__init__(num_classes, verbose)

        ### Parse or set context-mod arguments ###
        rem_kwargs = MainNetInterface._parse_context_mod_args(kwargs)
        if 'context_mod_apply_pixel_wise' in rem_kwargs:
            rem_kwargs.remove('context_mod_apply_pixel_wise')
        if len(rem_kwargs) > 0:
            raise ValueError('Keyword arguments %s unknown.' % str(rem_kwargs))
        # Since this is a conv-net, we may also want to add the following.
        if 'context_mod_apply_pixel_wise' not in kwargs.keys():
            kwargs['context_mod_apply_pixel_wise'] = False

        self._use_context_mod = kwargs['use_context_mod']
        self._context_mod_inputs = kwargs['context_mod_inputs']
        self._no_last_layer_context_mod = kwargs['no_last_layer_context_mod']
        self._context_mod_no_weights = kwargs['context_mod_no_weights']
        self._context_mod_post_activation = \
            kwargs['context_mod_post_activation']
        self._context_mod_gain_offset = kwargs['context_mod_gain_offset']
        self._context_mod_gain_softplus = kwargs['context_mod_gain_softplus']
        self._context_mod_apply_pixel_wise = \
            kwargs['context_mod_apply_pixel_wise']

        ### Check or parse remaining arguments ###
        self._in_shape = in_shape
        self._projection_shortcut = projection_shortcut
        self._bottleneck_blocks = bottleneck_blocks
        self._cutout_mod = cutout_mod
        if use_fc_bias is None:
            use_fc_bias = use_bias
        # Also, checkout attribute `_has_bias` below.
        self._use_bias = use_bias
        self._use_fc_bias = use_fc_bias
        self._no_weights = no_weights
        assert not use_batch_norm or (not distill_bn_stats or bn_track_stats)
        self._use_batch_norm = use_batch_norm
        self._bn_track_stats = bn_track_stats
        self._distill_bn_stats = distill_bn_stats and use_batch_norm
        self._chw_input_format = chw_input_format

        if len(blocks_per_group) != 4:
            raise ValueError('Option "blocks_per_group" must be a list of 4 ' +
                             'integers.')
        self._num_blocks = blocks_per_group
        if len(num_feature_maps) != 5:
            raise ValueError('Option "num_feature_maps" must be a list of 5 ' +
                             'integers.')
        self._filter_sizes = list(num_feature_maps)

        # The first layer of group 3, 4 and 5 uses a strided convolution, so
        # the shorcut connections need to perform a downsampling operation. In
        # addition, whenever traversing from one group to the next, the number
        # of feature maps might change. In all these cases, the network might
        # benefit from smart shortcut connections, which means using projection
        # shortcuts, where a 1x1 conv is used for the mentioned skip connection.
        self._num_non_ident_skips = 3  # Strided convs: 2->3, 3->4 and 4->5
        fs1 = self._filter_sizes[1]
        if self._bottleneck_blocks:
            fs1 *= 4
        if self._filter_sizes[0] != fs1:
            self._num_non_ident_skips += 1  # Also handle 1->2.
        self._group_has_1x1 = [False] * 4
        if self._projection_shortcut:
            for i in range(3, 3 - self._num_non_ident_skips, -1):
                self._group_has_1x1[i] = True
        # Number of conv layers (excluding skip connections)
        self._num_main_conv_layers = 1 + int(np.sum([self._num_blocks[i] * \
            (3 if self._bottleneck_blocks else 2) for i in range(4)]))

        # The original architecture uses a 7x7 kernel in the first conv layer
        # and 3x3 or 1x1 kernels in all remaining layers.
        self._init_kernel_size = (7, 7)
        # All 3x3 layers have padding 1 and 1x1 layers have padding 0.
        self._init_padding = 3
        self._init_stride = 2

        if self._cutout_mod:
            self._init_kernel_size = (3, 3)
            self._init_padding = 1
            self._init_stride = 1

        ### Set required class attributes ###
        # Note, we did overwrite the getter for attribute `has_bias`, as it is
        # not applicable if the values of `use_bias` and `use_fc_bias` differ.
        self._has_bias = use_bias if use_bias == use_fc_bias else False
        self._has_fc_out = True
        # We need to make sure that the last 2 entries of `weights` correspond
        # to the weight matrix and bias vector of the last layer!
        self._mask_fc_out = True
        self._has_linear_out = True

        self._param_shapes = []
        self._param_shapes_meta = []
        self._internal_params = None if no_weights and \
            self._context_mod_no_weights else nn.ParameterList()
        self._hyper_shapes_learned = None \
            if not no_weights and not self._context_mod_no_weights else []
        self._hyper_shapes_learned_ref = None if self._hyper_shapes_learned \
            is None else []
        self._layer_weight_tensors = nn.ParameterList()
        self._layer_bias_vectors = nn.ParameterList()

        #################################
        ### Create context mod layers ###
        #################################
        self._context_mod_layers = nn.ModuleList() if self._use_context_mod \
            else None

        if self._use_context_mod:
            cm_layer_inds = []
            cm_shapes = []  # Output shape of all layers.
            if self._context_mod_inputs:
                cm_shapes.append([in_shape[2], *in_shape[:2]])
                # We reserve layer zero for input context-mod. Otherwise, there
                # is no layer zero.
                cm_layer_inds.append(0)

            layer_out_shapes = self._compute_layer_out_sizes()
            cm_shapes.extend(layer_out_shapes)
            # All layer indices `l` with `l mod 3 == 0` are context-mod layers.
            cm_layer_inds.extend(range(3, 3 * len(layer_out_shapes) + 1, 3))
            if self._no_last_layer_context_mod:
                cm_shapes = cm_shapes[:-1]
                cm_layer_inds = cm_layer_inds[:-1]
            if not self._context_mod_apply_pixel_wise:
                # Only scalar gain and shift per feature map!
                for i, s in enumerate(cm_shapes):
                    if len(s) == 3:
                        cm_shapes[i] = [s[0], 1, 1]

            self._add_context_mod_layers(cm_shapes, cm_layers=cm_layer_inds)

        ###############################
        ### Create batchnorm layers ###
        ###############################
        # We just use even numbers starting from 2 as layer indices for
        # batchnorm layers.
        if use_batch_norm:
            bn_sizes = []
            for i, s in enumerate(self._filter_sizes):
                if i == 0:
                    bn_sizes.append(s)
                else:
                    for _ in range(self._num_blocks[i - 1]):
                        if self._bottleneck_blocks:
                            bn_sizes.extend([s, s, 4 * s])
                        else:
                            bn_sizes.extend([s, s])

            # All layer indices `l` with `l mod 3 == 2` are batchnorm layers.
            bn_layers = list(range(2, 3 * len(bn_sizes) + 1, 3))

            # We also need a batchnorm layer per skip connection that uses 1x1
            # projections.
            if self._projection_shortcut:
                bn_layer_ind_skip = 3 * (self._num_main_conv_layers + 1) + 2

                factor = 4 if self._bottleneck_blocks else 1
                for i in range(4):  # For each transition between conv groups.
                    if self._group_has_1x1[i]:
                        bn_sizes.append(self._filter_sizes[i + 1] * factor)
                        bn_layers.append(bn_layer_ind_skip)

                        bn_layer_ind_skip += 3

            self._add_batchnorm_layers(bn_sizes,
                                       no_weights,
                                       bn_layers=bn_layers,
                                       distill_bn_stats=distill_bn_stats,
                                       bn_track_stats=bn_track_stats)

        ######################################
        ### Create skip connection weights ###
        ######################################
        if self._projection_shortcut:
            layer_ind_skip = 3 * (self._num_main_conv_layers + 1) + 1

            factor = 4 if self._bottleneck_blocks else 1

            n_in = self._filter_sizes[0]
            for i in range(4):  # For each transition between conv groups.
                if not self._group_has_1x1[i]:
                    continue

                n_out = self._filter_sizes[i + 1] * factor

                skip_1x1_shape = [n_out, n_in, 1, 1]

                if not no_weights:
                    self._internal_params.append(nn.Parameter( \
                        torch.Tensor(*skip_1x1_shape), requires_grad=True))
                    self._layer_weight_tensors.append(
                        self._internal_params[-1])
                    self._layer_bias_vectors.append(None)
                    init_params(self._layer_weight_tensors[-1])
                else:
                    self._hyper_shapes_learned.append(skip_1x1_shape)
                    self._hyper_shapes_learned_ref.append( \
                        len(self.param_shapes))

                self._param_shapes.append(skip_1x1_shape)
                self._param_shapes_meta.append({
                    'name': 'weight',
                    'index': -1 if no_weights else \
                        len(self._internal_params)-1,
                    'layer': layer_ind_skip
                })

                layer_ind_skip += 3
                n_in = n_out

        ############################################################
        ### Create convolutional layers and final linear weights ###
        ############################################################
        # Convolutional layers will get IDs `l` such that `l mod 3 == 1`.
        layer_id = 1
        n_per_block = 3 if self._bottleneck_blocks else 2
        for i in range(6):
            if i == 0:  ### Fist layer.
                num = 1
                prev_fs = self._in_shape[2]
                curr_fs = self._filter_sizes[0]

                kernel_size = self._init_kernel_size
                #stride = self._init_stride
            elif i == 5:  ### Final fully-connected layer.
                num = 1
                curr_fs = num_classes

                kernel_size = None
            else:  # Group of residual blocks.
                num = self._num_blocks[i - 1] * n_per_block
                curr_fs = self._filter_sizes[i]

                kernel_size = (3, 3)  # depends on block structure!

            for n in range(num):
                if i == 5:
                    layer_shapes = [[curr_fs, prev_fs]]
                    if use_fc_bias:
                        layer_shapes.append([curr_fs])

                    prev_fs = curr_fs
                else:
                    if i > 0 and self._bottleneck_blocks:
                        if n % 3 == 0:
                            fs = curr_fs
                            ks = (1, 1)
                        elif n % 3 == 1:
                            fs = curr_fs
                            ks = kernel_size
                        else:
                            fs = 4 * curr_fs
                            ks = (1, 1)
                    elif i > 0 and not self._bottleneck_blocks:
                        fs = curr_fs
                        ks = kernel_size
                    else:
                        fs = curr_fs
                        ks = kernel_size

                    layer_shapes = [[fs, prev_fs, *ks]]
                    if use_bias:
                        layer_shapes.append([fs])

                    prev_fs = fs

                for s in layer_shapes:
                    if not no_weights:
                        self._internal_params.append(nn.Parameter( \
                            torch.Tensor(*s), requires_grad=True))
                        if len(s) == 1:
                            self._layer_bias_vectors.append( \
                                self._internal_params[-1])
                        else:
                            self._layer_weight_tensors.append( \
                                self._internal_params[-1])
                    else:
                        self._hyper_shapes_learned.append(s)
                        self._hyper_shapes_learned_ref.append( \
                            len(self.param_shapes))

                    self._param_shapes.append(s)
                    self._param_shapes_meta.append({
                        'name': 'weight' if len(s) != 1 else 'bias',
                        'index': -1 if no_weights else \
                            len(self._internal_params)-1,
                        'layer': layer_id
                    })

                layer_id += 3

                # Initialize_weights
                if not no_weights:
                    init_params(self._layer_weight_tensors[-1],
                        self._layer_bias_vectors[-1] \
                        if len(layer_shapes) == 2 else None)

        ###########################
        ### Print infos to user ###
        ###########################
        if verbose:
            if self._use_context_mod:
                cm_param_shapes = []
                for cm_layer in self.context_mod_layers:
                    cm_param_shapes.extend(cm_layer.param_shapes)
                cm_num_params = \
                    MainNetInterface.shapes_to_num_weights(cm_param_shapes)

            print('Creating a "%s" with %d weights' \
                  % (str(self), self.num_params)
                  + (' (including %d weights associated with-' % cm_num_params
                     + 'context modulation)' if self._use_context_mod else '')
                  + '.'
                  + (' The network uses batchnorm.' if use_batch_norm  else ''))

        self._is_properly_setup(check_has_bias=False)
コード例 #2
0
    def __init__(self, in_shape=(32, 32, 3), num_classes=10, n=4, k=10,
                 num_feature_maps=(16, 16, 32, 64), use_bias=True,
                 use_fc_bias=None, no_weights=False, use_batch_norm=True,
                 bn_track_stats=True, distill_bn_stats=False, dropout_rate=-1,
                 chw_input_format=False, verbose=True, **kwargs):
        super(WRN, self).__init__(num_classes, verbose)

        ### Parse or set context-mod arguments ###
        rem_kwargs = MainNetInterface._parse_context_mod_args(kwargs)
        if len(rem_kwargs) > 0:
            raise ValueError('Keyword arguments %s unknown.' % str(rem_kwargs))
        # Since this is a conv-net, we may also want to add the following.
        if 'context_mod_apply_pixel_wise' not in kwargs.keys():
            kwargs['context_mod_apply_pixel_wise'] = False

        self._use_context_mod = kwargs['use_context_mod']
        self._context_mod_inputs = kwargs['context_mod_inputs']
        self._no_last_layer_context_mod = kwargs['no_last_layer_context_mod']
        self._context_mod_no_weights = kwargs['context_mod_no_weights']
        self._context_mod_post_activation = \
            kwargs['context_mod_post_activation']
        self._context_mod_gain_offset = kwargs['context_mod_gain_offset']
        self._context_mod_gain_softplus = kwargs['context_mod_gain_softplus']
        self._context_mod_apply_pixel_wise = \
            kwargs['context_mod_apply_pixel_wise']

        ### Check or parse remaining arguments ###
        self._in_shape = in_shape
        self._n = n
        self._k = k
        if use_fc_bias is None:
            use_fc_bias = use_bias
        # Also, checkout attribute `_has_bias` below.
        self._use_bias = use_bias
        self._use_fc_bias = use_fc_bias
        self._no_weights = no_weights
        assert not use_batch_norm or (not distill_bn_stats or bn_track_stats)
        self._use_batch_norm = use_batch_norm
        self._bn_track_stats = bn_track_stats
        self._distill_bn_stats = distill_bn_stats and use_batch_norm
        self._dropout_rate = dropout_rate
        self._chw_input_format = chw_input_format

        # The original authors found that the best configuration uses this
        # kernel in all convolutional layers.
        self._kernel_size = (3, 3)
        if len(num_feature_maps) != 4:
            raise ValueError('Option "num_feature_maps" must be a list of 4 ' +
                             'integers.')
        self._filter_sizes = list(num_feature_maps)
        if k != 1:
            for i in range(1, 4):
                self._filter_sizes[i] = k * num_feature_maps[i]
        # Strides used in the first layer of each convolutional group.
        self._strides = (1, 1, 2, 2)

        ### Set required class attributes ###
        # Note, we did overwrite the getter for attribute `has_bias`, as it is
        # not applicable if the values of `use_bias` and `use_fc_bias` differ.
        self._has_bias = use_bias if use_bias == use_fc_bias else False
        self._has_fc_out = True
        # We need to make sure that the last 2 entries of `weights` correspond
        # to the weight matrix and bias vector of the last layer!
        self._mask_fc_out = True
        self._has_linear_out = True

        self._param_shapes = []
        self._param_shapes_meta = []
        self._internal_params = None if no_weights and \
            self._context_mod_no_weights else nn.ParameterList()
        self._hyper_shapes_learned = None \
            if not no_weights and not self._context_mod_no_weights else []
        self._hyper_shapes_learned_ref = None if self._hyper_shapes_learned \
            is None else []
        self._layer_weight_tensors = nn.ParameterList()
        self._layer_bias_vectors = nn.ParameterList()

        if dropout_rate != -1:
            assert dropout_rate >= 0. and dropout_rate <= 1.
            self._dropout = nn.Dropout(p=dropout_rate)

        #################################
        ### Create context mod layers ###
        #################################
        self._context_mod_layers = nn.ModuleList() if self._use_context_mod \
            else None

        if self._use_context_mod:
            cm_layer_inds = []
            cm_shapes = [] # Output shape of all layers.
            if self._context_mod_inputs:
                cm_shapes.append([in_shape[2], *in_shape[:2]])
                # We reserve layer zero for input context-mod. Otherwise, there
                # is no layer zero.
                cm_layer_inds.append(0)

            layer_out_shapes = self._compute_layer_out_sizes()
            cm_shapes.extend(layer_out_shapes)
            # All layer indices `l` with `l mod 3 == 0` are context-mod layers.
            cm_layer_inds.extend(range(3, 3*len(layer_out_shapes)+1, 3))
            if self._no_last_layer_context_mod:
                cm_shapes = cm_shapes[:-1]
                cm_layer_inds = cm_layer_inds[:-1]
            if not self._context_mod_apply_pixel_wise:
                # Only scalar gain and shift per feature map!
                for i, s in enumerate(cm_shapes):
                    if len(s) == 3:
                        cm_shapes[i] = [s[0], 1, 1]

            self._add_context_mod_layers(cm_shapes, cm_layers=cm_layer_inds)

        ###############################
        ### Create batchnorm layers ###
        ###############################
        # We just use even numbers starting from 2 as layer indices for
        # batchnorm layers.
        if use_batch_norm:
            bn_sizes = []
            for i, s in enumerate(self._filter_sizes):
                if i == 0:
                    bn_sizes.append(s)
                else:
                    bn_sizes.extend([s] * (2*n))

            # All layer indices `l` with `l mod 3 == 2` are batchnorm layers.
            self._add_batchnorm_layers(bn_sizes, no_weights,
                bn_layers=list(range(2, 3*len(bn_sizes)+1, 3)),
                distill_bn_stats=distill_bn_stats,
                bn_track_stats=bn_track_stats)

        ######################################
        ### Create skip connection weights ###
        ######################################
        # We use 1x1 convolutional layers for residual blocks in case the
        # number of input and output feature maps disagrees. We also use 1x1
        # convolutions whenever a stride greater than 1 is applied. This is not
        # necessary in my opinion (as it adds extra weights that do not affect
        # the downsampling itself), but commonly done; or instance, in the
        # original PyTorch implementation.
        # Note, there may be maximally 3 1x1 layers added to the network.
        # Note, we use 1x1 conv layers without biases.
        skip_1x1_shapes = []
        self._group_has_1x1 = [False] * 3
        for i in range(1, 4):
            if self._filter_sizes[i-1] != self._filter_sizes[i] or \
                    self._strides[i] != 1:
                skip_1x1_shapes.append([self._filter_sizes[i],
                                        self._filter_sizes[i-1], 1, 1])
                self._group_has_1x1[i-1] = True

        for s in skip_1x1_shapes:
            if not no_weights:
                self._internal_params.append(nn.Parameter( \
                    torch.Tensor(*s), requires_grad=True))
                self._layer_weight_tensors.append(self._internal_params[-1])
                init_params(self._layer_weight_tensors[-1])
            else:
                self._hyper_shapes_learned.append(s)
                self._hyper_shapes_learned_ref.append(len(self.param_shapes))

            self._param_shapes.append(s)
            self._param_shapes_meta.append({
                'name': 'weight',
                'index': -1 if no_weights else \
                    len(self._internal_params)-1,
                'layer': -1
            })

        ############################################################
        ### Create convolutional layers and final linear weights ###
        ############################################################
        # Convolutional layers will get IDs `l` such that `l mod 3 == 1`.
        layer_id = 1
        for i in range(5):
            if i == 0: ### Fist layer.
                num = 1
                prev_fs = self._in_shape[2]
                curr_fs = self._filter_sizes[0]
            elif i == 4: ### Final fully-connected layer.
                num = 1
                curr_fs = num_classes
            else: # Group of residual blocks.
                num = 2 * n
                curr_fs = self._filter_sizes[i]

            for _ in range(num):
                if i == 4:
                    layer_shapes = [[curr_fs, prev_fs]]
                    if use_fc_bias:
                        layer_shapes.append([curr_fs])
                else:
                    layer_shapes = [[curr_fs, prev_fs, *self._kernel_size]]
                    if use_bias:
                        layer_shapes.append([curr_fs])

                for s in layer_shapes:
                    if not no_weights:
                        self._internal_params.append(nn.Parameter( \
                            torch.Tensor(*s), requires_grad=True))
                        if len(s) == 1:
                            self._layer_bias_vectors.append( \
                                self._internal_params[-1])
                        else:
                            self._layer_weight_tensors.append( \
                                self._internal_params[-1])
                    else:
                        self._hyper_shapes_learned.append(s)
                        self._hyper_shapes_learned_ref.append( \
                            len(self.param_shapes))

                    self._param_shapes.append(s)
                    self._param_shapes_meta.append({
                        'name': 'weight' if len(s) != 1 else 'bias',
                        'index': -1 if no_weights else \
                            len(self._internal_params)-1,
                        'layer': layer_id
                    })

                prev_fs = curr_fs
                layer_id += 3

                # Initialize_weights
                if not no_weights:
                    init_params(self._layer_weight_tensors[-1],
                        self._layer_bias_vectors[-1] \
                        if len(layer_shapes) == 2 else None)

        ###########################
        ### Print infos to user ###
        ###########################
        if verbose:
            if self._use_context_mod:
                cm_param_shapes = []
                for cm_layer in self.context_mod_layers:
                    cm_param_shapes.extend(cm_layer.param_shapes)
                cm_num_params = \
                    MainNetInterface.shapes_to_num_weights(cm_param_shapes)

            print('Creating a WideResnet "%s" with %d weights' \
                  % (str(self), self.num_params)
                  + (' (including %d weights associated with-' % cm_num_params
                     + 'context modulation)' if self._use_context_mod else '')
                  + '.'
                  + (' The network uses batchnorm.' if use_batch_norm  else '')
                  + (' The network uses dropout.' if dropout_rate != -1 \
                     else ''))

        self._is_properly_setup(check_has_bias=False)
コード例 #3
0
    def __init__(
            self,
            in_shape=(28, 28, 1),
            num_classes=10,
            verbose=True,
            arch='mnist_large',
            no_weights=False,
            init_weights=None,
            dropout_rate=-1,  #0.5
            **kwargs):
        super(LeNet, self).__init__(num_classes, verbose)

        self._in_shape = in_shape
        assert arch in LeNet._ARCHITECTURES.keys()
        self._chosen_arch = LeNet._ARCHITECTURES[arch]
        if num_classes != 10:
            self._chosen_arch[-2][0] = num_classes
            self._chosen_arch[-1][0] = num_classes

        # Sanity check, given current implementation.
        if arch.startswith('mnist'):
            if not in_shape[0] == in_shape[1] == 28:
                raise ValueError('MNIST LeNet architectures expect input ' +
                                 'images of size 28x28.')
        else:
            if not in_shape[0] == in_shape[1] == 32:
                raise ValueError('CIFAR LeNet architectures expect input ' +
                                 'images of size 32x32.')

        ### Parse or set context-mod arguments ###
        rem_kwargs = MainNetInterface._parse_context_mod_args(kwargs)
        if len(rem_kwargs) > 0:
            raise ValueError('Keyword arguments %s unknown.' % str(rem_kwargs))
        # Since this is a conv-net, we may also want to add the following.
        if 'context_mod_apply_pixel_wise' not in kwargs.keys():
            kwargs['context_mod_apply_pixel_wise'] = False

        self._use_context_mod = kwargs['use_context_mod']
        self._context_mod_inputs = kwargs['context_mod_inputs']
        self._no_last_layer_context_mod = kwargs['no_last_layer_context_mod']
        self._context_mod_no_weights = kwargs['context_mod_no_weights']
        self._context_mod_post_activation = \
            kwargs['context_mod_post_activation']
        self._context_mod_gain_offset = kwargs['context_mod_gain_offset']
        self._context_mod_gain_softplus = kwargs['context_mod_gain_softplus']
        self._context_mod_apply_pixel_wise = \
            kwargs['context_mod_apply_pixel_wise']

        ### Setup class attributes ###
        assert(init_weights is None or \
               (not no_weights or not self._context_mod_no_weights))
        self._no_weights = no_weights
        self._dropout_rate = dropout_rate

        self._has_bias = True
        self._has_fc_out = True
        # We need to make sure that the last 2 entries of `weights` correspond
        # to the weight matrix and bias vector of the last layer!
        self._mask_fc_out = True
        self._has_linear_out = True

        self._param_shapes = []
        self._param_shapes_meta = []
        self._internal_params = None if no_weights and \
            self._context_mod_no_weights else nn.ParameterList()
        self._hyper_shapes_learned = None \
            if not no_weights and not self._context_mod_no_weights else []
        self._hyper_shapes_learned_ref = None if self._hyper_shapes_learned \
            is None else []
        self._layer_weight_tensors = nn.ParameterList()
        self._layer_bias_vectors = nn.ParameterList()

        if dropout_rate != -1:
            assert (dropout_rate >= 0. and dropout_rate <= 1.)
            # FIXME `nn.Dropout2d` zeroes out whole feature maps. Is that really
            # desired here?
            self._drop_conv1 = nn.Dropout2d(p=dropout_rate)
            self._drop_conv2 = nn.Dropout2d(p=dropout_rate)
            self._drop_fc1 = nn.Dropout(p=dropout_rate)

        ### Define and initialize context mod layers/weights ###
        self._context_mod_layers = nn.ModuleList() if self._use_context_mod \
            else None

        if self._use_context_mod:
            cm_layer_inds = []
            cm_shapes = []  # Output shape of all context-mod layers.
            if self._context_mod_inputs:
                cm_shapes.append([in_shape[2], *in_shape[:2]])
                # We reserve layer zero for input context-mod. Otherwise, there
                # is no layer zero.
                cm_layer_inds.append(0)

            layer_out_shapes = self._compute_layer_out_sizes()
            # Context-modulation is applied after the pooling layers.
            # So we delete the shapes of the conv-layer outputs and keep the
            # ones of the pooling layer outputs.
            del layer_out_shapes[2]
            del layer_out_shapes[0]
            cm_shapes.extend(layer_out_shapes)
            cm_layer_inds.extend(range(2, 2 * len(layer_out_shapes) + 1, 2))
            if self._no_last_layer_context_mod:
                cm_shapes = cm_shapes[:-1]
                cm_layer_inds = cm_layer_inds[:-1]

            if not self._context_mod_apply_pixel_wise:
                # Only scalar gain and shift per feature map!
                for i, s in enumerate(cm_shapes):
                    if len(s) == 3:
                        cm_shapes[i] = [s[0], 1, 1]

            self._add_context_mod_layers(cm_shapes, cm_layers=cm_layer_inds)

        ### Define and add conv- and fc-layer weights.
        for i, s in enumerate(self._chosen_arch):
            if not no_weights:
                self._internal_params.append(
                    nn.Parameter(torch.Tensor(*s), requires_grad=True))
                if len(s) == 1:
                    self._layer_bias_vectors.append(self._internal_params[-1])
                else:
                    self._layer_weight_tensors.append(
                        self._internal_params[-1])
            else:
                self._hyper_shapes_learned.append(s)
                self._hyper_shapes_learned_ref.append(len(self.param_shapes))

            self._param_shapes.append(s)
            self._param_shapes_meta.append({
                'name':
                'weight' if len(s) != 1 else 'bias',
                'index':
                -1 if no_weights else len(self._internal_params) - 1,
                'layer':
                2 * (i // 2) + 1
            })

        ### Initialize weights.
        if init_weights is not None:
            assert len(init_weights) == len(self.weights)
            for i in range(len(init_weights)):
                assert np.all(
                    np.equal(list(init_weights[i].shape),
                             self.weights[i].shape))
                self.weights[i].data = init_weights[i]
        else:
            for i in range(len(self._layer_weight_tensors)):
                init_params(self._layer_weight_tensors[i],
                            self._layer_bias_vectors[i])

        ### Print user info.
        if verbose:
            if self._use_context_mod:
                cm_param_shapes = []
                for cm_layer in self.context_mod_layers:
                    cm_param_shapes.extend(cm_layer.param_shapes)
                cm_num_weights = \
                    MainNetInterface.shapes_to_num_weights(cm_param_shapes)

            print('Creating a LeNet with %d weights' % self.num_params
                  + (' (including %d weights associated with-' % cm_num_weights
                     + 'context modulation)' if self._use_context_mod else '')
                  + '.'
                  + (' The network uses dropout.' if dropout_rate != -1 \
                     else ''))

        self._is_properly_setup()