Example #1
0
def get_net():
    # config = get_efficientdet_config('tf_efficientdet_d4')
    config = get_efficientdet_config('tf_efficientdet_d0')
    print(config)
    net = EfficientDet(config, pretrained_backbone=False)

    count = 0
    for param in net.parameters():
        count += torch.prod(torch.tensor(param.shape))
    print(count)
    # checkpoint = torch.load('/content/drive/MyDrive/efficientdet-pytorch/efficientdet_d4-5b370b7a.pth')
    # checkpoint = torch.load('/content/drive/MyDrive/efficientdet-pytorch/efficientdet_d0-d92fd44f.pth')
    # net.load_state_dict(checkpoint)
    state_dict = load_state_dict_from_url(config.url,
                                          progress=False,
                                          map_location='cpu')
    net.load_state_dict(state_dict, strict=True)

    net.reset_head(num_classes=1)

    return net
Example #2
0
def validate(args):
    # might as well try to validate something
    args.pretrained = args.pretrained or not args.checkpoint
    args.prefetcher = not args.no_prefetcher

    # create model
    config = get_efficientdet_config(args.model)
    model = EfficientDet(config)
    if args.checkpoint:
        load_checkpoint(model, args.checkpoint)

    param_count = sum([m.numel() for m in model.parameters()])
    print('Model %s created, param count: %d' % (args.model, param_count))

    bench = DetBenchEval(model, config)
    bench = bench.cuda()
    if has_amp:
        print('Using AMP mixed precision.')
        bench = amp.initialize(bench, opt_level='O1')
    else:
        print('AMP not installed, running network in FP32.')

    if args.num_gpu > 1:
        bench = torch.nn.DataParallel(bench,
                                      device_ids=list(range(args.num_gpu)))

    if 'test' in args.anno:
        annotation_path = os.path.join(args.data, 'annotations',
                                       f'image_info_{args.anno}.json')
        image_dir = 'test2017'
    else:
        annotation_path = os.path.join(args.data, 'annotations',
                                       f'instances_{args.anno}.json')
        image_dir = args.anno
    dataset = CocoDetection(os.path.join(args.data, image_dir),
                            annotation_path)

    loader = create_loader(dataset,
                           input_size=config.image_size,
                           batch_size=args.batch_size,
                           use_prefetcher=args.prefetcher,
                           interpolation=args.interpolation,
                           num_workers=args.workers)

    img_ids = []
    results = []
    model.eval()
    batch_time = AverageMeter()
    end = time.time()
    with torch.no_grad():
        for i, (input, target) in enumerate(loader):
            output = bench(input, target['scale'])
            output = output.cpu()
            sample_ids = target['img_id'].cpu()
            for index, sample in enumerate(output):
                image_id = int(sample_ids[index])
                for det in sample:
                    score = float(det[4])
                    if score < .001:  # stop when below this threshold, scores in descending order
                        break
                    coco_det = dict(image_id=image_id,
                                    bbox=det[0:4].tolist(),
                                    score=score,
                                    category_id=int(det[5]))
                    img_ids.append(image_id)
                    results.append(coco_det)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.log_freq == 0:
                print(
                    'Test: [{0:>4d}/{1}]  '
                    'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    .format(
                        i,
                        len(loader),
                        batch_time=batch_time,
                        rate_avg=input.size(0) / batch_time.avg,
                    ))

    json.dump(results, open(args.results, 'w'), indent=4)
    if 'test' not in args.anno:
        coco_results = dataset.coco.loadRes(args.results)
        coco_eval = COCOeval(dataset.coco, coco_results, 'bbox')
        coco_eval.params.imgIds = img_ids  # score only ids we've used
        coco_eval.evaluate()
        coco_eval.accumulate()
        coco_eval.summarize()

    return results
def main():
    setup_default_logging()
    args, args_text = _parse_args()

    args.prefetcher = not args.no_prefetcher
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
    args.device = 'cuda:0'
    args.world_size = 1
    args.rank = 0  # global rank
    if args.distributed:
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        args.rank = torch.distributed.get_rank()
    assert args.rank >= 0

    if args.distributed:
        logging.info(
            'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
            % (args.rank, args.world_size))
    else:
        logging.info('Training with a single process on 1 GPU.')

    #torch.manual_seed(args.seed + args.rank)

    # create model
    config = get_efficientdet_config(args.model)
    config.redundant_bias = args.redundant_bias  # redundant conv + BN bias layers (True to match official models)
    model = EfficientDet(config)
    if args.initial_checkpoint:
        load_checkpoint(model, args.initial_checkpoint)
    config.num_classes = 5
    model.class_net.predict.conv_pw = create_conv2d(config.fpn_channels,
                                                    9 * 5,
                                                    1,
                                                    padding=config.pad_type,
                                                    bias=True)
    variance_scaling(model.class_net.predict.conv_pw.weight)
    model.class_net.predict.conv_pw.bias.data.fill_(-math.log((1 - 0.01) /
                                                              0.01))
    model = DetBenchTrain(model, config)

    model.cuda()
    print(model.model.class_net.predict.conv_pw)
    # FIXME create model factory, pretrained zoo
    # model = create_model(
    #     args.model,
    #     pretrained=args.pretrained,
    #     num_classes=args.num_classes,
    #     drop_rate=args.drop,
    #     drop_connect_rate=args.drop_connect,  # DEPRECATED, use drop_path
    #     drop_path_rate=args.drop_path,
    #     drop_block_rate=args.drop_block,
    #     global_pool=args.gp,
    #     bn_tf=args.bn_tf,
    #     bn_momentum=args.bn_momentum,
    #     bn_eps=args.bn_eps,
    #     checkpoint_path=args.initial_checkpoint)

    if args.local_rank == 0:
        logging.info('Model %s created, param count: %d' %
                     (args.model, sum([m.numel()
                                       for m in model.parameters()])))

    optimizer = create_optimizer(args, model)
    use_amp = False
    if has_apex and args.amp:
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        use_amp = True
    if args.local_rank == 0:
        logging.info('NVIDIA APEX {}. AMP {}.'.format(
            'installed' if has_apex else 'not installed',
            'on' if use_amp else 'off'))

    # optionally resume from a checkpoint
    resume_state = {}
    resume_epoch = None
    if args.resume:
        resume_state, resume_epoch = resume_checkpoint(_unwrap_bench(model),
                                                       args.resume)
    if resume_state and not args.no_resume_opt:
        if 'optimizer' in resume_state:
            if args.local_rank == 0:
                logging.info('Restoring Optimizer state from checkpoint')
            optimizer.load_state_dict(resume_state['optimizer'])
        if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__:
            if args.local_rank == 0:
                logging.info('Restoring NVIDIA AMP state from checkpoint')
            amp.load_state_dict(resume_state['amp'])
    del resume_state

    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        model_ema = ModelEma(model,
                             decay=args.model_ema_decay,
                             device='cpu' if args.model_ema_force_cpu else '')
        #resume=args.resume)  # FIXME bit of a mess with bench
        if args.resume:
            load_checkpoint(_unwrap_bench(model_ema),
                            args.resume,
                            use_ema=True)

    if args.distributed:
        if args.sync_bn:
            try:
                if has_apex:
                    model = convert_syncbn_model(model)
                else:
                    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                        model)
                if args.local_rank == 0:
                    logging.info(
                        'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
                        'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.'
                    )
            except Exception as e:
                logging.error(
                    'Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1'
                )
        if has_apex:
            model = DDP(model, delay_allreduce=True)
        else:
            if args.local_rank == 0:
                logging.info(
                    "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP."
                )
            model = DDP(model,
                        device_ids=[args.local_rank
                                    ])  # can use device str in Torch >= 1.1
        # NOTE: EMA model does not need to be wrapped by DDP
    lr_scheduler, num_epochs = create_scheduler(args, optimizer)
    start_epoch = 0
    if args.start_epoch is not None:
        # a specified start_epoch will always override the resume epoch
        start_epoch = args.start_epoch
    elif resume_epoch is not None:
        start_epoch = resume_epoch
    if lr_scheduler is not None and start_epoch > 0:
        lr_scheduler.step(start_epoch)

    if args.local_rank == 0:
        logging.info('Scheduled epochs: {}'.format(num_epochs))

    train_anno_set = 'train_small'
    train_annotation_path = os.path.join(args.data, 'annotations_small',
                                         f'train_annotations.json')
    train_image_dir = train_anno_set
    dataset_train = CocoDetection(os.path.join(args.data, train_image_dir),
                                  train_annotation_path)

    # FIXME cutmix/mixup worth investigating?
    # collate_fn = None
    # if args.prefetcher and args.mixup > 0:
    #     collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes)

    loader_train = create_loader(
        dataset_train,
        input_size=config.image_size,
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=args.prefetcher,
        #re_prob=args.reprob,  # FIXME add back various augmentations
        #re_mode=args.remode,
        #re_count=args.recount,
        #re_split=args.resplit,
        #color_jitter=args.color_jitter,
        #auto_augment=args.aa,
        interpolation=args.train_interpolation,
        #mean=data_config['mean'],
        #std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        #collate_fn=collate_fn,
        pin_mem=args.pin_mem,
    )

    #train_anno_set = 'valid_small'
    #train_annotation_path = os.path.join(args.data, 'annotations_small', f'valid_annotations.json')
    #train_image_dir = train_anno_set
    dataset_eval = CocoDetection(os.path.join(args.data, train_image_dir),
                                 train_annotation_path)

    loader_eval = create_loader(
        dataset_eval,
        input_size=config.image_size,
        batch_size=args.validation_batch_size_multiplier * args.batch_size,
        is_training=False,
        use_prefetcher=args.prefetcher,
        interpolation=args.interpolation,
        #mean=data_config['mean'],
        #std=data_config['std'],
        num_workers=args.workers,
        #distributed=args.distributed,
        pin_mem=args.pin_mem,
    )

    evaluator = COCOEvaluator(dataset_eval.coco, distributed=args.distributed)

    eval_metric = args.eval_metric
    best_metric = None
    best_epoch = None
    saver = None
    output_dir = ''
    if args.local_rank == 0:
        output_base = args.output if args.output else './output'
        exp_name = '-'.join(
            [datetime.now().strftime("%Y%m%d-%H%M%S"), args.model])
        output_dir = get_outdir(output_base, 'train', exp_name)
        decreasing = True if eval_metric == 'loss' else False
        saver = CheckpointSaver(checkpoint_dir=output_dir,
                                decreasing=decreasing)
        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
            f.write(args_text)

    try:
        for epoch in range(start_epoch, num_epochs):
            if args.distributed:
                loader_train.sampler.set_epoch(epoch)

            train_metrics = train_epoch(epoch,
                                        model,
                                        loader_train,
                                        optimizer,
                                        args,
                                        lr_scheduler=lr_scheduler,
                                        saver=saver,
                                        output_dir=output_dir,
                                        use_amp=use_amp,
                                        model_ema=model_ema)

            if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                if args.local_rank == 0:
                    logging.info(
                        "Distributing BatchNorm running means and vars")
                distribute_bn(model, args.world_size, args.dist_bn == 'reduce')

            eval_metrics = validate(model, loader_eval, args, evaluator)

            if model_ema is not None and not args.model_ema_force_cpu:
                if args.distributed and args.dist_bn in ('broadcast',
                                                         'reduce'):
                    distribute_bn(model_ema, args.world_size,
                                  args.dist_bn == 'reduce')

                ema_eval_metrics = validate(model_ema.ema,
                                            loader_eval,
                                            args,
                                            log_suffix=' (EMA)')
                eval_metrics = ema_eval_metrics

            if lr_scheduler is not None:
                # step LR for next epoch
                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])

            if saver is not None:
                update_summary(epoch,
                               train_metrics,
                               eval_metrics,
                               os.path.join(output_dir, 'summary.csv'),
                               write_header=best_metric is None)

                # save proper checkpoint with eval metric
                save_metric = eval_metrics[eval_metric]
                best_metric, best_epoch = saver.save_checkpoint(
                    _unwrap_bench(model),
                    optimizer,
                    args,
                    epoch=epoch,
                    model_ema=_unwrap_bench(model_ema),
                    metric=save_metric,
                    use_amp=use_amp)

    except KeyboardInterrupt:
        pass
    if best_metric is not None:
        logging.info('*** Best metric: {0} (epoch {1})'.format(
            best_metric, best_epoch))