def __init__(self,
                 student_model,
                 abfs,
                 device,
                 device_ids,
                 distributed,
                 sizes=None,
                 **kwargs):
        super().__init__()
        self.student_model = wrap_if_distributed(student_model, device,
                                                 device_ids, distributed)
        if sizes is None:
            sizes = [1, 7, 14, 28, 56]

        self.sizes = sizes
        abf_list = nn.ModuleList()
        num_abfs = len(abfs)
        io_path_pairs = list()
        for idx, abf_config in enumerate(abfs):
            abf = wrap_if_distributed(
                AttentionBasedFusion(uses_attention=idx < num_abfs - 1,
                                     **abf_config['params']), device,
                device_ids, distributed)
            abf_list.append(abf)
            io_path_pairs.append((abf_config['io'], abf_config['path']))

        self.abf_modules = abf_list[::-1]
        self.io_path_pairs = io_path_pairs[::-1]
    def __init__(self,
                 input_module,
                 feat_dim,
                 ss_module_ckpt,
                 device,
                 device_ids,
                 distributed,
                 freezes_ss_module=False,
                 teacher_model=None,
                 student_model=None,
                 **kwargs):
        super().__init__()
        is_teacher = teacher_model is not None
        if not is_teacher:
            student_model = wrap_if_distributed(student_model, device,
                                                device_ids, distributed)

        self.model = teacher_model if is_teacher else student_model
        self.is_teacher = is_teacher
        self.input_module_path = input_module['path']
        self.input_module_io = input_module['io']
        ss_module = nn.Sequential(nn.Linear(feat_dim, feat_dim),
                                  nn.ReLU(inplace=True),
                                  nn.Linear(feat_dim, feat_dim))
        self.ckpt_file_path = ss_module_ckpt
        if os.path.isfile(self.ckpt_file_path):
            map_location = {
                'cuda:0': 'cuda:{}'.format(device_ids[0])
            } if distributed else device
            load_module_ckpt(ss_module, map_location, self.ckpt_file_path)
        self.ss_module = ss_module if is_teacher and freezes_ss_module \
            else wrap_if_distributed(ss_module, device, device_ids, distributed)
 def __init__(self, student_model, input_module_path, translator_params,
              device, device_ids, distributed, **kwargs):
     super().__init__()
     self.student_model = wrap_if_distributed(student_model, device,
                                              device_ids, distributed)
     self.input_module_path = input_module_path
     self.translator = \
         wrap_if_distributed(Translator4FactorTransfer(**translator_params), device, device_ids, distributed)
示例#4
0
 def __init__(self, student_model, regressors, device, device_ids, distributed, **kwargs):
     super().__init__()
     self.student_model = wrap_if_distributed(student_model, device, device_ids, distributed)
     io_path_pairs = list()
     self.regressor_dict = nn.ModuleDict()
     for regressor_key, regressor_params in regressors.items():
         regressor = Regressor4VID(**regressor_params)
         self.regressor_dict[regressor_key] = wrap_if_distributed(regressor, device, device_ids, distributed)
         io_path_pairs.append((regressor_key, regressor_params['io'], regressor_params['path']))
     self.io_path_pairs = io_path_pairs
示例#5
0
 def __init__(self, student_model, connectors, device, device_ids, distributed, **kwargs):
     super().__init__()
     self.student_model = wrap_if_distributed(student_model, device, device_ids, distributed)
     io_path_pairs = list()
     self.connector_dict = nn.ModuleDict()
     for connector_key, connector_params in connectors.items():
         connector = self.build_connector(connector_params['conv_params'], connector_params.get('bn_params', None))
         self.connector_dict[connector_key] = wrap_if_distributed(connector, device, device_ids, distributed)
         io_path_pairs.append((connector_key, connector_params['io'], connector_params['path']))
     self.io_path_pairs = io_path_pairs
示例#6
0
    def __init__(self, input_module, linear_params, device, device_ids, distributed,
                 teacher_model=None, student_model=None, **kwargs):
        super().__init__()
        is_teacher = teacher_model is not None
        if not is_teacher:
            student_model = wrap_if_distributed(student_model, device, device_ids, distributed)

        self.model = teacher_model if is_teacher else student_model
        self.is_teacher = is_teacher
        self.input_module_path = input_module['path']
        self.input_module_io = input_module['io']
        self.linear = wrap_if_distributed(nn.Linear(**linear_params), device, device_ids, distributed)
示例#7
0
    def __init__(self, input_module_path, linear_params, device, device_ids, distributed, power=2,
                 teacher_model=None, student_model=None, **kwargs):
        super().__init__()
        is_teacher = teacher_model is not None
        if not is_teacher:
            student_model = wrap_if_distributed(student_model, device, device_ids, distributed)

        self.model = teacher_model if is_teacher else student_model
        self.is_teacher = is_teacher
        self.empty = nn.Sequential()
        self.input_module_path = input_module_path
        linear = nn.Linear(**linear_params)
        self.normalizer = wrap_if_distributed(Normalizer4CRD(linear, power=power), device, device_ids, distributed)
示例#8
0
 def __init__(self, student_model, input_module, feat_dim, var_estimator_ckpt,
              device, device_ids, distributed, **kwargs):
     super().__init__()
     self.student_model = wrap_if_distributed(student_model, device, device_ids, distributed)
     self.input_module_path = input_module['path']
     self.input_module_io = input_module['io']
     var_estimator = nn.Sequential(
         nn.Linear(feat_dim, feat_dim),
         nn.BatchNorm1d(feat_dim)
     )
     self.ckpt_file_path = var_estimator_ckpt
     if os.path.isfile(self.ckpt_file_path):
         map_location = {'cuda:0': 'cuda:{}'.format(device_ids[0])} if distributed else device
         load_module_ckpt(var_estimator, map_location, self.ckpt_file_path)
     self.var_estimator = wrap_if_distributed(var_estimator, device, device_ids, distributed)
    def __init__(self, teacher_model, minimal, input_module_path,
                 paraphraser_params, paraphraser_ckpt, uses_decoder, device,
                 device_ids, distributed, **kwargs):
        super().__init__()
        if minimal is None:
            minimal = dict()

        special_teacher_model = build_special_module(
            minimal, teacher_model=teacher_model)
        model_type = 'original'
        teacher_ref_model = teacher_model
        if special_teacher_model is not None:
            teacher_ref_model = special_teacher_model
            model_type = type(teacher_ref_model).__name__

        self.teacher_model = redesign_model(teacher_ref_model, minimal,
                                            'teacher', model_type)
        self.input_module_path = input_module_path
        self.paraphraser = \
            wrap_if_distributed(Paraphraser4FactorTransfer(**paraphraser_params), device, device_ids, distributed)
        self.ckpt_file_path = paraphraser_ckpt
        if os.path.isfile(self.ckpt_file_path):
            map_location = {
                'cuda:0': 'cuda:{}'.format(device_ids[0])
            } if distributed else device
            load_module_ckpt(self.paraphraser, map_location,
                             self.ckpt_file_path)
        self.uses_decoder = uses_decoder
示例#10
0
 def __init__(self, student_model, input_module_path,
              feature_adapter_params, affinity_adapter_params, device,
              device_ids, distributed, **kwargs):
     super().__init__()
     self.student_model = wrap_if_distributed(student_model, device,
                                              device_ids, distributed)
     self.input_module_path = input_module_path
     feature_adapter = nn.Sequential(
         nn.Conv2d(**feature_adapter_params['conv']),
         nn.BatchNorm2d(**feature_adapter_params['bn']),
         nn.ReLU(**feature_adapter_params['relu']))
     affinity_adapter = nn.Sequential(
         nn.Conv2d(**affinity_adapter_params['conv']))
     self.feature_adapter = wrap_if_distributed(feature_adapter, device,
                                                device_ids, distributed)
     self.affinity_adapter = wrap_if_distributed(affinity_adapter, device,
                                                 device_ids, distributed)