def val(args, model=None, current_epoch=0): top1 = AverageMeter() top5 = AverageMeter() top1.reset() top5.reset() if model is None: model = get_model(args) model.eval() _, val_loader = data_loader(args, test_path=True) save_atten = SAVE_ATTEN(save_dir=args.save_atten_dir) global_counter = 0 prob = None gt = None for idx, dat in tqdm(enumerate(val_loader)): img_path, img, label_in = dat global_counter += 1 if args.tencrop == 'True': bs, ncrops, c, h, w = img.size() img = img.view(-1, c, h, w) label_input = label_in.repeat(10, 1) label = label_input.view(-1) else: label = label_in img, label = img.cuda(), label.cuda() img_var, label_var = Variable(img), Variable(label) logits = model(img_var, label_var) logits0 = logits[0] logits0 = F.softmax(logits0, dim=1) if args.tencrop == 'True': logits0 = logits0.view(bs, ncrops, -1).mean(1) # Calculate classification results prec1_1, prec5_1 = Metrics.accuracy(logits0.cpu().data, label_in.long(), topk=(1, 5)) # prec3_1, prec5_1 = Metrics.accuracy(logits[1].data, label.long(), topk=(1,5)) top1.update(prec1_1[0], img.size()[0]) top5.update(prec5_1[0], img.size()[0]) # save_atten.save_heatmap_segmentation(img_path, np_last_featmaps, label.cpu().numpy(), # save_dir='./save_bins/heatmaps', size=(0,0), maskedimg=True) # np_last_featmaps = logits[2].cpu().data.numpy() np_last_featmaps = logits[-1].cpu().data.numpy() np_scores, pred_labels = torch.topk(logits0, k=args.num_classes, dim=1) pred_np_labels = pred_labels.cpu().data.numpy() save_atten.save_top_5_pred_labels(pred_np_labels[:, :5], img_path, global_counter) # pred_np_labels[:,0] = label.cpu().numpy() #replace the first label with gt label # save_atten.save_top_5_atten_maps(np_last_featmaps, pred_np_labels, img_path) print('Top1:', top1.avg, 'Top5:', top5.avg)
def val(args, model=None, current_epoch=0): top1 = AverageMeter() top5 = AverageMeter() top1.reset() top5.reset() if model is None: model, _ = get_model(args) model.eval() train_loader, val_loader = data_loader(args, test_path=True) save_atten = SAVE_ATTEN(save_dir='../save_bins/') global_counter = 0 prob = None gt = None for idx, dat in tqdm(enumerate(val_loader)): img_path, img, label_in = dat global_counter += 1 if args.tencrop == 'True': bs, ncrops, c, h, w = img.size() img = img.view(-1, c, h, w) label_input = label_in.repeat(10, 1) label = label_input.view(-1) else: label = label_in img, label = img.cuda(), label.cuda() img_var, label_var = Variable(img), Variable(label) logits = model(img_var, label_var) logits0 = logits[0] logits0 = F.softmax(logits0, dim=1) if args.tencrop == 'True': logits0 = logits0.view(bs, ncrops, -1).mean(1) # Calculate classification results if args.onehot == 'True': val_mAP, prob, gt = cal_mAP(logits0, label_var, prob, gt) # print val_mAP else: prec1_1, prec5_1 = Metrics.accuracy(logits0.cpu().data, label_in.long(), topk=(1, 5)) # prec3_1, prec5_1 = Metrics.accuracy(logits[1].data, label.long(), topk=(1,5)) top1.update(prec1_1[0], img.size()[0]) top5.update(prec5_1[0], img.size()[0]) # model.module.save_erased_img(img_path) last_featmaps = model.module.get_localization_maps() np_last_featmaps = last_featmaps.cpu().data.numpy() # Save 100 sample masked images by heatmaps # if idx < 100/args.batch_size: save_atten.get_masked_img(img_path, np_last_featmaps, label_in.numpy(), size=(0, 0), maps_in_dir=True, only_map=True) # save_atten.save_heatmap_segmentation(img_path, np_last_featmaps, label.cpu().numpy(), # save_dir='./save_bins/heatmaps', size=(0,0), maskedimg=True) # save_atten.get_masked_img(img_path, np_last_featmaps, label_in.numpy(),size=(0,0), # maps_in_dir=True, save_dir='../heatmaps',only_map=True ) # np_scores, pred_labels = torch.topk(logits0,k=args.num_classes,dim=1) # # print pred_labels.size(), label.size() # pred_np_labels = pred_labels.cpu().data.numpy() # save_atten.save_top_5_pred_labels(pred_np_labels[:,:5], img_path, global_counter) # # pred_np_labels[:,0] = label.cpu().numpy() #replace the first label with gt label # # save_atten.save_top_5_atten_maps(np_last_featmaps, pred_np_labels, img_path) if args.onehot == 'True': print val_mAP print 'AVG:', np.mean(val_mAP) else: print('Top1:', top1.avg, 'Top5:', top5.avg)
def train(args): batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() model, optimizer = get_model(args) model.train() train_loader, _ = data_loader(args) with open(os.path.join(args.snapshot_dir, 'train_record.csv'), 'a') as fw: config = json.dumps(vars(args), indent=4, separators=(',', ':')) fw.write(config) fw.write('#epoch,loss,pred@1,pred@5\n') total_epoch = args.epoch global_counter = args.global_counter current_epoch = args.current_epoch end = time.time() max_iter = total_epoch * len(train_loader) print('Max iter:', max_iter) while current_epoch < total_epoch: model.train() losses.reset() top1.reset() top5.reset() batch_time.reset() res = my_optim.reduce_lr(args, optimizer, current_epoch) if res: for g in optimizer.param_groups: out_str = 'Epoch:%d, %f\n' % (current_epoch, g['lr']) fw.write(out_str) steps_per_epoch = len(train_loader) for idx, dat in enumerate(train_loader): img_path, img, label = dat global_counter += 1 img, label = img.cuda(), label.cuda() img_var, label_var = Variable(img), Variable(label) logits = model(img_var, label_var) loss_val, = model.module.get_loss(logits, label_var) optimizer.zero_grad() loss_val.backward() optimizer.step() if not args.onehot == 'True': logits1 = torch.squeeze(logits[0]) prec1_1, prec5_1 = Metrics.accuracy(logits1.data, label.long(), topk=(1, 5)) top1.update(prec1_1[0], img.size()[0]) top5.update(prec5_1[0], img.size()[0]) losses.update(loss_val.data[0], img.size()[0]) batch_time.update(time.time() - end) end = time.time() if global_counter % 1000 == 0: losses.reset() top1.reset() top5.reset() if global_counter % args.disp_interval == 0: # Calculate ETA eta_seconds = ( (total_epoch - current_epoch) * steps_per_epoch + (steps_per_epoch - idx)) * batch_time.avg eta_str = "{:0>8}".format( datetime.timedelta(seconds=int(eta_seconds))) eta_seconds_epoch = steps_per_epoch * batch_time.avg eta_str_epoch = "{:0>8}".format( datetime.timedelta(seconds=int(eta_seconds_epoch))) print( 'Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'ETA {eta_str}({eta_str_epoch})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( current_epoch, global_counter % len(train_loader), len(train_loader), batch_time=batch_time, eta_str=eta_str, eta_str_epoch=eta_str_epoch, loss=losses, top1=top1, top5=top5)) if current_epoch % 1 == 0: save_checkpoint(args, { 'epoch': current_epoch, 'arch': 'resnet', 'global_counter': global_counter, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict() }, is_best=False, filename='%s_epoch_%d_glo_step_%d.pth.tar' % (args.dataset, current_epoch, global_counter)) with open(os.path.join(args.snapshot_dir, 'train_record.csv'), 'a') as fw: fw.write('%d,%.4f,%.3f,%.3f\n' % (current_epoch, losses.avg, top1.avg, top5.avg)) current_epoch += 1