def forward_loss(model, criterion, input, target, meter): """forward model and return loss""" if getattr(FLAGS, 'normalize', False): input = input #(128 * input).round_().clamp_(-128, 127) else: input = (255 * input).round_() output = model(input) loss = torch.mean(criterion(output, target)) # topk _, pred = output.topk(max(FLAGS.topk)) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) correct_k = [] for k in FLAGS.topk: correct_k.append(correct[:k].float().sum(0)) res = torch.cat([loss.view(1)] + correct_k, dim=0) if getattr(FLAGS, 'distributed', False) and getattr(FLAGS, 'distributed_all_reduce', False): res = dist_all_reduce_tensor(res) res = res.cpu().detach().numpy() bs = (res.size - 1) // len(FLAGS.topk) for i, k in enumerate(FLAGS.topk): error_list = list(1. - res[1+i*bs:1+(i+1)*bs]) if meter is not None: meter['top{}_error'.format(k)].cache_list(error_list) if meter is not None: meter['loss'].cache(res[0]) return loss
def forward_loss( model, criterion, input, target, meter, soft_target=None, soft_criterion=None, return_soft_target=False, return_acc=False): """forward model and return loss""" output = model(input) if soft_target is not None: loss = torch.mean(soft_criterion(output, soft_target)) else: loss = torch.mean(criterion(output, target)) # topk _, pred = output.topk(max(FLAGS.topk)) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) correct_k = [] for k in FLAGS.topk: correct_k.append(correct[:k].float().sum(0)) tensor = torch.cat([loss.view(1)] + correct_k, dim=0) # allreduce tensor = dist_all_reduce_tensor(tensor) # cache to meter tensor = tensor.cpu().detach().numpy() bs = (tensor.size-1)//2 for i, k in enumerate(FLAGS.topk): error_list = list(1.-tensor[1+i*bs:1+(i+1)*bs]) if return_acc and k == 1: top1_error = sum(error_list) / len(error_list) return loss, top1_error if meter is not None: meter['top{}_error'.format(k)].cache_list(error_list) if meter is not None: meter['loss'].cache(tensor[0]) if return_soft_target: return loss, torch.nn.functional.softmax(output, dim=1) return loss