示例#1
0
class controller(nn.Module):    # LSTM Controller
    def __init__(self, num_inputs, num_outputs, num_layers):
        super(controller, self).__init__()

        self.num_inputs = num_inputs
        self.num_outputs = num_outputs
        self. num_layers = num_layers

        self.lstm_network = nn.LSTM(input_size = self.num_inputs, hidden_size = self.num_outputs, num_layers = self.num_layers)

        # Parameters of the LSTM. Hidden state serves as the output of our network
        self.h_init = Parameter(torch.randn(self.num_layers, 1, self.num_outputs) * 0.05)   # Hidden state initialization
        self.c_init = Parameter(torch.randn(self.num_layers, 1, self.num_outputs) * 0.05)   # C variable initialization

        # Initialization of the LSTM parameters.
        for p in self.lstm_network.parameters():
            if p.dim() == 1:
                nn.init.constant_(p, 0)
            else:
                stdev = 5 / (np.sqrt(self.num_inputs +  self.num_outputs))  # I don't know why we multiplied 5
                nn.init.uniform_(p, -stdev, stdev)

    def create_hidden_state(self, batch_size):  # Output : (num_layers x batch_size x num_outputs)
        h = self.h_init.clone().repeat(1, batch_size, 1)
        c = self.c_init.clone().repeat(1, batch_size, 1)
        return h, c

    def network_size(self):
        return self.num_inputs, self.num_outputs

    def forward(self, inp, prev_state):
        inp = inp.unsqueeze(0)              # inp dimension after unsqueeze : (1 x inp.shape)
        output, state = self.lstm_network(inp, prev_state)
        return output.squeeze(0), state
示例#2
0
class Controller(nn.Module):
    def __init__(self, input_size, output_size, num_layers):
        super(Controller, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.num_layers = num_layers

        self.lstm = nn.LSTM(input_size=input_size,
                            hidden_size=output_size,
                            num_layers=num_layers)

        self.reset()

        self.lstm_h_bias = Parameter(torch.randn(self.num_layers, 1, self.output_size) * 0.05)
        self.lstm_c_bias = Parameter(torch.randn(self.num_layers, 1, self.output_size) * 0.05)

    def new_init_state(self, batch_size):
        # Dimension: (num_layers * num_directions, batch, hidden_size)
        lstm_h = self.lstm_h_bias.clone().repeat(1, batch_size, 1)
        lstm_c = self.lstm_c_bias.clone().repeat(1, batch_size, 1)
        return lstm_h, lstm_c

    def reset(self):
        for p in self.lstm.parameters():
            if p.dim() == 1:
                nn.init.constant_(p, 0)
            else:
                stdev = 5 / (np.sqrt(self.input_size + self.output_size))
                nn.init.uniform_(p, -stdev, stdev)

    def forward(self, x, prev_state):
        out, state = self.lstm(x.unsqueeze(0), prev_state)
        return out.squeeze(0), state
class LSTMController(nn.Module):
    def __init__(self, input_size, output_size, controller_size,
                 read_data_size, num_outputs, outp_layer_size):
        super(LSTMController, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.num_layers = controller_size
        self.num_outputs = output_size
        self.lstm = nn.LSTM(input_size=input_size, hidden_size=output_size)
        self.output_layer = nn.Linear(outp_layer_size, num_outputs)
        self.lstm_h_bias = Parameter(
            torch.randn(self.num_layers, 1, self.num_outputs) * 0.05)
        self.lstm_c_bias = Parameter(
            torch.randn(self.num_layers, 1, self.num_outputs) * 0.05)
        self.h_state = torch.zeros([1, output_size])

    def forward(self, inputs, hidden=None):
        inputs = inputs.unsqueeze(0)
        outp, state = self.lstm(inputs, hidden)
        self.h_state = outp.squeeze(0)
        return self.h_state, state

    def output(self, read_data):
        end_state = read_data
        output = torch.nn.functional.sigmoid(self.output_layer(end_state))
        return output

    def size(self):
        return self.input_size, self.output_size

    def create_new_state(self, batch_size):
        # Dimension: (num_layers * num_directions, batch, hidden_size)
        lstm_h = self.lstm_h_bias.clone().repeat(1, batch_size, 1)
        lstm_c = self.lstm_c_bias.clone().repeat(1, batch_size, 1)
        return lstm_h, lstm_c
示例#4
0
class Controller(StatefulComponent):
    def __init__(self, embedding_size, hidden_size, dictionary_size=None):
        # Configurations
        super().__init__()
        self.dictionary_size = dictionary_size
        self.embedding_size = embedding_size
        self.hidden_size = hidden_size

        # Embedding layer (optional)
        self.embedding = nn.Embedding(
            self.dictionary_size,
            self.embedding_size) if dictionary_size else None

        # LSTM cell to extract features from input
        self.cell = nn.LSTMCell(self.embedding_size, self.hidden_size)

        # Learnable LSTM hidden state biases
        self.h_bias = Parameter(Tensor(self.hidden_size).normal_())
        self.c_bias = Parameter(Tensor(self.hidden_size).normal_())

        # States
        self.h = None
        self.c = None

    def forward(self, x, device=None):
        # supply the lstm cell with zeros in the embedding space. (no input)
        if device is None: device = x.device
        if x is None:
            e = Variable(
                torch.zeros(
                    self.expected_batch_size,
                    self.embedding_size,
                ).type_as(self.h.data)).to(device)
        # supply the lstm cell with an embedded input.
        elif self.embedding:
            e = self.embedding(x)
        # supply the lstm cell with the input as-is. (assuming that the input
        # is already in the embedding space)
        else:
            assert x.size() == (
                self.expected_batch_size, self.embedding_size
            ), 'Input should have size of {b}x{e}, while given {s}.'.format(
                b=self.expected_batch_size,
                e=self.embedding_size,
                s=x.size(),
            )
            e = x

        # run an lstm cell and update the states.
        self.h, self.c = self.cell(e, (self.h, self.c))
        return self.h

    def reset(self, batch_size, device=None):
        super().reset(batch_size)
        self.h = self.h_bias.clone().repeat(batch_size, 1)
        self.c = self.c_bias.clone().repeat(batch_size, 1)
        if device is not None:
            self.h = self.h.to(device)
            self.c = self.c.to(device)
示例#5
0
class LSTMBaseline(nn.Module):
    """An NTM controller based on LSTM."""
    def __init__(self, num_inputs, num_hidden, num_outputs, num_layers):
        super(LSTMBaseline, self).__init__()

        self.num_inputs = num_inputs
        self.num_hidden = num_hidden
        self.num_layers = num_layers

        self.lstm = nn.LSTM(input_size=num_inputs,
                            hidden_size=num_hidden,
                            num_layers=num_layers)

        self.out = nn.Linear(num_hidden, num_outputs)

        # The hidden state is a learned parameter
        if torch.cuda.is_available():
            self.lstm_h_bias = Parameter(torch.randn(self.num_layers, 1, self.num_hidden).cuda() * 0.05)
            self.lstm_c_bias = Parameter(torch.randn(self.num_layers, 1, self.num_hidden).cuda() * 0.05)
        else:
            self.lstm_h_bias = Parameter(torch.randn(self.num_layers, 1, self.num_hidden) * 0.05)
            self.lstm_c_bias = Parameter(torch.randn(self.num_layers, 1, self.num_hidden) * 0.05)

        self.reset_parameters()

    def create_new_state(self, batch_size):
        # Dimension: (num_layers * num_directions, batch, hidden_size)
        lstm_h = self.lstm_h_bias.clone().repeat(1, batch_size, 1)
        lstm_c = self.lstm_c_bias.clone().repeat(1, batch_size, 1)
        return lstm_h, lstm_c

    def reset_parameters(self):
        for p in self.lstm.parameters():
            if p.dim() == 1:
                nn.init.constant_(p, 0)
            else:
                stdev = 5 / (np.sqrt(self.num_inputs + self.num_hidden))
                nn.init.uniform_(p, -stdev, stdev)

    def init_sequence(self, batch_size):
        """Initializing the state."""
        self.previous_state = self.create_new_state(batch_size)

    def size(self):
        return self.num_inputs, self.num_hidden

    def forward(self, x):
        x = x.unsqueeze(0)
        outp, self.previous_state = self.lstm(x, self.previous_state)
        outp = self.out(outp)
        return outp.squeeze(0), self.previous_state

    def calculate_num_params(self):
        """Returns the total number of parameters."""
        num_params = 0
        for p in self.parameters():
            num_params += p.data.view(-1).size(0)
        return num_params
class backward_controller(nn.Module
                          ):  # Backward LSTM to make DNC Bi-Directional
    def __init__(self, num_inputs, num_outputs, num_layers):
        super(backward_controller, self).__init__()

        self.num_inputs = num_inputs
        self.num_outputs = num_outputs
        self.num_layers = num_layers

        self.lstm_network = nn.LSTM(input_size=self.num_inputs,
                                    hidden_size=self.num_outputs,
                                    num_layers=self.num_layers)

        # Parameters of the LSTM. Hidden state serves as the output of our network
        self.h_init = Parameter(
            torch.randn(self.num_layers, 1, self.num_outputs) *
            0.05)  # Hidden state initialization
        self.c_init = Parameter(
            torch.randn(self.num_layers, 1, self.num_outputs) *
            0.05)  # C variable initialization

        # Initialization of the LSTM parameters.
        for p in self.lstm_network.parameters():
            if p.dim() == 1:
                nn.init.constant_(p, 0)
            else:
                stdev = 5 / (np.sqrt(self.num_inputs + self.num_outputs)
                             )  # I don't know why we multiplied 5
                nn.init.uniform_(p, -stdev, stdev)

    def create_hidden_state(
            self,
            batch_size):  # Output : (num_layers x batch_size x num_outputs)
        h = self.h_init.clone().repeat(1, batch_size, 1)
        c = self.c_init.clone().repeat(1, batch_size, 1)
        return h, c

    def network_size(self):
        return self.num_inputs, self.num_outputs

    def forward(
            self, inp,
            prev_states):  # inp dimension: (seq_len x batch_size x input_size)
        inp = inp[torch.arange(
            inp.shape[0] - 1, -1,
            -1), :, :]  # Reversing the input for backward direction
        output, state = self.lstm_network(
            inp, prev_states
        )  # Input to LSTM must be of shape (seq_len x batch_size x input_size) in Pytorch
        # output = output[torch.arange(output.shape[0]-1, -1, -1), :, :] # Reversing the 'output'.
        return output, state  # Output size is (seq_len x batch x hidden_size) as per documentation
示例#7
0
文件: controller.py 项目: phymucs/NSM
class LSTMController(nn.Module):
    """An NTM controller based on LSTM."""
    def __init__(self, num_inputs, num_outputs, num_layers):
        super(LSTMController, self).__init__()

        self.num_inputs = num_inputs
        self.num_outputs = num_outputs
        self.num_layers = num_layers

        self.lstm = nn.LSTM(input_size=num_inputs,
                            hidden_size=num_outputs,
                            num_layers=num_layers)

        # The hidden state is a learned parameter
        if torch.cuda.is_available():
            self.lstm_h_bias = Parameter(
                torch.randn(self.num_layers, 1, self.num_outputs).cuda() *
                0.05)
            self.lstm_c_bias = Parameter(
                torch.randn(self.num_layers, 1, self.num_outputs).cuda() *
                0.05)
        else:
            self.lstm_h_bias = Parameter(
                torch.randn(self.num_layers, 1, self.num_outputs) * 0.05)
            self.lstm_c_bias = Parameter(
                torch.randn(self.num_layers, 1, self.num_outputs) * 0.05)

        self.reset_parameters()

    def create_new_state(self, batch_size):
        # Dimension: (num_layers * num_directions, batch, hidden_size)
        lstm_h = self.lstm_h_bias.clone().repeat(1, batch_size, 1)
        lstm_c = self.lstm_c_bias.clone().repeat(1, batch_size, 1)
        return lstm_h, lstm_c

    def reset_parameters(self):
        for p in self.lstm.parameters():
            if p.dim() == 1:
                nn.init.constant_(p, 0)
            else:
                stdev = 5 / (np.sqrt(self.num_inputs + self.num_outputs))
                nn.init.uniform_(p, -stdev, stdev)

    def size(self):
        return self.num_inputs, self.num_outputs

    def forward(self, x, prev_state):
        x = x.unsqueeze(0)
        outp, state = self.lstm(x, prev_state)
        return outp.squeeze(0), state
示例#8
0
class Memory(nn.Module):
    def __init__(self, memory_size):
        super(Memory, self).__init__()
        self._memory_size = memory_size

        # Initialize memory bias
        initial_state = torch.ones(memory_size) * 1e-6
        self.register_buffer('initial_state', initial_state.data)

        # Initial read vector is a learnt parameter
        self.initial_read = Parameter(
            torch.randn(1, self._memory_size[1]) * 0.01)

    def get_size(self):
        return self._memory_size

    def reset(self, batch_size):
        self.memory = self.initial_state.clone().repeat(batch_size, 1, 1)

    def get_initial_read(self, batch_size):
        return self.initial_read.clone().repeat(batch_size, 1)

    def read(self):
        return self.memory

    def write(self, w, e, a):
        self.memory = self.memory * (
            1 - torch.matmul(w.unsqueeze(-1), e.unsqueeze(1)))
        self.memory = self.memory + torch.matmul(w.unsqueeze(-1),
                                                 a.unsqueeze(1))
        return self.memory

    def size(self):
        return self._memory_size
示例#9
0
class ControllerState(nn.Module):
    def __init__(self, controller):
        super(ControllerState, self).__init__()
        self.controller = controller

        self.device = torch.device("cpu")
        if torch.cuda.is_available():
            self.device = torch.device("cuda")

        # starting hidden state is a learned parameter
        self.lstm_h_bias = Parameter(torch.randn(self.controller.num_layers, 1, self.controller.num_outputs) * 0.05)
        self.lstm_c_bias = Parameter(torch.randn(self.controller.num_layers, 1, self.controller.num_outputs) * 0.05)

        self.to(self.device)

    def reset(self, batch_size):
        h = self.lstm_h_bias.clone().repeat(1, batch_size, 1)
        c = self.lstm_c_bias.clone().repeat(1, batch_size, 1)
        self.state = h, c
示例#10
0
文件: controller.py 项目: clemkoa/ntm
class LSTMController(nn.Module):
    def __init__(self, vector_length, hidden_size):
        super(LSTMController, self).__init__()
        self.layer = nn.LSTM(input_size=vector_length, hidden_size=hidden_size)
        # The hidden state is a learned parameter
        self.lstm_h_state = Parameter(torch.randn(1, 1, hidden_size) * 0.05)
        self.lstm_c_state = Parameter(torch.randn(1, 1, hidden_size) * 0.05)
        for p in self.layer.parameters():
            if p.dim() == 1:
                nn.init.constant_(p, 0)
            else:
                stdev = 5 / (np.sqrt(vector_length + hidden_size))
                nn.init.uniform_(p, -stdev, stdev)

    def forward(self, x, state):
        output, state = self.layer(x.unsqueeze(0), state)
        return output.squeeze(0), state

    def get_initial_state(self, batch_size):
        lstm_h = self.lstm_h_state.clone().repeat(1, batch_size, 1)
        lstm_c = self.lstm_c_state.clone().repeat(1, batch_size, 1)
        return lstm_h, lstm_c
示例#11
0
class LSTMController(Controller):
    """LSTM controller for the NTM.
	"""
    def __init__(self, input_dim, output_dim, num_layers, batch_size,
                 use_cuda):
        super(LSTMController, self).__init__(input_dim, output_dim, num_layers,
                                             batch_size)

        self.lstm = nn.LSTM(input_size=input_dim,
                            hidden_size=output_dim,
                            num_layers=num_layers)
        if use_cuda:
            self.lstm.cuda()

        # From https://github.com/fanxiao001/ift6135-assignment/blob/master/assignment3/NTM/controller.py
        self.lstm_h_bias = Parameter(
            torch.randn(self.num_layers, 1, self.output_dim) * 0.05)
        self.lstm_c_bias = Parameter(
            torch.randn(self.num_layers, 1, self.output_dim) * 0.05)

        self.reset_parameters()

    def forward(self, x, r, lstm_h, lstm_c):
        """forward pass of the LSTM controller
		"""
        #  Concatenate previous read state with input
        x = torch.cat((r, x.squeeze(0)), 1)

        # feed into controller with previous state
        output, state = self.lstm(x.unsqueeze(0), (lstm_h, lstm_c))
        return output, state

    def create_state(self, batch_size):
        # Dimension: (num_layers * num_directions, batch, hidden_size)
        # From https://github.com/fanxiao001/ift6135-assignment/blob/master/assignment3/NTM/controller.py
        lstm_h = self.lstm_h_bias.clone().repeat(1, batch_size, 1)
        lstm_c = self.lstm_c_bias.clone().repeat(1, batch_size, 1)
        return lstm_h, lstm_c
示例#12
0
    def maybe_save(self, t, img: Parameter, content_image):
        from aikido.nn.modules.styletransfer.fileloader import deprocess

        should_save = self.save_iter > 0 and t % self.save_iter == 0 and t > 0
        should_save = should_save or t == 50  #FIXME
        if should_save:
            output_filename, file_extension = os.path.splitext(
                self.kun.file_name)
            filename = str(output_filename) + "_" + str(t) + str(
                file_extension)

            # disp = deprocess(img.squeeze(0).clone())
            disp = deprocess(img.clone(), self.kun)

            # Maybe perform postprocessing for color-independent style transfer
            if self.original_colors:
                disp = self.postprocess_colors(
                    deprocess(content_image.clone(), self.kun), disp)

            disp.save(str(filename))
示例#13
0
class Connection(AbstractConnection):
    # language=rst
    """
    Specifies synapses between one or two populations of neurons.
    """
    def __init__(self,
                 source: Nodes,
                 target: Nodes,
                 nu: Optional[Union[float, Sequence[float]]] = None,
                 weight_decay: float = 0.0,
                 **kwargs) -> None:
        # language=rst
        """
        Instantiates a :code:`Connection` object.

        :param source: A layer of nodes from which the connection originates.
        :param target: A layer of nodes to which the connection connects.
        :param nu: Learning rate for both pre- and post-synaptic events.
        :param weight_decay: Constant multiple to decay weights by on each iteration.

        Keyword arguments:

        :param LearningRule update_rule: Modifies connection parameters according to some rule.
        :param torch.Tensor w: Strengths of synapses.
        :param torch.Tensor b: Target population bias.
        :param float wmin: Minimum allowed value on the connection weights.
        :param float wmax: Maximum allowed value on the connection weights.
        :param float norm: Total weight per target neuron normalization constant.
        :param ByteTensor norm_by_max: Normalize the weight of a neuron by its max weight.
        :param ByteTensor norm_by_max_with_shadow_weights: Normalize the weight of a neuron by its max weight by
                                                           original weights.
        """
        super().__init__(source, target, nu, weight_decay, **kwargs)

        w = kwargs.get("w", None)
        if w is None:
            if self.wmin == -np.inf or self.wmax == np.inf:
                w = torch.clamp(torch.rand(source.n, target.n), self.wmin,
                                self.wmax)
            else:
                w = self.wmin + torch.rand(source.n,
                                           target.n) * (self.wmax - self.wmin)
        else:
            if self.wmin != -np.inf or self.wmax != np.inf:
                w = torch.clamp(w, self.wmin, self.wmax)

        self.w = Parameter(w, False)

        self.b = Parameter(kwargs.get("b", torch.zeros(target.n)), False)

        if self.norm_by_max_from_shadow_weights:
            self.shadow_w = self.w.clone().detach()
            self.prev_w = self.w.clone().detach()

    def compute(self, s: torch.Tensor) -> torch.Tensor:
        # language=rst
        """
        Compute pre-activations given spikes using connection weights.

        :param s: Incoming spikes.
        :return: Incoming spikes multiplied by synaptic weights (with or without decaying spike activation).
        """
        # Compute multiplication of spike activations by connection weights and add bias.
        post = s.float().view(-1) @ self.w + self.b
        return post.view(*self.target.shape)

    def update(self, **kwargs) -> None:
        # language=rst
        """
        Compute connection's update rule.
        """
        super().update(**kwargs)

    def normalize(self) -> None:
        # language=rst
        """
        Normalize weights so each target neuron has sum of connection weights equal to ``self.norm``.
        """
        if self.norm is not None:
            w_abs_sum = self.w.abs().sum(0).unsqueeze(0)
            w_abs_sum[w_abs_sum == 0] = 1.0
            self.w *= self.norm / w_abs_sum

    def normalize_by_max(self) -> None:
        # language=rst
        """
        Normalize weights by the max weight of the target neuron.
        """
        if self.norm_by_max:
            w_max = self.w.abs().max(0)[0]
            w_max[w_max == 0] = 1.0
            self.w /= w_max

    def normalize_by_max_from_shadow_weights(self) -> None:
        # language=rst
        """
        Normalize weights by the max weight of the target neuron.
        """
        if self.norm_by_max_from_shadow_weights:
            self.shadow_w += self.w - self.prev_w
            w_max = self.shadow_w.abs().max(0)[0]
            w_max[w_max == 0] = 1.0
            self.w = self.shadow_w / w_max
            self.prev_w = self.w.clone().detach()

    def reset_(self) -> None:
        # language=rst
        """
        Contains resetting logic for the connection.
        """
        super().reset_()
class ResNet(nn.Module):
    def __init__(self,
                 block,
                 layers,
                 num_classes=1000,
                 zero_init_residual=False,
                 groups=1,
                 width_per_group=64,
                 replace_stride_with_dilation=None,
                 norm_layer=None):
        super(ResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead

            # -----------------------------
            # modified
            replace_stride_with_dilation = [False, False, False]
            # -----------------------------

        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(
                                 replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        #layer  for RGB input
        self.conv1 = nn.Conv2d(3,
                               self.inplanes,
                               kernel_size=7,
                               stride=2,
                               padding=3,
                               bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block,
                                       128,
                                       layers[1],
                                       stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block,
                                       256,
                                       layers[2],
                                       stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block,
                                       512,
                                       layers[3],
                                       stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        #parameter  for selective
        self.selective_0 = Parameter(torch.Tensor(64, 1, 1))
        self.selective_1 = Parameter(torch.Tensor(64, 1, 1))
        self.selective_2 = Parameter(torch.Tensor(128, 1, 1))
        self.selective_3 = Parameter(torch.Tensor(256, 1, 1))
        self.selective_4 = Parameter(torch.Tensor(512, 1, 1))

        #layer for merge
        self.att_d = SELayer(64)
        self.att_rgb = SELayer(64)

        self.att_d_layer1 = SELayer(64)
        self.att_rgb_layer1 = SELayer(64)

        self.att_d_layer2 = SELayer(128)
        self.att_rgb_layer2 = SELayer(128)

        self.att_d_layer3 = SELayer(256)
        self.att_rgb_layer3 = SELayer(256)

        self.att_d_layer4 = SELayer(512)
        self.att_rgb_layer4 = SELayer(512)

        #layer for depth layer
        self.inplanes = 64
        self.conv1_d = nn.Conv2d(1,
                                 64,
                                 kernel_size=7,
                                 stride=2,
                                 padding=3,
                                 bias=False)
        self.bn1_d = norm_layer(64)
        self.relu_d = nn.ReLU(inplace=True)
        self.maxpool_d = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1_d = self._make_layer(block, 64, layers[0])
        self.layer2_d = self._make_layer(
            block,
            128,
            layers[1],
            stride=2,
            dilate=replace_stride_with_dilation[0])
        self.layer3_d = self._make_layer(
            block,
            256,
            layers[2],
            stride=2,
            dilate=replace_stride_with_dilation[1])
        self.layer4_d = self._make_layer(
            block,
            512,
            layers[3],
            stride=2,
            dilate=replace_stride_with_dilation[2])

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight,
                                        mode='fan_out',
                                        nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

        nn.init.constant_(self.selective_0, 0.9)
        nn.init.constant_(self.selective_1, 0.7)
        nn.init.constant_(self.selective_2, 0.5)
        nn.init.constant_(self.selective_3, 0.3)
        nn.init.constant_(self.selective_4, 0.1)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(
            block(self.inplanes, planes, stride, downsample, self.groups,
                  self.base_width, previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(
                block(self.inplanes,
                      planes,
                      groups=self.groups,
                      base_width=self.base_width,
                      dilation=self.dilation,
                      norm_layer=norm_layer))

        return nn.Sequential(*layers)

    def forward(self, x, depth):

        #pdb.set_trace()

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        #output_rgb_0 = x
        depth = self.conv1_d(depth)
        depth = self.bn1_d(depth)
        depth = self.relu_d(depth)
        depth = self.maxpool_d(depth)
        #output_depth_0 = depth
        selective_0_d = (1 - self.relu(self.selective_0.clone())).unsqueeze(0)
        selective_0_r = self.relu(self.selective_0.clone()).unsqueeze(0)
        x = selective_0_d * self.att_d(depth) + selective_0_r * self.att_rgb(
            x)  #merge

        x = self.layer1(x)
        depth = self.layer1_d(depth)
        #output_rgb_1 = x
        #output_depth_1 = depth
        selective_1_d = (1 - self.relu(self.selective_1.clone())).unsqueeze(0)
        selective_1_r = self.relu(self.selective_1.clone()).unsqueeze(0)
        x = selective_1_d * self.att_d_layer1(
            depth) + selective_1_r * self.att_rgb_layer1(x)  #merge
        #output_fusion_1 = x

        x = self.layer2(x)
        depth = self.layer2_d(depth)
        selective_2_d = (1 - self.relu(self.selective_2.clone())).unsqueeze(0)
        selective_2_r = self.relu(self.selective_2.clone()).unsqueeze(0)
        x = selective_2_d * self.att_d_layer2(
            depth) + selective_2_r * self.att_rgb_layer2(x)  #merge

        x = self.layer3(x)
        depth = self.layer3_d(depth)
        selective_3_d = (1 - self.relu(self.selective_3.clone())).unsqueeze(0)
        selective_3_r = self.relu(self.selective_3.clone()).unsqueeze(0)
        x = selective_3_d * self.att_d_layer3(
            depth) + selective_3_r * self.att_rgb_layer3(x)  #merge
        #output_depth_3 = depth

        x = self.layer4(x)
        depth = self.layer4_d(depth)
        selective_4_d = (1 - self.relu(self.selective_4.clone())).unsqueeze(0)
        selective_4_r = self.relu(self.selective_4.clone()).unsqueeze(0)
        x = selective_4_d * self.att_d_layer4(
            depth) + selective_4_r * self.att_rgb_layer4(x)  #merge

        #return x, output_depth_0,output_depth_1,output_depth_3, output_rgb_0, output_rgb_1, output_fusion_1
        return x
def attack(model,
           criterion,
           img,
           label,
           eps,
           attack_type,
           iters,
           clean_clean_img=None):
    assert not model.training

    adv = img.clone().detach()
    adv = Parameter(adv, requires_grad=True)

    if attack_type == 'fgsm':
        iterations = 1
    else:
        iterations = iters

    if attack_type == 'pgd':
        step = 2 / 255
    else:
        step = eps / iterations

        noise = 0

    for j in range(iterations):
        outputs = None
        if aug_test is None:
            out_adv = model(normalize(adv.clone()))
            loss = criterion(out_adv, label)
            loss.backward()
        else:
            adv_aux = adv * (1.0 - aug_test_lambda)
            for i in range(
                    aug_test
            ):  # TODO Check why this uses so much memory... it ain't normal fam
                adv_aux = adv_aux + aug_test_lambda * clean_clean_img[
                    torch.randperm(label.size(0))]
                out = model(normalize(adv_aux))
                if outputs is None:
                    outputs = out
                else:
                    outputs += out
            out_adv = outputs / aug_test

            loss = criterion(out_adv, label)
            loss.backward()

        if attack_type == 'mim':
            adv_mean = torch.mean(torch.abs(adv.grad), dim=1, keepdim=True)
            adv_mean = torch.mean(torch.abs(adv_mean), dim=2, keepdim=True)
            adv_mean = torch.mean(torch.abs(adv_mean), dim=3, keepdim=True)
            adv.grad = adv.grad / adv_mean
            noise = noise + adv.grad
        else:
            assert adv.grad is not None
            noise = adv.grad

        # Optimization step
        adv.data = adv.data + step * noise.sign()
        #        adv.data = adv.data + step * adv.grad.sign()

        if attack_type == 'pgd':
            adv.data = torch.where(adv.data > img.data + eps, img.data + eps,
                                   adv.data)
            adv.data = torch.where(adv.data < img.data - eps, img.data - eps,
                                   adv.data)
        adv.data.clamp_(0.0, 1.0)

        adv.grad.data.zero_()

    return adv.detach()
class MarkovFlow(nn.Module):
    def __init__(self, args, num_dims):
        super(MarkovFlow, self).__init__()

        self.args = args
        self.device = args.device

        # Gaussian Variance
        self.var = Parameter(torch.zeros(num_dims, dtype=torch.float32))

        if not args.train_var:
            self.var.requires_grad = False

        self.num_state = args.num_state
        self.num_dims = num_dims
        self.couple_layers = args.couple_layers
        self.cell_layers = args.cell_layers
        self.hidden_units = num_dims // 2
        self.lstm_hidden_units = self.num_dims

        # transition parameters in log space
        self.tparams = Parameter(torch.Tensor(self.num_state, self.num_state))

        self.prior_group = [self.tparams]

        # Gaussian means
        self.means = Parameter(torch.Tensor(self.num_state, self.num_dims))

        if args.mode == "unsupervised" and args.freeze_prior:
            self.tparams.requires_grad = False

        if args.mode == "unsupervised" and args.freeze_mean:
            self.means.requires_grad = False

        if args.model == 'nice':
            self.proj_layer = NICETrans(self.couple_layers, self.cell_layers,
                                        self.hidden_units, self.num_dims,
                                        self.device)
        elif args.model == "lstmnice":
            self.proj_layer = LSTMNICE(self.args.lstm_layers,
                                       self.args.couple_layers,
                                       self.args.cell_layers,
                                       self.lstm_hidden_units,
                                       self.hidden_units, self.num_dims,
                                       self.device)

        if args.mode == "unsupervised" and args.freeze_proj:
            for param in self.proj_layer.parameters():
                param.requires_grad = False

        if args.model == "gaussian":
            self.proj_group = [self.means, self.var]
        else:
            self.proj_group = list(
                self.proj_layer.parameters()) + [self.means, self.var]

        # prior
        self.pi = torch.zeros(self.num_state,
                              dtype=torch.float32,
                              requires_grad=False,
                              device=self.device).fill_(1.0 / self.num_state)

        self.pi = torch.log(self.pi)

    def init_params(self, train_data):
        """
        init_seed:(sents, masks)
        sents: (seq_length, batch_size, features)
        masks: (seq_length, batch_size)

        """

        # initialize transition matrix params
        # self.tparams.data.uniform_().add_(1)
        self.tparams.data.uniform_()

        # load pretrained model
        if self.args.load_nice != '':
            self.load_state_dict(torch.load(self.args.load_nice), strict=True)

            self.means_init = self.means.clone()
            self.tparams_init = self.tparams.clone()
            self.proj_init = [
                param.clone() for param in self.proj_layer.parameters()
            ]

            if self.args.init_var:
                self.init_var(train_data)

            if self.args.init_var_one:
                self.var.fill_(0.01)

            # self.means_init.requires_grad = False
            # self.tparams_init.requires_grad = False
            # for tensor in self.proj_init:
            #     tensor.requires_grad = False

            return

        # load pretrained Gaussian baseline
        if self.args.load_gaussian != '':
            self.load_state_dict(torch.load(self.args.load_gaussian),
                                 strict=False)

        # fully unsupervised training
        if self.args.mode == "unsupervised" and self.args.load_nice == "":
            with torch.no_grad():
                for iter_obj in train_data.data_iter(self.args.batch_size):
                    sents = iter_obj.embed
                    masks = iter_obj.mask
                    sents, _ = self.transform(sents, iter_obj.mask)
                    seq_length, _, features = sents.size()
                    flat_sents = sents.view(-1, features)
                    seed_mean = torch.sum(
                        masks.view(-1, 1).expand_as(flat_sents) * flat_sents,
                        dim=0) / masks.sum()
                    seed_var = torch.sum(
                        masks.view(-1, 1).expand_as(flat_sents) *
                        ((flat_sents - seed_mean.expand_as(flat_sents))**2),
                        dim=0) / masks.sum()
                    self.var.copy_(seed_var)
                    # self.var.fill_(0.02)

                    # add noise to the pretrained Gaussian mean
                    if self.args.load_gaussian != '' and self.args.model == 'nice':
                        self.means.data.add_(
                            seed_mean.data.expand_as(self.means.data))
                    elif self.args.load_gaussian == '' and self.args.load_nice == '':
                        self.means.data.normal_().mul_(0.04)
                        self.means.data.add_(
                            seed_mean.data.expand_as(self.means.data))

                    return

        self.init_mean(train_data)
        self.var.fill_(1.0)
        self.init_var(train_data)

        if self.args.init_var_one:
            self.var.fill_(1.0)

    def init_mean(self, train_data):
        emb_dict = {}
        cnt_dict = Counter()
        for iter_obj in train_data.data_iter(self.args.batch_size):
            sents_t = iter_obj.embed
            sents_t, _ = self.transform(sents_t, iter_obj.mask)
            sents_t = sents_t.transpose(0, 1)
            pos_t = iter_obj.pos.transpose(0, 1)
            mask_t = iter_obj.mask.transpose(0, 1)

            for emb_s, tagid_s, mask_s in zip(sents_t, pos_t, mask_t):
                for tagid, emb, mask in zip(tagid_s, emb_s, mask_s):
                    tagid = tagid.item()
                    mask = mask.item()
                    if tagid in emb_dict:
                        emb_dict[tagid] = emb_dict[tagid] + emb * mask
                    else:
                        emb_dict[tagid] = emb * mask

                    cnt_dict[tagid] += mask

        for tagid in emb_dict:
            self.means[tagid] = emb_dict[tagid] / cnt_dict[tagid]

    def init_var(self, train_data):
        cnt = 0
        mean_sum = 0.
        var_sum = 0.
        for iter_obj in train_data.data_iter(batch_size=self.args.batch_size):
            sents, masks = iter_obj.embed, iter_obj.mask
            sents, _ = self.transform(sents, masks)
            seq_length, _, features = sents.size()
            flat_sents = sents.view(-1, features)
            mean_sum = mean_sum + torch.sum(
                masks.view(-1, 1).expand_as(flat_sents) * flat_sents, dim=0)
            cnt += masks.sum().item()

        mean = mean_sum / cnt

        for iter_obj in train_data.data_iter(batch_size=self.args.batch_size):
            sents, masks = iter_obj.embed, iter_obj.mask
            sents, _ = self.transform(sents, masks)
            seq_length, _, features = sents.size()
            flat_sents = sents.view(-1, features)
            var_sum = var_sum + torch.sum(
                masks.view(-1, 1).expand_as(flat_sents) *
                ((flat_sents - mean.expand_as(flat_sents))**2),
                dim=0)
        var = var_sum / cnt
        self.var.copy_(var)

    def _calc_log_density_c(self):
        # return -self.num_dims/2.0 * (math.log(2) + \
        #         math.log(np.pi)) - 0.5 * self.num_dims * (torch.log(self.var))

        return -self.num_dims/2.0 * (math.log(2) + \
                math.log(np.pi)) - 0.5 * torch.sum(torch.log(self.var))

    def transform(self, x, masks=None):
        """
        Args:
            x: (sent_length, batch_size, num_dims)
        """
        jacobian_loss = torch.zeros(1, device=self.device, requires_grad=False)

        if self.args.model != 'gaussian':
            x, jacobian_loss_new = self.proj_layer(x, masks)
            jacobian_loss = jacobian_loss + jacobian_loss_new

        return x, jacobian_loss

    def MSE_loss(self):
        # diff1 = ((self.means - self.means_init) ** 2).sum()
        diff_prior = ((self.tparams - self.tparams_init)**2).sum()

        # diff = diff1 + diff2
        diff_proj = 0.

        for i, param in enumerate(self.proj_layer.parameters()):
            diff_proj = diff_proj + ((self.proj_init[i] - param)**2).sum()

        diff_mean = ((self.means_init - self.means)**2).sum()

        return 0.5 * (self.args.beta_prior * diff_prior + self.args.beta_proj *
                      diff_proj + self.args.beta_mean * diff_mean)

    def unsupervised_loss(self, sents, masks):
        """
        Args:
            sents: (sent_length, batch_size, self.num_dims)
            masks: (sent_length, batch_size)

        Returns: Tensor1, Tensor2
            Tensor1: negative log likelihood, shape ([])
            Tensor2: jacobian loss, shape ([])


        """
        max_length, batch_size, _ = sents.size()
        sents, jacobian_loss = self.transform(sents, masks)

        assert self.var.data.min() > 0

        self.logA = self._calc_logA()
        self.log_density_c = self._calc_log_density_c()

        alpha = self.pi + self._eval_density(sents[0])
        for t in range(1, max_length):
            density = self._eval_density(sents[t])
            mask_ep = masks[t].expand(self.num_state, batch_size) \
                      .transpose(0, 1)
            alpha = torch.mul(mask_ep,
                              self._forward_cell(alpha, density)) + \
                    torch.mul(1-mask_ep, alpha)

        # calculate objective from log space
        objective = torch.sum(log_sum_exp(alpha, dim=1))

        return -objective, jacobian_loss

    def supervised_loss(self, sents, tags, masks):
        """
        Args:
            sents: (sent_length, batch_size, num_dims)
            masks: (sent_length, batch_size)
            tags:  (sent_length, batch_size)

        Returns: Tensor1, Tensor2
            Tensor1: negative log likelihood, shape ([])
            Tensor2: jacobian loss, shape ([])

        """

        sent_len, batch_size, _ = sents.size()

        # (sent_length, batch_size, num_dims)
        sents, jacobian_loss = self.transform(sents, masks)

        # ()
        log_density_c = self._calc_log_density_c()

        # (1, 1, num_state, num_dims)
        means = self.means.view(1, 1, self.num_state, self.num_dims)
        means = means.expand(sent_len, batch_size, self.num_state,
                             self.num_dims)
        tag_id = tags.view(*tags.size(), 1, 1).expand(sent_len, batch_size, 1,
                                                      self.num_dims)

        # (sent_len, batch_size, num_dims)
        means = torch.gather(means, dim=2, index=tag_id).squeeze(2)

        var = self.var.view(1, 1, self.num_dims)

        # (sent_len, batch_size)
        log_emission_prob = log_density_c - \
                       0.5 * torch.sum((means-sents) ** 2 / var, dim=-1)

        log_emission_prob = torch.mul(masks, log_emission_prob).sum()

        # (num_state, num_state)
        log_trans = self._calc_logA()

        # (sent_len, batch_size, num_state, num_state)
        log_trans_prob = log_trans.view(1, 1, *log_trans.size()).expand(
            sent_len, batch_size, *log_trans.size())

        # (sent_len-1, batch_size, 1, num_state)
        tag_id = tags.view(*tags.size(), 1, 1).expand(sent_len, batch_size, 1,
                                                      self.num_state)[:-1]

        # (sent_len-1, batch_size, 1, num_state)
        log_trans_prob = torch.gather(log_trans_prob[:-1], dim=2, index=tag_id)

        # (sent_len-1, batch_size, 1, 1)
        tag_id = tags.view(*tags.size(), 1, 1)[1:]

        # (sent_len-1, batch_size)
        log_trans_prob = torch.gather(log_trans_prob, dim=3,
                                      index=tag_id).squeeze()

        log_trans_prob = torch.mul(masks[1:], log_trans_prob)

        log_trans_prior = self.pi.expand(batch_size, self.num_state)
        tag_id = tags[0].unsqueeze(dim=1)

        # (batch_size)
        log_trans_prior = torch.gather(log_trans_prior, dim=1,
                                       index=tag_id).sum()

        log_trans_prob = log_trans_prior + log_trans_prob.sum()

        return -(log_trans_prob + log_emission_prob), jacobian_loss

    def _calc_alpha(self, sents, masks):
        """
        sents: (sent_length, batch_size, self.num_dims)
        masks: (sent_length, batch_size)

        Returns:
            output: (batch_size, sent_length, num_state)

        """
        max_length, batch_size, _ = sents.size()

        alpha_all = []
        alpha = self.pi + self._eval_density(sents[0])
        alpha_all.append(alpha.unsqueeze(1))
        for t in range(1, max_length):
            density = self._eval_density(sents[t])
            mask_ep = masks[t].expand(self.num_state, batch_size) \
                      .transpose(0, 1)
            alpha = torch.mul(mask_ep, self._forward_cell(alpha, density)) + \
                    torch.mul(1-mask_ep, alpha)
            alpha_all.append(alpha.unsqueeze(1))

        return torch.cat(alpha_all, dim=1)

    def _forward_cell(self, alpha, density):
        batch_size = len(alpha)
        ep_size = torch.Size([batch_size, self.num_state, self.num_state])
        alpha = log_sum_exp(alpha.unsqueeze(dim=2).expand(ep_size) +
                            self.logA.expand(ep_size) +
                            density.unsqueeze(dim=1).expand(ep_size),
                            dim=1)

        return alpha

    def _backward_cell(self, beta, density):
        """
        density: (batch_size, num_state)
        beta: (batch_size, num_state)

        """
        batch_size = len(beta)
        ep_size = torch.Size([batch_size, self.num_state, self.num_state])
        beta = log_sum_exp(self.logA.expand(ep_size) +
                           density.unsqueeze(dim=1).expand(ep_size) +
                           beta.unsqueeze(dim=1).expand(ep_size),
                           dim=2)

        return beta

    def _eval_density(self, words):
        """
        Args:
            words: (batch_size, self.num_dims)

        Returns: Tensor1
            Tensor1: the density tensor with shape (batch_size, num_state)
        """

        batch_size = words.size(0)
        ep_size = torch.Size([batch_size, self.num_state, self.num_dims])
        words = words.unsqueeze(dim=1).expand(ep_size)
        means = self.means.expand(ep_size)
        var = self.var.expand(ep_size)

        return self.log_density_c - \
               0.5 * torch.sum((means-words) ** 2 / var, dim=2)

    def _calc_logA(self):
        return (self.tparams - \
                log_sum_exp(self.tparams, dim=1, keepdim=True) \
                .expand(self.num_state, self.num_state))

    def _calc_log_mul_emit(self):
        return self.emission - \
                log_sum_exp(self.emission, dim=1, keepdim=True) \
                .expand(self.num_state, self.vocab_size)

    def _viterbi(self, sents_var, masks):
        """
        Args:
            sents_var: (sent_length, batch_size, num_dims)
            masks: (sent_length, batch_size)
        """

        self.log_density_c = self._calc_log_density_c()
        self.logA = self._calc_logA()

        length, batch_size = masks.size()

        # (batch_size, num_state)
        delta = self.pi + self._eval_density(sents_var[0])

        ep_size = torch.Size([batch_size, self.num_state, self.num_state])
        index_all = []

        # forward calculate delta
        for t in range(1, length):
            density = self._eval_density(sents_var[t])
            delta_new = self.logA.expand(ep_size) + \
                    density.unsqueeze(dim=1).expand(ep_size) + \
                    delta.unsqueeze(dim=2).expand(ep_size)
            mask_ep = masks[t].view(-1, 1, 1).expand(ep_size)
            delta = mask_ep * delta_new + \
                    (1 - mask_ep) * delta.unsqueeze(dim=1).expand(ep_size)

            # index: (batch_size, num_state)
            delta, index = torch.max(delta, dim=1)
            index_all.append(index)

        assign_all = []
        # assign: (batch_size)
        _, assign = torch.max(delta, dim=1)
        assign_all.append(assign.unsqueeze(dim=1))

        # backward retrieve path
        # len(index_all) = length-1
        for t in range(length - 2, -1, -1):
            assign_new = torch.gather(index_all[t],
                                      dim=1,
                                      index=assign.view(-1, 1)).squeeze(dim=1)

            assign_new = assign_new.float()
            assign = assign.float()
            assign = masks[t + 1] * assign_new + (1 - masks[t + 1]) * assign
            assign = assign.long()

            assign_all.append(assign.unsqueeze(dim=1))

        assign_all = assign_all[-1::-1]

        return torch.cat(assign_all, dim=1)

    def test_supervised(self, test_data):
        """Evaluate tagging performance with
        token-level supervised accuracy

        Args:
            test_data: ConlluData object

        Returns: a scalar accuracy value

        """
        total = 0.0
        correct = 0.0

        index_all = []
        eval_tags = []

        for iter_obj in test_data.data_iter(batch_size=self.args.batch_size,
                                            shuffle=False):
            sents_t = iter_obj.embed
            masks = iter_obj.mask
            tags_t = iter_obj.pos

            sents_t, _ = self.transform(sents_t, masks)

            # index: (batch_size, seq_length)
            index = self._viterbi(sents_t, masks)

            for index_s, tag_s, mask_s in zip(index, tags_t.transpose(0, 1),
                                              masks.transpose(0, 1)):
                for i in range(int(mask_s.sum().item())):
                    if index_s[i].item() == tag_s[i].item():
                        correct += 1
                    total += 1

        return correct / total

    def test_unsupervised(self,
                          test_data,
                          sentences=None,
                          tagging=False,
                          path=None,
                          null_index=None):
        """Evaluate tagging performance with
        many-to-1 metric, VM score and 1-to-1
        accuracy

        Args:
            test_data: ConlluData object
            tagging: output the predicted tags if True
            path: The output tag file path
            null_index: the null element location in Penn
                        Treebank, only used for writing unsupervised
                        tags for downstream parsing task

        Returns:
            Tuple1: (M1, VM score, 1-to-1 accuracy)

        """

        total = 0.0
        correct = 0.0
        cnt_stats = {}
        match_dict = {}

        index_all = []
        eval_tags = []

        gold_vm = []
        model_vm = []

        for i in range(self.num_state):
            cnt_stats[i] = Counter()

        for iter_obj in test_data.data_iter(batch_size=self.args.batch_size,
                                            shuffle=False):
            total += iter_obj.mask.sum().item()
            sents_t = iter_obj.embed
            tags_t = iter_obj.pos
            masks = iter_obj.mask

            sents_t, _ = self.transform(sents_t, masks)

            # index: (batch_size, seq_length)
            index = self._viterbi(sents_t, masks)

            index_all += list(index)

            tags = [
                tags_t[:int(masks[:, i].sum().item()), i]
                for i in range(index.size(0))
            ]
            eval_tags += tags

            # count
            for (seq_gold_tags, seq_model_tags) in zip(tags, index):
                for (gold_tag, model_tag) in zip(seq_gold_tags,
                                                 seq_model_tags):
                    model_tag = model_tag.item()
                    gold_tag = gold_tag.item()
                    gold_vm += [gold_tag]
                    model_vm += [model_tag]
                    cnt_stats[model_tag][gold_tag] += 1

        # evaluate one-to-one accuracy
        cost_matrix = np.zeros((self.num_state, self.num_state))
        for i in range(self.num_state):
            for j in range(self.num_state):
                cost_matrix[i][j] = -cnt_stats[j][i]

        row_ind, col_ind = linear_sum_assignment(cost_matrix)
        for (seq_gold_tags, seq_model_tags) in zip(eval_tags, index_all):
            for (gold_tag, model_tag) in zip(seq_gold_tags, seq_model_tags):
                model_tag = model_tag.item()
                gold_tag = gold_tag.item()

                if col_ind[gold_tag] == model_tag:
                    correct += 1

        one2one = correct / total

        correct = 0.

        # match
        for tag in cnt_stats:
            if len(cnt_stats[tag]) != 0:
                match_dict[tag] = cnt_stats[tag].most_common(1)[0][0]

        # eval many2one
        for (seq_gold_tags, seq_model_tags) in zip(eval_tags, index_all):
            for (gold_tag, model_tag) in zip(seq_gold_tags, seq_model_tags):
                model_tag = model_tag.item()
                gold_tag = gold_tag.item()
                if match_dict[model_tag] == gold_tag:
                    correct += 1

        if tagging:
            write_conll(path, sentences, index_all, null_index)

        return correct / total, v_measure_score(gold_vm, model_vm), one2one
示例#17
0
class Connection(AbstractConnection):
    # language=rst
    """
    Specifies synapses between one or two populations of neurons.
    """

    def __init__(
        self,
        source: Nodes,
        target: Nodes,
        impulse_amplitude: float,
        impulse_length: float,
        impulse_shape_factor: float = 0.9,
        invert: bool = False,
        nu: Optional[Union[float, Sequence[float]]] = None,
        reduction: Optional[callable] = None,
        weight_decay: float = 0.0,
        post_spike_weight_decay: float = 0.0,
        **kwargs
    ) -> None:
        # language=rst
        """
        Instantiates a :code:`Connection` object.

        :param source: A layer of nodes from which the connection originates.
        :param target: A layer of nodes to which the connection connects.
        :param nu: Learning rate for both pre- and post-synaptic events.
        :param reduction: Method for reducing parameter updates along the minibatch dimension.
        :param weight_decay: Constant multiple to decay weights by on each iteration.

        Keyword arguments:

        :param LearningRule update_rule: Modifies connection parameters according to some rule.
        :param torch.Tensor w: Strengths of synapses.
        :param torch.Tensor b: Target population bias.
        :param float wmin: Minimum allowed value on the connection weights.
        :param float wmax: Maximum allowed value on the connection weights.
        :param float norm: Total weight per target neuron normalization constant.
        :param ByteTensor norm_by_max: Normalize the weight of a neuron by its max weight.
        :param ByteTensor norm_by_max_with_shadow_weights: Normalize the weight of a neuron by its max weight by
                                                           original weights.
        """
        super().__init__(source, target, impulse_amplitude, impulse_length, impulse_shape_factor, invert, nu, reduction, weight_decay, post_spike_weight_decay, **kwargs)

        w = kwargs.get("w", None)
        if w is None:
            if self.wmin == -np.inf or self.wmax == np.inf:
                w = torch.clamp(torch.rand(source.n, target.n), self.wmin, self.wmax)
            else:
                w = self.wmin + torch.rand(source.n, target.n) * (self.wmax - self.wmin)
        else:
            if self.wmin != -np.inf or self.wmax != np.inf:
                w = torch.clamp(w, self.wmin, self.wmax)

                
                
                
                
                
                
        self.w = Parameter(w, False)
        self.b = Parameter(kwargs.get("b", torch.zeros(target.n)), False)

        if self.norm_by_max_from_shadow_weights:
            self.shadow_w = self.w.clone().detach()
            self.prev_w = self.w.clone().detach()
    
    
    def update_impulse_state(self, s):
        self.impulse_state += (self.impulse_state > 0).float().view(-1) # adds 1 on where were spikes before
        s[self.impulse_state.unsqueeze(0) > 0] = 0
        self.source.s[self.impulse_state.unsqueeze(0) > 0] = 0
        self.impulse_state += (self.impulse_state == 0).float() * s.float().view(-1) # adds 1 on spikes
        impulse = self.impulse_curve()
        self.impulse_state *= (self.impulse_state < self.impulse_length).float()
        return impulse

    def impulse_curve(self):
        
        k = self.impulse_shape_factor
        
        if self.invert:
            
            impulse_value_2 = self.impulse_amplitude/(self.impulse_length*k - 1) #производная в точках, не находящихся в середине импульса
        
            impulse_value_1 = self.impulse_amplitude/(self.impulse_length*(1-k))
        
            impulse_bias =  2*self.impulse_amplitude *(self.impulse_state > (self.impulse_length * (1-k)+ 0.5)).float().view(-1) *(self.impulse_state <= (self.impulse_length * (1-k)+1.5)).float().view(-1)
                                               
            impulse = (-impulse_value_1) * (self.impulse_state > 0).float().view(-1) * (self.impulse_state <= (self.impulse_length * (1-k) + 0.5)).float().view(-1) + (-impulse_value_2)  * (self.impulse_state > (self.impulse_length * (1-k)+1.5)).float().view(-1) + impulse_bias   
        
            return impulse
            
        else:
            
            impulse_value = self.impulse_amplitude/(self.impulse_length - 2 )/k #производная в точках, не находящихся в середине импульса
    
            impulse_bias =  (2*self.impulse_amplitude *(abs(self.impulse_state - (self.impulse_length * k)) < 0.5).float().view(-1) + 2*self.impulse_amplitude*(abs(self.impulse_state - (self.impulse_length * k)) == 0.5).float().view(-1)*(self.impulse_state < self.impulse_length * k).float().view(-1))
                                               
            impulse = impulse_value * (self.impulse_state > 0).float().view(-1) * (self.impulse_state < self.impulse_length * k).float().view(-1) + impulse_value/((1-k)/k)  * (self.impulse_state > self.impulse_length * k).float().view(-1) - impulse_bias
        
        
            return impulse
      
    
    def compute(self, s: torch.Tensor) -> torch.Tensor:
        # language=rst
        """
        Compute pre-activations given spikes using connection weights.

        :param s: Incoming spikes.
        :return: Incoming spikes multiplied by synaptic weights (with or without
                 decaying spike activation).
        """
        # Compute multiplication of spike activations by weights and add bias.
        # language=rst
        """
        Compute pre-activations given spikes using connection weights.

        :param s: Incoming spikes.
        :return: Incoming spikes multiplied by synaptic weights (with or without decaying spike activation).
        """
       
        impulse = self.update_impulse_state(s)
        self.a_pre += impulse
        #
        self.a_pre *= (self.impulse_state > 0).float()
        #
        # Compute multiplication of spike activations by connection weights.
        a_post = self.a_pre @ self.w
        return a_post.view(*self.target.shape)



    def update(self, **kwargs) -> None:
        # language=rst
        """
        Compute connection's update rule.
        """
        super().update(**kwargs)

    def normalize(self) -> None:
        # language=rst
        """
        Normalize weights so each target neuron has sum of connection weights equal to
        ``self.norm``.
        """
        if self.norm is not None:
            w_abs_sum = self.w.abs().sum(0).unsqueeze(0)
            w_abs_sum[w_abs_sum == 0] = 1.0
            self.w *= self.norm / w_abs_sum

    def normalize_by_max(self) -> None:
        # language=rst
        """
        Normalize weights by the max weight of the target neuron.
        """
        if self.norm_by_max:
            w_max = self.w.abs().max(0)[0]
            w_max[w_max == 0] = 1.0
            self.w /= w_max

    def normalize_by_max_from_shadow_weights(self) -> None:
        # language=rst
        """
        Normalize weights by the max weight of the target neuron.
        """
        if self.norm_by_max_from_shadow_weights:
            self.shadow_w += self.w - self.prev_w
            w_max = self.shadow_w.abs().max(0)[0]
            w_max[w_max == 0] = 1.0
            self.w = self.shadow_w / w_max
            self.prev_w = self.w.clone().detach()

    def reset_(self) -> None:
        # language=rst
        """
        Contains resetting logic for the connection.
        """
        super().reset_()
        self.a_pre = torch.zeros_like(self.a_pre)
        self.impulse_state = torch.zeros_like(self.impulse_state)
示例#18
0
class VonMisesFisherReparametrizedSample(nn.Module):

    def __init__(self, batch_shape, event_shape, eps):
        if isinstance(batch_shape, Number):
            batch_shape = torch.Size([batch_shape])
        self.batch_shape = batch_shape
        if isinstance(event_shape, Number):
            event_shape = torch.Size([event_shape])
        self.event_shape = event_shape
        assert len(event_shape) == 1
        self.dim = int(event_shape[0])
        super(VonMisesFisherReparametrizedSample, self).__init__()
        self.loc = Parameter(torch.Tensor(batch_shape + event_shape))
        nn.init.kaiming_normal_(self.loc)
        self.loc.data /= torch.sum(self.loc.data ** 2, dim=-1, keepdim=True) ** 0.5
        concentration_init = ml_kappa(dim=float(event_shape[0]), eps=eps)
        print('concentration init', concentration_init)
        self.softplus_inv_concentration = Parameter(torch.Tensor(batch_shape).uniform_(softplus_inv(concentration_init), softplus_inv(concentration_init)))
        # self.softplus_inv_concentration = Parameter(torch.Tensor(batch_shape))
        # Too large kappa slow down rejection sampling, so we set upper bound, which is called in forward pass
        self.softplus_inv_concentration_upper_bound = softplus_inv(ml_kappa(dim=float(event_shape[0]), eps=2e-3))
        self.beta_sample = None
        self.concentration = None
        self.gradient_correction_required = True
        self.softplus_inv_concentration_normal_mean = softplus_inv(ml_kappa(dim=float(event_shape[0]), eps=EPSILON))
        self.softplus_inv_concentration_normal_std = 0.001
        self.direction_init_method = None
        self.rsample = None
        self.loc_init_type = 'random'

    def forward(self, sample_shape, sample=False):
        # self.loc.data /= torch.sum(self.loc.data ** 2, dim=-1, keepdim=True) ** 0.5
        if isinstance(sample_shape, Number):
            sample_shape = torch.Size([sample_shape])
        if not sample:
            assert sample_shape == torch.Size([1])
            return self.loc.unsqueeze(0)
        w_sample = self._rejection_sampling(sample_shape).unsqueeze(-1)
        assert (torch.abs(w_sample) < 1).all()
        spherical_section_sample = self._spherical_section_uniform_sampling(sample_shape)
        rsample_concentration = torch.cat([w_sample, (1 - w_sample ** 2) ** 0.5 * spherical_section_sample], dim=-1)
        rsample = self._householder_transformation(rsample_concentration, sample_shape)
        return rsample

    def sample_kld(self):
        pass

    def mode(self):
        return self.loc / torch.sum(self.loc ** 2, dim=-1, keepdim=True) ** 0.5

    def parameter_adjustment(self):
        self.loc.data /= torch.sum(self.loc.data ** 2, dim=-1, keepdim=True) ** 0.5
        self.softplus_inv_concentration.data = self.softplus_inv_concentration.data.clamp(max=self.softplus_inv_concentration_upper_bound)

    def reset_parameters(self, hyperparams={}):
        if 'vMF' in hyperparams.keys():
            if 'direction' in hyperparams['vMF'].keys():
                if type(hyperparams['vMF']['direction']) == str:
                    if hyperparams['vMF']['direction'] == 'kaiming':
                        self.direction_init_method = torch.nn.init.kaiming_normal_
                    elif hyperparams['vMF']['direction'] == 'kaiming_transpose':
                        self.direction_init_method = kaiming_transpose
                    elif hyperparams['vMF']['direction'] == 'orthogonal':
                        self.direction_init_method = torch.nn.init.orthogonal_
                    self.direction_init_method(self.loc)
                elif type(hyperparams['vMF']['direction']) == torch.Tensor:
                    self.loc.data.copy_(hyperparams['vMF']['direction'])
                    self.loc_init_type = 'fixed'
                else:
                    raise NotImplementedError

                self.loc.data /= torch.sum(self.loc.data ** 2, dim=-1, keepdim=True) ** 0.5

            if 'softplus_inv_concentration_normal_mean' in hyperparams['vMF'].keys():
                self.softplus_inv_concentration_normal_mean = hyperparams['vMF']['softplus_inv_concentration_normal_mean']
            if 'softplus_inv_concentration_normal_mean_via_epsilon' in hyperparams['vMF'].keys():
                epsilon = hyperparams['vMF']['softplus_inv_concentration_normal_mean_via_epsilon']
                self.softplus_inv_concentration_normal_mean = softplus_inv(ml_kappa(dim=float(self.event_shape[0]), eps=epsilon))
            if 'softplus_inv_concentration_normal_std' in hyperparams['vMF'].keys():
                self.softplus_inv_concentration_normal_std = hyperparams['vMF']['softplus_inv_concentration_normal_std']
        torch.nn.init.normal_(self.softplus_inv_concentration, self.softplus_inv_concentration_normal_mean*2, self.softplus_inv_concentration_normal_std)
        self.p_loc = self.loc.clone().detach().requires_grad_(False)
        # print(self.softplus_inv_concentration)

    def init_hyperparams_repr(self):
        if self.loc_init_type == 'random':
            loc_init_str = 'location ' + self.direction_init_method.__name__
        elif self.loc_init_type == 'fixed':
            loc_init_str = 'location fixed'
        else:
            raise NotImplementedError
        return '%s, Softplus inverse(scale)~Normal(%.2E, %.2E)' % (loc_init_str, self.softplus_inv_concentration_normal_mean, self.softplus_inv_concentration_normal_std)

    def _rejection_sampling(self, sample_shape):
        """
        :param concentration: tensor
        :param dim: scalar
        :return:
        """
        # concentration = softplus(self.softplus_inv_concentration)
        concentration = softplus(self.softplus_inv_concentration)
        beta_param = 0.5 * (self.dim - 1)
        flattened_concentration = concentration.repeat(sample_shape + torch.Size([1] * concentration.dim())).view(-1)
        sqrt = (4 * flattened_concentration ** 2 + (self.dim - 1) ** 2) ** 0.5

        b = (-2 * flattened_concentration + sqrt) / (self.dim - 1)
        # when concentration is too large compared to dim then b is zero due to underflow. so in this case, we use taylor approximation, bad_b means b == 0 or b is inf
        bad_b_ind = torch.isinf(b.detach()) + (b.detach() == 0)
        b[bad_b_ind] = sqrt_taylor_approximation(((self.dim - 1) / (2.0 * flattened_concentration[bad_b_ind])) ** 0.2) * 2.0 * flattened_concentration[bad_b_ind] / (self.dim - 1)

        a = (self.dim - 1 + 2 * flattened_concentration + sqrt) / 4.0
        # in calculation of sqrt square of concentration may give inf
        inf_a_ind = torch.isinf(a.detach())
        a[inf_a_ind] = ((sqrt_taylor_approximation(((self.dim - 1) / (2.0 * flattened_concentration[inf_a_ind])) ** 0.2) + 2) * 2 * flattened_concentration[inf_a_ind] + self.dim - 1) / 4.0

        d = 4 * a * b / (1 + b) - (self.dim - 1) * math.log(self.dim - 1)
        rejected = torch.ones_like(flattened_concentration).byte()
        beta_sample = flattened_concentration.new(flattened_concentration.size())
        if torch.isinf(d).any():
            raise RuntimeError('w sampling in vMF generates infinite d')
        if (d != d).any():
            print(flattened_concentration[d != d])
            print(sqrt[d != d])
            print(b[d != d])
            raise RuntimeError('w sampling in vMF generates nan d')
        while rejected.any():
            rejected_ind = rejected.nonzero().squeeze(1)
            beta_param_tensor = concentration.new_full((int(torch.sum(rejected)),), beta_param)
            eps_beta = Beta(beta_param_tensor, beta_param_tensor).rsample()
            eps_unif = torch.rand_like(eps_beta)
            b_sub = b[rejected]
            a_sub = a[rejected]
            d_sub = d[rejected]
            t_sub = 2 * a_sub * b_sub / (1 - (1 - b_sub) * eps_beta)
            criteria_sub = (self.dim - 1) * torch.log(t_sub) - t_sub + d_sub >= torch.log(eps_unif)
            beta_sample[rejected_ind[criteria_sub]] = eps_beta[criteria_sub]
            rejected[rejected_ind[criteria_sub]] = 0
            # aug = torch.exp((self.dim - 1) * torch.log(t_sub) - t_sub + d_sub)
            # print('%3d %.4E~%.4E' % (rejected_ind.numel(), float(torch.min(aug)), float(torch.max(aug))))
        self.beta_sample = beta_sample.reshape(torch.Size(sample_shape + concentration.size()))
        self.concentration = flattened_concentration.reshape(torch.Size(sample_shape + concentration.size()))
        vmf_sample = (1 - (1 + b.reshape(torch.Size(sample_shape + concentration.size()))) * self.beta_sample) \
                     / (1 - (1 - b.reshape(torch.Size(sample_shape + concentration.size()))) * self.beta_sample)
        return vmf_sample.clamp(min=-1.0 + 1e-7)

    def _spherical_section_uniform_sampling(self, sample_shape):
        shape = torch.Size(sample_shape + self.batch_shape + torch.Size([self.dim - 1]))
        sample = self.loc.new(shape).normal_()
        return sample / (sample ** 2).sum(dim=-1, keepdim=True) ** 0.5

    def _householder_transformation(self, rsample_concentration, sample_shape):
        loc = self.loc / torch.sum(self.loc ** 2, dim=-1, keepdim=True) ** 0.5
        u_prime = torch.zeros_like(loc).index_fill_(dim=-1, index=loc.new_zeros((1,)).long(), value=1) - loc
        u = (u_prime / (u_prime ** 2).sum(dim=-1, keepdim=True) ** 0.5).repeat(torch.Size(sample_shape + torch.Size([1] * loc.dim())))
        return rsample_concentration - 2 * (rsample_concentration * u).sum(dim=-1, keepdim=True) * u

    def gradient_correction(self, integrand):
        dim = self.dim
        nu = float(dim) / 2.0
        beta_sample = self.beta_sample.detach()
        concentration = self.concentration.detach()
        concentration.requires_grad_()
        correction_derivative = self._correction_derivative(concentration, beta_sample, dim)
        correction_deriv_grad = grad(correction_derivative, concentration, grad_outputs=torch.ones_like(correction_derivative))[0]
        concentration.requires_grad_(False)

        ## Bound from Thm4 (A new type of sharp bounds for ratios of modified Bessel functions), a bit loose upper bound for small nu, for larger nu, this works better??
        # bessel_ratio_upper = concentration / ((nu - 1.0) + ((nu + 1.0) ** 2 + concentration ** 2) ** 0.5)
        # bessel_ratio_lower = concentration / ((nu - 0.5) + ((nu + 0.5) ** 2 + concentration ** 2) ** 0.5)
        ## Bound from Thm5 (A new type of sharp bounds for ratios of modified Bessel functions), This is better with larger kappa (less uncertainty case)
        concentration_sq = concentration ** 2
        lambda0 = nu - 0.5
        delta0 = (nu - 0.5) + lambda0 / (lambda0 ** 2 + concentration_sq) ** 0.5 / 2.0
        bessel_ratio_upper = concentration / (delta0 + (delta0 ** 2 + concentration_sq) ** 0.5)
        lambda2 = nu + 0.5
        delta2 = (nu - 0.5) + lambda2 / (lambda2 ** 2 + concentration_sq) ** 0.5 / 2.0
        bessel_ratio_lower = concentration / (delta2 + (delta2 ** 2 + concentration_sq) ** 0.5)

        correction_bessel_ratio = -(bessel_ratio_upper + bessel_ratio_lower) / 2.0
        correction = (integrand.detach() * (correction_bessel_ratio + correction_deriv_grad) * softplus_inv_derivative(concentration)).sum(dim=tuple(range(concentration.dim() - len(self.batch_shape))))
        if torch.isinf(correction).any():
            raise RuntimeError('vMF gradient correction is infinite. Check the argument of gradient_correction.')
        if (correction != correction).any():
            if (concentration != concentration).any():
                raise RuntimeError('vMF gradient correction is nan due to concentration.')
            elif (correction_bessel_ratio != correction_bessel_ratio).any():
                raise RuntimeError('vMF gradient correction is nan due to correction bessel ratio.')
            elif (correction_deriv_grad != correction_deriv_grad).any():
                raise RuntimeError('vMF gradient correction is nan due to correction derivative gradient.')
            elif (integrand != integrand).any():
                raise RuntimeError('vMF gradient correction is nan due to integrand.')
        if (self.softplus_inv_concentration.grad.data != self.softplus_inv_concentration.grad.data).any():
            raise RuntimeError('vMF backward is nan.')

        # self.softplus_inv_concentration.grad.data += correction
        self.softplus_inv_concentration.grad.data += correction.sum()

        # This is for updating concentration sampler parameters
        # self.log_concentration_rsample.backward(self.log_concentration.grad.data)

    @staticmethod
    def _correction_derivative(concentration, beta_sample, dim):
        b = (-2 * concentration + (4 * concentration ** 2 + (dim - 1) ** 2) ** 0.5) / (dim - 1)
        # when concentration is too large compared to dim then b is zero due to underflow. so in this case, we use taylor approximation
        bad_b_ind = torch.isinf(b.detach()) + (b.detach() == 0)
        b[bad_b_ind] = sqrt_taylor_approximation(((dim - 1) / (2.0 * concentration[bad_b_ind])) ** 0.2) * 2.0 * concentration[bad_b_ind] / (dim - 1)
        w = (1 - (1 + b) * beta_sample) / (1 - (1 - b) * beta_sample)
        one_w_ind = torch.abs(w.detach()) == 1
        b_reciprocal = 1.0 / b[one_w_ind]
        w[one_w_ind] = -(1 + 2.0 * b_reciprocal + 2.0 * b_reciprocal ** 2)
        w = w.clamp(min=-1 + 1e-7)
        if (torch.abs(w) == 1).any():
            raise RuntimeError('vMF gradient derivative is nan due to w.')
        return w * concentration + (dim - 3.0) / 2.0 * torch.log(1 - w ** 2) + torch.log(torch.abs(2 * b / ((b - 1) * beta_sample + 1) ** 2))
示例#19
0
class LSTM(nn.Module):
    """An LSTM."""
    def __init__(self, num_inputs, num_outputs, num_hid, num_layers):
        super(LSTM, self).__init__()

        self.num_inputs = num_inputs
        self.num_outputs = num_outputs
        self.num_hid = num_hid
        self.num_layers = num_layers
        self.prev_state = None

        print('!!! USING VANILLA LSTM !!!')
        self.lstm = nn.LSTM(input_size=self.num_inputs,
                            hidden_size=self.num_hid,
                            num_layers=self.num_layers)
        self.fc = nn.Linear(self.num_hid, self.num_outputs, bias=True)

        # The hidden state is a learned parameter
        self.lstm_h_bias = Parameter(
            torch.randn(self.num_layers, 1, self.num_hid) * 0.05)
        self.lstm_c_bias = Parameter(
            torch.randn(self.num_layers, 1, self.num_hid) * 0.05)

        self.reset_parameters()

    def create_new_state(self, batch_size):
        # Dimension: (num_layers * num_directions, batch, hidden_size)
        lstm_h = self.lstm_h_bias.clone().repeat(1, batch_size, 1)
        lstm_c = self.lstm_c_bias.clone().repeat(1, batch_size, 1)
        return lstm_h, lstm_c

    def init_sequence(self, batch_size):
        """Initializing the state."""
        self.batch_size = batch_size
        self.prev_state = self.create_new_state(batch_size)

    def reset_parameters(self):
        for p in self.lstm.parameters():
            if p.dim() == 1:
                nn.init.constant(p, 0)
            else:
                stdev = 5 / (np.sqrt(self.num_inputs + self.num_outputs))
                nn.init.uniform(p, -stdev, stdev)

        for p in self.fc.parameters():
            if p.dim() == 1:
                nn.init.constant(p, 0)
            else:
                stdev = 5 / (np.sqrt(self.num_inputs + self.num_outputs))
                nn.init.uniform(p, -stdev, stdev)

    def size(self):
        return self.num_inputs, self.num_outputs

    def forward(self, x=None):
        if x is None:
            x = Variable(torch.zeros(self.batch_size, self.num_inputs))

        x = x.unsqueeze(0)
        o, self.prev_state = self.lstm(x, self.prev_state)
        o = self.fc(o)
        o = F.sigmoid(o)

        return o.squeeze(0), self.prev_state

    def calculate_num_params(self):
        """Returns the total number of parameters."""
        num_params = 0
        for p in self.parameters():
            num_params += p.data.view(-1).size(0)
        return num_params