def main(args): config = yaml_util.load_yaml_file(args.config) if args.json is not None: main_util.overwrite_config(config, args.json) distributed, device_ids = main_util.init_distributed_mode( args.world_size, args.dist_url) device = torch.device(args.device if torch.cuda.is_available() else 'cpu') teacher_model = get_model(config['teacher_model'], device) module_util.freeze_module_params(teacher_model) student_model_config = config['student_model'] student_model = get_model(student_model_config, device) freeze_modules(student_model, student_model_config) print('Updatable parameters: {}'.format( module_util.get_updatable_param_names(student_model))) distill_backbone_only = student_model_config['distill_backbone_only'] train_config = config['train'] train_sampler, train_data_loader, val_data_loader, test_data_loader = \ data_util.get_coco_data_loaders(config['dataset'], train_config['batch_size'], distributed) if distributed: teacher_model = DataParallel(teacher_model, device_ids=device_ids) student_model = DistributedDataParallel(student_model, device_ids=device_ids) if args.distill: distill(teacher_model, student_model, train_sampler, train_data_loader, val_data_loader, device, distributed, distill_backbone_only, config, args) load_ckpt( config['student_model']['ckpt'], model=student_model.module if isinstance( student_model, DistributedDataParallel) else student_model) evaluate(teacher_model, student_model, test_data_loader, device, args.skip_teacher_eval, args.transform_bottleneck)
def train(train_loader, valid_loader, input_shape, config, device, distributed, device_ids): ae_without_ddp, ae_type = ae_util.get_autoencoder(config, device) head_model = ae_util.get_head_model(config, input_shape, device) module_util.freeze_module_params(head_model) ckpt_file_path = config['autoencoder']['ckpt'] start_epoch, best_valid_acc = resume_from_ckpt(ckpt_file_path, ae_without_ddp) if best_valid_acc is None: best_valid_acc = 0.0 train_config = config['train'] criterion_config = train_config['criterion'] criterion = func_util.get_loss(criterion_config['type'], criterion_config['params']) optim_config = train_config['optimizer'] optimizer = func_util.get_optimizer(ae_without_ddp, optim_config['type'], optim_config['params']) scheduler_config = train_config['scheduler'] scheduler = func_util.get_scheduler(optimizer, scheduler_config['type'], scheduler_config['params']) interval = train_config['interval'] if interval <= 0: num_batches = len(train_loader) interval = num_batches // 20 if num_batches >= 20 else 1 autoencoder = ae_without_ddp if distributed: autoencoder = DistributedDataParallel(ae_without_ddp, device_ids=device_ids) head_model = DataParallel(head_model, device_ids=device_ids) elif device.type == 'cuda': autoencoder = DataParallel(ae_without_ddp) head_model = DataParallel(head_model) end_epoch = start_epoch + train_config['epoch'] start_time = time.time() for epoch in range(start_epoch, end_epoch): if distributed: train_loader.sampler.set_epoch(epoch) train_epoch(autoencoder, head_model, train_loader, optimizer, criterion, epoch, device, interval) valid_acc = validate(ae_without_ddp, valid_loader, config, device, distributed, device_ids) if valid_acc > best_valid_acc and main_util.is_main_process(): print( 'Updating ckpt (Best top1 accuracy: {:.4f} -> {:.4f})'.format( best_valid_acc, valid_acc)) best_valid_acc = valid_acc save_ckpt(ae_without_ddp, epoch, best_valid_acc, ckpt_file_path, ae_type) scheduler.step() dist.barrier() total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str)) del head_model
def freeze_modules(student_model, student_model_config): if 'frozen_modules' in student_model_config: for student_path in student_model_config['frozen_modules']: student_module = module_util.get_module(student_model, student_path) module_util.freeze_module_params(student_module) elif 'unfrozen_modules' in student_model_config: module_util.freeze_module_params(student_model) for student_path in student_model_config['unfrozen_modules']: student_module = module_util.get_module(student_model, student_path) module_util.unfreeze_module_params(student_module)
def freeze_modules(student_model, student_model_config, reset_unfrozen=False): if 'frozen_modules' in student_model_config: for student_path in student_model_config['frozen_modules']: student_module = module_util.get_module(student_model, student_path) module_util.freeze_module_params(student_module) elif 'unfrozen_modules' in student_model_config: module_util.freeze_module_params(student_model) for student_path in student_model_config['unfrozen_modules']: student_module = module_util.get_module(student_model, student_path) module_util.unfreeze_module_params(student_module) if reset_unfrozen: print("Reinitializing module: {}".format(student_path)) init_weights(student_module)
def __init__(self, backbone, return_layers, in_channels_list, out_channels, ext_config): super().__init__() if ext_config.get('backbone_frozen', False): module_util.freeze_module_params(backbone) self.body = ExtIntermediateLayerGetter(backbone, return_layers=return_layers, ext_config=ext_config) self.fpn = FeaturePyramidNetwork( in_channels_list=in_channels_list, out_channels=out_channels, extra_blocks=LastLevelMaxPool(), ) self.out_channels = out_channels self.split = False
def main(args): distributed, device_ids = main_util.init_distributed_mode( args.world_size, args.dist_url) config = yaml_util.load_yaml_file(args.config) if args.json is not None: main_util.overwrite_config(config, args.json) device = torch.device(args.device) print(args) print('Loading data') train_config = config['train'] train_sampler, train_data_loader, val_data_loader, test_data_loader =\ data_util.get_coco_data_loaders(config['dataset'], train_config['batch_size'], distributed) print('Creating model') model_config = config['model'] model = get_model(model_config, device, strict=False) module_util.freeze_module_params(model) ext_classifier = model.get_ext_classifier() module_util.unfreeze_module_params(ext_classifier) print('Updatable parameters: {}'.format( module_util.get_updatable_param_names(model))) model.train_ext() if distributed: model = nn.parallel.DistributedDataParallel(model, device_ids=device_ids) if args.train: print('Start training') start_time = time.time() ckpt_file_path = model_config['backbone']['ext_config']['ckpt'] train(model, ext_classifier, train_sampler, train_data_loader, val_data_loader, device, distributed, config, args, ckpt_file_path) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str)) load_ckpt(ckpt_file_path, model=ext_classifier) evaluate(model, test_data_loader, device=device, min_recall=args.min_recall, split_name='Test')
def experiment(config, device, args): # Load Model teacher_model = get_model(config['teacher_model'], device) module_util.freeze_module_params(teacher_model) student_model_config = config['student_model'] student_model = get_model(student_model_config, device, strict=True, require_weights=True) set_bottleneck_transformer(student_model, True) if args.dry_run: return None #Load Dataset train_config = config['train'] train_sampler, train_data_loader, val_data_loader, test_data_loader = \ data_util.get_coco_data_loaders(config['dataset'], train_config['batch_size'], distributed=False) # Prepare dataframe header = ["Setting", "Iou50:95", "Iou50"] df = pd.DataFrame(columns=header) res = evaluate(teacher_model, test_data_loader, device) df = df.append({"Setting": "Teacher", "Iou50:95": res[0], "Iou50": res[1]}, ignore_index=True) post_bn = False if 'post_batch_norm' in config['train']: post_bn = config['train']['post_batch_norm'] width_mult_list = [1.0] if "slimmable" in student_model_config['backbone']['params']: width_mult_list = student_model_config['backbone']['params']['width_mult_list'] for width in width_mult_list: print('\n[Student model@width={}]'.format(width)) set_width(student_model, width) if post_bn: ComputeBN(student_model, train_data_loader) res = evaluate(student_model, test_data_loader, device) df = df.append({"Setting": str(width), "Iou50:95": res[0], "Iou50": res[1]}, ignore_index=True) return df
def main(args): config = yaml_util.load_yaml_file(args.config) if args.json is not None: main_util.overwrite_config(config, args.json) distributed, device_ids = main_util.init_distributed_mode(args.world_size, args.dist_url) device = torch.device(args.device if torch.cuda.is_available() else 'cpu') teacher_model = get_model(config['teacher_model'], device) module_util.freeze_module_params(teacher_model) student_model_config = config['student_model'] student_model = get_model(student_model_config, device) freeze_modules(student_model, student_model_config) ckpt_file_path = config['student_model']['ckpt'] train_config = config['train'] optim_config = train_config['optimizer'] optimizer = func_util.get_optimizer(student_model, optim_config['type'], optim_config['params']) scheduler_config = train_config['scheduler'] lr_scheduler = func_util.get_scheduler(optimizer, scheduler_config['type'], scheduler_config['params']) if file_util.check_if_exists(ckpt_file_path): best_val_map, _, _ = load_ckpt(ckpt_file_path, optimizer=optimizer, lr_scheduler=lr_scheduler) save_ckpt(student_model, optimizer, lr_scheduler, best_val_map, config, args, ckpt_file_path)
def distill(train_loader, valid_loader, input_shape, aux_weight, config, device, distributed, device_ids): teacher_model_config = config['teacher_model'] teacher_model, teacher_model_type = mimic_util.get_teacher_model( teacher_model_config, input_shape, device) module_util.freeze_module_params(teacher_model) student_model_config = config['student_model'] student_model = mimic_util.get_student_model(teacher_model_type, student_model_config, config['dataset']['name']) student_model = student_model.to(device) start_epoch, best_valid_acc = mimic_util.resume_from_ckpt( student_model_config['ckpt'], student_model, is_student=True) if best_valid_acc is None: best_valid_acc = 0.0 train_config = config['train'] criterion_config = train_config['criterion'] criterion = func_util.get_loss(criterion_config['type'], criterion_config['params']) optim_config = train_config['optimizer'] optimizer = func_util.get_optimizer(student_model, optim_config['type'], optim_config['params']) scheduler_config = train_config['scheduler'] scheduler = func_util.get_scheduler(optimizer, scheduler_config['type'], scheduler_config['params']) interval = train_config['interval'] if interval <= 0: num_batches = len(train_loader) interval = num_batches // 20 if num_batches >= 20 else 1 student_model_without_ddp = student_model if distributed: teacher_model = DataParallel(teacher_model, device_ids=device_ids) student_model = DistributedDataParallel(student_model, device_ids=device_ids) student_model_without_ddp = student_model.module ckpt_file_path = student_model_config['ckpt'] end_epoch = start_epoch + train_config['epoch'] start_time = time.time() for epoch in range(start_epoch, end_epoch): if distributed: train_loader.sampler.set_epoch(epoch) distill_one_epoch(student_model, teacher_model, train_loader, optimizer, criterion, epoch, device, interval, aux_weight) valid_acc = validate(student_model, valid_loader, config, device, distributed, device_ids) if valid_acc > best_valid_acc and main_util.is_main_process(): print( 'Updating ckpt (Best top1 accuracy: {:.4f} -> {:.4f})'.format( best_valid_acc, valid_acc)) best_valid_acc = valid_acc save_ckpt(student_model_without_ddp, epoch, best_valid_acc, ckpt_file_path, teacher_model_type) scheduler.step() dist.barrier() total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str)) del teacher_model del student_model
def freeze_modules(student_model, student_model_config): for student_path in student_model_config['frozen_modules']: student_module = module_util.get_module(student_model, student_path) module_util.freeze_module_params(student_module)
def main(args): if args.apex: if sys.version_info < (3, 0): raise RuntimeError( 'Apex currently only supports Python 3. Aborting.') if amp is None: raise RuntimeError( 'Failed to import apex. Please install apex from https://www.github.com/nvidia/apex ' 'to enable mixed-precision training.') distributed, device_ids = main_util.init_distributed_mode( args.world_size, args.dist_url) print(args) if torch.cuda.is_available(): torch.backends.cudnn.benchmark = True config = yaml_util.load_yaml_file(args.config) device = torch.device(args.device if torch.cuda.is_available() else 'cpu') dataset_config = config['dataset'] input_shape = config['input_shape'] train_config = config['train'] test_config = config['test'] train_data_loader, val_data_loader, test_data_loader =\ dataset_util.get_data_loaders(dataset_config, batch_size=train_config['batch_size'], rough_size=train_config['rough_size'], reshape_size=input_shape[1:3], jpeg_quality=-1, test_batch_size=test_config['batch_size'], distributed=distributed) teacher_model_config = config['teacher_model'] teacher_model, teacher_model_type = mimic_util.get_org_model( teacher_model_config, device) module_util.freeze_module_params(teacher_model) student_model = mimic_util.get_mimic_model_easily(config, device) student_model_config = config['mimic_model'] optim_config = train_config['optimizer'] optimizer = func_util.get_optimizer(student_model, optim_config['type'], optim_config['params']) use_apex = args.apex if use_apex: student_model, optimizer = amp.initialize( student_model, optimizer, opt_level=args.apex_opt_level) if distributed: teacher_model = DataParallel(teacher_model, device_ids=device_ids) student_model = DistributedDataParallel(student_model, device_ids=device_ids) start_epoch = args.start_epoch if not args.test_only: distill(teacher_model, student_model, train_data_loader, val_data_loader, device, distributed, start_epoch, config, args) student_model_without_ddp =\ student_model.module if isinstance(student_model, DistributedDataParallel) else student_model load_ckpt(student_model_config['ckpt'], model=student_model_without_ddp, strict=True) if not args.student_only: evaluate(teacher_model, test_data_loader, device, title='[Teacher: {}]'.format(teacher_model_type)) evaluate(student_model, test_data_loader, device, title='[Student: {}]'.format(student_model_config['type']))