Exemplo n.º 1
0
 def _initialize_weights(self):
     num_modules = len(list(self.modules()))
     for idx, m in enumerate(self.modules()):
         if self.map_to_semantic and idx == num_modules - 1:
             assert m == self.conv1x1_instance_to_semantic
             copy_tensor(src=self.instance_to_semantic_mapping_matrix.view(
                 self.n_instance_classes, self.n_semantic_classes, 1, 1),
                         dest=self.conv1x1_instance_to_semantic.weight.data)
             self.conv1x1_instance_to_semantic.weight.requires_grad = False  # Fix weights
         elif isinstance(m, nn.Conv2d):
             m.weight.data.zero_()
             # m.weight.data.normal_(0.0, 0.02)
             if m.bias is not None:
                 m.bias.data.zero_()
         elif isinstance(m, nn.ConvTranspose2d):
             assert m.kernel_size[0] == m.kernel_size[1]
             if m.in_channels == m.out_channels:
                 initial_weight = model_utils.get_upsampling_weight(
                     m.in_channels, m.out_channels, m.kernel_size[0])
             else:
                 initial_weight = model_utils.get_non_symmetric_upsampling_weight(
                     m.in_channels,
                     m.out_channels,
                     m.kernel_size[0],
                     semantic_instance_class_list=self.
                     semantic_instance_class_list)
             copy_tensor(src=initial_weight, dest=m.weight.data)
     if self.score_multiplier_init:
         self.score_multiplier1x1.weight.data.zero_()
         for ch in range(self.score_multiplier1x1.weight.size(1)):
             self.score_multiplier1x1.weight.data[
                 ch, ch] = self.score_multiplier_init
         self.score_multiplier1x1.bias.data.zero_()
Exemplo n.º 2
0
 def copy_from_vgg16_to_modules(self, features, vgg16):
     for l1, l2 in zip(vgg16.features, features):
         if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d):
             if l2 == self.conv1[0] and self.n_input_channels != 3:  # accomodate different input size
                 assert self.n_input_channels > 3, NotImplementedError('Only know how to initialize with # '
                                                                       'input channels >= 3')
                 copy_tensor(src=l1.weight.data, dest=l2.weight.data[:, :3, ...])
                 copy_tensor(src=l1.bias.data, dest=l2.bias.data)
             else:
                 copy_conv(src_conv_module=l1, dest_conv_module=l2)
     for i, name in zip([0, 3], ['fc6', 'fc7']):
         l1 = vgg16.classifier[i]
         l2 = getattr(self, name)
         l2.weight.data.copy_(l1.weight.data.view(l2.weight.size()))
         l2.bias.data.copy_(l1.bias.data.view(l2.bias.size()))
Exemplo n.º 3
0
def copy_modules_from_semantic_to_instance(instance_model_dest, semantic_model, conv2dT_with_repeated_channels,
                                           conv2d_with_repeated_channels, module_names_to_ignore,
                                           module_types_to_ignore, n_semantic_classes, model_channel_semantic_ids):
    for module_name, my_module in instance_model_dest.named_children():
        if module_name in module_names_to_ignore:
            continue
        module_to_copy = getattr(semantic_model, module_name)
        if module_name in conv2d_with_repeated_channels:
            for p_name, my_p in my_module.named_parameters():
                p_to_copy = getattr(module_to_copy, p_name)
                if not all(my_p.size()[c] == p_to_copy.size()[c] for c in range(1, len(my_p.size()))):
                    import ipdb;
                    ipdb.set_trace()
                    raise ValueError('semantic model is formatted incorrectly at layer {}'.format(module_name))
                if DEBUG:
                    assert my_p.data.size(0) == len(model_channel_semantic_ids) \
                           and p_to_copy.data.size(0) == n_semantic_classes
                for inst_cls, sem_cls in enumerate(model_channel_semantic_ids):
                    # weird formatting because scalar -> scalar not implemented (must be FloatTensor,
                    # so we use slicing)
                    n_instances_this_class = float(sum(
                        [1 if sic == sem_cls else 0 for sic in model_channel_semantic_ids]))
                    copy_tensor(src=p_to_copy.data[sem_cls:(sem_cls + 1), ...] / n_instances_this_class,
                                dest=my_p.data[inst_cls:(inst_cls + 1), ...])
        elif module_name in conv2dT_with_repeated_channels:
            assert isinstance(module_to_copy, nn.ConvTranspose2d)
            # assert l1.weight.size() == l2.weight.size()
            # assert l1.bias.size() == l2.bias.size()
            for p_name, my_p in my_module.named_parameters():
                p_to_copy = getattr(module_to_copy, p_name)
                if not all(my_p.size()[c] == p_to_copy.size()[c]
                           for c in [0] + list(range(2, len(p_to_copy.size())))):
                    import ipdb;
                    ipdb.set_trace()
                    raise ValueError('semantic model formatted incorrectly for repeating params.')

                for inst_cls, sem_cls in enumerate(model_channel_semantic_ids):
                    # weird formatting because scalar -> scalar not implemented (must be FloatTensor,
                    # so we use slicing)
                    copy_tensor(src=p_to_copy.data[:, sem_cls:(sem_cls + 1), ...],
                                dest=my_p.data[:, inst_cls:(inst_cls + 1), ...])
        elif isinstance(my_module, nn.Conv2d) or isinstance(my_module, nn.ConvTranspose2d):
            assert type(module_to_copy) == type(my_module)
            for p_name, my_p in my_module.named_parameters():
                p_to_copy = getattr(module_to_copy, p_name)
                if not my_p.size() == p_to_copy.size():
                    import ipdb;
                    ipdb.set_trace()
                    raise ValueError('semantic model is formatted incorrectly at layer {}'.format(module_name))
                copy_tensor(src=p_to_copy.data, dest=my_p.data)
                assert torch.equal(my_p.data, p_to_copy.data)
        elif any([isinstance(my_module, type) for type in module_types_to_ignore]):
            continue
        else:
            if not module_has_params(my_module):
                print('Skipping module of type {} (name: {}) because it has no params.  But please place it in '
                      'list of module types to not copy.'.format(type(my_module), my_module))
                continue
            else:
                raise Exception('Haven''t handled copying of {}, of type {}'.format(module_name, type(my_module)))