Exemplo n.º 1
0
def validate_multi_class(config, data_loader, model):
    criterion = SoftTargetCrossEntropy()
    model.eval()

    batch_time = AverageMeter()
    loss_meter = AverageMeter()
    ''' calculate loss and f1 score '''
    tp_sum, fp_sum, fn_sum = 0, 0, 0

    end = time.time()
    for idx, (images, target) in enumerate(data_loader):
        images = images.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)

        # compute output
        output = model(images)

        # measure accuracy and record loss
        loss = criterion(output, target)
        
        tp, fp, fn = tfpn(output, target)        
        # acc1, acc5 = accuracy(output, target, topk=(1, 5))

        # acc1 = reduce_tensor(acc1)
        # acc5 = reduce_tensor(acc5)
        # loss = reduce_tensor(loss)

        loss_meter.update(loss.item(), target.size(0))
        tp_sum += tp
        fp_sum += fp
        fn_sum += fn

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if idx % config.PRINT_FREQ == 0:
            memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
            logger.info(
                f'Test: [{idx}/{len(data_loader)}]\t'
                f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
                f'precision {tp/(tp+fp):.3f} \t'
                f'recall {tp/(tp+fn):.3f} \t'
                f'Mem {memory_used:.0f}MB')
    precision = tp/(tp+fp)
    recall = tp/(tp+fn)
    f1_score = 2*precision*recall / (precision+recall)
    logger.info(f' * precision {precision:.3f} recall {recall:.3f} f1 {f1_score:.3f}')
    return f1_score, precision, recall
Exemplo n.º 2
0
def validate(model, loader, criterion, log_freq=50):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    model.eval()
    end = time.time()
    with torch.no_grad():
        for i, (input, target) in enumerate(loader):
            target = target.cuda()
            input = input.cuda()

            # compute output
            output = model(input)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output.data, target, topk=(1, 5))
            losses.update(loss.item(), input.size(0))
            top1.update(acc1.item(), input.size(0))
            top5.update(acc5.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % log_freq == 0:
                logging.info(
                    'Test: [{0:>4d}/{1}]  '
                    'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                    'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
                    'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
                        i,
                        len(loader),
                        batch_time=batch_time,
                        rate_avg=input.size(0) / batch_time.avg,
                        loss=losses,
                        top1=top1,
                        top5=top5))

    results = OrderedDict(top1=round(top1.avg, 4),
                          top1_err=round(100 - top1.avg, 4),
                          top5=round(top5.avg, 4),
                          top5_err=round(100 - top5.avg, 4))

    logging.info(' * Acc@1 {:.1f} ({:.3f}) Acc@5 {:.1f} ({:.3f})'.format(
        results['top1'], results['top1_err'], results['top5'],
        results['top5_err']))
Exemplo n.º 3
0
def validate(args):
    setup_default_logging()

    if args.amp:
        if has_apex:
            args.apex_amp = True
        elif has_native_amp:
            args.native_amp = True
    assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set."
    args.pretrained = args.pretrained or not args.checkpoint  # might as well try to validate something
    args.prefetcher = not args.no_prefetcher

    # create model
    with set_layer_config(scriptable=args.torchscript):
        bench = create_model(
            args.model,
            bench_task='predict',
            num_classes=args.num_classes,
            pretrained=args.pretrained,
            redundant_bias=args.redundant_bias,
            soft_nms=args.soft_nms,
            checkpoint_path=args.checkpoint,
            checkpoint_ema=args.use_ema,
        )
    model_config = bench.config

    param_count = sum([m.numel() for m in bench.parameters()])
    print('Model %s created, param count: %d' % (args.model, param_count))

    bench = bench.cuda()

    amp_autocast = suppress
    if args.apex_amp:
        bench = amp.initialize(bench, opt_level='O1')
        print('Using NVIDIA APEX AMP. Validating in mixed precision.')
    elif args.native_amp:
        amp_autocast = torch.cuda.amp.autocast
        print('Using native Torch AMP. Validating in mixed precision.')
    else:
        print('AMP not enabled. Validating in float32.')

    if args.num_gpu > 1:
        bench = torch.nn.DataParallel(bench,
                                      device_ids=list(range(args.num_gpu)))

    dataset = create_dataset(args.dataset, args.root, args.split)
    input_config = resolve_input_config(args, model_config)
    loader = create_loader(dataset,
                           input_size=input_config['input_size'],
                           batch_size=args.batch_size,
                           use_prefetcher=args.prefetcher,
                           interpolation=input_config['interpolation'],
                           fill_color=input_config['fill_color'],
                           mean=input_config['mean'],
                           std=input_config['std'],
                           num_workers=args.workers,
                           pin_mem=args.pin_mem)

    evaluator = create_evaluator(args.dataset, dataset, pred_yxyx=False)
    bench.eval()
    batch_time = AverageMeter()
    end = time.time()
    last_idx = len(loader) - 1
    with torch.no_grad():
        for i, (input, target) in enumerate(loader):
            with amp_autocast():
                output = bench(input, img_info=target)
            evaluator.add_predictions(output, target)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.log_freq == 0 or i == last_idx:
                print(
                    'Test: [{0:>4d}/{1}]  '
                    'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    .format(i,
                            len(loader),
                            batch_time=batch_time,
                            rate_avg=input.size(0) / batch_time.avg))

    mean_ap = 0.
    if dataset.parser.has_labels:
        mean_ap = evaluator.evaluate()
    else:
        evaluator.save(args.results)

    return mean_ap
Exemplo n.º 4
0
def validate(args):
    # might as well try to validate something
    args.pretrained = args.pretrained or not args.checkpoint
    args.prefetcher = not args.no_prefetcher
    amp_autocast = suppress  # do nothing
    if args.amp:
        if has_apex:
            args.apex_amp = True
        elif has_native_amp:
            args.native_amp = True
        else:
            _logger.warning(
                "Neither APEX or Native Torch AMP is available, using FP32.")
    assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set."
    if args.native_amp:
        amp_autocast = torch.cuda.amp.autocast

    if args.legacy_jit:
        set_jit_legacy()

    # create model
    model = create_model(args.model,
                         pretrained=args.pretrained,
                         num_classes=args.num_classes,
                         in_chans=3,
                         global_pool=args.gp,
                         scriptable=args.torchscript)
    if args.num_classes is None:
        assert hasattr(
            model, 'num_classes'
        ), 'Model must have `num_classes` attr if not set on cmd line/config.'
        args.num_classes = model.num_classes

    if args.checkpoint:
        load_checkpoint(model, args.checkpoint, args.use_ema)

    param_count = sum([m.numel() for m in model.parameters()])
    _logger.info('Model %s created, param count: %d' %
                 (args.model, param_count))

    data_config = resolve_data_config(vars(args), model=model)
    model, test_time_pool = (
        model, False) if args.no_test_pool else apply_test_time_pool(
            model, data_config)

    if args.torchscript:
        torch.jit.optimized_execution(True)
        model = torch.jit.script(model)

    model = model.cuda()
    if args.apex_amp:
        model = amp.initialize(model, opt_level='O1')

    if args.channels_last:
        model = model.to(memory_format=torch.channels_last)

    if args.num_gpu > 1:
        model = torch.nn.DataParallel(model,
                                      device_ids=list(range(args.num_gpu)))

    criterion = nn.CrossEntropyLoss().cuda()

    dataset = create_dataset(root=args.data,
                             name=args.dataset,
                             split=args.split,
                             load_bytes=args.tf_preprocessing,
                             class_map=args.class_map)

    if args.valid_labels:
        with open(args.valid_labels, 'r') as f:
            valid_labels = {int(line.rstrip()) for line in f}
            valid_labels = [i in valid_labels for i in range(args.num_classes)]
    else:
        valid_labels = None

    if args.real_labels:
        real_labels = RealLabelsImagenet(dataset.filenames(basename=True),
                                         real_json=args.real_labels)
    else:
        real_labels = None

    crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
    loader = create_loader(dataset,
                           input_size=data_config['input_size'],
                           batch_size=args.batch_size,
                           use_prefetcher=args.prefetcher,
                           interpolation=data_config['interpolation'],
                           mean=data_config['mean'],
                           std=data_config['std'],
                           num_workers=args.workers,
                           crop_pct=crop_pct,
                           pin_memory=args.pin_mem,
                           tf_preprocessing=args.tf_preprocessing)

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    model.eval()
    with torch.no_grad():
        # warmup, reduce variability of first batch time, especially for comparing torchscript vs non
        input = torch.randn((args.batch_size, ) +
                            data_config['input_size']).cuda()
        if args.channels_last:
            input = input.contiguous(memory_format=torch.channels_last)
        model(input)
        end = time.time()
        for batch_idx, (input, target) in enumerate(loader):
            if args.no_prefetcher:
                target = target.cuda()
                input = input.cuda()
            if args.channels_last:
                input = input.contiguous(memory_format=torch.channels_last)

            # compute output
            with amp_autocast():
                output = model(input)

            if valid_labels is not None:
                output = output[:, valid_labels]
            loss = criterion(output, target)

            if real_labels is not None:
                real_labels.add_result(output)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5))
            losses.update(loss.item(), input.size(0))
            top1.update(acc1.item(), input.size(0))
            top5.update(acc5.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if batch_idx % args.log_freq == 0:
                _logger.info(
                    'Test: [{0:>4d}/{1}]  '
                    'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                    'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
                    'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
                        batch_idx,
                        len(loader),
                        batch_time=batch_time,
                        rate_avg=input.size(0) / batch_time.avg,
                        loss=losses,
                        top1=top1,
                        top5=top5))

    if real_labels is not None:
        # real labels mode replaces topk values at the end
        top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(
            k=5)
    else:
        top1a, top5a = top1.avg, top5.avg
    results = OrderedDict(top1=round(top1a, 4),
                          top1_err=round(100 - top1a, 4),
                          top5=round(top5a, 4),
                          top5_err=round(100 - top5a, 4),
                          param_count=round(param_count / 1e6, 2),
                          img_size=data_config['input_size'][-1],
                          cropt_pct=crop_pct,
                          interpolation=data_config['interpolation'])

    _logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
        results['top1'], results['top1_err'], results['top5'],
        results['top5_err']))

    return results
Exemplo n.º 5
0
def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch,
                    mixup_fn, lr_scheduler):
    model.train()
    optimizer.zero_grad()

    num_steps = len(data_loader)
    batch_time = AverageMeter()
    loss_meter = AverageMeter()
    norm_meter = AverageMeter()

    start = time.time()
    end = time.time()
    for idx, (samples, targets) in enumerate(data_loader):
        samples = samples.cuda(non_blocking=True)
        targets = targets.cuda(non_blocking=True)

        if mixup_fn is not None:
            samples, targets = mixup_fn(samples, targets)

        outputs = model(samples)

        if config.TRAIN.ACCUMULATION_STEPS > 1:
            loss = criterion(outputs, targets)
            loss = loss / config.TRAIN.ACCUMULATION_STEPS
            if config.AMP_OPT_LEVEL != "O0":
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                if config.TRAIN.CLIP_GRAD:
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), config.TRAIN.CLIP_GRAD)
                else:
                    grad_norm = get_grad_norm(amp.master_params(optimizer))
            else:
                loss.backward()
                if config.TRAIN.CLIP_GRAD:
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        model.parameters(), config.TRAIN.CLIP_GRAD)
                else:
                    grad_norm = get_grad_norm(model.parameters())
            if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0:
                optimizer.step()
                optimizer.zero_grad()
                lr_scheduler.step_update(epoch * num_steps + idx)
        else:
            loss = criterion(outputs, targets)
            optimizer.zero_grad()
            if config.AMP_OPT_LEVEL != "O0":
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                if config.TRAIN.CLIP_GRAD:
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), config.TRAIN.CLIP_GRAD)
                else:
                    grad_norm = get_grad_norm(amp.master_params(optimizer))
            else:
                loss.backward()
                if config.TRAIN.CLIP_GRAD:
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        model.parameters(), config.TRAIN.CLIP_GRAD)
                else:
                    grad_norm = get_grad_norm(model.parameters())
            optimizer.step()
            lr_scheduler.step_update(epoch * num_steps + idx)

        torch.cuda.synchronize()

        loss_meter.update(loss.item(), targets.size(0))
        norm_meter.update(grad_norm)
        batch_time.update(time.time() - end)
        end = time.time()

        if idx % config.PRINT_FREQ == 0:
            lr = optimizer.param_groups[0]['lr']
            memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
            etas = batch_time.avg * (num_steps - idx)
            logger.info(
                f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t'
                f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t'
                f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t'
                f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
                f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t'
                f'mem {memory_used:.0f}MB')
    epoch_time = time.time() - start
    logger.info(
        f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}"
    )
Exemplo n.º 6
0
def validate(args):
    _logger.info(f'\n\n ---------------EVALUATION {args.eps}------------------------------- \n\n')
    _logger.info("Argument parser collected the following arguments:")
    for arg in vars(args):
        _logger.info(f"    {arg}:{getattr(args, arg)}")
    _logger.info("\n")

    # might as well try to validate something
    args.pretrained = args.pretrained or not args.checkpoint
    args.prefetcher = not args.no_prefetcher
    amp_autocast = suppress  # do nothing
    if args.amp:
        if has_native_amp:
            args.native_amp = True
        elif has_apex:
            args.apex_amp = True
        else:
            _logger.warning("Neither APEX or Native Torch AMP is available.")
    assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set."
    if args.native_amp:
        amp_autocast = torch.cuda.amp.autocast
        _logger.info('Validating in mixed precision with native PyTorch AMP.')
    elif args.apex_amp:
        _logger.info('Validating in mixed precision with NVIDIA APEX AMP.')
    else:
        _logger.info('Validating in float32. AMP not enabled.')

    if args.legacy_jit:
        set_jit_legacy()

    # create model
    model = create_model(
        args.model,
        pretrained=args.pretrained,
        num_classes=args.num_classes,
        in_chans=3,
        global_pool=args.gp,
        scriptable=args.torchscript)
    if args.num_classes is None:
        assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
        args.num_classes = model.num_classes

    if args.checkpoint:
        load_checkpoint(model, args.checkpoint, args.use_ema)

    param_count = sum([m.numel() for m in model.parameters()])        
    _logger.info(
        f'Model {args.model} created, param count: {param_count} ({(float(param_count)/(10.0**6)):.1f} M)'
    )

    data_config = resolve_data_config(vars(args), model=model, use_test_size=True, verbose=True)
    test_time_pool = False
    if not args.no_test_pool:
        model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True)

    if args.torchscript:
        torch.jit.optimized_execution(True)
        model = torch.jit.script(model)

    model = model.cuda()
    if args.apex_amp:
        model = amp.initialize(model, opt_level='O1')

    if args.channels_last:
        model = model.to(memory_format=torch.channels_last)

    if args.num_gpu > 1:
        model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))

    criterion = nn.CrossEntropyLoss().cuda()

    dataset = create_dataset(
        root=args.data_dir, name=args.dataset, split=args.split,
        load_bytes=args.tf_preprocessing, class_map=args.class_map)

    if args.valid_labels:
        with open(args.valid_labels, 'r') as f:
            valid_labels = {int(line.rstrip()) for line in f}
            valid_labels = [i in valid_labels for i in range(args.num_classes)]
    else:
        valid_labels = None

    if args.real_labels:
        real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels)
    else:
        real_labels = None

    crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
    loader = create_loader(
        dataset,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        crop_pct=crop_pct,
        pin_memory=args.pin_mem,
        tf_preprocessing=args.tf_preprocessing)

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    top1_fgm_ae = AverageMeter()
    top5_fgm_ae = AverageMeter()
    top1_pgd_ae = AverageMeter()
    top5_pgd_ae = AverageMeter()

    model.eval()
    #with torch.no_grad():# TODO Requires grad
    # warmup, reduce variability of first batch time, especially for comparing torchscript vs non
    input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).cuda()
    if args.channels_last:
        input = input.contiguous(memory_format=torch.channels_last)
    model(input)
    end = time.time()
    for batch_idx, (input, target) in enumerate(loader):
        if args.no_prefetcher:
            target = target.cuda()
            input = input.cuda()
        if args.channels_last:
            input = input.contiguous(memory_format=torch.channels_last)

        # compute output
        with amp_autocast():
            output = model(input)

        if valid_labels is not None:
            output = output[:, valid_labels]
        loss = criterion(output, target)

        if real_labels is not None:
            real_labels.add_result(output)

        # TODO <---------------------
        # Generate adversarial examples for current inputs
        input_fgm_ae = fast_gradient_method(
            model_fn=model,
            x=input,
            eps=args.eps,
            norm=np.inf,
            clip_min=None,
            clip_max=None,
        )
        input_pgd_ae = projected_gradient_descent(
            model_fn=model,
            x=input, 
            eps=args.eps, 
            eps_iter=0.01, 
            nb_iter=40, 
            norm=np.inf,
            clip_min=None,
            clip_max=None,
        )
        # Predict with Adversarial Examples
        with torch.no_grad():
            with amp_autocast():
                output_fgm_ae = model(input_fgm_ae)
                output_pgd_ae = model(input_pgd_ae)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5))
        losses.update(loss.item(), input.size(0))
        top1.update(acc1.item(), input.size(0))
        top5.update(acc5.item(), input.size(0))

        acc1_fgm_ae, acc5_fgm_ae = accuracy(output_fgm_ae.detach(), target, topk=(1, 5))
        acc1_pgd_ae, acc5_pgd_ae = accuracy(output_pgd_ae.detach(), target, topk=(1, 5))
        top1_fgm_ae.update(acc1_fgm_ae.item(), input.size(0))
        top5_fgm_ae.update(acc5_fgm_ae.item(), input.size(0))
        top1_pgd_ae.update(acc1_pgd_ae.item(), input.size(0))
        top5_pgd_ae.update(acc5_pgd_ae.item(), input.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if batch_idx % args.log_freq == 0:
            _logger.info(
                'Test: [{0:>4d}/{1}]  '
                'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
                'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
                    batch_idx, len(loader), batch_time=batch_time,
                    rate_avg=input.size(0) / batch_time.avg,
                    loss=losses, top1=top1, top5=top5))

    if real_labels is not None:
        raise NotImplementedError # TODO NOt modified for the adversarial examples mode 
        # real labels mode replaces topk values at the end
        top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(k=5)
    else:
        top1a, top5a = top1.avg, top5.avg
        top1a_fgm_ae, top5a_fgm_ae = top1_fgm_ae.avg, top5_fgm_ae.avg
        top1a_pgd_ae, top5a_pgd_ae = top1_pgd_ae.avg, top5_pgd_ae.avg
    results = OrderedDict(
        top1=round(top1a, 4), top1_err=round(100 - top1a, 4),
        top5=round(top5a, 4), top5_err=round(100 - top5a, 4),
        top1_fgm_ae=round(top1a_fgm_ae, 4),
        top5_fgm_ae=round(top5a_fgm_ae, 4),
        top1_pgd_ae=round(top1a_pgd_ae, 4),
        top5_pgd_ae=round(top5a_pgd_ae, 4),
        param_count=round(param_count / 1e6, 2),
        img_size=data_config['input_size'][-1],
        cropt_pct=crop_pct,
        interpolation=data_config['interpolation'])

    _logger.info(' * [Regular] Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
       results['top1'], results['top1_err'], results['top5'], results['top5_err']))

    _logger.info(' * [FGM Adversarial Attack] Acc@1 {:.3f}  Acc@5 {:.3f} '.format(
       results['top1_fgm_ae'], results['top5_fgm_ae']))
    _logger.info(' * [PGD Adversarial Attack] Acc@1 {:.3f}  Acc@5 {:.3f} '.format(
       results['top1_pgd_ae'], results['top5_pgd_ae']))

    return results
def validate(args):
    # might as well try to validate something
    args.pretrained = args.pretrained or not args.checkpoint
    args.prefetcher = not args.no_prefetcher
#    amp_autocast = suppress  # do nothing
#   if args.amp:
#        if has_native_amp:
#            args.native_amp = True
#        elif has_apex:
#            args.apex_amp = True
#        else:
#            _logger.warning("Neither APEX or Native Torch AMP is available.")
#    assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set."
#    if args.native_amp:
#        amp_autocast = torch.cuda.amp.autocast
#        _logger.info('Validating in mixed precision with native PyTorch AMP.')
#   elif args.apex_amp:
#        _logger.info('Validating in mixed precision with NVIDIA APEX AMP.')
#    else:
#        _logger.info('Validating in float32. AMP not enabled.')

    if args.legacy_jit:
        set_jit_legacy()

    # create model
    model = create_model(
        args.model,
        pretrained=args.pretrained,
        num_classes=args.num_classes,
        in_chans=3,
        global_pool=args.gp,
        scriptable=args.torchscript)
    if args.num_classes is None:
        assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
        args.num_classes = model.num_classes

    if args.checkpoint:
        load_checkpoint(model, args.checkpoint, args.use_ema)

    param_count = sum([m.numel() for m in model.parameters()])
    _logger.info('Model %s created, param count: %d' % (args.model, param_count))

    data_config = resolve_data_config(vars(args), model=model, use_test_size=True)
    test_time_pool = False
    if not args.no_test_pool:
        model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True)

    if args.torchscript:
        torch.jit.optimized_execution(True)
        model = torch.jit.script(model)

#    model = model.cuda()
#    if args.apex_amp:
#        model = amp.initialize(model, opt_level='O1')

    if args.channels_last:
        model = model.to(memory_format=torch.channels_last)

#    if args.num_gpu > 1:
#        model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))

#   criterion = nn.CrossEntropyLoss().cuda()
    criterion = nn.CrossEntropyLoss()
    dataset = create_dataset(
        root=args.data, name=args.dataset, split=args.split,
        load_bytes=args.tf_preprocessing, class_map=args.class_map)

    # added for post quantization calibration

    calib_dataset = create_dataset(
        root=args.data, name=args.dataset, split=args.split,
        load_bytes=args.tf_preprocessing, class_map=args.class_map)
        

    if args.valid_labels:
        with open(args.valid_labels, 'r') as f:
            valid_labels = {int(line.rstrip()) for line in f}
            valid_labels = [i in valid_labels for i in range(args.num_classes)]
    else:
        valid_labels = None

    if args.real_labels:
        real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels)
    else:
        real_labels = None

    crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
    loader = create_loader(
        dataset,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        crop_pct=crop_pct,
        pin_memory=args.pin_mem,
        tf_preprocessing=args.tf_preprocessing)

    #Also create loader for calibration dataset
    calib_loader = create_loader(
        calib_dataset,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        crop_pct=crop_pct,
        pin_memory=args.pin_mem,
        tf_preprocessing=args.tf_preprocessing)

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    print('Start calibration of quantization observers before post-quantization')
    model_to_quantize = copy.deepcopy(model)
    model_to_quantize.eval()

    #post training static quantization
    if args.quant_option == 'static':
        qconfig_dict = {"": torch.quantization.default_static_qconfig} 
        model_to_quantize = copy.deepcopy(model_fp)
        qconfig_dict = {"": torch.quantization.get_default_qconfig('qnnpack')}
        model_to_quantize.eval()
        # prepare
        model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_dict)
        # calibrate 
        with torch.no_grad():
            # warmup, reduce variability of first batch time, especially for comparing torchscript vs non
            input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])) 
            if args.channels_last:
                input = input.contiguous(memory_format=torch.channels_last)
            model(input)
            end = time.time()
            for batch_idx, (input, target) in enumerate(loader):

                if args.channels_last:
                    input = input.contiguous(memory_format=torch.channels_last)

                if valid_labels is not None:
                    output = output[:, valid_labels]
                loss = criterion(output, target)

                if real_labels is not None:
                    real_labels.add_result(output)

                # measure accuracy and record loss
                acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5))
                losses.update(loss.item(), input.size(0))
                top1.update(acc1.item(), input.size(0))
                top5.update(acc5.item(), input.size(0))

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                if batch_idx % args.log_freq == 0:
                    _logger.info(
                        'Test: [{0:>4d}/{1}]  '
                        'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                        'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                        'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
                        'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
                            batch_idx, len(loader), batch_time=batch_time,
                            rate_avg=input.size(0) / batch_time.avg,
                            loss=losses, top1=top1, top5=top5))        
        # quantize
        model_quantized = quantize_fx.convert_fx(model_prepared)           
    #post training dynamic/weight only quantization    
    elif args.quant_option == 'dynamic':    
        qconfig_dict = {"": torch.quantization.default_dynamic_qconfig}
        # prepare
        model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_dict)
        # no calibration needed when we only have dynamici/weight_only quantization
        # quantize
        model_quantized = quantize_fx.convert_fx(model_prepared)       
    else:
        _logger.warning("Invalid quantization option. Set option to default(static)")
    #
    # fusion
    #
    model_to_quantize = copy.deepcopy(model_fp)
    model_fused = quantize_fx.fuse_fx(model_to_quantize)   

    model = model_fused

    with torch.no_grad():
        # warmup, reduce variability of first batch time, especially for comparing torchscript vs non
#        input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).cuda()
        input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])) 
        if args.channels_last:
            input = input.contiguous(memory_format=torch.channels_last)
        model(input)
        end = time.time()
        for batch_idx, (input, target) in enumerate(loader):
 #           if args.no_prefetcher:
 #               target = target.cuda()
 #               input = input.cuda()
            if args.channels_last:
                input = input.contiguous(memory_format=torch.channels_last)

            # compute output
    #        with amp_autocast():
    #            output = model(input)

            if valid_labels is not None:
                output = output[:, valid_labels]
            loss = criterion(output, target)

            if real_labels is not None:
                real_labels.add_result(output)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5))
            losses.update(loss.item(), input.size(0))
            top1.update(acc1.item(), input.size(0))
            top5.update(acc5.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if batch_idx % args.log_freq == 0:
                _logger.info(
                    'Test: [{0:>4d}/{1}]  '
                    'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                    'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
                    'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
                        batch_idx, len(loader), batch_time=batch_time,
                        rate_avg=input.size(0) / batch_time.avg,
                        loss=losses, top1=top1, top5=top5))

    if real_labels is not None:
        # real labels mode replaces topk values at the end
        top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(k=5)
    else:
        top1a, top5a = top1.avg, top5.avg
    results = OrderedDict(
        top1=round(top1a, 4), top1_err=round(100 - top1a, 4),
        top5=round(top5a, 4), top5_err=round(100 - top5a, 4),
        param_count=round(param_count / 1e6, 2),
        img_size=data_config['input_size'][-1],
        cropt_pct=crop_pct,
        interpolation=data_config['interpolation'])

    _logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
       results['top1'], results['top1_err'], results['top5'], results['top5_err']))

    return results
Exemplo n.º 8
0
def validate(args):
    setup_default_logging()

    def setthresh():
        if args.checkpoint.split("/")[-1].split(
                "_")[0] in getthresholds.keys():
            return getthresholds[args.checkpoint.split("/")[-1].split("_")[0]]
        else:
            a = []
            [a.append(args.threshold) for x in range(4)]
            return a

    threshs = setthresh()
    print(threshs)
    # might as well try to validate something
    args.pretrained = args.pretrained or not args.checkpoint
    args.prefetcher = not args.no_prefetcher

    # create model
    bench = create_model(
        args.model,
        bench_task='predict',
        pretrained=args.pretrained,
        redundant_bias=args.redundant_bias,
        checkpoint_path=args.checkpoint,
        checkpoint_ema=args.use_ema,
    )
    input_size = bench.config.image_size

    param_count = sum([m.numel() for m in bench.parameters()])
    print('Model %s created, param count: %d' % (args.model, param_count))

    bench = bench.cuda()
    if has_amp:
        print('Using AMP mixed precision.')
        bench = amp.initialize(bench, opt_level='O1')
    else:
        print('AMP not installed, running network in FP32.')

    if args.num_gpu > 1:
        bench = torch.nn.DataParallel(bench,
                                      device_ids=list(range(args.num_gpu)))

    if 'test' in args.anno:
        annotation_path = os.path.join(args.data, 'annotations',
                                       f'image_info_{args.anno}.json')
        image_dir = args.anno
    else:
        annotation_path = os.path.join(args.data, 'annotations',
                                       f'instances_{args.anno}.json')
        image_dir = args.anno
    print(os.path.join(args.data, image_dir), annotation_path)
    dataset = CocoDetection(os.path.join(args.data, image_dir),
                            annotation_path)

    loader = create_loader(dataset,
                           input_size=input_size,
                           batch_size=args.batch_size,
                           use_prefetcher=args.prefetcher,
                           interpolation=args.interpolation,
                           fill_color=args.fill_color,
                           num_workers=args.workers,
                           pin_mem=args.pin_mem,
                           mean=args.mean,
                           std=args.std)
    if 'test' in args.anno:
        threshold = float(args.threshold)
    else:
        threshold = .001
    img_ids = []
    results = []
    writetofilearrtay = []
    bench.eval()
    batch_time = AverageMeter()
    end = time.time()
    with torch.no_grad():
        for i, (input, target) in enumerate(loader):
            output = bench(input, target['img_scale'], target['img_size'])
            output = output.cpu()
            # print(target['img_id'])
            sample_ids = target['img_id'].cpu()

            for index, sample in enumerate(output):
                image_id = int(sample_ids[index])

                for det in sample:
                    score = float(det[4])
                    if score < threshold:  # stop when below this threshold, scores in descending order
                        coco_det = dict(image_id=image_id, category_id=-1)
                        img_ids.append(image_id)
                        results.append(coco_det)
                        break
                    coco_det = dict(image_id=image_id,
                                    bbox=det[0:4].tolist(),
                                    score=score,
                                    category_id=int(det[5]),
                                    sizes=target['img_size'].tolist()[0])
                    img_ids.append(image_id)
                    results.append(coco_det)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.log_freq == 0:
                print(
                    'Test: [{0:>4d}/{1}]  '
                    'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    .format(
                        i,
                        len(loader),
                        batch_time=batch_time,
                        rate_avg=input.size(0) / batch_time.avg,
                    ))

    if 'test' in args.anno:
        from itertools import groupby
        results.sort(key=lambda x: x['image_id'])

        f = open(
            str(args.model) + "-" + str(args.anno) + "-" + str(min(threshs)) +
            ".txt", "w+")
        # for item in tqdm(writetofilearrtay):
        xxx = 0
        for k, v in tqdm(groupby(results, key=lambda x: x['image_id'])):
            xxx += 1
            f.write(getimageNamefromid(k) +
                    ",")  #print(getimageNamefromid(k),", ")
            for i in v:
                if i['category_id'] > 0:
                    if (i['category_id'] ==1 and i['score'] >= threshs[0] ) or (i['category_id'] ==2 and i['score'] >= threshs[1] ) or \
                      (i['category_id'] ==3 and i['score'] >= threshs[2] ) or (i['category_id'] ==4 and i['score'] >= threshs[3] ) :
                        f.write(
                            str(round(i['category_id'])) + " " +
                            str(round(i['bbox'][0])) + " " +
                            str(round(i['bbox'][1])) + " " + str(
                                round(
                                    float(i['bbox'][0]) +
                                    float(i['bbox'][2]))) + " " + str(
                                        round(
                                            float(i['bbox'][1]) +
                                            float(i['bbox'][3]))) + " ")
            f.write('\n')
            # print(i['category_id']," ",i['bbox'][0]," ",i['bbox'][1]," ",i['bbox'][2]," ",i['bbox'][3]," ")
        print("generated lines:", xxx)
        f.close()

    #   f.close()
    if 'test' not in args.anno:
        array_of_dm = []
        array_of_gt = []

        i = 0
        # if 'test' in args.anno :

        for _, item in tqdm(dataset):
            # if item["img_id"] == "1000780" :
            # print(item)
            for i in range(len(item['cls'])):
                # print(str(item["img_id"]),)
                array_of_gt.append(
                    BoundingBox(imageName=str(item["img_id"]),
                                classId=item["cls"][i],
                                x=item["bbox"][i][1] * item['img_scale'],
                                y=item["bbox"][i][0] * item['img_scale'],
                                w=item["bbox"][i][3] * item['img_scale'],
                                h=item["bbox"][i][2] * item['img_scale'],
                                typeCoordinates=CoordinatesType.Absolute,
                                bbType=BBType.GroundTruth,
                                format=BBFormat.XYX2Y2,
                                imgSize=(item['img_size'][0],
                                         item['img_size'][1])))

        for item in tqdm(results):
            if item["category_id"] >= 0:
                array_of_dm.append(
                    BoundingBox(imageName=str(item["image_id"]),
                                classId=item["category_id"],
                                classConfidence=item["score"],
                                x=item['bbox'][0],
                                y=item['bbox'][1],
                                w=item['bbox'][2],
                                h=item['bbox'][3],
                                typeCoordinates=CoordinatesType.Absolute,
                                bbType=BBType.Detected,
                                format=BBFormat.XYWH,
                                imgSize=(item['sizes'][0], item['sizes'][1])))
        myBoundingBoxes = BoundingBoxes()
        # # # # Add all bounding boxes to the BoundingBoxes object:
        for box in (array_of_gt):
            myBoundingBoxes.addBoundingBox(box)
        for dm in array_of_dm:
            myBoundingBoxes.addBoundingBox(dm)

        evaluator = Evaluator()
        f1res = []
        f1resd0 = []
        f1resd10 = []
        f1resd20 = []
        f1resd40 = []
        for conf in tqdm(range(210, 600, 1)):
            metricsPerClass = evaluator.GetPascalVOCMetrics(
                myBoundingBoxes, IOUThreshold=0.5, ConfThreshold=conf / 1000.0)

            totalTP = 0
            totalp = 0
            totalFP = 0
            tp = []
            fp = []
            ta = []
            # print('-------')
            for mc in metricsPerClass:
                tp.append(mc['total TP'])
                fp.append(mc['total FP'])
                ta.append(mc['total positives'])

                totalFP = totalFP + mc['total FP']
                totalTP = totalTP + mc['total TP']
                totalp = totalp + (mc['total positives'])

            # print(totalTP," ",totalFP," ",totalp)
            if totalTP + totalFP == 0:
                p = -1
            else:
                p = totalTP / (totalTP + totalFP)
            if totalp == 0:
                r = -1
            else:
                r = totalTP / (totalp)
            f1_dict = dict(tp=totalTP,
                           fp=totalFP,
                           totalp=totalp,
                           conf=conf / 1000.0,
                           prec=p,
                           rec=r,
                           f1score=(2 * p * r) / (p + r))
            f1res.append(f1_dict)
            #must clean these parts
            f1resd0.append(
                dict(tp=tp[0],
                     fp=fp[0],
                     totalp=ta[0],
                     conf=conf / 1000.0,
                     prec=tp[0] / (tp[0] + fp[0]),
                     rec=tp[0] / ta[0],
                     f1score=(2 * (tp[0] / (tp[0] + fp[0])) *
                              (tp[0] / ta[0])) / ((tp[0] / (tp[0] + fp[0])) +
                                                  (tp[0] / ta[0]))))

            f1resd10.append(
                dict(tp=tp[1],
                     fp=fp[1],
                     totalp=ta[1],
                     conf=conf / 1000.0,
                     prec=tp[1] / (tp[1] + fp[1]),
                     rec=tp[1] / ta[1],
                     f1score=(2 * (tp[1] / (tp[1] + fp[1])) *
                              (tp[1] / ta[1])) / ((tp[1] / (tp[1] + fp[1])) +
                                                  (tp[1] / ta[1]))))

            f1resd20.append(
                dict(tp=tp[2],
                     fp=fp[2],
                     totalp=ta[2],
                     conf=conf / 1000.0,
                     prec=tp[2] / (tp[2] + fp[2]),
                     rec=tp[2] / ta[2],
                     f1score=(2 * (tp[2] / (tp[2] + fp[2])) *
                              (tp[2] / ta[2])) / ((tp[2] / (tp[2] + fp[2])) +
                                                  (tp[2] / ta[2]))))

            f1resd40.append(
                dict(tp=tp[3],
                     fp=fp[3],
                     totalp=ta[3],
                     conf=conf / 1000.0,
                     prec=tp[3] / (tp[3] + fp[3]),
                     rec=tp[3] / ta[3],
                     f1score=(2 * (tp[3] / (tp[3] + fp[3])) *
                              (tp[3] / ta[3])) / ((tp[3] / (tp[3] + fp[3])) +
                                                  (tp[3] / ta[3]))))

        sortedf1 = sorted(f1res, key=lambda k: k['f1score'], reverse=True)

        f1resd0 = sorted(f1resd0, key=lambda k: k['f1score'], reverse=True)
        f1resd10 = sorted(f1resd10, key=lambda k: k['f1score'], reverse=True)
        f1resd20 = sorted(f1resd20, key=lambda k: k['f1score'], reverse=True)
        f1resd40 = sorted(f1resd40, key=lambda k: k['f1score'], reverse=True)

        print(sortedf1[0])
        print("\n\n")
        print(f1resd0[0])
        print(f1resd10[0])
        print(f1resd20[0])
        print(f1resd40[0])
        # sortedf1 = sorted(f1res, key=lambda k: k['f1score'],reverse=True)
        # print(sortedf1[0:2])
        # json.dump(results, open(args.results, 'w'), indent=4)
        json.dump(results, open(args.results, 'w'), indent=4)
        # coco_results = dataset.coco.loadRes(args.results)
        # coco_eval = COCOeval(dataset.coco, coco_results, 'bbox')
        # coco_eval.params.imgIds = img_ids  # score only ids we've used
        # coco_eval.evaluate()
        # coco_eval.accumulate()
        # coco_eval.summarize()
        # print(coco_eval.eval['params'])

    json.dump(results, open(args.results, 'w'), indent=4)

    return results
Exemplo n.º 9
0
def main():
    setup_default_logging()
    args = parser.parse_args()
    # might as well try to do something useful...
    args.pretrained = args.pretrained or not args.checkpoint

    # create model
    model = create_model(args.model,
                         num_classes=args.num_classes,
                         in_chans=3,
                         pretrained=args.pretrained,
                         checkpoint_path=args.checkpoint)

    logging.info('Model %s created, param count: %d' %
                 (args.model, sum([m.numel() for m in model.parameters()])))

    config = resolve_data_config(vars(args), model=model)
    model, test_time_pool = apply_test_time_pool(model, config, args)

    if args.num_gpu > 1:
        model = torch.nn.DataParallel(model,
                                      device_ids=list(range(
                                          args.num_gpu))).cuda()
    else:
        model = model.cuda()

    loader = create_loader(
        Dataset(args.data),
        input_size=config['input_size'],
        batch_size=args.batch_size,
        use_prefetcher=True,
        interpolation=config['interpolation'],
        mean=config['mean'],
        std=config['std'],
        num_workers=args.workers,
        crop_pct=1.0 if test_time_pool else config['crop_pct'])

    model.eval()

    k = min(args.topk, args.num_classes)
    batch_time = AverageMeter()
    end = time.time()
    topk_ids = []
    with torch.no_grad():
        for batch_idx, (input, _) in enumerate(loader):
            input = input.cuda()
            labels = model(input)
            topk = labels.topk(k)[1]
            topk_ids.append(topk.cpu().numpy())

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if batch_idx % args.log_freq == 0:
                logging.info(
                    'Predict: [{0}/{1}] Time {batch_time.val:.3f} ({batch_time.avg:.3f})'
                    .format(batch_idx, len(loader), batch_time=batch_time))

    topk_ids = np.concatenate(topk_ids, axis=0).squeeze()

    savebase = "classification_result/"
    os.makedirs(savebase, exist_ok=True)

    classfile = "labels.txt"
    classpath = os.path.join(os.getcwd(), classfile)
    classlist = {}
    with open(classpath) as f:
        for idx, line in enumerate(f):
            val = line.split('\n')[0]
            classlist[idx] = val

    filenames = loader.dataset.filenames()
    for filepath, label in zip(filenames, topk_ids):
        filename = os.path.basename(filepath)
        prediction = classlist[label[0]]
        savedir = savebase + prediction
        savepath = savedir + "/" + filename
        os.makedirs(savedir, exist_ok=True)
        copyfile(filepath, savepath)
        print('{0} : {1}'.format(filename, prediction))
Exemplo n.º 10
0
    def validate(self, net, loader, loss_fn, amp_autocast=suppress, metric_name=None):
        losses_m = AverageMeter()
        top1_m = AverageMeter()
        top5_m = AverageMeter()
        rmse_m = AverageMeter()

        net.eval()

        with torch.no_grad():
            for batch_idx, (input, target) in enumerate(loader):
                if not self._misc_cfg.prefetcher:
                    input = input.to(self.ctx[0])
                    target = target.to(self.ctx[0])

                with amp_autocast():
                    output = net(input)
                    if self._problem_type == REGRESSION:
                        output = output.flatten()
                if isinstance(output, (tuple, list)):
                    output = output[0]

                if self._problem_type == REGRESSION:
                    if metric_name:
                        assert metric_name == 'rmse', f'{metric_name} metric not supported for regression.'
                    val_metric_score = rmse(output, target)
                else:
                    val_metric_score = accuracy(output, target, topk=(1, min(5, self.num_class)))

                # augmentation reduction
                reduce_factor = self._misc_cfg.tta
                if self._problem_type != REGRESSION and reduce_factor > 1:
                    output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
                    target = target[0:target.size(0):reduce_factor]

                loss = loss_fn(output, target)
                reduced_loss = loss.data

                if self.found_gpu:
                    torch.cuda.synchronize()

                losses_m.update(reduced_loss.item(), input.size(0))
                if self._problem_type == REGRESSION:
                    rmse_score = val_metric_score
                    rmse_m.update(rmse_score.item(), output.size(0))
                else:
                    acc1, acc5 = val_metric_score
                    acc1 /= 100
                    acc5 /= 100
                    top1_m.update(acc1.item(), output.size(0))
                    top5_m.update(acc5.item(), output.size(0))

        if self._problem_type == REGRESSION:
            self._logger.info('[Epoch %d] validation: rmse=%f', self.epoch, rmse_m.avg)
            return {'loss': losses_m.avg, 'rmse': rmse_m.avg}
        else:
            self._logger.info('[Epoch %d] validation: top1=%f top5=%f', self.epoch, top1_m.avg, top5_m.avg)
            return {'loss': losses_m.avg, 'top1': top1_m.avg, 'top5': top5_m.avg}
Exemplo n.º 11
0
    def train_one_epoch(
            self, epoch, net, loader, optimizer, loss_fn,
            lr_scheduler=None, output_dir=None, amp_autocast=suppress,
            loss_scaler=None, model_ema=None, mixup_fn=None, time_limit=math.inf):
        start_tic = time.time()
        if self._augmentation_cfg.mixup_off_epoch and epoch >= self._augmentation_cfg.mixup_off_epoch:
            if self._misc_cfg.prefetcher and loader.mixup_enabled:
                loader.mixup_enabled = False
            elif mixup_fn is not None:
                mixup_fn.mixup_enabled = False

        second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
        losses_m = AverageMeter()
        train_metric_score_m = AverageMeter()

        net.train()

        num_updates = epoch * len(loader)
        self._time_elapsed += time.time() - start_tic
        tic = time.time()
        last_tic = time.time()
        train_metric_name = 'accuracy'
        batch_idx = 0
        for batch_idx, (input, target) in enumerate(loader):
            b_tic = time.time()
            if self._time_elapsed > time_limit:
                return {'train_acc': train_metric_score_m.avg, 'train_loss': losses_m.avg, 'time_limit': True}
            if self._problem_type == REGRESSION:
                target = target.to(torch.float32)
            if not self._misc_cfg.prefetcher:
                # prefetcher would move data to cuda by default
                input, target = input.to(self.ctx[0]), target.to(self.ctx[0])
                if mixup_fn is not None:
                    input, target = mixup_fn(input, target)

            with amp_autocast():
                output = net(input)
                if self._problem_type == REGRESSION:
                    output = output.flatten()
                loss = loss_fn(output, target)
            if self._problem_type == REGRESSION:
                train_metric_name = 'rmse'
                train_metric_score = rmse(output, target)
            else:
                if output.shape == target.shape:
                    train_metric_name = 'rmse'
                    train_metric_score = rmse(output, target)
                else:
                    train_metric_score = accuracy(output, target)[0] / 100

            losses_m.update(loss.item(), input.size(0))
            train_metric_score_m.update(train_metric_score.item(), output.size(0))

            optimizer.zero_grad()
            if loss_scaler is not None:
                loss_scaler(
                    loss, optimizer,
                    clip_grad=self._optimizer_cfg.clip_grad, clip_mode=self._optimizer_cfg.clip_mode,
                    parameters=model_parameters(net, exclude_head='agc' in self._optimizer_cfg.clip_mode),
                    create_graph=second_order)
            else:
                loss.backward(create_graph=second_order)
                if self._optimizer_cfg.clip_grad is not None:
                    dispatch_clip_grad(
                        model_parameters(net, exclude_head='agc' in self._optimizer_cfg.clip_mode),
                        value=self._optimizer_cfg.clip_grad, mode=self._optimizer_cfg.clip_mode)
                optimizer.step()

            if model_ema is not None:
                model_ema.update(net)

            if self.found_gpu:
                torch.cuda.synchronize()

            num_updates += 1
            if (batch_idx+1) % self._misc_cfg.log_interval == 0:
                lrl = [param_group['lr'] for param_group in optimizer.param_groups]
                lr = sum(lrl) / len(lrl)
                self._logger.info('Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f\tlr=%f',
                                  epoch, batch_idx,
                                  self._train_cfg.batch_size*self._misc_cfg.log_interval/(time.time()-last_tic),
                                  train_metric_name, train_metric_score_m.avg, lr)
                last_tic = time.time()

                if self._misc_cfg.save_images and output_dir:
                    torchvision.utils.save_image(
                        input,
                        os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
                        padding=0,
                        normalize=True)

            if lr_scheduler is not None:
                lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)

            self._time_elapsed += time.time() - b_tic

        throughput = int(self._train_cfg.batch_size * batch_idx / (time.time() - tic))
        self._logger.info('[Epoch %d] training: %s=%f', epoch, train_metric_name, train_metric_score_m.avg)
        self._logger.info('[Epoch %d] speed: %d samples/sec\ttime cost: %f', epoch, throughput, time.time()-tic)

        end_time = time.time()
        if hasattr(optimizer, 'sync_lookahead'):
            optimizer.sync_lookahead()

        self._time_elapsed += time.time() - end_time

        return {train_metric_name: train_metric_score_m.avg, 'train_loss': losses_m.avg, 'time_limit': False}
Exemplo n.º 12
0
def validate(args):
    # might as well try to validate something
    args.pretrained = False
    args.prefetcher = True

    # create model
    model = eval(args.model)(config_path=args.config_path,
                             target_flops=args.target_flops,
                             num_classes=args.num_classes,
                             bn_momentum=args.bn_momentum,
                             activation=args.activation,
                             se=args.se)

    if args.checkpoint:
        load_checkpoint(model, args.checkpoint, True)

    param_count = sum([m.numel() for m in model.parameters()])
    logging.info('Model %s created, param count: %d' %
                 (args.model, param_count))

    data_config = resolve_data_config(vars(args), model=model)
    #model, test_time_pool = apply_test_time_pool(model, data_config, args)

    if args.num_gpu > 1:
        model = torch.nn.DataParallel(model,
                                      device_ids=list(range(
                                          args.num_gpu))).cuda()
    else:
        model = model.cuda()

    criterion = nn.CrossEntropyLoss().cuda()

    if args.lmdb:
        eval_dir = os.path.join(args.data, 'test_lmdb', 'test.lmdb')
        dataset_eval = ImageFolderLMDB(eval_dir, None, None)
    else:
        eval_dir = os.path.join(args.data, 'val')
        dataset_eval = Dataset(eval_dir)

    #crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
    crop_pct = 1.0
    loader = create_loader(dataset_eval,
                           input_size=data_config['input_size'],
                           batch_size=args.batch_size,
                           is_training=False,
                           use_prefetcher=args.prefetcher,
                           interpolation=data_config['interpolation'],
                           mean=data_config['mean'],
                           std=data_config['std'],
                           num_workers=args.workers)
    # crop_pct=crop_pct)

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    model.eval()
    end = time.time()
    with torch.no_grad():
        for i, (input, target) in enumerate(loader):

            # compute output
            output = model(input)
            loss = criterion(output, target)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
            losses.update(loss.item(), input.size(0))
            top1.update(prec1.item(), input.size(0))
            top5.update(prec5.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.log_freq == 0:
                logging.info(
                    'Test: [{0:>4d}/{1}]  '
                    'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                    'Prec@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
                    'Prec@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
                        i,
                        len(loader),
                        batch_time=batch_time,
                        rate_avg=input.size(0) / batch_time.avg,
                        loss=losses,
                        top1=top1,
                        top5=top5))

    results = OrderedDict(top1=round(top1.avg, 4),
                          top1_err=round(100 - top1.avg, 4),
                          top5=round(top5.avg, 4),
                          top5_err=round(100 - top5.avg, 4),
                          param_count=round(param_count / 1e6, 2),
                          img_size=data_config['input_size'][-1],
                          cropt_pct=crop_pct,
                          interpolation=data_config['interpolation'])

    logging.info(' * Prec@1 {:.3f} ({:.3f}) Prec@5 {:.3f} ({:.3f})'.format(
        results['top1'], results['top1_err'], results['top5'],
        results['top5_err']))

    return results
Exemplo n.º 13
0
def validate(args):
    # create model
    from tinynet import tinynet
    if args.model_name == 'tinynet_a':
        args.r = 0.86
        args.w = 1.0
        args.d = 1.2
        ckpt_path = './models/tinynet_a.pth'
    elif args.model_name == 'tinynet_b':
        args.r = 0.84
        args.w = 0.75
        args.d = 1.1
        ckpt_path = './models/tinynet_b.pth'
    elif args.model_name == 'tinynet_c':
        args.r = 0.825
        args.w = 0.54
        args.d = 0.85
        ckpt_path = './models/tinynet_c.pth'
    elif args.model_name == 'tinynet_d':
        args.r = 0.68
        args.w = 0.54
        args.d = 0.695
        ckpt_path = './models/tinynet_d.pth'
    elif args.model_name == 'tinynet_e':
        args.r = 0.475
        args.w = 0.51
        args.d = 0.60
        ckpt_path = './models/tinynet_e.pth'
    else:
        raise 'Unsupported model name.'

    model = tinynet(
        r=args.r,
        w=args.w,
        d=args.d,
    )

    state_dict = torch.load(ckpt_path)
    model.load_state_dict(state_dict, strict=False)

    params = sum([param.numel() for param in model.parameters()])
    logging.info('Model %s created, #params: %d' % (args.model_name, params))

    data_config = resolve_data_config(vars(args), model=model)

    model = model.cuda()
    criterion = nn.CrossEntropyLoss().cuda()

    dataset = Dataset(args.data)
    data_loader = create_loader(dataset,
                                is_training=False,
                                input_size=data_config['input_size'],
                                batch_size=128,
                                use_prefetcher=False,
                                interpolation=data_config['interpolation'],
                                mean=data_config['mean'],
                                std=data_config['std'],
                                num_workers=4,
                                crop_pct=data_config['crop_pct'],
                                pin_memory=False)

    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    model.eval()
    with torch.no_grad():
        for i, (input, target) in enumerate(data_loader):
            input = input.cuda()
            target = target.cuda()

            output = model(input)
            loss = criterion(output, target)

            acc1, acc5 = accuracy(output.data, target, topk=(1, 5))
            losses.update(loss.item(), input.size(0))
            top1.update(acc1.item(), input.size(0))
            top5.update(acc5.item(), input.size(0))

            if i % 100 == 0:
                logging.info(
                    'Test: [{0:>4d}/{1}]  Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})'
                    .format(i, len(data_loader), loss=losses))

    logging.info(' * Acc@1 {:.3f} Acc@5 {:.3f}'.format(top1.avg, top5.avg))
Exemplo n.º 14
0
def train_epoch_slim_gate(epoch,
                          model,
                          loader,
                          optimizer,
                          loss_fn,
                          args,
                          lr_scheduler=None,
                          saver=None,
                          output_dir='',
                          use_amp=False,
                          model_ema=None,
                          optimizer_step=1):
    start_chn_idx = 0
    num_gate = 1

    batch_time_m = AverageMeter()
    data_time_m = AverageMeter()
    losses_m = AverageMeter()
    acc_m = AverageMeter()
    flops_m = AverageMeter()
    ce_loss_m = AverageMeter()
    flops_loss_m = AverageMeter()
    acc_gate_m_l = [AverageMeter() for i in range(num_gate)]
    gate_loss_m_l = [AverageMeter() for i in range(num_gate)]
    model.train()
    for n, m in model.named_modules():  # Freeze bn
        if isinstance(m, nn.BatchNorm2d) or isinstance(m, DSBatchNorm2d):
            m.eval()

    for n, m in model.named_modules():
        if len(getattr(m, 'in_channels_list', [])) > 4:
            m.in_channels_list = m.in_channels_list[start_chn_idx:4]
            m.in_channels_list_tensor = torch.from_numpy(
                np.array(m.in_channels_list)).float().cuda()
        if len(getattr(m, 'out_channels_list', [])) > 4:
            m.out_channels_list = m.out_channels_list[start_chn_idx:4]
            m.out_channels_list_tensor = torch.from_numpy(
                np.array(m.out_channels_list)).float().cuda()

    end = time.time()
    last_idx = len(loader) - 1
    num_updates = epoch * len(loader)
    model.apply(lambda m: add_mac_hooks(m))
    for batch_idx, (input, target) in enumerate(loader):
        last_batch = batch_idx == last_idx
        data_time_m.update(time.time() - end)
        if not args.prefetcher:
            input, target = input.cuda(), target.cuda()

        if last_batch or (batch_idx + 1) % optimizer_step == 0:
            optimizer.zero_grad()
        # generate online labels
        with torch.no_grad():
            set_model_mode(model, 'smallest')
            output = model(input)
            conf_s, correct_s = accuracy(output, target, no_reduce=True)
            gate_target = [
                torch.LongTensor([0])
                if correct_s[0][idx] else torch.LongTensor([3])
                for idx in range(correct_s[0].size(0))
            ]
            gate_target = torch.stack(gate_target).squeeze(-1).cuda()
        # =============
        set_model_mode(model, 'dynamic')
        output = model(input)

        if hasattr(model, 'module'):
            model_ = model.module
        else:
            model_ = model

        #  SGS Loss
        gate_loss = 0
        gate_num = 0
        gate_loss_l = []
        gate_acc_l = []
        for n, m in model_.named_modules():
            if isinstance(m, MultiHeadGate):
                if getattr(m, 'keep_gate', None) is not None:
                    gate_num += 1
                    g_loss = loss_fn(m.keep_gate, gate_target)
                    gate_loss += g_loss
                    gate_loss_l.append(g_loss)
                    gate_acc_l.append(
                        accuracy(m.keep_gate, gate_target, topk=(1, ))[0])

        gate_loss /= gate_num

        #  MAdds Loss
        running_flops = add_flops(model)
        if isinstance(running_flops, torch.Tensor):
            running_flops = running_flops.float().mean().cuda()
        else:
            running_flops = torch.FloatTensor([running_flops]).cuda()
        flops_loss = (running_flops / 1e9)**2

        #  Target Loss, back-propagate through gumbel-softmax
        ce_loss = loss_fn(output, target)

        loss = gate_loss + ce_loss + 0.5 * flops_loss
        # loss = ce_loss
        acc1 = accuracy(output, target, topk=(1, ))[0]

        if use_amp:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        if last_batch or (batch_idx + 1) % optimizer_step == 0:
            optimizer.step()

        torch.cuda.synchronize()
        if model_ema is not None:
            model_ema.update(model)
        num_updates += 1

        if not args.distributed:
            losses_m.update(loss.item(), input.size(0))
            acc_m.update(acc1.item(), input.size(0))
            flops_m.update(running_flops.item(), input.size(0))
            ce_loss_m.update(ce_loss.item(), input.size(0))
            flops_loss_m.update(flops_loss.item(), input.size(0))
        else:
            reduced_loss = reduce_tensor(loss.data, args.world_size)
            reduced_acc = reduce_tensor(acc1, args.world_size)
            reduced_flops = reduce_tensor(running_flops, args.world_size)
            reduced_loss_flops = reduce_tensor(flops_loss, args.world_size)
            reduced_ce_loss = reduce_tensor(ce_loss, args.world_size)
            reduced_acc_gate_l = reduce_list_tensor(gate_acc_l,
                                                    args.world_size)
            reduced_gate_loss_l = reduce_list_tensor(gate_loss_l,
                                                     args.world_size)
            losses_m.update(reduced_loss.item(), input.size(0))
            acc_m.update(reduced_acc.item(), input.size(0))
            flops_m.update(reduced_flops.item(), input.size(0))
            flops_loss_m.update(reduced_loss_flops.item(), input.size(0))
            ce_loss_m.update(reduced_ce_loss.item(), input.size(0))
            for i in range(num_gate):
                acc_gate_m_l[i].update(reduced_acc_gate_l[i].item(),
                                       input.size(0))
                gate_loss_m_l[i].update(reduced_gate_loss_l[i].item(),
                                        input.size(0))
        batch_time_m.update(time.time() - end)
        if (last_batch or batch_idx % args.log_interval
                == 0) and args.local_rank == 0 and batch_idx != 0:
            lrl = [param_group['lr'] for param_group in optimizer.param_groups]
            lr = sum(lrl) / len(lrl)
            print_gate_stats(model)
            logging.info(
                'Train: {} [{:>4d}/{} ({:>3.0f}%)]  '
                'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f})  '
                'CELoss: {celoss.val:>9.6f} ({celoss.avg:>6.4f})  '
                'GateLoss: {gate_loss[0].val:>6.4f} ({gate_loss[0].avg:>6.4f})  '
                'FlopsLoss: {flopsloss.val:>9.6f} ({flopsloss.avg:>6.4f})  '
                'TrainAcc: {acc.val:>9.6f} ({acc.avg:>6.4f})  '
                'GateAcc: {acc_gate[0].val:>6.4f}({acc_gate[0].avg:>6.4f})  '
                'Flops: {flops.val:>6.0f} ({flops.avg:>6.0f})  '
                'LR: {lr:.3e}  '
                'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s  '
                '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                'DataTime: {data_time.val:.3f} ({data_time.avg:.3f})\n'.format(
                    epoch,
                    batch_idx,
                    last_idx,
                    100. * batch_idx / last_idx,
                    loss=losses_m,
                    flopsloss=flops_loss_m,
                    acc=acc_m,
                    flops=flops_m,
                    celoss=ce_loss_m,
                    batch_time=batch_time_m,
                    rate=input.size(0) * args.world_size / batch_time_m.val,
                    rate_avg=input.size(0) * args.world_size /
                    batch_time_m.avg,
                    lr=lr,
                    data_time=data_time_m,
                    gate_loss=gate_loss_m_l,
                    acc_gate=acc_gate_m_l))

        if saver is not None and args.recovery_interval and (
                last_batch or (batch_idx + 1) % args.recovery_interval == 0):
            saver.save_recovery(model,
                                optimizer,
                                args,
                                epoch,
                                model_ema=model_ema,
                                use_amp=use_amp,
                                batch_idx=batch_idx)

        if lr_scheduler is not None:
            lr_scheduler.step_update(num_updates=num_updates,
                                     metric=losses_m.avg)

        end = time.time()
        # end for

    if hasattr(optimizer, 'sync_lookahead'):
        optimizer.sync_lookahead()

    return OrderedDict([('loss', losses_m.avg)])
Exemplo n.º 15
0
def validate_gate(model, loader, loss_fn, args, log_suffix=''):
    start_chn_idx = 0
    num_gate = 1

    batch_time_m = AverageMeter()
    data_time_m = AverageMeter()
    losses_m = AverageMeter()
    prec1_m = AverageMeter()
    prec5_m = AverageMeter()
    flops_m = AverageMeter()
    acc_gate_m_l = [AverageMeter() for i in range(num_gate)]
    model.eval()

    for n, m in model.named_modules():
        if len(getattr(m, 'in_channels_list', [])) > 4:
            m.in_channels_list = m.in_channels_list[start_chn_idx:4]
            m.in_channels_list_tensor = torch.from_numpy(
                np.array(m.in_channels_list)).float().cuda()
        if len(getattr(m, 'out_channels_list', [])) > 4:
            m.out_channels_list = m.out_channels_list[start_chn_idx:4]
            m.out_channels_list_tensor = torch.from_numpy(
                np.array(m.out_channels_list)).float().cuda()

    end = time.time()
    last_idx = len(loader) - 1
    model.apply(lambda m: add_mac_hooks(m))
    for batch_idx, (input, target) in enumerate(loader):
        last_batch = batch_idx == last_idx
        data_time_m.update(time.time() - end)
        if not args.prefetcher:
            input, target = input.cuda(), target.cuda()
        # generate online labels
        with torch.no_grad():
            set_model_mode(model, 'smallest')
            output = model(input)
            conf_s, correct_s = accuracy(output, target, no_reduce=True)
            gate_target = [
                torch.LongTensor([0])
                if correct_s[0][idx] else torch.LongTensor([3])
                for idx in range(correct_s[0].size(0))
            ]
            gate_target = torch.stack(gate_target).squeeze(-1).cuda()
        # =============
        set_model_mode(model, 'dynamic')
        output = model(input)

        if hasattr(model, 'module'):
            model_ = model.module
        else:
            model_ = model

        gate_acc_l = []
        for n, m in model_.named_modules():
            if isinstance(m, MultiHeadGate):
                if getattr(m, 'keep_gate', None) is not None:
                    gate_acc_l.append(
                        accuracy(m.keep_gate, gate_target, topk=(1, ))[0])

        running_flops = add_flops(model)
        if isinstance(running_flops, torch.Tensor):
            running_flops = running_flops.float().mean().cuda()
        else:
            running_flops = torch.FloatTensor([running_flops]).cuda()

        loss = loss_fn(output, target)
        prec1, prec5 = accuracy(output, target, topk=(1, 5))
        if not args.distributed:
            losses_m.update(loss.item(), input.size(0))
            prec1_m.update(prec1.item(), input.size(0))
            prec5_m.update(prec5.item(), input.size(0))
            flops_m.update(running_flops.item(), input.size(0))
        else:
            reduced_loss = reduce_tensor(loss.data, args.world_size)
            reduced_prec1 = reduce_tensor(prec1, args.world_size)
            reduced_prec5 = reduce_tensor(prec5, args.world_size)
            reduced_flops = reduce_tensor(running_flops, args.world_size)
            reduced_acc_gate_l = reduce_list_tensor(gate_acc_l,
                                                    args.world_size)
            torch.cuda.synchronize()
            losses_m.update(reduced_loss.item(), input.size(0))
            prec1_m.update(reduced_prec1.item(), input.size(0))
            prec5_m.update(reduced_prec5.item(), input.size(0))
            flops_m.update(reduced_flops.item(), input.size(0))
            for i in range(num_gate):
                acc_gate_m_l[i].update(reduced_acc_gate_l[i].item(),
                                       input.size(0))
        batch_time_m.update(time.time() - end)
        if (last_batch or batch_idx % args.log_interval
                == 0) and args.local_rank == 0 and batch_idx != 0:
            print_gate_stats(model)
            log_name = 'Test' + log_suffix
            logging.info(
                '{}: [{:>4d}/{} ({:>3.0f}%)]  '
                'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                'Acc@1: {prec1.val:>9.6f} ({prec1.avg:>6.4f})  '
                'Acc@5: {prec5.val:>9.6f} ({prec5.avg:>6.4f})  '
                'GateAcc: {acc_gate[0].val:>6.4f}({acc_gate[0].avg:>6.4f})  '
                'Flops: {flops.val:>6.0f} ({flops.avg:>6.0f})  '
                'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s  '
                '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                'DataTime: {data_time.val:.3f} ({data_time.avg:.3f})\n'.format(
                    log_name,
                    batch_idx,
                    last_idx,
                    100. * batch_idx / last_idx,
                    loss=losses_m,
                    prec1=prec1_m,
                    prec5=prec5_m,
                    flops=flops_m,
                    batch_time=batch_time_m,
                    rate=input.size(0) * args.world_size / batch_time_m.val,
                    rate_avg=input.size(0) * args.world_size /
                    batch_time_m.avg,
                    data_time=data_time_m,
                    acc_gate=acc_gate_m_l))

        end = time.time()
        # end for
    metrics = OrderedDict([('loss', losses_m.avg), ('prec1', prec1_m.avg),
                           ('prec5', prec5_m.avg), ('flops', flops_m.avg)])

    return metrics
Exemplo n.º 16
0
def validate(model,
             loader,
             loss_fn,
             args,
             amp_autocast=suppress,
             log_suffix=''):
    batch_time_m = AverageMeter()
    losses_m = AverageMeter()
    top1_m = AverageMeter()
    top5_m = AverageMeter()

    model.eval()

    end = time.time()
    last_idx = len(loader) - 1
    with torch.no_grad():
        for batch_idx, (input, target) in enumerate(loader):
            last_batch = batch_idx == last_idx
            if not args.prefetcher:
                input = input.cuda()
                target = target.cuda()
            if args.channels_last:
                input = input.contiguous(memory_format=torch.channels_last)

            with amp_autocast():
                output = model(input)
            if isinstance(output, (tuple, list)):
                output = output[0]

            # augmentation reduction
            reduce_factor = args.tta
            if reduce_factor > 1:
                output = output.unfold(0, reduce_factor,
                                       reduce_factor).mean(dim=2)
                target = target[0:target.size(0):reduce_factor]

            loss = loss_fn(output, target)
            acc1, acc5 = accuracy(output, target, topk=(1, 5))

            if args.distributed:
                reduced_loss = reduce_tensor(loss.data, args.world_size)
                acc1 = reduce_tensor(acc1, args.world_size)
                acc5 = reduce_tensor(acc5, args.world_size)
            else:
                reduced_loss = loss.data

            torch.cuda.synchronize()

            losses_m.update(reduced_loss.item(), input.size(0))
            top1_m.update(acc1.item(), output.size(0))
            top5_m.update(acc5.item(), output.size(0))

            batch_time_m.update(time.time() - end)
            end = time.time()
            if args.local_rank == 0 and (last_batch or
                                         batch_idx % args.log_interval == 0):
                log_name = 'Test' + log_suffix
                _logger.info(
                    '{0}: [{1:>4d}/{2}]  '
                    'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})  '
                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                    'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f})  '
                    'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
                        log_name,
                        batch_idx,
                        last_idx,
                        batch_time=batch_time_m,
                        loss=losses_m,
                        top1=top1_m,
                        top5=top5_m))

    metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg),
                           ('top5', top5_m.avg)])

    return metrics
Exemplo n.º 17
0
def validate(args):
    # might as well try to validate something
    args.pretrained = args.pretrained or not args.checkpoint

    # create model
    model = create_model(
        args.model,
        num_classes=args.num_classes,
        in_chans=3,
        pretrained=args.pretrained)

    if args.checkpoint:
        load_checkpoint(model, args.checkpoint, args.use_ema)

    param_count = sum([m.numel() for m in model.parameters()])
    logging.info('Model %s created, param count: %d' % (args.model, param_count))

    data_config = resolve_data_config(model, args)
    model, test_time_pool = apply_test_time_pool(model, data_config, args)

    if args.num_gpu > 1:
        model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
    else:
        model = model.cuda()

    criterion = nn.CrossEntropyLoss().cuda()

    loader = create_loader(
        Dataset(args.data, load_bytes=args.tf_preprocessing),
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        use_prefetcher=True,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        crop_pct=1.0 if test_time_pool else data_config['crop_pct'],
        tf_preprocessing=args.tf_preprocessing)

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    model.eval()
    end = time.time()
    with torch.no_grad():
        for i, (input, target) in enumerate(loader):
            target = target.cuda()
            input = input.cuda()

            # compute output
            output = model(input)
            loss = criterion(output, target)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
            losses.update(loss.item(), input.size(0))
            top1.update(prec1.item(), input.size(0))
            top5.update(prec5.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.log_freq == 0:
                logging.info(
                    'Test: [{0:>4d}/{1}]  '
                    'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})  '
                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                    'Prec@1: {top1.val:>7.4f} ({top1.avg:>7.4f})  '
                    'Prec@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
                        i, len(loader), batch_time=batch_time,
                        rate_avg=input.size(0) / batch_time.avg,
                        loss=losses, top1=top1, top5=top5))

    results = OrderedDict(
        top1=round(top1.avg, 3), top1_err=round(100 - top1.avg, 3),
        top5=round(top5.avg, 3), top5_err=round(100 - top5.avg, 3),
        param_count=round(param_count / 1e6, 2))

    logging.info(' * Prec@1 {:.3f} ({:.3f}) Prec@5 {:.3f} ({:.3f})'.format(
       results['top1'], results['top1_err'], results['top5'], results['top5_err']))

    return results
Exemplo n.º 18
0
def validate(args):
    setup_default_logging()

    # might as well try to validate something
    args.pretrained = args.pretrained or not args.checkpoint
    args.prefetcher = not args.no_prefetcher

    # create model
    bench = create_model(args.model,
                         bench_task='predict',
                         pretrained=args.pretrained,
                         redundant_bias=args.redundant_bias,
                         checkpoint_path=args.checkpoint,
                         checkpoint_ema=args.use_ema)
    input_size = bench.config.image_size

    param_count = sum([m.numel() for m in bench.parameters()])
    print('Model %s created, param count: %d' % (args.model, param_count))

    bench = bench.cuda()
    if has_amp:
        print('Using AMP mixed precision.')
        bench = amp.initialize(bench, opt_level='O1')
    else:
        print('AMP not installed, running network in FP32.')

    if args.num_gpu > 1:
        bench = torch.nn.DataParallel(bench,
                                      device_ids=list(range(args.num_gpu)))

    if 'test' in args.anno:
        annotation_path = os.path.join(args.data, 'annotations',
                                       f'image_info_{args.anno}.json')
        image_dir = args.anno
    else:
        annotation_path = os.path.join(args.data, 'annotations',
                                       f'instances_{args.anno}.json')
        image_dir = args.anno
    dataset = CocoDetection(os.path.join(args.data, image_dir),
                            annotation_path)

    loader = create_loader(dataset,
                           input_size=input_size,
                           batch_size=args.batch_size,
                           use_prefetcher=args.prefetcher,
                           interpolation=args.interpolation,
                           fill_color=args.fill_color,
                           num_workers=args.workers,
                           mean=args.mean,
                           std=args.std,
                           pin_mem=args.pin_mem)

    img_ids = []
    results = []
    bench.eval()

    for i, (input, target) in enumerate(loader, 1):
        dumm_inp = input
        tisc = target['img_scale']
        tisz = target['img_size']
        break

    starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(
        enable_timing=True)
    # repetitions = 300
    # timings=np.zeros((repetitions,1))
    #GPU-WARM-UP
    # print(enumerate())
    for _ in range(10):
        _ = bench(dumm_inp, tisc, tisz)
    # MEASURE PERFORMANCE

    # dummy_input = torch.randn(1, 3,bench.config.image_size,bench.config.image_size,dtype=torch.float).to("cuda")
    print("starting")
    batch_time = AverageMeter()
    # end = time.time()
    with torch.no_grad():
        for _ in range(2000):
            starter.record()
            _ = bench(dumm_inp, tisc, tisz)
            ender.record()
            # measure elapsed time
            torch.cuda.synchronize()
            curr_time = starter.elapsed_time(ender)
            batch_time.update(curr_time)
            # print(curr_time)
            # end = time.time()

            # if i % args.log_freq == 0:
            print(
                'Test: [{0:>4d}/{1}]  '
                'Time: {batch_time.val:.3f}ms ({batch_time.avg:.3f}ms, {rate_avg:>7.2f}/s)  '
                .format(
                    i,
                    len(loader),
                    batch_time=batch_time,
                    rate_avg=dumm_inp.size(0) / batch_time.avg,
                ))

    # json.dump(results, open(args.results, 'w'), indent=4)
    # if 'test' not in args.anno:
    #     coco_results = dataset.coco.loadRes(args.results)
    #     coco_eval = COCOeval(dataset.coco, coco_results, 'bbox')
    #     coco_eval.params.imgIds = img_ids  # score only ids we've used
    #     coco_eval.evaluate()
    #     coco_eval.accumulate()
    #     coco_eval.summarize()

    return results
Exemplo n.º 19
0
def validate(args):
    setup_default_logging()

    # might as well try to validate something
    args.pretrained = args.pretrained or not args.checkpoint
    args.prefetcher = not args.no_prefetcher
    if args.no_redundant_bias is None:
        args.redundant_bias = None
    else:
        args.redundant_bias = not args.no_redundant_bias

    # create model
    bench = create_model(
        args.model,
        bench_task='predict',
        pretrained=args.pretrained,
        redundant_bias=args.redundant_bias,
        checkpoint_path=args.checkpoint,
        checkpoint_ema=args.use_ema,
    )
    input_size = bench.config.image_size

    param_count = sum([m.numel() for m in bench.parameters()])
    print('Model %s created, param count: %d' % (args.model, param_count))

    bench = bench.cuda()
    if has_amp:
        print('Using AMP mixed precision.')
        bench = amp.initialize(bench, opt_level='O1')
    else:
        print('AMP not installed, running network in FP32.')

    if args.num_gpu > 1:
        bench = torch.nn.DataParallel(bench,
                                      device_ids=list(range(args.num_gpu)))

    if 'test' in args.anno:
        annotation_path = os.path.join(args.data, 'annotations',
                                       f'image_info_{args.anno}.json')
        image_dir = 'test2017'
    else:
        annotation_path = os.path.join(args.data, 'annotations',
                                       f'instances_{args.anno}.json')
        image_dir = args.anno
    dataset = CocoDetection(os.path.join(args.data, image_dir),
                            annotation_path)

    loader = create_loader(dataset,
                           input_size=input_size,
                           batch_size=args.batch_size,
                           use_prefetcher=args.prefetcher,
                           interpolation=args.interpolation,
                           fill_color=args.fill_color,
                           num_workers=args.workers,
                           pin_mem=args.pin_mem)

    img_ids = []
    results = []
    bench.eval()
    batch_time = AverageMeter()
    end = time.time()
    with torch.no_grad():
        for i, (input, target) in enumerate(loader):
            output = bench(input, target['img_scale'], target['img_size'])
            output = output.cpu()
            sample_ids = target['img_id'].cpu()
            for index, sample in enumerate(output):
                image_id = int(sample_ids[index])
                for det in sample:
                    score = float(det[4])
                    if score < .001:  # stop when below this threshold, scores in descending order
                        break
                    coco_det = dict(image_id=image_id,
                                    bbox=det[0:4].tolist(),
                                    score=score,
                                    category_id=int(det[5]))
                    img_ids.append(image_id)
                    results.append(coco_det)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.log_freq == 0:
                print(
                    'Test: [{0:>4d}/{1}]  '
                    'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    .format(
                        i,
                        len(loader),
                        batch_time=batch_time,
                        rate_avg=input.size(0) / batch_time.avg,
                    ))

    json.dump(results, open(args.results, 'w'), indent=4)
    if 'test' not in args.anno:
        coco_results = dataset.coco.loadRes(args.results)
        coco_eval = COCOeval(dataset.coco, coco_results, 'bbox')
        coco_eval.params.imgIds = img_ids  # score only ids we've used
        coco_eval.evaluate()
        coco_eval.accumulate()
        coco_eval.summarize()

    return results
def train_epoch(current_epoch, loss_functions, model, optimizer, scheduler,
                train_data_loader, summary_writer, conf, local_rank):
    losses = AverageMeter()
    c_losses = AverageMeter()
    d_losses = AverageMeter()
    dices = AverageMeter()
    iterator = tqdm(train_data_loader)
    model.train()
    if conf["optimizer"]["schedule"]["mode"] == "epoch":
        scheduler.step(current_epoch)
    for i, sample in enumerate(iterator):
        imgs = sample["image"].cuda()
        masks = sample["mask"].cuda().float()
        # if torch.sum(masks) < 100:
        #     continue
        centers = sample["center"].cuda().float()

        seg_mask, center_mask = model(imgs)
        with torch.no_grad():
            pred = torch.sigmoid(seg_mask)
            d = dice_round(pred[:, 0:1, ...].cpu(),
                           masks[:, 0:1, ...].cpu(),
                           t=0.5).item()
        dices.update(d, imgs.size(0))

        mask_loss = loss_functions["mask_loss"](seg_mask, masks)
        # if torch.isnan(mask_loss):
        #     print("nan loss, skipping!!!")
        #     optimizer.zero_grad()
        #     continue
        center_loss = loss_functions["center_loss"](center_mask, centers)
        center_loss *= 50
        loss = mask_loss + center_loss

        loss /= 2
        if current_epoch == 0:
            loss /= 10
        losses.update(loss.item(), imgs.size(0))
        d_losses.update(mask_loss.item(), imgs.size(0))

        c_losses.update(center_loss.item(), imgs.size(0))
        iterator.set_postfix({
            "lr": float(scheduler.get_lr()[-1]),
            "epoch": current_epoch,
            "loss": losses.avg,
            "dice": dices.avg,
            "d_loss": d_losses.avg,
            "c_loss": c_losses.avg,
        })
        optimizer.zero_grad()
        if conf['fp16']:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 1)
        optimizer.step()
        torch.cuda.synchronize()

        if conf["optimizer"]["schedule"]["mode"] in ("step", "poly"):
            scheduler.step(i + current_epoch * len(train_data_loader))

    if local_rank == 0:
        for idx, param_group in enumerate(optimizer.param_groups):
            lr = param_group['lr']
            summary_writer.add_scalar('group{}/lr'.format(idx),
                                      float(lr),
                                      global_step=current_epoch)
        summary_writer.add_scalar('train/loss',
                                  float(losses.avg),
                                  global_step=current_epoch)
def main():
    setup_default_logging()
    args = parser.parse_args()
    # might as well try to do something useful...
    args.pretrained = args.pretrained or not args.checkpoint

    # create model
    model = create_model(args.model,
                         num_classes=args.num_classes,
                         in_chans=3,
                         pretrained=args.pretrained,
                         checkpoint_path=args.checkpoint)

    _logger.info('Model %s created, param count: %d' %
                 (args.model, sum([m.numel() for m in model.parameters()])))

    config = resolve_data_config(vars(args), model=model)
    model, test_time_pool = (
        model,
        False) if args.no_test_pool else apply_test_time_pool(model, config)

    if args.num_gpu > 1:
        model = torch.nn.DataParallel(model,
                                      device_ids=list(range(
                                          args.num_gpu))).cuda()
    else:
        model = model.cuda()

    loader = create_loader(
        ImageDataset(args.data),
        input_size=config['input_size'],
        batch_size=args.batch_size,
        use_prefetcher=True,
        interpolation=config['interpolation'],
        mean=config['mean'],
        std=config['std'],
        num_workers=args.workers,
        crop_pct=1.0 if test_time_pool else config['crop_pct'])

    model.eval()

    k = min(args.topk, args.num_classes)
    batch_time = AverageMeter()
    end = time.time()
    topk_ids = []
    with torch.no_grad():
        for batch_idx, (input, _) in enumerate(loader):
            input = input.cuda()
            labels = model(input)
            topk = labels.topk(k)[1]
            topk_ids.append(topk.cpu().numpy())

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if batch_idx % args.log_freq == 0:
                _logger.info(
                    'Predict: [{0}/{1}] Time {batch_time.val:.3f} ({batch_time.avg:.3f})'
                    .format(batch_idx, len(loader), batch_time=batch_time))

    topk_ids = np.concatenate(topk_ids, axis=0).squeeze()

    with open(os.path.join(args.output_dir, 'topk_ids.csv'), 'w') as out_file:
        filenames = loader.dataset.filenames(basename=True)
        for filename, label in zip(filenames, topk_ids):
            out_file.write('{0},{1},{2},{3},{4},{5}\n'.format(
                filename, label[0], label[1], label[2], label[3], label[4]))
Exemplo n.º 22
0
def validate(args):
    # might as well try to validate something
    args.pretrained = args.pretrained or not args.checkpoint
    args.prefetcher = not args.no_prefetcher
    amp_autocast = suppress  # do nothing
    if args.amp:
        if has_apex:
            args.apex_amp = True
        elif has_native_amp:
            args.native_amp = True
        else:
            _logger.warning(
                "Neither APEX or Native Torch AMP is available, using FP32.")
    assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set."
    if args.native_amp:
        amp_autocast = torch.cuda.amp.autocast

    if args.legacy_jit:
        set_jit_legacy()

    # create model
    model = create_model(args.model,
                         pretrained=args.pretrained,
                         num_classes=args.num_classes,
                         in_chans=3,
                         global_pool=args.gp,
                         scriptable=args.torchscript)

    if args.checkpoint:
        load_checkpoint(model, args.checkpoint, args.use_ema)

    param_count = sum([m.numel() for m in model.parameters()])
    _logger.info('Model %s created, param count: %d' %
                 (args.model, param_count))

    data_config = resolve_data_config(vars(args), model=model)
    model, test_time_pool = (
        model, False) if args.no_test_pool else apply_test_time_pool(
            model, data_config)

    if args.torchscript:
        torch.jit.optimized_execution(True)
        model = torch.jit.script(model)

    model = model.cuda()
    if args.apex_amp:
        model = amp.initialize(model, opt_level='O1')

    if args.channels_last:
        model = model.to(memory_format=torch.channels_last)

    if args.num_gpu > 1:
        model = torch.nn.DataParallel(model,
                                      device_ids=list(range(args.num_gpu)))

    criterion = nn.CrossEntropyLoss().cuda()

    if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data):
        dataset = DatasetTar(args.data,
                             load_bytes=args.tf_preprocessing,
                             class_map=args.class_map)
    else:
        dataset = Dataset(args.data,
                          train_mode='val',
                          fold_num=args.fold_num,
                          load_bytes=args.tf_preprocessing,
                          class_map=args.class_map)

    if args.valid_labels:
        with open(args.valid_labels, 'r') as f:
            valid_labels = {int(line.rstrip()) for line in f}
            valid_labels = [i in valid_labels for i in range(args.num_classes)]
    else:
        valid_labels = None

    if args.real_labels:
        real_labels = RealLabelsImagenet(dataset.filenames(basename=True),
                                         real_json=args.real_labels)
    else:
        real_labels = None

    crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
    loader = create_loader(dataset,
                           input_size=data_config['input_size'],
                           batch_size=args.batch_size,
                           use_prefetcher=args.prefetcher,
                           interpolation=data_config['interpolation'],
                           mean=data_config['mean'],
                           std=data_config['std'],
                           num_workers=args.workers,
                           crop_pct=crop_pct,
                           pin_memory=args.pin_mem,
                           tf_preprocessing=args.tf_preprocessing)

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    # top5 = AverageMeter()
    f1_m = AverageMeter()

    model.eval()
    last_idx = len(loader) - 1
    cuda = torch.device('cuda')
    temperature = nn.Parameter(torch.ones(1) *
                               1.5).to(cuda).detach().requires_grad_(True)

    m = nn.Sigmoid()
    nll_criterion = nn.CrossEntropyLoss().cuda()
    ece_criterion = _ECELoss().cuda()

    with torch.no_grad():
        # warmup, reduce variability of first batch time, especially for comparing torchscript vs non
        input = torch.randn((args.batch_size, ) +
                            data_config['input_size']).cuda()
        if args.channels_last:
            input = input.contiguous(memory_format=torch.channels_last)
        model(input)
        end = time.time()

        logits_list = []
        target_list = []

        for batch_idx, (input, target) in enumerate(loader):
            last_batch = batch_idx == last_idx
            if args.no_prefetcher:
                target = target.cuda()
                input = input.cuda()
            if args.channels_last:
                input = input.contiguous(memory_format=torch.channels_last)

            # compute output
            with amp_autocast():
                output = model(input)

            if valid_labels is not None:
                output = output[:, valid_labels]
            loss = criterion(output, target)

            if real_labels is not None:
                real_labels.add_result(output)

            # measure accuracy and record loss
            acc1, _ = accuracy(output.detach(), target, topk=(1, 1))

            logits_list.append(output)
            target_list.append(target)

            best_f1 = 0.0
            best_th = 1.0

            if last_batch:
                logits = torch.cat(logits_list).cuda()  ###
                targets = torch.cat(target_list).cuda()  ###

                targets_cpu = targets.cpu().numpy()
                sigmoided = m(logits)[:, 1].cpu().numpy()

                for i in range(1000, 0, -1):
                    th = i * 0.001
                    real_pred = (sigmoided >= th) * 1.0
                    f1 = f1_score(targets_cpu.squeeze(), real_pred.squeeze())

                    if f1 > best_f1:
                        best_f1 = f1
                        best_th = th

            losses.update(loss.item(), input.size(0))
            top1.update(acc1.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if batch_idx % args.log_freq == 0:
                _logger.info(
                    'Test: [{0:>4d}/{1}]  '
                    'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                    'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
                    'thresh: {thresh:>7.4f}  '
                    'f1: {f1:>7.4f}'.format(batch_idx,
                                            len(loader),
                                            batch_time=batch_time,
                                            rate_avg=input.size(0) /
                                            batch_time.avg,
                                            loss=losses,
                                            top1=top1,
                                            thresh=best_th,
                                            f1=best_f1))

    print(best_th, best_f1)

    #for temp_scalilng
    if args.temp_scaling:

        #         before_temperature_ece = ece_criterion(logits, targets).item()
        #         before_temperature_nll = nll_criterion(logits, targets).item()
        #         print('Before temperature - NLL: %.3f, ECE: %.3f' % (before_temperature_nll, before_temperature_ece))

        #         optimizer = optim.LBFGS([temperature], lr=0.01, max_iter=50)

        #         def eval():
        #             unsqueezed_temperature = temperature.unsqueeze(1).expand(logits.size(0), logits.size(1))
        #             loss = nll_criterion(logits/unsqueezed_temperature, targets)
        #             loss.backward()
        #             return loss
        #         optimizer.step(eval)

        #         unsqueezed_temperature = temperature.unsqueeze(1).expand(logits.size(0), logits.size(1))

        #         logits = logits/unsqueezed_temperature
        #         after_temperature_nll = nll_criterion(logits, targets).item()
        #         after_temperature_ece = ece_criterion(logits, targets).item()
        #         print('Optimal temperature: %.3f' % temperature.item())
        #         print('After temperature - NLL: %.3f, ECE: %.3f' % (after_temperature_nll, after_temperature_ece))

        sigmoided = m(logits)[:, 1].detach().cpu().numpy()
        temperature = nn.Parameter(torch.ones(1) *
                                   11).to(cuda).detach().requires_grad_(False)

        logits = logits / temperature.unsqueeze(1).expand(
            logits.size(0), logits.size(1))
        targets_cpu = targets.cpu().numpy()
        sigmoided = m(logits)[:, 1].detach().cpu().numpy()

        best_f1 = 0.0
        best_th = 1.0
        for i in range(1000, 0, -1):
            th = i * 0.001
            real_pred = (sigmoided >= th) * 1.0
            f1 = f1_score(targets_cpu.squeeze(), real_pred.squeeze())

            if f1 > best_f1:
                best_f1 = f1
                best_th = th

        print(best_th, best_f1)

    if real_labels is not None:
        # real labels mode replaces topk values at the end
        top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(
            k=5)
    else:
        top1a, f1a = top1.avg, best_f1
    results = OrderedDict(top1=round(top1a, 4),
                          top1_err=round(100 - top1a, 4),
                          f1=f1a,
                          f1_err=round(100 - f1a, 4),
                          param_count=round(param_count / 1e6, 2),
                          img_size=data_config['input_size'][-1],
                          cropt_pct=crop_pct,
                          interpolation=data_config['interpolation'])

    _logger.info(' * Acc@1 {:.3f} ({:.3f}) f1 {:.3f} ({:.3f})'.format(
        results['top1'], results['top1_err'], results['f1'],
        results['f1_err']))

    return results
def validate(args):
    # might as well try to validate something
    args.pretrained = args.pretrained or not args.checkpoint
    args.prefetcher = not args.no_prefetcher
    if args.legacy_jit:
        set_jit_legacy()

    # create model
    if 'inception' in args.model:
        model = create_model(
            args.model,
            pretrained=args.pretrained,
            num_classes=args.num_classes,
            aux_logits=True,  # ! add aux loss
            in_chans=3,
            scriptable=args.torchscript)
    else:
        model = create_model(args.model,
                             pretrained=args.pretrained,
                             num_classes=args.num_classes,
                             in_chans=3,
                             scriptable=args.torchscript)

    # ! add more layer to classifier layer
    if args.create_classifier_layerfc:
        model.global_pool, model.classifier = create_classifier_layerfc(
            model.num_features, model.num_classes)

    if args.checkpoint:
        load_checkpoint(model, args.checkpoint, args.use_ema)

    param_count = sum([m.numel() for m in model.parameters()])
    _logger.info('Model %s created, param count: %d' %
                 (args.model, param_count))

    data_config = resolve_data_config(vars(args), model=model)
    model, test_time_pool = apply_test_time_pool(model, data_config, args)

    if args.torchscript:
        torch.jit.optimized_execution(True)
        model = torch.jit.script(model)

    if args.amp:
        model = amp.initialize(model.cuda(), opt_level='O1')
    else:
        model = model.cuda()

    if args.num_gpu > 1:
        model = torch.nn.DataParallel(model,
                                      device_ids=list(range(args.num_gpu)))

    if args.has_eval_label:
        criterion = nn.CrossEntropyLoss().cuda()  # ! don't have gold label

    if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data):
        dataset = DatasetTar(args.data,
                             load_bytes=args.tf_preprocessing,
                             class_map=args.class_map)
    else:
        dataset = Dataset(args.data,
                          load_bytes=args.tf_preprocessing,
                          class_map=args.class_map,
                          args=args)

    if args.valid_labels:
        with open(args.valid_labels,
                  'r') as f:  # @valid_labels is index numbering
            valid_labels = {int(line.rstrip()) for line in f}
            valid_labels = [i in valid_labels for i in range(args.num_classes)]
    else:
        valid_labels = None

    if args.real_labels:
        real_labels = RealLabelsImagenet(dataset.filenames(basename=True),
                                         real_json=args.real_labels)
    else:
        real_labels = None

    crop_pct = 1.0 if test_time_pool else data_config['crop_pct']

    loader = create_loader(
        dataset,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        use_prefetcher=args.prefetcher,
        interpolation=data_config[
            'interpolation'],  # 'blank' is default Image.BILINEAR https://github.com/rwightman/pytorch-image-models/blob/470220b1f4c61ad7deb16dbfb8917089e842cd2a/timm/data/transforms.py#L43
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        crop_pct=crop_pct,
        pin_memory=args.pin_mem,
        tf_preprocessing=args.tf_preprocessing,
        auto_augment=args.aa,
        scale=args.scale,
        ratio=args.ratio,
        hflip=args.hflip,
        vflip=args.vflip,
        color_jitter=args.color_jitter,
        args=args)

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    topk = AverageMeter()

    prediction = None  # ! need to save output
    true_label = None

    model.eval()
    with torch.no_grad():
        # warmup, reduce variability of first batch time, especially for comparing torchscript vs non
        input = torch.randn((args.batch_size, ) +
                            data_config['input_size']).cuda()
        model(input)
        end = time.time()
        for batch_idx, (input,
                        target) in enumerate(loader):  # ! not have real label

            if args.has_eval_label:  # ! just save true labels anyway... why not
                if true_label is None: true_label = target.cpu().data.numpy()
                else:
                    true_label = np.concatenate(
                        (true_label, target.cpu().data.numpy()), axis=0)

            if args.no_prefetcher:
                target = target.cuda()
                input = input.cuda()
                if args.fp16:
                    input = input.half()

            # compute output
            output = model(input)
            if isinstance(output, (tuple, list)):
                output = output[0]  # ! some model returns both loss + aux loss

            if valid_labels is not None:
                output = output[:,
                                valid_labels]  # ! keep only valid labels ? good to eval by class.

            # ! save prediction, don't append too slow ... whatever ?
            # ! are names of files also sorted ?
            if prediction is None:
                prediction = output.cpu().data.numpy()  # batchsize x label
            else:  # stack
                prediction = np.concatenate(
                    (prediction, output.cpu().data.numpy()), axis=0)

            if real_labels is not None:
                real_labels.add_result(output)

            if args.has_eval_label:
                # measure accuracy and record loss
                loss = criterion(
                    output, target)  # ! don't have gold standard on testset
                acc1, acc5 = accuracy(output.data, target, topk=(1, args.topk))
                losses.update(loss.item(), input.size(0))
                top1.update(acc1.item(), input.size(0))
                topk.update(acc5.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if args.has_eval_label and (batch_idx % args.log_freq == 0):
                _logger.info(
                    'Test: [{0:>4d}/{1}]  '
                    'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                    'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
                    'Acc@topk: {topk.val:>7.3f} ({topk.avg:>7.3f})'.format(
                        batch_idx,
                        len(loader),
                        batch_time=batch_time,
                        rate_avg=input.size(0) / batch_time.avg,
                        loss=losses,
                        top1=top1,
                        topk=topk))

    if not args.has_eval_label:
        top1a, topka = 0, 0  # just dummy, because we don't know ground labels
    else:
        if real_labels is not None:
            # real labels mode replaces topk values at the end
            top1a, topka = real_labels.get_accuracy(
                k=1), real_labels.get_accuracy(k=args.topk)
        else:
            top1a, topka = top1.avg, topk.avg

    results = OrderedDict(top1=round(top1a, 4),
                          top1_err=round(100 - top1a, 4),
                          topk=round(topka, 4),
                          topk_err=round(100 - topka, 4),
                          param_count=round(param_count / 1e6, 2),
                          img_size=data_config['input_size'][-1],
                          cropt_pct=crop_pct,
                          interpolation=data_config['interpolation'])

    _logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@topk {:.3f} ({:.3f})'.format(
        results['top1'], results['top1_err'], results['topk'],
        results['topk_err']))

    return results, prediction, true_label
Exemplo n.º 24
0
def main():
    setup_default_logging()
    args = parser.parse_args()
    # might as well try to do something useful...
    args.pretrained = args.pretrained or not args.checkpoint

    # create model
    model = create_model(args.model,
                         num_classes=args.num_classes,
                         in_chans=3,
                         pretrained=args.pretrained,
                         checkpoint_path=args.checkpoint)

    _logger.info('Model %s created, param count: %d' %
                 (args.model, sum([m.numel() for m in model.parameters()])))

    config = resolve_data_config(vars(args), model=model)
    model, test_time_pool = (
        model,
        False) if args.no_test_pool else apply_test_time_pool(model, config)

    if args.num_gpu > 1:
        model = torch.nn.DataParallel(model,
                                      device_ids=list(range(
                                          args.num_gpu))).cuda()
    else:
        model = model.cuda()

    loader = create_loader(
        Dataset(args.data, train_mode='test', fold_num=-1),
        input_size=config['input_size'],
        batch_size=args.batch_size,
        use_prefetcher=True,
        interpolation=config['interpolation'],
        mean=config['mean'],
        std=config['std'],
        num_workers=args.workers,
        crop_pct=1.0 if test_time_pool else config['crop_pct'])

    model.eval()

    #     k = min(args.topk, args.num_classes)
    batch_time = AverageMeter()
    end = time.time()
    topk_ids = []
    name_list = []
    sig_list = []
    logits_list = []
    m = torch.nn.Sigmoid()
    with torch.no_grad():
        for batch_idx, (
                input,
                _,
        ) in enumerate(loader):
            input = input.cuda()
            labels = model(input)
            logits_list.append(labels)
            sigmoided = m(labels)
            sig_list.append(
                np.expand_dims(sigmoided[:, 1].cpu().numpy(), axis=1))
            #             topk = labels.topk(k)[1]
            #             topk_ids.append(topk.cpu().numpy())

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if batch_idx % args.log_freq == 0:
                _logger.info(
                    'Predict: [{0}/{1}] Time {batch_time.val:.3f} ({batch_time.avg:.3f})'
                    .format(batch_idx, len(loader), batch_time=batch_time))


#     topk_ids = np.concatenate(topk_ids, axis=0).squeeze()
#     logits = torch.cat(logits_list).cuda()
#     temperature = nn.Parameter(torch.ones(1) * args.te).to(torch.device('cuda') ).detach().requires_grad_(False)
#     logits = logits/temperature.unsqueeze(1).expand(logits.size(0), logits.size(1))
#     temp_sigmoided =  m(logits)[:,1].detach().cpu().numpy()

    sig_list = np.vstack(sig_list)
    name_list = loader.dataset.filenames(basename=True)

    real_sigmoid = sig_list.squeeze()
    #     real_sigmoid = temp_sigmoided
    real_pred = ((sig_list >= args.thresh) * 1).squeeze()

    name_pred_dict = {}
    for idx in range(len(name_list)):
        name_pred_dict[name_list[idx]] = (real_pred[idx], real_sigmoid[idx])

    args.output_dir = args.checkpoint.replace(
        args.checkpoint.split('/')[-1], "")
    with open(os.path.join(args.output_dir, './prediction.tsv'),
              'w') as out_file:
        #         filenames_int = [int(f.split('.')[0]) for f in filenames]
        #         for name, topk in zip(filenames_int, topk_ids):
        #             print(name,topk)
        #             i = i+1
        #             if i == 10:
        #                 break
        #         idx = np.argsort(filenames_int)
        #         topk_ids = topk_ids[idx]
        for name in name_list:
            out_file.write('{}\n'.format(str(name_pred_dict[name][0])))
    with open(os.path.join(args.output_dir, './probability.tsv'),
              'w') as out_file:
        for name in name_list:
            out_file.write('{}\n'.format(name_pred_dict[name][1]))

    copyfile(
        os.path.join(args.output_dir, './prediction.tsv'),
        '/home/workspace/user-workspace/prediction/' + 'prediction_153_' +
        args.checkpoint.split('/')[-2] + '.tsv')
Exemplo n.º 25
0
def validate(config, data_loader, model):
    criterion = torch.nn.CrossEntropyLoss()
    model.eval()

    batch_time = AverageMeter()
    loss_meter = AverageMeter()
    acc1_meter = AverageMeter()
    acc5_meter = AverageMeter()

    end = time.time()
    for idx, (images, target) in enumerate(data_loader):
        images = images.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)

        # compute output
        output = model(images)

        # measure accuracy and record loss
        loss = criterion(output, target)
        acc1, acc5 = accuracy(output, target, topk=(1, 5))

        acc1 = reduce_tensor(acc1)
        acc5 = reduce_tensor(acc5)
        loss = reduce_tensor(loss)

        loss_meter.update(loss.item(), target.size(0))
        acc1_meter.update(acc1.item(), target.size(0))
        acc5_meter.update(acc5.item(), target.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if idx % config.PRINT_FREQ == 0:
            memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
            logger.info(f'Test: [{idx}/{len(data_loader)}]\t'
                        f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                        f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
                        f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t'
                        f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t'
                        f'Mem {memory_used:.0f}MB')
    logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')
    return acc1_meter.avg, acc5_meter.avg, loss_meter.avg
Exemplo n.º 26
0
def main():
    start_endpoint = "http://localhost:3000/start"
    stop_endpoint = "http://localhost:3000/stop"
    setup_default_logging()
    args = parser.parse_args()
    # might as well try to do something useful...
    args.pretrained = args.pretrained or not args.checkpoint

    output_dir = args.checkpoint.split('/')
    output_dir.pop(-1)
    output_dir = ('/').join(output_dir)

    # create model
    model = create_model(
        args.model,
        num_classes=args.num_classes,
        in_chans=3,
        pretrained=args.pretrained,
        checkpoint_path=args.checkpoint)

    logging.info('Model %s created, param count: %d' %
                 (args.model, sum([m.numel() for m in model.parameters()])))

    # config = resolve_data_config(vars(args), model=model)
    # model, test_time_pool = apply_test_time_pool(model, config, args)

    if args.num_gpu > 1:
        model = torch.nn.DataParallel(
            model, device_ids=list(range(args.num_gpu))).cuda()
    else:
        model = model.cuda()

    dataset_eval = torchvision.datasets.CIFAR100(
        root='./data', train=False, download=True)

    data_config = resolve_data_config(vars(args), model=model)

    # #CIFAR_100_MEAN = (0.5071, 0.4865, 0.4409)
    # #CIFAR_100_STD = (0.2673, 0.2564, 0.2762)
    data_config['mean'] = (0.5071, 0.4865, 0.4409)
    data_config['std'] = (0.2673, 0.2564, 0.2762)

    loader = create_loader(
        dataset_eval,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        is_training=False,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        crop_pct=data_config['crop_pct']
    )

    model.eval()

    batch_time = AverageMeter()

    with torch.no_grad():
        idle_power = requests.post(url=start_endpoint)
        idle_json = idle_power.json()
        for batch_idx, (input, _) in enumerate(loader):
            input = input.cuda()

            tstart = time.time()
            output = model(input)
            tend = time.time()

            if batch_idx != 0:
                batch_time.update(tend - tstart)

                if batch_idx % args.log_freq == 0:
                    print('Predict: [{0}/{1}] Time {batch_time.val:.6f} ({batch_time.avg:.6f})'.format(
                        batch_idx, len(loader), batch_time=batch_time), end='\r')

    load_power = requests.post(url=stop_endpoint)
    load_json = load_power.json()
    fps = 1 / batch_time.avg
    inference_power = float(load_json['load']) - float(idle_json['idle'])
    stats = [{'FPS': [float(fps)]},
                {'Total_Power': [float(inference_power)]}]
    with open(os.path.join(output_dir, '{}_fps_cifar.yaml'.format(args.model)), 'w') as f:
        yaml.safe_dump(stats, f)
Exemplo n.º 27
0
def validate(args):
    # might as well try to validate something
    args.pretrained = args.pretrained or not args.checkpoint

    # create model
    model = create_model(args.model,
                         pretrained=args.pretrained,
                         num_classes=args.num_classes,
                         in_chans=3,
                         scriptable=args.torchscript)

    if args.checkpoint:
        load_checkpoint(model, args.checkpoint, args.use_ema)

    param_count = sum([m.numel() for m in model.parameters()])
    logging.info('Model %s created, param count: %d' %
                 (args.model, param_count))

    if args.torchscript:
        torch.jit.optimized_execution(True)
        model = torch.jit.script(model)

    model = model.cuda()

    if args.num_gpu > 1:
        model = torch.nn.DataParallel(model,
                                      device_ids=list(range(args.num_gpu)))

    criterion = nn.CrossEntropyLoss().cuda()

    # from torchvision.datasets import ImageNet
    # dataset = ImageNet(args.data, split='val')

    valdir = args.data
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    transform = cvtransforms.Compose([
        cvtransforms.Resize(size=(256), interpolation='BILINEAR'),
        cvtransforms.CenterCrop(224),
        cvtransforms.ToTensor(),
        cvtransforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    # loader = torch.utils.data.DataLoader(
    #     datasets.ImageFolder(valdir, transform, loader=opencv_loader),
    #     batch_size=args.batch_size, shuffle=False,
    #     num_workers=args.workers, pin_memory=False)

    loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize((256), interpolation=2),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
                                         batch_size=args.batch_size,
                                         shuffle=False,
                                         num_workers=args.workers,
                                         pin_memory=False)

    # loader_eval = loader.Loader('val', valdir, batch_size=args.batch_size, num_workers=args.workers, shuffle=False)

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    model.eval()
    with torch.no_grad():
        # warmup, reduce variability of first batch time, especially for comparing torchscript vs non
        # input = torch.randn((args.batch_size,)).cuda()
        # model(input)
        end = time.time()
        for i, (input, target) in enumerate(loader):
            # if args.no_prefetcher:
            target = target.cuda()
            input = input.cuda()

            # compute output
            output, _ = model(input)
            # loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output.data, target, topk=(1, 5))
            # losses.update(loss.item(), input.size(0))
            top1.update(acc1.item(), input.size(0))
            top5.update(acc5.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.log_freq == 0:
                logging.info(
                    'Test: [{0:>4d}/{1}]  '
                    'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                    'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
                    'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
                        i,
                        len(loader),
                        batch_time=batch_time,
                        rate_avg=input.size(0) / batch_time.avg,
                        loss=losses,
                        top1=top1,
                        top5=top5))

    results = OrderedDict(top1=round(top1.avg, 4),
                          top1_err=round(100 - top1.avg, 4),
                          top5=round(top5.avg, 4),
                          top5_err=round(100 - top5.avg, 4),
                          param_count=round(param_count / 1e6, 2))

    logging.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
        results['top1'], results['top1_err'], results['top5'],
        results['top5_err']))

    return results
Exemplo n.º 28
0
def train_one_epoch(epoch,
                    model,
                    loader,
                    optimizer,
                    loss_fn,
                    args,
                    lr_scheduler=None,
                    saver=None,
                    output_dir=None,
                    amp_autocast=suppress,
                    loss_scaler=None,
                    model_ema=None,
                    mixup_fn=None):

    if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
        if args.prefetcher and loader.mixup_enabled:
            loader.mixup_enabled = False
        elif mixup_fn is not None:
            mixup_fn.mixup_enabled = False

    second_order = hasattr(optimizer,
                           'is_second_order') and optimizer.is_second_order
    batch_time_m = AverageMeter()
    data_time_m = AverageMeter()
    losses_m = AverageMeter()

    model.train()

    end = time.time()
    last_idx = len(loader) - 1
    num_updates = epoch * len(loader)
    for batch_idx, (input, target) in enumerate(loader):
        last_batch = batch_idx == last_idx
        data_time_m.update(time.time() - end)
        if not args.prefetcher:
            input, target = input.cuda(), target.cuda()
            if mixup_fn is not None:
                input, target = mixup_fn(input, target)
        if args.channels_last:
            input = input.contiguous(memory_format=torch.channels_last)

        with amp_autocast():
            output = model(input)
            loss = loss_fn(output, target)

        if not args.distributed:
            losses_m.update(loss.item(), input.size(0))

        optimizer.zero_grad()
        if loss_scaler is not None:
            loss_scaler(loss,
                        optimizer,
                        clip_grad=args.clip_grad,
                        clip_mode=args.clip_mode,
                        parameters=model_parameters(model,
                                                    exclude_head='agc'
                                                    in args.clip_mode),
                        create_graph=second_order)
        else:
            loss.backward(create_graph=second_order)
            if args.clip_grad is not None:
                dispatch_clip_grad(model_parameters(model,
                                                    exclude_head='agc'
                                                    in args.clip_mode),
                                   value=args.clip_grad,
                                   mode=args.clip_mode)
            optimizer.step()

        if model_ema is not None:
            model_ema.update(model)

        torch.cuda.synchronize()
        num_updates += 1
        batch_time_m.update(time.time() - end)
        if last_batch or batch_idx % args.log_interval == 0:
            lrl = [param_group['lr'] for param_group in optimizer.param_groups]
            lr = sum(lrl) / len(lrl)

            if args.distributed:
                reduced_loss = reduce_tensor(loss.data, args.world_size)
                losses_m.update(reduced_loss.item(), input.size(0))

            if args.local_rank == 0:
                _logger.info(
                    'Train: {} [{:>4d}/{} ({:>3.0f}%)]  '
                    'Loss: {loss.val:#.4g} ({loss.avg:#.3g})  '
                    'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s  '
                    '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'LR: {lr:.3e}  '
                    'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
                        epoch,
                        batch_idx,
                        len(loader),
                        100. * batch_idx / last_idx,
                        loss=losses_m,
                        batch_time=batch_time_m,
                        rate=input.size(0) * args.world_size /
                        batch_time_m.val,
                        rate_avg=input.size(0) * args.world_size /
                        batch_time_m.avg,
                        lr=lr,
                        data_time=data_time_m))

                if args.save_images and output_dir:
                    torchvision.utils.save_image(
                        input,
                        os.path.join(output_dir,
                                     'train-batch-%d.jpg' % batch_idx),
                        padding=0,
                        normalize=True)

        if saver is not None and args.recovery_interval and (
                last_batch or (batch_idx + 1) % args.recovery_interval == 0):
            saver.save_recovery(epoch, batch_idx=batch_idx)

        if lr_scheduler is not None:
            lr_scheduler.step_update(num_updates=num_updates,
                                     metric=losses_m.avg)

        end = time.time()
        # end for

    if hasattr(optimizer, 'sync_lookahead'):
        optimizer.sync_lookahead()

    return OrderedDict([('loss', losses_m.avg)])
Exemplo n.º 29
0
def validate(args):
    # might as well try to validate something
    args.pretrained = args.pretrained or not args.checkpoint
    args.prefetcher = not args.no_prefetcher

    # create model
    model = create_model(
        args.model,
        num_classes=args.num_classes,
        in_chans=3,
        pretrained=args.pretrained)
    if args.checkpoint:
        load_checkpoint(model, args.checkpoint, args.use_ema)

    param_count = sum([m.numel() for m in model.parameters()])
    logging.info('Model %s created, param count: %d' % (args.model, param_count))

    data_config = resolve_data_config(vars(args), model=model)
    model, test_time_pool = apply_test_time_pool(model, data_config, args)

    if args.torchscript:
        torch.jit.optimized_execution(True)
        model = torch.jit.script(model)

    if args.amp:
        model = amp.initialize(model.cuda(), opt_level='O1')
    else:
        model = model.cuda()

    if args.num_gpu > 1:
        model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))

    criterion = nn.CrossEntropyLoss().cuda()

    #from torchvision.datasets import ImageNet
    #dataset = ImageNet(args.data, split='val')
    if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data):
        dataset = DatasetTar(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map)
    else:
        dataset = Dataset(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map)

    crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
    loader = create_loader(
        dataset,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        crop_pct=crop_pct,
        pin_memory=args.pin_mem,
        tf_preprocessing=args.tf_preprocessing)

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    model.eval()
    end = time.time()
    with torch.no_grad():
        for i, (input, target) in enumerate(loader):
            if args.no_prefetcher:
                target = target.cuda()
                input = input.cuda()
                if args.fp16:
                    input = input.half()

            # compute output
            output = model(input)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output.data, target, topk=(1, 2))
            losses.update(loss.item(), input.size(0))
            top1.update(acc1.item(), input.size(0))
            top5.update(acc5.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.log_freq == 0:
                logging.info(
                    'Test: [{0:>4d}/{1}]  '
                    'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                    'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
                    'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
                        i, len(loader), batch_time=batch_time,
                        rate_avg=input.size(0) / batch_time.avg,
                        loss=losses, top1=top1, top5=top5))

    results = OrderedDict(
        top1=round(top1.avg, 4), top1_err=round(100 - top1.avg, 4),
        top5=round(top5.avg, 4), top5_err=round(100 - top5.avg, 4),
        param_count=round(param_count / 1e6, 2),
        img_size=data_config['input_size'][-1],
        cropt_pct=crop_pct,
        interpolation=data_config['interpolation'])

    logging.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
       results['top1'], results['top1_err'], results['top5'], results['top5_err']))

    return results
Exemplo n.º 30
0
def main():
    setup_default_logging()
    args = parser.parse_args()
    # might as well try to do something useful...
    args.pretrained = args.pretrained or not args.checkpoint

    # create model
    model = create_model(
        args.model,
        num_classes=args.num_classes,
        in_chans=3,
        pretrained=args.pretrained,
        checkpoint_path=args.checkpoint)

    logging.info('Model %s created, param count: %d' %
                 (args.model, sum([m.numel() for m in model.parameters()])))

    config = resolve_data_config(vars(args), model=model)
    test_time_pool = False
    if args.num_gpu > 1:
        model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
    else:
        model = model.cuda()
    
    test_dataset = Dataset(args.data)
    class_mapper = {v:k for k,v in test_dataset.class_to_idx.items()}
    
    loader = create_loader(
        test_dataset,
        input_size=config['input_size'],
        batch_size=args.batch_size,
        use_prefetcher=False,
        interpolation=config['interpolation'],
        mean=config['mean'],
        std=config['std'],
        num_workers=args.workers,
        crop_pct=1.0 if test_time_pool else config['crop_pct'])

    model.eval()

    k = min(args.topk, args.num_classes)
    batch_time = AverageMeter()
    end = time.time()
    topk_ids = []
    results  = []
    with torch.no_grad():
        for batch_idx, (input, _, path) in enumerate(loader):
            input = input.cuda()
            labels = model(input)
            topk = labels.topk(k)[1]
            topk_ids.append(topk.cpu().numpy())
            label_ = list(topk.cpu().numpy())[0][0]
            results.append([path, class_mapper[label_]])
            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if batch_idx % args.log_freq == 0:
                logging.info('Predict: [{0}/{1}] Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format(
                    batch_idx, len(loader), batch_time=batch_time))

    topk_ids = np.concatenate(topk_ids, axis=0).squeeze()

    with open(os.path.join(args.output_dir, './topk_ids.csv'), 'w') as out_file:
        filenames = loader.dataset.filenames()
        for filename, label in zip(filenames, topk_ids):
            filename = os.path.basename(filename)
            out_file.write('{0},{1},{2},{3}\n'.format(
                filename, label[0], label[1], label[2]))

    if args.result_dir:
        if not os.path.isdir(args.result_dir):
            os.system("mkdir "+args.result_dir)
        else:
            os.system("rm -rf "+args.result_dir)
            os.system("mkdir "+args.result_dir)
            for k,v in class_mapper.items():
                os.system("mkdir -p "+os.path.join(args.result_dir,v))
        for image in results:
            path   = image[0][0]
            result = image[1]
            gt     = path.split('/')[-2]
            print("\tPath : %s  \tResult : %s  \tGT : %s"%(path, result, gt))
            os.system("cp "+path+" "+os.path.join(args.result_dir,result))