def redesign_model(org_model,
                   model_config,
                   model_label,
                   model_type='original'):
    logger.info('[{} model]'.format(model_label))
    frozen_module_path_set = set(model_config.get('frozen_modules', list()))
    module_paths = model_config.get('sequential', list())
    if not isinstance(module_paths, list) or len(module_paths) == 0:
        logger.info('Using the {} {} model'.format(model_type, model_label))
        if len(frozen_module_path_set) > 0:
            logger.info('Frozen module(s): {}'.format(frozen_module_path_set))

        isinstance_str = 'instance('
        for frozen_module_path in frozen_module_path_set:
            if frozen_module_path.startswith(
                    isinstance_str) and frozen_module_path.endswith(')'):
                target_cls = nn.__dict__[
                    frozen_module_path[len(isinstance_str):-1]]
                for m in org_model.modules():
                    if isinstance(m, target_cls):
                        freeze_module_params(m)
            else:
                module = get_module(org_model, frozen_module_path)
                freeze_module_params(module)
        return org_model

    logger.info('Redesigning the {} model with {}'.format(
        model_label, module_paths))
    if len(frozen_module_path_set) > 0:
        logger.info('Frozen module(s): {}'.format(frozen_module_path_set))

    module_dict = OrderedDict()
    adaptation_dict = model_config.get('adaptations', dict())

    for frozen_module_path in frozen_module_path_set:
        module = get_module(org_model, frozen_module_path)
        freeze_module_params(module)

    for module_path in module_paths:
        if module_path.startswith('+'):
            module_path = module_path[1:]
            adaptation_config = adaptation_dict[module_path]
            module = get_adaptation_module(adaptation_config['type'],
                                           **adaptation_config['params'])
        else:
            module = get_module(org_model, module_path)

        if module_path in frozen_module_path_set:
            freeze_module_params(module)

        add_submodule(module, module_path, module_dict)
    return build_sequential_container(module_dict)
 def add_hook(self,
              module,
              module_path,
              requires_input=True,
              requires_output=True):
     unwrapped_module = module.module if check_if_wrapped(
         module) else module
     sub_module = get_module(unwrapped_module, module_path)
     handle = \
         register_forward_hook_with_dict(sub_module, module_path, requires_input, requires_output, self.io_dict)
     self.hook_list.append((module_path, handle))
    def setup(self, train_config):
        # Set up train and val data loaders
        self.setup_data_loaders(train_config)

        # Define model used in this stage
        model_config = train_config.get('model', dict())
        self.setup_model(model_config)

        # Define loss function used in this stage
        self.setup_loss(train_config)

        # Wrap models if necessary
        self.model =\
            wrap_model(self.model, model_config, self.device, self.device_ids, self.distributed,
                       self.model_any_frozen)

        if not model_config.get('requires_grad', True):
            logger.info('Freezing the whole model')
            freeze_module_params(self.model)

        # Set up optimizer and scheduler
        optim_config = train_config.get('optimizer', dict())
        optimizer_reset = False
        if len(optim_config) > 0:
            optim_params_config = optim_config['params']
            optim_params_config['lr'] *= self.lr_factor
            module_wise_params_configs = optim_config.get('module_wise_params', list())
            if len(module_wise_params_configs) > 0:
                trainable_module_list = list()
                for module_wise_params_config in module_wise_params_configs:
                    module_wise_params_dict = dict()
                    module_wise_params_dict.update(module_wise_params_config['params'])
                    if 'lr' in module_wise_params_dict:
                        module_wise_params_dict['lr'] *= self.lr_factor

                    module = get_module(self, module_wise_params_config['module'])
                    module_wise_params_dict['params'] = module.parameters()
                    trainable_module_list.append(module_wise_params_dict)
            else:
                trainable_module_list = nn.ModuleList([self.model])

            self.optimizer = get_optimizer(trainable_module_list, optim_config['type'], optim_params_config)
            optimizer_reset = True

        scheduler_config = train_config.get('scheduler', None)
        if scheduler_config is not None and len(scheduler_config) > 0:
            self.lr_scheduler = get_scheduler(self.optimizer, scheduler_config['type'], scheduler_config['params'])
        elif optimizer_reset:
            self.lr_scheduler = None

        # Set up apex if you require mixed-precision training
        self.apex = False
        apex_config = train_config.get('apex', None)
        if apex_config is not None and apex_config.get('requires', False):
            if sys.version_info < (3, 0):
                raise RuntimeError('Apex currently only supports Python 3. Aborting.')
            if amp is None:
                raise RuntimeError('Failed to import apex. Please install apex from https://www.github.com/nvidia/apex '
                                   'to enable mixed-precision training.')
            self.model, self.optimizer =\
                amp.initialize(self.model, self.optimizer, opt_level=apex_config['opt_level'])
            self.apex = True
    def setup(self, train_config):
        # Set up train and val data loaders
        self.setup_data_loaders(train_config)

        # Define teacher and student models used in this stage
        teacher_config = train_config.get('teacher', dict())
        student_config = train_config.get('student', dict())
        self.setup_teacher_student_models(teacher_config, student_config)

        # Define loss function used in this stage
        self.setup_loss(train_config)

        # Freeze parameters if specified
        self.teacher_updatable = True
        if not teacher_config.get('requires_grad', True):
            logger.info('Freezing the whole teacher model')
            freeze_module_params(self.teacher_model)
            self.teacher_updatable = False

        if not student_config.get('requires_grad', True):
            logger.info('Freezing the whole student model')
            freeze_module_params(self.student_model)

        # Wrap models if necessary
        teacher_unused_parameters = teacher_config.get(
            'find_unused_parameters', self.teacher_any_frozen)
        teacher_any_updatable = len(
            get_updatable_param_names(self.teacher_model)) > 0
        self.teacher_model =\
            wrap_model(self.teacher_model, teacher_config, self.device, self.device_ids, self.distributed,
                       teacher_unused_parameters, teacher_any_updatable)
        student_unused_parameters = student_config.get(
            'find_unused_parameters', self.student_any_frozen)
        student_any_updatable = len(
            get_updatable_param_names(self.student_model)) > 0
        self.student_model =\
            wrap_model(self.student_model, student_config, self.device, self.device_ids, self.distributed,
                       student_unused_parameters, student_any_updatable)

        # Set up optimizer and scheduler
        optim_config = train_config.get('optimizer', dict())
        optimizer_reset = False
        if len(optim_config) > 0:
            optim_params_config = optim_config['params']
            if 'lr' in optim_params_config:
                optim_params_config['lr'] *= self.lr_factor

            module_wise_params_configs = optim_config.get(
                'module_wise_params', list())
            if len(module_wise_params_configs) > 0:
                trainable_module_list = list()
                for module_wise_params_config in module_wise_params_configs:
                    module_wise_params_dict = dict()
                    if isinstance(
                            module_wise_params_config.get('params', None),
                            dict):
                        module_wise_params_dict.update(
                            module_wise_params_config['params'])

                    if 'lr' in module_wise_params_dict:
                        module_wise_params_dict['lr'] *= self.lr_factor

                    target_model = \
                        self.teacher_model if module_wise_params_config.get('is_teacher', False) else self.student_model
                    module = get_module(target_model,
                                        module_wise_params_config['module'])
                    module_wise_params_dict['params'] = module.parameters()
                    trainable_module_list.append(module_wise_params_dict)
            else:
                trainable_module_list = nn.ModuleList([self.student_model])
                if self.teacher_updatable:
                    logger.info(
                        'Note that you are training some/all of the modules in the teacher model'
                    )
                    trainable_module_list.append(self.teacher_model)

            filters_params = optim_config.get('filters_params', True)
            self.optimizer = \
                get_optimizer(trainable_module_list, optim_config['type'], optim_params_config, filters_params)
            self.optimizer.zero_grad()
            self.max_grad_norm = optim_config.get('max_grad_norm', None)
            self.grad_accum_step = optim_config.get('grad_accum_step', 1)
            optimizer_reset = True

        scheduler_config = train_config.get('scheduler', None)
        if scheduler_config is not None and len(scheduler_config) > 0:
            self.lr_scheduler = get_scheduler(self.optimizer,
                                              scheduler_config['type'],
                                              scheduler_config['params'])
            self.scheduling_step = scheduler_config.get('scheduling_step', 0)
        elif optimizer_reset:
            self.lr_scheduler = None
            self.scheduling_step = None

        # Set up accelerator/apex if necessary
        self.apex = False
        apex_config = train_config.get('apex', None)
        if self.accelerator is not None:
            if self.teacher_updatable:
                self.teacher_model, self.student_model, self.optimizer, self.train_data_loader, self.val_data_loader = \
                    self.accelerator.prepare(self.teacher_model, self.student_model, self.optimizer,
                                             self.train_data_loader, self.val_data_loader)
            else:
                self.teacher_model = self.teacher_model.to(
                    self.accelerator.device)
                if self.accelerator.state.use_fp16:
                    self.teacher_model = self.teacher_model.half()

                self.student_model, self.optimizer, self.train_data_loader, self.val_data_loader = \
                    self.accelerator.prepare(self.student_model, self.optimizer,
                                             self.train_data_loader, self.val_data_loader)
        elif apex_config is not None and apex_config.get('requires', False):
            if sys.version_info < (3, 0):
                raise RuntimeError(
                    'Apex currently only supports Python 3. Aborting.')
            if amp is None:
                raise RuntimeError(
                    'Failed to import apex. Please install apex from https://www.github.com/nvidia/apex '
                    'to enable mixed-precision training.')
            self.student_model, self.optimizer =\
                amp.initialize(self.student_model, self.optimizer, opt_level=apex_config['opt_level'])
            self.apex = True
Exemple #5
0
def extract_module(org_model, sub_model, module_path):
    if module_path.startswith('+'):
        return get_module(sub_model, module_path[1:])
    return get_module(org_model, module_path)