def train(teacher_model, student_model, dataset_dict, is_regression, ckpt_dir_path, metric, device, device_ids, distributed, config, args, accelerator): logger.info('Start training') train_config = config['train'] lr_factor = args.world_size if distributed and args.adjust_lr else 1 training_box = get_training_box(student_model, dataset_dict, train_config, device, device_ids, distributed, lr_factor, accelerator) if teacher_model is None \ else get_distillation_box(teacher_model, student_model, dataset_dict, train_config, device, device_ids, distributed, lr_factor, accelerator) # Only show the progress bar once on each machine. log_freq = train_config['log_freq'] best_val_number = 0.0 for epoch in range(training_box.num_epochs): train_one_epoch(training_box, epoch, log_freq) val_dict = evaluate(student_model, training_box.val_data_loader, metric, is_regression, accelerator, header='Validation: ') val_value = sum(val_dict.values()) if val_value > best_val_number: logger.info('Updating ckpt at {}'.format(ckpt_dir_path)) best_val_number = val_value accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(student_model) unwrapped_model.save_pretrained(ckpt_dir_path, save_function=accelerator.save) training_box.post_process()
def train(teacher_model, student_model, dataset_dict, ckpt_file_path, device, device_ids, distributed, config, args): logger.info('Start training') train_config = config['train'] lr_factor = args.world_size if distributed and args.adjust_lr else 1 training_box = get_training_box(student_model, dataset_dict, train_config, device, device_ids, distributed, lr_factor) if teacher_model is None \ else get_distillation_box(teacher_model, student_model, dataset_dict, train_config, device, device_ids, distributed, lr_factor) best_val_map = 0.0 optimizer, lr_scheduler = training_box.optimizer, training_box.lr_scheduler if file_util.check_if_exists(ckpt_file_path): best_val_map, _, _ = load_ckpt(ckpt_file_path, optimizer=optimizer, lr_scheduler=lr_scheduler) log_freq = train_config['log_freq'] iou_types = args.iou_types val_iou_type = iou_types[0] if isinstance( iou_types, (list, tuple)) and len(iou_types) > 0 else 'bbox' student_model_without_ddp = student_model.module if module_util.check_if_wrapped( student_model) else student_model start_time = time.time() for epoch in range(args.start_epoch, training_box.num_epochs): training_box.pre_process(epoch=epoch) train_one_epoch(training_box, device, epoch, log_freq) val_coco_evaluator =\ evaluate(student_model, training_box.val_data_loader, iou_types, device, device_ids, distributed, log_freq=log_freq, header='Validation:') # Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] val_map = val_coco_evaluator.coco_eval[val_iou_type].stats[0] if val_map > best_val_map and is_main_process(): logger.info('Best mAP ({}): {:.4f} -> {:.4f}'.format( val_iou_type, best_val_map, val_map)) logger.info('Updating ckpt at {}'.format(ckpt_file_path)) best_val_map = val_map save_ckpt(student_model_without_ddp, optimizer, lr_scheduler, best_val_map, config, args, ckpt_file_path) training_box.post_process() if distributed: dist.barrier() total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) logger.info('Training time {}'.format(total_time_str)) training_box.clean_modules()
def train(teacher_model, student_model, dataset_dict, ckpt_file_path, device, device_ids, distributed, config, args): logger.info('Start training') train_config = config['train'] lr_factor = args.world_size if distributed and args.adjust_lr else 1 training_box = get_training_box(student_model, dataset_dict, train_config, device, device_ids, distributed, lr_factor) if teacher_model is None \ else get_distillation_box(teacher_model, student_model, dataset_dict, train_config, device, device_ids, distributed, lr_factor) best_val_top1_accuracy = 0.0 optimizer, lr_scheduler = training_box.optimizer, training_box.lr_scheduler if file_util.check_if_exists(ckpt_file_path): best_val_top1_accuracy, _, _ = load_ckpt(ckpt_file_path, optimizer=optimizer, lr_scheduler=lr_scheduler) log_freq = train_config['log_freq'] student_model_without_ddp = student_model.module if module_util.check_if_wrapped( student_model) else student_model start_time = time.time() for epoch in range(args.start_epoch, training_box.num_epochs): training_box.pre_process(epoch=epoch) train_one_epoch(training_box, device, epoch, log_freq) val_top1_accuracy = evaluate(student_model, training_box.val_data_loader, device, device_ids, distributed, log_freq=log_freq, header='Validation:') if val_top1_accuracy > best_val_top1_accuracy and is_main_process(): logger.info('Updating ckpt (Best top1 accuracy: ' '{:.4f} -> {:.4f})'.format(best_val_top1_accuracy, val_top1_accuracy)) best_val_top1_accuracy = val_top1_accuracy save_ckpt(student_model_without_ddp, optimizer, lr_scheduler, best_val_top1_accuracy, config, args, ckpt_file_path) training_box.post_process() if distributed: dist.barrier() total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) logger.info('Training time {}'.format(total_time_str)) training_box.clean_modules()