Example #1
0
    def __init__(self, n_channels, ss, individual_index=0, se_block=True):
        super(CnnSearchModule, self).__init__()

        self.ss = ss
        self.n_channels = n_channels
        self.config_dict = {'n_channels': n_channels}
        self.sub_graph_module = SubGraphModule(
            ss, self.config_dict, individual_index=individual_index)
        if ss.single_block:
            self.n_inputs = len(ss.ocl[0].inputs)
        else:
            self.n_inputs = len(ss.ocl[individual_index][0].inputs)
        if se_block:
            self.se_block = SEBlock(n_channels, 8)
        else:
            self.se_block = Identity()

        self.bn = nn.BatchNorm2d(n_channels)
        self.relu = nn.ReLU()
        if self.ss.single_block:
            self.weights = [
                Parameter(torch.Tensor(n_channels, n_channels, 1, 1))
                for _ in range(len(ss.ocl))
            ]
        else:
            self.weights = [
                Parameter(torch.Tensor(n_channels, n_channels, 1, 1))
                for _ in range(len(ss.ocl[individual_index]))
            ]
        [
            self.register_parameter('w_' + str(i), w)
            for i, w in enumerate(self.weights)
        ]
        self.register_parameter('bias', None)
        self.reset_parameters()
Example #2
0
    def test_cnn_sub_module(self):
        ss = generate_ss_cnn()
        sgm = SubGraphModule(ss, {'n_channels': 64})

        for i in range(100):
            sgm.set_individual(ss.generate_individual())
            y = torch.randn(32, 64, 16, 16, dtype=torch.float)
            x = torch.randn(32, 64, 16, 16, dtype=torch.float)
            res = sgm(x, y)
Example #3
0
 def test_run_sub_module(self):
     ss = generate_ss()
     sgm = SubGraphModule(ss, {'in_channels': 32, 'n_channels': 128})
     y = torch.randn(25, 128, dtype=torch.float)
     for i in range(100):
         sgm.set_individual(ss.generate_individual())
         x = torch.randn(25, 32, dtype=torch.float)
         y = sgm(x, y)
         y = y[-1]
Example #4
0
    def __init__(self, in_channels, n_channels, working_device, ss):
        super(RnnSearchModule, self).__init__()

        self.ss = ss
        self.in_channels = in_channels
        self.n_channels = n_channels
        self.working_device = working_device
        self.config_dict = {
            'in_channels': self.in_channels,
            'n_channels': self.n_channels
        }
        self.sub_graph_module = SubGraphModule(ss, self.config_dict)

        self.reset_parameters()
Example #5
0
class CnnSearchModule(nn.Module):
    def __init__(self, n_channels, ss, individual_index=0, se_block=True):
        super(CnnSearchModule, self).__init__()

        self.ss = ss
        self.n_channels = n_channels
        self.config_dict = {'n_channels': n_channels}
        self.sub_graph_module = SubGraphModule(
            ss, self.config_dict, individual_index=individual_index)
        if ss.single_block:
            self.n_inputs = len(ss.ocl[0].inputs)
        else:
            self.n_inputs = len(ss.ocl[individual_index][0].inputs)
        if se_block:
            self.se_block = SEBlock(n_channels, 8)
        else:
            self.se_block = Identity()

        self.bn = nn.BatchNorm2d(n_channels)
        self.relu = nn.ReLU()
        if self.ss.single_block:
            self.weights = [
                Parameter(torch.Tensor(n_channels, n_channels, 1, 1))
                for _ in range(len(ss.ocl))
            ]
        else:
            self.weights = [
                Parameter(torch.Tensor(n_channels, n_channels, 1, 1))
                for _ in range(len(ss.ocl[individual_index]))
            ]
        [
            self.register_parameter('w_' + str(i), w)
            for i, w in enumerate(self.weights)
        ]
        self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        n = self.n_channels * len(self.weights)
        stdv = 1. / math.sqrt(n)
        for w in self.weights:
            w.data.uniform_(-stdv, stdv)

    def forward(self, inputs_tensor, bypass_input):
        if self.n_inputs == 1:
            net = self.sub_graph_module(inputs_tensor)
        elif self.n_inputs == 2:
            net = self.sub_graph_module(inputs_tensor, bypass_input)

        net = torch.cat(
            [net[i] for i in self.sub_graph_module.avg_index if i > 1], dim=1)
        w = torch.cat([
            self.weights[i - 2]
            for i in self.sub_graph_module.avg_index if i > 1
        ],
                      dim=1)
        net = self.bn(F.conv2d(self.relu(net), w, self.bias, 1, 0, 1, 1))
        return self.se_block(net) + inputs_tensor

    def set_individual(self, individual: Individual):
        self.sub_graph_module.set_individual(individual)

    def parameters(self):
        for name, param in self.named_parameters():
            yield param
Example #6
0
class RnnSearchModule(nn.Module):
    def __init__(self, in_channels, n_channels, working_device, ss):
        super(RnnSearchModule, self).__init__()

        self.ss = ss
        self.in_channels = in_channels
        self.n_channels = n_channels
        self.working_device = working_device
        self.config_dict = {
            'in_channels': self.in_channels,
            'n_channels': self.n_channels
        }
        self.sub_graph_module = SubGraphModule(ss, self.config_dict)

        self.reset_parameters()

    def forward(self, inputs_tensor, state_tensor):
        # input size [Time step,Batch,features]

        state = state_tensor[0, :, :]
        outputs = []

        for i in torch.split(inputs_tensor, split_size_or_sections=1,
                             dim=0):  # Loop over time steps
            output, state = self.cell(i, state)
            # state_norm = state.norm(dim=-1)
            # max_norm = 25.0
            # if torch.any(state_norm > max_norm).item():
            #     clip_select = state_norm > max_norm
            #     clip_norms = state_norm[clip_select]
            #
            #     mask = torch.ones(state.size(), device=self.working_device)
            #     normalizer = max_norm / clip_norms
            #     mask[clip_select, :] = normalizer.unsqueeze(dim=-1)
            #     mask = mask.detach()
            #     state *= mask
            # print(np.max(state.norm(dim=-1).detach().cpu().numpy()))
            # print("Max Norm pass")
            # state = state / state.norm(dim=-1)
            outputs.append(output)
        output = torch.stack(outputs, dim=0)

        return output, state.unsqueeze(dim=0)

    def cell(self, x, state):
        net = self.sub_graph_module(x.squeeze(dim=0), state)
        output, state = torch.mean(torch.stack(
            [net[i] for i in self.sub_graph_module.avg_index], dim=-1),
                                   dim=-1), net[-1]
        return output, output

    def set_individual(self, individual: Individual):
        self.sub_graph_module.set_individual(individual)

    def init_state(self, batch_size=1):  # model init state
        weight = next(self.parameters())
        return weight.new_zeros(1, batch_size, self.n_channels)

    def parameters(self):
        for name, param in self.named_parameters():
            yield param

    def reset_parameters(self):
        init_range = 0.025
        for param in self.parameters():
            param.data.uniform_(-init_range, init_range)
Example #7
0
    def test_sub_graph_build_rnn(self):
        ss = generate_ss()
        sgm = SubGraphModule(ss, {'in_channels': 32, 'n_channels': 128})

        sgm.set_individual(ss.generate_individual())