Exemplo n.º 1
0
def main(args):
    config = yaml_util.load_yaml_file(args.config)
    if args.json is not None:
        main_util.overwrite_config(config, args.json)

    distributed, device_ids = main_util.init_distributed_mode(
        args.world_size, args.dist_url)
    device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
    teacher_model = get_model(config['teacher_model'], device)
    module_util.freeze_module_params(teacher_model)
    student_model_config = config['student_model']
    student_model = get_model(student_model_config, device)
    freeze_modules(student_model, student_model_config)
    print('Updatable parameters: {}'.format(
        module_util.get_updatable_param_names(student_model)))
    distill_backbone_only = student_model_config['distill_backbone_only']
    train_config = config['train']
    train_sampler, train_data_loader, val_data_loader, test_data_loader = \
        data_util.get_coco_data_loaders(config['dataset'], train_config['batch_size'], distributed)
    if distributed:
        teacher_model = DataParallel(teacher_model, device_ids=device_ids)
        student_model = DistributedDataParallel(student_model,
                                                device_ids=device_ids)

    if args.distill:
        distill(teacher_model, student_model, train_sampler, train_data_loader,
                val_data_loader, device, distributed, distill_backbone_only,
                config, args)
        load_ckpt(
            config['student_model']['ckpt'],
            model=student_model.module if isinstance(
                student_model, DistributedDataParallel) else student_model)
    evaluate(teacher_model, student_model, test_data_loader, device,
             args.skip_teacher_eval, args.transform_bottleneck)
Exemplo n.º 2
0
def train(train_loader, valid_loader, input_shape, config, device, distributed,
          device_ids):
    ae_without_ddp, ae_type = ae_util.get_autoencoder(config, device)
    head_model = ae_util.get_head_model(config, input_shape, device)
    module_util.freeze_module_params(head_model)
    ckpt_file_path = config['autoencoder']['ckpt']
    start_epoch, best_valid_acc = resume_from_ckpt(ckpt_file_path,
                                                   ae_without_ddp)
    if best_valid_acc is None:
        best_valid_acc = 0.0

    train_config = config['train']
    criterion_config = train_config['criterion']
    criterion = func_util.get_loss(criterion_config['type'],
                                   criterion_config['params'])
    optim_config = train_config['optimizer']
    optimizer = func_util.get_optimizer(ae_without_ddp, optim_config['type'],
                                        optim_config['params'])
    scheduler_config = train_config['scheduler']
    scheduler = func_util.get_scheduler(optimizer, scheduler_config['type'],
                                        scheduler_config['params'])
    interval = train_config['interval']
    if interval <= 0:
        num_batches = len(train_loader)
        interval = num_batches // 20 if num_batches >= 20 else 1

    autoencoder = ae_without_ddp
    if distributed:
        autoencoder = DistributedDataParallel(ae_without_ddp,
                                              device_ids=device_ids)
        head_model = DataParallel(head_model, device_ids=device_ids)
    elif device.type == 'cuda':
        autoencoder = DataParallel(ae_without_ddp)
        head_model = DataParallel(head_model)

    end_epoch = start_epoch + train_config['epoch']
    start_time = time.time()
    for epoch in range(start_epoch, end_epoch):
        if distributed:
            train_loader.sampler.set_epoch(epoch)

        train_epoch(autoencoder, head_model, train_loader, optimizer,
                    criterion, epoch, device, interval)
        valid_acc = validate(ae_without_ddp, valid_loader, config, device,
                             distributed, device_ids)
        if valid_acc > best_valid_acc and main_util.is_main_process():
            print(
                'Updating ckpt (Best top1 accuracy: {:.4f} -> {:.4f})'.format(
                    best_valid_acc, valid_acc))
            best_valid_acc = valid_acc
            save_ckpt(ae_without_ddp, epoch, best_valid_acc, ckpt_file_path,
                      ae_type)
        scheduler.step()

    dist.barrier()
    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
    del head_model
Exemplo n.º 3
0
def freeze_modules(student_model, student_model_config):
    if 'frozen_modules' in student_model_config:
        for student_path in student_model_config['frozen_modules']:
            student_module = module_util.get_module(student_model, student_path)
            module_util.freeze_module_params(student_module)

    elif 'unfrozen_modules' in student_model_config:
        module_util.freeze_module_params(student_model)
        for student_path in student_model_config['unfrozen_modules']:
            student_module = module_util.get_module(student_model, student_path)
            module_util.unfreeze_module_params(student_module)
Exemplo n.º 4
0
def freeze_modules(student_model, student_model_config, reset_unfrozen=False):
    if 'frozen_modules' in student_model_config:
        for student_path in student_model_config['frozen_modules']:
            student_module = module_util.get_module(student_model,
                                                    student_path)
            module_util.freeze_module_params(student_module)

    elif 'unfrozen_modules' in student_model_config:
        module_util.freeze_module_params(student_model)
        for student_path in student_model_config['unfrozen_modules']:
            student_module = module_util.get_module(student_model,
                                                    student_path)
            module_util.unfreeze_module_params(student_module)
            if reset_unfrozen:
                print("Reinitializing module: {}".format(student_path))
                init_weights(student_module)
Exemplo n.º 5
0
    def __init__(self, backbone, return_layers, in_channels_list, out_channels,
                 ext_config):
        super().__init__()
        if ext_config.get('backbone_frozen', False):
            module_util.freeze_module_params(backbone)

        self.body = ExtIntermediateLayerGetter(backbone,
                                               return_layers=return_layers,
                                               ext_config=ext_config)
        self.fpn = FeaturePyramidNetwork(
            in_channels_list=in_channels_list,
            out_channels=out_channels,
            extra_blocks=LastLevelMaxPool(),
        )
        self.out_channels = out_channels
        self.split = False
Exemplo n.º 6
0
def main(args):
    distributed, device_ids = main_util.init_distributed_mode(
        args.world_size, args.dist_url)
    config = yaml_util.load_yaml_file(args.config)
    if args.json is not None:
        main_util.overwrite_config(config, args.json)

    device = torch.device(args.device)
    print(args)
    print('Loading data')
    train_config = config['train']
    train_sampler, train_data_loader, val_data_loader, test_data_loader =\
        data_util.get_coco_data_loaders(config['dataset'], train_config['batch_size'], distributed)

    print('Creating model')
    model_config = config['model']
    model = get_model(model_config, device, strict=False)
    module_util.freeze_module_params(model)
    ext_classifier = model.get_ext_classifier()
    module_util.unfreeze_module_params(ext_classifier)
    print('Updatable parameters: {}'.format(
        module_util.get_updatable_param_names(model)))
    model.train_ext()
    if distributed:
        model = nn.parallel.DistributedDataParallel(model,
                                                    device_ids=device_ids)

    if args.train:
        print('Start training')
        start_time = time.time()
        ckpt_file_path = model_config['backbone']['ext_config']['ckpt']
        train(model, ext_classifier, train_sampler, train_data_loader,
              val_data_loader, device, distributed, config, args,
              ckpt_file_path)
        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
        print('Training time {}'.format(total_time_str))
        load_ckpt(ckpt_file_path, model=ext_classifier)
    evaluate(model,
             test_data_loader,
             device=device,
             min_recall=args.min_recall,
             split_name='Test')
Exemplo n.º 7
0
def experiment(config, device, args):
    # Load Model
    teacher_model = get_model(config['teacher_model'], device)
    module_util.freeze_module_params(teacher_model)
    student_model_config = config['student_model']
    student_model = get_model(student_model_config, device, strict=True, require_weights=True)
    set_bottleneck_transformer(student_model, True)
    if args.dry_run:
        return None

    #Load Dataset
    train_config = config['train']
    train_sampler, train_data_loader, val_data_loader, test_data_loader = \
        data_util.get_coco_data_loaders(config['dataset'], train_config['batch_size'], distributed=False)



    # Prepare dataframe
    header = ["Setting", "Iou50:95", "Iou50"]
    df = pd.DataFrame(columns=header)
    res = evaluate(teacher_model, test_data_loader, device)
    df = df.append({"Setting": "Teacher",
                    "Iou50:95": res[0], "Iou50": res[1]}, ignore_index=True)

    post_bn = False
    if 'post_batch_norm' in config['train']:
        post_bn = config['train']['post_batch_norm']

    width_mult_list = [1.0]
    if "slimmable" in student_model_config['backbone']['params']:
        width_mult_list = student_model_config['backbone']['params']['width_mult_list']
    for width in width_mult_list:
        print('\n[Student model@width={}]'.format(width))
        set_width(student_model, width)
        if post_bn:
            ComputeBN(student_model, train_data_loader)
        res = evaluate(student_model, test_data_loader, device)
        df = df.append({"Setting": str(width),
                        "Iou50:95": res[0], "Iou50": res[1]}, ignore_index=True)

    return df
Exemplo n.º 8
0
def main(args):
    config = yaml_util.load_yaml_file(args.config)
    if args.json is not None:
        main_util.overwrite_config(config, args.json)

    distributed, device_ids = main_util.init_distributed_mode(args.world_size, args.dist_url)
    device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
    teacher_model = get_model(config['teacher_model'], device)
    module_util.freeze_module_params(teacher_model)
    student_model_config = config['student_model']
    student_model = get_model(student_model_config, device)
    freeze_modules(student_model, student_model_config)
    ckpt_file_path = config['student_model']['ckpt']
    train_config = config['train']
    optim_config = train_config['optimizer']
    optimizer = func_util.get_optimizer(student_model, optim_config['type'], optim_config['params'])
    scheduler_config = train_config['scheduler']
    lr_scheduler = func_util.get_scheduler(optimizer, scheduler_config['type'], scheduler_config['params'])
    if file_util.check_if_exists(ckpt_file_path):
        best_val_map, _, _ = load_ckpt(ckpt_file_path, optimizer=optimizer, lr_scheduler=lr_scheduler)
        save_ckpt(student_model, optimizer, lr_scheduler, best_val_map, config, args, ckpt_file_path)
Exemplo n.º 9
0
def distill(train_loader, valid_loader, input_shape, aux_weight, config,
            device, distributed, device_ids):
    teacher_model_config = config['teacher_model']
    teacher_model, teacher_model_type = mimic_util.get_teacher_model(
        teacher_model_config, input_shape, device)
    module_util.freeze_module_params(teacher_model)
    student_model_config = config['student_model']
    student_model = mimic_util.get_student_model(teacher_model_type,
                                                 student_model_config,
                                                 config['dataset']['name'])
    student_model = student_model.to(device)
    start_epoch, best_valid_acc = mimic_util.resume_from_ckpt(
        student_model_config['ckpt'], student_model, is_student=True)
    if best_valid_acc is None:
        best_valid_acc = 0.0

    train_config = config['train']
    criterion_config = train_config['criterion']
    criterion = func_util.get_loss(criterion_config['type'],
                                   criterion_config['params'])
    optim_config = train_config['optimizer']
    optimizer = func_util.get_optimizer(student_model, optim_config['type'],
                                        optim_config['params'])
    scheduler_config = train_config['scheduler']
    scheduler = func_util.get_scheduler(optimizer, scheduler_config['type'],
                                        scheduler_config['params'])
    interval = train_config['interval']
    if interval <= 0:
        num_batches = len(train_loader)
        interval = num_batches // 20 if num_batches >= 20 else 1

    student_model_without_ddp = student_model
    if distributed:
        teacher_model = DataParallel(teacher_model, device_ids=device_ids)
        student_model = DistributedDataParallel(student_model,
                                                device_ids=device_ids)
        student_model_without_ddp = student_model.module

    ckpt_file_path = student_model_config['ckpt']
    end_epoch = start_epoch + train_config['epoch']
    start_time = time.time()
    for epoch in range(start_epoch, end_epoch):
        if distributed:
            train_loader.sampler.set_epoch(epoch)

        distill_one_epoch(student_model, teacher_model, train_loader,
                          optimizer, criterion, epoch, device, interval,
                          aux_weight)
        valid_acc = validate(student_model, valid_loader, config, device,
                             distributed, device_ids)
        if valid_acc > best_valid_acc and main_util.is_main_process():
            print(
                'Updating ckpt (Best top1 accuracy: {:.4f} -> {:.4f})'.format(
                    best_valid_acc, valid_acc))
            best_valid_acc = valid_acc
            save_ckpt(student_model_without_ddp, epoch, best_valid_acc,
                      ckpt_file_path, teacher_model_type)
        scheduler.step()

    dist.barrier()
    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
    del teacher_model
    del student_model
Exemplo n.º 10
0
def freeze_modules(student_model, student_model_config):
    for student_path in student_model_config['frozen_modules']:
        student_module = module_util.get_module(student_model, student_path)
        module_util.freeze_module_params(student_module)
Exemplo n.º 11
0
def main(args):
    if args.apex:
        if sys.version_info < (3, 0):
            raise RuntimeError(
                'Apex currently only supports Python 3. Aborting.')
        if amp is None:
            raise RuntimeError(
                'Failed to import apex. Please install apex from https://www.github.com/nvidia/apex '
                'to enable mixed-precision training.')

    distributed, device_ids = main_util.init_distributed_mode(
        args.world_size, args.dist_url)
    print(args)
    if torch.cuda.is_available():
        torch.backends.cudnn.benchmark = True

    config = yaml_util.load_yaml_file(args.config)
    device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
    dataset_config = config['dataset']
    input_shape = config['input_shape']
    train_config = config['train']
    test_config = config['test']
    train_data_loader, val_data_loader, test_data_loader =\
        dataset_util.get_data_loaders(dataset_config, batch_size=train_config['batch_size'],
                                      rough_size=train_config['rough_size'], reshape_size=input_shape[1:3],
                                      jpeg_quality=-1, test_batch_size=test_config['batch_size'],
                                      distributed=distributed)

    teacher_model_config = config['teacher_model']
    teacher_model, teacher_model_type = mimic_util.get_org_model(
        teacher_model_config, device)
    module_util.freeze_module_params(teacher_model)

    student_model = mimic_util.get_mimic_model_easily(config, device)
    student_model_config = config['mimic_model']

    optim_config = train_config['optimizer']
    optimizer = func_util.get_optimizer(student_model, optim_config['type'],
                                        optim_config['params'])
    use_apex = args.apex
    if use_apex:
        student_model, optimizer = amp.initialize(
            student_model, optimizer, opt_level=args.apex_opt_level)

    if distributed:
        teacher_model = DataParallel(teacher_model, device_ids=device_ids)
        student_model = DistributedDataParallel(student_model,
                                                device_ids=device_ids)

    start_epoch = args.start_epoch
    if not args.test_only:
        distill(teacher_model, student_model, train_data_loader,
                val_data_loader, device, distributed, start_epoch, config,
                args)
        student_model_without_ddp =\
            student_model.module if isinstance(student_model, DistributedDataParallel) else student_model
        load_ckpt(student_model_config['ckpt'],
                  model=student_model_without_ddp,
                  strict=True)

    if not args.student_only:
        evaluate(teacher_model,
                 test_data_loader,
                 device,
                 title='[Teacher: {}]'.format(teacher_model_type))
    evaluate(student_model,
             test_data_loader,
             device,
             title='[Student: {}]'.format(student_model_config['type']))