def check_logits_one_epoch(config, model, data_loader, epoch, mixup_fn): model.eval() num_steps = len(data_loader) batch_time = AverageMeter() meters = defaultdict(AverageMeter) start = time.time() end = time.time() topk = config.DISTILL.LOGITS_TOPK for idx, ((samples, targets), (saved_logits_index, saved_logits_value, seeds)) in enumerate(data_loader): samples = samples.cuda(non_blocking=True) targets = targets.cuda(non_blocking=True) if mixup_fn is not None: samples, targets = mixup_fn(samples, targets, seeds) with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE): outputs = model(samples) softmax_prob = torch.softmax(outputs, -1) torch.cuda.synchronize() values, indices = softmax_prob.topk(k=topk, dim=-1, largest=True, sorted=True) meters['error'].update( (values - saved_logits_value.cuda()).abs().mean().item()) meters['diff_rate'].update( torch.count_nonzero( (indices != saved_logits_index.cuda())).item() / indices.numel()) batch_time.update(time.time() - end) end = time.time() if idx % config.PRINT_FREQ == 0: memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) etas = batch_time.avg * (num_steps - idx) extra_meters_str = '' for k, v in meters.items(): extra_meters_str += f'{k} {v.val:.4f} ({v.avg:.4f})\t' logger.info( f'Check: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t' f'eta {datetime.timedelta(seconds=int(etas))}\t' f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t' f'{extra_meters_str}' f'mem {memory_used:.0f}MB') epoch_time = time.time() - start logger.info( f"EPOCH {epoch} check logits takes {datetime.timedelta(seconds=int(epoch_time))}" )
def validate(config, data_loader, model, logger): criterion = torch.nn.CrossEntropyLoss() model.eval() batch_time = AverageMeter() loss_meter = AverageMeter() acc1_meter = AverageMeter() acc5_meter = AverageMeter() end = time.time() for idx, (images, target) in enumerate(data_loader): images = images.cuda(non_blocking=True) target = target.cuda(non_blocking=True) output = model(images) # measure accuracy and record loss loss = criterion(output, target) acc1, acc5 = accuracy(output, target, topk=(1, 5)) loss_meter.update(loss.item(), target.size(0)) acc1_meter.update(acc1.item(), target.size(0)) acc5_meter.update(acc5.item(), target.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if idx % config.PRINT_FREQ == 0: memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) logger.info(f'Test: [{idx}/{len(data_loader)}]\t' f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t' f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t' f'Mem {memory_used:.0f}MB') loss_meter.sync() acc1_meter.sync() acc5_meter.sync() logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}') return acc1_meter.avg, acc5_meter.avg, loss_meter.avg
def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler): model.train() optimizer.zero_grad() num_steps = len(data_loader) batch_time = AverageMeter() loss_meter = AverageMeter() norm_meter = AverageMeter() acc1_meter = AverageMeter() acc5_meter = AverageMeter() start = time.time() end = time.time() for idx, (samples, targets) in enumerate(data_loader): samples = samples.cuda(non_blocking=True) targets = targets.cuda(non_blocking=True) original_targets = targets if mixup_fn is not None: samples, targets = mixup_fn(samples, targets) outputs = model(samples) with torch.no_grad(): acc1, acc5 = accuracy(outputs, original_targets, topk=(1, 5)) if config.TRAIN.ACCUMULATION_STEPS > 1: loss = criterion(outputs, targets) loss = loss / config.TRAIN.ACCUMULATION_STEPS if config.AMP_OPT_LEVEL != "O0": with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() if config.TRAIN.CLIP_GRAD: grad_norm = torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), config.TRAIN.CLIP_GRAD) else: grad_norm = get_grad_norm(amp.master_params(optimizer)) else: loss.backward() if config.TRAIN.CLIP_GRAD: grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), config.TRAIN.CLIP_GRAD) else: grad_norm = get_grad_norm(model.parameters()) if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0: optimizer.step() optimizer.zero_grad() lr_scheduler.step_update(epoch * num_steps + idx) else: loss = criterion(outputs, targets) optimizer.zero_grad() if config.AMP_OPT_LEVEL != "O0": with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() if config.TRAIN.CLIP_GRAD: grad_norm = torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), config.TRAIN.CLIP_GRAD) else: grad_norm = get_grad_norm(amp.master_params(optimizer)) else: loss.backward() if config.TRAIN.CLIP_GRAD: grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), config.TRAIN.CLIP_GRAD) else: grad_norm = get_grad_norm(model.parameters()) optimizer.step() lr_scheduler.step_update(epoch * num_steps + idx) torch.cuda.synchronize() loss_meter.update(loss.item(), targets.size(0)) norm_meter.update(grad_norm) batch_time.update(time.time() - end) end = time.time() acc1_meter.update(acc1.item(), targets.size(0)) acc5_meter.update(acc5.item(), targets.size(0)) if idx % config.PRINT_FREQ == 0: lr = optimizer.param_groups[0]['lr'] memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) etas = batch_time.avg * (num_steps - idx) logger.info( f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t' f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t' f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t' f'Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}\t' f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t' f'mem {memory_used:.0f}MB') epoch_time = time.time() - start logger.info( f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}" )
def train_one_epoch_distill(config, model, model_teacher, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, criterion_soft=None, criterion_truth=None, criterion_attn=None, criterion_hidden=None): layer_id_s_list = config.DISTILL.STUDENT_LAYER_LIST layer_id_t_list = config.DISTILL.TEACHER_LAYER_LIST model.train() optimizer.zero_grad() model_teacher.eval() num_steps = len(data_loader) batch_time = AverageMeter() loss_meter = AverageMeter() norm_meter = AverageMeter() loss_soft_meter = AverageMeter() loss_truth_meter = AverageMeter() loss_attn_meter = AverageMeter() loss_hidden_meter = AverageMeter() acc1_meter = AverageMeter() acc5_meter = AverageMeter() teacher_acc1_meter = AverageMeter() teacher_acc5_meter = AverageMeter() start = time.time() end = time.time() for idx, (samples, targets) in enumerate(data_loader): samples = samples.cuda(non_blocking=True) targets = targets.cuda(non_blocking=True) original_targets = targets if mixup_fn is not None: samples, targets = mixup_fn(samples, targets) if config.DISTILL.ATTN_LOSS and config.DISTILL.HIDDEN_LOSS: outputs, qkv_s, hidden_s = model( samples, layer_id_s_list, is_attn_loss=True, is_hidden_loss=True, is_hidden_org=config.DISTILL.HIDDEN_RELATION) elif config.DISTILL.ATTN_LOSS: outputs, qkv_s = model( samples, layer_id_s_list, is_attn_loss=True, is_hidden_loss=False, is_hidden_org=config.DISTILL.HIDDEN_RELATION) elif config.DISTILL.HIDDEN_LOSS: outputs, hidden_s = model( samples, layer_id_s_list, is_attn_loss=False, is_hidden_loss=True, is_hidden_org=config.DISTILL.HIDDEN_RELATION) else: outputs = model(samples) with torch.no_grad(): acc1, acc5 = accuracy(outputs, original_targets, topk=(1, 5)) if config.DISTILL.ATTN_LOSS or config.DISTILL.HIDDEN_LOSS: outputs_teacher, qkv_t, hidden_t = model_teacher( samples, layer_id_t_list, is_attn_loss=True, is_hidden_loss=True) else: outputs_teacher = model_teacher(samples) teacher_acc1, teacher_acc5 = accuracy(outputs_teacher, original_targets, topk=(1, 5)) if config.TRAIN.ACCUMULATION_STEPS > 1: loss_truth = config.DISTILL.ALPHA * criterion_truth( outputs, targets) loss_soft = (1.0 - config.DISTILL.ALPHA) * criterion_soft( outputs / config.DISTILL.TEMPERATURE, outputs_teacher / config.DISTILL.TEMPERATURE) if config.DISTILL.ATTN_LOSS: loss_attn = config.DISTILL.QKV_LOSS_WEIGHT * criterion_attn( qkv_s, qkv_t, config.DISTILL.AR) else: loss_attn = torch.zeros(loss_truth.shape) if config.DISTILL.HIDDEN_LOSS: loss_hidden = config.DISTILL.HIDDEN_LOSS_WEIGHT * criterion_hidden( hidden_s, hidden_t) else: loss_hidden = torch.zeros(loss_truth.shape) loss = loss_truth + loss_soft + loss_attn + loss_hidden loss = loss / config.TRAIN.ACCUMULATION_STEPS if config.AMP_OPT_LEVEL != "O0": with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() if config.TRAIN.CLIP_GRAD: grad_norm = torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), config.TRAIN.CLIP_GRAD) else: grad_norm = get_grad_norm(amp.master_params(optimizer)) else: loss.backward() if config.TRAIN.CLIP_GRAD: grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), config.TRAIN.CLIP_GRAD) else: grad_norm = get_grad_norm(model.parameters()) if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0: optimizer.step() optimizer.zero_grad() lr_scheduler.step_update(epoch * num_steps + idx) else: loss_truth = config.DISTILL.ALPHA * criterion_truth( outputs, targets) loss_soft = (1.0 - config.DISTILL.ALPHA) * criterion_soft( outputs / config.DISTILL.TEMPERATURE, outputs_teacher / config.DISTILL.TEMPERATURE) if config.DISTILL.ATTN_LOSS: loss_attn = config.DISTILL.QKV_LOSS_WEIGHT * criterion_attn( qkv_s, qkv_t, config.DISTILL.AR) else: loss_attn = torch.zeros(loss_truth.shape) if config.DISTILL.HIDDEN_LOSS: loss_hidden = config.DISTILL.HIDDEN_LOSS_WEIGHT * criterion_hidden( hidden_s, hidden_t) else: loss_hidden = torch.zeros(loss_truth.shape) loss = loss_truth + loss_soft + loss_attn + loss_hidden optimizer.zero_grad() if config.AMP_OPT_LEVEL != "O0": with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() if config.TRAIN.CLIP_GRAD: grad_norm = torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), config.TRAIN.CLIP_GRAD) else: grad_norm = get_grad_norm(amp.master_params(optimizer)) else: loss.backward() if config.TRAIN.CLIP_GRAD: grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), config.TRAIN.CLIP_GRAD) else: grad_norm = get_grad_norm(model.parameters()) optimizer.step() lr_scheduler.step_update(epoch * num_steps + idx) torch.cuda.synchronize() loss_meter.update(loss.item(), targets.size(0)) loss_soft_meter.update(loss_soft.item(), targets.size(0)) loss_truth_meter.update(loss_truth.item(), targets.size(0)) loss_attn_meter.update(loss_attn.item(), targets.size(0)) loss_hidden_meter.update(loss_hidden.item(), targets.size(0)) norm_meter.update(grad_norm) batch_time.update(time.time() - end) end = time.time() acc1_meter.update(acc1.item(), targets.size(0)) acc5_meter.update(acc5.item(), targets.size(0)) teacher_acc1_meter.update(teacher_acc1.item(), targets.size(0)) teacher_acc5_meter.update(teacher_acc5.item(), targets.size(0)) if idx % config.PRINT_FREQ == 0: lr = optimizer.param_groups[0]['lr'] memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) etas = batch_time.avg * (num_steps - idx) logger.info( f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t' f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t' f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t' f'Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}\t' f'Teacher_Acc@1 {teacher_acc1_meter.avg:.3f} Teacher_Acc@5 {teacher_acc5_meter.avg:.3f}\t' f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' f'loss_soft {loss_soft_meter.val:.4f} ({loss_soft_meter.avg:.4f})\t' f'loss_truth {loss_truth_meter.val:.4f} ({loss_truth_meter.avg:.4f})\t' f'loss_attn {loss_attn_meter.val:.4f} ({loss_attn_meter.avg:.4f})\t' f'loss_hidden {loss_hidden_meter.val:.4f} ({loss_hidden_meter.avg:.4f})\t' f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t' f'mem {memory_used:.0f}MB') epoch_time = time.time() - start logger.info( f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}" )
def validate(args, config, data_loader, model, num_classes=1000): criterion = torch.nn.CrossEntropyLoss() model.eval() batch_time = AverageMeter() loss_meter = AverageMeter() acc1_meter = AverageMeter() acc5_meter = AverageMeter() end = time.time() for idx, (images, target) in enumerate(data_loader): images = images.cuda(non_blocking=True) target = target.cuda(non_blocking=True) # compute output with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE): output = model(images) if num_classes == 1000: output_num_classes = output.size(-1) if output_num_classes == 21841: output = remap_layer_22kto1k(output) # measure accuracy and record loss loss = criterion(output, target) acc1, acc5 = accuracy(output, target, topk=(1, 5)) loss_meter.update(loss.item(), target.size(0)) acc1_meter.update(acc1.item(), target.size(0)) acc5_meter.update(acc5.item(), target.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if idx % config.PRINT_FREQ == 0: memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) logger.info(f'Test: [{idx}/{len(data_loader)}]\t' f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t' f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t' f'Mem {memory_used:.0f}MB') acc1_meter.sync() acc5_meter.sync() logger.info( f' The number of validation samples is {int(acc1_meter.count)}') logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}') return acc1_meter.avg, acc5_meter.avg, loss_meter.avg
def train_one_epoch_distill_using_saved_logits(args, config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, loss_scaler): model.train() set_bn_state(config, model) optimizer.zero_grad() num_steps = len(data_loader) batch_time = AverageMeter() loss_meter = AverageMeter() norm_meter = AverageMeter() scaler_meter = AverageMeter() meters = defaultdict(AverageMeter) start = time.time() end = time.time() data_tic = time.time() num_classes = config.MODEL.NUM_CLASSES topk = config.DISTILL.LOGITS_TOPK for idx, ((samples, targets), (logits_index, logits_value, seeds)) in enumerate(data_loader): normal_global_idx = epoch * NORM_ITER_LEN + \ (idx * NORM_ITER_LEN // num_steps) samples = samples.cuda(non_blocking=True) targets = targets.cuda(non_blocking=True) if mixup_fn is not None: samples, targets = mixup_fn(samples, targets, seeds) original_targets = targets.argmax(dim=1) else: original_targets = targets meters['data_time'].update(time.time() - data_tic) with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE): outputs = model(samples) # recover teacher logits logits_index = logits_index.long() logits_value = logits_value.float() logits_index = logits_index.cuda(non_blocking=True) logits_value = logits_value.cuda(non_blocking=True) minor_value = (1.0 - logits_value.sum(-1, keepdim=True)) / ( num_classes - topk) minor_value = minor_value.repeat_interleave(num_classes, dim=-1) outputs_teacher = minor_value.scatter_(-1, logits_index, logits_value) loss = criterion(outputs, outputs_teacher) loss = loss / config.TRAIN.ACCUMULATION_STEPS # this attribute is added by timm on one optimizer (adahessian) is_second_order = hasattr( optimizer, 'is_second_order') and optimizer.is_second_order grad_norm = loss_scaler(loss, optimizer, clip_grad=config.TRAIN.CLIP_GRAD, parameters=model.parameters(), create_graph=is_second_order, update_grad=(idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0) if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0: optimizer.zero_grad() lr_scheduler.step_update( (epoch * num_steps + idx) // config.TRAIN.ACCUMULATION_STEPS) loss_scale_value = loss_scaler.state_dict()["scale"] # compute accuracy real_batch_size = len(original_targets) acc1, acc5 = accuracy(outputs, original_targets, topk=(1, 5)) meters['train_acc1'].update(acc1.item(), real_batch_size) meters['train_acc5'].update(acc5.item(), real_batch_size) teacher_acc1, teacher_acc5 = accuracy(outputs_teacher, original_targets, topk=(1, 5)) meters['teacher_acc1'].update(teacher_acc1.item(), real_batch_size) meters['teacher_acc5'].update(teacher_acc5.item(), real_batch_size) torch.cuda.synchronize() loss_meter.update(loss.item(), real_batch_size) if is_valid_grad_norm(grad_norm): norm_meter.update(grad_norm) scaler_meter.update(loss_scale_value) batch_time.update(time.time() - end) end = time.time() data_tic = time.time() if idx % config.PRINT_FREQ == 0: lr = optimizer.param_groups[0]['lr'] memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) etas = batch_time.avg * (num_steps - idx) extra_meters_str = '' for k, v in meters.items(): extra_meters_str += f'{k} {v.val:.4f} ({v.avg:.4f})\t' logger.info( f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t' f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t' f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t' f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t' f'loss_scale {scaler_meter.val:.4f} ({scaler_meter.avg:.4f})\t' f'{extra_meters_str}' f'mem {memory_used:.0f}MB') if is_main_process() and args.use_wandb: acc1_meter, acc5_meter = meters['train_acc1'], meters[ 'train_acc5'] wandb.log( { "train/acc@1": acc1_meter.val, "train/acc@5": acc5_meter.val, "train/loss": loss_meter.val, "train/grad_norm": norm_meter.val, "train/loss_scale": scaler_meter.val, "train/lr": lr, }, step=normal_global_idx) epoch_time = time.time() - start extra_meters_str = f'Train-Summary: [{epoch}/{config.TRAIN.EPOCHS}]\t' for k, v in meters.items(): v.sync() extra_meters_str += f'{k} {v.val:.4f} ({v.avg:.4f})\t' logger.info(extra_meters_str) logger.info( f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}" )
def train_one_epoch(args, config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, loss_scaler): model.train() set_bn_state(config, model) optimizer.zero_grad() num_steps = len(data_loader) batch_time = AverageMeter() loss_meter = AverageMeter() norm_meter = AverageMeter() scaler_meter = AverageMeter() acc1_meter = AverageMeter() acc5_meter = AverageMeter() start = time.time() end = time.time() for idx, (samples, targets) in enumerate(data_loader): normal_global_idx = epoch * NORM_ITER_LEN + \ (idx * NORM_ITER_LEN // num_steps) samples = samples.cuda(non_blocking=True) targets = targets.cuda(non_blocking=True) if mixup_fn is not None: samples, targets = mixup_fn(samples, targets) original_targets = targets.argmax(dim=1) else: original_targets = targets with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE): outputs = model(samples) loss = criterion(outputs, targets) loss = loss / config.TRAIN.ACCUMULATION_STEPS # this attribute is added by timm on one optimizer (adahessian) is_second_order = hasattr( optimizer, 'is_second_order') and optimizer.is_second_order grad_norm = loss_scaler(loss, optimizer, clip_grad=config.TRAIN.CLIP_GRAD, parameters=model.parameters(), create_graph=is_second_order, update_grad=(idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0) if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0: optimizer.zero_grad() lr_scheduler.step_update( (epoch * num_steps + idx) // config.TRAIN.ACCUMULATION_STEPS) loss_scale_value = loss_scaler.state_dict()["scale"] with torch.no_grad(): acc1, acc5 = accuracy(outputs, original_targets, topk=(1, 5)) acc1_meter.update(acc1.item(), targets.size(0)) acc5_meter.update(acc5.item(), targets.size(0)) torch.cuda.synchronize() loss_meter.update(loss.item(), targets.size(0)) if is_valid_grad_norm(grad_norm): norm_meter.update(grad_norm) scaler_meter.update(loss_scale_value) batch_time.update(time.time() - end) end = time.time() if idx % config.PRINT_FREQ == 0: lr = optimizer.param_groups[0]['lr'] memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) etas = batch_time.avg * (num_steps - idx) logger.info( f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t' f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t' f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t' f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t' f'loss_scale {scaler_meter.val:.4f} ({scaler_meter.avg:.4f})\t' f'mem {memory_used:.0f}MB') if is_main_process() and args.use_wandb: wandb.log( { "train/acc@1": acc1_meter.val, "train/acc@5": acc5_meter.val, "train/loss": loss_meter.val, "train/grad_norm": norm_meter.val, "train/loss_scale": scaler_meter.val, "train/lr": lr, }, step=normal_global_idx) epoch_time = time.time() - start logger.info( f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}" )
def save_logits_one_epoch(config, model, data_loader, epoch, mixup_fn): model.eval() num_steps = len(data_loader) batch_time = AverageMeter() meters = defaultdict(AverageMeter) start = time.time() end = time.time() topk = config.DISTILL.LOGITS_TOPK logits_manager = data_loader.dataset.get_manager() for idx, ((samples, targets), (keys, seeds)) in enumerate(data_loader): samples = samples.cuda(non_blocking=True) targets = targets.cuda(non_blocking=True) if mixup_fn is not None: samples, targets = mixup_fn(samples, targets, seeds) original_targets = targets.argmax(dim=1) else: original_targets = targets with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE): outputs = model(samples) acc1, acc5 = accuracy(outputs, original_targets, topk=(1, 5)) real_batch_size = len(samples) meters['teacher_acc1'].update(acc1.item(), real_batch_size) meters['teacher_acc5'].update(acc5.item(), real_batch_size) # save teacher logits softmax_prob = torch.softmax(outputs, -1) torch.cuda.synchronize() write_tic = time.time() values, indices = softmax_prob.topk(k=topk, dim=-1, largest=True, sorted=True) cpu_device = torch.device('cpu') values = values.detach().to(device=cpu_device, dtype=torch.float16) indices = indices.detach().to(device=cpu_device, dtype=torch.int16) seeds = seeds.numpy() values = values.numpy() indices = indices.numpy() # check data type assert seeds.dtype == np.int32, seeds.dtype assert indices.dtype == np.int16, indices.dtype assert values.dtype == np.float16, values.dtype for key, seed, indice, value in zip(keys, seeds, indices, values): bstr = seed.tobytes() + indice.tobytes() + value.tobytes() logits_manager.write(key, bstr) meters['write_time'].update(time.time() - write_tic) batch_time.update(time.time() - end) end = time.time() if idx % config.PRINT_FREQ == 0: memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) etas = batch_time.avg * (num_steps - idx) extra_meters_str = '' for k, v in meters.items(): extra_meters_str += f'{k} {v.val:.4f} ({v.avg:.4f})\t' logger.info( f'Save: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t' f'eta {datetime.timedelta(seconds=int(etas))}\t' f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t' f'{extra_meters_str}' f'mem {memory_used:.0f}MB') epoch_time = time.time() - start logger.info( f"EPOCH {epoch} save logits takes {datetime.timedelta(seconds=int(epoch_time))}" )