def load_ckpt(ckpt_file_path, model=None, optimizer=None, lr_scheduler=None, strict=True): if not check_if_exists(ckpt_file_path): logger.info('ckpt file is not found at `{}`'.format(ckpt_file_path)) return None, None ckpt = torch.load(ckpt_file_path, map_location='cpu') if model is not None: if 'model' in ckpt: logger.info('Loading model parameters') model.load_state_dict(ckpt['model'], strict=strict) else: logger.info('No model parameters found') if optimizer is not None: if 'optimizer' in ckpt: logger.info('Loading optimizer parameters') optimizer.load_state_dict(ckpt['optimizer']) else: logger.info('No optimizer parameters found') if lr_scheduler is not None: if 'lr_scheduler' in ckpt: logger.info('Loading scheduler parameters') lr_scheduler.load_state_dict(ckpt['lr_scheduler']) else: logger.info('No scheduler parameters found') return ckpt.get('best_value', 0.0), ckpt.get('config', None), ckpt.get('args', None)
def main(args): in_ckpt_file_path = args.src if not check_if_exists(in_ckpt_file_path): print('ckpt file is not found at `{}`'.format(in_ckpt_file_path)) return src_ckpt = torch.load(in_ckpt_file_path, map_location='cpu') out_ckpt_file_path = args.dst if len(args.keys) == 1 and not args.use_dict: key = args.keys[0] if key in src_ckpt: save_obj(src_ckpt[key], out_ckpt_file_path) else: print('Parameter key `{}` was not found'.format(key)) return dst_ckpt = dict() for key in args.keys: if key in src_ckpt: dst_ckpt[key] = src_ckpt[key] else: print('Parameter key `{}` was not found'.format(key)) if len(dst_ckpt) > 0: save_obj(dst_ckpt, out_ckpt_file_path)
def __getitem__(self, index): sample, target, supp_dict = super().__getitem__(index) cache_file_path = os.path.join(self.cache_dir_path, self.idx2subath_func(index) + self.ext) if file_util.check_if_exists(cache_file_path): cached_data = torch.load(cache_file_path) supp_dict['cached_data'] = cached_data supp_dict['cache_file_path'] = cache_file_path return sample, target, supp_dict
def load_ckpt(ckpt_file_path, model=None, optimizer=None, lr_scheduler=None, strict=True): if check_if_exists(ckpt_file_path): ckpt = torch.load(ckpt_file_path, map_location='cpu') elif isinstance(ckpt_file_path, str) and \ (ckpt_file_path.startswith('https://') or ckpt_file_path.startswith('http://')): ckpt = torch.hub.load_state_dict_from_url(ckpt_file_path, map_location='cpu', progress=True) else: logger.info('ckpt file is not found at `{}`'.format(ckpt_file_path)) return None, None, None if model is not None: if 'model' in ckpt: logger.info('Loading model parameters') if strict is None: model.load_state_dict(ckpt['model'], strict=strict) else: model.load_state_dict(ckpt['model'], strict=strict) elif optimizer is None and lr_scheduler is None: logger.info('Loading model parameters only') model.load_state_dict(ckpt, strict=strict) else: logger.info('No model parameters found') if optimizer is not None: if 'optimizer' in ckpt: logger.info('Loading optimizer parameters') optimizer.load_state_dict(ckpt['optimizer']) elif model is None and lr_scheduler is None: logger.info('Loading optimizer parameters only') optimizer.load_state_dict(ckpt) else: logger.info('No optimizer parameters found') if lr_scheduler is not None: if 'lr_scheduler' in ckpt: logger.info('Loading scheduler parameters') lr_scheduler.load_state_dict(ckpt['lr_scheduler']) elif model is None and optimizer is None: logger.info('Loading scheduler parameters only') lr_scheduler.load_state_dict(ckpt) else: logger.info('No scheduler parameters found') return ckpt.get('best_value', 0.0), ckpt.get('config', None), ckpt.get('args', None)
def train(teacher_model, student_model, dataset_dict, ckpt_file_path, device, device_ids, distributed, config, args): logger.info('Start training') train_config = config['train'] lr_factor = args.world_size if distributed and args.adjust_lr else 1 training_box = get_training_box(student_model, dataset_dict, train_config, device, device_ids, distributed, lr_factor) if teacher_model is None \ else get_distillation_box(teacher_model, student_model, dataset_dict, train_config, device, device_ids, distributed, lr_factor) best_val_map = 0.0 optimizer, lr_scheduler = training_box.optimizer, training_box.lr_scheduler if file_util.check_if_exists(ckpt_file_path): best_val_map, _, _ = load_ckpt(ckpt_file_path, optimizer=optimizer, lr_scheduler=lr_scheduler) log_freq = train_config['log_freq'] iou_types = args.iou_types val_iou_type = iou_types[0] if isinstance( iou_types, (list, tuple)) and len(iou_types) > 0 else 'bbox' student_model_without_ddp = student_model.module if module_util.check_if_wrapped( student_model) else student_model start_time = time.time() for epoch in range(args.start_epoch, training_box.num_epochs): training_box.pre_process(epoch=epoch) train_one_epoch(training_box, device, epoch, log_freq) val_coco_evaluator =\ evaluate(student_model, training_box.val_data_loader, iou_types, device, device_ids, distributed, log_freq=log_freq, header='Validation:') # Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] val_map = val_coco_evaluator.coco_eval[val_iou_type].stats[0] if val_map > best_val_map and is_main_process(): logger.info('Best mAP ({}): {:.4f} -> {:.4f}'.format( val_iou_type, best_val_map, val_map)) logger.info('Updating ckpt at {}'.format(ckpt_file_path)) best_val_map = val_map save_ckpt(student_model_without_ddp, optimizer, lr_scheduler, best_val_map, config, args, ckpt_file_path) training_box.post_process() if distributed: dist.barrier() total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) logger.info('Training time {}'.format(total_time_str)) training_box.clean_modules()
def distill(teacher_model, student_model, dataset_dict, device, device_ids, distributed, config, args): logger.info('Start distillation') train_config = config['train'] lr_factor = args.world_size if distributed and args.adjust_lr else 1 distillation_box =\ get_distillation_box(teacher_model, student_model, dataset_dict, train_config, device, device_ids, distributed, lr_factor) ckpt_file_path = config['models']['student_model']['ckpt'] best_val_top1_accuracy = 0.0 optimizer, lr_scheduler = distillation_box.optimizer, distillation_box.lr_scheduler if file_util.check_if_exists(ckpt_file_path): best_val_top1_accuracy, _, _ = load_ckpt(ckpt_file_path, optimizer=optimizer, lr_scheduler=lr_scheduler) log_freq = train_config['log_freq'] student_model_without_ddp = student_model.module if module_util.check_if_wrapped( student_model) else student_model start_time = time.time() for epoch in range(args.start_epoch, distillation_box.num_epochs): distillation_box.pre_process(epoch=epoch) distill_one_epoch(distillation_box, device, epoch, log_freq) val_top1_accuracy = evaluate(student_model, distillation_box.val_data_loader, device, device_ids, distributed, log_freq=log_freq, header='Validation:') if val_top1_accuracy > best_val_top1_accuracy and is_main_process(): logger.info('Updating ckpt (Best top1 accuracy: ' '{:.4f} -> {:.4f})'.format(best_val_top1_accuracy, val_top1_accuracy)) best_val_top1_accuracy = val_top1_accuracy save_ckpt(student_model_without_ddp, optimizer, lr_scheduler, best_val_top1_accuracy, config, args, ckpt_file_path) distillation_box.post_process() if distributed: dist.barrier() total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) logger.info('Training time {}'.format(total_time_str)) distillation_box.clean_modules()
def load_tokenizer_and_model(model_config, task_name, prioritizes_ckpt=False): # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently # download model & vocab. num_labels = model_config['num_labels'] config_config = model_config['config_params'] config = AutoConfig.from_pretrained(**config_config, num_labels=num_labels, finetuning_task=task_name) tokenizer_config = model_config['tokenizer_params'] tokenizer = AutoTokenizer.from_pretrained(**tokenizer_config) from_tf = model_config.get('from_tf', False) model_name_or_path = model_config['ckpt'] \ if prioritizes_ckpt and file_util.check_if_exists(model_config.get('ckpt', None)) \ else model_config['model_name_or_path'] model = AutoModelForSequenceClassification.from_pretrained( model_name_or_path, from_tf=from_tf, config=config) return tokenizer, model
def train(teacher_model, student_model, dataset_dict, ckpt_file_path, device, device_ids, distributed, config, args): logger.info('Start training') train_config = config['train'] lr_factor = args.world_size if distributed and args.adjust_lr else 1 training_box = get_training_box(student_model, dataset_dict, train_config, device, device_ids, distributed, lr_factor) if teacher_model is None \ else get_distillation_box(teacher_model, student_model, dataset_dict, train_config, device, device_ids, distributed, lr_factor) best_val_miou = 0.0 optimizer, lr_scheduler = training_box.optimizer, training_box.lr_scheduler if file_util.check_if_exists(ckpt_file_path): best_val_miou, _, _ = load_ckpt(ckpt_file_path, optimizer=optimizer, lr_scheduler=lr_scheduler) log_freq = train_config['log_freq'] student_model_without_ddp = student_model.module if module_util.check_if_wrapped( student_model) else student_model start_time = time.time() for epoch in range(args.start_epoch, training_box.num_epochs): training_box.pre_process(epoch=epoch) train_one_epoch(training_box, device, epoch, log_freq) val_seg_evaluator =\ evaluate(student_model, training_box.val_data_loader, device, device_ids, distributed, num_classes=args.num_classes, log_freq=log_freq, header='Validation:') val_acc_global, val_acc, val_iou = val_seg_evaluator.compute() val_miou = val_iou.mean().item() if val_miou > best_val_miou and is_main_process(): logger.info('Updating ckpt (Best mIoU: {:.4f} -> {:.4f})'.format( best_val_miou, val_miou)) best_val_miou = val_miou save_ckpt(student_model_without_ddp, optimizer, lr_scheduler, best_val_miou, config, args, ckpt_file_path) training_box.post_process() if distributed: dist.barrier() total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) logger.info('Training time {}'.format(total_time_str)) training_box.clean_modules()