コード例 #1
0
    def setup_teacher_student_models(self, teacher_config, student_config):
        unwrapped_org_teacher_model =\
            self.org_teacher_model.module if check_if_wrapped(self.org_teacher_model) else self.org_teacher_model
        unwrapped_org_student_model = \
            self.org_student_model.module if check_if_wrapped(self.org_student_model) else self.org_student_model
        self.target_teacher_pairs.clear()
        self.target_student_pairs.clear()
        teacher_ref_model = unwrapped_org_teacher_model
        student_ref_model = unwrapped_org_student_model
        if len(teacher_config) > 0 or (len(teacher_config) == 0
                                       and self.teacher_model is None):
            model_type = 'original'
            special_teacher_model = \
                build_special_module(teacher_config, teacher_model=unwrapped_org_teacher_model, device=self.device,
                                     device_ids=self.device_ids, distributed=self.distributed)
            if special_teacher_model is not None:
                teacher_ref_model = special_teacher_model
                model_type = type(teacher_ref_model).__name__
            self.teacher_model = redesign_model(teacher_ref_model,
                                                teacher_config, 'teacher',
                                                model_type)

        if len(student_config) > 0 or (len(student_config) == 0
                                       and self.student_model is None):
            model_type = 'original'
            special_student_model = \
                build_special_module(student_config, student_model=unwrapped_org_student_model, device=self.device,
                                     device_ids=self.device_ids, distributed=self.distributed)
            if special_student_model is not None:
                student_ref_model = special_student_model
                model_type = type(student_ref_model).__name__
            self.student_model = redesign_model(student_ref_model,
                                                student_config, 'student',
                                                model_type)

        self.teacher_any_frozen = \
            len(teacher_config.get('frozen_modules', list())) > 0 or not teacher_config.get('requires_grad', True)
        self.student_any_frozen = \
            len(student_config.get('frozen_modules', list())) > 0 or not student_config.get('requires_grad', True)
        self.target_teacher_pairs.extend(
            set_hooks(self.teacher_model, teacher_ref_model, teacher_config,
                      self.teacher_io_dict))
        self.target_student_pairs.extend(
            set_hooks(self.student_model, student_ref_model, student_config,
                      self.student_io_dict))
        self.teacher_forward_proc = get_forward_proc_func(
            teacher_config.get('forward_proc', None))
        self.student_forward_proc = get_forward_proc_func(
            student_config.get('forward_proc', None))
コード例 #2
0
def main(args):
    log_file_path = args.log
    if is_main_process() and log_file_path is not None:
        setup_log_file(os.path.expanduser(log_file_path))

    distributed, device_ids = init_distributed_mode(args.world_size,
                                                    args.dist_url)
    logger.info(args)
    cudnn.benchmark = True
    set_seed(args.seed)
    config = yaml_util.load_yaml_file(os.path.expanduser(args.config))
    device = torch.device(args.device)
    dataset_dict = util.get_all_datasets(config['datasets'])
    models_config = config['models']
    teacher_model_config = models_config.get('teacher_model', None)
    teacher_model = load_model(
        teacher_model_config,
        device) if teacher_model_config is not None else None
    student_model_config =\
        models_config['student_model'] if 'student_model' in models_config else models_config['model']
    ckpt_file_path = student_model_config['ckpt']
    student_model = load_model(student_model_config, device)
    if args.log_config:
        logger.info(config)

    if not args.test_only:
        train(teacher_model, student_model, dataset_dict, ckpt_file_path,
              device, device_ids, distributed, config, args)
        student_model_without_ddp =\
            student_model.module if module_util.check_if_wrapped(student_model) else student_model
        load_ckpt(student_model_config['ckpt'],
                  model=student_model_without_ddp,
                  strict=True)

    test_config = config['test']
    test_data_loader_config = test_config['test_data_loader']
    test_data_loader = util.build_data_loader(
        dataset_dict[test_data_loader_config['dataset_id']],
        test_data_loader_config, distributed)
    log_freq = test_config.get('log_freq', 1000)
    iou_types = args.iou_types
    if not args.student_only and teacher_model is not None:
        evaluate(teacher_model,
                 test_data_loader,
                 iou_types,
                 device,
                 device_ids,
                 distributed,
                 log_freq=log_freq,
                 title='[Teacher: {}]'.format(teacher_model_config['name']))
    evaluate(student_model,
             test_data_loader,
             iou_types,
             device,
             device_ids,
             distributed,
             log_freq=log_freq,
             title='[Student: {}]'.format(student_model_config['name']))
コード例 #3
0
ファイル: util.py プロジェクト: lilujunai/torchdistill
def wrap_model(model, model_config, device, device_ids=None, distributed=False, any_frozen=False):
    wrapper = model_config.get('wrapper', None) if model_config is not None else None
    model.to(device)
    if wrapper is not None and device.type.startswith('cuda') and not check_if_wrapped(model):
        if wrapper == 'DistributedDataParallel' and distributed:
            model = DistributedDataParallel(model, device_ids=device_ids, find_unused_parameters=any_frozen)
        elif wrapper in {'DataParallel', 'DistributedDataParallel'}:
            model = DataParallel(model, device_ids=device_ids)
    return model
コード例 #4
0
 def add_hook(self,
              module,
              module_path,
              requires_input=True,
              requires_output=True):
     unwrapped_module = module.module if check_if_wrapped(
         module) else module
     sub_module = get_module(unwrapped_module, module_path)
     handle = \
         register_forward_hook_with_dict(sub_module, module_path, requires_input, requires_output, self.io_dict)
     self.hook_list.append((module_path, handle))
コード例 #5
0
ファイル: distillation.py プロジェクト: CV-IP/torchdistill
    def get_teacher_output(self, sample_batch, targets, supp_dict):
        cached_data = supp_dict.get('cached_data', None)
        cache_file_paths = supp_dict.get('cache_file_path', None)
        teacher_outputs = None
        cached_extracted_teacher_output_dict = None
        # Use cached data if available
        if cached_data is not None and isinstance(cached_data, dict):
            device = sample_batch.device
            teacher_outputs = cached_data['teacher_outputs']
            cached_extracted_teacher_output_dict = cached_data['extracted_outputs']
            if device.type != 'cpu':
                teacher_outputs = change_device(teacher_outputs, device)
                cached_extracted_teacher_output_dict = change_device(cached_extracted_teacher_output_dict, device)
            if not self.teacher_updatable:
                return teacher_outputs, cached_extracted_teacher_output_dict

        if teacher_outputs is None:
            if self.teacher_updatable:
                teacher_outputs = self.teacher_forward_proc(self.teacher_model, sample_batch, targets, supp_dict)
            else:
                with torch.no_grad():
                    teacher_outputs = self.teacher_forward_proc(self.teacher_model, sample_batch, targets, supp_dict)

        if cached_extracted_teacher_output_dict is not None:
            if isinstance(self.teacher_model, SpecialModule) or \
                    (check_if_wrapped(self.teacher_model) and isinstance(self.teacher_model.module, SpecialModule)):
                self.teacher_io_dict.update(cached_extracted_teacher_output_dict)
                if isinstance(self.teacher_model, SpecialModule):
                    self.teacher_model.post_forward(self.teacher_io_dict)

            extracted_teacher_io_dict = extract_io_dict(self.teacher_io_dict, self.device)
            return teacher_outputs, extracted_teacher_io_dict

        # Deep copy of teacher info dict if teacher special module contains trainable module(s)
        teacher_io_dict4cache = copy.deepcopy(self.teacher_io_dict) \
            if self.teacher_updatable and isinstance(cache_file_paths, (list, tuple)) is not None else None
        extracted_teacher_io_dict = extract_io_dict(self.teacher_io_dict, self.device)
        if isinstance(self.teacher_model, SpecialModule):
            self.teacher_model.post_forward(extracted_teacher_io_dict)

        update_io_dict(extracted_teacher_io_dict, extract_io_dict(self.teacher_io_dict, self.device))
        # Write cache files if output file paths (cache_file_paths) are given
        if isinstance(cache_file_paths, (list, tuple)):
            if teacher_io_dict4cache is None:
                teacher_io_dict4cache = extracted_teacher_io_dict

            cpu_device = torch.device('cpu')
            for i, (teacher_output, cache_file_path) in enumerate(zip(teacher_outputs.cpu().numpy(), cache_file_paths)):
                sub_dict = extract_sub_model_output_dict(teacher_io_dict4cache, i)
                sub_dict = tensor2numpy2tensor(sub_dict, cpu_device)
                cache_dict = {'teacher_outputs': torch.Tensor(teacher_output), 'extracted_outputs': sub_dict}
                make_parent_dirs(cache_file_path)
                torch.save(cache_dict, cache_file_path)
        return teacher_outputs, extracted_teacher_io_dict
コード例 #6
0
def save_ckpt(model, optimizer, lr_scheduler, best_value, config, args,
              output_file_path):
    make_parent_dirs(output_file_path)
    model_state_dict = model.module.state_dict() if check_if_wrapped(
        model) else model.state_dict()
    lr_scheduler_state_dict = lr_scheduler.state_dict(
    ) if lr_scheduler is not None else None
    save_on_master(
        {
            'model': model_state_dict,
            'optimizer': optimizer.state_dict(),
            'best_value': best_value,
            'lr_scheduler': lr_scheduler_state_dict,
            'config': config,
            'args': args
        }, output_file_path)
コード例 #7
0
def train(teacher_model, student_model, dataset_dict, ckpt_file_path, device,
          device_ids, distributed, config, args):
    logger.info('Start training')
    train_config = config['train']
    lr_factor = args.world_size if distributed and args.adjust_lr else 1
    training_box = get_training_box(student_model, dataset_dict, train_config,
                                    device, device_ids, distributed, lr_factor) if teacher_model is None \
        else get_distillation_box(teacher_model, student_model, dataset_dict, train_config,
                                  device, device_ids, distributed, lr_factor)
    best_val_map = 0.0
    optimizer, lr_scheduler = training_box.optimizer, training_box.lr_scheduler
    if file_util.check_if_exists(ckpt_file_path):
        best_val_map, _, _ = load_ckpt(ckpt_file_path,
                                       optimizer=optimizer,
                                       lr_scheduler=lr_scheduler)

    log_freq = train_config['log_freq']
    iou_types = args.iou_types
    val_iou_type = iou_types[0] if isinstance(
        iou_types, (list, tuple)) and len(iou_types) > 0 else 'bbox'
    student_model_without_ddp = student_model.module if module_util.check_if_wrapped(
        student_model) else student_model
    start_time = time.time()
    for epoch in range(args.start_epoch, training_box.num_epochs):
        training_box.pre_process(epoch=epoch)
        train_one_epoch(training_box, device, epoch, log_freq)
        val_coco_evaluator =\
            evaluate(student_model, training_box.val_data_loader, iou_types, device, device_ids, distributed,
                     log_freq=log_freq, header='Validation:')
        # Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ]
        val_map = val_coco_evaluator.coco_eval[val_iou_type].stats[0]
        if val_map > best_val_map and is_main_process():
            logger.info('Best mAP ({}): {:.4f} -> {:.4f}'.format(
                val_iou_type, best_val_map, val_map))
            logger.info('Updating ckpt at {}'.format(ckpt_file_path))
            best_val_map = val_map
            save_ckpt(student_model_without_ddp, optimizer, lr_scheduler,
                      best_val_map, config, args, ckpt_file_path)
        training_box.post_process()

    if distributed:
        dist.barrier()

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    logger.info('Training time {}'.format(total_time_str))
    training_box.clean_modules()
コード例 #8
0
def distill(teacher_model, student_model, dataset_dict, device, device_ids,
            distributed, config, args):
    logger.info('Start distillation')
    train_config = config['train']
    lr_factor = args.world_size if distributed and args.adjust_lr else 1
    distillation_box =\
        get_distillation_box(teacher_model, student_model, dataset_dict,
                             train_config, device, device_ids, distributed, lr_factor)
    ckpt_file_path = config['models']['student_model']['ckpt']
    best_val_top1_accuracy = 0.0
    optimizer, lr_scheduler = distillation_box.optimizer, distillation_box.lr_scheduler
    if file_util.check_if_exists(ckpt_file_path):
        best_val_top1_accuracy, _, _ = load_ckpt(ckpt_file_path,
                                                 optimizer=optimizer,
                                                 lr_scheduler=lr_scheduler)

    log_freq = train_config['log_freq']
    student_model_without_ddp = student_model.module if module_util.check_if_wrapped(
        student_model) else student_model
    start_time = time.time()
    for epoch in range(args.start_epoch, distillation_box.num_epochs):
        distillation_box.pre_process(epoch=epoch)
        distill_one_epoch(distillation_box, device, epoch, log_freq)
        val_top1_accuracy = evaluate(student_model,
                                     distillation_box.val_data_loader,
                                     device,
                                     device_ids,
                                     distributed,
                                     log_freq=log_freq,
                                     header='Validation:')
        if val_top1_accuracy > best_val_top1_accuracy and is_main_process():
            logger.info('Updating ckpt (Best top1 accuracy: '
                        '{:.4f} -> {:.4f})'.format(best_val_top1_accuracy,
                                                   val_top1_accuracy))
            best_val_top1_accuracy = val_top1_accuracy
            save_ckpt(student_model_without_ddp, optimizer, lr_scheduler,
                      best_val_top1_accuracy, config, args, ckpt_file_path)
        distillation_box.post_process()

    if distributed:
        dist.barrier()

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    logger.info('Training time {}'.format(total_time_str))
    distillation_box.clean_modules()
コード例 #9
0
def main(args):
    log_file_path = args.log
    if is_main_process() and log_file_path is not None:
        setup_log_file(os.path.expanduser(log_file_path))

    distributed, device_ids = init_distributed_mode(args.world_size,
                                                    args.dist_url)
    logger.info(args)
    cudnn.benchmark = True
    config = yaml_util.load_yaml_file(os.path.expanduser(args.config))
    device = torch.device(args.device)
    dataset_dict = util.get_all_dataset(config['datasets'])
    models_config = config['models']
    teacher_model_config = models_config['teacher_model']
    teacher_model = get_model(teacher_model_config, device, distributed, False)
    student_model_config = models_config['student_model']
    student_model = get_model(student_model_config, device, distributed,
                              args.sync_bn)
    if not args.test_only:
        distill(teacher_model, student_model, dataset_dict, device, device_ids,
                distributed, config, args)
        student_model_without_ddp =\
            student_model.module if module_util.check_if_wrapped(student_model) else student_model
        load_ckpt(student_model_config['ckpt'],
                  model=student_model_without_ddp,
                  strict=True)

    test_config = config['test']
    test_data_loader_config = test_config['test_data_loader']
    test_data_loader = util.build_data_loader(
        dataset_dict[test_data_loader_config['dataset_id']],
        test_data_loader_config, distributed)
    if not args.student_only:
        evaluate(teacher_model,
                 test_data_loader,
                 device,
                 device_ids,
                 distributed,
                 title='[Teacher: {}]'.format(teacher_model_config['name']))
    evaluate(student_model,
             test_data_loader,
             device,
             device_ids,
             distributed,
             title='[Student: {}]'.format(student_model_config['name']))
コード例 #10
0
def train(teacher_model, student_model, dataset_dict, ckpt_file_path, device,
          device_ids, distributed, config, args):
    logger.info('Start training')
    train_config = config['train']
    lr_factor = args.world_size if distributed and args.adjust_lr else 1
    training_box = get_training_box(student_model, dataset_dict, train_config,
                                    device, device_ids, distributed, lr_factor) if teacher_model is None \
        else get_distillation_box(teacher_model, student_model, dataset_dict, train_config,
                                  device, device_ids, distributed, lr_factor)
    best_val_miou = 0.0
    optimizer, lr_scheduler = training_box.optimizer, training_box.lr_scheduler
    if file_util.check_if_exists(ckpt_file_path):
        best_val_miou, _, _ = load_ckpt(ckpt_file_path,
                                        optimizer=optimizer,
                                        lr_scheduler=lr_scheduler)

    log_freq = train_config['log_freq']
    student_model_without_ddp = student_model.module if module_util.check_if_wrapped(
        student_model) else student_model
    start_time = time.time()
    for epoch in range(args.start_epoch, training_box.num_epochs):
        training_box.pre_process(epoch=epoch)
        train_one_epoch(training_box, device, epoch, log_freq)
        val_seg_evaluator =\
            evaluate(student_model, training_box.val_data_loader, device, device_ids, distributed,
                     num_classes=args.num_classes, log_freq=log_freq, header='Validation:')

        val_acc_global, val_acc, val_iou = val_seg_evaluator.compute()
        val_miou = val_iou.mean().item()
        if val_miou > best_val_miou and is_main_process():
            logger.info('Updating ckpt (Best mIoU: {:.4f} -> {:.4f})'.format(
                best_val_miou, val_miou))
            best_val_miou = val_miou
            save_ckpt(student_model_without_ddp, optimizer, lr_scheduler,
                      best_val_miou, config, args, ckpt_file_path)
        training_box.post_process()

    if distributed:
        dist.barrier()

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    logger.info('Training time {}'.format(total_time_str))
    training_box.clean_modules()
コード例 #11
0
ファイル: util.py プロジェクト: wuxiaolianggit/torchdistill
def save_module_ckpt(module, ckpt_file_path):
    if is_main_process():
        make_parent_dirs(ckpt_file_path)
    state_dict = module.module.state_dict() if check_if_wrapped(module) else module.state_dict()
    save_on_master(state_dict, ckpt_file_path)
コード例 #12
0
ファイル: util.py プロジェクト: wuxiaolianggit/torchdistill
def load_module_ckpt(module, map_location, ckpt_file_path):
    state_dict = torch.load(ckpt_file_path, map_location=map_location)
    if check_if_wrapped(module):
        module.module.load_state_dict(state_dict)
    else:
        module.load_state_dict(state_dict)