def main(args): distributed, device_ids = main_util.init_distributed_mode( args.world_size, args.dist_url) config = yaml_util.load_yaml_file(args.config) if args.json is not None: main_util.overwrite_config(config, args.json) device = torch.device(args.device) print(args) print('Loading data') train_config = config['train'] train_sampler, train_data_loader, val_data_loader, test_data_loader =\ data_util.get_coco_data_loaders(config['dataset'], train_config['batch_size'], distributed) print('Creating model') model_config = config['model'] model = get_model(model_config, device) print('Model Created') if distributed: model = nn.parallel.DistributedDataParallel(model, device_ids=device_ids) if args.train: print('Start training') start_time = time.time() train(model, train_sampler, train_data_loader, val_data_loader, device, distributed, config, args, model_config['ckpt']) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str)) main_util.evaluate(model, test_data_loader, device=device)
def evaluate(teacher_model, student_model, test_data_loader, device, student_only, use_bottleneck_transformer): teacher_model_without_dp = teacher_model.module if isinstance( teacher_model, DataParallel) else teacher_model student_model_without_ddp =\ student_model.module if isinstance(student_model, DistributedDataParallel) else student_model teacher_model_without_dp.distill_backbone_only = False student_model_without_ddp.distill_backbone_only = False student_model_without_ddp.backbone.body.layer1.use_bottleneck_transformer = use_bottleneck_transformer if not student_only: print('[Teacher model]') main_util.evaluate(teacher_model, test_data_loader, device=device) print('\n[Student model]') main_util.evaluate(student_model, test_data_loader, device=device)
def train(model, train_sampler, train_data_loader, val_data_loader, device, distributed, config, args, ckpt_file_path): train_config = config['train'] optim_config = train_config['optimizer'] optimizer = func_util.get_optimizer(model, optim_config['type'], optim_config['params']) scheduler_config = train_config['scheduler'] lr_scheduler = func_util.get_scheduler(optimizer, scheduler_config['type'], scheduler_config['params']) best_val_map = 0.0 if file_util.check_if_exists(ckpt_file_path): best_val_map, _, _ = load_ckpt(ckpt_file_path, optimizer=optimizer, lr_scheduler=lr_scheduler) num_epochs = train_config['num_epochs'] log_freq = train_config['log_freq'] start_time = time.time() for epoch in range(num_epochs): if distributed: train_sampler.set_epoch(epoch) train_model(model, optimizer, train_data_loader, device, epoch, log_freq) lr_scheduler.step() # evaluate after every epoch coco_evaluator = main_util.evaluate(model, val_data_loader, device=device) # Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] val_map = coco_evaluator.coco_eval['bbox'].stats[0] if val_map > best_val_map: print('Updating ckpt (Best BBox mAP: {:.4f} -> {:.4f})'.format(best_val_map, val_map)) best_val_map = val_map save_ckpt(model, optimizer, lr_scheduler, best_val_map, config, args, ckpt_file_path) lr_scheduler.step() dist.barrier() total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str))
def evaluate(model, test_data_loader, device): coco_evaluator = main_util.evaluate(model, test_data_loader, device=device) for iou_type, coco_eval in coco_evaluator.coco_eval.items(): print("***************************************") print(coco_eval.stats) print(iou_type) return coco_evaluator.coco_eval['bbox'].stats
def distill(teacher_model, student_model, train_sampler, train_data_loader, val_data_loader, device, distributed, distill_backbone_only, config, args): train_config = config['train'] distillation_box = DistillationBox(teacher_model, student_model, train_config['criterion']) ckpt_file_path = config['student_model']['ckpt'] optim_config = train_config['optimizer'] optimizer = func_util.get_optimizer(student_model, optim_config['type'], optim_config['params']) scheduler_config = train_config['scheduler'] lr_scheduler = func_util.get_scheduler(optimizer, scheduler_config['type'], scheduler_config['params']) use_bottleneck_transformer = args.transform_bottleneck best_val_map = 0.0 if file_util.check_if_exists(ckpt_file_path): best_val_map, _, _ = load_ckpt(ckpt_file_path, optimizer=optimizer, lr_scheduler=lr_scheduler) num_epochs = train_config['num_epochs'] log_freq = train_config['log_freq'] teacher_model_without_dp = teacher_model.module if isinstance( teacher_model, DataParallel) else teacher_model student_model_without_ddp =\ student_model.module if isinstance(student_model, DistributedDataParallel) else student_model start_time = time.time() for epoch in range(num_epochs): if distributed: train_sampler.set_epoch(epoch) teacher_model.eval() student_model.train() teacher_model_without_dp.distill_backbone_only = distill_backbone_only student_model_without_ddp.distill_backbone_only = distill_backbone_only student_model_without_ddp.backbone.body.layer1.use_bottleneck_transformer = False distill_model(distillation_box, train_data_loader, optimizer, log_freq, device, epoch) student_model_without_ddp.distill_backbone_only = False student_model_without_ddp.backbone.body.layer1.use_bottleneck_transformer = use_bottleneck_transformer coco_evaluator = main_util.evaluate(student_model, val_data_loader, device=device) # Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] val_map = coco_evaluator.coco_eval['bbox'].stats[0] if val_map > best_val_map and misc_util.is_main_process(): print('Updating ckpt (Best BBox mAP: {:.4f} -> {:.4f})'.format( best_val_map, val_map)) best_val_map = val_map save_ckpt(student_model_without_ddp, optimizer, lr_scheduler, best_val_map, config, args, ckpt_file_path) lr_scheduler.step() dist.barrier() total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str))
def evaluate(teacher_model, student_model, test_data_loader, train_data_loader, device, student_only, use_bottleneck_transformer, student_model_config, post_bn=False): teacher_model_without_dp = teacher_model.module if isinstance( teacher_model, DataParallel) else teacher_model student_model_without_ddp = \ student_model.module if isinstance(student_model, DistributedDataParallel) else student_model teacher_model_without_dp.distill_backbone_only = False student_model_without_ddp.distill_backbone_only = False set_bottleneck_transformer(student_model_without_ddp, use_bottleneck_transformer) if not student_only: print('[Teacher model]') main_util.evaluate(teacher_model, test_data_loader, device=device) if "slimmable" in student_model_config['backbone']['params']: width_mult_list = student_model_config['backbone']['params'][ 'width_mult_list'] for width in width_mult_list: print('\n[Student model@width={}]'.format(width)) set_width(student_model, width) if post_bn: ComputeBN(student_model, train_data_loader) main_util.evaluate(student_model, test_data_loader, device=device) else: main_util.evaluate(student_model, test_data_loader, device=device)
def distill(teacher_model, student_model, train_sampler, train_data_loader, val_data_loader, device, distributed, distill_backbone_only, config, args): train_config = config['train'] student_config = config['student_model'] distillation_box = DistillationBox(teacher_model, student_model, train_config['criterion'], student_config) ckpt_file_path = config['student_model']['ckpt'] optim_config = train_config['optimizer'] optimizer = func_util.get_optimizer(student_model, optim_config['type'], optim_config['params']) scheduler_config = train_config['scheduler'] lr_scheduler = func_util.get_scheduler(optimizer, scheduler_config['type'], scheduler_config['params']) use_bottleneck_transformer = args.transform_bottleneck best_val_map = 0.0 if file_util.check_if_exists(ckpt_file_path): if args.ignore_optimizer: best_val_map, _, _ = load_ckpt(ckpt_file_path, optimizer=None, lr_scheduler=None) else: best_val_map, _, _ = load_ckpt(ckpt_file_path, optimizer=optimizer, lr_scheduler=lr_scheduler) num_epochs = train_config['num_epochs'] log_freq = train_config['log_freq'] teacher_model_without_dp = teacher_model.module if isinstance( teacher_model, DataParallel) else teacher_model student_model_without_ddp = \ student_model.module if isinstance(student_model, DistributedDataParallel) else student_model start_time = time.time() post_bn = False if 'post_batch_norm' in config['train']: post_bn = config['train']['post_batch_norm'] for epoch in range(lr_scheduler.last_epoch, num_epochs): if distributed: train_sampler.set_epoch(epoch) teacher_model.eval() student_model.train() teacher_model_without_dp.distill_backbone_only = distill_backbone_only student_model_without_ddp.distill_backbone_only = distill_backbone_only set_bottleneck_transformer(student_model_without_ddp, False) distill_model(distillation_box, train_data_loader, optimizer, log_freq, device, epoch) student_model_without_ddp.distill_backbone_only = False set_bottleneck_transformer(student_model_without_ddp, use_bottleneck_transformer) val_map = 0 width_list = [1.0] if 'slimmable' in student_config['backbone']['params']: width_list = [0.25, 0.5, 0.75, 1.0] width_list = [ w for w in width_list if w in student_config['backbone']['params']['width_mult_list'] ] for width in width_list: set_width(student_model, width) if post_bn: ComputeBN(student_model, train_data_loader) print('\n[Student model@width={}]'.format(width)) coco_evaluator = main_util.evaluate(student_model, val_data_loader, device=device) val_map += coco_evaluator.coco_eval['bbox'].stats[0] val_map = val_map / len(width_list) print('BBox mAP: {:.4f})'.format(val_map)) if val_map > best_val_map and misc_util.is_main_process(): print('Updating ckpt (Best BBox mAP: {:.4f} -> {:.4f})'.format( best_val_map, val_map)) best_val_map = val_map save_ckpt(student_model_without_ddp, optimizer, lr_scheduler, best_val_map, config, args, ckpt_file_path) lr_scheduler.step() if distributed: dist.barrier() total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str))