def setup_teacher_student_models(self, teacher_config, student_config): unwrapped_org_teacher_model =\ self.org_teacher_model.module if check_if_wrapped(self.org_teacher_model) else self.org_teacher_model unwrapped_org_student_model = \ self.org_student_model.module if check_if_wrapped(self.org_student_model) else self.org_student_model self.target_teacher_pairs.clear() self.target_student_pairs.clear() teacher_ref_model = unwrapped_org_teacher_model student_ref_model = unwrapped_org_student_model if len(teacher_config) > 0 or (len(teacher_config) == 0 and self.teacher_model is None): model_type = 'original' special_teacher_model = \ build_special_module(teacher_config, teacher_model=unwrapped_org_teacher_model, device=self.device, device_ids=self.device_ids, distributed=self.distributed) if special_teacher_model is not None: teacher_ref_model = special_teacher_model model_type = type(teacher_ref_model).__name__ self.teacher_model = redesign_model(teacher_ref_model, teacher_config, 'teacher', model_type) if len(student_config) > 0 or (len(student_config) == 0 and self.student_model is None): model_type = 'original' special_student_model = \ build_special_module(student_config, student_model=unwrapped_org_student_model, device=self.device, device_ids=self.device_ids, distributed=self.distributed) if special_student_model is not None: student_ref_model = special_student_model model_type = type(student_ref_model).__name__ self.student_model = redesign_model(student_ref_model, student_config, 'student', model_type) self.teacher_any_frozen = \ len(teacher_config.get('frozen_modules', list())) > 0 or not teacher_config.get('requires_grad', True) self.student_any_frozen = \ len(student_config.get('frozen_modules', list())) > 0 or not student_config.get('requires_grad', True) self.target_teacher_pairs.extend( set_hooks(self.teacher_model, teacher_ref_model, teacher_config, self.teacher_io_dict)) self.target_student_pairs.extend( set_hooks(self.student_model, student_ref_model, student_config, self.student_io_dict)) self.teacher_forward_proc = get_forward_proc_func( teacher_config.get('forward_proc', None)) self.student_forward_proc = get_forward_proc_func( student_config.get('forward_proc', None))
def main(args): log_file_path = args.log if is_main_process() and log_file_path is not None: setup_log_file(os.path.expanduser(log_file_path)) distributed, device_ids = init_distributed_mode(args.world_size, args.dist_url) logger.info(args) cudnn.benchmark = True set_seed(args.seed) config = yaml_util.load_yaml_file(os.path.expanduser(args.config)) device = torch.device(args.device) dataset_dict = util.get_all_datasets(config['datasets']) models_config = config['models'] teacher_model_config = models_config.get('teacher_model', None) teacher_model = load_model( teacher_model_config, device) if teacher_model_config is not None else None student_model_config =\ models_config['student_model'] if 'student_model' in models_config else models_config['model'] ckpt_file_path = student_model_config['ckpt'] student_model = load_model(student_model_config, device) if args.log_config: logger.info(config) if not args.test_only: train(teacher_model, student_model, dataset_dict, ckpt_file_path, device, device_ids, distributed, config, args) student_model_without_ddp =\ student_model.module if module_util.check_if_wrapped(student_model) else student_model load_ckpt(student_model_config['ckpt'], model=student_model_without_ddp, strict=True) test_config = config['test'] test_data_loader_config = test_config['test_data_loader'] test_data_loader = util.build_data_loader( dataset_dict[test_data_loader_config['dataset_id']], test_data_loader_config, distributed) log_freq = test_config.get('log_freq', 1000) iou_types = args.iou_types if not args.student_only and teacher_model is not None: evaluate(teacher_model, test_data_loader, iou_types, device, device_ids, distributed, log_freq=log_freq, title='[Teacher: {}]'.format(teacher_model_config['name'])) evaluate(student_model, test_data_loader, iou_types, device, device_ids, distributed, log_freq=log_freq, title='[Student: {}]'.format(student_model_config['name']))
def wrap_model(model, model_config, device, device_ids=None, distributed=False, any_frozen=False): wrapper = model_config.get('wrapper', None) if model_config is not None else None model.to(device) if wrapper is not None and device.type.startswith('cuda') and not check_if_wrapped(model): if wrapper == 'DistributedDataParallel' and distributed: model = DistributedDataParallel(model, device_ids=device_ids, find_unused_parameters=any_frozen) elif wrapper in {'DataParallel', 'DistributedDataParallel'}: model = DataParallel(model, device_ids=device_ids) return model
def add_hook(self, module, module_path, requires_input=True, requires_output=True): unwrapped_module = module.module if check_if_wrapped( module) else module sub_module = get_module(unwrapped_module, module_path) handle = \ register_forward_hook_with_dict(sub_module, module_path, requires_input, requires_output, self.io_dict) self.hook_list.append((module_path, handle))
def get_teacher_output(self, sample_batch, targets, supp_dict): cached_data = supp_dict.get('cached_data', None) cache_file_paths = supp_dict.get('cache_file_path', None) teacher_outputs = None cached_extracted_teacher_output_dict = None # Use cached data if available if cached_data is not None and isinstance(cached_data, dict): device = sample_batch.device teacher_outputs = cached_data['teacher_outputs'] cached_extracted_teacher_output_dict = cached_data['extracted_outputs'] if device.type != 'cpu': teacher_outputs = change_device(teacher_outputs, device) cached_extracted_teacher_output_dict = change_device(cached_extracted_teacher_output_dict, device) if not self.teacher_updatable: return teacher_outputs, cached_extracted_teacher_output_dict if teacher_outputs is None: if self.teacher_updatable: teacher_outputs = self.teacher_forward_proc(self.teacher_model, sample_batch, targets, supp_dict) else: with torch.no_grad(): teacher_outputs = self.teacher_forward_proc(self.teacher_model, sample_batch, targets, supp_dict) if cached_extracted_teacher_output_dict is not None: if isinstance(self.teacher_model, SpecialModule) or \ (check_if_wrapped(self.teacher_model) and isinstance(self.teacher_model.module, SpecialModule)): self.teacher_io_dict.update(cached_extracted_teacher_output_dict) if isinstance(self.teacher_model, SpecialModule): self.teacher_model.post_forward(self.teacher_io_dict) extracted_teacher_io_dict = extract_io_dict(self.teacher_io_dict, self.device) return teacher_outputs, extracted_teacher_io_dict # Deep copy of teacher info dict if teacher special module contains trainable module(s) teacher_io_dict4cache = copy.deepcopy(self.teacher_io_dict) \ if self.teacher_updatable and isinstance(cache_file_paths, (list, tuple)) is not None else None extracted_teacher_io_dict = extract_io_dict(self.teacher_io_dict, self.device) if isinstance(self.teacher_model, SpecialModule): self.teacher_model.post_forward(extracted_teacher_io_dict) update_io_dict(extracted_teacher_io_dict, extract_io_dict(self.teacher_io_dict, self.device)) # Write cache files if output file paths (cache_file_paths) are given if isinstance(cache_file_paths, (list, tuple)): if teacher_io_dict4cache is None: teacher_io_dict4cache = extracted_teacher_io_dict cpu_device = torch.device('cpu') for i, (teacher_output, cache_file_path) in enumerate(zip(teacher_outputs.cpu().numpy(), cache_file_paths)): sub_dict = extract_sub_model_output_dict(teacher_io_dict4cache, i) sub_dict = tensor2numpy2tensor(sub_dict, cpu_device) cache_dict = {'teacher_outputs': torch.Tensor(teacher_output), 'extracted_outputs': sub_dict} make_parent_dirs(cache_file_path) torch.save(cache_dict, cache_file_path) return teacher_outputs, extracted_teacher_io_dict
def save_ckpt(model, optimizer, lr_scheduler, best_value, config, args, output_file_path): make_parent_dirs(output_file_path) model_state_dict = model.module.state_dict() if check_if_wrapped( model) else model.state_dict() lr_scheduler_state_dict = lr_scheduler.state_dict( ) if lr_scheduler is not None else None save_on_master( { 'model': model_state_dict, 'optimizer': optimizer.state_dict(), 'best_value': best_value, 'lr_scheduler': lr_scheduler_state_dict, 'config': config, 'args': args }, output_file_path)
def train(teacher_model, student_model, dataset_dict, ckpt_file_path, device, device_ids, distributed, config, args): logger.info('Start training') train_config = config['train'] lr_factor = args.world_size if distributed and args.adjust_lr else 1 training_box = get_training_box(student_model, dataset_dict, train_config, device, device_ids, distributed, lr_factor) if teacher_model is None \ else get_distillation_box(teacher_model, student_model, dataset_dict, train_config, device, device_ids, distributed, lr_factor) best_val_map = 0.0 optimizer, lr_scheduler = training_box.optimizer, training_box.lr_scheduler if file_util.check_if_exists(ckpt_file_path): best_val_map, _, _ = load_ckpt(ckpt_file_path, optimizer=optimizer, lr_scheduler=lr_scheduler) log_freq = train_config['log_freq'] iou_types = args.iou_types val_iou_type = iou_types[0] if isinstance( iou_types, (list, tuple)) and len(iou_types) > 0 else 'bbox' student_model_without_ddp = student_model.module if module_util.check_if_wrapped( student_model) else student_model start_time = time.time() for epoch in range(args.start_epoch, training_box.num_epochs): training_box.pre_process(epoch=epoch) train_one_epoch(training_box, device, epoch, log_freq) val_coco_evaluator =\ evaluate(student_model, training_box.val_data_loader, iou_types, device, device_ids, distributed, log_freq=log_freq, header='Validation:') # Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] val_map = val_coco_evaluator.coco_eval[val_iou_type].stats[0] if val_map > best_val_map and is_main_process(): logger.info('Best mAP ({}): {:.4f} -> {:.4f}'.format( val_iou_type, best_val_map, val_map)) logger.info('Updating ckpt at {}'.format(ckpt_file_path)) best_val_map = val_map save_ckpt(student_model_without_ddp, optimizer, lr_scheduler, best_val_map, config, args, ckpt_file_path) training_box.post_process() if distributed: dist.barrier() total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) logger.info('Training time {}'.format(total_time_str)) training_box.clean_modules()
def distill(teacher_model, student_model, dataset_dict, device, device_ids, distributed, config, args): logger.info('Start distillation') train_config = config['train'] lr_factor = args.world_size if distributed and args.adjust_lr else 1 distillation_box =\ get_distillation_box(teacher_model, student_model, dataset_dict, train_config, device, device_ids, distributed, lr_factor) ckpt_file_path = config['models']['student_model']['ckpt'] best_val_top1_accuracy = 0.0 optimizer, lr_scheduler = distillation_box.optimizer, distillation_box.lr_scheduler if file_util.check_if_exists(ckpt_file_path): best_val_top1_accuracy, _, _ = load_ckpt(ckpt_file_path, optimizer=optimizer, lr_scheduler=lr_scheduler) log_freq = train_config['log_freq'] student_model_without_ddp = student_model.module if module_util.check_if_wrapped( student_model) else student_model start_time = time.time() for epoch in range(args.start_epoch, distillation_box.num_epochs): distillation_box.pre_process(epoch=epoch) distill_one_epoch(distillation_box, device, epoch, log_freq) val_top1_accuracy = evaluate(student_model, distillation_box.val_data_loader, device, device_ids, distributed, log_freq=log_freq, header='Validation:') if val_top1_accuracy > best_val_top1_accuracy and is_main_process(): logger.info('Updating ckpt (Best top1 accuracy: ' '{:.4f} -> {:.4f})'.format(best_val_top1_accuracy, val_top1_accuracy)) best_val_top1_accuracy = val_top1_accuracy save_ckpt(student_model_without_ddp, optimizer, lr_scheduler, best_val_top1_accuracy, config, args, ckpt_file_path) distillation_box.post_process() if distributed: dist.barrier() total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) logger.info('Training time {}'.format(total_time_str)) distillation_box.clean_modules()
def main(args): log_file_path = args.log if is_main_process() and log_file_path is not None: setup_log_file(os.path.expanduser(log_file_path)) distributed, device_ids = init_distributed_mode(args.world_size, args.dist_url) logger.info(args) cudnn.benchmark = True config = yaml_util.load_yaml_file(os.path.expanduser(args.config)) device = torch.device(args.device) dataset_dict = util.get_all_dataset(config['datasets']) models_config = config['models'] teacher_model_config = models_config['teacher_model'] teacher_model = get_model(teacher_model_config, device, distributed, False) student_model_config = models_config['student_model'] student_model = get_model(student_model_config, device, distributed, args.sync_bn) if not args.test_only: distill(teacher_model, student_model, dataset_dict, device, device_ids, distributed, config, args) student_model_without_ddp =\ student_model.module if module_util.check_if_wrapped(student_model) else student_model load_ckpt(student_model_config['ckpt'], model=student_model_without_ddp, strict=True) test_config = config['test'] test_data_loader_config = test_config['test_data_loader'] test_data_loader = util.build_data_loader( dataset_dict[test_data_loader_config['dataset_id']], test_data_loader_config, distributed) if not args.student_only: evaluate(teacher_model, test_data_loader, device, device_ids, distributed, title='[Teacher: {}]'.format(teacher_model_config['name'])) evaluate(student_model, test_data_loader, device, device_ids, distributed, title='[Student: {}]'.format(student_model_config['name']))
def train(teacher_model, student_model, dataset_dict, ckpt_file_path, device, device_ids, distributed, config, args): logger.info('Start training') train_config = config['train'] lr_factor = args.world_size if distributed and args.adjust_lr else 1 training_box = get_training_box(student_model, dataset_dict, train_config, device, device_ids, distributed, lr_factor) if teacher_model is None \ else get_distillation_box(teacher_model, student_model, dataset_dict, train_config, device, device_ids, distributed, lr_factor) best_val_miou = 0.0 optimizer, lr_scheduler = training_box.optimizer, training_box.lr_scheduler if file_util.check_if_exists(ckpt_file_path): best_val_miou, _, _ = load_ckpt(ckpt_file_path, optimizer=optimizer, lr_scheduler=lr_scheduler) log_freq = train_config['log_freq'] student_model_without_ddp = student_model.module if module_util.check_if_wrapped( student_model) else student_model start_time = time.time() for epoch in range(args.start_epoch, training_box.num_epochs): training_box.pre_process(epoch=epoch) train_one_epoch(training_box, device, epoch, log_freq) val_seg_evaluator =\ evaluate(student_model, training_box.val_data_loader, device, device_ids, distributed, num_classes=args.num_classes, log_freq=log_freq, header='Validation:') val_acc_global, val_acc, val_iou = val_seg_evaluator.compute() val_miou = val_iou.mean().item() if val_miou > best_val_miou and is_main_process(): logger.info('Updating ckpt (Best mIoU: {:.4f} -> {:.4f})'.format( best_val_miou, val_miou)) best_val_miou = val_miou save_ckpt(student_model_without_ddp, optimizer, lr_scheduler, best_val_miou, config, args, ckpt_file_path) training_box.post_process() if distributed: dist.barrier() total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) logger.info('Training time {}'.format(total_time_str)) training_box.clean_modules()
def save_module_ckpt(module, ckpt_file_path): if is_main_process(): make_parent_dirs(ckpt_file_path) state_dict = module.module.state_dict() if check_if_wrapped(module) else module.state_dict() save_on_master(state_dict, ckpt_file_path)
def load_module_ckpt(module, map_location, ckpt_file_path): state_dict = torch.load(ckpt_file_path, map_location=map_location) if check_if_wrapped(module): module.module.load_state_dict(state_dict) else: module.load_state_dict(state_dict)