def _main(): torch.set_grad_enabled(False) torch.backends.cudnn.benchmark = True test_dir = "../input/deepfake-detection-challenge/test_videos" csv_path = "../input/deepfake-detection-challenge/sample_submission.csv" face_detector = FaceDetector() face_detector.load_checkpoint("../input/pretrained/RetinaFace-Resnet50-fixed.pth") loader = DFDCLoader(test_dir, face_detector, T.ToTensor()) model1 = xception(num_classes=2, pretrained=False) ckpt = torch.load("../input/pretrained/xception.pth") model1.load_state_dict(ckpt["state_dict"]) model1 = model1.cuda() model1.eval() model2 = WSDAN(num_classes=2, M=8, net="xception", pretrained=False).cuda() ckpt = torch.load("../input/pretrained/wsdan.pth") model2.load_state_dict(ckpt["state_dict"]) model2.eval() zhq_nm_avg = torch.Tensor([.4479, .3744, .3473]).view(1, 3, 1, 1).cuda() zhq_nm_std = torch.Tensor([.2537, .2502, .2424]).view(1, 3, 1, 1).cuda() for batch in loader: batch = batch.cuda(non_blocking=True) m1 = F.interpolate(batch, size=299, mode="bilinear") m1.sub_(0.5).mul_(2.0) m1 = model1(m1).softmax(-1)[:, 1].cpu().numpy() m2 = (batch - zhq_nm_avg) / zhq_nm_std m2, _, _ = model2(m2) m2 = m2.softmax(-1)[:, 1].cpu().numpy() prediction = 0.25 * m1 + 0.75 * m2 loader.feedback(prediction) with open(csv_path) as fin, open("submission.csv", "w") as fout: fout.write(next(fin)) for line in fin: fname = line.split(",", 1)[0] pred = loader.score[fname] print("%s,%.6f" % (fname, pred), file=fout)
def main(): parser = OptionParser() parser.add_option('-j', '--workers', dest='workers', default=16, type='int', help='number of data loading workers (default: 16)') parser.add_option('-e', '--epochs', dest='epochs', default=80, type='int', help='number of epochs (default: 80)') parser.add_option('-b', '--batch-size', dest='batch_size', default=16, type='int', help='batch size (default: 16)') parser.add_option('-c', '--ckpt', dest='ckpt', default=False, help='load checkpoint model (default: False)') parser.add_option('-v', '--verbose', dest='verbose', default=100, type='int', help='show information for each <verbose> iterations (default: 100)') parser.add_option('--lr', '--learning-rate', dest='lr', default=1e-3, type='float', help='learning rate (default: 1e-3)') parser.add_option('--sf', '--save-freq', dest='save_freq', default=1, type='int', help='saving frequency of .ckpt models (default: 1)') parser.add_option('--sd', '--save-dir', dest='save_dir', default='./models/wsdan/', help='saving directory of .ckpt models (default: ./models/wsdan)') parser.add_option('--ln', '--log-name', dest='log_name', default='train.log', help='log name (default: train.log)') parser.add_option('--mn', '--model-name', dest='model_name', default='model.ckpt', help='model name (default:model.ckpt)') parser.add_option('--init', '--initial-training', dest='initial_training', default=1, type='int', help='train from 1-beginning or 0-resume training (default: 1)') (options, args) = parser.parse_args() ################################## # Initialize saving directory ################################## if not os.path.exists(options.save_dir): os.makedirs(options.save_dir) ################################## # Logging setting ################################## logging.basicConfig( filename=os.path.join( options.save_dir, options.log_name), filemode='w', format='%(asctime)s: %(levelname)s: [%(filename)s:%(lineno)d]: %(message)s', level=logging.INFO) warnings.filterwarnings("ignore") ################################## # Load dataset ################################## image_size = (256,256) num_classes = 4 transform = transforms.Compose([transforms.Resize(size=image_size), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) train_dataset = CustomDataset(data_root='/mnt/HDD/RFW/train/data/',csv_file='data/RFW_Train40k_Images_Metada.csv',transform=transform) val_dataset = CustomDataset(data_root='/mnt/HDD/RFW/train/data/',csv_file='data/RFW_Val4k_Images_Metadata.csv',transform=transform) test_dataset = CustomDataset(data_root='/mnt/HDD/RFW/test/data/',csv_file='data/RFW_Test_Images_Metadata.csv',transform=transform) train_loader = DataLoader(train_dataset, batch_size=options.batch_size, shuffle=True,num_workers=options.workers, pin_memory=True) validate_loader = DataLoader(val_dataset, batch_size=options.batch_size * 4, shuffle=False,num_workers=options.workers, pin_memory=True) test_loader = DataLoader(test_dataset, batch_size=options.batch_size * 4, shuffle=False,num_workers=options.workers, pin_memory=True) ################################## # Initialize model ################################## logs = {} start_epoch = 0 num_attentions = 32 feature_net = inception_v3(pretrained=True) net = WSDAN(num_classes=num_classes, M=num_attentions, net='inception_mixed_6e', pretrained=True) # feature_center: size of (#classes, #attention_maps * #channel_features) feature_center = torch.zeros(num_classes, num_attentions * net.num_features).to(device) if options.ckpt: # Load ckpt and get state_dict checkpoint = torch.load(options.ckpt) # Get epoch and some logs logs = checkpoint['logs'] start_epoch = int(logs['epoch']) # Load weights state_dict = checkpoint['state_dict'] net.load_state_dict(state_dict) logging.info('Network loaded from {}'.format(options.ckpt)) # load feature center if 'feature_center' in checkpoint: feature_center = checkpoint['feature_center'].to(device) logging.info('feature_center loaded from {}'.format(options.ckpt)) logging.info('Network weights save to {}'.format(options.save_dir)) feature_net = inception_v3(pretrained=True) if options.ckpt: ckpt = options.ckpt if options.initial_training == 0: # Get Name (epoch) epoch_name = (ckpt.split('/')[-1]).split('.')[0] start_epoch = int(epoch_name) # Load ckpt and get state_dict checkpoint = torch.load(ckpt) state_dict = checkpoint['state_dict'] # Load weights net.load_state_dict(state_dict) logging.info('Network loaded from {}'.format(options.ckpt)) # load feature center if 'feature_center' in checkpoint: feature_center = checkpoint['feature_center'].to(torch.device("cuda")) logging.info('feature_center loaded from {}'.format(options.ckpt)) ################################## # Use cuda ################################## net.to(device) if torch.cuda.device_count() > 1: net = nn.DataParallel(net) ################################## # Optimizer, LR Scheduler ################################## learning_rate = logs['lr'] if 'lr' in logs else options.lr optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-5) # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.9, patience=2) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.9) ################################## # ModelCheckpoint ################################## callback_monitor = 'val_{}'.format(raw_metric.name) callback = ModelCheckpoint(savepath=os.path.join(options.save_dir, options.model_name), monitor=callback_monitor, mode='max') if callback_monitor in logs: callback.set_best_score(logs[callback_monitor]) else: callback.reset() ################################## # TRAINING ################################## logging.info('') logging.info('Start training: Total epochs: {}, Batch size: {}, Training size: {}, Validation size: {}'. format(options.epochs, options.batch_size, len(train_dataset), len(val_dataset))) for epoch in range(start_epoch, options.epochs): callback.on_epoch_begin() logs['epoch'] = epoch + 1 logs['lr'] = optimizer.param_groups[0]['lr'] logging.info('Epoch {:03d}, Learning Rate {:g}'.format(epoch + 1, optimizer.param_groups[0]['lr'])) pbar = tqdm(total=len(train_loader), unit=' batches') pbar.set_description('Epoch {}/{}'.format(epoch + 1, options.epochs)) train(logs=logs, data_loader=train_loader, net=net, feature_center=feature_center, optimizer=optimizer, pbar=pbar) validate(logs=logs, data_loader=validate_loader, net=net, pbar=pbar) if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): scheduler.step(logs['val_loss']) else: scheduler.step() callback.on_epoch_end(logs, net, feature_center=feature_center) pbar.close()
def predict(image_path, model_param_path, save_path, img_save_name, resize=(224, 224), gen_hm=False): image = Image.open(image_path).convert('RGB') transform = transforms.Compose([ # transforms.Resize(size=(int(resize[0] / 0.875), int(resize[1] / 0.875))), transforms.Resize(size=(int(resize[0]), int(resize[1]))), # transforms.CenterCrop(resize), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) image = transform(image) image = image.unsqueeze(0) net = WSDAN(num_classes=4) net.load_state_dict(torch.load(model_param_path)) net.eval() if 'gpu' in model_param_path: print("please make sure your computer has a GPU") device = torch.device("cuda") try: net.to(device) except: print("No GPU in the environment") else: device = torch.device("cpu") X = image X = X.to(device) # WS-DAN y_pred_raw, _, attention_maps = net(X) attention_maps = torch.mean(attention_maps, dim=1, keepdim=True) # Augmentation with crop_mask crop_image = batch_augment(X, attention_maps, mode='crop', theta=0.1, padding_ratio=0.05) y_pred_crop, _, _ = net(crop_image) y_pred = (y_pred_raw + y_pred_crop) / 2. y_pred = F.softmax(y_pred) if gen_hm: attention_maps = F.upsample_bilinear(attention_maps, size=(X.size(2), X.size(3))) attention_maps = torch.sqrt(attention_maps.cpu() / attention_maps.max().item()) # get heat attention maps heat_attention_maps = generate_heatmap(attention_maps) # raw_image, heat_attention, raw_attention raw_image = X.cpu() * STD + MEAN heat_attention_image = raw_image * 0.4 + heat_attention_maps * 0.6 raw_attention_image = raw_image * attention_maps for batch_idx in range(X.size(0)): rimg = ToPILImage(raw_image[batch_idx]) raimg = ToPILImage(raw_attention_image[batch_idx]) haimg = ToPILImage(heat_attention_image[batch_idx]) rimg.save( os.path.join(save_path, '{}_raw.jpg'.format(img_save_name))) raimg.save( os.path.join(save_path, '{}_raw_atten.jpg'.format(img_save_name))) haimg.save( os.path.join(save_path, '{}_heat_atten.jpg'.format(img_save_name))) df = pd.read_csv("../data/train.csv") for i in range(len(df)): # if df.loc[i, 'image_id'] in image_path: head, tail = os.path.split(image_path) if df.loc[i, 'image_id'] == tail[:-4]: label = torch.tensor( df.loc[i, ['healthy', 'multiple_diseases', 'rust', 'scab']]) break return y_pred, label
def main(): logging.basicConfig( format= '%(asctime)s: %(levelname)s: [%(filename)s:%(lineno)d]: %(message)s', level=logging.INFO) warnings.filterwarnings("ignore") try: ckpt = config.eval_ckpt except: logging.info('Set ckpt for evaluation in config.py') return ################################## # Dataset for testing ################################## # _, test_dataset = get_trainval_datasets(config.tag, resize=config.image_size) test_dataset = CarDataset('test') test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False, num_workers=2, pin_memory=True) name2label, label2name = mapping('../training_labels.csv') ################################## # Initialize model ################################## net = WSDAN(num_classes=test_dataset.num_classes, M=config.num_attentions, net=config.net) # Load ckpt and get state_dict checkpoint = torch.load(ckpt) state_dict = checkpoint['state_dict'] # Load weights net.load_state_dict(state_dict) logging.info('Network loaded from {}'.format(ckpt)) ################################## # use cuda ################################## net.to(device) if torch.cuda.device_count() > 1: net = nn.DataParallel(net) ################################## # Prediction ################################## raw_accuracy = TopKAccuracyMetric(topk=(1, 5)) ref_accuracy = TopKAccuracyMetric(topk=(1, 5)) raw_accuracy.reset() ref_accuracy.reset() net.eval() logits = [] ids = [] with torch.no_grad(): pbar = tqdm(total=len(test_loader), unit=' batches') pbar.set_description('Validation') for i, (X, y, id) in enumerate(test_loader): X = X.to(device) y = y.to(device) ids.extend(id) # WS-DAN y_pred_raw, _, attention_maps = net(X) # Augmentation with crop_mask crop_image = batch_augment(X, attention_maps, mode='crop', theta=0.1, padding_ratio=0.05) y_pred_crop, _, _ = net(crop_image) y_pred = (y_pred_raw + y_pred_crop) / 2. # Save the predictions logits.append(y_pred.cpu()) prediction = torch.argmax(torch.cat(logits, dim=0), dim=1) submission = pd.DataFrame( [ids, [label2name[x] for x in prediction.numpy()]]).transpose() submission.columns = ['id', 'label'] submission.to_csv(savepath + 'predictions.csv', index=False) if visualize: # reshape attention maps attention_maps = F.upsample_bilinear(attention_maps, size=(X.size(2), X.size(3))) attention_maps = torch.sqrt(attention_maps.cpu() / attention_maps.max().item()) # get heat attention maps heat_attention_maps = generate_heatmap(attention_maps) # raw_image, heat_attention, raw_attention raw_image = X.cpu() * STD + MEAN heat_attention_image = raw_image * 0.5 + heat_attention_maps * 0.5 raw_attention_image = raw_image * attention_maps for batch_idx in range(X.size(0)): rimg = ToPILImage(raw_image[batch_idx]) raimg = ToPILImage(raw_attention_image[batch_idx]) haimg = ToPILImage(heat_attention_image[batch_idx]) rimg.save( os.path.join( savepath, '%03d_raw.jpg' % (i * config.batch_size + batch_idx))) raimg.save( os.path.join( savepath, '%03d_raw_atten.jpg' % (i * config.batch_size + batch_idx))) haimg.save( os.path.join( savepath, '%03d_heat_atten.jpg' % (i * config.batch_size + batch_idx))) # Top K epoch_raw_acc = raw_accuracy(y_pred_raw, y) epoch_ref_acc = ref_accuracy(y_pred, y) # end of this batch batch_info = 'Val Acc: Raw ({:.2f}, {:.2f}), Refine ({:.2f}, {:.2f})'.format( epoch_raw_acc[0], epoch_raw_acc[1], epoch_ref_acc[0], epoch_ref_acc[1]) pbar.update() pbar.set_postfix_str(batch_info) pbar.close()
batch_size=args.batch_size, num_workers=1, shuffle = True) #sampler=sampler if VALID: val_loader = torch.utils.data.DataLoader( data_orig_val, batch_size=2, shuffle=False, num_workers=1) if args.data_crop: val_loader_crop = torch.utils.data.DataLoader( datasets.ImageFolder(args.data_crop + VALID_IMAGES, transform=data_transforms_val), batch_size=2, shuffle=False, num_workers=1) device = torch.device("cuda") print("define wsdan") model = WSDAN(num_classes=args.num_classes, M=num_attentions, net=NET, pretrained=True) feature_center = torch.zeros(args.num_classes, num_attentions * model.num_features).to(device) center_loss = CenterLoss() cross_entropy_loss = nn.CrossEntropyLoss() if args.model: print("loading pretrained model") checkpoint = torch.load(args.model) model.load_state_dict(checkpoint) if use_cuda: print('Using GPU') model.cuda() else: print('Using CPU')
def main(): logging.basicConfig( format= '%(asctime)s: %(levelname)s: [%(filename)s:%(lineno)d]: %(message)s', level=logging.INFO) warnings.filterwarnings("ignore") try: ckpt = sys.argv[1] except: logging.info('Usage: python3 eval.py <model.ckpt>') return ################################## # Dataset for testing ################################## test_dataset = CarDataset(phase='test', resize=448) test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=1, pin_memory=True) ################################## # Initialize model ################################## net = WSDAN(num_classes=test_dataset.num_classes, M=32, net='inception_mixed_6e') # Load ckpt and get state_dict checkpoint = torch.load(ckpt) state_dict = checkpoint['state_dict'] # Load weights net.load_state_dict(state_dict) logging.info('Network loaded from {}'.format(ckpt)) ################################## # use cuda ################################## cudnn.benchmark = True net.to(device) net = nn.DataParallel(net) net.eval() ################################## # Prediction ################################## accuracy = TopKAccuracyMetric(topk=(1, 5)) accuracy.reset() with torch.no_grad(): pbar = tqdm(total=len(test_loader), unit=' batches') pbar.set_description('Validation') for i, (X, y) in enumerate(test_loader): X = X.to(device) y = y.to(device) # WS-DAN y_pred_raw, feature_matrix, attention_maps = net(X) # Augmentation with crop_mask crop_image = batch_augment(X, attention_maps, mode='crop', theta=0.1) y_pred_crop, _, _ = net(crop_image) pred = (y_pred_raw + y_pred_crop) / 2. if visualize: # reshape attention maps attention_maps = F.upsample_bilinear(attention_maps, size=(X.size(2), X.size(3))) attention_maps = torch.sqrt(attention_maps.cpu() / attention_maps.max().item()) # get heat attention maps heat_attention_maps = generate_heatmap(attention_maps) # raw_image, heat_attention, raw_attention raw_image = X.cpu() * STD + MEAN heat_attention_image = raw_image * 0.5 + heat_attention_maps * 0.5 raw_attention_image = raw_image * attention_maps for batch_idx in range(X.size(0)): rimg = ToPILImage(raw_image[batch_idx]) raimg = ToPILImage(raw_attention_image[batch_idx]) haimg = ToPILImage(heat_attention_image[batch_idx]) rimg.save( os.path.join(savepath, '%03d_raw.jpg' % (i + batch_idx))) raimg.save( os.path.join(savepath, '%03d_raw_atten.jpg' % (i + batch_idx))) haimg.save( os.path.join(savepath, '%03d_heat_atten.jpg' % (i + batch_idx))) # Top K epoch_acc = accuracy(pred, y) # end of this batch batch_info = 'Val Acc ({:.2f}, {:.2f})'.format( epoch_acc[0], epoch_acc[1]) pbar.update() pbar.set_postfix_str(batch_info) pbar.close() # show information for this epoch logging.info('Accuracy: %.2f, %.2f' % (epoch_acc[0], epoch_acc[1]))
def main(): ################################## # Initialize saving directory ################################## if not os.path.exists(config.save_dir): os.makedirs(config.save_dir) ################################## # Logging setting ################################## logging.basicConfig( filename=os.path.join(config.save_dir, config.log_name), filemode='w', format= '%(asctime)s: %(levelname)s: [%(filename)s:%(lineno)d]: %(message)s', level=logging.INFO) warnings.filterwarnings("ignore") ################################## # Load dataset ################################## # train_dataset, validate_dataset = get_trainval_datasets(config.tag, config.image_size) full_train_dataset = CarDataset('train') n = len(full_train_dataset) # train_dataset, validate_dataset = torch.utils.data.random_split(full_train_dataset, [int(n*0.8), n-int(n*0.8)]) train_dataset = full_train_dataset validate_dataset = full_train_dataset train_loader, validate_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=config.workers, pin_memory=True), \ DataLoader(validate_dataset, batch_size=config.batch_size * 4, shuffle=False, num_workers=config.workers, pin_memory=True) num_classes = full_train_dataset.num_classes ################################## # Initialize model ################################## logs = {} start_epoch = 0 net = WSDAN(num_classes=num_classes, M=config.num_attentions, net=config.net, pretrained=True) # feature_center: size of (#classes, #attention_maps * #channel_features) feature_center = torch.zeros(num_classes, config.num_attentions * net.num_features).to(device) if config.ckpt: # Load ckpt and get state_dict checkpoint = torch.load(config.ckpt) # Get epoch and some logs logs = checkpoint['logs'] start_epoch = int(logs['epoch']) # Load weights state_dict = checkpoint['state_dict'] net.load_state_dict(state_dict) logging.info('Network loaded from {}'.format(config.ckpt)) # load feature center if 'feature_center' in checkpoint: feature_center = checkpoint['feature_center'].to(device) logging.info('feature_center loaded from {}'.format(config.ckpt)) logging.info('Network weights save to {}'.format(config.save_dir)) ################################## # Use cuda ################################## net.to(device) if torch.cuda.device_count() > 1: net = nn.DataParallel(net) ################################## # Optimizer, LR Scheduler ################################## learning_rate = logs['lr'] if 'lr' in logs else config.learning_rate optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-5) # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.9, patience=2) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.9) ################################## # ModelCheckpoint ################################## callback_monitor = 'val_{}'.format(raw_metric.name) callback = ModelCheckpoint(savepath=os.path.join(config.save_dir, config.model_name), monitor=callback_monitor, mode='max') if callback_monitor in logs: callback.set_best_score(logs[callback_monitor]) else: callback.reset() ################################## # TRAINING ################################## logging.info( 'Start training: Total epochs: {}, Batch size: {}, Training size: {}, Validation size: {}' .format(config.epochs, config.batch_size, len(train_dataset), len(validate_dataset))) logging.info('') for epoch in range(start_epoch, config.epochs): callback.on_epoch_begin() logs['epoch'] = epoch + 1 logs['lr'] = optimizer.param_groups[0]['lr'] logging.info('Epoch {:03d}, Learning Rate {:g}'.format( epoch + 1, optimizer.param_groups[0]['lr'])) pbar = tqdm(total=len(train_loader), unit=' batches') pbar.set_description('Epoch {}/{}'.format(epoch + 1, config.epochs)) train(logs=logs, data_loader=train_loader, net=net, feature_center=feature_center, optimizer=optimizer, pbar=pbar) validate(logs=logs, data_loader=validate_loader, net=net, pbar=pbar) if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): scheduler.step(logs['val_loss']) else: scheduler.step() callback.on_epoch_end(logs, net, feature_center=feature_center) pbar.close()
def main(result_arr): logging.basicConfig( format='%(asctime)s: %(levelname)s: [%(filename)s:%(lineno)d]: %(message)s', level=logging.INFO) warnings.filterwarnings("ignore") try: ckpt = config.eval_ckpt except: logging.info('Set ckpt for evaluation in config.py') return ################################## # Dataset for testing ################################## _, test_dataset = get_trainval_datasets(config.tag, resize=config.image_size) test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False, num_workers=2, pin_memory=True) ################################## # Initialize model ################################## net = WSDAN(num_classes=test_dataset.num_classes, M=config.num_attentions, net=config.net) # Load ckpt and get state_dict checkpoint = torch.load(ckpt) state_dict = checkpoint['state_dict'] # Load weights net.load_state_dict(state_dict) logging.info('Network loaded from {}'.format(ckpt)) ################################## # use cuda ################################## net.to(device) if torch.cuda.device_count() > 1: net = nn.DataParallel(net) ################################## # Prediction ################################## raw_accuracy = TopKAccuracyMetric(topk=(1, 5)) ref_accuracy = TopKAccuracyMetric(topk=(1, 5)) raw_accuracy.reset() ref_accuracy.reset() net.eval() with torch.no_grad(): pbar = tqdm(total=len(test_loader), unit=' batches') pbar.set_description('Validation') for i, (X, y) in enumerate(test_loader): X = X.to(device) y = y.to(device) # WS-DAN y_pred_raw, _, attention_maps = net(X) # Augmentation with crop_mask crop_image = batch_augment(X, attention_maps, mode='crop', theta=0.1, padding_ratio=0.05) y_pred_crop, _, _ = net(crop_image) y_pred = (y_pred_raw + y_pred_crop) / 2. d = {} reader = csv.reader(open('/home/naman/Documents/Assignment_Job/out_dict.csv', 'r')) for row in reader: k, v = row d[v] = k result.append(y_pred, d[y_pred)] if visualize: # reshape attention maps attention_maps = F.upsample_bilinear(attention_maps, size=(X.size(2), X.size(3))) attention_maps = torch.sqrt(attention_maps.cpu() / attention_maps.max().item()) # get heat attention maps heat_attention_maps = generate_heatmap(attention_maps) # raw_image, heat_attention, raw_attention raw_image = X.cpu() * STD + MEAN heat_attention_image = raw_image * 0.5 + heat_attention_maps * 0.5 raw_attention_image = raw_image * attention_maps for batch_idx in range(X.size(0)): rimg = ToPILImage(raw_image[batch_idx]) raimg = ToPILImage(raw_attention_image[batch_idx]) haimg = ToPILImage(heat_attention_image[batch_idx]) rimg.save(os.path.join(savepath, '%03d_raw.jpg' % (i * config.batch_size + batch_idx))) raimg.save(os.path.join(savepath, '%03d_raw_atten.jpg' % (i * config.batch_size + batch_idx))) haimg.save(os.path.join(savepath, '%03d_heat_atten.jpg' % (i * config.batch_size + batch_idx))) # Top K epoch_raw_acc = raw_accuracy(y_pred_raw, y) epoch_ref_acc = ref_accuracy(y_pred, y) # end of this batch batch_info = 'Val Acc: Raw ({:.2f}, {:.2f}), Refine ({:.2f}, {:.2f})'.format( epoch_raw_acc[0], epoch_raw_acc[1], epoch_ref_acc[0], epoch_ref_acc[1]) pbar.update() pbar.set_postfix_str(batch_info) pbar.close()
def init_model(num_classes, use_pretrained): model = WSDAN(num_classes=num_classes, pretrained=use_pretrained) return model, 224
parser.add_argument( '--output-folder', type=str, default='attention_dataset', metavar='D', help= "folder where data is located. train_images/ and val_images/ need to be found in the folder" ) args = parser.parse_args() model_path = 'experiment/wsdanp_retrain_model_9.pth' NET = 'inception_mixed_7c' #'"inception_mixed_6e" #inception_mixed_7c num_attentions = 32 model = WSDAN(num_classes=20, M=num_attentions, net=NET, pretrained=False) checkpoint = torch.load(model_path) model.load_state_dict(checkpoint) model.cuda() def generate_heatmap(attention_maps): heat_attention_maps = [] heat_attention_maps.append(attention_maps[:, 0, ...]) # R heat_attention_maps.append(attention_maps[:, 0, ...] * (attention_maps[:, 0, ...] < 0.5).float() + \ (1. - attention_maps[:, 0, ...]) * (attention_maps[:, 0, ...] >= 0.5).float()) # G heat_attention_maps.append(1. - attention_maps[:, 0, ...]) # B return torch.stack(heat_attention_maps, dim=1) tf = transforms.ToTensor()
def main(): parser = OptionParser() parser.add_option('--gpu', '--gpu', dest='GPU', default=0, type='int', help='GPU Id (default: 0)') parser.add_option('--evalckpt', '--eval-ckpt', dest='eval_ckpt', default='models/wsdan/003.ckpt', help='saved models are in ckpt directory') parser.add_option('-b', '--batch-size', dest='batch_size', default=64, type='int', help='batch size (default: 16)') parser.add_option('-j', '--workers', dest='workers', default=4, type='int', help='number of data loading workers (default: 16)') parser.add_option('--na', '--num-attentions', dest='num_attentions', default=32, type='int', help='number of attentions') parser.add_option('--cm', '--confusion_matrix', dest='confusion_matrix', default=True, help='if you want to create confusion matrix') (options, args) = parser.parse_args() logging.basicConfig( format= '%(asctime)s: %(levelname)s: [%(filename)s:%(lineno)d]: %(message)s', level=logging.INFO) warnings.filterwarnings("ignore") try: ckpt = options.eval_ckpt except: logging.info('Set ckpt for evaluation options') return # Dataset for testing transform = transforms.Compose([ transforms.Resize(size=(256, 256)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) test_dataset = CustomDataset( data_root='/mnt/HDD/DatasetOriginals/RFW/test/data/', csv_file='data/RFW_Test_Images_Metadata.csv', transform=transform) test_loader = DataLoader(test_dataset, batch_size=options.batch_size * 4, shuffle=False, num_workers=options.workers, pin_memory=True) ################################## # Initialize model ################################## net = WSDAN(num_classes=4, M=32, net='inception_mixed_6e') # Load ckpt and get state_dict checkpoint = torch.load(ckpt) state_dict = checkpoint['state_dict'] # Load weights net.load_state_dict(state_dict) logging.info('Network loaded from {}'.format(ckpt)) ################################## # use cuda ################################## net.to(device) if torch.cuda.device_count() > 1: net = nn.DataParallel(net) ################################## # Prediction ################################## raw_accuracy = TopKAccuracyMetric(topk=(1, 3)) ref_accuracy = TopKAccuracyMetric(topk=(1, 3)) raw_accuracy.reset() ref_accuracy.reset() top1 = AverageMeter('Acc@1', ':6.2f') top_refined = AverageMeter('Acc@1', ':6.2f') net.eval() y_pred, y_true = [], [] with torch.no_grad(): pbar = tqdm(total=len(test_loader), unit=' batches') pbar.set_description('Validation') for i, (X, y) in enumerate(test_loader): y_true += list(y.numpy()) X = X.to(device) y = y.to(device) # WS-DAN y_pred_raw, _, attention_maps = net(X) # Augmentation with crop_mask crop_image = batch_augment(X, attention_maps, mode='crop', theta=0.1, padding_ratio=0.05) y_pred_crop, _, _ = net(crop_image) y_predicted = (y_pred_raw + y_pred_crop) / 2. _, pred = y_predicted.topk(1, 1, True, True) y_pred += list(pred.cpu().numpy()) if visualize: # reshape attention maps attention_maps = F.upsample_bilinear(attention_maps, size=(X.size(2), X.size(3))) attention_maps = torch.sqrt(attention_maps.cpu() / attention_maps.max().item()) # get heat attention maps heat_attention_maps = generate_heatmap(attention_maps) # raw_image, heat_attention, raw_attention raw_image = X.cpu() * STD + MEAN heat_attention_image = raw_image * 0.5 + heat_attention_maps * 0.5 raw_attention_image = raw_image * attention_maps for batch_idx in range(X.size(0)): rimg = ToPILImage(raw_image[batch_idx]) raimg = ToPILImage(raw_attention_image[batch_idx]) haimg = ToPILImage(heat_attention_image[batch_idx]) rimg.save( os.path.join( savepath, '%03d_raw.jpg' % (i * options.batch_size + batch_idx))) raimg.save( os.path.join( savepath, '%03d_raw_atten.jpg' % (i * options.batch_size + batch_idx))) haimg.save( os.path.join( savepath, '%03d_heat_atten.jpg' % (i * options.batch_size + batch_idx))) # Top K epoch_raw_acc = raw_accuracy(y_pred_raw, y) epoch_ref_acc = ref_accuracy(y_predicted, y) top1.update(epoch_raw_acc[0], X.size(0)) top_refined.update(epoch_ref_acc[0], X.size(0)) # end of this batch batch_info = 'Val Acc: Raw ({:.2f}, {:.2f}), Refine ({:.2f}, {:.2f})'.format( epoch_raw_acc[0], epoch_raw_acc[1], epoch_ref_acc[0], epoch_ref_acc[1]) pbar.update() pbar.set_postfix_str(batch_info) pbar.close() print(' * Raw Accuracy {top1.avg:.3f}'.format(top1=top1)) print(' * Refined Accuracy {top1.avg:.3f}'.format(top1=top_refined)) print(len(y_pred), len(y_true)) if options.confusion_matrix: file_name = 'source/wsdan_confusion_matrix.svg' draw_confusion_matrix(np.asarray(y_true), np.asarray(y_pred), file_name)
def main_worker(local_rank, ngpus_per_node, args): if local_rank == 0: logging.basicConfig( filename=os.path.join(settings.save_dir, settings.log_name), filemode='w', format='%(asctime)s: %(levelname)s: [%(filename)s:%(lineno)d]: ' '%(message)s', level=logging.INFO) warnings.filterwarnings("ignore") dist.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:22465', world_size=ngpus_per_node, rank=local_rank) torch.cuda.set_device(local_rank) train_dataset = DfdcDataset(phase='train', datapath=settings.datapath, resize=settings.image_size) validate_dataset = DfdcDataset(phase='val', datapath=settings.datapath, resize=settings.image_size) train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset) validate_sampler = torch.utils.data.distributed.DistributedSampler( validate_dataset) train_loader = DataLoader(train_dataset, batch_size=settings.batch_size, sampler=train_sampler, pin_memory=True, num_workers=settings.workers) validate_loader = DataLoader(validate_dataset, batch_size=settings.batch_size, sampler=validate_sampler, pin_memory=True, num_workers=settings.workers) num_classes = train_dataset.num_classes logs = {} start_epoch = 0 net = WSDAN(num_classes=num_classes, M=settings.num_attentions, net=settings.net, pretrained=settings.pretrained) num_features = net.num_features net = nn.SyncBatchNorm.convert_sync_batchnorm(net).to(local_rank) net = nn.parallel.DistributedDataParallel(net, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True) center_loss = CenterLoss().to(local_rank) cross_entropy_loss = nn.CrossEntropyLoss().to(local_rank) feature_center = torch.zeros(num_classes, settings.num_attentions * num_features).to(local_rank) if settings.ckpt: loc = 'cuda:{}'.format(local_rank) checkpoint = torch.load(settings.ckpt, map_location=loc) logs = checkpoint['logs'] start_epoch = int(logs['epoch']) state_dict = checkpoint['state_dict'] net.module.load_state_dict(state_dict) if 'feature_center' in checkpoint: feature_center = F.normalize(checkpoint['feature_center'], dim=-1) learning_rate = logs['lr'] if 'lr' in logs else settings.learning_rate optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-5) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.95) for epoch in range(start_epoch, settings.epochs): logs['epoch'] = epoch + 1 logs['lr'] = optimizer.param_groups[0]['lr'] train_sampler.set_epoch(epoch) train_sampler.dataset.next_epoch() train(logs=logs, data_loader=train_loader, net=net, cross_entropy_loss=cross_entropy_loss, center_loss=center_loss, feature_center=feature_center, optimizer=optimizer, ngpus_per_node=ngpus_per_node, local_rank=local_rank) validate(logs=logs, data_loader=validate_loader, cross_entropy_loss=cross_entropy_loss, net=net, ngpus_per_node=ngpus_per_node, local_rank=local_rank) scheduler.step() if local_rank == 0: torch.save( { 'logs': logs, 'state_dict': net.module.state_dict(), 'feature_center': feature_center }, settings.save_dir + 'ckpt_%s.pth' % epoch) dist.barrier()