class controller(nn.Module): # LSTM Controller def __init__(self, num_inputs, num_outputs, num_layers): super(controller, self).__init__() self.num_inputs = num_inputs self.num_outputs = num_outputs self. num_layers = num_layers self.lstm_network = nn.LSTM(input_size = self.num_inputs, hidden_size = self.num_outputs, num_layers = self.num_layers) # Parameters of the LSTM. Hidden state serves as the output of our network self.h_init = Parameter(torch.randn(self.num_layers, 1, self.num_outputs) * 0.05) # Hidden state initialization self.c_init = Parameter(torch.randn(self.num_layers, 1, self.num_outputs) * 0.05) # C variable initialization # Initialization of the LSTM parameters. for p in self.lstm_network.parameters(): if p.dim() == 1: nn.init.constant_(p, 0) else: stdev = 5 / (np.sqrt(self.num_inputs + self.num_outputs)) # I don't know why we multiplied 5 nn.init.uniform_(p, -stdev, stdev) def create_hidden_state(self, batch_size): # Output : (num_layers x batch_size x num_outputs) h = self.h_init.clone().repeat(1, batch_size, 1) c = self.c_init.clone().repeat(1, batch_size, 1) return h, c def network_size(self): return self.num_inputs, self.num_outputs def forward(self, inp, prev_state): inp = inp.unsqueeze(0) # inp dimension after unsqueeze : (1 x inp.shape) output, state = self.lstm_network(inp, prev_state) return output.squeeze(0), state
class Controller(nn.Module): def __init__(self, input_size, output_size, num_layers): super(Controller, self).__init__() self.input_size = input_size self.output_size = output_size self.num_layers = num_layers self.lstm = nn.LSTM(input_size=input_size, hidden_size=output_size, num_layers=num_layers) self.reset() self.lstm_h_bias = Parameter(torch.randn(self.num_layers, 1, self.output_size) * 0.05) self.lstm_c_bias = Parameter(torch.randn(self.num_layers, 1, self.output_size) * 0.05) def new_init_state(self, batch_size): # Dimension: (num_layers * num_directions, batch, hidden_size) lstm_h = self.lstm_h_bias.clone().repeat(1, batch_size, 1) lstm_c = self.lstm_c_bias.clone().repeat(1, batch_size, 1) return lstm_h, lstm_c def reset(self): for p in self.lstm.parameters(): if p.dim() == 1: nn.init.constant_(p, 0) else: stdev = 5 / (np.sqrt(self.input_size + self.output_size)) nn.init.uniform_(p, -stdev, stdev) def forward(self, x, prev_state): out, state = self.lstm(x.unsqueeze(0), prev_state) return out.squeeze(0), state
class LSTMController(nn.Module): def __init__(self, input_size, output_size, controller_size, read_data_size, num_outputs, outp_layer_size): super(LSTMController, self).__init__() self.input_size = input_size self.output_size = output_size self.num_layers = controller_size self.num_outputs = output_size self.lstm = nn.LSTM(input_size=input_size, hidden_size=output_size) self.output_layer = nn.Linear(outp_layer_size, num_outputs) self.lstm_h_bias = Parameter( torch.randn(self.num_layers, 1, self.num_outputs) * 0.05) self.lstm_c_bias = Parameter( torch.randn(self.num_layers, 1, self.num_outputs) * 0.05) self.h_state = torch.zeros([1, output_size]) def forward(self, inputs, hidden=None): inputs = inputs.unsqueeze(0) outp, state = self.lstm(inputs, hidden) self.h_state = outp.squeeze(0) return self.h_state, state def output(self, read_data): end_state = read_data output = torch.nn.functional.sigmoid(self.output_layer(end_state)) return output def size(self): return self.input_size, self.output_size def create_new_state(self, batch_size): # Dimension: (num_layers * num_directions, batch, hidden_size) lstm_h = self.lstm_h_bias.clone().repeat(1, batch_size, 1) lstm_c = self.lstm_c_bias.clone().repeat(1, batch_size, 1) return lstm_h, lstm_c
class Controller(StatefulComponent): def __init__(self, embedding_size, hidden_size, dictionary_size=None): # Configurations super().__init__() self.dictionary_size = dictionary_size self.embedding_size = embedding_size self.hidden_size = hidden_size # Embedding layer (optional) self.embedding = nn.Embedding( self.dictionary_size, self.embedding_size) if dictionary_size else None # LSTM cell to extract features from input self.cell = nn.LSTMCell(self.embedding_size, self.hidden_size) # Learnable LSTM hidden state biases self.h_bias = Parameter(Tensor(self.hidden_size).normal_()) self.c_bias = Parameter(Tensor(self.hidden_size).normal_()) # States self.h = None self.c = None def forward(self, x, device=None): # supply the lstm cell with zeros in the embedding space. (no input) if device is None: device = x.device if x is None: e = Variable( torch.zeros( self.expected_batch_size, self.embedding_size, ).type_as(self.h.data)).to(device) # supply the lstm cell with an embedded input. elif self.embedding: e = self.embedding(x) # supply the lstm cell with the input as-is. (assuming that the input # is already in the embedding space) else: assert x.size() == ( self.expected_batch_size, self.embedding_size ), 'Input should have size of {b}x{e}, while given {s}.'.format( b=self.expected_batch_size, e=self.embedding_size, s=x.size(), ) e = x # run an lstm cell and update the states. self.h, self.c = self.cell(e, (self.h, self.c)) return self.h def reset(self, batch_size, device=None): super().reset(batch_size) self.h = self.h_bias.clone().repeat(batch_size, 1) self.c = self.c_bias.clone().repeat(batch_size, 1) if device is not None: self.h = self.h.to(device) self.c = self.c.to(device)
class LSTMBaseline(nn.Module): """An NTM controller based on LSTM.""" def __init__(self, num_inputs, num_hidden, num_outputs, num_layers): super(LSTMBaseline, self).__init__() self.num_inputs = num_inputs self.num_hidden = num_hidden self.num_layers = num_layers self.lstm = nn.LSTM(input_size=num_inputs, hidden_size=num_hidden, num_layers=num_layers) self.out = nn.Linear(num_hidden, num_outputs) # The hidden state is a learned parameter if torch.cuda.is_available(): self.lstm_h_bias = Parameter(torch.randn(self.num_layers, 1, self.num_hidden).cuda() * 0.05) self.lstm_c_bias = Parameter(torch.randn(self.num_layers, 1, self.num_hidden).cuda() * 0.05) else: self.lstm_h_bias = Parameter(torch.randn(self.num_layers, 1, self.num_hidden) * 0.05) self.lstm_c_bias = Parameter(torch.randn(self.num_layers, 1, self.num_hidden) * 0.05) self.reset_parameters() def create_new_state(self, batch_size): # Dimension: (num_layers * num_directions, batch, hidden_size) lstm_h = self.lstm_h_bias.clone().repeat(1, batch_size, 1) lstm_c = self.lstm_c_bias.clone().repeat(1, batch_size, 1) return lstm_h, lstm_c def reset_parameters(self): for p in self.lstm.parameters(): if p.dim() == 1: nn.init.constant_(p, 0) else: stdev = 5 / (np.sqrt(self.num_inputs + self.num_hidden)) nn.init.uniform_(p, -stdev, stdev) def init_sequence(self, batch_size): """Initializing the state.""" self.previous_state = self.create_new_state(batch_size) def size(self): return self.num_inputs, self.num_hidden def forward(self, x): x = x.unsqueeze(0) outp, self.previous_state = self.lstm(x, self.previous_state) outp = self.out(outp) return outp.squeeze(0), self.previous_state def calculate_num_params(self): """Returns the total number of parameters.""" num_params = 0 for p in self.parameters(): num_params += p.data.view(-1).size(0) return num_params
class backward_controller(nn.Module ): # Backward LSTM to make DNC Bi-Directional def __init__(self, num_inputs, num_outputs, num_layers): super(backward_controller, self).__init__() self.num_inputs = num_inputs self.num_outputs = num_outputs self.num_layers = num_layers self.lstm_network = nn.LSTM(input_size=self.num_inputs, hidden_size=self.num_outputs, num_layers=self.num_layers) # Parameters of the LSTM. Hidden state serves as the output of our network self.h_init = Parameter( torch.randn(self.num_layers, 1, self.num_outputs) * 0.05) # Hidden state initialization self.c_init = Parameter( torch.randn(self.num_layers, 1, self.num_outputs) * 0.05) # C variable initialization # Initialization of the LSTM parameters. for p in self.lstm_network.parameters(): if p.dim() == 1: nn.init.constant_(p, 0) else: stdev = 5 / (np.sqrt(self.num_inputs + self.num_outputs) ) # I don't know why we multiplied 5 nn.init.uniform_(p, -stdev, stdev) def create_hidden_state( self, batch_size): # Output : (num_layers x batch_size x num_outputs) h = self.h_init.clone().repeat(1, batch_size, 1) c = self.c_init.clone().repeat(1, batch_size, 1) return h, c def network_size(self): return self.num_inputs, self.num_outputs def forward( self, inp, prev_states): # inp dimension: (seq_len x batch_size x input_size) inp = inp[torch.arange( inp.shape[0] - 1, -1, -1), :, :] # Reversing the input for backward direction output, state = self.lstm_network( inp, prev_states ) # Input to LSTM must be of shape (seq_len x batch_size x input_size) in Pytorch # output = output[torch.arange(output.shape[0]-1, -1, -1), :, :] # Reversing the 'output'. return output, state # Output size is (seq_len x batch x hidden_size) as per documentation
class LSTMController(nn.Module): """An NTM controller based on LSTM.""" def __init__(self, num_inputs, num_outputs, num_layers): super(LSTMController, self).__init__() self.num_inputs = num_inputs self.num_outputs = num_outputs self.num_layers = num_layers self.lstm = nn.LSTM(input_size=num_inputs, hidden_size=num_outputs, num_layers=num_layers) # The hidden state is a learned parameter if torch.cuda.is_available(): self.lstm_h_bias = Parameter( torch.randn(self.num_layers, 1, self.num_outputs).cuda() * 0.05) self.lstm_c_bias = Parameter( torch.randn(self.num_layers, 1, self.num_outputs).cuda() * 0.05) else: self.lstm_h_bias = Parameter( torch.randn(self.num_layers, 1, self.num_outputs) * 0.05) self.lstm_c_bias = Parameter( torch.randn(self.num_layers, 1, self.num_outputs) * 0.05) self.reset_parameters() def create_new_state(self, batch_size): # Dimension: (num_layers * num_directions, batch, hidden_size) lstm_h = self.lstm_h_bias.clone().repeat(1, batch_size, 1) lstm_c = self.lstm_c_bias.clone().repeat(1, batch_size, 1) return lstm_h, lstm_c def reset_parameters(self): for p in self.lstm.parameters(): if p.dim() == 1: nn.init.constant_(p, 0) else: stdev = 5 / (np.sqrt(self.num_inputs + self.num_outputs)) nn.init.uniform_(p, -stdev, stdev) def size(self): return self.num_inputs, self.num_outputs def forward(self, x, prev_state): x = x.unsqueeze(0) outp, state = self.lstm(x, prev_state) return outp.squeeze(0), state
class Memory(nn.Module): def __init__(self, memory_size): super(Memory, self).__init__() self._memory_size = memory_size # Initialize memory bias initial_state = torch.ones(memory_size) * 1e-6 self.register_buffer('initial_state', initial_state.data) # Initial read vector is a learnt parameter self.initial_read = Parameter( torch.randn(1, self._memory_size[1]) * 0.01) def get_size(self): return self._memory_size def reset(self, batch_size): self.memory = self.initial_state.clone().repeat(batch_size, 1, 1) def get_initial_read(self, batch_size): return self.initial_read.clone().repeat(batch_size, 1) def read(self): return self.memory def write(self, w, e, a): self.memory = self.memory * ( 1 - torch.matmul(w.unsqueeze(-1), e.unsqueeze(1))) self.memory = self.memory + torch.matmul(w.unsqueeze(-1), a.unsqueeze(1)) return self.memory def size(self): return self._memory_size
class ControllerState(nn.Module): def __init__(self, controller): super(ControllerState, self).__init__() self.controller = controller self.device = torch.device("cpu") if torch.cuda.is_available(): self.device = torch.device("cuda") # starting hidden state is a learned parameter self.lstm_h_bias = Parameter(torch.randn(self.controller.num_layers, 1, self.controller.num_outputs) * 0.05) self.lstm_c_bias = Parameter(torch.randn(self.controller.num_layers, 1, self.controller.num_outputs) * 0.05) self.to(self.device) def reset(self, batch_size): h = self.lstm_h_bias.clone().repeat(1, batch_size, 1) c = self.lstm_c_bias.clone().repeat(1, batch_size, 1) self.state = h, c
class LSTMController(nn.Module): def __init__(self, vector_length, hidden_size): super(LSTMController, self).__init__() self.layer = nn.LSTM(input_size=vector_length, hidden_size=hidden_size) # The hidden state is a learned parameter self.lstm_h_state = Parameter(torch.randn(1, 1, hidden_size) * 0.05) self.lstm_c_state = Parameter(torch.randn(1, 1, hidden_size) * 0.05) for p in self.layer.parameters(): if p.dim() == 1: nn.init.constant_(p, 0) else: stdev = 5 / (np.sqrt(vector_length + hidden_size)) nn.init.uniform_(p, -stdev, stdev) def forward(self, x, state): output, state = self.layer(x.unsqueeze(0), state) return output.squeeze(0), state def get_initial_state(self, batch_size): lstm_h = self.lstm_h_state.clone().repeat(1, batch_size, 1) lstm_c = self.lstm_c_state.clone().repeat(1, batch_size, 1) return lstm_h, lstm_c
class LSTMController(Controller): """LSTM controller for the NTM. """ def __init__(self, input_dim, output_dim, num_layers, batch_size, use_cuda): super(LSTMController, self).__init__(input_dim, output_dim, num_layers, batch_size) self.lstm = nn.LSTM(input_size=input_dim, hidden_size=output_dim, num_layers=num_layers) if use_cuda: self.lstm.cuda() # From https://github.com/fanxiao001/ift6135-assignment/blob/master/assignment3/NTM/controller.py self.lstm_h_bias = Parameter( torch.randn(self.num_layers, 1, self.output_dim) * 0.05) self.lstm_c_bias = Parameter( torch.randn(self.num_layers, 1, self.output_dim) * 0.05) self.reset_parameters() def forward(self, x, r, lstm_h, lstm_c): """forward pass of the LSTM controller """ # Concatenate previous read state with input x = torch.cat((r, x.squeeze(0)), 1) # feed into controller with previous state output, state = self.lstm(x.unsqueeze(0), (lstm_h, lstm_c)) return output, state def create_state(self, batch_size): # Dimension: (num_layers * num_directions, batch, hidden_size) # From https://github.com/fanxiao001/ift6135-assignment/blob/master/assignment3/NTM/controller.py lstm_h = self.lstm_h_bias.clone().repeat(1, batch_size, 1) lstm_c = self.lstm_c_bias.clone().repeat(1, batch_size, 1) return lstm_h, lstm_c
def maybe_save(self, t, img: Parameter, content_image): from aikido.nn.modules.styletransfer.fileloader import deprocess should_save = self.save_iter > 0 and t % self.save_iter == 0 and t > 0 should_save = should_save or t == 50 #FIXME if should_save: output_filename, file_extension = os.path.splitext( self.kun.file_name) filename = str(output_filename) + "_" + str(t) + str( file_extension) # disp = deprocess(img.squeeze(0).clone()) disp = deprocess(img.clone(), self.kun) # Maybe perform postprocessing for color-independent style transfer if self.original_colors: disp = self.postprocess_colors( deprocess(content_image.clone(), self.kun), disp) disp.save(str(filename))
class Connection(AbstractConnection): # language=rst """ Specifies synapses between one or two populations of neurons. """ def __init__(self, source: Nodes, target: Nodes, nu: Optional[Union[float, Sequence[float]]] = None, weight_decay: float = 0.0, **kwargs) -> None: # language=rst """ Instantiates a :code:`Connection` object. :param source: A layer of nodes from which the connection originates. :param target: A layer of nodes to which the connection connects. :param nu: Learning rate for both pre- and post-synaptic events. :param weight_decay: Constant multiple to decay weights by on each iteration. Keyword arguments: :param LearningRule update_rule: Modifies connection parameters according to some rule. :param torch.Tensor w: Strengths of synapses. :param torch.Tensor b: Target population bias. :param float wmin: Minimum allowed value on the connection weights. :param float wmax: Maximum allowed value on the connection weights. :param float norm: Total weight per target neuron normalization constant. :param ByteTensor norm_by_max: Normalize the weight of a neuron by its max weight. :param ByteTensor norm_by_max_with_shadow_weights: Normalize the weight of a neuron by its max weight by original weights. """ super().__init__(source, target, nu, weight_decay, **kwargs) w = kwargs.get("w", None) if w is None: if self.wmin == -np.inf or self.wmax == np.inf: w = torch.clamp(torch.rand(source.n, target.n), self.wmin, self.wmax) else: w = self.wmin + torch.rand(source.n, target.n) * (self.wmax - self.wmin) else: if self.wmin != -np.inf or self.wmax != np.inf: w = torch.clamp(w, self.wmin, self.wmax) self.w = Parameter(w, False) self.b = Parameter(kwargs.get("b", torch.zeros(target.n)), False) if self.norm_by_max_from_shadow_weights: self.shadow_w = self.w.clone().detach() self.prev_w = self.w.clone().detach() def compute(self, s: torch.Tensor) -> torch.Tensor: # language=rst """ Compute pre-activations given spikes using connection weights. :param s: Incoming spikes. :return: Incoming spikes multiplied by synaptic weights (with or without decaying spike activation). """ # Compute multiplication of spike activations by connection weights and add bias. post = s.float().view(-1) @ self.w + self.b return post.view(*self.target.shape) def update(self, **kwargs) -> None: # language=rst """ Compute connection's update rule. """ super().update(**kwargs) def normalize(self) -> None: # language=rst """ Normalize weights so each target neuron has sum of connection weights equal to ``self.norm``. """ if self.norm is not None: w_abs_sum = self.w.abs().sum(0).unsqueeze(0) w_abs_sum[w_abs_sum == 0] = 1.0 self.w *= self.norm / w_abs_sum def normalize_by_max(self) -> None: # language=rst """ Normalize weights by the max weight of the target neuron. """ if self.norm_by_max: w_max = self.w.abs().max(0)[0] w_max[w_max == 0] = 1.0 self.w /= w_max def normalize_by_max_from_shadow_weights(self) -> None: # language=rst """ Normalize weights by the max weight of the target neuron. """ if self.norm_by_max_from_shadow_weights: self.shadow_w += self.w - self.prev_w w_max = self.shadow_w.abs().max(0)[0] w_max[w_max == 0] = 1.0 self.w = self.shadow_w / w_max self.prev_w = self.w.clone().detach() def reset_(self) -> None: # language=rst """ Contains resetting logic for the connection. """ super().reset_()
class ResNet(nn.Module): def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, groups=1, width_per_group=64, replace_stride_with_dilation=None, norm_layer=None): super(ResNet, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d self._norm_layer = norm_layer self.inplanes = 64 self.dilation = 1 if replace_stride_with_dilation is None: # each element in the tuple indicates if we should replace # the 2x2 stride with a dilated convolution instead # ----------------------------- # modified replace_stride_with_dilation = [False, False, False] # ----------------------------- if len(replace_stride_with_dilation) != 3: raise ValueError("replace_stride_with_dilation should be None " "or a 3-element tuple, got {}".format( replace_stride_with_dilation)) self.groups = groups self.base_width = width_per_group #layer for RGB input self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = norm_layer(self.inplanes) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) self.avgpool = nn.AdaptiveAvgPool2d(1) #parameter for selective self.selective_0 = Parameter(torch.Tensor(64, 1, 1)) self.selective_1 = Parameter(torch.Tensor(64, 1, 1)) self.selective_2 = Parameter(torch.Tensor(128, 1, 1)) self.selective_3 = Parameter(torch.Tensor(256, 1, 1)) self.selective_4 = Parameter(torch.Tensor(512, 1, 1)) #layer for merge self.att_d = SELayer(64) self.att_rgb = SELayer(64) self.att_d_layer1 = SELayer(64) self.att_rgb_layer1 = SELayer(64) self.att_d_layer2 = SELayer(128) self.att_rgb_layer2 = SELayer(128) self.att_d_layer3 = SELayer(256) self.att_rgb_layer3 = SELayer(256) self.att_d_layer4 = SELayer(512) self.att_rgb_layer4 = SELayer(512) #layer for depth layer self.inplanes = 64 self.conv1_d = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1_d = norm_layer(64) self.relu_d = nn.ReLU(inplace=True) self.maxpool_d = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1_d = self._make_layer(block, 64, layers[0]) self.layer2_d = self._make_layer( block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) self.layer3_d = self._make_layer( block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) self.layer4_d = self._make_layer( block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) # Zero-initialize the last BN in each residual branch, # so that the residual branch starts with zeros, and each residual block behaves like an identity. # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 if zero_init_residual: for m in self.modules(): if isinstance(m, Bottleneck): nn.init.constant_(m.bn3.weight, 0) elif isinstance(m, BasicBlock): nn.init.constant_(m.bn2.weight, 0) nn.init.constant_(self.selective_0, 0.9) nn.init.constant_(self.selective_1, 0.7) nn.init.constant_(self.selective_2, 0.5) nn.init.constant_(self.selective_3, 0.3) nn.init.constant_(self.selective_4, 0.1) def _make_layer(self, block, planes, blocks, stride=1, dilate=False): norm_layer = self._norm_layer downsample = None previous_dilation = self.dilation if dilate: self.dilation *= stride stride = 1 if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( conv1x1(self.inplanes, planes * block.expansion, stride), norm_layer(planes * block.expansion), ) layers = [] layers.append( block(self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer)) self.inplanes = planes * block.expansion for _ in range(1, blocks): layers.append( block(self.inplanes, planes, groups=self.groups, base_width=self.base_width, dilation=self.dilation, norm_layer=norm_layer)) return nn.Sequential(*layers) def forward(self, x, depth): #pdb.set_trace() x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) #output_rgb_0 = x depth = self.conv1_d(depth) depth = self.bn1_d(depth) depth = self.relu_d(depth) depth = self.maxpool_d(depth) #output_depth_0 = depth selective_0_d = (1 - self.relu(self.selective_0.clone())).unsqueeze(0) selective_0_r = self.relu(self.selective_0.clone()).unsqueeze(0) x = selective_0_d * self.att_d(depth) + selective_0_r * self.att_rgb( x) #merge x = self.layer1(x) depth = self.layer1_d(depth) #output_rgb_1 = x #output_depth_1 = depth selective_1_d = (1 - self.relu(self.selective_1.clone())).unsqueeze(0) selective_1_r = self.relu(self.selective_1.clone()).unsqueeze(0) x = selective_1_d * self.att_d_layer1( depth) + selective_1_r * self.att_rgb_layer1(x) #merge #output_fusion_1 = x x = self.layer2(x) depth = self.layer2_d(depth) selective_2_d = (1 - self.relu(self.selective_2.clone())).unsqueeze(0) selective_2_r = self.relu(self.selective_2.clone()).unsqueeze(0) x = selective_2_d * self.att_d_layer2( depth) + selective_2_r * self.att_rgb_layer2(x) #merge x = self.layer3(x) depth = self.layer3_d(depth) selective_3_d = (1 - self.relu(self.selective_3.clone())).unsqueeze(0) selective_3_r = self.relu(self.selective_3.clone()).unsqueeze(0) x = selective_3_d * self.att_d_layer3( depth) + selective_3_r * self.att_rgb_layer3(x) #merge #output_depth_3 = depth x = self.layer4(x) depth = self.layer4_d(depth) selective_4_d = (1 - self.relu(self.selective_4.clone())).unsqueeze(0) selective_4_r = self.relu(self.selective_4.clone()).unsqueeze(0) x = selective_4_d * self.att_d_layer4( depth) + selective_4_r * self.att_rgb_layer4(x) #merge #return x, output_depth_0,output_depth_1,output_depth_3, output_rgb_0, output_rgb_1, output_fusion_1 return x
def attack(model, criterion, img, label, eps, attack_type, iters, clean_clean_img=None): assert not model.training adv = img.clone().detach() adv = Parameter(adv, requires_grad=True) if attack_type == 'fgsm': iterations = 1 else: iterations = iters if attack_type == 'pgd': step = 2 / 255 else: step = eps / iterations noise = 0 for j in range(iterations): outputs = None if aug_test is None: out_adv = model(normalize(adv.clone())) loss = criterion(out_adv, label) loss.backward() else: adv_aux = adv * (1.0 - aug_test_lambda) for i in range( aug_test ): # TODO Check why this uses so much memory... it ain't normal fam adv_aux = adv_aux + aug_test_lambda * clean_clean_img[ torch.randperm(label.size(0))] out = model(normalize(adv_aux)) if outputs is None: outputs = out else: outputs += out out_adv = outputs / aug_test loss = criterion(out_adv, label) loss.backward() if attack_type == 'mim': adv_mean = torch.mean(torch.abs(adv.grad), dim=1, keepdim=True) adv_mean = torch.mean(torch.abs(adv_mean), dim=2, keepdim=True) adv_mean = torch.mean(torch.abs(adv_mean), dim=3, keepdim=True) adv.grad = adv.grad / adv_mean noise = noise + adv.grad else: assert adv.grad is not None noise = adv.grad # Optimization step adv.data = adv.data + step * noise.sign() # adv.data = adv.data + step * adv.grad.sign() if attack_type == 'pgd': adv.data = torch.where(adv.data > img.data + eps, img.data + eps, adv.data) adv.data = torch.where(adv.data < img.data - eps, img.data - eps, adv.data) adv.data.clamp_(0.0, 1.0) adv.grad.data.zero_() return adv.detach()
class MarkovFlow(nn.Module): def __init__(self, args, num_dims): super(MarkovFlow, self).__init__() self.args = args self.device = args.device # Gaussian Variance self.var = Parameter(torch.zeros(num_dims, dtype=torch.float32)) if not args.train_var: self.var.requires_grad = False self.num_state = args.num_state self.num_dims = num_dims self.couple_layers = args.couple_layers self.cell_layers = args.cell_layers self.hidden_units = num_dims // 2 self.lstm_hidden_units = self.num_dims # transition parameters in log space self.tparams = Parameter(torch.Tensor(self.num_state, self.num_state)) self.prior_group = [self.tparams] # Gaussian means self.means = Parameter(torch.Tensor(self.num_state, self.num_dims)) if args.mode == "unsupervised" and args.freeze_prior: self.tparams.requires_grad = False if args.mode == "unsupervised" and args.freeze_mean: self.means.requires_grad = False if args.model == 'nice': self.proj_layer = NICETrans(self.couple_layers, self.cell_layers, self.hidden_units, self.num_dims, self.device) elif args.model == "lstmnice": self.proj_layer = LSTMNICE(self.args.lstm_layers, self.args.couple_layers, self.args.cell_layers, self.lstm_hidden_units, self.hidden_units, self.num_dims, self.device) if args.mode == "unsupervised" and args.freeze_proj: for param in self.proj_layer.parameters(): param.requires_grad = False if args.model == "gaussian": self.proj_group = [self.means, self.var] else: self.proj_group = list( self.proj_layer.parameters()) + [self.means, self.var] # prior self.pi = torch.zeros(self.num_state, dtype=torch.float32, requires_grad=False, device=self.device).fill_(1.0 / self.num_state) self.pi = torch.log(self.pi) def init_params(self, train_data): """ init_seed:(sents, masks) sents: (seq_length, batch_size, features) masks: (seq_length, batch_size) """ # initialize transition matrix params # self.tparams.data.uniform_().add_(1) self.tparams.data.uniform_() # load pretrained model if self.args.load_nice != '': self.load_state_dict(torch.load(self.args.load_nice), strict=True) self.means_init = self.means.clone() self.tparams_init = self.tparams.clone() self.proj_init = [ param.clone() for param in self.proj_layer.parameters() ] if self.args.init_var: self.init_var(train_data) if self.args.init_var_one: self.var.fill_(0.01) # self.means_init.requires_grad = False # self.tparams_init.requires_grad = False # for tensor in self.proj_init: # tensor.requires_grad = False return # load pretrained Gaussian baseline if self.args.load_gaussian != '': self.load_state_dict(torch.load(self.args.load_gaussian), strict=False) # fully unsupervised training if self.args.mode == "unsupervised" and self.args.load_nice == "": with torch.no_grad(): for iter_obj in train_data.data_iter(self.args.batch_size): sents = iter_obj.embed masks = iter_obj.mask sents, _ = self.transform(sents, iter_obj.mask) seq_length, _, features = sents.size() flat_sents = sents.view(-1, features) seed_mean = torch.sum( masks.view(-1, 1).expand_as(flat_sents) * flat_sents, dim=0) / masks.sum() seed_var = torch.sum( masks.view(-1, 1).expand_as(flat_sents) * ((flat_sents - seed_mean.expand_as(flat_sents))**2), dim=0) / masks.sum() self.var.copy_(seed_var) # self.var.fill_(0.02) # add noise to the pretrained Gaussian mean if self.args.load_gaussian != '' and self.args.model == 'nice': self.means.data.add_( seed_mean.data.expand_as(self.means.data)) elif self.args.load_gaussian == '' and self.args.load_nice == '': self.means.data.normal_().mul_(0.04) self.means.data.add_( seed_mean.data.expand_as(self.means.data)) return self.init_mean(train_data) self.var.fill_(1.0) self.init_var(train_data) if self.args.init_var_one: self.var.fill_(1.0) def init_mean(self, train_data): emb_dict = {} cnt_dict = Counter() for iter_obj in train_data.data_iter(self.args.batch_size): sents_t = iter_obj.embed sents_t, _ = self.transform(sents_t, iter_obj.mask) sents_t = sents_t.transpose(0, 1) pos_t = iter_obj.pos.transpose(0, 1) mask_t = iter_obj.mask.transpose(0, 1) for emb_s, tagid_s, mask_s in zip(sents_t, pos_t, mask_t): for tagid, emb, mask in zip(tagid_s, emb_s, mask_s): tagid = tagid.item() mask = mask.item() if tagid in emb_dict: emb_dict[tagid] = emb_dict[tagid] + emb * mask else: emb_dict[tagid] = emb * mask cnt_dict[tagid] += mask for tagid in emb_dict: self.means[tagid] = emb_dict[tagid] / cnt_dict[tagid] def init_var(self, train_data): cnt = 0 mean_sum = 0. var_sum = 0. for iter_obj in train_data.data_iter(batch_size=self.args.batch_size): sents, masks = iter_obj.embed, iter_obj.mask sents, _ = self.transform(sents, masks) seq_length, _, features = sents.size() flat_sents = sents.view(-1, features) mean_sum = mean_sum + torch.sum( masks.view(-1, 1).expand_as(flat_sents) * flat_sents, dim=0) cnt += masks.sum().item() mean = mean_sum / cnt for iter_obj in train_data.data_iter(batch_size=self.args.batch_size): sents, masks = iter_obj.embed, iter_obj.mask sents, _ = self.transform(sents, masks) seq_length, _, features = sents.size() flat_sents = sents.view(-1, features) var_sum = var_sum + torch.sum( masks.view(-1, 1).expand_as(flat_sents) * ((flat_sents - mean.expand_as(flat_sents))**2), dim=0) var = var_sum / cnt self.var.copy_(var) def _calc_log_density_c(self): # return -self.num_dims/2.0 * (math.log(2) + \ # math.log(np.pi)) - 0.5 * self.num_dims * (torch.log(self.var)) return -self.num_dims/2.0 * (math.log(2) + \ math.log(np.pi)) - 0.5 * torch.sum(torch.log(self.var)) def transform(self, x, masks=None): """ Args: x: (sent_length, batch_size, num_dims) """ jacobian_loss = torch.zeros(1, device=self.device, requires_grad=False) if self.args.model != 'gaussian': x, jacobian_loss_new = self.proj_layer(x, masks) jacobian_loss = jacobian_loss + jacobian_loss_new return x, jacobian_loss def MSE_loss(self): # diff1 = ((self.means - self.means_init) ** 2).sum() diff_prior = ((self.tparams - self.tparams_init)**2).sum() # diff = diff1 + diff2 diff_proj = 0. for i, param in enumerate(self.proj_layer.parameters()): diff_proj = diff_proj + ((self.proj_init[i] - param)**2).sum() diff_mean = ((self.means_init - self.means)**2).sum() return 0.5 * (self.args.beta_prior * diff_prior + self.args.beta_proj * diff_proj + self.args.beta_mean * diff_mean) def unsupervised_loss(self, sents, masks): """ Args: sents: (sent_length, batch_size, self.num_dims) masks: (sent_length, batch_size) Returns: Tensor1, Tensor2 Tensor1: negative log likelihood, shape ([]) Tensor2: jacobian loss, shape ([]) """ max_length, batch_size, _ = sents.size() sents, jacobian_loss = self.transform(sents, masks) assert self.var.data.min() > 0 self.logA = self._calc_logA() self.log_density_c = self._calc_log_density_c() alpha = self.pi + self._eval_density(sents[0]) for t in range(1, max_length): density = self._eval_density(sents[t]) mask_ep = masks[t].expand(self.num_state, batch_size) \ .transpose(0, 1) alpha = torch.mul(mask_ep, self._forward_cell(alpha, density)) + \ torch.mul(1-mask_ep, alpha) # calculate objective from log space objective = torch.sum(log_sum_exp(alpha, dim=1)) return -objective, jacobian_loss def supervised_loss(self, sents, tags, masks): """ Args: sents: (sent_length, batch_size, num_dims) masks: (sent_length, batch_size) tags: (sent_length, batch_size) Returns: Tensor1, Tensor2 Tensor1: negative log likelihood, shape ([]) Tensor2: jacobian loss, shape ([]) """ sent_len, batch_size, _ = sents.size() # (sent_length, batch_size, num_dims) sents, jacobian_loss = self.transform(sents, masks) # () log_density_c = self._calc_log_density_c() # (1, 1, num_state, num_dims) means = self.means.view(1, 1, self.num_state, self.num_dims) means = means.expand(sent_len, batch_size, self.num_state, self.num_dims) tag_id = tags.view(*tags.size(), 1, 1).expand(sent_len, batch_size, 1, self.num_dims) # (sent_len, batch_size, num_dims) means = torch.gather(means, dim=2, index=tag_id).squeeze(2) var = self.var.view(1, 1, self.num_dims) # (sent_len, batch_size) log_emission_prob = log_density_c - \ 0.5 * torch.sum((means-sents) ** 2 / var, dim=-1) log_emission_prob = torch.mul(masks, log_emission_prob).sum() # (num_state, num_state) log_trans = self._calc_logA() # (sent_len, batch_size, num_state, num_state) log_trans_prob = log_trans.view(1, 1, *log_trans.size()).expand( sent_len, batch_size, *log_trans.size()) # (sent_len-1, batch_size, 1, num_state) tag_id = tags.view(*tags.size(), 1, 1).expand(sent_len, batch_size, 1, self.num_state)[:-1] # (sent_len-1, batch_size, 1, num_state) log_trans_prob = torch.gather(log_trans_prob[:-1], dim=2, index=tag_id) # (sent_len-1, batch_size, 1, 1) tag_id = tags.view(*tags.size(), 1, 1)[1:] # (sent_len-1, batch_size) log_trans_prob = torch.gather(log_trans_prob, dim=3, index=tag_id).squeeze() log_trans_prob = torch.mul(masks[1:], log_trans_prob) log_trans_prior = self.pi.expand(batch_size, self.num_state) tag_id = tags[0].unsqueeze(dim=1) # (batch_size) log_trans_prior = torch.gather(log_trans_prior, dim=1, index=tag_id).sum() log_trans_prob = log_trans_prior + log_trans_prob.sum() return -(log_trans_prob + log_emission_prob), jacobian_loss def _calc_alpha(self, sents, masks): """ sents: (sent_length, batch_size, self.num_dims) masks: (sent_length, batch_size) Returns: output: (batch_size, sent_length, num_state) """ max_length, batch_size, _ = sents.size() alpha_all = [] alpha = self.pi + self._eval_density(sents[0]) alpha_all.append(alpha.unsqueeze(1)) for t in range(1, max_length): density = self._eval_density(sents[t]) mask_ep = masks[t].expand(self.num_state, batch_size) \ .transpose(0, 1) alpha = torch.mul(mask_ep, self._forward_cell(alpha, density)) + \ torch.mul(1-mask_ep, alpha) alpha_all.append(alpha.unsqueeze(1)) return torch.cat(alpha_all, dim=1) def _forward_cell(self, alpha, density): batch_size = len(alpha) ep_size = torch.Size([batch_size, self.num_state, self.num_state]) alpha = log_sum_exp(alpha.unsqueeze(dim=2).expand(ep_size) + self.logA.expand(ep_size) + density.unsqueeze(dim=1).expand(ep_size), dim=1) return alpha def _backward_cell(self, beta, density): """ density: (batch_size, num_state) beta: (batch_size, num_state) """ batch_size = len(beta) ep_size = torch.Size([batch_size, self.num_state, self.num_state]) beta = log_sum_exp(self.logA.expand(ep_size) + density.unsqueeze(dim=1).expand(ep_size) + beta.unsqueeze(dim=1).expand(ep_size), dim=2) return beta def _eval_density(self, words): """ Args: words: (batch_size, self.num_dims) Returns: Tensor1 Tensor1: the density tensor with shape (batch_size, num_state) """ batch_size = words.size(0) ep_size = torch.Size([batch_size, self.num_state, self.num_dims]) words = words.unsqueeze(dim=1).expand(ep_size) means = self.means.expand(ep_size) var = self.var.expand(ep_size) return self.log_density_c - \ 0.5 * torch.sum((means-words) ** 2 / var, dim=2) def _calc_logA(self): return (self.tparams - \ log_sum_exp(self.tparams, dim=1, keepdim=True) \ .expand(self.num_state, self.num_state)) def _calc_log_mul_emit(self): return self.emission - \ log_sum_exp(self.emission, dim=1, keepdim=True) \ .expand(self.num_state, self.vocab_size) def _viterbi(self, sents_var, masks): """ Args: sents_var: (sent_length, batch_size, num_dims) masks: (sent_length, batch_size) """ self.log_density_c = self._calc_log_density_c() self.logA = self._calc_logA() length, batch_size = masks.size() # (batch_size, num_state) delta = self.pi + self._eval_density(sents_var[0]) ep_size = torch.Size([batch_size, self.num_state, self.num_state]) index_all = [] # forward calculate delta for t in range(1, length): density = self._eval_density(sents_var[t]) delta_new = self.logA.expand(ep_size) + \ density.unsqueeze(dim=1).expand(ep_size) + \ delta.unsqueeze(dim=2).expand(ep_size) mask_ep = masks[t].view(-1, 1, 1).expand(ep_size) delta = mask_ep * delta_new + \ (1 - mask_ep) * delta.unsqueeze(dim=1).expand(ep_size) # index: (batch_size, num_state) delta, index = torch.max(delta, dim=1) index_all.append(index) assign_all = [] # assign: (batch_size) _, assign = torch.max(delta, dim=1) assign_all.append(assign.unsqueeze(dim=1)) # backward retrieve path # len(index_all) = length-1 for t in range(length - 2, -1, -1): assign_new = torch.gather(index_all[t], dim=1, index=assign.view(-1, 1)).squeeze(dim=1) assign_new = assign_new.float() assign = assign.float() assign = masks[t + 1] * assign_new + (1 - masks[t + 1]) * assign assign = assign.long() assign_all.append(assign.unsqueeze(dim=1)) assign_all = assign_all[-1::-1] return torch.cat(assign_all, dim=1) def test_supervised(self, test_data): """Evaluate tagging performance with token-level supervised accuracy Args: test_data: ConlluData object Returns: a scalar accuracy value """ total = 0.0 correct = 0.0 index_all = [] eval_tags = [] for iter_obj in test_data.data_iter(batch_size=self.args.batch_size, shuffle=False): sents_t = iter_obj.embed masks = iter_obj.mask tags_t = iter_obj.pos sents_t, _ = self.transform(sents_t, masks) # index: (batch_size, seq_length) index = self._viterbi(sents_t, masks) for index_s, tag_s, mask_s in zip(index, tags_t.transpose(0, 1), masks.transpose(0, 1)): for i in range(int(mask_s.sum().item())): if index_s[i].item() == tag_s[i].item(): correct += 1 total += 1 return correct / total def test_unsupervised(self, test_data, sentences=None, tagging=False, path=None, null_index=None): """Evaluate tagging performance with many-to-1 metric, VM score and 1-to-1 accuracy Args: test_data: ConlluData object tagging: output the predicted tags if True path: The output tag file path null_index: the null element location in Penn Treebank, only used for writing unsupervised tags for downstream parsing task Returns: Tuple1: (M1, VM score, 1-to-1 accuracy) """ total = 0.0 correct = 0.0 cnt_stats = {} match_dict = {} index_all = [] eval_tags = [] gold_vm = [] model_vm = [] for i in range(self.num_state): cnt_stats[i] = Counter() for iter_obj in test_data.data_iter(batch_size=self.args.batch_size, shuffle=False): total += iter_obj.mask.sum().item() sents_t = iter_obj.embed tags_t = iter_obj.pos masks = iter_obj.mask sents_t, _ = self.transform(sents_t, masks) # index: (batch_size, seq_length) index = self._viterbi(sents_t, masks) index_all += list(index) tags = [ tags_t[:int(masks[:, i].sum().item()), i] for i in range(index.size(0)) ] eval_tags += tags # count for (seq_gold_tags, seq_model_tags) in zip(tags, index): for (gold_tag, model_tag) in zip(seq_gold_tags, seq_model_tags): model_tag = model_tag.item() gold_tag = gold_tag.item() gold_vm += [gold_tag] model_vm += [model_tag] cnt_stats[model_tag][gold_tag] += 1 # evaluate one-to-one accuracy cost_matrix = np.zeros((self.num_state, self.num_state)) for i in range(self.num_state): for j in range(self.num_state): cost_matrix[i][j] = -cnt_stats[j][i] row_ind, col_ind = linear_sum_assignment(cost_matrix) for (seq_gold_tags, seq_model_tags) in zip(eval_tags, index_all): for (gold_tag, model_tag) in zip(seq_gold_tags, seq_model_tags): model_tag = model_tag.item() gold_tag = gold_tag.item() if col_ind[gold_tag] == model_tag: correct += 1 one2one = correct / total correct = 0. # match for tag in cnt_stats: if len(cnt_stats[tag]) != 0: match_dict[tag] = cnt_stats[tag].most_common(1)[0][0] # eval many2one for (seq_gold_tags, seq_model_tags) in zip(eval_tags, index_all): for (gold_tag, model_tag) in zip(seq_gold_tags, seq_model_tags): model_tag = model_tag.item() gold_tag = gold_tag.item() if match_dict[model_tag] == gold_tag: correct += 1 if tagging: write_conll(path, sentences, index_all, null_index) return correct / total, v_measure_score(gold_vm, model_vm), one2one
class Connection(AbstractConnection): # language=rst """ Specifies synapses between one or two populations of neurons. """ def __init__( self, source: Nodes, target: Nodes, impulse_amplitude: float, impulse_length: float, impulse_shape_factor: float = 0.9, invert: bool = False, nu: Optional[Union[float, Sequence[float]]] = None, reduction: Optional[callable] = None, weight_decay: float = 0.0, post_spike_weight_decay: float = 0.0, **kwargs ) -> None: # language=rst """ Instantiates a :code:`Connection` object. :param source: A layer of nodes from which the connection originates. :param target: A layer of nodes to which the connection connects. :param nu: Learning rate for both pre- and post-synaptic events. :param reduction: Method for reducing parameter updates along the minibatch dimension. :param weight_decay: Constant multiple to decay weights by on each iteration. Keyword arguments: :param LearningRule update_rule: Modifies connection parameters according to some rule. :param torch.Tensor w: Strengths of synapses. :param torch.Tensor b: Target population bias. :param float wmin: Minimum allowed value on the connection weights. :param float wmax: Maximum allowed value on the connection weights. :param float norm: Total weight per target neuron normalization constant. :param ByteTensor norm_by_max: Normalize the weight of a neuron by its max weight. :param ByteTensor norm_by_max_with_shadow_weights: Normalize the weight of a neuron by its max weight by original weights. """ super().__init__(source, target, impulse_amplitude, impulse_length, impulse_shape_factor, invert, nu, reduction, weight_decay, post_spike_weight_decay, **kwargs) w = kwargs.get("w", None) if w is None: if self.wmin == -np.inf or self.wmax == np.inf: w = torch.clamp(torch.rand(source.n, target.n), self.wmin, self.wmax) else: w = self.wmin + torch.rand(source.n, target.n) * (self.wmax - self.wmin) else: if self.wmin != -np.inf or self.wmax != np.inf: w = torch.clamp(w, self.wmin, self.wmax) self.w = Parameter(w, False) self.b = Parameter(kwargs.get("b", torch.zeros(target.n)), False) if self.norm_by_max_from_shadow_weights: self.shadow_w = self.w.clone().detach() self.prev_w = self.w.clone().detach() def update_impulse_state(self, s): self.impulse_state += (self.impulse_state > 0).float().view(-1) # adds 1 on where were spikes before s[self.impulse_state.unsqueeze(0) > 0] = 0 self.source.s[self.impulse_state.unsqueeze(0) > 0] = 0 self.impulse_state += (self.impulse_state == 0).float() * s.float().view(-1) # adds 1 on spikes impulse = self.impulse_curve() self.impulse_state *= (self.impulse_state < self.impulse_length).float() return impulse def impulse_curve(self): k = self.impulse_shape_factor if self.invert: impulse_value_2 = self.impulse_amplitude/(self.impulse_length*k - 1) #производная в точках, не находящихся в середине импульса impulse_value_1 = self.impulse_amplitude/(self.impulse_length*(1-k)) impulse_bias = 2*self.impulse_amplitude *(self.impulse_state > (self.impulse_length * (1-k)+ 0.5)).float().view(-1) *(self.impulse_state <= (self.impulse_length * (1-k)+1.5)).float().view(-1) impulse = (-impulse_value_1) * (self.impulse_state > 0).float().view(-1) * (self.impulse_state <= (self.impulse_length * (1-k) + 0.5)).float().view(-1) + (-impulse_value_2) * (self.impulse_state > (self.impulse_length * (1-k)+1.5)).float().view(-1) + impulse_bias return impulse else: impulse_value = self.impulse_amplitude/(self.impulse_length - 2 )/k #производная в точках, не находящихся в середине импульса impulse_bias = (2*self.impulse_amplitude *(abs(self.impulse_state - (self.impulse_length * k)) < 0.5).float().view(-1) + 2*self.impulse_amplitude*(abs(self.impulse_state - (self.impulse_length * k)) == 0.5).float().view(-1)*(self.impulse_state < self.impulse_length * k).float().view(-1)) impulse = impulse_value * (self.impulse_state > 0).float().view(-1) * (self.impulse_state < self.impulse_length * k).float().view(-1) + impulse_value/((1-k)/k) * (self.impulse_state > self.impulse_length * k).float().view(-1) - impulse_bias return impulse def compute(self, s: torch.Tensor) -> torch.Tensor: # language=rst """ Compute pre-activations given spikes using connection weights. :param s: Incoming spikes. :return: Incoming spikes multiplied by synaptic weights (with or without decaying spike activation). """ # Compute multiplication of spike activations by weights and add bias. # language=rst """ Compute pre-activations given spikes using connection weights. :param s: Incoming spikes. :return: Incoming spikes multiplied by synaptic weights (with or without decaying spike activation). """ impulse = self.update_impulse_state(s) self.a_pre += impulse # self.a_pre *= (self.impulse_state > 0).float() # # Compute multiplication of spike activations by connection weights. a_post = self.a_pre @ self.w return a_post.view(*self.target.shape) def update(self, **kwargs) -> None: # language=rst """ Compute connection's update rule. """ super().update(**kwargs) def normalize(self) -> None: # language=rst """ Normalize weights so each target neuron has sum of connection weights equal to ``self.norm``. """ if self.norm is not None: w_abs_sum = self.w.abs().sum(0).unsqueeze(0) w_abs_sum[w_abs_sum == 0] = 1.0 self.w *= self.norm / w_abs_sum def normalize_by_max(self) -> None: # language=rst """ Normalize weights by the max weight of the target neuron. """ if self.norm_by_max: w_max = self.w.abs().max(0)[0] w_max[w_max == 0] = 1.0 self.w /= w_max def normalize_by_max_from_shadow_weights(self) -> None: # language=rst """ Normalize weights by the max weight of the target neuron. """ if self.norm_by_max_from_shadow_weights: self.shadow_w += self.w - self.prev_w w_max = self.shadow_w.abs().max(0)[0] w_max[w_max == 0] = 1.0 self.w = self.shadow_w / w_max self.prev_w = self.w.clone().detach() def reset_(self) -> None: # language=rst """ Contains resetting logic for the connection. """ super().reset_() self.a_pre = torch.zeros_like(self.a_pre) self.impulse_state = torch.zeros_like(self.impulse_state)
class VonMisesFisherReparametrizedSample(nn.Module): def __init__(self, batch_shape, event_shape, eps): if isinstance(batch_shape, Number): batch_shape = torch.Size([batch_shape]) self.batch_shape = batch_shape if isinstance(event_shape, Number): event_shape = torch.Size([event_shape]) self.event_shape = event_shape assert len(event_shape) == 1 self.dim = int(event_shape[0]) super(VonMisesFisherReparametrizedSample, self).__init__() self.loc = Parameter(torch.Tensor(batch_shape + event_shape)) nn.init.kaiming_normal_(self.loc) self.loc.data /= torch.sum(self.loc.data ** 2, dim=-1, keepdim=True) ** 0.5 concentration_init = ml_kappa(dim=float(event_shape[0]), eps=eps) print('concentration init', concentration_init) self.softplus_inv_concentration = Parameter(torch.Tensor(batch_shape).uniform_(softplus_inv(concentration_init), softplus_inv(concentration_init))) # self.softplus_inv_concentration = Parameter(torch.Tensor(batch_shape)) # Too large kappa slow down rejection sampling, so we set upper bound, which is called in forward pass self.softplus_inv_concentration_upper_bound = softplus_inv(ml_kappa(dim=float(event_shape[0]), eps=2e-3)) self.beta_sample = None self.concentration = None self.gradient_correction_required = True self.softplus_inv_concentration_normal_mean = softplus_inv(ml_kappa(dim=float(event_shape[0]), eps=EPSILON)) self.softplus_inv_concentration_normal_std = 0.001 self.direction_init_method = None self.rsample = None self.loc_init_type = 'random' def forward(self, sample_shape, sample=False): # self.loc.data /= torch.sum(self.loc.data ** 2, dim=-1, keepdim=True) ** 0.5 if isinstance(sample_shape, Number): sample_shape = torch.Size([sample_shape]) if not sample: assert sample_shape == torch.Size([1]) return self.loc.unsqueeze(0) w_sample = self._rejection_sampling(sample_shape).unsqueeze(-1) assert (torch.abs(w_sample) < 1).all() spherical_section_sample = self._spherical_section_uniform_sampling(sample_shape) rsample_concentration = torch.cat([w_sample, (1 - w_sample ** 2) ** 0.5 * spherical_section_sample], dim=-1) rsample = self._householder_transformation(rsample_concentration, sample_shape) return rsample def sample_kld(self): pass def mode(self): return self.loc / torch.sum(self.loc ** 2, dim=-1, keepdim=True) ** 0.5 def parameter_adjustment(self): self.loc.data /= torch.sum(self.loc.data ** 2, dim=-1, keepdim=True) ** 0.5 self.softplus_inv_concentration.data = self.softplus_inv_concentration.data.clamp(max=self.softplus_inv_concentration_upper_bound) def reset_parameters(self, hyperparams={}): if 'vMF' in hyperparams.keys(): if 'direction' in hyperparams['vMF'].keys(): if type(hyperparams['vMF']['direction']) == str: if hyperparams['vMF']['direction'] == 'kaiming': self.direction_init_method = torch.nn.init.kaiming_normal_ elif hyperparams['vMF']['direction'] == 'kaiming_transpose': self.direction_init_method = kaiming_transpose elif hyperparams['vMF']['direction'] == 'orthogonal': self.direction_init_method = torch.nn.init.orthogonal_ self.direction_init_method(self.loc) elif type(hyperparams['vMF']['direction']) == torch.Tensor: self.loc.data.copy_(hyperparams['vMF']['direction']) self.loc_init_type = 'fixed' else: raise NotImplementedError self.loc.data /= torch.sum(self.loc.data ** 2, dim=-1, keepdim=True) ** 0.5 if 'softplus_inv_concentration_normal_mean' in hyperparams['vMF'].keys(): self.softplus_inv_concentration_normal_mean = hyperparams['vMF']['softplus_inv_concentration_normal_mean'] if 'softplus_inv_concentration_normal_mean_via_epsilon' in hyperparams['vMF'].keys(): epsilon = hyperparams['vMF']['softplus_inv_concentration_normal_mean_via_epsilon'] self.softplus_inv_concentration_normal_mean = softplus_inv(ml_kappa(dim=float(self.event_shape[0]), eps=epsilon)) if 'softplus_inv_concentration_normal_std' in hyperparams['vMF'].keys(): self.softplus_inv_concentration_normal_std = hyperparams['vMF']['softplus_inv_concentration_normal_std'] torch.nn.init.normal_(self.softplus_inv_concentration, self.softplus_inv_concentration_normal_mean*2, self.softplus_inv_concentration_normal_std) self.p_loc = self.loc.clone().detach().requires_grad_(False) # print(self.softplus_inv_concentration) def init_hyperparams_repr(self): if self.loc_init_type == 'random': loc_init_str = 'location ' + self.direction_init_method.__name__ elif self.loc_init_type == 'fixed': loc_init_str = 'location fixed' else: raise NotImplementedError return '%s, Softplus inverse(scale)~Normal(%.2E, %.2E)' % (loc_init_str, self.softplus_inv_concentration_normal_mean, self.softplus_inv_concentration_normal_std) def _rejection_sampling(self, sample_shape): """ :param concentration: tensor :param dim: scalar :return: """ # concentration = softplus(self.softplus_inv_concentration) concentration = softplus(self.softplus_inv_concentration) beta_param = 0.5 * (self.dim - 1) flattened_concentration = concentration.repeat(sample_shape + torch.Size([1] * concentration.dim())).view(-1) sqrt = (4 * flattened_concentration ** 2 + (self.dim - 1) ** 2) ** 0.5 b = (-2 * flattened_concentration + sqrt) / (self.dim - 1) # when concentration is too large compared to dim then b is zero due to underflow. so in this case, we use taylor approximation, bad_b means b == 0 or b is inf bad_b_ind = torch.isinf(b.detach()) + (b.detach() == 0) b[bad_b_ind] = sqrt_taylor_approximation(((self.dim - 1) / (2.0 * flattened_concentration[bad_b_ind])) ** 0.2) * 2.0 * flattened_concentration[bad_b_ind] / (self.dim - 1) a = (self.dim - 1 + 2 * flattened_concentration + sqrt) / 4.0 # in calculation of sqrt square of concentration may give inf inf_a_ind = torch.isinf(a.detach()) a[inf_a_ind] = ((sqrt_taylor_approximation(((self.dim - 1) / (2.0 * flattened_concentration[inf_a_ind])) ** 0.2) + 2) * 2 * flattened_concentration[inf_a_ind] + self.dim - 1) / 4.0 d = 4 * a * b / (1 + b) - (self.dim - 1) * math.log(self.dim - 1) rejected = torch.ones_like(flattened_concentration).byte() beta_sample = flattened_concentration.new(flattened_concentration.size()) if torch.isinf(d).any(): raise RuntimeError('w sampling in vMF generates infinite d') if (d != d).any(): print(flattened_concentration[d != d]) print(sqrt[d != d]) print(b[d != d]) raise RuntimeError('w sampling in vMF generates nan d') while rejected.any(): rejected_ind = rejected.nonzero().squeeze(1) beta_param_tensor = concentration.new_full((int(torch.sum(rejected)),), beta_param) eps_beta = Beta(beta_param_tensor, beta_param_tensor).rsample() eps_unif = torch.rand_like(eps_beta) b_sub = b[rejected] a_sub = a[rejected] d_sub = d[rejected] t_sub = 2 * a_sub * b_sub / (1 - (1 - b_sub) * eps_beta) criteria_sub = (self.dim - 1) * torch.log(t_sub) - t_sub + d_sub >= torch.log(eps_unif) beta_sample[rejected_ind[criteria_sub]] = eps_beta[criteria_sub] rejected[rejected_ind[criteria_sub]] = 0 # aug = torch.exp((self.dim - 1) * torch.log(t_sub) - t_sub + d_sub) # print('%3d %.4E~%.4E' % (rejected_ind.numel(), float(torch.min(aug)), float(torch.max(aug)))) self.beta_sample = beta_sample.reshape(torch.Size(sample_shape + concentration.size())) self.concentration = flattened_concentration.reshape(torch.Size(sample_shape + concentration.size())) vmf_sample = (1 - (1 + b.reshape(torch.Size(sample_shape + concentration.size()))) * self.beta_sample) \ / (1 - (1 - b.reshape(torch.Size(sample_shape + concentration.size()))) * self.beta_sample) return vmf_sample.clamp(min=-1.0 + 1e-7) def _spherical_section_uniform_sampling(self, sample_shape): shape = torch.Size(sample_shape + self.batch_shape + torch.Size([self.dim - 1])) sample = self.loc.new(shape).normal_() return sample / (sample ** 2).sum(dim=-1, keepdim=True) ** 0.5 def _householder_transformation(self, rsample_concentration, sample_shape): loc = self.loc / torch.sum(self.loc ** 2, dim=-1, keepdim=True) ** 0.5 u_prime = torch.zeros_like(loc).index_fill_(dim=-1, index=loc.new_zeros((1,)).long(), value=1) - loc u = (u_prime / (u_prime ** 2).sum(dim=-1, keepdim=True) ** 0.5).repeat(torch.Size(sample_shape + torch.Size([1] * loc.dim()))) return rsample_concentration - 2 * (rsample_concentration * u).sum(dim=-1, keepdim=True) * u def gradient_correction(self, integrand): dim = self.dim nu = float(dim) / 2.0 beta_sample = self.beta_sample.detach() concentration = self.concentration.detach() concentration.requires_grad_() correction_derivative = self._correction_derivative(concentration, beta_sample, dim) correction_deriv_grad = grad(correction_derivative, concentration, grad_outputs=torch.ones_like(correction_derivative))[0] concentration.requires_grad_(False) ## Bound from Thm4 (A new type of sharp bounds for ratios of modified Bessel functions), a bit loose upper bound for small nu, for larger nu, this works better?? # bessel_ratio_upper = concentration / ((nu - 1.0) + ((nu + 1.0) ** 2 + concentration ** 2) ** 0.5) # bessel_ratio_lower = concentration / ((nu - 0.5) + ((nu + 0.5) ** 2 + concentration ** 2) ** 0.5) ## Bound from Thm5 (A new type of sharp bounds for ratios of modified Bessel functions), This is better with larger kappa (less uncertainty case) concentration_sq = concentration ** 2 lambda0 = nu - 0.5 delta0 = (nu - 0.5) + lambda0 / (lambda0 ** 2 + concentration_sq) ** 0.5 / 2.0 bessel_ratio_upper = concentration / (delta0 + (delta0 ** 2 + concentration_sq) ** 0.5) lambda2 = nu + 0.5 delta2 = (nu - 0.5) + lambda2 / (lambda2 ** 2 + concentration_sq) ** 0.5 / 2.0 bessel_ratio_lower = concentration / (delta2 + (delta2 ** 2 + concentration_sq) ** 0.5) correction_bessel_ratio = -(bessel_ratio_upper + bessel_ratio_lower) / 2.0 correction = (integrand.detach() * (correction_bessel_ratio + correction_deriv_grad) * softplus_inv_derivative(concentration)).sum(dim=tuple(range(concentration.dim() - len(self.batch_shape)))) if torch.isinf(correction).any(): raise RuntimeError('vMF gradient correction is infinite. Check the argument of gradient_correction.') if (correction != correction).any(): if (concentration != concentration).any(): raise RuntimeError('vMF gradient correction is nan due to concentration.') elif (correction_bessel_ratio != correction_bessel_ratio).any(): raise RuntimeError('vMF gradient correction is nan due to correction bessel ratio.') elif (correction_deriv_grad != correction_deriv_grad).any(): raise RuntimeError('vMF gradient correction is nan due to correction derivative gradient.') elif (integrand != integrand).any(): raise RuntimeError('vMF gradient correction is nan due to integrand.') if (self.softplus_inv_concentration.grad.data != self.softplus_inv_concentration.grad.data).any(): raise RuntimeError('vMF backward is nan.') # self.softplus_inv_concentration.grad.data += correction self.softplus_inv_concentration.grad.data += correction.sum() # This is for updating concentration sampler parameters # self.log_concentration_rsample.backward(self.log_concentration.grad.data) @staticmethod def _correction_derivative(concentration, beta_sample, dim): b = (-2 * concentration + (4 * concentration ** 2 + (dim - 1) ** 2) ** 0.5) / (dim - 1) # when concentration is too large compared to dim then b is zero due to underflow. so in this case, we use taylor approximation bad_b_ind = torch.isinf(b.detach()) + (b.detach() == 0) b[bad_b_ind] = sqrt_taylor_approximation(((dim - 1) / (2.0 * concentration[bad_b_ind])) ** 0.2) * 2.0 * concentration[bad_b_ind] / (dim - 1) w = (1 - (1 + b) * beta_sample) / (1 - (1 - b) * beta_sample) one_w_ind = torch.abs(w.detach()) == 1 b_reciprocal = 1.0 / b[one_w_ind] w[one_w_ind] = -(1 + 2.0 * b_reciprocal + 2.0 * b_reciprocal ** 2) w = w.clamp(min=-1 + 1e-7) if (torch.abs(w) == 1).any(): raise RuntimeError('vMF gradient derivative is nan due to w.') return w * concentration + (dim - 3.0) / 2.0 * torch.log(1 - w ** 2) + torch.log(torch.abs(2 * b / ((b - 1) * beta_sample + 1) ** 2))
class LSTM(nn.Module): """An LSTM.""" def __init__(self, num_inputs, num_outputs, num_hid, num_layers): super(LSTM, self).__init__() self.num_inputs = num_inputs self.num_outputs = num_outputs self.num_hid = num_hid self.num_layers = num_layers self.prev_state = None print('!!! USING VANILLA LSTM !!!') self.lstm = nn.LSTM(input_size=self.num_inputs, hidden_size=self.num_hid, num_layers=self.num_layers) self.fc = nn.Linear(self.num_hid, self.num_outputs, bias=True) # The hidden state is a learned parameter self.lstm_h_bias = Parameter( torch.randn(self.num_layers, 1, self.num_hid) * 0.05) self.lstm_c_bias = Parameter( torch.randn(self.num_layers, 1, self.num_hid) * 0.05) self.reset_parameters() def create_new_state(self, batch_size): # Dimension: (num_layers * num_directions, batch, hidden_size) lstm_h = self.lstm_h_bias.clone().repeat(1, batch_size, 1) lstm_c = self.lstm_c_bias.clone().repeat(1, batch_size, 1) return lstm_h, lstm_c def init_sequence(self, batch_size): """Initializing the state.""" self.batch_size = batch_size self.prev_state = self.create_new_state(batch_size) def reset_parameters(self): for p in self.lstm.parameters(): if p.dim() == 1: nn.init.constant(p, 0) else: stdev = 5 / (np.sqrt(self.num_inputs + self.num_outputs)) nn.init.uniform(p, -stdev, stdev) for p in self.fc.parameters(): if p.dim() == 1: nn.init.constant(p, 0) else: stdev = 5 / (np.sqrt(self.num_inputs + self.num_outputs)) nn.init.uniform(p, -stdev, stdev) def size(self): return self.num_inputs, self.num_outputs def forward(self, x=None): if x is None: x = Variable(torch.zeros(self.batch_size, self.num_inputs)) x = x.unsqueeze(0) o, self.prev_state = self.lstm(x, self.prev_state) o = self.fc(o) o = F.sigmoid(o) return o.squeeze(0), self.prev_state def calculate_num_params(self): """Returns the total number of parameters.""" num_params = 0 for p in self.parameters(): num_params += p.data.view(-1).size(0) return num_params