from torch.utils.data import DataLoader from optparse import OptionParser from torchvision.datasets import ImageFolder from torchvision import transforms, utils,datasets from utils import CenterLoss, AverageMeter, TopKAccuracyMetric, ModelCheckpoint, batch_augment from models import WSDAN,inception_v3 from dataset import * device = torch.device("cuda") # General loss functions cross_entropy_loss = nn.CrossEntropyLoss() center_loss = CenterLoss() # loss and metric loss_container = AverageMeter(name='loss') raw_metric = TopKAccuracyMetric(topk=(1, 3)) crop_metric = TopKAccuracyMetric(topk=(1, 3)) drop_metric = TopKAccuracyMetric(topk=(1, 3)) def main(): parser = OptionParser() parser.add_option('-j', '--workers', dest='workers', default=16, type='int', help='number of data loading workers (default: 16)') parser.add_option('-e', '--epochs', dest='epochs', default=80, type='int', help='number of epochs (default: 80)') parser.add_option('-b', '--batch-size', dest='batch_size', default=16, type='int', help='batch size (default: 16)') parser.add_option('-c', '--ckpt', dest='ckpt', default=False, help='load checkpoint model (default: False)') parser.add_option('-v', '--verbose', dest='verbose', default=100, type='int', help='show information for each <verbose> iterations (default: 100)')
def main(): logging.basicConfig( format= '%(asctime)s: %(levelname)s: [%(filename)s:%(lineno)d]: %(message)s', level=logging.INFO) warnings.filterwarnings("ignore") try: ckpt = sys.argv[1] except: logging.info('Usage: python3 eval.py <model.ckpt>') return ################################## # Dataset for testing ################################## test_dataset = CarDataset(phase='test', resize=448) test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=1, pin_memory=True) ################################## # Initialize model ################################## net = WSDAN(num_classes=test_dataset.num_classes, M=32, net='inception_mixed_6e') # Load ckpt and get state_dict checkpoint = torch.load(ckpt) state_dict = checkpoint['state_dict'] # Load weights net.load_state_dict(state_dict) logging.info('Network loaded from {}'.format(ckpt)) ################################## # use cuda ################################## cudnn.benchmark = True net.to(device) net = nn.DataParallel(net) net.eval() ################################## # Prediction ################################## accuracy = TopKAccuracyMetric(topk=(1, 5)) accuracy.reset() with torch.no_grad(): pbar = tqdm(total=len(test_loader), unit=' batches') pbar.set_description('Validation') for i, (X, y) in enumerate(test_loader): X = X.to(device) y = y.to(device) # WS-DAN y_pred_raw, feature_matrix, attention_maps = net(X) # Augmentation with crop_mask crop_image = batch_augment(X, attention_maps, mode='crop', theta=0.1) y_pred_crop, _, _ = net(crop_image) pred = (y_pred_raw + y_pred_crop) / 2. if visualize: # reshape attention maps attention_maps = F.upsample_bilinear(attention_maps, size=(X.size(2), X.size(3))) attention_maps = torch.sqrt(attention_maps.cpu() / attention_maps.max().item()) # get heat attention maps heat_attention_maps = generate_heatmap(attention_maps) # raw_image, heat_attention, raw_attention raw_image = X.cpu() * STD + MEAN heat_attention_image = raw_image * 0.5 + heat_attention_maps * 0.5 raw_attention_image = raw_image * attention_maps for batch_idx in range(X.size(0)): rimg = ToPILImage(raw_image[batch_idx]) raimg = ToPILImage(raw_attention_image[batch_idx]) haimg = ToPILImage(heat_attention_image[batch_idx]) rimg.save( os.path.join(savepath, '%03d_raw.jpg' % (i + batch_idx))) raimg.save( os.path.join(savepath, '%03d_raw_atten.jpg' % (i + batch_idx))) haimg.save( os.path.join(savepath, '%03d_heat_atten.jpg' % (i + batch_idx))) # Top K epoch_acc = accuracy(pred, y) # end of this batch batch_info = 'Val Acc ({:.2f}, {:.2f})'.format( epoch_acc[0], epoch_acc[1]) pbar.update() pbar.set_postfix_str(batch_info) pbar.close() # show information for this epoch logging.info('Accuracy: %.2f, %.2f' % (epoch_acc[0], epoch_acc[1]))
def main(): logging.basicConfig( format= '%(asctime)s: %(levelname)s: [%(filename)s:%(lineno)d]: %(message)s', level=logging.INFO) warnings.filterwarnings("ignore") try: ckpt = config.eval_ckpt except: logging.info('Set ckpt for evaluation in config.py') return ################################## # Dataset for testing ################################## # _, test_dataset = get_trainval_datasets(config.tag, resize=config.image_size) test_dataset = CarDataset('test') test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False, num_workers=2, pin_memory=True) name2label, label2name = mapping('../training_labels.csv') ################################## # Initialize model ################################## net = WSDAN(num_classes=test_dataset.num_classes, M=config.num_attentions, net=config.net) # Load ckpt and get state_dict checkpoint = torch.load(ckpt) state_dict = checkpoint['state_dict'] # Load weights net.load_state_dict(state_dict) logging.info('Network loaded from {}'.format(ckpt)) ################################## # use cuda ################################## net.to(device) if torch.cuda.device_count() > 1: net = nn.DataParallel(net) ################################## # Prediction ################################## raw_accuracy = TopKAccuracyMetric(topk=(1, 5)) ref_accuracy = TopKAccuracyMetric(topk=(1, 5)) raw_accuracy.reset() ref_accuracy.reset() net.eval() logits = [] ids = [] with torch.no_grad(): pbar = tqdm(total=len(test_loader), unit=' batches') pbar.set_description('Validation') for i, (X, y, id) in enumerate(test_loader): X = X.to(device) y = y.to(device) ids.extend(id) # WS-DAN y_pred_raw, _, attention_maps = net(X) # Augmentation with crop_mask crop_image = batch_augment(X, attention_maps, mode='crop', theta=0.1, padding_ratio=0.05) y_pred_crop, _, _ = net(crop_image) y_pred = (y_pred_raw + y_pred_crop) / 2. # Save the predictions logits.append(y_pred.cpu()) prediction = torch.argmax(torch.cat(logits, dim=0), dim=1) submission = pd.DataFrame( [ids, [label2name[x] for x in prediction.numpy()]]).transpose() submission.columns = ['id', 'label'] submission.to_csv(savepath + 'predictions.csv', index=False) if visualize: # reshape attention maps attention_maps = F.upsample_bilinear(attention_maps, size=(X.size(2), X.size(3))) attention_maps = torch.sqrt(attention_maps.cpu() / attention_maps.max().item()) # get heat attention maps heat_attention_maps = generate_heatmap(attention_maps) # raw_image, heat_attention, raw_attention raw_image = X.cpu() * STD + MEAN heat_attention_image = raw_image * 0.5 + heat_attention_maps * 0.5 raw_attention_image = raw_image * attention_maps for batch_idx in range(X.size(0)): rimg = ToPILImage(raw_image[batch_idx]) raimg = ToPILImage(raw_attention_image[batch_idx]) haimg = ToPILImage(heat_attention_image[batch_idx]) rimg.save( os.path.join( savepath, '%03d_raw.jpg' % (i * config.batch_size + batch_idx))) raimg.save( os.path.join( savepath, '%03d_raw_atten.jpg' % (i * config.batch_size + batch_idx))) haimg.save( os.path.join( savepath, '%03d_heat_atten.jpg' % (i * config.batch_size + batch_idx))) # Top K epoch_raw_acc = raw_accuracy(y_pred_raw, y) epoch_ref_acc = ref_accuracy(y_pred, y) # end of this batch batch_info = 'Val Acc: Raw ({:.2f}, {:.2f}), Refine ({:.2f}, {:.2f})'.format( epoch_raw_acc[0], epoch_raw_acc[1], epoch_ref_acc[0], epoch_ref_acc[1]) pbar.update() pbar.set_postfix_str(batch_info) pbar.close()
def main(): logging.basicConfig( format='%(asctime)s: %(levelname)s: [%(filename)s:%(lineno)d]: %(message)s', level=logging.INFO) warnings.filterwarnings("ignore") try: ckpt = config.eval_ckpt except: logging.info('Set ckpt for evaluation in config.py') return ################################## # Dataset for testing ################################## _, test_dataset = get_trainval_datasets(config.tag, resize=config.image_size) test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False, num_workers=2, pin_memory=True) ################################## # Initialize model ################################## # net = WSDAN(num_classes=test_dataset.num_classes, M=config.num_attentions, net=config.net) net = resnet34_plus(num_classes=2) # Load ckpt and get state_dict checkpoint = torch.load(ckpt) #state_dict = checkpoint['state_dict'] # Load weights net.load_state_dict(checkpoint) logging.info('Network loaded from {}'.format(ckpt)) ################################## # use cuda ################################## net.to(device) if torch.cuda.device_count() > 1: net = nn.DataParallel(net) ################################## # Prediction ################################## raw_accuracy = TopKAccuracyMetric(topk=(1,)) ref_accuracy = TopKAccuracyMetric(topk=(1,)) raw_accuracy.reset() ref_accuracy.reset() net.eval() with torch.no_grad(): pbar = tqdm(total=len(test_loader), unit=' batches') pbar.set_description('Validation') for i, (X, y) in enumerate(test_loader): X = X.to(device) y = y.to(device) # WS-DAN y_pred_raw= net(X) # Augmentation with crop_mask y_pred = y_pred_raw if visualize: # reshape attention maps attention_maps = F.upsample_bilinear(attention_maps, size=(X.size(2), X.size(3))) attention_maps = torch.sqrt(attention_maps.cpu() / attention_maps.max().item()) # get heat attention maps heat_attention_maps = generate_heatmap(attention_maps) # raw_image, heat_attention, raw_attention raw_image = X.cpu() * STD + MEAN heat_attention_image = raw_image * 0.5 + heat_attention_maps * 0.5 raw_attention_image = raw_image * attention_maps for batch_idx in range(X.size(0)): rimg = ToPILImage(raw_image[batch_idx]) raimg = ToPILImage(raw_attention_image[batch_idx]) haimg = ToPILImage(heat_attention_image[batch_idx]) rimg.save(os.path.join(savepath, '%03d_raw.jpg' % (i * config.batch_size + batch_idx))) raimg.save(os.path.join(savepath, '%03d_raw_atten.jpg' % (i * config.batch_size + batch_idx))) haimg.save(os.path.join(savepath, '%03d_heat_atten.jpg' % (i * config.batch_size + batch_idx))) # Top K epoch_raw_acc = raw_accuracy(y_pred_raw, y) epoch_ref_acc = ref_accuracy(y_pred, y) # end of this batch batch_info = 'Val Acc: Raw ({:.2f}, {:.2f}), Refine ({:.2f}, {:.2f})'.format( epoch_raw_acc[0], epoch_raw_acc[0], epoch_ref_acc[0], epoch_ref_acc[0]) pbar.update() pbar.set_postfix_str(batch_info) pbar.close()