def main_worker(gpu, ngpus_per_node, args): global best_acc1 args.gpu = gpu if args.gpu is not None: print("Use GPU: {} for training".format(args.gpu)) if args.distributed: if args.dist_url == "env://" and args.rank == -1: args.rank = int(os.environ["RANK"]) if args.multiprocessing_distributed: # For multiprocessing distributed training, rank needs to be the # global rank among all the processes args.rank = args.rank * ngpus_per_node + gpu dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) # create model if args.pretrained: print("=> using pre-trained model '{}'".format(args.arch)) model = models.__dict__[args.arch](pretrained=True) else: #print("=> creating model '{}'".format(args.arch)) #model = models.__dict__[args.arch]() print("Creating Attn Model") model = AttnVGG_before(num_classes=1000, attention=True, normalize_attn=True) if args.distributed: # For multiprocessing distributed, DistributedDataParallel constructor # should always set the single device scope, otherwise, # DistributedDataParallel will use all available devices. if args.gpu is not None: torch.cuda.set_device(args.gpu) model.cuda(args.gpu) # When using a single GPU per process and per # DistributedDataParallel, we need to divide the batch size # ourselves based on the total number of GPUs we have args.batch_size = int(args.batch_size / ngpus_per_node) args.workers = int(args.workers / ngpus_per_node) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) else: model.cuda() # DistributedDataParallel will divide and allocate batch_size to all # available GPUs if device_ids are not set model = torch.nn.parallel.DistributedDataParallel(model) elif args.gpu is not None: torch.cuda.set_device(args.gpu) model = model.cuda(args.gpu) else: # DataParallel will divide and allocate batch_size to all available GPUs if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): model.features = torch.nn.DataParallel(model.features) model.cuda() else: model = torch.nn.DataParallel(model).cuda() # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss().cuda(args.gpu) optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_acc1 = checkpoint['best_acc1'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) cudnn.benchmark = True # Data loading code traindir = os.path.join(args.data, 'train') valdir = os.path.join(args.data, 'val') normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = datasets.ImageFolder( traindir, transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) else: train_sampler = None train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler) val_loader = torch.utils.data.DataLoader( datasets.ImageFolder(valdir, transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ])), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) if args.evaluate: validate(val_loader, model, criterion, args) return for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) adjust_learning_rate(optimizer, epoch, args) # train for one epoch train(train_loader, model, criterion, optimizer, epoch, args) # evaluate on validation set acc1 = validate(val_loader, model, criterion, args) # remember best acc@1 and save checkpoint is_best = acc1 > best_acc1 best_acc1 = max(acc1, best_acc1) if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): save_checkpoint({ 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_acc1': best_acc1, 'optimizer' : optimizer.state_dict(), }, is_best)
def main(): ## load data # CIFAR-100: 500 training images and 100 testing images per class print('\nloading the dataset ...\n') num_aug = 3 im_size = 32 transform_train = transforms.Compose([ transforms.RandomCrop(im_size, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)) ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)) ]) def _init_fn(worker_id): random.seed(base_seed + worker_id) # trainset = torchvision.datasets.CIFAR100(root='CIFAR100_data', train=True, download=True, transform=transform_train) trainset = DrawDataset(transform=transform_train) trainloader = torch.utils.data.DataLoader(trainset, batch_size=opt.batch_size, shuffle=True, num_workers=8, worker_init_fn=_init_fn) # testset = torchvision.datasets.CIFAR100(root='CIFAR100_data', train=False, download=True, transform=transform_test) testset = DrawDataset(transform=transform_test) testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=5) print('done') ## load network print('\nloading the network ...\n') # use attention module? if not opt.no_attention: print('\nturn on attention ...\n') else: print('\nturn off attention ...\n') # (linear attn) insert attention befroe or after maxpooling? # (grid attn only supports "before" mode) if opt.attn_mode == 'before': print('\npay attention before maxpooling layers...\n') net = AttnVGG_before(im_size=im_size, num_classes=100, attention=not opt.no_attention, normalize_attn=opt.normalize_attn, init='xavierUniform') elif opt.attn_mode == 'after': print('\npay attention after maxpooling layers...\n') net = AttnVGG_after(im_size=im_size, num_classes=100, attention=not opt.no_attention, normalize_attn=opt.normalize_attn, init='xavierUniform') else: raise NotImplementedError("Invalid attention mode!") criterion = nn.CrossEntropyLoss() print('done') ## move to GPU print('\nmoving to GPU ...\n') device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device_ids = [0, 1] model = nn.DataParallel(net, device_ids=device_ids).to(device) criterion.to(device) print('done') ### optimizer optimizer = optim.SGD(model.parameters(), lr=opt.lr, momentum=0.9, weight_decay=5e-4) lr_lambda = lambda epoch: np.power(0.5, int(epoch / 25)) scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) # training print('\nstart training ...\n') step = 0 running_avg_accuracy = 0 writer = SummaryWriter(opt.outf) for epoch in range(opt.epochs): images_disp = [] # adjust learning rate scheduler.step() writer.add_scalar('train/learning_rate', optimizer.param_groups[0]['lr'], epoch) print("\nepoch %d learning rate %f\n" % (epoch, optimizer.param_groups[0]['lr'])) # run for one epoch for aug in range(num_aug): for i, data in enumerate(trainloader, 0): # warm up model.train() model.zero_grad() optimizer.zero_grad() inputs, labels = data inputs, labels = inputs.to(device), labels.to(device) if (aug == 0) and ( i == 0): # archive images in order to save to logs images_disp.append(inputs[0:36, :, :, :]) # forward pred, __, __, __ = model(inputs) # backward loss = criterion(pred, labels) loss.backward() optimizer.step() # display results if i % 10 == 0: model.eval() pred, __, __, __ = model(inputs) predict = torch.argmax(pred, 1) total = labels.size(0) correct = torch.eq(predict, labels).sum().double().item() accuracy = correct / total running_avg_accuracy = 0.9 * running_avg_accuracy + 0.1 * accuracy writer.add_scalar('train/loss', loss.item(), step) writer.add_scalar('train/accuracy', accuracy, step) writer.add_scalar('train/running_avg_accuracy', running_avg_accuracy, step) print( "[epoch %d][aug %d/%d][%d/%d] loss %.4f accuracy %.2f%% running avg accuracy %.2f%%" % (epoch, aug, num_aug - 1, i, len(trainloader) - 1, loss.item(), (100 * accuracy), (100 * running_avg_accuracy))) step += 1 # the end of each epoch: test & log print('\none epoch done, saving records ...\n') torch.save(model.state_dict(), os.path.join(opt.outf, 'net.pth')) if epoch == opt.epochs / 2: torch.save(model.state_dict(), os.path.join(opt.outf, 'net%d.pth' % epoch)) model.eval() total = 0 correct = 0 with torch.no_grad(): # log scalars for i, data in enumerate(testloader, 0): images_test, labels_test = data images_test, labels_test = images_test.to( device), labels_test.to(device) if i == 0: # archive images in order to save to logs images_disp.append(inputs[0:36, :, :, :]) pred_test, __, __, __ = model(images_test) predict = torch.argmax(pred_test, 1) total += labels_test.size(0) correct += torch.eq(predict, labels_test).sum().double().item() writer.add_scalar('test/accuracy', correct / total, epoch) print("\n[epoch %d] accuracy on test data: %.2f%%\n" % (epoch, 100 * correct / total)) # log images if opt.log_images: print('\nlog images ...\n') I_train = utils.make_grid(images_disp[0], nrow=6, normalize=True, scale_each=True) writer.add_image('train/image', I_train, epoch) if epoch == 0: I_test = utils.make_grid(images_disp[1], nrow=6, normalize=True, scale_each=True) writer.add_image('test/image', I_test, epoch) if opt.log_images and (not opt.no_attention): print('\nlog attention maps ...\n') # base factor if opt.attn_mode == 'before': min_up_factor = 1 else: min_up_factor = 2 # sigmoid or softmax if opt.normalize_attn: vis_fun = visualize_attn_softmax else: vis_fun = visualize_attn_sigmoid # training data __, c1, c2, c3 = model(images_disp[0]) if c1 is not None: attn1 = vis_fun(I_train, c1, up_factor=min_up_factor, nrow=6) writer.add_image('train/attention_map_1', attn1, epoch) if c2 is not None: attn2 = vis_fun(I_train, c2, up_factor=min_up_factor * 2, nrow=6) writer.add_image('train/attention_map_2', attn2, epoch) if c3 is not None: attn3 = vis_fun(I_train, c3, up_factor=min_up_factor * 4, nrow=6) writer.add_image('train/attention_map_3', attn3, epoch) # test data __, c1, c2, c3 = model(images_disp[1]) if c1 is not None: attn1 = vis_fun(I_test, c1, up_factor=min_up_factor, nrow=6) writer.add_image('test/attention_map_1', attn1, epoch) if c2 is not None: attn2 = vis_fun(I_test, c2, up_factor=min_up_factor * 2, nrow=6) writer.add_image('test/attention_map_2', attn2, epoch) if c3 is not None: attn3 = vis_fun(I_test, c3, up_factor=min_up_factor * 4, nrow=6) writer.add_image('test/attention_map_3', attn3, epoch)
def main(): ## load data print('\nloading the dataset ...\n') if False: # TODO debug section, remove pass else: opt = argparser() print(opt) num_aug = 1 raw_size = 1024 im_size = opt.image_size transform_train = transforms.Compose([ transforms.Resize(im_size), transforms.RandomVerticalFlip(), transforms.RandomHorizontalFlip(), transforms.ToTensor(), # 0.498, std: 0.185 transforms.Normalize((0.5, ), (0.185, )) ]) transform_test = transforms.Compose([ transforms.Resize(im_size), transforms.ToTensor(), transforms.Normalize((0.5, ), (0.185, )) ]) if opt.global_param != 'CIFAR': xray = XRAY(transform_train, transform_test, force_pre_process=False, csv_file=opt.csv_path) trainset = xray.train_set trainloader = torch.utils.data.DataLoader( trainset, batch_size=opt.batch_size, shuffle=True, num_workers=8, worker_init_fn=_worker_init_fn_) testset = xray.test_set testloader = torch.utils.data.DataLoader(testset, batch_size=opt.batch_size, shuffle=False, num_workers=5) class_to_index = xray.class_to_index() else: trainset = PacemakerDataset(transform=transform_train, is_train=True) trainloader = torch.utils.data.DataLoader( trainset, batch_size=opt.batch_size, shuffle=True, num_workers=8, worker_init_fn=_worker_init_fn_) testset = PacemakerDataset(transform=transform_test, is_train=False) testloader = torch.utils.data.DataLoader(testset, batch_size=opt.batch_size, shuffle=False, num_workers=5) class_to_index = trainset.class_to_index() num_of_class = len(class_to_index.keys()) - 1 device_ids = [0, 1] if opt.loss is not None and isinstance(opt.loss, str): criterion = xray_loss.Loss(opt.loss) else: criterion = nn.BCELoss() #criterion = nn.CrossEntropyLoss() print("criterion = %s" % type(criterion)) print('done num_of_classes: %s [%s] , post crop size: %s' % (num_of_class, class_to_index, im_size)) ## load network print('\nloading the network ...\n') # use attention module? if not opt.no_attention: print('\nturn on attention ...\n') else: print('\nturn off attention ...\n') # (linear attn) insert attention befroe or after maxpooling? # (grid attn only supports "before" mode) if opt.attn_mode == 'before': print('\npay attention before maxpooling layers...\n') net = AttnVGG_before(im_size=im_size, num_classes=num_of_class, attention=not opt.no_attention, normalize_attn=opt.normalize_attn, init='xavierUniform', _base_features=opt.base_feature_size, dropout=opt.dropout) elif opt.attn_mode == 'after': print('\npay attention after maxpooling layers...\n') net = AttnVGG_after(im_size=im_size, num_classes=num_of_class, attention=not opt.no_attention, normalize_attn=opt.normalize_attn, init='xavierUniform') else: raise NotImplementedError("Invalid attention mode!") print('done') ## move to GPU print('\nmoving to GPU ...\n') device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_pre_trained = None record = None if opt.pre_train: try: if os.path.exists(opt.chest_xray_pretrain_path): model_pre_trained = torch.load(opt.chest_xray_pretrain_path) record = None else: assert opt.test_only is False, "cannot run test only mode without pre train data" except AttributeError: record_path = os.path.join(opt.outf, 'record') if os.path.exists(record_path): with open(os.path.join(opt.outf, 'record'), 'r') as frecord: # contents = record.read() record = json.load(frecord) # record = ast.literal_eval(contents) if os.path.exists(record['model']): model_pre_trained = torch.load(record['model']) print("found pre trained data: %s", record) if model_pre_trained is not None: model = nn.DataParallel(net, device_ids=device_ids) model.load_state_dict(model_pre_trained) model = model.to(device) else: model = nn.DataParallel(net, device_ids=device_ids).to(device) criterion.to(device) print('done') if opt.test_only: visual_test_image_softmax(model, opt.test_image, transform_test) return if record is None: lr = opt.lr first_epoch = 0 step = 0 else: lr = record['lr'] first_epoch = record['epoch'] step = record['step'] slow_lr = opt.slow_lr if hasattr(opt, 'slow_lr') else False ### optimizer if opt.global_param == 'CIFAR': optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4) lr_lambda = lambda epoch: np.power(0.5, int(epoch / 25)) scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) else: if opt.global_param == 'BCE': optimizer = optim.SGD(model.parameters(), lr=lr, momentum=opt.momentum, weight_decay=opt.weight_decay) else: optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=opt.weight_decay) rate = 25 if not slow_lr else 50 lr_lambda = lambda epoch: max(np.power(0.5, int(epoch / rate)), 1e-4) scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) # training start = time.time() print('\nstart training [%s]...\n' % start) running_avg_accuracy = 0 writer = SummaryWriter(opt.outf) for epoch in range(first_epoch, first_epoch + opt.epochs): images_disp = [] # adjust learning rate scheduler.step() writer.add_scalar('train/learning_rate', optimizer.param_groups[0]['lr'], epoch) print("\nepoch %d learning rate %f\n" % (epoch, optimizer.param_groups[0]['lr'])) # run for one epoch for aug in range(num_aug): for i, data in enumerate(trainloader, 0): inputs, labels = data inputs, labels = inputs.to(device), labels.to(device) # warm up model.train() model.zero_grad() optimizer.zero_grad() if (aug == 0) and ( i == 0): # archive images in order to save to logs print("input:", inputs.shape, "inputs[0:36, :, :, :] ->", inputs[0:36, :, :, :].shape) images_disp.append(inputs[0:36, :, :, :]) # forward pred, __, __, __ = model(inputs) # backward if isinstance(criterion, nn.BCELoss): pred = torch.sigmoid(pred) loss = criterion(pred, labels) #print("loss: %s, pred: %s, labels: %s" % (loss, pred, labels)) loss.backward() #print("post loss backward") optimizer.step() # display results if i % 10 == 0: model.eval() pred, __, __, __ = model(inputs) if isinstance(criterion, nn.BCELoss): predict = pred predict[predict > 0.5] = 1 predict[predict <= 0.5] = 0 elif isinstance(criterion, nn.CrossEntropyLoss): predict = torch.argmax(pred, 1) elif isinstance(criterion, xray_loss.Loss): predict = torch.sigmoid(pred) predict[predict > 0.5] = 1 predict[predict <= 0.5] = 0 else: raise Exception("{} what is this?".format(criterion)) #print("predict: ", predict.shape, "pred: ", pred.shape, "label: ", labels.shape, "input: ", inputs.shape) #print("Train: predict: ", predict, "label: ", labels) total = labels.size(0) * labels.size(1) correct = torch.eq(predict, labels).sum().double().item() accuracy = correct / total # print("accuracy:%s = correct:%s [pred:%s, predict:%s, labels:%s] / total:%s" % (accuracy, correct, pred, predict, labels, total)) running_avg_accuracy = 0.9 * running_avg_accuracy + 0.1 * accuracy writer.add_scalar('train/loss', loss.item(), step) writer.add_scalar('train/accuracy', accuracy, step) writer.add_scalar('train/running_avg_accuracy', running_avg_accuracy, step) print( "[epoch %d][aug %d/%d][%d/%d] loss %.4f accuracy %.2f%% running avg accuracy %.2f%%" % (epoch, aug, num_aug - 1, i, len(trainloader) - 1, loss.item(), (100 * accuracy), (100 * running_avg_accuracy))) step += 1 # the end of each epoch: test & log print('\none epoch done [took: %s], saving records ...\n' % (time.time() - start)) state = os.path.join(opt.outf, 'net.pth') torch.save(model.state_dict(), state) with open(os.path.join(opt.outf, 'record'), 'w') as record: srecord = { "lr": optimizer.param_groups[0]['lr'], "epoch": epoch, "model": state, "step": step, 'global_arg': str(opt) } json.dump(srecord, record) if epoch == opt.epochs / 2: torch.save(model.state_dict(), os.path.join(opt.outf, 'net%d.pth' % epoch)) model.eval() total = 0 correct = 0 with torch.no_grad(): # log scalars images_disp.append(inputs[0:36, :, :, :]) if opt.global_param == 'PACEMAKER': # TODO not needed, remove it print("\n[epoch %d] log images for pacemaker" % epoch) else: for i, data in enumerate(testloader, 0): images_test, labels_test = data images_test, labels_test = images_test.to( device), labels_test.to(device) pred_test, __, __, __ = model(images_test) pred_test = torch.sigmoid(pred_test) #print("Test prediction: %s" % pred_test) #assert not (isinstance(criterion, nn.BCELoss) or isinstance(criterion, nn.CrossEntropyLoss)) if isinstance(criterion, nn.BCELoss): predict = pred_test predict[predict > 0.5] = 1 predict[predict <= 0.5] = 0 elif isinstance(criterion, nn.CrossEntropyLoss): predict = torch.argmax(pred_test, 1) elif isinstance(criterion, xray_loss.Loss): predict = pred_test predict[predict > 0.5] = 1 predict[predict <= 0.5] = 0 else: raise Exception("not sure how we reached here") total += labels_test.size(0) * labels_test.size(1) correct += torch.eq(predict, labels_test).sum().double().item() writer.add_scalar('test/accuracy', correct / total, epoch) print("\n[epoch %d] accuracy on test data: %.2f%%\n" % (epoch, 100 * correct / total)) # log images if opt.log_images: print('\nlog images ...\n') I_train = utils.make_grid(images_disp[0], nrow=6, normalize=True, scale_each=True) writer.add_image('train/image', I_train, epoch) #if epoch == 0: if epoch == first_epoch: I_test = utils.make_grid(images_disp[1], nrow=6, normalize=True, scale_each=True) writer.add_image('test/image', I_test, epoch) if opt.log_images and (not opt.no_attention): print('\nlog attention maps ...\n') # base factor if opt.attn_mode == 'before': min_up_factor = 1 else: min_up_factor = 2 # sigmoid or softmax if opt.normalize_attn: vis_fun = visualize_attn_softmax else: vis_fun = visualize_attn_sigmoid # training data __, c1, c2, c3 = model(images_disp[0]) if c1 is not None: attn1 = vis_fun(I_train, c1, up_factor=min_up_factor, nrow=6) writer.add_image('train/attention_map_1', attn1, epoch) if c2 is not None: attn2 = vis_fun(I_train, c2, up_factor=min_up_factor * 2, nrow=6) writer.add_image('train/attention_map_2', attn2, epoch) if c3 is not None: attn3 = vis_fun(I_train, c3, up_factor=min_up_factor * 4, nrow=6) writer.add_image('train/attention_map_3', attn3, epoch) # test data __, c1, c2, c3 = model(images_disp[1]) if c1 is not None: attn1 = vis_fun(I_test, c1, up_factor=min_up_factor, nrow=6) writer.add_image('test/attention_map_1', attn1, epoch) if c2 is not None: attn2 = vis_fun(I_test, c2, up_factor=min_up_factor * 2, nrow=6) writer.add_image('test/attention_map_2', attn2, epoch) if c3 is not None: attn3 = vis_fun(I_test, c3, up_factor=min_up_factor * 4, nrow=6) writer.add_image('test/attention_map_3', attn3, epoch) start = time.time()
def main(): im_size = 32 mean, std = [0.5071, 0.4867, 0.4408], [0.2675, 0.2565, 0.2761] transform_test = transforms.Compose([ #transforms.ToTensor(), transforms.Normalize(mean, std) ]) print('done') ## load network print('\nloading the network ...\n') # (linear attn) insert attention befroe or after maxpooling? # (grid attn only supports "before" mode) if opt.attn_mode == 'before': print('\npay attention before maxpooling layers...\n') net = AttnVGG_before(im_size=im_size, num_classes=100, attention=True, normalize_attn=opt.normalize_attn, init='xavierUniform') elif opt.attn_mode == 'after': print('\npay attention after maxpooling layers...\n') net = AttnVGG_after(im_size=im_size, num_classes=100, attention=True, normalize_attn=opt.normalize_attn, init='xavierUniform') else: raise NotImplementedError("Invalid attention mode!") print('done') ## load model print('\nloading the model ...\n') state_dict = torch.load(opt.model, map_location=str(device)) # Remove 'module.' prefix state_dict = {k[7:]: v for k, v in state_dict.items()} net.load_state_dict(state_dict) net = net.to(device) net.eval() print('done') model = net # base factor if opt.attn_mode == 'before': min_up_factor = 1 else: min_up_factor = 2 # sigmoid or softmax if opt.normalize_attn: vis_fun = visualize_attn_softmax else: vis_fun = visualize_attn_sigmoid if opt.output_dir: print("\nwill save heatmaps\n") if opt.img: img_dir = "" filenames = [opt.img] else: img_dir = opt.img_dir filenames = os.listdir(img_dir) display_fig = len(filenames) == 1 with torch.no_grad(): for filename in filenames: ## load image path = os.path.join(img_dir, filename) img = imread(path) if len(img.shape) == 2: img = img[:, :, np.newaxis] img = np.concatenate([img, img, img], axis=2) img = np.array(Image.fromarray(img).resize((im_size, im_size))) orig_img = img.copy() img = img.transpose(2, 0, 1) img = img / 255. img = torch.FloatTensor(img).to(device) image = transform_test(img) # (3, 32, 32) if opt.output_dir: file_prefix = os.path.join( opt.output_dir, os.path.splitext(os.path.basename(filename))[0]) else: file_prefix = None batch = image[np.newaxis, :, :, :] __, c1, c2, c3 = model(batch) if display_fig: fig, axs = plt.subplots(1, 4) axs[0].imshow(orig_img) if c1 is not None: attn1 = vis_fun( img, c1, up_factor=min_up_factor, nrow=1, hm_file=None if file_prefix is None else file_prefix + "_c1.npy") if display_fig: axs[1].imshow(attn1.numpy().transpose(1, 2, 0)) if c2 is not None: attn2 = vis_fun( img, c2, up_factor=min_up_factor * 2, nrow=1, hm_file=None if file_prefix is None else file_prefix + "_c2.npy") if display_fig: axs[2].imshow(attn2.numpy().transpose(1, 2, 0)) if c3 is not None: attn3 = vis_fun( img, c3, up_factor=min_up_factor * 4, nrow=1, hm_file=None if file_prefix is None else file_prefix + "_c3.npy") if display_fig: axs[3].imshow(attn3.numpy().transpose(1, 2, 0)) if display_fig: plt.show()
def main(): im_size = 32 mean, std = [0.5071, 0.4867, 0.4408], [0.2675, 0.2565, 0.2761] transform_test = transforms.Compose([ #transforms.ToTensor(), transforms.Normalize(mean, std) ]) print('done') ## load network print('\nloading the network ...\n') # (linear attn) insert attention befroe or after maxpooling? # (grid attn only supports "before" mode) if opt.attn_mode == 'before': print('\npay attention before maxpooling layers...\n') net = AttnVGG_before(im_size=im_size, num_classes=100, attention=True, normalize_attn=opt.normalize_attn, init='xavierUniform') elif opt.attn_mode == 'after': print('\npay attention after maxpooling layers...\n') net = AttnVGG_after(im_size=im_size, num_classes=100, attention=True, normalize_attn=opt.normalize_attn, init='xavierUniform') else: raise NotImplementedError("Invalid attention mode!") print('done') ## load model print('\nloading the model ...\n') state_dict = torch.load(opt.model, map_location=str(device)) # Remove 'module.' prefix state_dict = {k[7:]: v for k, v in state_dict.items()} net.load_state_dict(state_dict) net = net.to(device) net.eval() print('done') model = net # base factor if opt.attn_mode == 'before': min_up_factor = 1 else: min_up_factor = 2 # sigmoid or softmax if opt.normalize_attn: vis_fun = visualize_attn_softmax else: vis_fun = visualize_attn_sigmoid results = [] with torch.no_grad(): for img_file in os.scandir(opt.image_dir): ## load image img = imread(img_file.path) if len(img.shape) == 2: img = img[:, :, np.newaxis] img = np.concatenate([img, img, img], axis=2) img = np.array(Image.fromarray(img).resize((im_size, im_size))) orig_img = img.copy() img = img.transpose(2, 0, 1) img = img / 255. img = torch.FloatTensor(img).to(device) image = transform_test(img) # (3, 256, 256) batch = image[np.newaxis, :, :, :] pred, __, __, __ = model(batch) out, cls = torch.max(F.softmax(pred, dim=1), 1) results.append((out.item(), cls.item(), img_file.name)) sorted_results = sorted(results, reverse=True) print("\n".join(f"{result[2]} {result[0]} {result[1]}" for result in sorted_results[:opt.num_images]))