def main(args): log_file_path = args.log if is_main_process() and log_file_path is not None: setup_log_file(os.path.expanduser(log_file_path)) distributed, device_ids = init_distributed_mode(args.world_size, args.dist_url) logger.info(args) cudnn.benchmark = True set_seed(args.seed) config = yaml_util.load_yaml_file(os.path.expanduser(args.config)) device = torch.device(args.device) dataset_dict = util.get_all_datasets(config['datasets']) models_config = config['models'] teacher_model_config = models_config.get('teacher_model', None) teacher_model = load_model( teacher_model_config, device) if teacher_model_config is not None else None student_model_config =\ models_config['student_model'] if 'student_model' in models_config else models_config['model'] ckpt_file_path = student_model_config['ckpt'] student_model = load_model(student_model_config, device) if args.log_config: logger.info(config) if not args.test_only: train(teacher_model, student_model, dataset_dict, ckpt_file_path, device, device_ids, distributed, config, args) student_model_without_ddp =\ student_model.module if module_util.check_if_wrapped(student_model) else student_model load_ckpt(student_model_config['ckpt'], model=student_model_without_ddp, strict=True) test_config = config['test'] test_data_loader_config = test_config['test_data_loader'] test_data_loader = util.build_data_loader( dataset_dict[test_data_loader_config['dataset_id']], test_data_loader_config, distributed) log_freq = test_config.get('log_freq', 1000) iou_types = args.iou_types if not args.student_only and teacher_model is not None: evaluate(teacher_model, test_data_loader, iou_types, device, device_ids, distributed, log_freq=log_freq, title='[Teacher: {}]'.format(teacher_model_config['name'])) evaluate(student_model, test_data_loader, iou_types, device, device_ids, distributed, log_freq=log_freq, title='[Student: {}]'.format(student_model_config['name']))
def train_one_epoch(training_box, device, epoch, log_freq): metric_logger = MetricLogger(delimiter=' ') metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value}')) metric_logger.add_meter('img/s', SmoothedValue(window_size=10, fmt='{value}')) header = 'Epoch: [{}]'.format(epoch) for sample_batch, targets, supp_dict in \ metric_logger.log_every(training_box.train_data_loader, log_freq, header): start_time = time.time() sample_batch, targets = sample_batch.to(device), targets.to(device) loss = training_box(sample_batch, targets, supp_dict) training_box.update_params(loss) batch_size = sample_batch.shape[0] metric_logger.update(loss=loss.item(), lr=training_box.optimizer.param_groups[0]['lr']) metric_logger.meters['img/s'].update(batch_size / (time.time() - start_time)) if (torch.isnan(loss) or torch.isinf(loss)) and is_main_process(): raise ValueError('The training loop was broken due to loss = {}'.format(loss))
def train(teacher_model, student_model, dataset_dict, ckpt_file_path, device, device_ids, distributed, config, args): logger.info('Start training') train_config = config['train'] lr_factor = args.world_size if distributed and args.adjust_lr else 1 training_box = get_training_box(student_model, dataset_dict, train_config, device, device_ids, distributed, lr_factor) if teacher_model is None \ else get_distillation_box(teacher_model, student_model, dataset_dict, train_config, device, device_ids, distributed, lr_factor) best_val_map = 0.0 optimizer, lr_scheduler = training_box.optimizer, training_box.lr_scheduler if file_util.check_if_exists(ckpt_file_path): best_val_map, _, _ = load_ckpt(ckpt_file_path, optimizer=optimizer, lr_scheduler=lr_scheduler) log_freq = train_config['log_freq'] iou_types = args.iou_types val_iou_type = iou_types[0] if isinstance( iou_types, (list, tuple)) and len(iou_types) > 0 else 'bbox' student_model_without_ddp = student_model.module if module_util.check_if_wrapped( student_model) else student_model start_time = time.time() for epoch in range(args.start_epoch, training_box.num_epochs): training_box.pre_process(epoch=epoch) train_one_epoch(training_box, device, epoch, log_freq) val_coco_evaluator =\ evaluate(student_model, training_box.val_data_loader, iou_types, device, device_ids, distributed, log_freq=log_freq, header='Validation:') # Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] val_map = val_coco_evaluator.coco_eval[val_iou_type].stats[0] if val_map > best_val_map and is_main_process(): logger.info('Best mAP ({}): {:.4f} -> {:.4f}'.format( val_iou_type, best_val_map, val_map)) logger.info('Updating ckpt at {}'.format(ckpt_file_path)) best_val_map = val_map save_ckpt(student_model_without_ddp, optimizer, lr_scheduler, best_val_map, config, args, ckpt_file_path) training_box.post_process() if distributed: dist.barrier() total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) logger.info('Training time {}'.format(total_time_str)) training_box.clean_modules()
def main(args): log_file_path = args.log if is_main_process() and log_file_path is not None: setup_log_file(os.path.expanduser(log_file_path)) distributed, device_ids = init_distributed_mode(args.world_size, args.dist_url) logger.info(args) cudnn.benchmark = True config = yaml_util.load_yaml_file(os.path.expanduser(args.config)) device = torch.device(args.device) dataset_dict = util.get_all_dataset(config['datasets']) models_config = config['models'] teacher_model_config = models_config['teacher_model'] teacher_model = get_model(teacher_model_config, device) student_model_config = models_config['student_model'] student_model = get_model(student_model_config, device) if not args.test_only: distill(teacher_model, student_model, dataset_dict, device, device_ids, distributed, config, args) student_model_without_ddp =\ student_model.module if module_util.check_if_wrapped(student_model) else student_model load_ckpt(student_model_config['ckpt'], model=student_model_without_ddp, strict=True) test_config = config['test'] test_data_loader_config = test_config['test_data_loader'] test_data_loader = util.build_data_loader( dataset_dict[test_data_loader_config['dataset_id']], test_data_loader_config, distributed) num_classes = args.num_classes if not args.student_only: evaluate(teacher_model, test_data_loader, device, device_ids, distributed, num_classes=num_classes, title='[Teacher: {}]'.format(teacher_model_config['name'])) evaluate(student_model, test_data_loader, device, device_ids, distributed, num_classes=num_classes, title='[Student: {}]'.format(student_model_config['name']))
def distill(teacher_model, student_model, dataset_dict, device, device_ids, distributed, config, args): logger.info('Start distillation') train_config = config['train'] lr_factor = args.world_size if distributed and args.adjust_lr else 1 distillation_box =\ get_distillation_box(teacher_model, student_model, dataset_dict, train_config, device, device_ids, distributed, lr_factor) ckpt_file_path = config['models']['student_model']['ckpt'] best_val_top1_accuracy = 0.0 optimizer, lr_scheduler = distillation_box.optimizer, distillation_box.lr_scheduler if file_util.check_if_exists(ckpt_file_path): best_val_top1_accuracy, _, _ = load_ckpt(ckpt_file_path, optimizer=optimizer, lr_scheduler=lr_scheduler) log_freq = train_config['log_freq'] student_model_without_ddp = student_model.module if module_util.check_if_wrapped( student_model) else student_model start_time = time.time() for epoch in range(args.start_epoch, distillation_box.num_epochs): distillation_box.pre_process(epoch=epoch) distill_one_epoch(distillation_box, device, epoch, log_freq) val_top1_accuracy = evaluate(student_model, distillation_box.val_data_loader, device, device_ids, distributed, log_freq=log_freq, header='Validation:') if val_top1_accuracy > best_val_top1_accuracy and is_main_process(): logger.info('Updating ckpt (Best top1 accuracy: ' '{:.4f} -> {:.4f})'.format(best_val_top1_accuracy, val_top1_accuracy)) best_val_top1_accuracy = val_top1_accuracy save_ckpt(student_model_without_ddp, optimizer, lr_scheduler, best_val_top1_accuracy, config, args, ckpt_file_path) distillation_box.post_process() if distributed: dist.barrier() total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) logger.info('Training time {}'.format(total_time_str)) distillation_box.clean_modules()
def train(teacher_model, student_model, dataset_dict, ckpt_file_path, device, device_ids, distributed, config, args): logger.info('Start training') train_config = config['train'] lr_factor = args.world_size if distributed and args.adjust_lr else 1 training_box = get_training_box(student_model, dataset_dict, train_config, device, device_ids, distributed, lr_factor) if teacher_model is None \ else get_distillation_box(teacher_model, student_model, dataset_dict, train_config, device, device_ids, distributed, lr_factor) best_val_miou = 0.0 optimizer, lr_scheduler = training_box.optimizer, training_box.lr_scheduler if file_util.check_if_exists(ckpt_file_path): best_val_miou, _, _ = load_ckpt(ckpt_file_path, optimizer=optimizer, lr_scheduler=lr_scheduler) log_freq = train_config['log_freq'] student_model_without_ddp = student_model.module if module_util.check_if_wrapped( student_model) else student_model start_time = time.time() for epoch in range(args.start_epoch, training_box.num_epochs): training_box.pre_process(epoch=epoch) train_one_epoch(training_box, device, epoch, log_freq) val_seg_evaluator =\ evaluate(student_model, training_box.val_data_loader, device, device_ids, distributed, num_classes=args.num_classes, log_freq=log_freq, header='Validation:') val_acc_global, val_acc, val_iou = val_seg_evaluator.compute() val_miou = val_iou.mean().item() if val_miou > best_val_miou and is_main_process(): logger.info('Updating ckpt (Best mIoU: {:.4f} -> {:.4f})'.format( best_val_miou, val_miou)) best_val_miou = val_miou save_ckpt(student_model_without_ddp, optimizer, lr_scheduler, best_val_miou, config, args, ckpt_file_path) training_box.post_process() if distributed: dist.barrier() total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) logger.info('Training time {}'.format(total_time_str)) training_box.clean_modules()
def save_module_ckpt(module, ckpt_file_path): if is_main_process(): make_parent_dirs(ckpt_file_path) state_dict = module.module.state_dict() if check_if_wrapped(module) else module.state_dict() save_on_master(state_dict, ckpt_file_path)
def main(args): log_file_path = args.log if is_main_process() and log_file_path is not None: setup_log_file(os.path.expanduser(log_file_path)) world_size = args.world_size logger.info(args) cudnn.benchmark = True set_seed(args.seed) config = yaml_util.load_yaml_file(os.path.expanduser(args.config)) # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. accelerator = Accelerator() distributed = accelerator.state.distributed_type == DistributedType.MULTI_GPU device_ids = [accelerator.device.index] if distributed: setup_for_distributed(is_main_process()) logger.info(accelerator.state) device = accelerator.device # Setup logging, we only want one process per machine to log things on the screen. # accelerator.is_local_main_process is only True for one process per machine. logger.setLevel( logging.INFO if accelerator.is_local_main_process else logging.ERROR) if accelerator.is_local_main_process: datasets.utils.logging.set_verbosity_warning() transformers.utils.logging.set_verbosity_info() else: datasets.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error() # Load pretrained model and tokenizer task_name = args.task_name models_config = config['models'] teacher_model_config = models_config.get('teacher_model', None) teacher_tokenizer, teacher_model = (None, None) if teacher_model_config is None \ else load_tokenizer_and_model(teacher_model_config, task_name, True) student_model_config =\ models_config['student_model'] if 'student_model' in models_config else models_config['model'] student_tokenizer, student_model = load_tokenizer_and_model( student_model_config, task_name) ckpt_dir_path = student_model_config['ckpt'] # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub). dataset_dict, label_names_dict, is_regression = \ get_all_datasets(config['datasets'], task_name, student_tokenizer, student_model) # Update config with dataset size len(data_loader) customize_lr_config(config, dataset_dict, world_size) # register collate function register_collate_func( DataCollatorWithPadding( student_tokenizer, pad_to_multiple_of=(8 if accelerator.use_fp16 else None))) # Get the metric function metric = get_metrics(task_name) if not args.test_only: train(teacher_model, student_model, dataset_dict, is_regression, ckpt_dir_path, metric, device, device_ids, distributed, config, args, accelerator) student_tokenizer.save_pretrained(ckpt_dir_path) test_config = config['test'] test_data_loader_config = test_config['test_data_loader'] test_data_loader = util.build_data_loader( dataset_dict[test_data_loader_config['dataset_id']], test_data_loader_config, distributed) test_data_loader = accelerator.prepare(test_data_loader) if not args.student_only and teacher_model is not None: teacher_model = teacher_model.to(accelerator.device) evaluate(teacher_model, test_data_loader, metric, is_regression, accelerator, title='[Teacher: {}]'.format(teacher_model_config['name'])) # Reload best checkpoint based on validation result student_tokenizer, student_model = load_tokenizer_and_model( student_model_config, task_name, True) student_model = accelerator.prepare(student_model) evaluate(student_model, test_data_loader, metric, is_regression, accelerator, title='[Student: {}]'.format(student_model_config['name'])) # Output prediction for private dataset(s) if both the config and output dir path are given private_configs = config.get('private', None) private_output_dir_path = args.private_output if private_configs is not None and private_output_dir_path is not None and is_main_process( ): predict_private(student_model, dataset_dict, label_names_dict, is_regression, accelerator, private_configs, private_output_dir_path)
def log_info(*args, **kwargs): force = kwargs.pop('force', False) if is_main_process() or force: logger.info(*args, **kwargs)