Beispiel #1
0
 def set_masks(self, mask_to_set):
     idx = 0
     for m in self.modules():
         if is_masked_module(m):
             m.mask = mask_to_set[idx].cpu().to(m.mask.device)
             m.weight.data *= m.mask
             idx += 1
     assert idx == len(mask_to_set)
Beispiel #2
0
 def set_weights(self, weights_to_set):
     idx = 0
     for m in self.modules():
         if is_masked_module(m):
             m.weight.data = weights_to_set[idx].cpu().to(
                 m.weight.data.device)
             idx += 1
     assert idx == len(weights_to_set)
Beispiel #3
0
def extract_param_modules(network):
    modules = []

    for m in network.modules():
        if is_base_module(m) or is_masked_module(m) or is_batch_norm(m):
            modules.append(m)

    return modules
Beispiel #4
0
 def reinit(self):
     for m in self.modules():
         if is_masked_module(m):
             self.init_scheme(m.weight.data)
             if m.bias is not None:
                 nn.init.zeros_(m.bias.data)
         elif isinstance(m, nn.BatchNorm2d):
             m.weight.data.fill_(1)
             m.bias.data.zero_()
Beispiel #5
0
    def get_bn_weights(self):
        weights = []
        for m in self.modules():
            if is_masked_module(m):
                weights.append(None)
            if is_batch_norm(m):
                del weights[-1]
                r_var = m.running_var.cpu().detach()
                w = m.weight.cpu().detach()
                weight = w / torch.sqrt(r_var + 0.0000000001)
                weights.append(weight)

        return weights
Beispiel #6
0
def copy_network(network, network_seq):
    modules = extract_param_modules(network)
    modules_seq = extract_param_modules(network_seq)

    assert len(modules) == len(modules_seq)

    for i, (m, m_seq) in enumerate(zip(modules, modules_seq)):
        state_dict = m.state_dict()
        if is_masked_module(m):
            del state_dict['mask']

        if isinstance(m, nn.BatchNorm2d):
            assert isinstance(m_seq, nn.Conv2d)
            m_seq.bias.data = m.bias.data - m.running_mean.data / m.running_var.data.sqrt(
            ) * m.weight.data
            m_seq.weight.data = (m.weight.data /
                                 m.running_var.data.sqrt()).diag()
            m_seq.weight.data = m_seq.weight.data.view(m_seq.weight.shape[0],
                                                       m_seq.weight.shape[1],
                                                       1, 1)
        else:
            m_seq.load_state_dict(state_dict)

    return network_seq
Beispiel #7
0
 def get_masks(self):
     masks = []
     for m in self.modules():
         if is_masked_module(m):
             masks.append(m.mask.cpu().detach())
     return masks
Beispiel #8
0
 def get_weights(self):
     weights = []
     for m in self.modules():
         if is_masked_module(m):
             weights.append(m.weight.data.cpu().detach())
     return weights