Exemplo n.º 1
0
    def _gen_layers(self, layers, te_dim, use_bias, no_weights, init_weights,
                    ce_dim, noise_dim):
        """Generate all layers of this network. This method will create
        the parameters of each layer. Note, this method should only be
        called by the constructor.

        This method will add the attributes "_hidden_dims" and "_out_dims".
        If "no_weights" is False, it will also create an attribute "_weights"
        and initialize all parameters. Otherwise, _weights" is set to None.

        Args:
            See constructur arguments.
        """
        ### Compute the shapes of all parameters.
        # Hidden layers.
        self._hidden_dims = []
        prev_dim = te_dim
        if ce_dim is not None:
            prev_dim += ce_dim
        if noise_dim != -1:
            prev_dim += noise_dim
        for i, size in enumerate(layers):
            self._hidden_dims.append([size, prev_dim])
            if use_bias:
                self._hidden_dims.append([size])
            prev_dim = size
        self._last_hidden_size = prev_dim

        # Output layers.
        self._out_dims = []
        for i, dims in enumerate(self.target_shapes):
            nouts = np.prod(dims)
            self._out_dims.append([nouts, self._last_hidden_size])
            if use_bias:
                self._out_dims.append([nouts])
        if no_weights:
            self._theta = None
            return

        ### Create parameter tensors.
        # If "use_bias" is True, then each odd entry of this list will contain
        # a weight matrix and each even entry a bias vector. Otherwise,
        # it only contains a weight matrix per layer.
        self._theta = nn.ParameterList()
        for i, dims in enumerate(self._hidden_dims + self._out_dims):
            self._theta.append(nn.Parameter(torch.Tensor(*dims),
                                            requires_grad=True))

        if init_weights is not None:
            assert (len(init_weights) == len(self._theta))
            for i in range(len(init_weights)):
                assert (np.all(np.equal(list(init_weights[i].shape),
                                        list(self._theta[i].shape))))
                self._theta[i].data = init_weights[i]
        else:
            for i in range(0, len(self._theta), 2 if use_bias else 1):
                if use_bias:
                    init_params(self._theta[i], self._theta[i + 1])
                else:
                    init_params(self._theta[i])
    def __init__(self, in_channels, out_channels, in_height, in_width,
                 kernel_size, stride=1, padding=0, bias=True, no_weights=False):
        super(LocalConv2dLayer, self).__init__()

        if isinstance(kernel_size, int):
            kernel_size = (kernel_size, kernel_size)
        if isinstance(stride, int):
            stride = (stride, stride)
        if isinstance(padding, int):
            padding = (padding, padding)

        self._in_channels = in_channels
        self._out_channels = out_channels
        self._in_height = in_height
        self._in_width = in_width
        self._kernel_size = kernel_size
        self._stride = stride
        self._padding = padding
        self._has_bias = bias
        self._no_weights = no_weights

        self._out_height = (in_height - kernel_size[0] + 2 * padding[0]) // \
            stride[0] + 1
        self._out_width = (in_width - kernel_size[1] + 2 * padding[1]) // \
            stride[1] + 1

        # Size of a single receptive field.
        rf_size = in_channels * kernel_size[0] * kernel_size[1]
        self._rf_size = rf_size
        # Number of pixels per output feature map.
        num_pix = self._out_height * self._out_width
        self._num_pix = num_pix

        self._weights = None
        self._param_shapes = [[out_channels, rf_size, num_pix]]
        if bias:
            self._param_shapes.append([out_channels, num_pix])

        if not no_weights:
            self._weights = nn.ParameterList()

            self.register_parameter('filters', nn.Parameter( \
                torch.Tensor(*self._param_shapes[0]), requires_grad=True))
            self._weights.append(self.filters)

            if bias:
                self.register_parameter('bias', nn.Parameter( \
                    torch.Tensor(*self._param_shapes[1]), requires_grad=True))
                self._weights.append(self.bias)

                init_params(self.filters, self.bias)
            else:
                self.register_parameter('bias', None)

                init_params(self.filters)
Exemplo n.º 3
0
    def __init__(self, weight_shapes, activation_fn=torch.nn.ReLU(),
                 use_bias=True, no_weights=False, init_weights=None,
                 dropout_rate=-1, out_fn=None, verbose=True, 
                 use_spectral_norm=False, use_batch_norm=False):
        """Initialize the network.

        Args:
            weight_shapes: A list of list of integers, denoting the shape of
                each parameter tensor in this network. Note, this parameter only
                has an effect on the construction of this network, if
                "no_weights" is False. Otherwise, it is just used to check the
                shapes of the input to the network in the forward method.
            activation_fn: The nonlinearity used in hidden layers. If None, no
                nonlinearity will be applied.
            use_bias: Whether layers may have bias terms.
            no_weights: If set to True, no trainable parameters will be
                constructed, i.e., weights are assumed to be produced ad-hoc
                by a hypernetwork and passed to the forward function.
            init_weights (optional): This option is for convinience reasons.
                The option expects a list of parameter values that are used to
                initialize the network weights. As such, it provides a
                convinient way of initializing a network with a weight draw
                produced by the hypernetwork.
            dropout_rate: If -1, no dropout will be applied. Otherwise a number
                between 0 and 1 is expected, denoting the dropout rate of hidden
                layers.
            out_fn (optional): If provided, this function will be applied to the
                output neurons of the network. Note, this changes the output
                of the forward method.
            verbose: Whether to print the number of weights in the network.
            use_spectral_norm: Use spectral normalization for training.
            use_batch_norm: Whether batch normalization should be used.
        """
        # FIXME find a way using super to handle multiple inheritence.
        #super(MainNetwork, self).__init__()
        nn.Module.__init__(self)
        MainNetInterface.__init__(self)

        warn('Please use class "mnets.mlp.MLP" instead.', DeprecationWarning)

        if use_spectral_norm:
            raise NotImplementedError('Spectral normalization not yet ' +
                                      'implemented for this network.')
        if use_batch_norm:
            raise NotImplementedError('Batch normalization not yet ' +
                                      'implemented for this network.')

        assert(len(weight_shapes) > 0)
        self._all_shapes = weight_shapes
        self._has_bias = use_bias
        self._a_fun = activation_fn
        assert(init_weights is None or no_weights is False)
        self._no_weights = no_weights
        self._dropout_rate = dropout_rate
        self._out_fn = out_fn

        self._has_fc_out = True

        if use_spectral_norm and no_weights:
            raise ValueError('Cannot use spectral norm in a network without ' +
                             'parameters.')

        if use_spectral_norm:
            self._spec_norm = nn.utils.spectral_norm
        else:
            self._spec_norm = lambda x : x # identity


        if verbose:
            print('Creating an MLP with %d weights' \
                      % (MnetAPIV2.shapes_to_num_weights(self._all_shapes))
                      + (', that uses dropout.' if dropout_rate != -1 else '.'))

        if dropout_rate != -1:
            assert(dropout_rate >= 0. and dropout_rate <= 1.)
            self._dropout = nn.Dropout(p=dropout_rate)

        self._weights = None
        if no_weights:
            self._hyper_shapes = self._all_shapes
            self._is_properly_setup()
            return

        ### Define and initialize network weights.
        # Each odd entry of this list will contain a weight Tensor and each
        # even entry a bias vector.
        self._weights = nn.ParameterList()

        for i, dims in enumerate(self._all_shapes):
            self._weights.append(nn.Parameter(torch.Tensor(*dims),
                                              requires_grad=True))

        if init_weights is not None:
            assert(len(init_weights) == len(self._all_shapes))
            for i in range(len(init_weights)):
                assert(np.all(np.equal(list(init_weights[i].shape),
                                       list(self._weights[i].shape))))
                self._weights[i].data = init_weights[i]
        else:
            for i in range(0, len(self._weights), 2 if use_bias else 1):
                if use_bias:
                    init_params(self._weights[i], self._weights[i + 1])
                else:
                    init_params(self._weights[i])

        self._is_properly_setup()
Exemplo n.º 4
0
    def __init__(self,
                 in_shape=[32, 32, 3],
                 num_classes=10,
                 verbose=True,
                 n=5,
                 no_weights=False,
                 init_weights=None,
                 use_batch_norm=True,
                 bn_track_stats=True,
                 distill_bn_stats=False,
                 use_context_mod=False,
                 context_mod_inputs=False,
                 no_last_layer_context_mod=False,
                 context_mod_no_weights=False,
                 context_mod_post_activation=False,
                 context_mod_gain_offset=False,
                 context_mod_apply_pixel_wise=False):
        super(ResNet, self).__init__(num_classes, verbose)

        self._in_shape = in_shape
        self._n = n

        assert (init_weights is None or \
                (not no_weights or not context_mod_no_weights))
        self._no_weights = no_weights

        assert (not use_batch_norm or (not distill_bn_stats or bn_track_stats))

        self._use_batch_norm = use_batch_norm
        self._bn_track_stats = bn_track_stats
        self._distill_bn_stats = distill_bn_stats and use_batch_norm

        self._use_context_mod = use_context_mod
        self._context_mod_inputs = context_mod_inputs
        self._no_last_layer_context_mod = no_last_layer_context_mod
        self._context_mod_no_weights = context_mod_no_weights
        self._context_mod_post_activation = context_mod_post_activation
        self._context_mod_gain_offset = context_mod_gain_offset
        self._context_mod_apply_pixel_wise = context_mod_apply_pixel_wise

        self._kernel_size = [3, 3]
        self._filter_sizes = [16, 16, 32, 64]

        self._has_bias = True
        self._has_fc_out = True
        # We need to make sure that the last 2 entries of `weights` correspond
        # to the weight matrix and bias vector of the last layer.
        self._mask_fc_out = True
        # We don't use any output non-linearity.
        self._has_linear_out = True

        self._param_shapes = []
        self._weights = None if no_weights and context_mod_no_weights \
            else nn.ParameterList()
        self._hyper_shapes_learned = None \
            if not no_weights and not context_mod_no_weights else []

        #################################################
        ### Define and initialize context mod weights ###
        #################################################
        self._context_mod_layers = nn.ModuleList() if use_context_mod else None
        self._context_mod_shapes = [] if use_context_mod else None

        if use_context_mod:
            cm_ind = 0
            cm_shapes = []  # Output shape of all layers.
            if context_mod_inputs:
                cm_shapes.append([in_shape[2], *in_shape[:2]])
            layer_out_shapes = self._compute_layer_out_sizes()
            if no_last_layer_context_mod:
                cm_shapes.extend(layer_out_shapes[:-1])
            else:
                cm_shapes.extend(layer_out_shapes)

            if not context_mod_apply_pixel_wise:
                # Only scalar gain and shift per feature map!
                for i, s in enumerate(cm_shapes):
                    if len(s) == 3:
                        cm_shapes[i] = [s[0], 1, 1]

            for i, s in enumerate(cm_shapes):
                cmod_layer = ContextModLayer(
                    s,
                    no_weights=context_mod_no_weights,
                    apply_gain_offset=context_mod_gain_offset)
                self._context_mod_layers.append(cmod_layer)

                self.param_shapes.extend(cmod_layer.param_shapes)
                self._context_mod_shapes.extend(cmod_layer.param_shapes)
                if context_mod_no_weights:
                    self._hyper_shapes_learned.extend(cmod_layer.param_shapes)
                else:
                    self._weights.extend(cmod_layer.weights)

                # FIXME ugly code. Move initialization somewhere else.
                if not context_mod_no_weights and init_weights is not None:
                    assert (len(cmod_layer.weights) == 2)
                    for ii in range(2):
                        assert (np.all(np.equal( \
                            list(init_weights[cm_ind].shape),
                            list(cm_ind.weights[ii].shape))))
                        cmod_layer.weights[ii].data = init_weights[cm_ind]
                        cm_ind += 1

            if init_weights is not None:
                init_weights = init_weights[cm_ind:]

        ###########################
        ### Print infos to user ###
        ###########################
        # Compute the total number of weights in this network and display
        # them to the user.
        # Note, this complicated calculation is not necessary as we can simply
        # count the number of weights afterwards. But it's an additional sanity
        # check for us.
        fs = self._filter_sizes
        num_weights = np.prod(self._kernel_size) * \
                      (in_shape[2] * fs[0] + np.sum([fs[i] * fs[i + 1] + \
                                                     (2 * n - 1) * fs[i + 1] ** 2 for i in range(3)])) + \
                      (fs[0] + 2 * n * np.sum([fs[i] for i in range(1, 4)])) + \
                      (fs[-1] * num_classes + num_classes)

        cm_num_weights = MainNetInterface.shapes_to_num_weights( \
            self._context_mod_shapes) if use_context_mod else 0
        num_weights += cm_num_weights

        if use_batch_norm:
            # The gamma and beta parameters of a batch norm layer are
            # learned as well.
            num_weights += 2 * (fs[0] + \
                                2 * n * np.sum([fs[i] for i in range(1, 4)]))

        if verbose:
            print('A ResNet with %d layers and %d weights is created' \
                  % (6 * n + 2, num_weights)
                  + (' (including %d context-mod weights).' % cm_num_weights \
                         if cm_num_weights > 0 else '.'))

        ################################################
        ### Define and initialize batch norm weights ###
        ################################################
        self._batchnorm_layers = nn.ModuleList() if use_batch_norm else None

        if use_batch_norm:
            if distill_bn_stats:
                self._hyper_shapes_distilled = []

            bn_ind = 0
            for i, s in enumerate(self._filter_sizes):
                if i == 0:
                    num = 1
                else:
                    num = 2 * n

                for j in range(num):
                    bn_layer = BatchNormLayer(
                        s,
                        affine=not no_weights,
                        track_running_stats=bn_track_stats)
                    self._batchnorm_layers.append(bn_layer)

                    if distill_bn_stats:
                        self._hyper_shapes_distilled.extend( \
                            [list(p.shape) for p in bn_layer.get_stats(0)])

                    if not no_weights and init_weights is not None:
                        assert (len(bn_layer.weights) == 2)
                        for ii in range(2):
                            assert (np.all(np.equal( \
                                list(init_weights[bn_ind].shape),
                                list(bn_layer.weights[ii].shape))))
                            bn_layer.weights[ii].data = init_weights[bn_ind]
                            bn_ind += 1

            if init_weights is not None:
                init_weights = init_weights[bn_ind:]

        # Note, method `_compute_hyper_shapes` doesn't take context-mod into
        # consideration.
        self._param_shapes.extend(self._compute_hyper_shapes(no_weights=True))
        assert (num_weights == \
                MainNetInterface.shapes_to_num_weights(self._param_shapes))

        self._layer_weight_tensors = nn.ParameterList()
        self._layer_bias_vectors = nn.ParameterList()

        if no_weights:
            if self._hyper_shapes_learned is None:
                self._hyper_shapes_learned = self._compute_hyper_shapes()
            else:
                # Context-mod weights are already included.
                self._hyper_shapes_learned.extend(self._compute_hyper_shapes())

            self._is_properly_setup()
            return

        if use_batch_norm:
            for bn_layer in self._batchnorm_layers:
                self._weights.extend(bn_layer.weights)

        ############################################
        ### Define and initialize layer weights ###
        ###########################################
        ### Does not include context-mod or batchnorm weights.
        # First layer.
        self._layer_weight_tensors.append(
            nn.Parameter(torch.Tensor(self._filter_sizes[0], self._in_shape[2],
                                      *self._kernel_size),
                         requires_grad=True))
        self._layer_bias_vectors.append(
            nn.Parameter(torch.Tensor(self._filter_sizes[0]),
                         requires_grad=True))

        # Each block consists of 2n layers.
        for i in range(1, len(self._filter_sizes)):
            in_filters = self._filter_sizes[i - 1]
            out_filters = self._filter_sizes[i]

            for _ in range(2 * n):
                self._layer_weight_tensors.append(
                    nn.Parameter(torch.Tensor(out_filters, in_filters,
                                              *self._kernel_size),
                                 requires_grad=True))
                self._layer_bias_vectors.append(
                    nn.Parameter(torch.Tensor(out_filters),
                                 requires_grad=True))
                # Note, that the first layer in this block has potentially a
                # different number of input filters.
                in_filters = out_filters

        # After the average pooling, there is one more dense layer.
        self._layer_weight_tensors.append(
            nn.Parameter(torch.Tensor(num_classes, self._filter_sizes[-1]),
                         requires_grad=True))
        self._layer_bias_vectors.append(
            nn.Parameter(torch.Tensor(num_classes), requires_grad=True))

        # We add the weights interleaved, such that there are always consecutive
        # weight tensor and bias vector per layer. This fulfils the requirements
        # of attribute `mask_fc_out`.
        for i in range(len(self._layer_weight_tensors)):
            self._weights.append(self._layer_weight_tensors[i])
            self._weights.append(self._layer_bias_vectors[i])

        ### Initialize weights.
        if init_weights is not None:
            num_layers = 6 * n + 2
            assert (len(init_weights) == 2 * num_layers)
            offset = 0
            if use_batch_norm:
                offset = 2 * (6 * n + 1)
            assert (len(self._weights) == offset + 2 * num_layers)
            for i in range(len(init_weights)):
                j = offset + i
                assert (np.all(
                    np.equal(list(init_weights[i].shape),
                             list(self._weights[j].shape))))
                self._weights[j].data = init_weights[i]
        else:
            for i in range(len(self._layer_weight_tensors)):
                init_params(self._layer_weight_tensors[i],
                            self._layer_bias_vectors[i])

        self._is_properly_setup()
Exemplo n.º 5
0
Arquivo: mlp.py Projeto: ZixuanKe/CAT
    def __init__(self,
                 n_in=1,
                 n_out=1,
                 hidden_layers=[2000, 2000],
                 activation_fn=torch.nn.ReLU(),
                 use_bias=True,
                 no_weights=False,
                 init_weights=None,
                 dropout_rate=-1,
                 use_spectral_norm=False,
                 use_batch_norm=False,
                 bn_track_stats=True,
                 distill_bn_stats=False,
                 use_context_mod=False,
                 context_mod_inputs=False,
                 no_last_layer_context_mod=False,
                 context_mod_no_weights=False,
                 context_mod_post_activation=False,
                 context_mod_gain_offset=False,
                 out_fn=None,
                 verbose=True):
        # FIXME find a way using super to handle multiple inheritence.
        #super(MainNetwork, self).__init__()
        nn.Module.__init__(self)
        MainNetInterface.__init__(self)

        if use_spectral_norm:
            raise NotImplementedError('Spectral normalization not yet ' +
                                      'implemented for this network.')

        if use_batch_norm and use_context_mod:
            # FIXME Does it make sense to have both enabled?
            # I.e., should we produce a warning or error?
            pass

        self._a_fun = activation_fn
        assert(init_weights is None or \
               (not no_weights or not context_mod_no_weights))
        self._no_weights = no_weights
        self._dropout_rate = dropout_rate
        #self._use_spectral_norm = use_spectral_norm
        self._use_batch_norm = use_batch_norm
        self._bn_track_stats = bn_track_stats
        self._distill_bn_stats = distill_bn_stats and use_batch_norm
        self._use_context_mod = use_context_mod
        self._context_mod_inputs = context_mod_inputs
        self._no_last_layer_context_mod = no_last_layer_context_mod
        self._context_mod_no_weights = context_mod_no_weights
        self._context_mod_post_activation = context_mod_post_activation
        self._context_mod_gain_offset = context_mod_gain_offset
        self._out_fn = out_fn

        self._has_bias = use_bias
        self._has_fc_out = True
        # We need to make sure that the last 2 entries of `weights` correspond
        # to the weight matrix and bias vector of the last layer.
        self._mask_fc_out = True
        self._has_linear_out = True if out_fn is None else False

        if use_spectral_norm and no_weights:
            raise ValueError('Cannot use spectral norm in a network without ' +
                             'parameters.')

        # FIXME make sure that this implementation is correct in all situations
        # (e.g., what to do if weights are passed to the forward method?).
        if use_spectral_norm:
            self._spec_norm = nn.utils.spectral_norm
        else:
            self._spec_norm = lambda x: x  # identity

        self._param_shapes = []
        self._weights = None if no_weights and context_mod_no_weights \
            else nn.ParameterList()
        self._hyper_shapes_learned = None \
            if not no_weights and not context_mod_no_weights else []

        if dropout_rate != -1:
            assert (dropout_rate >= 0. and dropout_rate <= 1.)
            self._dropout = nn.Dropout(p=dropout_rate)

        ### Define and initialize context mod weights.
        self._context_mod_layers = nn.ModuleList() if use_context_mod else None
        self._context_mod_shapes = [] if use_context_mod else None

        if use_context_mod:
            cm_ind = 0
            cm_sizes = []
            if context_mod_inputs:
                cm_sizes.append(n_in)
            cm_sizes.extend(hidden_layers)
            if not no_last_layer_context_mod:
                cm_sizes.append(n_out)

            for i, n in enumerate(cm_sizes):
                cmod_layer = ContextModLayer(
                    n,
                    no_weights=context_mod_no_weights,
                    apply_gain_offset=context_mod_gain_offset)
                self._context_mod_layers.append(cmod_layer)

                self.param_shapes.extend(cmod_layer.param_shapes)
                self._context_mod_shapes.extend(cmod_layer.param_shapes)
                if context_mod_no_weights:
                    self._hyper_shapes_learned.extend(cmod_layer.param_shapes)
                else:
                    self._weights.extend(cmod_layer.weights)

                # FIXME ugly code. Move initialization somewhere else.
                if not context_mod_no_weights and init_weights is not None:
                    assert (len(cmod_layer.weights) == 2)
                    for ii in range(2):
                        assert(np.all(np.equal( \
                                list(init_weights[cm_ind].shape),
                                list(cm_ind.weights[ii].shape))))
                        cmod_layer.weights[ii].data = init_weights[cm_ind]
                        cm_ind += 1

            if init_weights is not None:
                init_weights = init_weights[cm_ind:]
        print('hidden_layers', hidden_layers)

        ### Define and initialize batch norm weights.
        self._batchnorm_layers = nn.ModuleList() if use_batch_norm else None

        if use_batch_norm:
            if distill_bn_stats:
                self._hyper_shapes_distilled = []

            bn_ind = 0
            for i, n in enumerate(hidden_layers):
                bn_layer = BatchNormLayer(n,
                                          affine=not no_weights,
                                          track_running_stats=bn_track_stats)
                self._batchnorm_layers.append(bn_layer)
                self._param_shapes.extend(bn_layer.param_shapes)

                if no_weights:
                    self._hyper_shapes_learned.extend(bn_layer.param_shapes)
                else:
                    self._weights.extend(bn_layer.weights)

                if distill_bn_stats:
                    self._hyper_shapes_distilled.extend( \
                        [list(p.shape) for p in bn_layer.get_stats(0)])

                # FIXME ugly code. Move initialization somewhere else.
                if not no_weights and init_weights is not None:
                    assert (len(bn_layer.weights) == 2)
                    for ii in range(2):
                        assert(np.all(np.equal( \
                                list(init_weights[bn_ind].shape),
                                list(bn_layer.weights[ii].shape))))
                        bn_layer.weights[ii].data = init_weights[bn_ind]
                        bn_ind += 1

            if init_weights is not None:
                init_weights = init_weights[bn_ind:]

        # Compute shapes of linear layers.
        linear_shapes = MLP.weight_shapes(n_in=n_in,
                                          n_out=n_out,
                                          hidden_layers=hidden_layers,
                                          use_bias=use_bias)
        self._param_shapes.extend(linear_shapes)

        num_weights = MainNetInterface.shapes_to_num_weights(
            self._param_shapes)

        if verbose:
            if use_context_mod:
                cm_num_weights = 0
                for cm_layer in self._context_mod_layers:
                    cm_num_weights += MainNetInterface.shapes_to_num_weights( \
                        cm_layer.param_shapes)

            print(
                'Creating an MLP with %d weights' % num_weights +
                (' (including %d weights associated with-' % cm_num_weights +
                 'context modulation)' if use_context_mod else '') + '.' +
                (' The network uses dropout.' if dropout_rate != -1 else '') +
                (' The network uses batchnorm.' if use_batch_norm else ''))

        self._layer_weight_tensors = nn.ParameterList()
        self._layer_bias_vectors = nn.ParameterList()

        if no_weights:
            self._hyper_shapes_learned.extend(linear_shapes)
            self._is_properly_setup()
            return

        ### Define and initialize linear weights.
        for i, dims in enumerate(linear_shapes):
            self._weights.append(
                nn.Parameter(torch.Tensor(*dims), requires_grad=True))
            if len(dims) == 1:
                self._layer_bias_vectors.append(self._weights[-1])
            else:
                self._layer_weight_tensors.append(self._weights[-1])

        if init_weights is not None:
            assert (len(init_weights) == len(linear_shapes))
            for i in range(len(init_weights)):
                assert (np.all(
                    np.equal(list(init_weights[i].shape), linear_shapes[i])))
                if use_bias:
                    if i % 2 == 0:
                        self._layer_weight_tensors[i //
                                                   2].data = init_weights[i]
                    else:
                        self._layer_bias_vectors[i // 2].data = init_weights[i]
                else:
                    self._layer_weight_tensors[i].data = init_weights[i]
        else:
            for i in range(len(self._layer_weight_tensors)):
                if use_bias:
                    init_params(self._layer_weight_tensors[i],
                                self._layer_bias_vectors[i])
                else:
                    init_params(self._layer_weight_tensors[i])

        self._is_properly_setup()
    def __init__(self, in_shape=(32, 32, 3), num_classes=10, no_weights=False,
                 init_weights=None, use_context_mod=False,
                 context_mod_inputs=False, no_last_layer_context_mod=False,
                 context_mod_no_weights=False,
                 context_mod_post_activation=False,
                 context_mod_gain_offset=False,
                 context_mod_gain_softplus=False,
                 context_mod_apply_pixel_wise=False):
        super(BioConvNet, self).__init__(num_classes, True)

        assert(len(in_shape) == 3)
        # FIXME This assertion is not mandatory but a sanity check that the user
        # uses the Tensorflow layout.
        assert(in_shape[2] in [1, 3])
        assert(init_weights is None or \
               (not no_weights or not context_mod_no_weights))
        self._in_shape = in_shape
        self._no_weights = no_weights
        self._use_context_mod = use_context_mod
        self._context_mod_inputs = context_mod_inputs
        self._no_last_layer_context_mod = no_last_layer_context_mod
        self._context_mod_no_weights = context_mod_no_weights
        self._context_mod_post_activation = context_mod_post_activation
        self._context_mod_gain_offset = context_mod_gain_offset
        self._context_mod_gain_softplus = context_mod_gain_softplus
        self._context_mod_apply_pixel_wise = context_mod_apply_pixel_wise

        self._has_bias = True
        self._has_fc_out = True
        # We need to make sure that the last 2 entries of `weights` correspond
        # to the weight matrix and bias vector of the last layer.
        self._mask_fc_out = True
        self._has_linear_out = True

        self._param_shapes = []
        self._weights = None if no_weights and context_mod_no_weights \
            else nn.ParameterList()
        self._hyper_shapes_learned = None \
            if not no_weights and not context_mod_no_weights else []
        self._layer_weight_tensors = nn.ParameterList()
        self._layer_bias_vectors = nn.ParameterList()

        # Shapes of output activities for context-modulation, if used.
        cm_shapes = [] # Output shape of all layers.
        if context_mod_inputs:
            cm_shapes.append([in_shape[2], *in_shape[:2]])

        ### Define and initialize all conv and linear layers
        ### Bio-conv layers.
        H = in_shape[0]
        W = in_shape[1]
        C_in = in_shape[2]

        C = [64, 128, 256]
        K = [5, 5, 3]
        S = [2, 2, 1]
        P = [0, 0, 1]

        self._conv_layer = []

        for i, C_out in enumerate(C):
            self._conv_layer.append(LocalConv2dLayer(C_in, C_out, H, W, K[i],
                stride=S[i], padding=P[i], no_weights=no_weights))
            H = self._conv_layer[-1].out_height
            W = self._conv_layer[-1].out_width

            cm_shapes.append([C_out, H, W])

            C_in = C_out

            self._param_shapes.extend(self._conv_layer[-1].param_shapes)
            if no_weights:
                self._hyper_shapes_learned.extend( \
                    self._conv_layer[-1].param_shapes)
            else:
                self._weights.extend(self._conv_layer[-1].weights)

                assert(len(self._conv_layer[-1].weights) == 2)
                self._layer_weight_tensors.append( \
                    self._conv_layer[-1].filters)
                self._layer_bias_vectors.append( \
                    self._conv_layer[-1].bias)

        ### Linear layers
        n_in = H * W * C_out
        assert(n_in == 6400)
        n = [1024, num_classes]
        for i, n_out in enumerate(n):
            W_shape = [n_out, n_in]
            b_shape = [n_out]

            # Note, that the last layer shape might not be used for context-
            # modulation.
            if i < (len(n)-1) or not no_last_layer_context_mod:
                cm_shapes.append([n_out])

            n_in = n_out

            self._param_shapes.extend([W_shape, b_shape])
            if no_weights:
                self._hyper_shapes_learned.extend([W_shape, b_shape])
            else:
                W = nn.Parameter(torch.Tensor(*W_shape), requires_grad=True)
                b = nn.Parameter(torch.Tensor(*b_shape), requires_grad=True)

                init_params(W, b)

                self._weights.extend([W, b])
                self._layer_weight_tensors.append(W)
                self._layer_bias_vectors.append(b)

        ### Define and initialize context mod weights.
        self._context_mod_layers = nn.ModuleList() if use_context_mod else None
        self._context_mod_shapes = [] if use_context_mod else None
        self._context_mod_weights = nn.ParameterList() if use_context_mod \
            else None

        if use_context_mod:
            if not context_mod_apply_pixel_wise:
                # Only scalar gain and shift per feature map!
                for i, s in enumerate(cm_shapes):
                    if len(s) == 3:
                        cm_shapes[i] = [s[0], 1, 1]

            for i, s in enumerate(cm_shapes):
                cmod_layer = ContextModLayer(s,
                    no_weights=context_mod_no_weights,
                    apply_gain_offset=context_mod_gain_offset,
                    apply_gain_softplus=context_mod_gain_softplus)
                self._context_mod_layers.append(cmod_layer)

                self._context_mod_shapes.extend(cmod_layer.param_shapes)
                if not context_mod_no_weights:
                    self._context_mod_weights.extend(cmod_layer.weights)

            # We always had the context mod weights/shapes at the beginning of
            # our list attributes.
            self._param_shapes = self._context_mod_shapes + self._param_shapes
            if context_mod_no_weights:
                self._hyper_shapes_learned = self._context_mod_shapes + \
                    self._hyper_shapes_learned
            else:
                tmp = self._weights
                self._weights = nn.ParameterList(self._context_mod_weights)
                for w in tmp:
                    self._weights.append(w)

        ### Apply custom init if given.
        if init_weights is not None:
            assert(len(self.weights) == len(init_weights))
            for i in range(len(init_weights)):
                assert(np.all(np.equal(list(init_weights[i].shape),
                                       list(self._weights[i].shape))))
                self._weights[i].data = init_weights[i]

        ### Print user info.
        num_weights = MainNetInterface.shapes_to_num_weights( \
            self._param_shapes)
        if use_context_mod:
            cm_num_weights = MainNetInterface.shapes_to_num_weights( \
                self._context_mod_shapes)

        print('Creating bio-plausible convnet with %d weights' % num_weights
              + (' (including %d weights associated with-' % cm_num_weights
                 + 'context modulation)' if use_context_mod else '') + '.')

        self._is_properly_setup()
    def __init__(self, in_shape=(32, 32, 3), num_classes=10, n=4, k=10,
                 num_feature_maps=(16, 16, 32, 64), use_bias=True,
                 use_fc_bias=None, no_weights=False, use_batch_norm=True,
                 bn_track_stats=True, distill_bn_stats=False, dropout_rate=-1,
                 chw_input_format=False, verbose=True, **kwargs):
        super(WRN, self).__init__(num_classes, verbose)

        ### Parse or set context-mod arguments ###
        rem_kwargs = MainNetInterface._parse_context_mod_args(kwargs)
        if len(rem_kwargs) > 0:
            raise ValueError('Keyword arguments %s unknown.' % str(rem_kwargs))
        # Since this is a conv-net, we may also want to add the following.
        if 'context_mod_apply_pixel_wise' not in kwargs.keys():
            kwargs['context_mod_apply_pixel_wise'] = False

        self._use_context_mod = kwargs['use_context_mod']
        self._context_mod_inputs = kwargs['context_mod_inputs']
        self._no_last_layer_context_mod = kwargs['no_last_layer_context_mod']
        self._context_mod_no_weights = kwargs['context_mod_no_weights']
        self._context_mod_post_activation = \
            kwargs['context_mod_post_activation']
        self._context_mod_gain_offset = kwargs['context_mod_gain_offset']
        self._context_mod_gain_softplus = kwargs['context_mod_gain_softplus']
        self._context_mod_apply_pixel_wise = \
            kwargs['context_mod_apply_pixel_wise']

        ### Check or parse remaining arguments ###
        self._in_shape = in_shape
        self._n = n
        self._k = k
        if use_fc_bias is None:
            use_fc_bias = use_bias
        # Also, checkout attribute `_has_bias` below.
        self._use_bias = use_bias
        self._use_fc_bias = use_fc_bias
        self._no_weights = no_weights
        assert not use_batch_norm or (not distill_bn_stats or bn_track_stats)
        self._use_batch_norm = use_batch_norm
        self._bn_track_stats = bn_track_stats
        self._distill_bn_stats = distill_bn_stats and use_batch_norm
        self._dropout_rate = dropout_rate
        self._chw_input_format = chw_input_format

        # The original authors found that the best configuration uses this
        # kernel in all convolutional layers.
        self._kernel_size = (3, 3)
        if len(num_feature_maps) != 4:
            raise ValueError('Option "num_feature_maps" must be a list of 4 ' +
                             'integers.')
        self._filter_sizes = list(num_feature_maps)
        if k != 1:
            for i in range(1, 4):
                self._filter_sizes[i] = k * num_feature_maps[i]
        # Strides used in the first layer of each convolutional group.
        self._strides = (1, 1, 2, 2)

        ### Set required class attributes ###
        # Note, we did overwrite the getter for attribute `has_bias`, as it is
        # not applicable if the values of `use_bias` and `use_fc_bias` differ.
        self._has_bias = use_bias if use_bias == use_fc_bias else False
        self._has_fc_out = True
        # We need to make sure that the last 2 entries of `weights` correspond
        # to the weight matrix and bias vector of the last layer!
        self._mask_fc_out = True
        self._has_linear_out = True

        self._param_shapes = []
        self._param_shapes_meta = []
        self._internal_params = None if no_weights and \
            self._context_mod_no_weights else nn.ParameterList()
        self._hyper_shapes_learned = None \
            if not no_weights and not self._context_mod_no_weights else []
        self._hyper_shapes_learned_ref = None if self._hyper_shapes_learned \
            is None else []
        self._layer_weight_tensors = nn.ParameterList()
        self._layer_bias_vectors = nn.ParameterList()

        if dropout_rate != -1:
            assert dropout_rate >= 0. and dropout_rate <= 1.
            self._dropout = nn.Dropout(p=dropout_rate)

        #################################
        ### Create context mod layers ###
        #################################
        self._context_mod_layers = nn.ModuleList() if self._use_context_mod \
            else None

        if self._use_context_mod:
            cm_layer_inds = []
            cm_shapes = [] # Output shape of all layers.
            if self._context_mod_inputs:
                cm_shapes.append([in_shape[2], *in_shape[:2]])
                # We reserve layer zero for input context-mod. Otherwise, there
                # is no layer zero.
                cm_layer_inds.append(0)

            layer_out_shapes = self._compute_layer_out_sizes()
            cm_shapes.extend(layer_out_shapes)
            # All layer indices `l` with `l mod 3 == 0` are context-mod layers.
            cm_layer_inds.extend(range(3, 3*len(layer_out_shapes)+1, 3))
            if self._no_last_layer_context_mod:
                cm_shapes = cm_shapes[:-1]
                cm_layer_inds = cm_layer_inds[:-1]
            if not self._context_mod_apply_pixel_wise:
                # Only scalar gain and shift per feature map!
                for i, s in enumerate(cm_shapes):
                    if len(s) == 3:
                        cm_shapes[i] = [s[0], 1, 1]

            self._add_context_mod_layers(cm_shapes, cm_layers=cm_layer_inds)

        ###############################
        ### Create batchnorm layers ###
        ###############################
        # We just use even numbers starting from 2 as layer indices for
        # batchnorm layers.
        if use_batch_norm:
            bn_sizes = []
            for i, s in enumerate(self._filter_sizes):
                if i == 0:
                    bn_sizes.append(s)
                else:
                    bn_sizes.extend([s] * (2*n))

            # All layer indices `l` with `l mod 3 == 2` are batchnorm layers.
            self._add_batchnorm_layers(bn_sizes, no_weights,
                bn_layers=list(range(2, 3*len(bn_sizes)+1, 3)),
                distill_bn_stats=distill_bn_stats,
                bn_track_stats=bn_track_stats)

        ######################################
        ### Create skip connection weights ###
        ######################################
        # We use 1x1 convolutional layers for residual blocks in case the
        # number of input and output feature maps disagrees. We also use 1x1
        # convolutions whenever a stride greater than 1 is applied. This is not
        # necessary in my opinion (as it adds extra weights that do not affect
        # the downsampling itself), but commonly done; or instance, in the
        # original PyTorch implementation.
        # Note, there may be maximally 3 1x1 layers added to the network.
        # Note, we use 1x1 conv layers without biases.
        skip_1x1_shapes = []
        self._group_has_1x1 = [False] * 3
        for i in range(1, 4):
            if self._filter_sizes[i-1] != self._filter_sizes[i] or \
                    self._strides[i] != 1:
                skip_1x1_shapes.append([self._filter_sizes[i],
                                        self._filter_sizes[i-1], 1, 1])
                self._group_has_1x1[i-1] = True

        for s in skip_1x1_shapes:
            if not no_weights:
                self._internal_params.append(nn.Parameter( \
                    torch.Tensor(*s), requires_grad=True))
                self._layer_weight_tensors.append(self._internal_params[-1])
                init_params(self._layer_weight_tensors[-1])
            else:
                self._hyper_shapes_learned.append(s)
                self._hyper_shapes_learned_ref.append(len(self.param_shapes))

            self._param_shapes.append(s)
            self._param_shapes_meta.append({
                'name': 'weight',
                'index': -1 if no_weights else \
                    len(self._internal_params)-1,
                'layer': -1
            })

        ############################################################
        ### Create convolutional layers and final linear weights ###
        ############################################################
        # Convolutional layers will get IDs `l` such that `l mod 3 == 1`.
        layer_id = 1
        for i in range(5):
            if i == 0: ### Fist layer.
                num = 1
                prev_fs = self._in_shape[2]
                curr_fs = self._filter_sizes[0]
            elif i == 4: ### Final fully-connected layer.
                num = 1
                curr_fs = num_classes
            else: # Group of residual blocks.
                num = 2 * n
                curr_fs = self._filter_sizes[i]

            for _ in range(num):
                if i == 4:
                    layer_shapes = [[curr_fs, prev_fs]]
                    if use_fc_bias:
                        layer_shapes.append([curr_fs])
                else:
                    layer_shapes = [[curr_fs, prev_fs, *self._kernel_size]]
                    if use_bias:
                        layer_shapes.append([curr_fs])

                for s in layer_shapes:
                    if not no_weights:
                        self._internal_params.append(nn.Parameter( \
                            torch.Tensor(*s), requires_grad=True))
                        if len(s) == 1:
                            self._layer_bias_vectors.append( \
                                self._internal_params[-1])
                        else:
                            self._layer_weight_tensors.append( \
                                self._internal_params[-1])
                    else:
                        self._hyper_shapes_learned.append(s)
                        self._hyper_shapes_learned_ref.append( \
                            len(self.param_shapes))

                    self._param_shapes.append(s)
                    self._param_shapes_meta.append({
                        'name': 'weight' if len(s) != 1 else 'bias',
                        'index': -1 if no_weights else \
                            len(self._internal_params)-1,
                        'layer': layer_id
                    })

                prev_fs = curr_fs
                layer_id += 3

                # Initialize_weights
                if not no_weights:
                    init_params(self._layer_weight_tensors[-1],
                        self._layer_bias_vectors[-1] \
                        if len(layer_shapes) == 2 else None)

        ###########################
        ### Print infos to user ###
        ###########################
        if verbose:
            if self._use_context_mod:
                cm_param_shapes = []
                for cm_layer in self.context_mod_layers:
                    cm_param_shapes.extend(cm_layer.param_shapes)
                cm_num_params = \
                    MainNetInterface.shapes_to_num_weights(cm_param_shapes)

            print('Creating a WideResnet "%s" with %d weights' \
                  % (str(self), self.num_params)
                  + (' (including %d weights associated with-' % cm_num_params
                     + 'context modulation)' if self._use_context_mod else '')
                  + '.'
                  + (' The network uses batchnorm.' if use_batch_norm  else '')
                  + (' The network uses dropout.' if dropout_rate != -1 \
                     else ''))

        self._is_properly_setup(check_has_bias=False)
Exemplo n.º 8
0
    def __init__(self,
                 in_shape=(224, 224, 3),
                 num_classes=1000,
                 use_bias=True,
                 use_fc_bias=None,
                 num_feature_maps=(64, 64, 128, 256, 512),
                 blocks_per_group=(2, 2, 2, 2),
                 projection_shortcut=False,
                 bottleneck_blocks=False,
                 cutout_mod=False,
                 no_weights=False,
                 use_batch_norm=True,
                 bn_track_stats=True,
                 distill_bn_stats=False,
                 chw_input_format=False,
                 verbose=True,
                 **kwargs):
        super(ResNetIN, self).__init__(num_classes, verbose)

        ### Parse or set context-mod arguments ###
        rem_kwargs = MainNetInterface._parse_context_mod_args(kwargs)
        if 'context_mod_apply_pixel_wise' in rem_kwargs:
            rem_kwargs.remove('context_mod_apply_pixel_wise')
        if len(rem_kwargs) > 0:
            raise ValueError('Keyword arguments %s unknown.' % str(rem_kwargs))
        # Since this is a conv-net, we may also want to add the following.
        if 'context_mod_apply_pixel_wise' not in kwargs.keys():
            kwargs['context_mod_apply_pixel_wise'] = False

        self._use_context_mod = kwargs['use_context_mod']
        self._context_mod_inputs = kwargs['context_mod_inputs']
        self._no_last_layer_context_mod = kwargs['no_last_layer_context_mod']
        self._context_mod_no_weights = kwargs['context_mod_no_weights']
        self._context_mod_post_activation = \
            kwargs['context_mod_post_activation']
        self._context_mod_gain_offset = kwargs['context_mod_gain_offset']
        self._context_mod_gain_softplus = kwargs['context_mod_gain_softplus']
        self._context_mod_apply_pixel_wise = \
            kwargs['context_mod_apply_pixel_wise']

        ### Check or parse remaining arguments ###
        self._in_shape = in_shape
        self._projection_shortcut = projection_shortcut
        self._bottleneck_blocks = bottleneck_blocks
        self._cutout_mod = cutout_mod
        if use_fc_bias is None:
            use_fc_bias = use_bias
        # Also, checkout attribute `_has_bias` below.
        self._use_bias = use_bias
        self._use_fc_bias = use_fc_bias
        self._no_weights = no_weights
        assert not use_batch_norm or (not distill_bn_stats or bn_track_stats)
        self._use_batch_norm = use_batch_norm
        self._bn_track_stats = bn_track_stats
        self._distill_bn_stats = distill_bn_stats and use_batch_norm
        self._chw_input_format = chw_input_format

        if len(blocks_per_group) != 4:
            raise ValueError('Option "blocks_per_group" must be a list of 4 ' +
                             'integers.')
        self._num_blocks = blocks_per_group
        if len(num_feature_maps) != 5:
            raise ValueError('Option "num_feature_maps" must be a list of 5 ' +
                             'integers.')
        self._filter_sizes = list(num_feature_maps)

        # The first layer of group 3, 4 and 5 uses a strided convolution, so
        # the shorcut connections need to perform a downsampling operation. In
        # addition, whenever traversing from one group to the next, the number
        # of feature maps might change. In all these cases, the network might
        # benefit from smart shortcut connections, which means using projection
        # shortcuts, where a 1x1 conv is used for the mentioned skip connection.
        self._num_non_ident_skips = 3  # Strided convs: 2->3, 3->4 and 4->5
        fs1 = self._filter_sizes[1]
        if self._bottleneck_blocks:
            fs1 *= 4
        if self._filter_sizes[0] != fs1:
            self._num_non_ident_skips += 1  # Also handle 1->2.
        self._group_has_1x1 = [False] * 4
        if self._projection_shortcut:
            for i in range(3, 3 - self._num_non_ident_skips, -1):
                self._group_has_1x1[i] = True
        # Number of conv layers (excluding skip connections)
        self._num_main_conv_layers = 1 + int(np.sum([self._num_blocks[i] * \
            (3 if self._bottleneck_blocks else 2) for i in range(4)]))

        # The original architecture uses a 7x7 kernel in the first conv layer
        # and 3x3 or 1x1 kernels in all remaining layers.
        self._init_kernel_size = (7, 7)
        # All 3x3 layers have padding 1 and 1x1 layers have padding 0.
        self._init_padding = 3
        self._init_stride = 2

        if self._cutout_mod:
            self._init_kernel_size = (3, 3)
            self._init_padding = 1
            self._init_stride = 1

        ### Set required class attributes ###
        # Note, we did overwrite the getter for attribute `has_bias`, as it is
        # not applicable if the values of `use_bias` and `use_fc_bias` differ.
        self._has_bias = use_bias if use_bias == use_fc_bias else False
        self._has_fc_out = True
        # We need to make sure that the last 2 entries of `weights` correspond
        # to the weight matrix and bias vector of the last layer!
        self._mask_fc_out = True
        self._has_linear_out = True

        self._param_shapes = []
        self._param_shapes_meta = []
        self._internal_params = None if no_weights and \
            self._context_mod_no_weights else nn.ParameterList()
        self._hyper_shapes_learned = None \
            if not no_weights and not self._context_mod_no_weights else []
        self._hyper_shapes_learned_ref = None if self._hyper_shapes_learned \
            is None else []
        self._layer_weight_tensors = nn.ParameterList()
        self._layer_bias_vectors = nn.ParameterList()

        #################################
        ### Create context mod layers ###
        #################################
        self._context_mod_layers = nn.ModuleList() if self._use_context_mod \
            else None

        if self._use_context_mod:
            cm_layer_inds = []
            cm_shapes = []  # Output shape of all layers.
            if self._context_mod_inputs:
                cm_shapes.append([in_shape[2], *in_shape[:2]])
                # We reserve layer zero for input context-mod. Otherwise, there
                # is no layer zero.
                cm_layer_inds.append(0)

            layer_out_shapes = self._compute_layer_out_sizes()
            cm_shapes.extend(layer_out_shapes)
            # All layer indices `l` with `l mod 3 == 0` are context-mod layers.
            cm_layer_inds.extend(range(3, 3 * len(layer_out_shapes) + 1, 3))
            if self._no_last_layer_context_mod:
                cm_shapes = cm_shapes[:-1]
                cm_layer_inds = cm_layer_inds[:-1]
            if not self._context_mod_apply_pixel_wise:
                # Only scalar gain and shift per feature map!
                for i, s in enumerate(cm_shapes):
                    if len(s) == 3:
                        cm_shapes[i] = [s[0], 1, 1]

            self._add_context_mod_layers(cm_shapes, cm_layers=cm_layer_inds)

        ###############################
        ### Create batchnorm layers ###
        ###############################
        # We just use even numbers starting from 2 as layer indices for
        # batchnorm layers.
        if use_batch_norm:
            bn_sizes = []
            for i, s in enumerate(self._filter_sizes):
                if i == 0:
                    bn_sizes.append(s)
                else:
                    for _ in range(self._num_blocks[i - 1]):
                        if self._bottleneck_blocks:
                            bn_sizes.extend([s, s, 4 * s])
                        else:
                            bn_sizes.extend([s, s])

            # All layer indices `l` with `l mod 3 == 2` are batchnorm layers.
            bn_layers = list(range(2, 3 * len(bn_sizes) + 1, 3))

            # We also need a batchnorm layer per skip connection that uses 1x1
            # projections.
            if self._projection_shortcut:
                bn_layer_ind_skip = 3 * (self._num_main_conv_layers + 1) + 2

                factor = 4 if self._bottleneck_blocks else 1
                for i in range(4):  # For each transition between conv groups.
                    if self._group_has_1x1[i]:
                        bn_sizes.append(self._filter_sizes[i + 1] * factor)
                        bn_layers.append(bn_layer_ind_skip)

                        bn_layer_ind_skip += 3

            self._add_batchnorm_layers(bn_sizes,
                                       no_weights,
                                       bn_layers=bn_layers,
                                       distill_bn_stats=distill_bn_stats,
                                       bn_track_stats=bn_track_stats)

        ######################################
        ### Create skip connection weights ###
        ######################################
        if self._projection_shortcut:
            layer_ind_skip = 3 * (self._num_main_conv_layers + 1) + 1

            factor = 4 if self._bottleneck_blocks else 1

            n_in = self._filter_sizes[0]
            for i in range(4):  # For each transition between conv groups.
                if not self._group_has_1x1[i]:
                    continue

                n_out = self._filter_sizes[i + 1] * factor

                skip_1x1_shape = [n_out, n_in, 1, 1]

                if not no_weights:
                    self._internal_params.append(nn.Parameter( \
                        torch.Tensor(*skip_1x1_shape), requires_grad=True))
                    self._layer_weight_tensors.append(
                        self._internal_params[-1])
                    self._layer_bias_vectors.append(None)
                    init_params(self._layer_weight_tensors[-1])
                else:
                    self._hyper_shapes_learned.append(skip_1x1_shape)
                    self._hyper_shapes_learned_ref.append( \
                        len(self.param_shapes))

                self._param_shapes.append(skip_1x1_shape)
                self._param_shapes_meta.append({
                    'name': 'weight',
                    'index': -1 if no_weights else \
                        len(self._internal_params)-1,
                    'layer': layer_ind_skip
                })

                layer_ind_skip += 3
                n_in = n_out

        ############################################################
        ### Create convolutional layers and final linear weights ###
        ############################################################
        # Convolutional layers will get IDs `l` such that `l mod 3 == 1`.
        layer_id = 1
        n_per_block = 3 if self._bottleneck_blocks else 2
        for i in range(6):
            if i == 0:  ### Fist layer.
                num = 1
                prev_fs = self._in_shape[2]
                curr_fs = self._filter_sizes[0]

                kernel_size = self._init_kernel_size
                #stride = self._init_stride
            elif i == 5:  ### Final fully-connected layer.
                num = 1
                curr_fs = num_classes

                kernel_size = None
            else:  # Group of residual blocks.
                num = self._num_blocks[i - 1] * n_per_block
                curr_fs = self._filter_sizes[i]

                kernel_size = (3, 3)  # depends on block structure!

            for n in range(num):
                if i == 5:
                    layer_shapes = [[curr_fs, prev_fs]]
                    if use_fc_bias:
                        layer_shapes.append([curr_fs])

                    prev_fs = curr_fs
                else:
                    if i > 0 and self._bottleneck_blocks:
                        if n % 3 == 0:
                            fs = curr_fs
                            ks = (1, 1)
                        elif n % 3 == 1:
                            fs = curr_fs
                            ks = kernel_size
                        else:
                            fs = 4 * curr_fs
                            ks = (1, 1)
                    elif i > 0 and not self._bottleneck_blocks:
                        fs = curr_fs
                        ks = kernel_size
                    else:
                        fs = curr_fs
                        ks = kernel_size

                    layer_shapes = [[fs, prev_fs, *ks]]
                    if use_bias:
                        layer_shapes.append([fs])

                    prev_fs = fs

                for s in layer_shapes:
                    if not no_weights:
                        self._internal_params.append(nn.Parameter( \
                            torch.Tensor(*s), requires_grad=True))
                        if len(s) == 1:
                            self._layer_bias_vectors.append( \
                                self._internal_params[-1])
                        else:
                            self._layer_weight_tensors.append( \
                                self._internal_params[-1])
                    else:
                        self._hyper_shapes_learned.append(s)
                        self._hyper_shapes_learned_ref.append( \
                            len(self.param_shapes))

                    self._param_shapes.append(s)
                    self._param_shapes_meta.append({
                        'name': 'weight' if len(s) != 1 else 'bias',
                        'index': -1 if no_weights else \
                            len(self._internal_params)-1,
                        'layer': layer_id
                    })

                layer_id += 3

                # Initialize_weights
                if not no_weights:
                    init_params(self._layer_weight_tensors[-1],
                        self._layer_bias_vectors[-1] \
                        if len(layer_shapes) == 2 else None)

        ###########################
        ### Print infos to user ###
        ###########################
        if verbose:
            if self._use_context_mod:
                cm_param_shapes = []
                for cm_layer in self.context_mod_layers:
                    cm_param_shapes.extend(cm_layer.param_shapes)
                cm_num_params = \
                    MainNetInterface.shapes_to_num_weights(cm_param_shapes)

            print('Creating a "%s" with %d weights' \
                  % (str(self), self.num_params)
                  + (' (including %d weights associated with-' % cm_num_params
                     + 'context modulation)' if self._use_context_mod else '')
                  + '.'
                  + (' The network uses batchnorm.' if use_batch_norm  else ''))

        self._is_properly_setup(check_has_bias=False)
Exemplo n.º 9
0
    def __init__(
            self,
            in_shape=(28, 28, 1),
            num_classes=10,
            verbose=True,
            arch='mnist_large',
            no_weights=False,
            init_weights=None,
            dropout_rate=-1,  #0.5
            **kwargs):
        super(LeNet, self).__init__(num_classes, verbose)

        self._in_shape = in_shape
        assert arch in LeNet._ARCHITECTURES.keys()
        self._chosen_arch = LeNet._ARCHITECTURES[arch]
        if num_classes != 10:
            self._chosen_arch[-2][0] = num_classes
            self._chosen_arch[-1][0] = num_classes

        # Sanity check, given current implementation.
        if arch.startswith('mnist'):
            if not in_shape[0] == in_shape[1] == 28:
                raise ValueError('MNIST LeNet architectures expect input ' +
                                 'images of size 28x28.')
        else:
            if not in_shape[0] == in_shape[1] == 32:
                raise ValueError('CIFAR LeNet architectures expect input ' +
                                 'images of size 32x32.')

        ### Parse or set context-mod arguments ###
        rem_kwargs = MainNetInterface._parse_context_mod_args(kwargs)
        if len(rem_kwargs) > 0:
            raise ValueError('Keyword arguments %s unknown.' % str(rem_kwargs))
        # Since this is a conv-net, we may also want to add the following.
        if 'context_mod_apply_pixel_wise' not in kwargs.keys():
            kwargs['context_mod_apply_pixel_wise'] = False

        self._use_context_mod = kwargs['use_context_mod']
        self._context_mod_inputs = kwargs['context_mod_inputs']
        self._no_last_layer_context_mod = kwargs['no_last_layer_context_mod']
        self._context_mod_no_weights = kwargs['context_mod_no_weights']
        self._context_mod_post_activation = \
            kwargs['context_mod_post_activation']
        self._context_mod_gain_offset = kwargs['context_mod_gain_offset']
        self._context_mod_gain_softplus = kwargs['context_mod_gain_softplus']
        self._context_mod_apply_pixel_wise = \
            kwargs['context_mod_apply_pixel_wise']

        ### Setup class attributes ###
        assert(init_weights is None or \
               (not no_weights or not self._context_mod_no_weights))
        self._no_weights = no_weights
        self._dropout_rate = dropout_rate

        self._has_bias = True
        self._has_fc_out = True
        # We need to make sure that the last 2 entries of `weights` correspond
        # to the weight matrix and bias vector of the last layer!
        self._mask_fc_out = True
        self._has_linear_out = True

        self._param_shapes = []
        self._param_shapes_meta = []
        self._internal_params = None if no_weights and \
            self._context_mod_no_weights else nn.ParameterList()
        self._hyper_shapes_learned = None \
            if not no_weights and not self._context_mod_no_weights else []
        self._hyper_shapes_learned_ref = None if self._hyper_shapes_learned \
            is None else []
        self._layer_weight_tensors = nn.ParameterList()
        self._layer_bias_vectors = nn.ParameterList()

        if dropout_rate != -1:
            assert (dropout_rate >= 0. and dropout_rate <= 1.)
            # FIXME `nn.Dropout2d` zeroes out whole feature maps. Is that really
            # desired here?
            self._drop_conv1 = nn.Dropout2d(p=dropout_rate)
            self._drop_conv2 = nn.Dropout2d(p=dropout_rate)
            self._drop_fc1 = nn.Dropout(p=dropout_rate)

        ### Define and initialize context mod layers/weights ###
        self._context_mod_layers = nn.ModuleList() if self._use_context_mod \
            else None

        if self._use_context_mod:
            cm_layer_inds = []
            cm_shapes = []  # Output shape of all context-mod layers.
            if self._context_mod_inputs:
                cm_shapes.append([in_shape[2], *in_shape[:2]])
                # We reserve layer zero for input context-mod. Otherwise, there
                # is no layer zero.
                cm_layer_inds.append(0)

            layer_out_shapes = self._compute_layer_out_sizes()
            # Context-modulation is applied after the pooling layers.
            # So we delete the shapes of the conv-layer outputs and keep the
            # ones of the pooling layer outputs.
            del layer_out_shapes[2]
            del layer_out_shapes[0]
            cm_shapes.extend(layer_out_shapes)
            cm_layer_inds.extend(range(2, 2 * len(layer_out_shapes) + 1, 2))
            if self._no_last_layer_context_mod:
                cm_shapes = cm_shapes[:-1]
                cm_layer_inds = cm_layer_inds[:-1]

            if not self._context_mod_apply_pixel_wise:
                # Only scalar gain and shift per feature map!
                for i, s in enumerate(cm_shapes):
                    if len(s) == 3:
                        cm_shapes[i] = [s[0], 1, 1]

            self._add_context_mod_layers(cm_shapes, cm_layers=cm_layer_inds)

        ### Define and add conv- and fc-layer weights.
        for i, s in enumerate(self._chosen_arch):
            if not no_weights:
                self._internal_params.append(
                    nn.Parameter(torch.Tensor(*s), requires_grad=True))
                if len(s) == 1:
                    self._layer_bias_vectors.append(self._internal_params[-1])
                else:
                    self._layer_weight_tensors.append(
                        self._internal_params[-1])
            else:
                self._hyper_shapes_learned.append(s)
                self._hyper_shapes_learned_ref.append(len(self.param_shapes))

            self._param_shapes.append(s)
            self._param_shapes_meta.append({
                'name':
                'weight' if len(s) != 1 else 'bias',
                'index':
                -1 if no_weights else len(self._internal_params) - 1,
                'layer':
                2 * (i // 2) + 1
            })

        ### Initialize weights.
        if init_weights is not None:
            assert len(init_weights) == len(self.weights)
            for i in range(len(init_weights)):
                assert np.all(
                    np.equal(list(init_weights[i].shape),
                             self.weights[i].shape))
                self.weights[i].data = init_weights[i]
        else:
            for i in range(len(self._layer_weight_tensors)):
                init_params(self._layer_weight_tensors[i],
                            self._layer_bias_vectors[i])

        ### Print user info.
        if verbose:
            if self._use_context_mod:
                cm_param_shapes = []
                for cm_layer in self.context_mod_layers:
                    cm_param_shapes.extend(cm_layer.param_shapes)
                cm_num_weights = \
                    MainNetInterface.shapes_to_num_weights(cm_param_shapes)

            print('Creating a LeNet with %d weights' % self.num_params
                  + (' (including %d weights associated with-' % cm_num_weights
                     + 'context modulation)' if self._use_context_mod else '')
                  + '.'
                  + (' The network uses dropout.' if dropout_rate != -1 \
                     else ''))

        self._is_properly_setup()
Exemplo n.º 10
0
    def __init__(self,
                 n_in=1,
                 n_out=1,
                 hidden_layers=(10, 10),
                 activation_fn=torch.nn.ReLU(),
                 use_bias=True,
                 no_weights=False,
                 init_weights=None,
                 dropout_rate=-1,
                 use_spectral_norm=False,
                 use_batch_norm=False,
                 bn_track_stats=True,
                 distill_bn_stats=False,
                 use_context_mod=False,
                 context_mod_inputs=False,
                 no_last_layer_context_mod=False,
                 context_mod_no_weights=False,
                 context_mod_post_activation=False,
                 context_mod_gain_offset=False,
                 context_mod_gain_softplus=False,
                 out_fn=None,
                 verbose=True):
        # FIXME find a way using super to handle multiple inheritance.
        nn.Module.__init__(self)
        MainNetInterface.__init__(self)

        # FIXME Spectral norm is incorrectly implemented. Function
        # `nn.utils.spectral_norm` needs to be called in the constructor, such
        # that sepc norm is wrapped around a module.
        if use_spectral_norm:
            raise NotImplementedError('Spectral normalization not yet ' +
                                      'implemented for this network.')

        if use_batch_norm and use_context_mod:
            # FIXME Does it make sense to have both enabled?
            # I.e., should we produce a warning or error?
            pass

        self._a_fun = activation_fn
        assert(init_weights is None or \
               (not no_weights or not context_mod_no_weights))
        self._no_weights = no_weights
        self._dropout_rate = dropout_rate
        #self._use_spectral_norm = use_spectral_norm
        self._use_batch_norm = use_batch_norm
        self._bn_track_stats = bn_track_stats
        self._distill_bn_stats = distill_bn_stats and use_batch_norm
        self._use_context_mod = use_context_mod
        self._context_mod_inputs = context_mod_inputs
        self._no_last_layer_context_mod = no_last_layer_context_mod
        self._context_mod_no_weights = context_mod_no_weights
        self._context_mod_post_activation = context_mod_post_activation
        self._context_mod_gain_offset = context_mod_gain_offset
        self._context_mod_gain_softplus = context_mod_gain_softplus
        self._out_fn = out_fn

        self._has_bias = use_bias
        self._has_fc_out = True
        # We need to make sure that the last 2 entries of `weights` correspond
        # to the weight matrix and bias vector of the last layer.
        self._mask_fc_out = True
        self._has_linear_out = True if out_fn is None else False

        if use_spectral_norm and no_weights:
            raise ValueError('Cannot use spectral norm in a network without ' +
                             'parameters.')

        # FIXME make sure that this implementation is correct in all situations
        # (e.g., what to do if weights are passed to the forward method?).
        if use_spectral_norm:
            self._spec_norm = nn.utils.spectral_norm
        else:
            self._spec_norm = lambda x: x  # identity

        self._param_shapes = []
        self._param_shapes_meta = []
        self._weights = None if no_weights and context_mod_no_weights \
            else nn.ParameterList()
        self._hyper_shapes_learned = None \
            if not no_weights and not context_mod_no_weights else []
        self._hyper_shapes_learned_ref = None if self._hyper_shapes_learned \
            is None else []

        if dropout_rate != -1:
            assert (dropout_rate >= 0. and dropout_rate <= 1.)
            self._dropout = nn.Dropout(p=dropout_rate)

        ### Define and initialize context mod weights.
        self._context_mod_layers = nn.ModuleList() if use_context_mod else None
        self._context_mod_shapes = [] if use_context_mod else None

        if use_context_mod:
            cm_ind = 0
            cm_sizes = []
            if context_mod_inputs:
                cm_sizes.append(n_in)
            cm_sizes.extend(hidden_layers)
            if not no_last_layer_context_mod:
                cm_sizes.append(n_out)

            for i, n in enumerate(cm_sizes):
                cmod_layer = ContextModLayer(
                    n,
                    no_weights=context_mod_no_weights,
                    apply_gain_offset=context_mod_gain_offset,
                    apply_gain_softplus=context_mod_gain_softplus)
                self._context_mod_layers.append(cmod_layer)

                self.param_shapes.extend(cmod_layer.param_shapes)
                assert len(cmod_layer.param_shapes) == 2
                self._param_shapes_meta.extend([
                    {'name': 'cm_scale',
                     'index': -1 if context_mod_no_weights else \
                         len(self._weights),
                     'layer': -1}, # 'layer' is set later.
                    {'name': 'cm_shift',
                     'index': -1 if context_mod_no_weights else \
                         len(self._weights)+1,
                     'layer': -1}, # 'layer' is set later.
                ])

                self._context_mod_shapes.extend(cmod_layer.param_shapes)
                if context_mod_no_weights:
                    self._hyper_shapes_learned.extend(cmod_layer.param_shapes)
                else:
                    self._weights.extend(cmod_layer.weights)

                # FIXME ugly code. Move initialization somewhere else.
                if not context_mod_no_weights and init_weights is not None:
                    assert (len(cmod_layer.weights) == 2)
                    for ii in range(2):
                        assert(np.all(np.equal( \
                                list(init_weights[cm_ind].shape),
                                list(cm_ind.weights[ii].shape))))
                        cmod_layer.weights[ii].data = init_weights[cm_ind]
                        cm_ind += 1

            if init_weights is not None:
                init_weights = init_weights[cm_ind:]

        if context_mod_no_weights:
            self._hyper_shapes_learned_ref = \
                list(range(len(self._param_shapes)))

        ### Define and initialize batch norm weights.
        self._batchnorm_layers = nn.ModuleList() if use_batch_norm else None

        if use_batch_norm:
            if distill_bn_stats:
                self._hyper_shapes_distilled = []

            bn_ind = 0
            for i, n in enumerate(hidden_layers):
                bn_layer = BatchNormLayer(n,
                                          affine=not no_weights,
                                          track_running_stats=bn_track_stats)
                self._batchnorm_layers.append(bn_layer)

                self._param_shapes.extend(bn_layer.param_shapes)
                assert len(bn_layer.param_shapes) == 2
                self._param_shapes_meta.extend([
                    {
                        'name': 'bn_scale',
                        'index': -1 if no_weights else len(self._weights),
                        'layer': -1
                    },  # 'layer' is set later.
                    {
                        'name': 'bn_shift',
                        'index': -1 if no_weights else len(self._weights) + 1,
                        'layer': -1
                    },  # 'layer' is set later.
                ])

                if no_weights:
                    self._hyper_shapes_learned.extend(bn_layer.param_shapes)
                else:
                    self._weights.extend(bn_layer.weights)

                if distill_bn_stats:
                    self._hyper_shapes_distilled.extend( \
                        [list(p.shape) for p in bn_layer.get_stats(0)])

                # FIXME ugly code. Move initialization somewhere else.
                if not no_weights and init_weights is not None:
                    assert (len(bn_layer.weights) == 2)
                    for ii in range(2):
                        assert(np.all(np.equal( \
                                list(init_weights[bn_ind].shape),
                                list(bn_layer.weights[ii].shape))))
                        bn_layer.weights[ii].data = init_weights[bn_ind]
                        bn_ind += 1

            if init_weights is not None:
                init_weights = init_weights[bn_ind:]

        ### Compute shapes of linear layers.
        linear_shapes = MLP.weight_shapes(n_in=n_in,
                                          n_out=n_out,
                                          hidden_layers=hidden_layers,
                                          use_bias=use_bias)
        self._param_shapes.extend(linear_shapes)

        for i, s in enumerate(linear_shapes):
            self._param_shapes_meta.append({
                'name':
                'weight' if len(s) != 1 else 'bias',
                'index':
                -1 if no_weights else len(self._weights) + i,
                'layer':
                -1  # 'layer' is set later.
            })

        num_weights = MainNetInterface.shapes_to_num_weights(
            self._param_shapes)

        ### Set missing meta information of param_shapes.
        offset = 1 if use_context_mod and context_mod_inputs else 0
        shift = 1
        if use_batch_norm:
            shift += 1
        if use_context_mod:
            shift += 1

        cm_offset = 2 if context_mod_post_activation else 1
        bn_offset = 1 if context_mod_post_activation else 2

        cm_ind = 0
        bn_ind = 0
        layer_ind = 0

        for i, dd in enumerate(self._param_shapes_meta):
            if dd['name'].startswith('cm'):
                if offset == 1 and i in [0, 1]:
                    dd['layer'] = 0
                else:
                    if cm_ind < len(hidden_layers):
                        dd['layer'] = offset + cm_ind * shift + cm_offset
                    else:
                        assert cm_ind == len(hidden_layers) and \
                            not no_last_layer_context_mod
                        # No batchnorm in output layer.
                        dd['layer'] = offset + cm_ind * shift + 1

                    if dd['name'] == 'cm_shift':
                        cm_ind += 1

            elif dd['name'].startswith('bn'):
                dd['layer'] = offset + bn_ind * shift + bn_offset
                if dd['name'] == 'bn_shift':
                    bn_ind += 1

            else:
                dd['layer'] = offset + layer_ind * shift
                if not use_bias or dd['name'] == 'bias':
                    layer_ind += 1

        ### Uer information
        if verbose:
            if use_context_mod:
                cm_num_weights = 0
                for cm_layer in self._context_mod_layers:
                    cm_num_weights += MainNetInterface.shapes_to_num_weights( \
                        cm_layer.param_shapes)

            print(
                'Creating an MLP with %d weights' % num_weights +
                (' (including %d weights associated with-' % cm_num_weights +
                 'context modulation)' if use_context_mod else '') + '.' +
                (' The network uses dropout.' if dropout_rate != -1 else '') +
                (' The network uses batchnorm.' if use_batch_norm else ''))

        self._layer_weight_tensors = nn.ParameterList()
        self._layer_bias_vectors = nn.ParameterList()

        if no_weights:
            self._hyper_shapes_learned.extend(linear_shapes)

            if use_context_mod:
                if context_mod_no_weights:
                    self._hyper_shapes_learned_ref = \
                        list(range(len(self._param_shapes)))
                else:
                    ncm = len(self._context_mod_shapes)
                    self._hyper_shapes_learned_ref = \
                        list(range(ncm, len(self._param_shapes)))

            self._is_properly_setup()
            return

        ### Define and initialize linear weights.
        for i, dims in enumerate(linear_shapes):
            self._weights.append(
                nn.Parameter(torch.Tensor(*dims), requires_grad=True))
            if len(dims) == 1:
                self._layer_bias_vectors.append(self._weights[-1])
            else:
                self._layer_weight_tensors.append(self._weights[-1])

        if init_weights is not None:
            assert (len(init_weights) == len(linear_shapes))
            for i in range(len(init_weights)):
                assert (np.all(
                    np.equal(list(init_weights[i].shape), linear_shapes[i])))
                if use_bias:
                    if i % 2 == 0:
                        self._layer_weight_tensors[i //
                                                   2].data = init_weights[i]
                    else:
                        self._layer_bias_vectors[i // 2].data = init_weights[i]
                else:
                    self._layer_weight_tensors[i].data = init_weights[i]
        else:
            for i in range(len(self._layer_weight_tensors)):
                if use_bias:
                    init_params(self._layer_weight_tensors[i],
                                self._layer_bias_vectors[i])
                else:
                    init_params(self._layer_weight_tensors[i])

        if self._num_context_mod_shapes() == 0:
            # Note, that might be the case if no hidden layers exist and no
            # input or output modulation is used.
            self._use_context_mod = False

        self._is_properly_setup()
    def __init__(self,
                 n_in,
                 n_out=1,
                 inp_chunk_dim=100,
                 out_chunk_dim=10,
                 cemb_size=8,
                 cemb_init_std=1.,
                 red_layers=(10, 10),
                 net_layers=(10, 10),
                 activation_fn=torch.nn.ReLU(),
                 use_bias=True,
                 dynamic_biases=None,
                 no_weights=False,
                 init_weights=None,
                 dropout_rate=-1,
                 use_spectral_norm=False,
                 use_batch_norm=False,
                 bn_track_stats=True,
                 distill_bn_stats=False,
                 verbose=True):
        # FIXME find a way using super to handle multiple inheritance.
        nn.Module.__init__(self)
        MainNetInterface.__init__(self)

        self._n_in = n_in
        self._n_out = n_out
        self._inp_chunk_dim = inp_chunk_dim
        self._out_chunk_dim = out_chunk_dim
        self._cemb_size = cemb_size
        self._a_fun = activation_fn
        self._no_weights = no_weights

        self._has_bias = use_bias
        self._has_fc_out = True
        # We need to make sure that the last 2 entries of `weights` correspond
        # to the weight matrix and bias vector of the last layer.
        self._mask_fc_out = True
        self._has_linear_out = True  # Ensure that `out_fn` is `None`!

        self._param_shapes = []
        #self._param_shapes_meta = [] # TODO implement!
        self._weights = None if no_weights else nn.ParameterList()
        self._hyper_shapes_learned = None if not no_weights else []
        #self._hyper_shapes_learned_ref = None if self._hyper_shapes_learned \
        #    is None else [] # TODO implement.

        self._layer_weight_tensors = nn.ParameterList()
        self._layer_bias_vectors = nn.ParameterList()

        self._context_mod_layers = None
        self._batchnorm_layers = nn.ModuleList() if use_batch_norm else None

        #################################
        ### Generate Chunk Embeddings ###
        #################################
        self._num_cembs = int(np.ceil(n_in / inp_chunk_dim))
        last_chunk_size = n_in % inp_chunk_dim
        if last_chunk_size != 0:
            self._pad = inp_chunk_dim - last_chunk_size
        else:
            self._pad = -1

        cemb_shape = [self._num_cembs, cemb_size]
        self._param_shapes.append(cemb_shape)
        if no_weights:
            self._cembs = None
            self._hyper_shapes_learned.append(cemb_shape)
        else:
            self._cembs = nn.Parameter(data=torch.Tensor(*cemb_shape),
                                       requires_grad=True)
            nn.init.normal_(self._cembs, mean=0., std=cemb_init_std)

            self._weights.append(self._cembs)

        ############################
        ### Setup Dynamic Biases ###
        ############################
        self._has_dyn_bias = None
        if dynamic_biases is not None:
            assert np.all(np.array(dynamic_biases) >= 0) and \
                   np.all(np.array(dynamic_biases) < len(red_layers) + 1)
            dynamic_biases = np.sort(np.unique(dynamic_biases))

            # For each layer in the `reducer`, where we want to apply a dynamic
            # bias, we have to create a weight matrix for a corresponding
            # linear layer (we just ignore)
            self._dyn_bias_weights = nn.ModuleList()
            self._has_dyn_bias = []

            for i in range(len(red_layers) + 1):
                if i in dynamic_biases:
                    self._has_dyn_bias.append(True)

                    trgt_dim = out_chunk_dim
                    if i < len(red_layers):
                        trgt_dim = red_layers[i]
                    trgt_shape = [trgt_dim, cemb_size]

                    self._param_shapes.append(trgt_shape)
                    if not no_weights:
                        self._dyn_bias_weights.append(None)
                        self._hyper_shapes_learned.append(trgt_shape)
                    else:
                        self._dyn_bias_weights.append(nn.Parameter( \
                            torch.Tensor(*trgt_shape), requires_grad=True))
                        self._weights.append(self._dyn_bias_weights[-1])

                        init_params(self._dyn_bias_weights[-1])

                        self._layer_weight_tensors.append( \
                            self._dyn_bias_weights[-1])
                        self._layer_bias_vectors.append(None)
                else:
                    self._has_dyn_bias.append(False)
                    self._dyn_bias_weights.append(None)

        ################################
        ### Create `Reducer` Network ###
        ################################
        red_inp_dim = inp_chunk_dim + \
            (cemb_size if dynamic_biases is None else 0)
        self._reducer = MLP(
            n_in=red_inp_dim,
            n_out=out_chunk_dim,
            hidden_layers=red_layers,
            activation_fn=activation_fn,
            use_bias=use_bias,
            no_weights=no_weights,
            init_weights=None,
            dropout_rate=dropout_rate,
            use_spectral_norm=use_spectral_norm,
            use_batch_norm=use_batch_norm,
            bn_track_stats=bn_track_stats,
            distill_bn_stats=distill_bn_stats,
            # We use context modulation to realize dynamic biases, since they
            # allow a different modulation per sample in the input mini-batch.
            # Hence, we can process several chunks in parallel with the reducer
            # network.
            use_context_mod=not dynamic_biases is None,
            context_mod_inputs=False,
            no_last_layer_context_mod=False,
            context_mod_no_weights=True,
            context_mod_post_activation=False,
            context_mod_gain_offset=False,
            context_mod_gain_softplus=False,
            out_fn=None,
            verbose=True)

        if dynamic_biases is not None:
            # FIXME We have to extract the param shapes from
            # `self._reducer.param_shapes`, as well as from
            # `self._reducer._hyper_shapes_learned` that belong to context-mod
            # layers. We may not add them to our own `param_shapes` attribute,
            # as these are not parameters (due to our misuse of the context-mod
            # layers).
            # Note, in the `forward` method, we need to supply context-mod
            # weights for all reducer networks, independent on whether they have
            # a dynamic bias or not. We can do so, by providing constant ones
            # for all gains and constance zero-shift for all layers without
            # dynamic biases (note, we need to ensure the correct batch dim!).
            raise NotImplementedError(
                'Dynamic biases are not yet implemented!')

        assert self._reducer._context_mod_layers is None

        ### Overtake all attributes from the underlying MLP.
        for s in self._reducer.param_shapes:
            self._param_shapes.append(s)
        if no_weights:
            for s in self._reducer._hyper_shapes_learned:
                self._hyper_shapes_learned.append(s)
        else:
            for p in self._reducer._weights:
                self._weights.append(p)

        for p in self._reducer._layer_weight_tensors:
            self._layer_weight_tensors.append(p)
        for p in self._reducer._layer_bias_vectors:
            self._layer_bias_vectors.append(p)

        if use_batch_norm:
            for p in self._reducer._batchnorm_layers:
                self._batchnorm_layers.append(p)

        if self._reducer._hyper_shapes_distilled is not None:
            self._hyper_shapes_distilled = []
            for s in self._reducer._hyper_shapes_distilled:
                self._hyper_shapes_distilled.append(s)

        ###############################
        ### Create Actual `Network` ###
        ###############################
        net_inp_dim = out_chunk_dim * self._num_cembs
        self._network = MLP(n_in=net_inp_dim,
                            n_out=n_out,
                            hidden_layers=net_layers,
                            activation_fn=activation_fn,
                            use_bias=use_bias,
                            no_weights=no_weights,
                            init_weights=None,
                            dropout_rate=dropout_rate,
                            use_spectral_norm=use_spectral_norm,
                            use_batch_norm=use_batch_norm,
                            bn_track_stats=bn_track_stats,
                            distill_bn_stats=distill_bn_stats,
                            use_context_mod=False,
                            out_fn=None,
                            verbose=True)

        ### Overtake all attributes from the underlying MLP.
        for s in self._network.param_shapes:
            self._param_shapes.append(s)
        if no_weights:
            for s in self._network._hyper_shapes_learned:
                self._hyper_shapes_learned.append(s)
        else:
            for p in self._network._weights:
                self._weights.append(p)

        for p in self._network._layer_weight_tensors:
            self._layer_weight_tensors.append(p)
        for p in self._network._layer_bias_vectors:
            self._layer_bias_vectors.append(p)

        if use_batch_norm:
            for p in self._network._batchnorm_layers:
                self._batchnorm_layers.append(p)

        if self._hyper_shapes_distilled is not None:
            assert self._network._hyper_shapes_distilled is not None
            for s in self._network._hyper_shapes_distilled:
                self._hyper_shapes_distilled.append(s)

        #####################################
        ### Takeover given Initialization ###
        #####################################
        if init_weights is not None:
            assert len(init_weights) == len(self._weights)
            for i in range(len(init_weights)):
                assert np.all(
                    np.equal(list(init_weights[i].shape),
                             self._param_shapes[i]))
                self._weights[i].data = init_weights[i]

        ######################
        ### Finalize Setup ###
        ######################
        num_weights = MainNetInterface.shapes_to_num_weights(self.param_shapes)
        print('Constructed MLP that processes dimensionality reduced inputs ' +
              'through chunking. The network has a total of %d weights.' %
              num_weights)

        self._is_properly_setup()
Exemplo n.º 12
0
    def _add_fc_layers(self, in_sizes, out_sizes, no_weights, fc_layers=None):
        """Add fully-connected layers to the network.

        This method will set the weight requirements for fully-connected layers
        correctly. During the :meth:`forward` computation, those weights can be
        used in combination with :func:`torch.nn.functional.linear`.

        Note:
            This method should only be called inside the constructor of any
            class that implements this interface.

        Note:
            Bias weights are handled based on attribute :attr:`has_bias`.

        Note:
            This method will assumes attributes :attr:`param_shapes_meta` and
            :attr:`hyper_shapes_learned_ref` exist already.

        Note:
            Generated weights will be automatically added to attributes
            :attr:`layer_bias_vectors` and :attr:`layer_weight_tensors`.

        Note:
            Standard initialization will be applied to created weights.

        Args:
            in_sizes (list): List of intergers denoting the input size of each
                added fc-layer.
            out_sizes (list): List of intergers denoting the output size of each
                added fc-layer.
            no_weights (bool): If ``True``, fc-layers will be generated without
                internal parameters :attr:`internal_params`.
            fc_layers (list, optional): See attribute ``cm_layers`` of method
                :meth:`_add_context_mod_layers`.
        """
        assert len(in_sizes) == len(out_sizes)
        assert fc_layers is None or len(fc_layers) == len(in_sizes)
        assert self._param_shapes_meta is not None
        assert self._hyper_shapes_learned_ref is not None

        if self._layer_weight_tensors is None:
            self._layer_weight_tensors = torch.nn.ParameterList()
        if self._layer_bias_vectors is None:
            self._layer_bias_vectors = torch.nn.ParameterList()

        for i, n_in in enumerate(in_sizes):
            n_out = out_sizes[i]

            s_w = [n_out, n_in]
            s_b = [n_out] if self.has_bias else None

            for j, s in enumerate([s_w, s_b]):
                if s is None:
                    continue

                is_bias = True
                if j % 2 == 0:
                    is_bias = False

                if not no_weights:
                    self._internal_params.append(torch.nn.Parameter( \
                        torch.Tensor(*s), requires_grad=True))
                    if is_bias:
                        self._layer_bias_vectors.append( \
                            self._internal_params[-1])
                    else:
                        self._layer_weight_tensors.append( \
                            self._internal_params[-1])
                else:
                    self._hyper_shapes_learned.append(s)
                    self._hyper_shapes_learned_ref.append( \
                        len(self.param_shapes))

                self._param_shapes.append(s)
                self._param_shapes_meta.append({
                    'name': 'bias' if is_bias else 'weight',
                    'index': -1 if no_weights else len(self._internal_params)-1,
                    'layer': -1 if fc_layers is None else fc_layers[i]
                })

            if not no_weights:
                init_params(self._layer_weight_tensors[-1],
                    self._layer_bias_vectors[-1] if self.has_bias else None)