Example #1
0
def load_ckpt(ckpt_file_path,
              model=None,
              optimizer=None,
              lr_scheduler=None,
              strict=True):
    if not check_if_exists(ckpt_file_path):
        logger.info('ckpt file is not found at `{}`'.format(ckpt_file_path))
        return None, None

    ckpt = torch.load(ckpt_file_path, map_location='cpu')
    if model is not None:
        if 'model' in ckpt:
            logger.info('Loading model parameters')
            model.load_state_dict(ckpt['model'], strict=strict)
        else:
            logger.info('No model parameters found')
    if optimizer is not None:
        if 'optimizer' in ckpt:
            logger.info('Loading optimizer parameters')
            optimizer.load_state_dict(ckpt['optimizer'])
        else:
            logger.info('No optimizer parameters found')
    if lr_scheduler is not None:
        if 'lr_scheduler' in ckpt:
            logger.info('Loading scheduler parameters')
            lr_scheduler.load_state_dict(ckpt['lr_scheduler'])
        else:
            logger.info('No scheduler parameters found')
    return ckpt.get('best_value', 0.0), ckpt.get('config',
                                                 None), ckpt.get('args', None)
def main(args):
    in_ckpt_file_path = args.src
    if not check_if_exists(in_ckpt_file_path):
        print('ckpt file is not found at `{}`'.format(in_ckpt_file_path))
        return

    src_ckpt = torch.load(in_ckpt_file_path, map_location='cpu')
    out_ckpt_file_path = args.dst
    if len(args.keys) == 1 and not args.use_dict:
        key = args.keys[0]
        if key in src_ckpt:
            save_obj(src_ckpt[key], out_ckpt_file_path)
        else:
            print('Parameter key `{}` was not found'.format(key))
        return

    dst_ckpt = dict()
    for key in args.keys:
        if key in src_ckpt:
            dst_ckpt[key] = src_ckpt[key]
        else:
            print('Parameter key `{}` was not found'.format(key))

    if len(dst_ckpt) > 0:
        save_obj(dst_ckpt, out_ckpt_file_path)
Example #3
0
    def __getitem__(self, index):
        sample, target, supp_dict = super().__getitem__(index)
        cache_file_path = os.path.join(self.cache_dir_path,
                                       self.idx2subath_func(index) + self.ext)
        if file_util.check_if_exists(cache_file_path):
            cached_data = torch.load(cache_file_path)
            supp_dict['cached_data'] = cached_data

        supp_dict['cache_file_path'] = cache_file_path
        return sample, target, supp_dict
Example #4
0
def load_ckpt(ckpt_file_path,
              model=None,
              optimizer=None,
              lr_scheduler=None,
              strict=True):
    if check_if_exists(ckpt_file_path):
        ckpt = torch.load(ckpt_file_path, map_location='cpu')
    elif isinstance(ckpt_file_path, str) and \
            (ckpt_file_path.startswith('https://') or ckpt_file_path.startswith('http://')):
        ckpt = torch.hub.load_state_dict_from_url(ckpt_file_path,
                                                  map_location='cpu',
                                                  progress=True)
    else:
        logger.info('ckpt file is not found at `{}`'.format(ckpt_file_path))
        return None, None, None

    if model is not None:
        if 'model' in ckpt:
            logger.info('Loading model parameters')
            if strict is None:
                model.load_state_dict(ckpt['model'], strict=strict)
            else:
                model.load_state_dict(ckpt['model'], strict=strict)
        elif optimizer is None and lr_scheduler is None:
            logger.info('Loading model parameters only')
            model.load_state_dict(ckpt, strict=strict)
        else:
            logger.info('No model parameters found')

    if optimizer is not None:
        if 'optimizer' in ckpt:
            logger.info('Loading optimizer parameters')
            optimizer.load_state_dict(ckpt['optimizer'])
        elif model is None and lr_scheduler is None:
            logger.info('Loading optimizer parameters only')
            optimizer.load_state_dict(ckpt)
        else:
            logger.info('No optimizer parameters found')

    if lr_scheduler is not None:
        if 'lr_scheduler' in ckpt:
            logger.info('Loading scheduler parameters')
            lr_scheduler.load_state_dict(ckpt['lr_scheduler'])
        elif model is None and optimizer is None:
            logger.info('Loading scheduler parameters only')
            lr_scheduler.load_state_dict(ckpt)
        else:
            logger.info('No scheduler parameters found')
    return ckpt.get('best_value', 0.0), ckpt.get('config',
                                                 None), ckpt.get('args', None)
def train(teacher_model, student_model, dataset_dict, ckpt_file_path, device,
          device_ids, distributed, config, args):
    logger.info('Start training')
    train_config = config['train']
    lr_factor = args.world_size if distributed and args.adjust_lr else 1
    training_box = get_training_box(student_model, dataset_dict, train_config,
                                    device, device_ids, distributed, lr_factor) if teacher_model is None \
        else get_distillation_box(teacher_model, student_model, dataset_dict, train_config,
                                  device, device_ids, distributed, lr_factor)
    best_val_map = 0.0
    optimizer, lr_scheduler = training_box.optimizer, training_box.lr_scheduler
    if file_util.check_if_exists(ckpt_file_path):
        best_val_map, _, _ = load_ckpt(ckpt_file_path,
                                       optimizer=optimizer,
                                       lr_scheduler=lr_scheduler)

    log_freq = train_config['log_freq']
    iou_types = args.iou_types
    val_iou_type = iou_types[0] if isinstance(
        iou_types, (list, tuple)) and len(iou_types) > 0 else 'bbox'
    student_model_without_ddp = student_model.module if module_util.check_if_wrapped(
        student_model) else student_model
    start_time = time.time()
    for epoch in range(args.start_epoch, training_box.num_epochs):
        training_box.pre_process(epoch=epoch)
        train_one_epoch(training_box, device, epoch, log_freq)
        val_coco_evaluator =\
            evaluate(student_model, training_box.val_data_loader, iou_types, device, device_ids, distributed,
                     log_freq=log_freq, header='Validation:')
        # Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ]
        val_map = val_coco_evaluator.coco_eval[val_iou_type].stats[0]
        if val_map > best_val_map and is_main_process():
            logger.info('Best mAP ({}): {:.4f} -> {:.4f}'.format(
                val_iou_type, best_val_map, val_map))
            logger.info('Updating ckpt at {}'.format(ckpt_file_path))
            best_val_map = val_map
            save_ckpt(student_model_without_ddp, optimizer, lr_scheduler,
                      best_val_map, config, args, ckpt_file_path)
        training_box.post_process()

    if distributed:
        dist.barrier()

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    logger.info('Training time {}'.format(total_time_str))
    training_box.clean_modules()
def distill(teacher_model, student_model, dataset_dict, device, device_ids,
            distributed, config, args):
    logger.info('Start distillation')
    train_config = config['train']
    lr_factor = args.world_size if distributed and args.adjust_lr else 1
    distillation_box =\
        get_distillation_box(teacher_model, student_model, dataset_dict,
                             train_config, device, device_ids, distributed, lr_factor)
    ckpt_file_path = config['models']['student_model']['ckpt']
    best_val_top1_accuracy = 0.0
    optimizer, lr_scheduler = distillation_box.optimizer, distillation_box.lr_scheduler
    if file_util.check_if_exists(ckpt_file_path):
        best_val_top1_accuracy, _, _ = load_ckpt(ckpt_file_path,
                                                 optimizer=optimizer,
                                                 lr_scheduler=lr_scheduler)

    log_freq = train_config['log_freq']
    student_model_without_ddp = student_model.module if module_util.check_if_wrapped(
        student_model) else student_model
    start_time = time.time()
    for epoch in range(args.start_epoch, distillation_box.num_epochs):
        distillation_box.pre_process(epoch=epoch)
        distill_one_epoch(distillation_box, device, epoch, log_freq)
        val_top1_accuracy = evaluate(student_model,
                                     distillation_box.val_data_loader,
                                     device,
                                     device_ids,
                                     distributed,
                                     log_freq=log_freq,
                                     header='Validation:')
        if val_top1_accuracy > best_val_top1_accuracy and is_main_process():
            logger.info('Updating ckpt (Best top1 accuracy: '
                        '{:.4f} -> {:.4f})'.format(best_val_top1_accuracy,
                                                   val_top1_accuracy))
            best_val_top1_accuracy = val_top1_accuracy
            save_ckpt(student_model_without_ddp, optimizer, lr_scheduler,
                      best_val_top1_accuracy, config, args, ckpt_file_path)
        distillation_box.post_process()

    if distributed:
        dist.barrier()

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    logger.info('Training time {}'.format(total_time_str))
    distillation_box.clean_modules()
Example #7
0
def load_tokenizer_and_model(model_config, task_name, prioritizes_ckpt=False):
    # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.
    num_labels = model_config['num_labels']
    config_config = model_config['config_params']
    config = AutoConfig.from_pretrained(**config_config,
                                        num_labels=num_labels,
                                        finetuning_task=task_name)
    tokenizer_config = model_config['tokenizer_params']
    tokenizer = AutoTokenizer.from_pretrained(**tokenizer_config)
    from_tf = model_config.get('from_tf', False)
    model_name_or_path = model_config['ckpt'] \
        if prioritizes_ckpt and file_util.check_if_exists(model_config.get('ckpt', None)) \
        else model_config['model_name_or_path']
    model = AutoModelForSequenceClassification.from_pretrained(
        model_name_or_path, from_tf=from_tf, config=config)
    return tokenizer, model
def train(teacher_model, student_model, dataset_dict, ckpt_file_path, device,
          device_ids, distributed, config, args):
    logger.info('Start training')
    train_config = config['train']
    lr_factor = args.world_size if distributed and args.adjust_lr else 1
    training_box = get_training_box(student_model, dataset_dict, train_config,
                                    device, device_ids, distributed, lr_factor) if teacher_model is None \
        else get_distillation_box(teacher_model, student_model, dataset_dict, train_config,
                                  device, device_ids, distributed, lr_factor)
    best_val_miou = 0.0
    optimizer, lr_scheduler = training_box.optimizer, training_box.lr_scheduler
    if file_util.check_if_exists(ckpt_file_path):
        best_val_miou, _, _ = load_ckpt(ckpt_file_path,
                                        optimizer=optimizer,
                                        lr_scheduler=lr_scheduler)

    log_freq = train_config['log_freq']
    student_model_without_ddp = student_model.module if module_util.check_if_wrapped(
        student_model) else student_model
    start_time = time.time()
    for epoch in range(args.start_epoch, training_box.num_epochs):
        training_box.pre_process(epoch=epoch)
        train_one_epoch(training_box, device, epoch, log_freq)
        val_seg_evaluator =\
            evaluate(student_model, training_box.val_data_loader, device, device_ids, distributed,
                     num_classes=args.num_classes, log_freq=log_freq, header='Validation:')

        val_acc_global, val_acc, val_iou = val_seg_evaluator.compute()
        val_miou = val_iou.mean().item()
        if val_miou > best_val_miou and is_main_process():
            logger.info('Updating ckpt (Best mIoU: {:.4f} -> {:.4f})'.format(
                best_val_miou, val_miou))
            best_val_miou = val_miou
            save_ckpt(student_model_without_ddp, optimizer, lr_scheduler,
                      best_val_miou, config, args, ckpt_file_path)
        training_box.post_process()

    if distributed:
        dist.barrier()

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    logger.info('Training time {}'.format(total_time_str))
    training_box.clean_modules()