def __init__(self, student_model, abfs, device, device_ids, distributed, sizes=None, **kwargs): super().__init__() self.student_model = wrap_if_distributed(student_model, device, device_ids, distributed) if sizes is None: sizes = [1, 7, 14, 28, 56] self.sizes = sizes abf_list = nn.ModuleList() num_abfs = len(abfs) io_path_pairs = list() for idx, abf_config in enumerate(abfs): abf = wrap_if_distributed( AttentionBasedFusion(uses_attention=idx < num_abfs - 1, **abf_config['params']), device, device_ids, distributed) abf_list.append(abf) io_path_pairs.append((abf_config['io'], abf_config['path'])) self.abf_modules = abf_list[::-1] self.io_path_pairs = io_path_pairs[::-1]
def __init__(self, input_module, feat_dim, ss_module_ckpt, device, device_ids, distributed, freezes_ss_module=False, teacher_model=None, student_model=None, **kwargs): super().__init__() is_teacher = teacher_model is not None if not is_teacher: student_model = wrap_if_distributed(student_model, device, device_ids, distributed) self.model = teacher_model if is_teacher else student_model self.is_teacher = is_teacher self.input_module_path = input_module['path'] self.input_module_io = input_module['io'] ss_module = nn.Sequential(nn.Linear(feat_dim, feat_dim), nn.ReLU(inplace=True), nn.Linear(feat_dim, feat_dim)) self.ckpt_file_path = ss_module_ckpt if os.path.isfile(self.ckpt_file_path): map_location = { 'cuda:0': 'cuda:{}'.format(device_ids[0]) } if distributed else device load_module_ckpt(ss_module, map_location, self.ckpt_file_path) self.ss_module = ss_module if is_teacher and freezes_ss_module \ else wrap_if_distributed(ss_module, device, device_ids, distributed)
def __init__(self, student_model, input_module_path, translator_params, device, device_ids, distributed, **kwargs): super().__init__() self.student_model = wrap_if_distributed(student_model, device, device_ids, distributed) self.input_module_path = input_module_path self.translator = \ wrap_if_distributed(Translator4FactorTransfer(**translator_params), device, device_ids, distributed)
def __init__(self, student_model, regressors, device, device_ids, distributed, **kwargs): super().__init__() self.student_model = wrap_if_distributed(student_model, device, device_ids, distributed) io_path_pairs = list() self.regressor_dict = nn.ModuleDict() for regressor_key, regressor_params in regressors.items(): regressor = Regressor4VID(**regressor_params) self.regressor_dict[regressor_key] = wrap_if_distributed(regressor, device, device_ids, distributed) io_path_pairs.append((regressor_key, regressor_params['io'], regressor_params['path'])) self.io_path_pairs = io_path_pairs
def __init__(self, student_model, connectors, device, device_ids, distributed, **kwargs): super().__init__() self.student_model = wrap_if_distributed(student_model, device, device_ids, distributed) io_path_pairs = list() self.connector_dict = nn.ModuleDict() for connector_key, connector_params in connectors.items(): connector = self.build_connector(connector_params['conv_params'], connector_params.get('bn_params', None)) self.connector_dict[connector_key] = wrap_if_distributed(connector, device, device_ids, distributed) io_path_pairs.append((connector_key, connector_params['io'], connector_params['path'])) self.io_path_pairs = io_path_pairs
def __init__(self, input_module, linear_params, device, device_ids, distributed, teacher_model=None, student_model=None, **kwargs): super().__init__() is_teacher = teacher_model is not None if not is_teacher: student_model = wrap_if_distributed(student_model, device, device_ids, distributed) self.model = teacher_model if is_teacher else student_model self.is_teacher = is_teacher self.input_module_path = input_module['path'] self.input_module_io = input_module['io'] self.linear = wrap_if_distributed(nn.Linear(**linear_params), device, device_ids, distributed)
def __init__(self, input_module_path, linear_params, device, device_ids, distributed, power=2, teacher_model=None, student_model=None, **kwargs): super().__init__() is_teacher = teacher_model is not None if not is_teacher: student_model = wrap_if_distributed(student_model, device, device_ids, distributed) self.model = teacher_model if is_teacher else student_model self.is_teacher = is_teacher self.empty = nn.Sequential() self.input_module_path = input_module_path linear = nn.Linear(**linear_params) self.normalizer = wrap_if_distributed(Normalizer4CRD(linear, power=power), device, device_ids, distributed)
def __init__(self, student_model, input_module, feat_dim, var_estimator_ckpt, device, device_ids, distributed, **kwargs): super().__init__() self.student_model = wrap_if_distributed(student_model, device, device_ids, distributed) self.input_module_path = input_module['path'] self.input_module_io = input_module['io'] var_estimator = nn.Sequential( nn.Linear(feat_dim, feat_dim), nn.BatchNorm1d(feat_dim) ) self.ckpt_file_path = var_estimator_ckpt if os.path.isfile(self.ckpt_file_path): map_location = {'cuda:0': 'cuda:{}'.format(device_ids[0])} if distributed else device load_module_ckpt(var_estimator, map_location, self.ckpt_file_path) self.var_estimator = wrap_if_distributed(var_estimator, device, device_ids, distributed)
def __init__(self, teacher_model, minimal, input_module_path, paraphraser_params, paraphraser_ckpt, uses_decoder, device, device_ids, distributed, **kwargs): super().__init__() if minimal is None: minimal = dict() special_teacher_model = build_special_module( minimal, teacher_model=teacher_model) model_type = 'original' teacher_ref_model = teacher_model if special_teacher_model is not None: teacher_ref_model = special_teacher_model model_type = type(teacher_ref_model).__name__ self.teacher_model = redesign_model(teacher_ref_model, minimal, 'teacher', model_type) self.input_module_path = input_module_path self.paraphraser = \ wrap_if_distributed(Paraphraser4FactorTransfer(**paraphraser_params), device, device_ids, distributed) self.ckpt_file_path = paraphraser_ckpt if os.path.isfile(self.ckpt_file_path): map_location = { 'cuda:0': 'cuda:{}'.format(device_ids[0]) } if distributed else device load_module_ckpt(self.paraphraser, map_location, self.ckpt_file_path) self.uses_decoder = uses_decoder
def __init__(self, student_model, input_module_path, feature_adapter_params, affinity_adapter_params, device, device_ids, distributed, **kwargs): super().__init__() self.student_model = wrap_if_distributed(student_model, device, device_ids, distributed) self.input_module_path = input_module_path feature_adapter = nn.Sequential( nn.Conv2d(**feature_adapter_params['conv']), nn.BatchNorm2d(**feature_adapter_params['bn']), nn.ReLU(**feature_adapter_params['relu'])) affinity_adapter = nn.Sequential( nn.Conv2d(**affinity_adapter_params['conv'])) self.feature_adapter = wrap_if_distributed(feature_adapter, device, device_ids, distributed) self.affinity_adapter = wrap_if_distributed(affinity_adapter, device, device_ids, distributed)