def build_dataset(is_train, config): transform = build_transform(is_train, config) if config.DATA.DATASET == 'imagenet': prefix = 'train' if is_train else 'val' if config.DATA.LOAD_TAR: data_dir = os.path.join(config.DATA.DATA_PATH, f'{prefix}.tar') dataset = DatasetTar(data_dir, transform=transform) else: if config.DATA.ZIP_MODE: ann_file = prefix + "_map.txt" prefix = prefix + ".zip@/" dataset = CachedImageFolder( config.DATA.DATA_PATH, ann_file, prefix, transform, cache_mode=config.DATA.CACHE_MODE if is_train else 'part') else: root = os.path.join(config.DATA.DATA_PATH, prefix) dataset = datasets.ImageFolder(root, transform=transform) nb_classes = 1000 else: raise NotImplementedError("We only support ImageNet Now.") return dataset, nb_classes
def validate(args): rng = jax.random.PRNGKey(0) model, variables = create_model(args.model, pretrained=True, rng=rng) print(f'Created {args.model} model. Validating...') if args.no_jit: eval_step = lambda images, labels: eval_forward( model, variables, images, labels) else: eval_step = jax.jit(lambda images, labels: eval_forward( model, variables, images, labels)) if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data): dataset = DatasetTar(args.data) else: dataset = Dataset(args.data) data_config = resolve_data_config(vars(args), model=model) loader = create_loader(dataset, input_size=data_config['input_size'], batch_size=args.batch_size, use_prefetcher=False, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=8, crop_pct=data_config['crop_pct']) batch_time = AverageMeter() correct_top1, correct_top5 = 0, 0 total_examples = 0 start_time = prev_time = time.time() for batch_index, (images, labels) in enumerate(loader): images = images.numpy().transpose(0, 2, 3, 1) labels = labels.numpy() top1_count, top5_count = eval_step(images, labels) correct_top1 += top1_count correct_top5 += top5_count total_examples += images.shape[0] batch_time.update(time.time() - prev_time) if batch_index % 20 == 0 and batch_index > 0: print( f'Test: [{batch_index:>4d}/{len(loader)}] ' f'Rate: {images.shape[0] / batch_time.val:>5.2f}/s ({images.shape[0] / batch_time.avg:>5.2f}/s) ' f'Acc@1: {100 * correct_top1 / total_examples:>7.3f} ' f'Acc@5: {100 * correct_top5 / total_examples:>7.3f}') prev_time = time.time() acc_1 = 100 * correct_top1 / total_examples acc_5 = 100 * correct_top5 / total_examples print( f'Validation complete. {total_examples / (prev_time - start_time):>5.2f} img/s. ' f'Acc@1 {acc_1:>7.3f}, Acc@5 {acc_5:>7.3f}') return dict(top1=float(acc_1), top5=float(acc_5))
def validate(args): # might as well try to validate something args.pretrained = args.pretrained or not args.checkpoint args.prefetcher = not args.no_prefetcher # create model model = create_model( args.model, num_classes=args.num_classes, in_chans=3, pretrained=args.pretrained) if args.checkpoint: load_checkpoint(model, args.checkpoint, args.use_ema) param_count = sum([m.numel() for m in model.parameters()]) logging.info('Model %s created, param count: %d' % (args.model, param_count)) data_config = resolve_data_config(vars(args), model=model) model, test_time_pool = apply_test_time_pool(model, data_config, args) if args.torchscript: torch.jit.optimized_execution(True) model = torch.jit.script(model) if args.amp: model = amp.initialize(model.cuda(), opt_level='O1') else: model = model.cuda() if args.num_gpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))) criterion = nn.CrossEntropyLoss().cuda() #from torchvision.datasets import ImageNet #dataset = ImageNet(args.data, split='val') if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data): dataset = DatasetTar(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map) else: dataset = Dataset(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map) crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] loader = create_loader( dataset, input_size=data_config['input_size'], batch_size=args.batch_size, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, crop_pct=crop_pct, pin_memory=args.pin_mem, tf_preprocessing=args.tf_preprocessing) batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() model.eval() end = time.time() with torch.no_grad(): for i, (input, target) in enumerate(loader): if args.no_prefetcher: target = target.cuda() input = input.cuda() if args.fp16: input = input.half() # compute output output = model(input) loss = criterion(output, target) # measure accuracy and record loss acc1, acc5 = accuracy(output.data, target, topk=(1, 2)) losses.update(loss.item(), input.size(0)) top1.update(acc1.item(), input.size(0)) top5.update(acc5.item(), input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.log_freq == 0: logging.info( 'Test: [{0:>4d}/{1}] ' 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' 'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format( i, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, loss=losses, top1=top1, top5=top5)) results = OrderedDict( top1=round(top1.avg, 4), top1_err=round(100 - top1.avg, 4), top5=round(top5.avg, 4), top5_err=round(100 - top5.avg, 4), param_count=round(param_count / 1e6, 2), img_size=data_config['input_size'][-1], cropt_pct=crop_pct, interpolation=data_config['interpolation']) logging.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format( results['top1'], results['top1_err'], results['top5'], results['top5_err'])) return results
def validate(args): # might as well try to validate something args.pretrained = args.pretrained or not args.checkpoint args.prefetcher = not args.no_prefetcher if args.legacy_jit: set_jit_legacy() # create model if 'inception' in args.model: model = create_model( args.model, pretrained=args.pretrained, num_classes=args.num_classes, aux_logits=True, # ! add aux loss in_chans=3, scriptable=args.torchscript) else: model = create_model(args.model, pretrained=args.pretrained, num_classes=args.num_classes, in_chans=3, scriptable=args.torchscript) # ! add more layer to classifier layer if args.create_classifier_layerfc: model.global_pool, model.classifier = create_classifier_layerfc( model.num_features, model.num_classes) if args.checkpoint: load_checkpoint(model, args.checkpoint, args.use_ema) param_count = sum([m.numel() for m in model.parameters()]) _logger.info('Model %s created, param count: %d' % (args.model, param_count)) data_config = resolve_data_config(vars(args), model=model) model, test_time_pool = apply_test_time_pool(model, data_config, args) if args.torchscript: torch.jit.optimized_execution(True) model = torch.jit.script(model) if args.amp: model = amp.initialize(model.cuda(), opt_level='O1') else: model = model.cuda() if args.num_gpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))) if args.has_eval_label: criterion = nn.CrossEntropyLoss().cuda() # ! don't have gold label if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data): dataset = DatasetTar(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map) else: dataset = Dataset(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map, args=args) if args.valid_labels: with open(args.valid_labels, 'r') as f: # @valid_labels is index numbering valid_labels = {int(line.rstrip()) for line in f} valid_labels = [i in valid_labels for i in range(args.num_classes)] else: valid_labels = None if args.real_labels: real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels) else: real_labels = None crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] loader = create_loader( dataset, input_size=data_config['input_size'], batch_size=args.batch_size, use_prefetcher=args.prefetcher, interpolation=data_config[ 'interpolation'], # 'blank' is default Image.BILINEAR https://github.com/rwightman/pytorch-image-models/blob/470220b1f4c61ad7deb16dbfb8917089e842cd2a/timm/data/transforms.py#L43 mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, crop_pct=crop_pct, pin_memory=args.pin_mem, tf_preprocessing=args.tf_preprocessing, auto_augment=args.aa, scale=args.scale, ratio=args.ratio, hflip=args.hflip, vflip=args.vflip, color_jitter=args.color_jitter, args=args) batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() topk = AverageMeter() prediction = None # ! need to save output true_label = None model.eval() with torch.no_grad(): # warmup, reduce variability of first batch time, especially for comparing torchscript vs non input = torch.randn((args.batch_size, ) + data_config['input_size']).cuda() model(input) end = time.time() for batch_idx, (input, target) in enumerate(loader): # ! not have real label if args.has_eval_label: # ! just save true labels anyway... why not if true_label is None: true_label = target.cpu().data.numpy() else: true_label = np.concatenate( (true_label, target.cpu().data.numpy()), axis=0) if args.no_prefetcher: target = target.cuda() input = input.cuda() if args.fp16: input = input.half() # compute output output = model(input) if isinstance(output, (tuple, list)): output = output[0] # ! some model returns both loss + aux loss if valid_labels is not None: output = output[:, valid_labels] # ! keep only valid labels ? good to eval by class. # ! save prediction, don't append too slow ... whatever ? # ! are names of files also sorted ? if prediction is None: prediction = output.cpu().data.numpy() # batchsize x label else: # stack prediction = np.concatenate( (prediction, output.cpu().data.numpy()), axis=0) if real_labels is not None: real_labels.add_result(output) if args.has_eval_label: # measure accuracy and record loss loss = criterion( output, target) # ! don't have gold standard on testset acc1, acc5 = accuracy(output.data, target, topk=(1, args.topk)) losses.update(loss.item(), input.size(0)) top1.update(acc1.item(), input.size(0)) topk.update(acc5.item(), input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if args.has_eval_label and (batch_idx % args.log_freq == 0): _logger.info( 'Test: [{0:>4d}/{1}] ' 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' 'Acc@topk: {topk.val:>7.3f} ({topk.avg:>7.3f})'.format( batch_idx, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, loss=losses, top1=top1, topk=topk)) if not args.has_eval_label: top1a, topka = 0, 0 # just dummy, because we don't know ground labels else: if real_labels is not None: # real labels mode replaces topk values at the end top1a, topka = real_labels.get_accuracy( k=1), real_labels.get_accuracy(k=args.topk) else: top1a, topka = top1.avg, topk.avg results = OrderedDict(top1=round(top1a, 4), top1_err=round(100 - top1a, 4), topk=round(topka, 4), topk_err=round(100 - topka, 4), param_count=round(param_count / 1e6, 2), img_size=data_config['input_size'][-1], cropt_pct=crop_pct, interpolation=data_config['interpolation']) _logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@topk {:.3f} ({:.3f})'.format( results['top1'], results['top1_err'], results['topk'], results['topk_err'])) return results, prediction, true_label
def validate(args): # might as well try to validate something args.pretrained = args.pretrained or not args.checkpoint args.prefetcher = not args.no_prefetcher amp_autocast = suppress # do nothing if args.amp: if has_apex: args.apex_amp = True elif has_native_amp: args.native_amp = True else: _logger.warning( "Neither APEX or Native Torch AMP is available, using FP32.") assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set." if args.native_amp: amp_autocast = torch.cuda.amp.autocast if args.legacy_jit: set_jit_legacy() # create model if args.neuron: model = torch.jit.load(args.checkpoint) else: model = create_model(args.model, pretrained=args.pretrained, num_classes=args.num_classes, in_chans=3, global_pool=args.gp, scriptable=args.torchscript) if args.checkpoint: load_checkpoint(model, args.checkpoint, args.use_ema) if not args.neuron: param_count = sum([m.numel() for m in model.parameters()]) _logger.info('Model %s created, param count: %d' % (args.model, param_count)) data_config = resolve_data_config(vars(args), model=model) model, test_time_pool = ( model, False) if args.no_test_pool else apply_test_time_pool( model, data_config) if args.torchscript: torch.jit.optimized_execution(True) model = torch.jit.script(model) if not args.neuron: model = model.cuda() if args.apex_amp: model = amp.initialize(model, opt_level='O1') if args.channels_last: model = model.to(memory_format=torch.channels_last) if args.num_gpu > 1 and not args.neuron: model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))) criterion = nn.CrossEntropyLoss() if not args.neuron: criterion = criterion.cuda() if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data): dataset = DatasetTar(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map) else: dataset = Dataset(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map) if args.valid_labels: with open(args.valid_labels, 'r') as f: valid_labels = {int(line.rstrip()) for line in f} valid_labels = [i in valid_labels for i in range(args.num_classes)] else: valid_labels = None if args.real_labels: real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels) else: real_labels = None crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] loader = create_loader(dataset, input_size=data_config['input_size'], batch_size=args.batch_size, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, crop_pct=crop_pct, pin_memory=args.pin_mem, tf_preprocessing=args.tf_preprocessing) batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() model.eval() with torch.no_grad(): # warmup, reduce variability of first batch time, especially for comparing torchscript vs non input = torch.randn((args.batch_size, ) + data_config['input_size']) if not args.neuron: input = input.cuda() if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) model(input) end = time.time() for batch_idx, (input, target) in enumerate(loader): if args.no_prefetcher: if not args.neuron: target = target.cuda() input = input.cuda() if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) # compute output if not args.neuron: with amp_autocast(): output = model(input) else: output = model(input) if valid_labels is not None: output = output[:, valid_labels] loss = criterion(output, target) if real_labels is not None: real_labels.add_result(output) # measure accuracy and record loss acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(acc1.item(), input.size(0)) top5.update(acc5.item(), input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if batch_idx % args.log_freq == 0: _logger.info( 'Test: [{0:>4d}/{1}] ' 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' 'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format( batch_idx, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, loss=losses, top1=top1, top5=top5)) if real_labels is not None: # real labels mode replaces topk values at the end top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy( k=5) else: top1a, top5a = top1.avg, top5.avg results = OrderedDict(top1=round(top1a, 4), top1_err=round(100 - top1a, 4), top5=round(top5a, 4), top5_err=round(100 - top5a, 4), param_count=round(param_count / 1e6, 2), img_size=data_config['input_size'][-1], cropt_pct=crop_pct, interpolation=data_config['interpolation']) _logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format( results['top1'], results['top1_err'], results['top5'], results['top5_err'])) return results
if is_server(): DATA_ROOT = './.data/vision/imagenet' else: # local settings DATA_ROOT = './' DATA_FILENAME = 'ILSVRC2012_img_val.tar' TAR_PATH = os.path.join(DATA_ROOT, DATA_FILENAME) for m in model_list: model_name = m['model'] # create model from name model = create_model(model_name, pretrained=True) param_count = sum([m.numel() for m in model.parameters()]) print('Model %s, %s created. Param count: %d' % (model_name, m['paper_model_name'], param_count)) dataset = DatasetTar(TAR_PATH) filenames = [os.path.splitext(f)[0] for f in dataset.filenames()] # get appropriate transform for model's default pretrained config data_config = resolve_data_config(m['args'], model=model, verbose=True) test_time_pool = False if m['ttp']: model, test_time_pool = apply_test_time_pool(model, data_config) data_config['crop_pct'] = 1.0 batch_size = m['batch_size'] loader = create_loader( dataset, input_size=data_config['input_size'], batch_size=batch_size, use_prefetcher=True,
def validate(args): # might as well try to validate something args.pretrained = args.pretrained or not args.checkpoint args.prefetcher = not args.no_prefetcher amp_autocast = suppress # do nothing if args.amp: if has_apex: args.apex_amp = True elif has_native_amp: args.native_amp = True else: _logger.warning( "Neither APEX or Native Torch AMP is available, using FP32.") assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set." if args.native_amp: amp_autocast = torch.cuda.amp.autocast if args.legacy_jit: set_jit_legacy() # create model model = create_model(args.model, pretrained=args.pretrained, num_classes=args.num_classes, in_chans=3, global_pool=args.gp, scriptable=args.torchscript) if args.checkpoint: load_checkpoint(model, args.checkpoint, args.use_ema) param_count = sum([m.numel() for m in model.parameters()]) _logger.info('Model %s created, param count: %d' % (args.model, param_count)) data_config = resolve_data_config(vars(args), model=model) model, test_time_pool = ( model, False) if args.no_test_pool else apply_test_time_pool( model, data_config) if args.torchscript: torch.jit.optimized_execution(True) model = torch.jit.script(model) model = model.cuda() if args.apex_amp: model = amp.initialize(model, opt_level='O1') if args.channels_last: model = model.to(memory_format=torch.channels_last) if args.num_gpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))) criterion = nn.CrossEntropyLoss().cuda() if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data): dataset = DatasetTar(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map) else: dataset = Dataset(args.data, train_mode='val', fold_num=args.fold_num, load_bytes=args.tf_preprocessing, class_map=args.class_map) if args.valid_labels: with open(args.valid_labels, 'r') as f: valid_labels = {int(line.rstrip()) for line in f} valid_labels = [i in valid_labels for i in range(args.num_classes)] else: valid_labels = None if args.real_labels: real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels) else: real_labels = None crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] loader = create_loader(dataset, input_size=data_config['input_size'], batch_size=args.batch_size, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, crop_pct=crop_pct, pin_memory=args.pin_mem, tf_preprocessing=args.tf_preprocessing) batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() # top5 = AverageMeter() f1_m = AverageMeter() model.eval() last_idx = len(loader) - 1 cuda = torch.device('cuda') temperature = nn.Parameter(torch.ones(1) * 1.5).to(cuda).detach().requires_grad_(True) m = nn.Sigmoid() nll_criterion = nn.CrossEntropyLoss().cuda() ece_criterion = _ECELoss().cuda() with torch.no_grad(): # warmup, reduce variability of first batch time, especially for comparing torchscript vs non input = torch.randn((args.batch_size, ) + data_config['input_size']).cuda() if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) model(input) end = time.time() logits_list = [] target_list = [] for batch_idx, (input, target) in enumerate(loader): last_batch = batch_idx == last_idx if args.no_prefetcher: target = target.cuda() input = input.cuda() if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) # compute output with amp_autocast(): output = model(input) if valid_labels is not None: output = output[:, valid_labels] loss = criterion(output, target) if real_labels is not None: real_labels.add_result(output) # measure accuracy and record loss acc1, _ = accuracy(output.detach(), target, topk=(1, 1)) logits_list.append(output) target_list.append(target) best_f1 = 0.0 best_th = 1.0 if last_batch: logits = torch.cat(logits_list).cuda() ### targets = torch.cat(target_list).cuda() ### targets_cpu = targets.cpu().numpy() sigmoided = m(logits)[:, 1].cpu().numpy() for i in range(1000, 0, -1): th = i * 0.001 real_pred = (sigmoided >= th) * 1.0 f1 = f1_score(targets_cpu.squeeze(), real_pred.squeeze()) if f1 > best_f1: best_f1 = f1 best_th = th losses.update(loss.item(), input.size(0)) top1.update(acc1.item(), input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if batch_idx % args.log_freq == 0: _logger.info( 'Test: [{0:>4d}/{1}] ' 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' 'thresh: {thresh:>7.4f} ' 'f1: {f1:>7.4f}'.format(batch_idx, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, loss=losses, top1=top1, thresh=best_th, f1=best_f1)) print(best_th, best_f1) #for temp_scalilng if args.temp_scaling: # before_temperature_ece = ece_criterion(logits, targets).item() # before_temperature_nll = nll_criterion(logits, targets).item() # print('Before temperature - NLL: %.3f, ECE: %.3f' % (before_temperature_nll, before_temperature_ece)) # optimizer = optim.LBFGS([temperature], lr=0.01, max_iter=50) # def eval(): # unsqueezed_temperature = temperature.unsqueeze(1).expand(logits.size(0), logits.size(1)) # loss = nll_criterion(logits/unsqueezed_temperature, targets) # loss.backward() # return loss # optimizer.step(eval) # unsqueezed_temperature = temperature.unsqueeze(1).expand(logits.size(0), logits.size(1)) # logits = logits/unsqueezed_temperature # after_temperature_nll = nll_criterion(logits, targets).item() # after_temperature_ece = ece_criterion(logits, targets).item() # print('Optimal temperature: %.3f' % temperature.item()) # print('After temperature - NLL: %.3f, ECE: %.3f' % (after_temperature_nll, after_temperature_ece)) sigmoided = m(logits)[:, 1].detach().cpu().numpy() temperature = nn.Parameter(torch.ones(1) * 11).to(cuda).detach().requires_grad_(False) logits = logits / temperature.unsqueeze(1).expand( logits.size(0), logits.size(1)) targets_cpu = targets.cpu().numpy() sigmoided = m(logits)[:, 1].detach().cpu().numpy() best_f1 = 0.0 best_th = 1.0 for i in range(1000, 0, -1): th = i * 0.001 real_pred = (sigmoided >= th) * 1.0 f1 = f1_score(targets_cpu.squeeze(), real_pred.squeeze()) if f1 > best_f1: best_f1 = f1 best_th = th print(best_th, best_f1) if real_labels is not None: # real labels mode replaces topk values at the end top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy( k=5) else: top1a, f1a = top1.avg, best_f1 results = OrderedDict(top1=round(top1a, 4), top1_err=round(100 - top1a, 4), f1=f1a, f1_err=round(100 - f1a, 4), param_count=round(param_count / 1e6, 2), img_size=data_config['input_size'][-1], cropt_pct=crop_pct, interpolation=data_config['interpolation']) _logger.info(' * Acc@1 {:.3f} ({:.3f}) f1 {:.3f} ({:.3f})'.format( results['top1'], results['top1_err'], results['f1'], results['f1_err'])) return results
def validate(args): # might as well try to validate something args.pretrained = args.pretrained or not args.checkpoint args.prefetcher = not args.no_prefetcher # create model model = create_model(args.model, num_classes=args.num_classes, in_chans=3, pretrained=args.pretrained) if args.checkpoint: load_checkpoint(model, args.checkpoint, args.use_ema) param_count = sum([m.numel() for m in model.parameters()]) logging.info('Model %s created, param count: %d' % (args.model, param_count)) data_config = resolve_data_config(vars(args), model=model) model, test_time_pool = apply_test_time_pool(model, data_config, args) if args.num_gpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range( args.num_gpu))).cuda() else: model = model.cuda() if args.fp16: model = model.half() criterion = nn.CrossEntropyLoss().cuda() if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data): dataset = DatasetTar(args.data, load_bytes=args.tf_preprocessing) else: dataset = Dataset(args.data, load_bytes=args.tf_preprocessing) crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] loader = create_loader(dataset, input_size=data_config['input_size'], batch_size=args.batch_size, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, crop_pct=crop_pct, fp16=args.fp16, tf_preprocessing=args.tf_preprocessing) batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() c_matrix = np.zeros((40, 40), dtype=int) labels = np.arange(0, 40, 1) model.eval() end = time.time() with torch.no_grad(): cf = open('results.csv', 'w') cv = open('results-parent.csv', 'w') writer = csv.writer(cf) writer_2 = csv.writer(cv) for i, (input, target) in enumerate(loader): if args.no_prefetcher: target = target.cuda() input = input.cuda() if args.fp16: input = input.half() # compute output output = model(input) loss = criterion(output, target) # measure accuracy and record loss prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(prec1.item(), input.size(0)) top5.update(prec5.item(), input.size(0)) c_matrix += cal_confusions(output, target, labels=labels) # measure elapsed time batch_time.update(time.time() - end) end = time.time() writer.writerow([i, round(top1.avg, 4)]) # 计算大类分类准确率 if args.hier_classify: a = [i for i in range(0, 6)] b = [i for i in range(6, 14)] c = [i for i in range(14, 37)] d = [i for i in range(37, 40)] corrects = 0. corrects += c_matrix[a][:, a].sum() corrects += c_matrix[b][:, b].sum() corrects += c_matrix[c][:, c].sum() corrects += c_matrix[d][:, d].sum() writer_2.writerow([i, round(corrects / c_matrix.sum(), 4)]) logging.info('parent precision: {}'.format(corrects / c_matrix.sum())) if i % args.log_freq == 0: logging.info( 'Test: [{0:>4d}/{1}] ' 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Prec@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' 'Prec@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format( i, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, loss=losses, top1=top1, top5=top5)) cf.close() cv.close() results = OrderedDict(top1=round(top1.avg, 4), top1_err=round(100 - top1.avg, 4), top5=round(top5.avg, 4), top5_err=round(100 - top5.avg, 4), param_count=round(param_count / 1e6, 2), img_size=data_config['input_size'][-1], cropt_pct=crop_pct, interpolation=data_config['interpolation']) logging.info(' * Prec@1 {:.3f} ({:.3f}) Prec@5 {:.3f} ({:.3f})'.format( results['top1'], results['top1_err'], results['top5'], results['top5_err'])) logging.info('confusion_matrix: \n {}'.format(c_matrix)) logging.info('precision by confusion matrix: \n {}'.format( truediv(np.sum(np.diag(c_matrix)), np.sum(np.sum(c_matrix, axis=1))))) # with open('confusion_matrix.csv', 'w') as cf: # writer = csv.writer(cf) # for row in c_matrix: # writer.writerow(row) # # diag = np.diag(c_matrix) # each_acc = truediv(diag, np.sum(c_matrix, axis=1)) # writer.writerow(each_acc) return results
def validate(args): # might as well try to validate something args.pretrained = False args.prefetcher = True # create model model = eval(args.model)() if args.checkpoint: load_checkpoint(model, args.checkpoint, False) param_count = sum([m.numel() for m in model.parameters()]) logging.info('Model %s created, param count: %d' % (args.model, param_count)) data_config = resolve_data_config(vars(args), model=model) model, test_time_pool = apply_test_time_pool(model, data_config, args) if args.num_gpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range( args.num_gpu))).cuda() else: model = model.cuda() criterion = nn.CrossEntropyLoss().cuda() if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data): dataset = DatasetTar(args.data, load_bytes=False) else: dataset = Dataset(args.data, load_bytes=False) crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] loader = create_loader(dataset, input_size=data_config['input_size'], batch_size=args.batch_size, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, crop_pct=crop_pct) batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() model.eval() end = time.time() with torch.no_grad(): for i, (input, target) in enumerate(loader): # compute output output = model(input) loss = criterion(output, target) # measure accuracy and record loss prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(prec1.item(), input.size(0)) top5.update(prec5.item(), input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.log_freq == 0: logging.info( 'Test: [{0:>4d}/{1}] ' 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Prec@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' 'Prec@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format( i, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, loss=losses, top1=top1, top5=top5)) results = OrderedDict(top1=round(top1.avg, 4), top1_err=round(100 - top1.avg, 4), top5=round(top5.avg, 4), top5_err=round(100 - top5.avg, 4), param_count=round(param_count / 1e6, 2), img_size=data_config['input_size'][-1], cropt_pct=crop_pct, interpolation=data_config['interpolation']) logging.info(' * Prec@1 {:.3f} ({:.3f}) Prec@5 {:.3f} ({:.3f})'.format( results['top1'], results['top1_err'], results['top5'], results['top5_err'])) return results
def main(args): utils.init_distributed_mode(args) print(args) if args.distillation_type != 'none' and args.finetune and not args.eval: raise NotImplementedError("Finetuning with distillation not yet supported") device = torch.device(args.device) # fix the seed for reproducibility seed = args.seed + utils.get_rank() torch.manual_seed(seed) np.random.seed(seed) # random.seed(seed) cudnn.benchmark = True if args.load_tar: train_dir = os.path.join(args.data_path, 'train.tar') train_transform = build_transform(True, args) dataset_train = DatasetTar(train_dir, transform=train_transform) args.nb_classes = 1000 val_transform = build_transform(False, args) eval_dir = os.path.join(args.data_path, 'val.tar') dataset_val = DatasetTar(eval_dir, transform=val_transform) else: dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) dataset_val, _ = build_dataset(is_train=False, args=args) if True: # args.distributed: num_tasks = utils.get_world_size() global_rank = utils.get_rank() if args.repeated_aug: sampler_train = RASampler( dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True ) else: sampler_train = torch.utils.data.DistributedSampler( dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True ) if args.dist_eval: if len(dataset_val) % num_tasks != 0: print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 'equal num of samples per-process.') sampler_val = torch.utils.data.DistributedSampler( dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) else: sampler_val = torch.utils.data.SequentialSampler(dataset_val) else: sampler_train = torch.utils.data.RandomSampler(dataset_train) sampler_val = torch.utils.data.SequentialSampler(dataset_val) data_loader_train = torch.utils.data.DataLoader( dataset_train, sampler=sampler_train, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=True, ) data_loader_val = torch.utils.data.DataLoader( dataset_val, sampler=sampler_val, batch_size=int(1.5 * args.batch_size), num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False ) mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_fn = Mixup( mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.nb_classes) print(f"Creating model: {args.model}") model = create_model( args.model, pretrained=args.pretrained, num_classes=args.nb_classes, drop_rate=args.drop, drop_path_rate=args.drop_path, drop_block_rate=None, ) if args.finetune: if args.finetune.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url( args.finetune, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.finetune, map_location='cpu') checkpoint_model = checkpoint['model'] state_dict = model.state_dict() for k in ['head.weight', 'head.bias', 'head_dist.weight', 'head_dist.bias']: if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: print(f"Removing key {k} from pretrained checkpoint") del checkpoint_model[k] # interpolate position embedding pos_embed_checkpoint = checkpoint_model['pos_embed'] embedding_size = pos_embed_checkpoint.shape[-1] num_patches = model.patch_embed.num_patches num_extra_tokens = model.pos_embed.shape[-2] - num_patches # height (== width) for the checkpoint position embedding orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) # height (== width) for the new position embedding new_size = int(num_patches ** 0.5) # class_token and dist_token are kept unchanged extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] # only the position tokens are interpolated pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) pos_tokens = torch.nn.functional.interpolate( pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) checkpoint_model['pos_embed'] = new_pos_embed model.load_state_dict(checkpoint_model, strict=False) model.to(device) model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEma( model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else '', resume='') model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print('number of params:', n_parameters) linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0 args.lr = linear_scaled_lr optimizer = create_optimizer(args, model_without_ddp) loss_scaler = NativeScaler() lr_scheduler, _ = create_scheduler(args, optimizer) criterion = LabelSmoothingCrossEntropy() if args.mixup > 0.: # smoothing is handled with mixup label transform criterion = SoftTargetCrossEntropy() elif args.smoothing: criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) else: criterion = torch.nn.CrossEntropyLoss() teacher_model = None if args.distillation_type != 'none': print(f"Creating teacher model: {args.teacher_model}") # teacher_pretrained is True when args.teacher_path is empty teacher_pretrained = not bool(args.teacher_path) teacher_model = create_model( args.teacher_model, pretrained=teacher_pretrained, num_classes=args.nb_classes, global_pool='avg', ) if not teacher_pretrained: if args.teacher_path.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url( args.teacher_path, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.teacher_path, map_location='cpu') teacher_model.load_state_dict(checkpoint['model']) teacher_model.to(device) teacher_model.eval() # wrap the criterion in our custom DistillationLoss, which # just dispatches to the original criterion if args.distillation_type is 'none' criterion = DistillationLoss( criterion, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau ) output_dir = Path(args.output_dir) if args.resume: if args.resume.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url( args.resume, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 if args.model_ema: utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) if 'scaler' in checkpoint: loss_scaler.load_state_dict(checkpoint['scaler']) if args.eval: test_stats = evaluate(data_loader_val, model, device) print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") return print(f"Start training for {args.epochs} epochs") start_time = time.time() max_accuracy = 0.0 for epoch in range(args.start_epoch, args.epochs): if args.distributed: data_loader_train.sampler.set_epoch(epoch) train_stats = train_one_epoch( model, criterion, data_loader_train, optimizer, device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn, set_training_mode=args.finetune == '' # keep in eval mode during finetuning ) lr_scheduler.step(epoch) if args.output_dir: checkpoint_paths = [output_dir / 'checkpoint.pth'] for checkpoint_path in checkpoint_paths: utils.save_on_master({ 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'model_ema': get_state_dict(model_ema), 'scaler': loss_scaler.state_dict(), 'args': args, }, checkpoint_path) test_stats = evaluate(data_loader_val, model, device) print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") max_accuracy = max(max_accuracy, test_stats["acc1"]) print(f'Max accuracy: {max_accuracy:.2f}%') log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, **{f'test_{k}': v for k, v in test_stats.items()}, 'epoch': epoch, 'n_parameters': n_parameters} if args.output_dir and utils.is_main_process(): with (output_dir / "log.txt").open("a") as f: f.write(json.dumps(log_stats) + "\n") total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str))
def validate(args): args.pretrained = args.pretrained or (not args.checkpoint) args.prefetcher = not args.no_prefetcher if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data): dataset = DatasetTar(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map) else: dataset = Dataset(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map) logging.info(f'Validation data has {len(dataset)} images') args.num_classes = len(dataset.class_to_idx) logging.info(f'setting num classes to {args.num_classes}') # create model model = create_model(args.model, num_classes=args.num_classes, in_chans=3, pretrained=args.pretrained, scriptable=args.torchscript, resnet_structure=args.resnet_structure, resnet_block=args.resnet_block, heaviest_network=args.heaviest_network, use_kernel_3=args.use_kernel_3, exp_r=args.exp_r, depth=args.depth, reduced_exp_ratio=args.reduced_exp_ratio, use_dedicated_pwl_se=args.use_dedicated_pwl_se, multipath_sampling=args.multipath_sampling, force_sync_gpu=args.force_sync_gpu, mobilenet_string=args.mobilenet_string if not args.transform_model_to_mobilenet else '', no_swish=args.no_swish, use_swish=args.use_swish) data_config = resolve_data_config(vars(args), model=model) if args.checkpoint: load_checkpoint(model, args.checkpoint, True, strict=True) if 'mobilenasnet' in args.model and args.transform_model_to_mobilenet: model.eval() expected_latency = model.extract_expected_latency( file_name=args.lut_filename, batch_size=args.lut_measure_batch_size, iterations=args.repeat_measure, target=args.target_device) model.eval() model2, string_model = transform_model_to_mobilenet( model, mobilenet_string=args.mobilenet_string) del model model = model2 model.eval() print('Model converted. Expected latency: {:0.2f}[ms]'.format( expected_latency * 1e3)) elif args.normalize_weights: IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) std = torch.tensor(IMAGENET_DEFAULT_STD).unsqueeze(0).unsqueeze( -1).unsqueeze(-1) mean = torch.tensor(IMAGENET_DEFAULT_MEAN).unsqueeze(0).unsqueeze( -1).unsqueeze(-1) W = model.conv_stem.weight.data bnw = model.bn1.weight.data bnb = model.bn1.bias.data model.conv_stem.weight.data = W / std bias = -bnw.data * (W.sum(dim=[-1, -2]) @ (mean / std).squeeze()) / ( torch.sqrt(model.bn1.running_var + model.bn1.eps)) model.bn1.bias.data = bnb + bias if args.fuse_bn: model = fuse_bn(model) if args.target_device == 'gpu': measure_time(model, batch_size=64, target='gpu') t = measure_time(model, batch_size=64, target='gpu') elif args.target_device == 'onnx': t = measure_time_onnx(model) else: measure_time(model) t = measure_time(model) param_count = sum([m.numel() for m in model.parameters()]) flops = compute_flops(model, data_config['input_size']) logging.info( 'Model {} created, param count: {}, flops: {}, Measured latency ({}): {:0.2f}[ms]' .format(args.model, param_count, flops / 1e9, args.target_device, t * 1e3)) data_config = resolve_data_config(vars(args), model=model, verbose=False) model, test_time_pool = apply_test_time_pool(model, data_config, args) if args.torchscript: torch.jit.optimized_execution(True) model = torch.jit.script(model) if args.amp: model = amp.initialize(model.cuda(), opt_level='O1') else: model = model.cuda() if args.num_gpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))) criterion = nn.CrossEntropyLoss().cuda() crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] loader = create_loader( dataset, input_size=data_config['input_size'], batch_size=args.batch_size, is_training=False, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, crop_pct=crop_pct, pin_memory=args.pin_mem, tf_preprocessing=args.tf_preprocessing, squish=args.squish, ) batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() model.cuda() model.eval() with torch.no_grad(): # warmup, reduce variability of first batch time, especially for comparing torchscript vs non input = torch.randn((args.batch_size, ) + data_config['input_size']).cuda() model(input) end = time.time() for i, (input, target) in enumerate(loader): if i == 0: end = time.time() if args.no_prefetcher: target = target.cuda() input = input.cuda() if args.amp: input = input.half() # compute output output = model(input) loss = criterion(output, target) # measure accuracy and record loss k = min(5, args.num_classes) acc1, acc5 = accuracy(output.data, target, topk=(1, k)) losses.update(loss.item(), input.size(0)) top1.update(acc1.item(), input.size(0)) top5.update(acc5.item(), input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.log_freq == 0: logging.info( 'Test: [{0:>4d}/{1}] ' 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' 'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format( i, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, loss=losses, top1=top1, top5=top5)) results = OrderedDict(top1=round(top1.avg, 4), top1_err=round(100 - top1.avg, 4), top5=round(top5.avg, 4), top5_err=round(100 - top5.avg, 4), param_count=round(param_count / 1e6, 2), img_size=data_config['input_size'][-1], cropt_pct=crop_pct, interpolation=data_config['interpolation']) logging.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format( results['top1'], results['top1_err'], results['top5'], results['top5_err'])) return results