def evaluate(data_loader, model, device): criterion = torch.nn.CrossEntropyLoss() metric_logger = utils.MetricLogger(delimiter=" ") header = 'Test:' # switch to evaluation mode model.eval() for images, target in metric_logger.log_every(data_loader, 10, header): images = images.to(device, non_blocking=True) target = target.to(device, non_blocking=True) # compute output with torch.cuda.amp.autocast(): output = model(images) loss = criterion(output, target) acc1, acc5 = accuracy(output, target, topk=(1, 5)) batch_size = images.shape[0] metric_logger.update(loss=loss.item()) metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) # gather the stats from all processes metric_logger.synchronize_between_processes() print( '* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) print(output.mean().item(), output.std().item()) return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
def validate(self, net, loader, loss_fn, amp_autocast=suppress, metric_name=None): losses_m = AverageMeter() top1_m = AverageMeter() top5_m = AverageMeter() rmse_m = AverageMeter() net.eval() with torch.no_grad(): for batch_idx, (input, target) in enumerate(loader): if not self._misc_cfg.prefetcher: input = input.to(self.ctx[0]) target = target.to(self.ctx[0]) with amp_autocast(): output = net(input) if self._problem_type == REGRESSION: output = output.flatten() if isinstance(output, (tuple, list)): output = output[0] if self._problem_type == REGRESSION: if metric_name: assert metric_name == 'rmse', f'{metric_name} metric not supported for regression.' val_metric_score = rmse(output, target) else: val_metric_score = accuracy(output, target, topk=(1, min(5, self.num_class))) # augmentation reduction reduce_factor = self._misc_cfg.tta if self._problem_type != REGRESSION and reduce_factor > 1: output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2) target = target[0:target.size(0):reduce_factor] loss = loss_fn(output, target) reduced_loss = loss.data if self.found_gpu: torch.cuda.synchronize() losses_m.update(reduced_loss.item(), input.size(0)) if self._problem_type == REGRESSION: rmse_score = val_metric_score rmse_m.update(rmse_score.item(), output.size(0)) else: acc1, acc5 = val_metric_score acc1 /= 100 acc5 /= 100 top1_m.update(acc1.item(), output.size(0)) top5_m.update(acc5.item(), output.size(0)) if self._problem_type == REGRESSION: self._logger.info('[Epoch %d] validation: rmse=%f', self.epoch, rmse_m.avg) return {'loss': losses_m.avg, 'rmse': rmse_m.avg} else: self._logger.info('[Epoch %d] validation: top1=%f top5=%f', self.epoch, top1_m.avg, top5_m.avg) return {'loss': losses_m.avg, 'top1': top1_m.avg, 'top5': top5_m.avg}
def evaluate(data_loader, model, device, amp=True, distill_token=False, choices=None, mode='super', retrain_config=None): criterion = torch.nn.CrossEntropyLoss() metric_logger = utils.MetricLogger(delimiter=" ") header = 'Test:' # switch to evaluation mode model.eval() if mode == 'super': config = sample_configs(choices=choices) model_module = unwrap_model(model) model_module.set_sample_config(config=config) else: config = retrain_config model_module = unwrap_model(model) model_module.set_sample_config(config=config) print("sampled model config: {}".format(config)) parameters = model_module.get_sampled_params_numel(config) print("sampled model parameters: {}".format(parameters)) for images, target in metric_logger.log_every(data_loader, 10, header): images = images.to(device, non_blocking=True) target = target.to(device, non_blocking=True) # compute output if amp: with torch.cuda.amp.autocast(): if distill_token: output_cls, output_dis = model(images) output = (output_cls + output_dis)/2 else: output = model(images) loss = criterion(output, target) else: if distill_token: output_cls, output_dis = model(images) output = (output_cls + output_dis) / 2 else: output = model(images) loss = criterion(output, target) acc1, acc5 = accuracy(output, target, topk=(1, 5)) batch_size = images.shape[0] metric_logger.update(loss=loss.item()) metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) # gather the stats from all processes metric_logger.synchronize_between_processes() print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
def validation_step(self, batch, batch_idx): samples, targets = batch targets = targets.long() outputs = self(samples) loss = F.cross_entropy(outputs, targets) acc1, acc5 = accuracy(outputs, targets, topk=(1, 5)) self.log_dict({ 'validation loss': loss, 'acc1': acc1, 'acc5': acc5 }, on_epoch=True)
def validate(model, loader, criterion, log_freq=50): batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() model.eval() end = time.time() with torch.no_grad(): for i, (input, target) in enumerate(loader): target = target.cuda() input = input.cuda() # compute output output = model(input) loss = criterion(output, target) # measure accuracy and record loss acc1, acc5 = accuracy(output.data, target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(acc1.item(), input.size(0)) top5.update(acc5.item(), input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % log_freq == 0: logging.info( 'Test: [{0:>4d}/{1}] ' 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' 'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format( i, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, loss=losses, top1=top1, top5=top5)) results = OrderedDict(top1=round(top1.avg, 4), top1_err=round(100 - top1.avg, 4), top5=round(top5.avg, 4), top5_err=round(100 - top5.avg, 4)) logging.info(' * Acc@1 {:.1f} ({:.3f}) Acc@5 {:.1f} ({:.3f})'.format( results['top1'], results['top1_err'], results['top5'], results['top5_err']))
def validate(args, config, data_loader, model, num_classes=1000): criterion = torch.nn.CrossEntropyLoss() model.eval() batch_time = AverageMeter() loss_meter = AverageMeter() acc1_meter = AverageMeter() acc5_meter = AverageMeter() end = time.time() for idx, (images, target) in enumerate(data_loader): images = images.cuda(non_blocking=True) target = target.cuda(non_blocking=True) # compute output with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE): output = model(images) if num_classes == 1000: output_num_classes = output.size(-1) if output_num_classes == 21841: output = remap_layer_22kto1k(output) # measure accuracy and record loss loss = criterion(output, target) acc1, acc5 = accuracy(output, target, topk=(1, 5)) loss_meter.update(loss.item(), target.size(0)) acc1_meter.update(acc1.item(), target.size(0)) acc5_meter.update(acc5.item(), target.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if idx % config.PRINT_FREQ == 0: memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) logger.info(f'Test: [{idx}/{len(data_loader)}]\t' f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t' f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t' f'Mem {memory_used:.0f}MB') acc1_meter.sync() acc5_meter.sync() logger.info( f' The number of validation samples is {int(acc1_meter.count)}') logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}') return acc1_meter.avg, acc5_meter.avg, loss_meter.avg
def validate(config, data_loader, model): criterion = torch.nn.CrossEntropyLoss() model.eval() batch_time = AverageMeter() loss_meter = AverageMeter() acc1_meter = AverageMeter() acc5_meter = AverageMeter() end = time.time() for idx, (images, target) in enumerate(data_loader): images = images.cuda(non_blocking=True) target = target.cuda(non_blocking=True) # compute output output = model(images) # measure accuracy and record loss loss = criterion(output, target) acc1, acc5 = accuracy(output, target, topk=(1, 5)) # For Distrubuted Training # acc1 = reduce_tensor(acc1) # acc5 = reduce_tensor(acc5) # loss = reduce_tensor(loss) loss_meter.update(loss.item(), target.size(0)) acc1_meter.update(acc1.item(), target.size(0)) acc5_meter.update(acc5.item(), target.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if idx % config.PRINT_FREQ == 0: memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) logger.info(f'Test: [{idx}/{len(data_loader)}]\t' f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t' f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t' f'Mem {memory_used:.0f}MB') logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}') return acc1_meter.avg, acc5_meter.avg, loss_meter.avg
def train_one_epoch_distill(config, model, model_teacher, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, criterion_soft=None, criterion_truth=None, criterion_attn=None, criterion_hidden=None): layer_id_s_list = config.DISTILL.STUDENT_LAYER_LIST layer_id_t_list = config.DISTILL.TEACHER_LAYER_LIST model.train() optimizer.zero_grad() model_teacher.eval() num_steps = len(data_loader) batch_time = AverageMeter() loss_meter = AverageMeter() norm_meter = AverageMeter() loss_soft_meter = AverageMeter() loss_truth_meter = AverageMeter() loss_attn_meter = AverageMeter() loss_hidden_meter = AverageMeter() acc1_meter = AverageMeter() acc5_meter = AverageMeter() teacher_acc1_meter = AverageMeter() teacher_acc5_meter = AverageMeter() start = time.time() end = time.time() for idx, (samples, targets) in enumerate(data_loader): samples = samples.cuda(non_blocking=True) targets = targets.cuda(non_blocking=True) original_targets = targets if mixup_fn is not None: samples, targets = mixup_fn(samples, targets) if config.DISTILL.ATTN_LOSS and config.DISTILL.HIDDEN_LOSS: outputs, qkv_s, hidden_s = model( samples, layer_id_s_list, is_attn_loss=True, is_hidden_loss=True, is_hidden_org=config.DISTILL.HIDDEN_RELATION) elif config.DISTILL.ATTN_LOSS: outputs, qkv_s = model( samples, layer_id_s_list, is_attn_loss=True, is_hidden_loss=False, is_hidden_org=config.DISTILL.HIDDEN_RELATION) elif config.DISTILL.HIDDEN_LOSS: outputs, hidden_s = model( samples, layer_id_s_list, is_attn_loss=False, is_hidden_loss=True, is_hidden_org=config.DISTILL.HIDDEN_RELATION) else: outputs = model(samples) with torch.no_grad(): acc1, acc5 = accuracy(outputs, original_targets, topk=(1, 5)) if config.DISTILL.ATTN_LOSS or config.DISTILL.HIDDEN_LOSS: outputs_teacher, qkv_t, hidden_t = model_teacher( samples, layer_id_t_list, is_attn_loss=True, is_hidden_loss=True) else: outputs_teacher = model_teacher(samples) teacher_acc1, teacher_acc5 = accuracy(outputs_teacher, original_targets, topk=(1, 5)) if config.TRAIN.ACCUMULATION_STEPS > 1: loss_truth = config.DISTILL.ALPHA * criterion_truth( outputs, targets) loss_soft = (1.0 - config.DISTILL.ALPHA) * criterion_soft( outputs / config.DISTILL.TEMPERATURE, outputs_teacher / config.DISTILL.TEMPERATURE) if config.DISTILL.ATTN_LOSS: loss_attn = config.DISTILL.QKV_LOSS_WEIGHT * criterion_attn( qkv_s, qkv_t, config.DISTILL.AR) else: loss_attn = torch.zeros(loss_truth.shape) if config.DISTILL.HIDDEN_LOSS: loss_hidden = config.DISTILL.HIDDEN_LOSS_WEIGHT * criterion_hidden( hidden_s, hidden_t) else: loss_hidden = torch.zeros(loss_truth.shape) loss = loss_truth + loss_soft + loss_attn + loss_hidden loss = loss / config.TRAIN.ACCUMULATION_STEPS if config.AMP_OPT_LEVEL != "O0": with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() if config.TRAIN.CLIP_GRAD: grad_norm = torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), config.TRAIN.CLIP_GRAD) else: grad_norm = get_grad_norm(amp.master_params(optimizer)) else: loss.backward() if config.TRAIN.CLIP_GRAD: grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), config.TRAIN.CLIP_GRAD) else: grad_norm = get_grad_norm(model.parameters()) if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0: optimizer.step() optimizer.zero_grad() lr_scheduler.step_update(epoch * num_steps + idx) else: loss_truth = config.DISTILL.ALPHA * criterion_truth( outputs, targets) loss_soft = (1.0 - config.DISTILL.ALPHA) * criterion_soft( outputs / config.DISTILL.TEMPERATURE, outputs_teacher / config.DISTILL.TEMPERATURE) if config.DISTILL.ATTN_LOSS: loss_attn = config.DISTILL.QKV_LOSS_WEIGHT * criterion_attn( qkv_s, qkv_t, config.DISTILL.AR) else: loss_attn = torch.zeros(loss_truth.shape) if config.DISTILL.HIDDEN_LOSS: loss_hidden = config.DISTILL.HIDDEN_LOSS_WEIGHT * criterion_hidden( hidden_s, hidden_t) else: loss_hidden = torch.zeros(loss_truth.shape) loss = loss_truth + loss_soft + loss_attn + loss_hidden optimizer.zero_grad() if config.AMP_OPT_LEVEL != "O0": with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() if config.TRAIN.CLIP_GRAD: grad_norm = torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), config.TRAIN.CLIP_GRAD) else: grad_norm = get_grad_norm(amp.master_params(optimizer)) else: loss.backward() if config.TRAIN.CLIP_GRAD: grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), config.TRAIN.CLIP_GRAD) else: grad_norm = get_grad_norm(model.parameters()) optimizer.step() lr_scheduler.step_update(epoch * num_steps + idx) torch.cuda.synchronize() loss_meter.update(loss.item(), targets.size(0)) loss_soft_meter.update(loss_soft.item(), targets.size(0)) loss_truth_meter.update(loss_truth.item(), targets.size(0)) loss_attn_meter.update(loss_attn.item(), targets.size(0)) loss_hidden_meter.update(loss_hidden.item(), targets.size(0)) norm_meter.update(grad_norm) batch_time.update(time.time() - end) end = time.time() acc1_meter.update(acc1.item(), targets.size(0)) acc5_meter.update(acc5.item(), targets.size(0)) teacher_acc1_meter.update(teacher_acc1.item(), targets.size(0)) teacher_acc5_meter.update(teacher_acc5.item(), targets.size(0)) if idx % config.PRINT_FREQ == 0: lr = optimizer.param_groups[0]['lr'] memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) etas = batch_time.avg * (num_steps - idx) logger.info( f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t' f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t' f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t' f'Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}\t' f'Teacher_Acc@1 {teacher_acc1_meter.avg:.3f} Teacher_Acc@5 {teacher_acc5_meter.avg:.3f}\t' f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' f'loss_soft {loss_soft_meter.val:.4f} ({loss_soft_meter.avg:.4f})\t' f'loss_truth {loss_truth_meter.val:.4f} ({loss_truth_meter.avg:.4f})\t' f'loss_attn {loss_attn_meter.val:.4f} ({loss_attn_meter.avg:.4f})\t' f'loss_hidden {loss_hidden_meter.val:.4f} ({loss_hidden_meter.avg:.4f})\t' f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t' f'mem {memory_used:.0f}MB') epoch_time = time.time() - start logger.info( f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}" )
def validate(args): args.pretrained = args.pretrained or (not args.checkpoint) args.prefetcher = not args.no_prefetcher if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data): dataset = DatasetTar(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map) else: dataset = Dataset(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map) logging.info(f'Validation data has {len(dataset)} images') args.num_classes = len(dataset.class_to_idx) logging.info(f'setting num classes to {args.num_classes}') # create model model = create_model(args.model, num_classes=args.num_classes, in_chans=3, pretrained=args.pretrained, scriptable=args.torchscript, resnet_structure=args.resnet_structure, resnet_block=args.resnet_block, heaviest_network=args.heaviest_network, use_kernel_3=args.use_kernel_3, exp_r=args.exp_r, depth=args.depth, reduced_exp_ratio=args.reduced_exp_ratio, use_dedicated_pwl_se=args.use_dedicated_pwl_se, multipath_sampling=args.multipath_sampling, force_sync_gpu=args.force_sync_gpu, mobilenet_string=args.mobilenet_string if not args.transform_model_to_mobilenet else '', no_swish=args.no_swish, use_swish=args.use_swish) data_config = resolve_data_config(vars(args), model=model) if args.checkpoint: load_checkpoint(model, args.checkpoint, True, strict=True) if 'mobilenasnet' in args.model and args.transform_model_to_mobilenet: model.eval() expected_latency = model.extract_expected_latency( file_name=args.lut_filename, batch_size=args.lut_measure_batch_size, iterations=args.repeat_measure, target=args.target_device) model.eval() model2, string_model = transform_model_to_mobilenet( model, mobilenet_string=args.mobilenet_string) del model model = model2 model.eval() print('Model converted. Expected latency: {:0.2f}[ms]'.format( expected_latency * 1e3)) elif args.normalize_weights: IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) std = torch.tensor(IMAGENET_DEFAULT_STD).unsqueeze(0).unsqueeze( -1).unsqueeze(-1) mean = torch.tensor(IMAGENET_DEFAULT_MEAN).unsqueeze(0).unsqueeze( -1).unsqueeze(-1) W = model.conv_stem.weight.data bnw = model.bn1.weight.data bnb = model.bn1.bias.data model.conv_stem.weight.data = W / std bias = -bnw.data * (W.sum(dim=[-1, -2]) @ (mean / std).squeeze()) / ( torch.sqrt(model.bn1.running_var + model.bn1.eps)) model.bn1.bias.data = bnb + bias if args.fuse_bn: model = fuse_bn(model) if args.target_device == 'gpu': measure_time(model, batch_size=64, target='gpu') t = measure_time(model, batch_size=64, target='gpu') elif args.target_device == 'onnx': t = measure_time_onnx(model) else: measure_time(model) t = measure_time(model) param_count = sum([m.numel() for m in model.parameters()]) flops = compute_flops(model, data_config['input_size']) logging.info( 'Model {} created, param count: {}, flops: {}, Measured latency ({}): {:0.2f}[ms]' .format(args.model, param_count, flops / 1e9, args.target_device, t * 1e3)) data_config = resolve_data_config(vars(args), model=model, verbose=False) model, test_time_pool = apply_test_time_pool(model, data_config, args) if args.torchscript: torch.jit.optimized_execution(True) model = torch.jit.script(model) if args.amp: model = amp.initialize(model.cuda(), opt_level='O1') else: model = model.cuda() if args.num_gpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))) criterion = nn.CrossEntropyLoss().cuda() crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] loader = create_loader( dataset, input_size=data_config['input_size'], batch_size=args.batch_size, is_training=False, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, crop_pct=crop_pct, pin_memory=args.pin_mem, tf_preprocessing=args.tf_preprocessing, squish=args.squish, ) batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() model.cuda() model.eval() with torch.no_grad(): # warmup, reduce variability of first batch time, especially for comparing torchscript vs non input = torch.randn((args.batch_size, ) + data_config['input_size']).cuda() model(input) end = time.time() for i, (input, target) in enumerate(loader): if i == 0: end = time.time() if args.no_prefetcher: target = target.cuda() input = input.cuda() if args.amp: input = input.half() # compute output output = model(input) loss = criterion(output, target) # measure accuracy and record loss k = min(5, args.num_classes) acc1, acc5 = accuracy(output.data, target, topk=(1, k)) losses.update(loss.item(), input.size(0)) top1.update(acc1.item(), input.size(0)) top5.update(acc5.item(), input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.log_freq == 0: logging.info( 'Test: [{0:>4d}/{1}] ' 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' 'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format( i, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, loss=losses, top1=top1, top5=top5)) results = OrderedDict(top1=round(top1.avg, 4), top1_err=round(100 - top1.avg, 4), top5=round(top5.avg, 4), top5_err=round(100 - top5.avg, 4), param_count=round(param_count / 1e6, 2), img_size=data_config['input_size'][-1], cropt_pct=crop_pct, interpolation=data_config['interpolation']) logging.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format( results['top1'], results['top1_err'], results['top5'], results['top5_err'])) return results
def validate(args): # might as well try to validate something args.pretrained = args.pretrained or not args.checkpoint args.prefetcher = not args.no_prefetcher amp_autocast = suppress # do nothing if args.amp: if has_apex: args.apex_amp = True elif has_native_amp: args.native_amp = True else: _logger.warning( "Neither APEX or Native Torch AMP is available, using FP32.") assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set." if args.native_amp: amp_autocast = torch.cuda.amp.autocast if args.legacy_jit: set_jit_legacy() # create model model = create_model(args.model, pretrained=args.pretrained, num_classes=args.num_classes, in_chans=3, global_pool=args.gp, scriptable=args.torchscript) if args.checkpoint: load_checkpoint(model, args.checkpoint, args.use_ema) param_count = sum([m.numel() for m in model.parameters()]) _logger.info('Model %s created, param count: %d' % (args.model, param_count)) data_config = resolve_data_config(vars(args), model=model) model, test_time_pool = ( model, False) if args.no_test_pool else apply_test_time_pool( model, data_config) if args.torchscript: torch.jit.optimized_execution(True) model = torch.jit.script(model) model = model.cuda() if args.apex_amp: model = amp.initialize(model, opt_level='O1') if args.channels_last: model = model.to(memory_format=torch.channels_last) if args.num_gpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))) criterion = nn.CrossEntropyLoss().cuda() if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data): dataset = DatasetTar(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map) else: dataset = Dataset(args.data, train_mode='val', fold_num=args.fold_num, load_bytes=args.tf_preprocessing, class_map=args.class_map) if args.valid_labels: with open(args.valid_labels, 'r') as f: valid_labels = {int(line.rstrip()) for line in f} valid_labels = [i in valid_labels for i in range(args.num_classes)] else: valid_labels = None if args.real_labels: real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels) else: real_labels = None crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] loader = create_loader(dataset, input_size=data_config['input_size'], batch_size=args.batch_size, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, crop_pct=crop_pct, pin_memory=args.pin_mem, tf_preprocessing=args.tf_preprocessing) batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() # top5 = AverageMeter() f1_m = AverageMeter() model.eval() last_idx = len(loader) - 1 cuda = torch.device('cuda') temperature = nn.Parameter(torch.ones(1) * 1.5).to(cuda).detach().requires_grad_(True) m = nn.Sigmoid() nll_criterion = nn.CrossEntropyLoss().cuda() ece_criterion = _ECELoss().cuda() with torch.no_grad(): # warmup, reduce variability of first batch time, especially for comparing torchscript vs non input = torch.randn((args.batch_size, ) + data_config['input_size']).cuda() if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) model(input) end = time.time() logits_list = [] target_list = [] for batch_idx, (input, target) in enumerate(loader): last_batch = batch_idx == last_idx if args.no_prefetcher: target = target.cuda() input = input.cuda() if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) # compute output with amp_autocast(): output = model(input) if valid_labels is not None: output = output[:, valid_labels] loss = criterion(output, target) if real_labels is not None: real_labels.add_result(output) # measure accuracy and record loss acc1, _ = accuracy(output.detach(), target, topk=(1, 1)) logits_list.append(output) target_list.append(target) best_f1 = 0.0 best_th = 1.0 if last_batch: logits = torch.cat(logits_list).cuda() ### targets = torch.cat(target_list).cuda() ### targets_cpu = targets.cpu().numpy() sigmoided = m(logits)[:, 1].cpu().numpy() for i in range(1000, 0, -1): th = i * 0.001 real_pred = (sigmoided >= th) * 1.0 f1 = f1_score(targets_cpu.squeeze(), real_pred.squeeze()) if f1 > best_f1: best_f1 = f1 best_th = th losses.update(loss.item(), input.size(0)) top1.update(acc1.item(), input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if batch_idx % args.log_freq == 0: _logger.info( 'Test: [{0:>4d}/{1}] ' 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' 'thresh: {thresh:>7.4f} ' 'f1: {f1:>7.4f}'.format(batch_idx, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, loss=losses, top1=top1, thresh=best_th, f1=best_f1)) print(best_th, best_f1) #for temp_scalilng if args.temp_scaling: # before_temperature_ece = ece_criterion(logits, targets).item() # before_temperature_nll = nll_criterion(logits, targets).item() # print('Before temperature - NLL: %.3f, ECE: %.3f' % (before_temperature_nll, before_temperature_ece)) # optimizer = optim.LBFGS([temperature], lr=0.01, max_iter=50) # def eval(): # unsqueezed_temperature = temperature.unsqueeze(1).expand(logits.size(0), logits.size(1)) # loss = nll_criterion(logits/unsqueezed_temperature, targets) # loss.backward() # return loss # optimizer.step(eval) # unsqueezed_temperature = temperature.unsqueeze(1).expand(logits.size(0), logits.size(1)) # logits = logits/unsqueezed_temperature # after_temperature_nll = nll_criterion(logits, targets).item() # after_temperature_ece = ece_criterion(logits, targets).item() # print('Optimal temperature: %.3f' % temperature.item()) # print('After temperature - NLL: %.3f, ECE: %.3f' % (after_temperature_nll, after_temperature_ece)) sigmoided = m(logits)[:, 1].detach().cpu().numpy() temperature = nn.Parameter(torch.ones(1) * 11).to(cuda).detach().requires_grad_(False) logits = logits / temperature.unsqueeze(1).expand( logits.size(0), logits.size(1)) targets_cpu = targets.cpu().numpy() sigmoided = m(logits)[:, 1].detach().cpu().numpy() best_f1 = 0.0 best_th = 1.0 for i in range(1000, 0, -1): th = i * 0.001 real_pred = (sigmoided >= th) * 1.0 f1 = f1_score(targets_cpu.squeeze(), real_pred.squeeze()) if f1 > best_f1: best_f1 = f1 best_th = th print(best_th, best_f1) if real_labels is not None: # real labels mode replaces topk values at the end top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy( k=5) else: top1a, f1a = top1.avg, best_f1 results = OrderedDict(top1=round(top1a, 4), top1_err=round(100 - top1a, 4), f1=f1a, f1_err=round(100 - f1a, 4), param_count=round(param_count / 1e6, 2), img_size=data_config['input_size'][-1], cropt_pct=crop_pct, interpolation=data_config['interpolation']) _logger.info(' * Acc@1 {:.3f} ({:.3f}) f1 {:.3f} ({:.3f})'.format( results['top1'], results['top1_err'], results['f1'], results['f1_err'])) return results
def validate(args): # might as well try to validate something args.pretrained = args.pretrained or not args.checkpoint args.prefetcher = not args.no_prefetcher # create model model = create_model(args.model, num_classes=args.num_classes, in_chans=3, pretrained=args.pretrained) if args.checkpoint: load_checkpoint(model, args.checkpoint, args.use_ema) param_count = sum([m.numel() for m in model.parameters()]) logging.info('Model %s created, param count: %d' % (args.model, param_count)) data_config = resolve_data_config(vars(args), model=model) model, test_time_pool = apply_test_time_pool(model, data_config, args) if args.num_gpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range( args.num_gpu))).cuda() else: model = model.cuda() if args.fp16: model = model.half() criterion = nn.CrossEntropyLoss().cuda() if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data): dataset = DatasetTar(args.data, load_bytes=args.tf_preprocessing) else: dataset = Dataset(args.data, load_bytes=args.tf_preprocessing) crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] loader = create_loader(dataset, input_size=data_config['input_size'], batch_size=args.batch_size, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, crop_pct=crop_pct, fp16=args.fp16, tf_preprocessing=args.tf_preprocessing) batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() c_matrix = np.zeros((40, 40), dtype=int) labels = np.arange(0, 40, 1) model.eval() end = time.time() with torch.no_grad(): cf = open('results.csv', 'w') cv = open('results-parent.csv', 'w') writer = csv.writer(cf) writer_2 = csv.writer(cv) for i, (input, target) in enumerate(loader): if args.no_prefetcher: target = target.cuda() input = input.cuda() if args.fp16: input = input.half() # compute output output = model(input) loss = criterion(output, target) # measure accuracy and record loss prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(prec1.item(), input.size(0)) top5.update(prec5.item(), input.size(0)) c_matrix += cal_confusions(output, target, labels=labels) # measure elapsed time batch_time.update(time.time() - end) end = time.time() writer.writerow([i, round(top1.avg, 4)]) # 计算大类分类准确率 if args.hier_classify: a = [i for i in range(0, 6)] b = [i for i in range(6, 14)] c = [i for i in range(14, 37)] d = [i for i in range(37, 40)] corrects = 0. corrects += c_matrix[a][:, a].sum() corrects += c_matrix[b][:, b].sum() corrects += c_matrix[c][:, c].sum() corrects += c_matrix[d][:, d].sum() writer_2.writerow([i, round(corrects / c_matrix.sum(), 4)]) logging.info('parent precision: {}'.format(corrects / c_matrix.sum())) if i % args.log_freq == 0: logging.info( 'Test: [{0:>4d}/{1}] ' 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Prec@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' 'Prec@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format( i, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, loss=losses, top1=top1, top5=top5)) cf.close() cv.close() results = OrderedDict(top1=round(top1.avg, 4), top1_err=round(100 - top1.avg, 4), top5=round(top5.avg, 4), top5_err=round(100 - top5.avg, 4), param_count=round(param_count / 1e6, 2), img_size=data_config['input_size'][-1], cropt_pct=crop_pct, interpolation=data_config['interpolation']) logging.info(' * Prec@1 {:.3f} ({:.3f}) Prec@5 {:.3f} ({:.3f})'.format( results['top1'], results['top1_err'], results['top5'], results['top5_err'])) logging.info('confusion_matrix: \n {}'.format(c_matrix)) logging.info('precision by confusion matrix: \n {}'.format( truediv(np.sum(np.diag(c_matrix)), np.sum(np.sum(c_matrix, axis=1))))) # with open('confusion_matrix.csv', 'w') as cf: # writer = csv.writer(cf) # for row in c_matrix: # writer.writerow(row) # # diag = np.diag(c_matrix) # each_acc = truediv(diag, np.sum(c_matrix, axis=1)) # writer.writerow(each_acc) return results
def validate(args): # might as well try to validate something args.pretrained = args.pretrained or not args.checkpoint args.prefetcher = not args.no_prefetcher if args.legacy_jit: set_jit_legacy() # create model if 'inception' in args.model: model = create_model( args.model, pretrained=args.pretrained, num_classes=args.num_classes, aux_logits=True, # ! add aux loss in_chans=3, scriptable=args.torchscript) else: model = create_model(args.model, pretrained=args.pretrained, num_classes=args.num_classes, in_chans=3, scriptable=args.torchscript) # ! add more layer to classifier layer if args.create_classifier_layerfc: model.global_pool, model.classifier = create_classifier_layerfc( model.num_features, model.num_classes) if args.checkpoint: load_checkpoint(model, args.checkpoint, args.use_ema) param_count = sum([m.numel() for m in model.parameters()]) _logger.info('Model %s created, param count: %d' % (args.model, param_count)) data_config = resolve_data_config(vars(args), model=model) model, test_time_pool = apply_test_time_pool(model, data_config, args) if args.torchscript: torch.jit.optimized_execution(True) model = torch.jit.script(model) if args.amp: model = amp.initialize(model.cuda(), opt_level='O1') else: model = model.cuda() if args.num_gpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))) if args.has_eval_label: criterion = nn.CrossEntropyLoss().cuda() # ! don't have gold label if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data): dataset = DatasetTar(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map) else: dataset = Dataset(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map, args=args) if args.valid_labels: with open(args.valid_labels, 'r') as f: # @valid_labels is index numbering valid_labels = {int(line.rstrip()) for line in f} valid_labels = [i in valid_labels for i in range(args.num_classes)] else: valid_labels = None if args.real_labels: real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels) else: real_labels = None crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] loader = create_loader( dataset, input_size=data_config['input_size'], batch_size=args.batch_size, use_prefetcher=args.prefetcher, interpolation=data_config[ 'interpolation'], # 'blank' is default Image.BILINEAR https://github.com/rwightman/pytorch-image-models/blob/470220b1f4c61ad7deb16dbfb8917089e842cd2a/timm/data/transforms.py#L43 mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, crop_pct=crop_pct, pin_memory=args.pin_mem, tf_preprocessing=args.tf_preprocessing, auto_augment=args.aa, scale=args.scale, ratio=args.ratio, hflip=args.hflip, vflip=args.vflip, color_jitter=args.color_jitter, args=args) batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() topk = AverageMeter() prediction = None # ! need to save output true_label = None model.eval() with torch.no_grad(): # warmup, reduce variability of first batch time, especially for comparing torchscript vs non input = torch.randn((args.batch_size, ) + data_config['input_size']).cuda() model(input) end = time.time() for batch_idx, (input, target) in enumerate(loader): # ! not have real label if args.has_eval_label: # ! just save true labels anyway... why not if true_label is None: true_label = target.cpu().data.numpy() else: true_label = np.concatenate( (true_label, target.cpu().data.numpy()), axis=0) if args.no_prefetcher: target = target.cuda() input = input.cuda() if args.fp16: input = input.half() # compute output output = model(input) if isinstance(output, (tuple, list)): output = output[0] # ! some model returns both loss + aux loss if valid_labels is not None: output = output[:, valid_labels] # ! keep only valid labels ? good to eval by class. # ! save prediction, don't append too slow ... whatever ? # ! are names of files also sorted ? if prediction is None: prediction = output.cpu().data.numpy() # batchsize x label else: # stack prediction = np.concatenate( (prediction, output.cpu().data.numpy()), axis=0) if real_labels is not None: real_labels.add_result(output) if args.has_eval_label: # measure accuracy and record loss loss = criterion( output, target) # ! don't have gold standard on testset acc1, acc5 = accuracy(output.data, target, topk=(1, args.topk)) losses.update(loss.item(), input.size(0)) top1.update(acc1.item(), input.size(0)) topk.update(acc5.item(), input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if args.has_eval_label and (batch_idx % args.log_freq == 0): _logger.info( 'Test: [{0:>4d}/{1}] ' 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' 'Acc@topk: {topk.val:>7.3f} ({topk.avg:>7.3f})'.format( batch_idx, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, loss=losses, top1=top1, topk=topk)) if not args.has_eval_label: top1a, topka = 0, 0 # just dummy, because we don't know ground labels else: if real_labels is not None: # real labels mode replaces topk values at the end top1a, topka = real_labels.get_accuracy( k=1), real_labels.get_accuracy(k=args.topk) else: top1a, topka = top1.avg, topk.avg results = OrderedDict(top1=round(top1a, 4), top1_err=round(100 - top1a, 4), topk=round(topka, 4), topk_err=round(100 - topka, 4), param_count=round(param_count / 1e6, 2), img_size=data_config['input_size'][-1], cropt_pct=crop_pct, interpolation=data_config['interpolation']) _logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@topk {:.3f} ({:.3f})'.format( results['top1'], results['top1_err'], results['topk'], results['topk_err'])) return results, prediction, true_label
def validate(args): # might as well try to validate something args.pretrained = args.pretrained or not args.checkpoint args.prefetcher = not args.no_prefetcher # amp_autocast = suppress # do nothing # if args.amp: # if has_native_amp: # args.native_amp = True # elif has_apex: # args.apex_amp = True # else: # _logger.warning("Neither APEX or Native Torch AMP is available.") # assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set." # if args.native_amp: # amp_autocast = torch.cuda.amp.autocast # _logger.info('Validating in mixed precision with native PyTorch AMP.') # elif args.apex_amp: # _logger.info('Validating in mixed precision with NVIDIA APEX AMP.') # else: # _logger.info('Validating in float32. AMP not enabled.') if args.legacy_jit: set_jit_legacy() # create model model = create_model( args.model, pretrained=args.pretrained, num_classes=args.num_classes, in_chans=3, global_pool=args.gp, scriptable=args.torchscript) if args.num_classes is None: assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' args.num_classes = model.num_classes if args.checkpoint: load_checkpoint(model, args.checkpoint, args.use_ema) param_count = sum([m.numel() for m in model.parameters()]) _logger.info('Model %s created, param count: %d' % (args.model, param_count)) data_config = resolve_data_config(vars(args), model=model, use_test_size=True) test_time_pool = False if not args.no_test_pool: model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True) if args.torchscript: torch.jit.optimized_execution(True) model = torch.jit.script(model) # model = model.cuda() # if args.apex_amp: # model = amp.initialize(model, opt_level='O1') if args.channels_last: model = model.to(memory_format=torch.channels_last) # if args.num_gpu > 1: # model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))) # criterion = nn.CrossEntropyLoss().cuda() criterion = nn.CrossEntropyLoss() dataset = create_dataset( root=args.data, name=args.dataset, split=args.split, load_bytes=args.tf_preprocessing, class_map=args.class_map) # added for post quantization calibration calib_dataset = create_dataset( root=args.data, name=args.dataset, split=args.split, load_bytes=args.tf_preprocessing, class_map=args.class_map) if args.valid_labels: with open(args.valid_labels, 'r') as f: valid_labels = {int(line.rstrip()) for line in f} valid_labels = [i in valid_labels for i in range(args.num_classes)] else: valid_labels = None if args.real_labels: real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels) else: real_labels = None crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] loader = create_loader( dataset, input_size=data_config['input_size'], batch_size=args.batch_size, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, crop_pct=crop_pct, pin_memory=args.pin_mem, tf_preprocessing=args.tf_preprocessing) #Also create loader for calibration dataset calib_loader = create_loader( calib_dataset, input_size=data_config['input_size'], batch_size=args.batch_size, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, crop_pct=crop_pct, pin_memory=args.pin_mem, tf_preprocessing=args.tf_preprocessing) batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() print('Start calibration of quantization observers before post-quantization') model_to_quantize = copy.deepcopy(model) model_to_quantize.eval() #post training static quantization if args.quant_option == 'static': qconfig_dict = {"": torch.quantization.default_static_qconfig} model_to_quantize = copy.deepcopy(model_fp) qconfig_dict = {"": torch.quantization.get_default_qconfig('qnnpack')} model_to_quantize.eval() # prepare model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_dict) # calibrate with torch.no_grad(): # warmup, reduce variability of first batch time, especially for comparing torchscript vs non input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])) if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) model(input) end = time.time() for batch_idx, (input, target) in enumerate(loader): if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) if valid_labels is not None: output = output[:, valid_labels] loss = criterion(output, target) if real_labels is not None: real_labels.add_result(output) # measure accuracy and record loss acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(acc1.item(), input.size(0)) top5.update(acc5.item(), input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if batch_idx % args.log_freq == 0: _logger.info( 'Test: [{0:>4d}/{1}] ' 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' 'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format( batch_idx, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, loss=losses, top1=top1, top5=top5)) # quantize model_quantized = quantize_fx.convert_fx(model_prepared) #post training dynamic/weight only quantization elif args.quant_option == 'dynamic': qconfig_dict = {"": torch.quantization.default_dynamic_qconfig} # prepare model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_dict) # no calibration needed when we only have dynamici/weight_only quantization # quantize model_quantized = quantize_fx.convert_fx(model_prepared) else: _logger.warning("Invalid quantization option. Set option to default(static)") # # fusion # model_to_quantize = copy.deepcopy(model_fp) model_fused = quantize_fx.fuse_fx(model_to_quantize) model = model_fused with torch.no_grad(): # warmup, reduce variability of first batch time, especially for comparing torchscript vs non # input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).cuda() input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])) if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) model(input) end = time.time() for batch_idx, (input, target) in enumerate(loader): # if args.no_prefetcher: # target = target.cuda() # input = input.cuda() if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) # compute output # with amp_autocast(): # output = model(input) if valid_labels is not None: output = output[:, valid_labels] loss = criterion(output, target) if real_labels is not None: real_labels.add_result(output) # measure accuracy and record loss acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(acc1.item(), input.size(0)) top5.update(acc5.item(), input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if batch_idx % args.log_freq == 0: _logger.info( 'Test: [{0:>4d}/{1}] ' 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' 'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format( batch_idx, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, loss=losses, top1=top1, top5=top5)) if real_labels is not None: # real labels mode replaces topk values at the end top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(k=5) else: top1a, top5a = top1.avg, top5.avg results = OrderedDict( top1=round(top1a, 4), top1_err=round(100 - top1a, 4), top5=round(top5a, 4), top5_err=round(100 - top5a, 4), param_count=round(param_count / 1e6, 2), img_size=data_config['input_size'][-1], cropt_pct=crop_pct, interpolation=data_config['interpolation']) _logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format( results['top1'], results['top1_err'], results['top5'], results['top5_err'])) return results
def validate(args): # might as well try to validate something args.pretrained = args.pretrained or not args.checkpoint # create model model = create_model( args.model, num_classes=args.num_classes, in_chans=3, pretrained=args.pretrained) if args.checkpoint: load_checkpoint(model, args.checkpoint, args.use_ema) param_count = sum([m.numel() for m in model.parameters()]) logging.info('Model %s created, param count: %d' % (args.model, param_count)) data_config = resolve_data_config(model, args) model, test_time_pool = apply_test_time_pool(model, data_config, args) if args.num_gpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() else: model = model.cuda() criterion = nn.CrossEntropyLoss().cuda() loader = create_loader( Dataset(args.data, load_bytes=args.tf_preprocessing), input_size=data_config['input_size'], batch_size=args.batch_size, use_prefetcher=True, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, crop_pct=1.0 if test_time_pool else data_config['crop_pct'], tf_preprocessing=args.tf_preprocessing) batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() model.eval() end = time.time() with torch.no_grad(): for i, (input, target) in enumerate(loader): target = target.cuda() input = input.cuda() # compute output output = model(input) loss = criterion(output, target) # measure accuracy and record loss prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(prec1.item(), input.size(0)) top5.update(prec5.item(), input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.log_freq == 0: logging.info( 'Test: [{0:>4d}/{1}] ' 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Prec@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) ' 'Prec@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format( i, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, loss=losses, top1=top1, top5=top5)) results = OrderedDict( top1=round(top1.avg, 3), top1_err=round(100 - top1.avg, 3), top5=round(top5.avg, 3), top5_err=round(100 - top5.avg, 3), param_count=round(param_count / 1e6, 2)) logging.info(' * Prec@1 {:.3f} ({:.3f}) Prec@5 {:.3f} ({:.3f})'.format( results['top1'], results['top1_err'], results['top5'], results['top5_err'])) return results
def validate(args): # might as well try to validate something args.pretrained = False args.prefetcher = True # create model model = eval(args.model)(config_path=args.config_path, target_flops=args.target_flops, num_classes=args.num_classes, bn_momentum=args.bn_momentum, activation=args.activation, se=args.se) if args.checkpoint: load_checkpoint(model, args.checkpoint, True) param_count = sum([m.numel() for m in model.parameters()]) logging.info('Model %s created, param count: %d' % (args.model, param_count)) data_config = resolve_data_config(vars(args), model=model) #model, test_time_pool = apply_test_time_pool(model, data_config, args) if args.num_gpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range( args.num_gpu))).cuda() else: model = model.cuda() criterion = nn.CrossEntropyLoss().cuda() if args.lmdb: eval_dir = os.path.join(args.data, 'test_lmdb', 'test.lmdb') dataset_eval = ImageFolderLMDB(eval_dir, None, None) else: eval_dir = os.path.join(args.data, 'val') dataset_eval = Dataset(eval_dir) #crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] crop_pct = 1.0 loader = create_loader(dataset_eval, input_size=data_config['input_size'], batch_size=args.batch_size, is_training=False, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers) # crop_pct=crop_pct) batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() model.eval() end = time.time() with torch.no_grad(): for i, (input, target) in enumerate(loader): # compute output output = model(input) loss = criterion(output, target) # measure accuracy and record loss prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(prec1.item(), input.size(0)) top5.update(prec5.item(), input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.log_freq == 0: logging.info( 'Test: [{0:>4d}/{1}] ' 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Prec@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' 'Prec@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format( i, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, loss=losses, top1=top1, top5=top5)) results = OrderedDict(top1=round(top1.avg, 4), top1_err=round(100 - top1.avg, 4), top5=round(top5.avg, 4), top5_err=round(100 - top5.avg, 4), param_count=round(param_count / 1e6, 2), img_size=data_config['input_size'][-1], cropt_pct=crop_pct, interpolation=data_config['interpolation']) logging.info(' * Prec@1 {:.3f} ({:.3f}) Prec@5 {:.3f} ({:.3f})'.format( results['top1'], results['top1_err'], results['top5'], results['top5_err'])) return results
def train_one_epoch_distill_using_saved_logits(args, config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, loss_scaler): model.train() set_bn_state(config, model) optimizer.zero_grad() num_steps = len(data_loader) batch_time = AverageMeter() loss_meter = AverageMeter() norm_meter = AverageMeter() scaler_meter = AverageMeter() meters = defaultdict(AverageMeter) start = time.time() end = time.time() data_tic = time.time() num_classes = config.MODEL.NUM_CLASSES topk = config.DISTILL.LOGITS_TOPK for idx, ((samples, targets), (logits_index, logits_value, seeds)) in enumerate(data_loader): normal_global_idx = epoch * NORM_ITER_LEN + \ (idx * NORM_ITER_LEN // num_steps) samples = samples.cuda(non_blocking=True) targets = targets.cuda(non_blocking=True) if mixup_fn is not None: samples, targets = mixup_fn(samples, targets, seeds) original_targets = targets.argmax(dim=1) else: original_targets = targets meters['data_time'].update(time.time() - data_tic) with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE): outputs = model(samples) # recover teacher logits logits_index = logits_index.long() logits_value = logits_value.float() logits_index = logits_index.cuda(non_blocking=True) logits_value = logits_value.cuda(non_blocking=True) minor_value = (1.0 - logits_value.sum(-1, keepdim=True)) / ( num_classes - topk) minor_value = minor_value.repeat_interleave(num_classes, dim=-1) outputs_teacher = minor_value.scatter_(-1, logits_index, logits_value) loss = criterion(outputs, outputs_teacher) loss = loss / config.TRAIN.ACCUMULATION_STEPS # this attribute is added by timm on one optimizer (adahessian) is_second_order = hasattr( optimizer, 'is_second_order') and optimizer.is_second_order grad_norm = loss_scaler(loss, optimizer, clip_grad=config.TRAIN.CLIP_GRAD, parameters=model.parameters(), create_graph=is_second_order, update_grad=(idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0) if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0: optimizer.zero_grad() lr_scheduler.step_update( (epoch * num_steps + idx) // config.TRAIN.ACCUMULATION_STEPS) loss_scale_value = loss_scaler.state_dict()["scale"] # compute accuracy real_batch_size = len(original_targets) acc1, acc5 = accuracy(outputs, original_targets, topk=(1, 5)) meters['train_acc1'].update(acc1.item(), real_batch_size) meters['train_acc5'].update(acc5.item(), real_batch_size) teacher_acc1, teacher_acc5 = accuracy(outputs_teacher, original_targets, topk=(1, 5)) meters['teacher_acc1'].update(teacher_acc1.item(), real_batch_size) meters['teacher_acc5'].update(teacher_acc5.item(), real_batch_size) torch.cuda.synchronize() loss_meter.update(loss.item(), real_batch_size) if is_valid_grad_norm(grad_norm): norm_meter.update(grad_norm) scaler_meter.update(loss_scale_value) batch_time.update(time.time() - end) end = time.time() data_tic = time.time() if idx % config.PRINT_FREQ == 0: lr = optimizer.param_groups[0]['lr'] memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) etas = batch_time.avg * (num_steps - idx) extra_meters_str = '' for k, v in meters.items(): extra_meters_str += f'{k} {v.val:.4f} ({v.avg:.4f})\t' logger.info( f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t' f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t' f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t' f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t' f'loss_scale {scaler_meter.val:.4f} ({scaler_meter.avg:.4f})\t' f'{extra_meters_str}' f'mem {memory_used:.0f}MB') if is_main_process() and args.use_wandb: acc1_meter, acc5_meter = meters['train_acc1'], meters[ 'train_acc5'] wandb.log( { "train/acc@1": acc1_meter.val, "train/acc@5": acc5_meter.val, "train/loss": loss_meter.val, "train/grad_norm": norm_meter.val, "train/loss_scale": scaler_meter.val, "train/lr": lr, }, step=normal_global_idx) epoch_time = time.time() - start extra_meters_str = f'Train-Summary: [{epoch}/{config.TRAIN.EPOCHS}]\t' for k, v in meters.items(): v.sync() extra_meters_str += f'{k} {v.val:.4f} ({v.avg:.4f})\t' logger.info(extra_meters_str) logger.info( f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}" )
def train_one_epoch(args, config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, loss_scaler): model.train() set_bn_state(config, model) optimizer.zero_grad() num_steps = len(data_loader) batch_time = AverageMeter() loss_meter = AverageMeter() norm_meter = AverageMeter() scaler_meter = AverageMeter() acc1_meter = AverageMeter() acc5_meter = AverageMeter() start = time.time() end = time.time() for idx, (samples, targets) in enumerate(data_loader): normal_global_idx = epoch * NORM_ITER_LEN + \ (idx * NORM_ITER_LEN // num_steps) samples = samples.cuda(non_blocking=True) targets = targets.cuda(non_blocking=True) if mixup_fn is not None: samples, targets = mixup_fn(samples, targets) original_targets = targets.argmax(dim=1) else: original_targets = targets with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE): outputs = model(samples) loss = criterion(outputs, targets) loss = loss / config.TRAIN.ACCUMULATION_STEPS # this attribute is added by timm on one optimizer (adahessian) is_second_order = hasattr( optimizer, 'is_second_order') and optimizer.is_second_order grad_norm = loss_scaler(loss, optimizer, clip_grad=config.TRAIN.CLIP_GRAD, parameters=model.parameters(), create_graph=is_second_order, update_grad=(idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0) if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0: optimizer.zero_grad() lr_scheduler.step_update( (epoch * num_steps + idx) // config.TRAIN.ACCUMULATION_STEPS) loss_scale_value = loss_scaler.state_dict()["scale"] with torch.no_grad(): acc1, acc5 = accuracy(outputs, original_targets, topk=(1, 5)) acc1_meter.update(acc1.item(), targets.size(0)) acc5_meter.update(acc5.item(), targets.size(0)) torch.cuda.synchronize() loss_meter.update(loss.item(), targets.size(0)) if is_valid_grad_norm(grad_norm): norm_meter.update(grad_norm) scaler_meter.update(loss_scale_value) batch_time.update(time.time() - end) end = time.time() if idx % config.PRINT_FREQ == 0: lr = optimizer.param_groups[0]['lr'] memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) etas = batch_time.avg * (num_steps - idx) logger.info( f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t' f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t' f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t' f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t' f'loss_scale {scaler_meter.val:.4f} ({scaler_meter.avg:.4f})\t' f'mem {memory_used:.0f}MB') if is_main_process() and args.use_wandb: wandb.log( { "train/acc@1": acc1_meter.val, "train/acc@5": acc5_meter.val, "train/loss": loss_meter.val, "train/grad_norm": norm_meter.val, "train/loss_scale": scaler_meter.val, "train/lr": lr, }, step=normal_global_idx) epoch_time = time.time() - start logger.info( f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}" )
def validate(args): # create model from tinynet import tinynet if args.model_name == 'tinynet_a': args.r = 0.86 args.w = 1.0 args.d = 1.2 ckpt_path = './models/tinynet_a.pth' elif args.model_name == 'tinynet_b': args.r = 0.84 args.w = 0.75 args.d = 1.1 ckpt_path = './models/tinynet_b.pth' elif args.model_name == 'tinynet_c': args.r = 0.825 args.w = 0.54 args.d = 0.85 ckpt_path = './models/tinynet_c.pth' elif args.model_name == 'tinynet_d': args.r = 0.68 args.w = 0.54 args.d = 0.695 ckpt_path = './models/tinynet_d.pth' elif args.model_name == 'tinynet_e': args.r = 0.475 args.w = 0.51 args.d = 0.60 ckpt_path = './models/tinynet_e.pth' else: raise 'Unsupported model name.' model = tinynet( r=args.r, w=args.w, d=args.d, ) state_dict = torch.load(ckpt_path) model.load_state_dict(state_dict, strict=False) params = sum([param.numel() for param in model.parameters()]) logging.info('Model %s created, #params: %d' % (args.model_name, params)) data_config = resolve_data_config(vars(args), model=model) model = model.cuda() criterion = nn.CrossEntropyLoss().cuda() dataset = Dataset(args.data) data_loader = create_loader(dataset, is_training=False, input_size=data_config['input_size'], batch_size=128, use_prefetcher=False, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=4, crop_pct=data_config['crop_pct'], pin_memory=False) losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() model.eval() with torch.no_grad(): for i, (input, target) in enumerate(data_loader): input = input.cuda() target = target.cuda() output = model(input) loss = criterion(output, target) acc1, acc5 = accuracy(output.data, target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(acc1.item(), input.size(0)) top5.update(acc5.item(), input.size(0)) if i % 100 == 0: logging.info( 'Test: [{0:>4d}/{1}] Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})' .format(i, len(data_loader), loss=losses)) logging.info(' * Acc@1 {:.3f} Acc@5 {:.3f}'.format(top1.avg, top5.avg))
def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler): model.train() optimizer.zero_grad() num_steps = len(data_loader) batch_time = AverageMeter() loss_meter = AverageMeter() norm_meter = AverageMeter() acc1_meter = AverageMeter() acc5_meter = AverageMeter() start = time.time() end = time.time() for idx, (samples, targets) in enumerate(data_loader): samples = samples.cuda(non_blocking=True) targets = targets.cuda(non_blocking=True) original_targets = targets if mixup_fn is not None: samples, targets = mixup_fn(samples, targets) outputs = model(samples) with torch.no_grad(): acc1, acc5 = accuracy(outputs, original_targets, topk=(1, 5)) if config.TRAIN.ACCUMULATION_STEPS > 1: loss = criterion(outputs, targets) loss = loss / config.TRAIN.ACCUMULATION_STEPS if config.AMP_OPT_LEVEL != "O0": with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() if config.TRAIN.CLIP_GRAD: grad_norm = torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), config.TRAIN.CLIP_GRAD) else: grad_norm = get_grad_norm(amp.master_params(optimizer)) else: loss.backward() if config.TRAIN.CLIP_GRAD: grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), config.TRAIN.CLIP_GRAD) else: grad_norm = get_grad_norm(model.parameters()) if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0: optimizer.step() optimizer.zero_grad() lr_scheduler.step_update(epoch * num_steps + idx) else: loss = criterion(outputs, targets) optimizer.zero_grad() if config.AMP_OPT_LEVEL != "O0": with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() if config.TRAIN.CLIP_GRAD: grad_norm = torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), config.TRAIN.CLIP_GRAD) else: grad_norm = get_grad_norm(amp.master_params(optimizer)) else: loss.backward() if config.TRAIN.CLIP_GRAD: grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), config.TRAIN.CLIP_GRAD) else: grad_norm = get_grad_norm(model.parameters()) optimizer.step() lr_scheduler.step_update(epoch * num_steps + idx) torch.cuda.synchronize() loss_meter.update(loss.item(), targets.size(0)) norm_meter.update(grad_norm) batch_time.update(time.time() - end) end = time.time() acc1_meter.update(acc1.item(), targets.size(0)) acc5_meter.update(acc5.item(), targets.size(0)) if idx % config.PRINT_FREQ == 0: lr = optimizer.param_groups[0]['lr'] memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) etas = batch_time.avg * (num_steps - idx) logger.info( f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t' f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t' f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t' f'Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}\t' f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t' f'mem {memory_used:.0f}MB') epoch_time = time.time() - start logger.info( f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}" )
def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''): batch_time_m = AverageMeter() losses_m = AverageMeter() top1_m = AverageMeter() top5_m = AverageMeter() model.eval() end = time.time() last_idx = len(loader) - 1 with torch.no_grad(): for batch_idx, (input, target) in enumerate(loader): last_batch = batch_idx == last_idx if not args.prefetcher: input = input.cuda() target = target.cuda() if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) with amp_autocast(): output = model(input) if isinstance(output, (tuple, list)): output = output[0] # augmentation reduction reduce_factor = args.tta if reduce_factor > 1: output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2) target = target[0:target.size(0):reduce_factor] loss = loss_fn(output, target) acc1, acc5 = accuracy(output, target, topk=(1, 5)) if args.distributed: reduced_loss = reduce_tensor(loss.data, args.world_size) acc1 = reduce_tensor(acc1, args.world_size) acc5 = reduce_tensor(acc5, args.world_size) else: reduced_loss = loss.data torch.cuda.synchronize() losses_m.update(reduced_loss.item(), input.size(0)) top1_m.update(acc1.item(), output.size(0)) top5_m.update(acc5.item(), output.size(0)) batch_time_m.update(time.time() - end) end = time.time() if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0): log_name = 'Test' + log_suffix _logger.info( '{0}: [{1:>4d}/{2}] ' 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) ' 'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format( log_name, batch_idx, last_idx, batch_time=batch_time_m, loss=losses_m, top1=top1_m, top5=top5_m)) metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)]) return metrics
def train_one_epoch( self, epoch, net, loader, optimizer, loss_fn, lr_scheduler=None, output_dir=None, amp_autocast=suppress, loss_scaler=None, model_ema=None, mixup_fn=None, time_limit=math.inf): start_tic = time.time() if self._augmentation_cfg.mixup_off_epoch and epoch >= self._augmentation_cfg.mixup_off_epoch: if self._misc_cfg.prefetcher and loader.mixup_enabled: loader.mixup_enabled = False elif mixup_fn is not None: mixup_fn.mixup_enabled = False second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order losses_m = AverageMeter() train_metric_score_m = AverageMeter() net.train() num_updates = epoch * len(loader) self._time_elapsed += time.time() - start_tic tic = time.time() last_tic = time.time() train_metric_name = 'accuracy' batch_idx = 0 for batch_idx, (input, target) in enumerate(loader): b_tic = time.time() if self._time_elapsed > time_limit: return {'train_acc': train_metric_score_m.avg, 'train_loss': losses_m.avg, 'time_limit': True} if self._problem_type == REGRESSION: target = target.to(torch.float32) if not self._misc_cfg.prefetcher: # prefetcher would move data to cuda by default input, target = input.to(self.ctx[0]), target.to(self.ctx[0]) if mixup_fn is not None: input, target = mixup_fn(input, target) with amp_autocast(): output = net(input) if self._problem_type == REGRESSION: output = output.flatten() loss = loss_fn(output, target) if self._problem_type == REGRESSION: train_metric_name = 'rmse' train_metric_score = rmse(output, target) else: if output.shape == target.shape: train_metric_name = 'rmse' train_metric_score = rmse(output, target) else: train_metric_score = accuracy(output, target)[0] / 100 losses_m.update(loss.item(), input.size(0)) train_metric_score_m.update(train_metric_score.item(), output.size(0)) optimizer.zero_grad() if loss_scaler is not None: loss_scaler( loss, optimizer, clip_grad=self._optimizer_cfg.clip_grad, clip_mode=self._optimizer_cfg.clip_mode, parameters=model_parameters(net, exclude_head='agc' in self._optimizer_cfg.clip_mode), create_graph=second_order) else: loss.backward(create_graph=second_order) if self._optimizer_cfg.clip_grad is not None: dispatch_clip_grad( model_parameters(net, exclude_head='agc' in self._optimizer_cfg.clip_mode), value=self._optimizer_cfg.clip_grad, mode=self._optimizer_cfg.clip_mode) optimizer.step() if model_ema is not None: model_ema.update(net) if self.found_gpu: torch.cuda.synchronize() num_updates += 1 if (batch_idx+1) % self._misc_cfg.log_interval == 0: lrl = [param_group['lr'] for param_group in optimizer.param_groups] lr = sum(lrl) / len(lrl) self._logger.info('Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f\tlr=%f', epoch, batch_idx, self._train_cfg.batch_size*self._misc_cfg.log_interval/(time.time()-last_tic), train_metric_name, train_metric_score_m.avg, lr) last_tic = time.time() if self._misc_cfg.save_images and output_dir: torchvision.utils.save_image( input, os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx), padding=0, normalize=True) if lr_scheduler is not None: lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg) self._time_elapsed += time.time() - b_tic throughput = int(self._train_cfg.batch_size * batch_idx / (time.time() - tic)) self._logger.info('[Epoch %d] training: %s=%f', epoch, train_metric_name, train_metric_score_m.avg) self._logger.info('[Epoch %d] speed: %d samples/sec\ttime cost: %f', epoch, throughput, time.time()-tic) end_time = time.time() if hasattr(optimizer, 'sync_lookahead'): optimizer.sync_lookahead() self._time_elapsed += time.time() - end_time return {train_metric_name: train_metric_score_m.avg, 'train_loss': losses_m.avg, 'time_limit': False}
def validate(args): # might as well try to validate something args.pretrained = args.pretrained or not args.checkpoint # create model model = create_model(args.model, pretrained=args.pretrained, num_classes=args.num_classes, in_chans=3, scriptable=args.torchscript) if args.checkpoint: load_checkpoint(model, args.checkpoint, args.use_ema) param_count = sum([m.numel() for m in model.parameters()]) logging.info('Model %s created, param count: %d' % (args.model, param_count)) if args.torchscript: torch.jit.optimized_execution(True) model = torch.jit.script(model) model = model.cuda() if args.num_gpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))) criterion = nn.CrossEntropyLoss().cuda() # from torchvision.datasets import ImageNet # dataset = ImageNet(args.data, split='val') valdir = args.data normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) transform = cvtransforms.Compose([ cvtransforms.Resize(size=(256), interpolation='BILINEAR'), cvtransforms.CenterCrop(224), cvtransforms.ToTensor(), cvtransforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) # loader = torch.utils.data.DataLoader( # datasets.ImageFolder(valdir, transform, loader=opencv_loader), # batch_size=args.batch_size, shuffle=False, # num_workers=args.workers, pin_memory=False) loader = torch.utils.data.DataLoader(datasets.ImageFolder( valdir, transforms.Compose([ transforms.Resize((256), interpolation=2), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ])), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=False) # loader_eval = loader.Loader('val', valdir, batch_size=args.batch_size, num_workers=args.workers, shuffle=False) batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() model.eval() with torch.no_grad(): # warmup, reduce variability of first batch time, especially for comparing torchscript vs non # input = torch.randn((args.batch_size,)).cuda() # model(input) end = time.time() for i, (input, target) in enumerate(loader): # if args.no_prefetcher: target = target.cuda() input = input.cuda() # compute output output, _ = model(input) # loss = criterion(output, target) # measure accuracy and record loss acc1, acc5 = accuracy(output.data, target, topk=(1, 5)) # losses.update(loss.item(), input.size(0)) top1.update(acc1.item(), input.size(0)) top5.update(acc5.item(), input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.log_freq == 0: logging.info( 'Test: [{0:>4d}/{1}] ' 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' 'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format( i, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, loss=losses, top1=top1, top5=top5)) results = OrderedDict(top1=round(top1.avg, 4), top1_err=round(100 - top1.avg, 4), top5=round(top5.avg, 4), top5_err=round(100 - top5.avg, 4), param_count=round(param_count / 1e6, 2)) logging.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format( results['top1'], results['top1_err'], results['top5'], results['top5_err'])) return results
def evaluate_SSL(data_loader, model, device, epoch, output_dir): criterion = torch.nn.CrossEntropyLoss() metric_logger = utils.MetricLogger(delimiter=" ") header = 'Test:' save_recon = os.path.join(output_dir, 'reconstruction_samples') Path(save_recon).mkdir(parents=True, exist_ok=True) # switch to evaluation mode model.eval() print_freq = 50 i = 0 for imgs1, rots1, imgs2, rots2 in metric_logger.log_every( data_loader, print_freq, header): imgs1 = imgs1.to(device, non_blocking=True) imgs1_aug = distortImages(imgs1) # Apply distortion rots1 = rots1.to(device, non_blocking=True) imgs2 = imgs2.to(device, non_blocking=True) imgs2_aug = distortImages(imgs2) rots2 = rots2.to(device, non_blocking=True) # compute output with torch.cuda.amp.autocast(): rot1_p, contrastive1_p, imgs1_recon, r_w, cn_w, rec_w = model( imgs1_aug) rot2_p, contrastive2_p, imgs2_recon, _, _, _ = model(imgs2_aug) rot_p = torch.cat([rot1_p, rot2_p], dim=0) rots = torch.cat([rots1, rots2], dim=0) loss = criterion(rot_p, rots) acc1, acc5 = accuracy(rot_p, rots, topk=(1, 4)) batch_size = imgs1.shape[0] * 2 if i % print_freq == 0: print_out = save_recon + '/Test_epoch_' + str( epoch) + '_Iter' + str(i) + '.jpg' imagesToPrint = torch.cat([ imgs1[0:min(15, batch_size)].cpu(), imgs1_aug[0:min(15, batch_size)].cpu(), imgs1_recon[0:min(15, batch_size)].cpu(), imgs2[0:min(15, batch_size)].cpu(), imgs2_aug[0:min(15, batch_size)].cpu(), imgs2_recon[0:min(15, batch_size)].cpu() ], dim=0) torchvision.utils.save_image(imagesToPrint, print_out, nrow=min(15, batch_size), normalize=True, range=(-1, 1)) metric_logger.update(loss=loss.item()) metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) i = i + 1 # gather the stats from all processes metric_logger.synchronize_between_processes() print( '* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
def validate(args): _logger.info(f'\n\n ---------------EVALUATION {args.eps}------------------------------- \n\n') _logger.info("Argument parser collected the following arguments:") for arg in vars(args): _logger.info(f" {arg}:{getattr(args, arg)}") _logger.info("\n") # might as well try to validate something args.pretrained = args.pretrained or not args.checkpoint args.prefetcher = not args.no_prefetcher amp_autocast = suppress # do nothing if args.amp: if has_native_amp: args.native_amp = True elif has_apex: args.apex_amp = True else: _logger.warning("Neither APEX or Native Torch AMP is available.") assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set." if args.native_amp: amp_autocast = torch.cuda.amp.autocast _logger.info('Validating in mixed precision with native PyTorch AMP.') elif args.apex_amp: _logger.info('Validating in mixed precision with NVIDIA APEX AMP.') else: _logger.info('Validating in float32. AMP not enabled.') if args.legacy_jit: set_jit_legacy() # create model model = create_model( args.model, pretrained=args.pretrained, num_classes=args.num_classes, in_chans=3, global_pool=args.gp, scriptable=args.torchscript) if args.num_classes is None: assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' args.num_classes = model.num_classes if args.checkpoint: load_checkpoint(model, args.checkpoint, args.use_ema) param_count = sum([m.numel() for m in model.parameters()]) _logger.info( f'Model {args.model} created, param count: {param_count} ({(float(param_count)/(10.0**6)):.1f} M)' ) data_config = resolve_data_config(vars(args), model=model, use_test_size=True, verbose=True) test_time_pool = False if not args.no_test_pool: model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True) if args.torchscript: torch.jit.optimized_execution(True) model = torch.jit.script(model) model = model.cuda() if args.apex_amp: model = amp.initialize(model, opt_level='O1') if args.channels_last: model = model.to(memory_format=torch.channels_last) if args.num_gpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))) criterion = nn.CrossEntropyLoss().cuda() dataset = create_dataset( root=args.data_dir, name=args.dataset, split=args.split, load_bytes=args.tf_preprocessing, class_map=args.class_map) if args.valid_labels: with open(args.valid_labels, 'r') as f: valid_labels = {int(line.rstrip()) for line in f} valid_labels = [i in valid_labels for i in range(args.num_classes)] else: valid_labels = None if args.real_labels: real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels) else: real_labels = None crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] loader = create_loader( dataset, input_size=data_config['input_size'], batch_size=args.batch_size, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, crop_pct=crop_pct, pin_memory=args.pin_mem, tf_preprocessing=args.tf_preprocessing) batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() top1_fgm_ae = AverageMeter() top5_fgm_ae = AverageMeter() top1_pgd_ae = AverageMeter() top5_pgd_ae = AverageMeter() model.eval() #with torch.no_grad():# TODO Requires grad # warmup, reduce variability of first batch time, especially for comparing torchscript vs non input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).cuda() if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) model(input) end = time.time() for batch_idx, (input, target) in enumerate(loader): if args.no_prefetcher: target = target.cuda() input = input.cuda() if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) # compute output with amp_autocast(): output = model(input) if valid_labels is not None: output = output[:, valid_labels] loss = criterion(output, target) if real_labels is not None: real_labels.add_result(output) # TODO <--------------------- # Generate adversarial examples for current inputs input_fgm_ae = fast_gradient_method( model_fn=model, x=input, eps=args.eps, norm=np.inf, clip_min=None, clip_max=None, ) input_pgd_ae = projected_gradient_descent( model_fn=model, x=input, eps=args.eps, eps_iter=0.01, nb_iter=40, norm=np.inf, clip_min=None, clip_max=None, ) # Predict with Adversarial Examples with torch.no_grad(): with amp_autocast(): output_fgm_ae = model(input_fgm_ae) output_pgd_ae = model(input_pgd_ae) # measure accuracy and record loss acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(acc1.item(), input.size(0)) top5.update(acc5.item(), input.size(0)) acc1_fgm_ae, acc5_fgm_ae = accuracy(output_fgm_ae.detach(), target, topk=(1, 5)) acc1_pgd_ae, acc5_pgd_ae = accuracy(output_pgd_ae.detach(), target, topk=(1, 5)) top1_fgm_ae.update(acc1_fgm_ae.item(), input.size(0)) top5_fgm_ae.update(acc5_fgm_ae.item(), input.size(0)) top1_pgd_ae.update(acc1_pgd_ae.item(), input.size(0)) top5_pgd_ae.update(acc5_pgd_ae.item(), input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if batch_idx % args.log_freq == 0: _logger.info( 'Test: [{0:>4d}/{1}] ' 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' 'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format( batch_idx, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, loss=losses, top1=top1, top5=top5)) if real_labels is not None: raise NotImplementedError # TODO NOt modified for the adversarial examples mode # real labels mode replaces topk values at the end top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(k=5) else: top1a, top5a = top1.avg, top5.avg top1a_fgm_ae, top5a_fgm_ae = top1_fgm_ae.avg, top5_fgm_ae.avg top1a_pgd_ae, top5a_pgd_ae = top1_pgd_ae.avg, top5_pgd_ae.avg results = OrderedDict( top1=round(top1a, 4), top1_err=round(100 - top1a, 4), top5=round(top5a, 4), top5_err=round(100 - top5a, 4), top1_fgm_ae=round(top1a_fgm_ae, 4), top5_fgm_ae=round(top5a_fgm_ae, 4), top1_pgd_ae=round(top1a_pgd_ae, 4), top5_pgd_ae=round(top5a_pgd_ae, 4), param_count=round(param_count / 1e6, 2), img_size=data_config['input_size'][-1], cropt_pct=crop_pct, interpolation=data_config['interpolation']) _logger.info(' * [Regular] Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format( results['top1'], results['top1_err'], results['top5'], results['top5_err'])) _logger.info(' * [FGM Adversarial Attack] Acc@1 {:.3f} Acc@5 {:.3f} '.format( results['top1_fgm_ae'], results['top5_fgm_ae'])) _logger.info(' * [PGD Adversarial Attack] Acc@1 {:.3f} Acc@5 {:.3f} '.format( results['top1_pgd_ae'], results['top5_pgd_ae'])) return results
def _validate_one_epoch(self, epoch): criterion = torch.nn.CrossEntropyLoss() metric_logger = MetricLogger(delimiter=" ") header = 'Test:' # switch to evaluation mode self.model.eval() if self.mode == 'super': config = self._sample_configs(choices=self.choices) model_module = unwrap_model(self.model) model_module.set_sample_config(config=config) else: config = self.retrain_config model_module = unwrap_model(self.model) model_module.set_sample_config(config=config) print("sampled model config: {}".format(config)) parameters = model_module.get_sampled_params_numel(config) print("sampled model parameters: {}".format(parameters)) for images, target in metric_logger.log_every(self.data_loader_val, 10, header): images = images.to(self.device, non_blocking=True) target = target.to(self.device, non_blocking=True) # compute output if self.amp: with torch.cuda.amp.autocast(): output = self.model(images) loss = criterion(output, target) else: output = self.model(images) loss = criterion(output, target) acc1, acc5 = accuracy(output, target, topk=(1, 5)) batch_size = images.shape[0] metric_logger.update(loss=loss.item()) metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) # gather the stats from all processes metric_logger.synchronize_between_processes() print( '* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) self.max_accuracy = max(self.max_accuracy, metric_logger.meters['acc1'].global_avg) print(f'Max accuracy: {self.max_accuracy:.2f}%') val_status = { k: meter.global_avg for k, meter in metric_logger.meters.items() } log_stats = { **{f'val_{k}': v for k, v in val_status.items()}, 'epoch': epoch, } if self.output_dir and self._is_main_process(): with (self.output_dir / "log.txt").open("a") as f: f.write(json.dumps(log_stats) + "\n")
def validate(args): # might as well try to validate something args.pretrained = args.pretrained or not args.checkpoint args.prefetcher = not args.no_prefetcher amp_autocast = suppress # do nothing if args.amp: if has_apex: args.apex_amp = True elif has_native_amp: args.native_amp = True else: _logger.warning( "Neither APEX or Native Torch AMP is available, using FP32.") assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set." if args.native_amp: amp_autocast = torch.cuda.amp.autocast if args.legacy_jit: set_jit_legacy() # create model model = create_model(args.model, pretrained=args.pretrained, num_classes=args.num_classes, in_chans=3, global_pool=args.gp, scriptable=args.torchscript) if args.num_classes is None: assert hasattr( model, 'num_classes' ), 'Model must have `num_classes` attr if not set on cmd line/config.' args.num_classes = model.num_classes if args.checkpoint: load_checkpoint(model, args.checkpoint, args.use_ema) param_count = sum([m.numel() for m in model.parameters()]) _logger.info('Model %s created, param count: %d' % (args.model, param_count)) data_config = resolve_data_config(vars(args), model=model) model, test_time_pool = ( model, False) if args.no_test_pool else apply_test_time_pool( model, data_config) if args.torchscript: torch.jit.optimized_execution(True) model = torch.jit.script(model) model = model.cuda() if args.apex_amp: model = amp.initialize(model, opt_level='O1') if args.channels_last: model = model.to(memory_format=torch.channels_last) if args.num_gpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))) criterion = nn.CrossEntropyLoss().cuda() dataset = create_dataset(root=args.data, name=args.dataset, split=args.split, load_bytes=args.tf_preprocessing, class_map=args.class_map) if args.valid_labels: with open(args.valid_labels, 'r') as f: valid_labels = {int(line.rstrip()) for line in f} valid_labels = [i in valid_labels for i in range(args.num_classes)] else: valid_labels = None if args.real_labels: real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels) else: real_labels = None crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] loader = create_loader(dataset, input_size=data_config['input_size'], batch_size=args.batch_size, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, crop_pct=crop_pct, pin_memory=args.pin_mem, tf_preprocessing=args.tf_preprocessing) batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() model.eval() with torch.no_grad(): # warmup, reduce variability of first batch time, especially for comparing torchscript vs non input = torch.randn((args.batch_size, ) + data_config['input_size']).cuda() if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) model(input) end = time.time() for batch_idx, (input, target) in enumerate(loader): if args.no_prefetcher: target = target.cuda() input = input.cuda() if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) # compute output with amp_autocast(): output = model(input) if valid_labels is not None: output = output[:, valid_labels] loss = criterion(output, target) if real_labels is not None: real_labels.add_result(output) # measure accuracy and record loss acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(acc1.item(), input.size(0)) top5.update(acc5.item(), input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if batch_idx % args.log_freq == 0: _logger.info( 'Test: [{0:>4d}/{1}] ' 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' 'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format( batch_idx, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, loss=losses, top1=top1, top5=top5)) if real_labels is not None: # real labels mode replaces topk values at the end top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy( k=5) else: top1a, top5a = top1.avg, top5.avg results = OrderedDict(top1=round(top1a, 4), top1_err=round(100 - top1a, 4), top5=round(top5a, 4), top5_err=round(100 - top5a, 4), param_count=round(param_count / 1e6, 2), img_size=data_config['input_size'][-1], cropt_pct=crop_pct, interpolation=data_config['interpolation']) _logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format( results['top1'], results['top1_err'], results['top5'], results['top5_err'])) return results
def evaluate(data_loader, model, device): criterion = torch.nn.CrossEntropyLoss() metric_logger = utils.MetricLogger(delimiter=" ") header = 'Test:' # switch to evaluation mode model.eval() for images, target in metric_logger.log_every(data_loader, 10, header): images = images.to(device, non_blocking=True) target = target.to(device, non_blocking=True) # compute output with torch.cuda.amp.autocast(): output = model(images) # Conformer if isinstance(output, list): loss_list = [ criterion(o, target) / len(output) for o in output ] loss = sum(loss_list) # others else: loss = criterion(output, target) if isinstance(output, list): # Conformer acc1_head1 = accuracy(output[0], target, topk=(1, ))[0] acc1_head2 = accuracy(output[1], target, topk=(1, ))[0] acc1_total = accuracy(output[0] + output[1], target, topk=(1, ))[0] else: # others acc1, acc5 = accuracy(output, target, topk=(1, 5)) batch_size = images.shape[0] if isinstance(output, list): metric_logger.update(loss=loss.item()) metric_logger.update(loss_0=loss_list[0].item()) metric_logger.update(loss_1=loss_list[1].item()) metric_logger.meters['acc1'].update(acc1_total.item(), n=batch_size) metric_logger.meters['acc1_head1'].update(acc1_head1.item(), n=batch_size) metric_logger.meters['acc1_head2'].update(acc1_head2.item(), n=batch_size) else: metric_logger.update(loss=loss.item()) metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) if isinstance(output, list): print( '* Acc@heads_top1 {heads_top1.global_avg:.3f} Acc@head_1 {head1_top1.global_avg:.3f} Acc@head_2 {head2_top1.global_avg:.3f} ' 'loss@total {losses.global_avg:.3f} loss@1 {loss_0.global_avg:.3f} loss@2 {loss_1.global_avg:.3f} ' .format(heads_top1=metric_logger.acc1, head1_top1=metric_logger.acc1_head1, head2_top1=metric_logger.acc1_head2, losses=metric_logger.loss, loss_0=metric_logger.loss_0, loss_1=metric_logger.loss_1)) else: print( '* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
def validate(args): # might as well try to validate something args.pretrained = args.pretrained or not args.checkpoint args.prefetcher = not args.no_prefetcher # create model model = create_model( args.model, num_classes=args.num_classes, in_chans=3, pretrained=args.pretrained) if args.checkpoint: load_checkpoint(model, args.checkpoint, args.use_ema) param_count = sum([m.numel() for m in model.parameters()]) logging.info('Model %s created, param count: %d' % (args.model, param_count)) data_config = resolve_data_config(vars(args), model=model) model, test_time_pool = apply_test_time_pool(model, data_config, args) if args.torchscript: torch.jit.optimized_execution(True) model = torch.jit.script(model) if args.amp: model = amp.initialize(model.cuda(), opt_level='O1') else: model = model.cuda() if args.num_gpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))) criterion = nn.CrossEntropyLoss().cuda() #from torchvision.datasets import ImageNet #dataset = ImageNet(args.data, split='val') if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data): dataset = DatasetTar(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map) else: dataset = Dataset(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map) crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] loader = create_loader( dataset, input_size=data_config['input_size'], batch_size=args.batch_size, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, crop_pct=crop_pct, pin_memory=args.pin_mem, tf_preprocessing=args.tf_preprocessing) batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() model.eval() end = time.time() with torch.no_grad(): for i, (input, target) in enumerate(loader): if args.no_prefetcher: target = target.cuda() input = input.cuda() if args.fp16: input = input.half() # compute output output = model(input) loss = criterion(output, target) # measure accuracy and record loss acc1, acc5 = accuracy(output.data, target, topk=(1, 2)) losses.update(loss.item(), input.size(0)) top1.update(acc1.item(), input.size(0)) top5.update(acc5.item(), input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.log_freq == 0: logging.info( 'Test: [{0:>4d}/{1}] ' 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' 'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format( i, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, loss=losses, top1=top1, top5=top5)) results = OrderedDict( top1=round(top1.avg, 4), top1_err=round(100 - top1.avg, 4), top5=round(top5.avg, 4), top5_err=round(100 - top5.avg, 4), param_count=round(param_count / 1e6, 2), img_size=data_config['input_size'][-1], cropt_pct=crop_pct, interpolation=data_config['interpolation']) logging.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format( results['top1'], results['top1_err'], results['top5'], results['top5_err'])) return results
def save_logits_one_epoch(config, model, data_loader, epoch, mixup_fn): model.eval() num_steps = len(data_loader) batch_time = AverageMeter() meters = defaultdict(AverageMeter) start = time.time() end = time.time() topk = config.DISTILL.LOGITS_TOPK logits_manager = data_loader.dataset.get_manager() for idx, ((samples, targets), (keys, seeds)) in enumerate(data_loader): samples = samples.cuda(non_blocking=True) targets = targets.cuda(non_blocking=True) if mixup_fn is not None: samples, targets = mixup_fn(samples, targets, seeds) original_targets = targets.argmax(dim=1) else: original_targets = targets with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE): outputs = model(samples) acc1, acc5 = accuracy(outputs, original_targets, topk=(1, 5)) real_batch_size = len(samples) meters['teacher_acc1'].update(acc1.item(), real_batch_size) meters['teacher_acc5'].update(acc5.item(), real_batch_size) # save teacher logits softmax_prob = torch.softmax(outputs, -1) torch.cuda.synchronize() write_tic = time.time() values, indices = softmax_prob.topk(k=topk, dim=-1, largest=True, sorted=True) cpu_device = torch.device('cpu') values = values.detach().to(device=cpu_device, dtype=torch.float16) indices = indices.detach().to(device=cpu_device, dtype=torch.int16) seeds = seeds.numpy() values = values.numpy() indices = indices.numpy() # check data type assert seeds.dtype == np.int32, seeds.dtype assert indices.dtype == np.int16, indices.dtype assert values.dtype == np.float16, values.dtype for key, seed, indice, value in zip(keys, seeds, indices, values): bstr = seed.tobytes() + indice.tobytes() + value.tobytes() logits_manager.write(key, bstr) meters['write_time'].update(time.time() - write_tic) batch_time.update(time.time() - end) end = time.time() if idx % config.PRINT_FREQ == 0: memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) etas = batch_time.avg * (num_steps - idx) extra_meters_str = '' for k, v in meters.items(): extra_meters_str += f'{k} {v.val:.4f} ({v.avg:.4f})\t' logger.info( f'Save: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t' f'eta {datetime.timedelta(seconds=int(etas))}\t' f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t' f'{extra_meters_str}' f'mem {memory_used:.0f}MB') epoch_time = time.time() - start logger.info( f"EPOCH {epoch} save logits takes {datetime.timedelta(seconds=int(epoch_time))}" )