def __init__(self, n_channels, ss, individual_index=0, se_block=True): super(CnnSearchModule, self).__init__() self.ss = ss self.n_channels = n_channels self.config_dict = {'n_channels': n_channels} self.sub_graph_module = SubGraphModule( ss, self.config_dict, individual_index=individual_index) if ss.single_block: self.n_inputs = len(ss.ocl[0].inputs) else: self.n_inputs = len(ss.ocl[individual_index][0].inputs) if se_block: self.se_block = SEBlock(n_channels, 8) else: self.se_block = Identity() self.bn = nn.BatchNorm2d(n_channels) self.relu = nn.ReLU() if self.ss.single_block: self.weights = [ Parameter(torch.Tensor(n_channels, n_channels, 1, 1)) for _ in range(len(ss.ocl)) ] else: self.weights = [ Parameter(torch.Tensor(n_channels, n_channels, 1, 1)) for _ in range(len(ss.ocl[individual_index])) ] [ self.register_parameter('w_' + str(i), w) for i, w in enumerate(self.weights) ] self.register_parameter('bias', None) self.reset_parameters()
def test_cnn_sub_module(self): ss = generate_ss_cnn() sgm = SubGraphModule(ss, {'n_channels': 64}) for i in range(100): sgm.set_individual(ss.generate_individual()) y = torch.randn(32, 64, 16, 16, dtype=torch.float) x = torch.randn(32, 64, 16, 16, dtype=torch.float) res = sgm(x, y)
def test_run_sub_module(self): ss = generate_ss() sgm = SubGraphModule(ss, {'in_channels': 32, 'n_channels': 128}) y = torch.randn(25, 128, dtype=torch.float) for i in range(100): sgm.set_individual(ss.generate_individual()) x = torch.randn(25, 32, dtype=torch.float) y = sgm(x, y) y = y[-1]
def __init__(self, in_channels, n_channels, working_device, ss): super(RnnSearchModule, self).__init__() self.ss = ss self.in_channels = in_channels self.n_channels = n_channels self.working_device = working_device self.config_dict = { 'in_channels': self.in_channels, 'n_channels': self.n_channels } self.sub_graph_module = SubGraphModule(ss, self.config_dict) self.reset_parameters()
class CnnSearchModule(nn.Module): def __init__(self, n_channels, ss, individual_index=0, se_block=True): super(CnnSearchModule, self).__init__() self.ss = ss self.n_channels = n_channels self.config_dict = {'n_channels': n_channels} self.sub_graph_module = SubGraphModule( ss, self.config_dict, individual_index=individual_index) if ss.single_block: self.n_inputs = len(ss.ocl[0].inputs) else: self.n_inputs = len(ss.ocl[individual_index][0].inputs) if se_block: self.se_block = SEBlock(n_channels, 8) else: self.se_block = Identity() self.bn = nn.BatchNorm2d(n_channels) self.relu = nn.ReLU() if self.ss.single_block: self.weights = [ Parameter(torch.Tensor(n_channels, n_channels, 1, 1)) for _ in range(len(ss.ocl)) ] else: self.weights = [ Parameter(torch.Tensor(n_channels, n_channels, 1, 1)) for _ in range(len(ss.ocl[individual_index])) ] [ self.register_parameter('w_' + str(i), w) for i, w in enumerate(self.weights) ] self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): n = self.n_channels * len(self.weights) stdv = 1. / math.sqrt(n) for w in self.weights: w.data.uniform_(-stdv, stdv) def forward(self, inputs_tensor, bypass_input): if self.n_inputs == 1: net = self.sub_graph_module(inputs_tensor) elif self.n_inputs == 2: net = self.sub_graph_module(inputs_tensor, bypass_input) net = torch.cat( [net[i] for i in self.sub_graph_module.avg_index if i > 1], dim=1) w = torch.cat([ self.weights[i - 2] for i in self.sub_graph_module.avg_index if i > 1 ], dim=1) net = self.bn(F.conv2d(self.relu(net), w, self.bias, 1, 0, 1, 1)) return self.se_block(net) + inputs_tensor def set_individual(self, individual: Individual): self.sub_graph_module.set_individual(individual) def parameters(self): for name, param in self.named_parameters(): yield param
class RnnSearchModule(nn.Module): def __init__(self, in_channels, n_channels, working_device, ss): super(RnnSearchModule, self).__init__() self.ss = ss self.in_channels = in_channels self.n_channels = n_channels self.working_device = working_device self.config_dict = { 'in_channels': self.in_channels, 'n_channels': self.n_channels } self.sub_graph_module = SubGraphModule(ss, self.config_dict) self.reset_parameters() def forward(self, inputs_tensor, state_tensor): # input size [Time step,Batch,features] state = state_tensor[0, :, :] outputs = [] for i in torch.split(inputs_tensor, split_size_or_sections=1, dim=0): # Loop over time steps output, state = self.cell(i, state) # state_norm = state.norm(dim=-1) # max_norm = 25.0 # if torch.any(state_norm > max_norm).item(): # clip_select = state_norm > max_norm # clip_norms = state_norm[clip_select] # # mask = torch.ones(state.size(), device=self.working_device) # normalizer = max_norm / clip_norms # mask[clip_select, :] = normalizer.unsqueeze(dim=-1) # mask = mask.detach() # state *= mask # print(np.max(state.norm(dim=-1).detach().cpu().numpy())) # print("Max Norm pass") # state = state / state.norm(dim=-1) outputs.append(output) output = torch.stack(outputs, dim=0) return output, state.unsqueeze(dim=0) def cell(self, x, state): net = self.sub_graph_module(x.squeeze(dim=0), state) output, state = torch.mean(torch.stack( [net[i] for i in self.sub_graph_module.avg_index], dim=-1), dim=-1), net[-1] return output, output def set_individual(self, individual: Individual): self.sub_graph_module.set_individual(individual) def init_state(self, batch_size=1): # model init state weight = next(self.parameters()) return weight.new_zeros(1, batch_size, self.n_channels) def parameters(self): for name, param in self.named_parameters(): yield param def reset_parameters(self): init_range = 0.025 for param in self.parameters(): param.data.uniform_(-init_range, init_range)
def test_sub_graph_build_rnn(self): ss = generate_ss() sgm = SubGraphModule(ss, {'in_channels': 32, 'n_channels': 128}) sgm.set_individual(ss.generate_individual())