def main(): ## load data # CIFAR-100: 500 training images and 100 testing images per class print('\nloading the dataset ...\n') num_aug = 3 im_size = 32 transform_train = transforms.Compose([ transforms.RandomCrop(im_size, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)) ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)) ]) def _init_fn(worker_id): random.seed(base_seed + worker_id) # trainset = torchvision.datasets.CIFAR100(root='CIFAR100_data', train=True, download=True, transform=transform_train) trainset = DrawDataset(transform=transform_train) trainloader = torch.utils.data.DataLoader(trainset, batch_size=opt.batch_size, shuffle=True, num_workers=8, worker_init_fn=_init_fn) # testset = torchvision.datasets.CIFAR100(root='CIFAR100_data', train=False, download=True, transform=transform_test) testset = DrawDataset(transform=transform_test) testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=5) print('done') ## load network print('\nloading the network ...\n') # use attention module? if not opt.no_attention: print('\nturn on attention ...\n') else: print('\nturn off attention ...\n') # (linear attn) insert attention befroe or after maxpooling? # (grid attn only supports "before" mode) if opt.attn_mode == 'before': print('\npay attention before maxpooling layers...\n') net = AttnVGG_before(im_size=im_size, num_classes=100, attention=not opt.no_attention, normalize_attn=opt.normalize_attn, init='xavierUniform') elif opt.attn_mode == 'after': print('\npay attention after maxpooling layers...\n') net = AttnVGG_after(im_size=im_size, num_classes=100, attention=not opt.no_attention, normalize_attn=opt.normalize_attn, init='xavierUniform') else: raise NotImplementedError("Invalid attention mode!") criterion = nn.CrossEntropyLoss() print('done') ## move to GPU print('\nmoving to GPU ...\n') device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device_ids = [0, 1] model = nn.DataParallel(net, device_ids=device_ids).to(device) criterion.to(device) print('done') ### optimizer optimizer = optim.SGD(model.parameters(), lr=opt.lr, momentum=0.9, weight_decay=5e-4) lr_lambda = lambda epoch: np.power(0.5, int(epoch / 25)) scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) # training print('\nstart training ...\n') step = 0 running_avg_accuracy = 0 writer = SummaryWriter(opt.outf) for epoch in range(opt.epochs): images_disp = [] # adjust learning rate scheduler.step() writer.add_scalar('train/learning_rate', optimizer.param_groups[0]['lr'], epoch) print("\nepoch %d learning rate %f\n" % (epoch, optimizer.param_groups[0]['lr'])) # run for one epoch for aug in range(num_aug): for i, data in enumerate(trainloader, 0): # warm up model.train() model.zero_grad() optimizer.zero_grad() inputs, labels = data inputs, labels = inputs.to(device), labels.to(device) if (aug == 0) and ( i == 0): # archive images in order to save to logs images_disp.append(inputs[0:36, :, :, :]) # forward pred, __, __, __ = model(inputs) # backward loss = criterion(pred, labels) loss.backward() optimizer.step() # display results if i % 10 == 0: model.eval() pred, __, __, __ = model(inputs) predict = torch.argmax(pred, 1) total = labels.size(0) correct = torch.eq(predict, labels).sum().double().item() accuracy = correct / total running_avg_accuracy = 0.9 * running_avg_accuracy + 0.1 * accuracy writer.add_scalar('train/loss', loss.item(), step) writer.add_scalar('train/accuracy', accuracy, step) writer.add_scalar('train/running_avg_accuracy', running_avg_accuracy, step) print( "[epoch %d][aug %d/%d][%d/%d] loss %.4f accuracy %.2f%% running avg accuracy %.2f%%" % (epoch, aug, num_aug - 1, i, len(trainloader) - 1, loss.item(), (100 * accuracy), (100 * running_avg_accuracy))) step += 1 # the end of each epoch: test & log print('\none epoch done, saving records ...\n') torch.save(model.state_dict(), os.path.join(opt.outf, 'net.pth')) if epoch == opt.epochs / 2: torch.save(model.state_dict(), os.path.join(opt.outf, 'net%d.pth' % epoch)) model.eval() total = 0 correct = 0 with torch.no_grad(): # log scalars for i, data in enumerate(testloader, 0): images_test, labels_test = data images_test, labels_test = images_test.to( device), labels_test.to(device) if i == 0: # archive images in order to save to logs images_disp.append(inputs[0:36, :, :, :]) pred_test, __, __, __ = model(images_test) predict = torch.argmax(pred_test, 1) total += labels_test.size(0) correct += torch.eq(predict, labels_test).sum().double().item() writer.add_scalar('test/accuracy', correct / total, epoch) print("\n[epoch %d] accuracy on test data: %.2f%%\n" % (epoch, 100 * correct / total)) # log images if opt.log_images: print('\nlog images ...\n') I_train = utils.make_grid(images_disp[0], nrow=6, normalize=True, scale_each=True) writer.add_image('train/image', I_train, epoch) if epoch == 0: I_test = utils.make_grid(images_disp[1], nrow=6, normalize=True, scale_each=True) writer.add_image('test/image', I_test, epoch) if opt.log_images and (not opt.no_attention): print('\nlog attention maps ...\n') # base factor if opt.attn_mode == 'before': min_up_factor = 1 else: min_up_factor = 2 # sigmoid or softmax if opt.normalize_attn: vis_fun = visualize_attn_softmax else: vis_fun = visualize_attn_sigmoid # training data __, c1, c2, c3 = model(images_disp[0]) if c1 is not None: attn1 = vis_fun(I_train, c1, up_factor=min_up_factor, nrow=6) writer.add_image('train/attention_map_1', attn1, epoch) if c2 is not None: attn2 = vis_fun(I_train, c2, up_factor=min_up_factor * 2, nrow=6) writer.add_image('train/attention_map_2', attn2, epoch) if c3 is not None: attn3 = vis_fun(I_train, c3, up_factor=min_up_factor * 4, nrow=6) writer.add_image('train/attention_map_3', attn3, epoch) # test data __, c1, c2, c3 = model(images_disp[1]) if c1 is not None: attn1 = vis_fun(I_test, c1, up_factor=min_up_factor, nrow=6) writer.add_image('test/attention_map_1', attn1, epoch) if c2 is not None: attn2 = vis_fun(I_test, c2, up_factor=min_up_factor * 2, nrow=6) writer.add_image('test/attention_map_2', attn2, epoch) if c3 is not None: attn3 = vis_fun(I_test, c3, up_factor=min_up_factor * 4, nrow=6) writer.add_image('test/attention_map_3', attn3, epoch)
def main(): ## load data print('\nloading the dataset ...\n') if False: # TODO debug section, remove pass else: opt = argparser() print(opt) num_aug = 1 raw_size = 1024 im_size = opt.image_size transform_train = transforms.Compose([ transforms.Resize(im_size), transforms.RandomVerticalFlip(), transforms.RandomHorizontalFlip(), transforms.ToTensor(), # 0.498, std: 0.185 transforms.Normalize((0.5, ), (0.185, )) ]) transform_test = transforms.Compose([ transforms.Resize(im_size), transforms.ToTensor(), transforms.Normalize((0.5, ), (0.185, )) ]) if opt.global_param != 'CIFAR': xray = XRAY(transform_train, transform_test, force_pre_process=False, csv_file=opt.csv_path) trainset = xray.train_set trainloader = torch.utils.data.DataLoader( trainset, batch_size=opt.batch_size, shuffle=True, num_workers=8, worker_init_fn=_worker_init_fn_) testset = xray.test_set testloader = torch.utils.data.DataLoader(testset, batch_size=opt.batch_size, shuffle=False, num_workers=5) class_to_index = xray.class_to_index() else: trainset = PacemakerDataset(transform=transform_train, is_train=True) trainloader = torch.utils.data.DataLoader( trainset, batch_size=opt.batch_size, shuffle=True, num_workers=8, worker_init_fn=_worker_init_fn_) testset = PacemakerDataset(transform=transform_test, is_train=False) testloader = torch.utils.data.DataLoader(testset, batch_size=opt.batch_size, shuffle=False, num_workers=5) class_to_index = trainset.class_to_index() num_of_class = len(class_to_index.keys()) - 1 device_ids = [0, 1] if opt.loss is not None and isinstance(opt.loss, str): criterion = xray_loss.Loss(opt.loss) else: criterion = nn.BCELoss() #criterion = nn.CrossEntropyLoss() print("criterion = %s" % type(criterion)) print('done num_of_classes: %s [%s] , post crop size: %s' % (num_of_class, class_to_index, im_size)) ## load network print('\nloading the network ...\n') # use attention module? if not opt.no_attention: print('\nturn on attention ...\n') else: print('\nturn off attention ...\n') # (linear attn) insert attention befroe or after maxpooling? # (grid attn only supports "before" mode) if opt.attn_mode == 'before': print('\npay attention before maxpooling layers...\n') net = AttnVGG_before(im_size=im_size, num_classes=num_of_class, attention=not opt.no_attention, normalize_attn=opt.normalize_attn, init='xavierUniform', _base_features=opt.base_feature_size, dropout=opt.dropout) elif opt.attn_mode == 'after': print('\npay attention after maxpooling layers...\n') net = AttnVGG_after(im_size=im_size, num_classes=num_of_class, attention=not opt.no_attention, normalize_attn=opt.normalize_attn, init='xavierUniform') else: raise NotImplementedError("Invalid attention mode!") print('done') ## move to GPU print('\nmoving to GPU ...\n') device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_pre_trained = None record = None if opt.pre_train: try: if os.path.exists(opt.chest_xray_pretrain_path): model_pre_trained = torch.load(opt.chest_xray_pretrain_path) record = None else: assert opt.test_only is False, "cannot run test only mode without pre train data" except AttributeError: record_path = os.path.join(opt.outf, 'record') if os.path.exists(record_path): with open(os.path.join(opt.outf, 'record'), 'r') as frecord: # contents = record.read() record = json.load(frecord) # record = ast.literal_eval(contents) if os.path.exists(record['model']): model_pre_trained = torch.load(record['model']) print("found pre trained data: %s", record) if model_pre_trained is not None: model = nn.DataParallel(net, device_ids=device_ids) model.load_state_dict(model_pre_trained) model = model.to(device) else: model = nn.DataParallel(net, device_ids=device_ids).to(device) criterion.to(device) print('done') if opt.test_only: visual_test_image_softmax(model, opt.test_image, transform_test) return if record is None: lr = opt.lr first_epoch = 0 step = 0 else: lr = record['lr'] first_epoch = record['epoch'] step = record['step'] slow_lr = opt.slow_lr if hasattr(opt, 'slow_lr') else False ### optimizer if opt.global_param == 'CIFAR': optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4) lr_lambda = lambda epoch: np.power(0.5, int(epoch / 25)) scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) else: if opt.global_param == 'BCE': optimizer = optim.SGD(model.parameters(), lr=lr, momentum=opt.momentum, weight_decay=opt.weight_decay) else: optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=opt.weight_decay) rate = 25 if not slow_lr else 50 lr_lambda = lambda epoch: max(np.power(0.5, int(epoch / rate)), 1e-4) scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) # training start = time.time() print('\nstart training [%s]...\n' % start) running_avg_accuracy = 0 writer = SummaryWriter(opt.outf) for epoch in range(first_epoch, first_epoch + opt.epochs): images_disp = [] # adjust learning rate scheduler.step() writer.add_scalar('train/learning_rate', optimizer.param_groups[0]['lr'], epoch) print("\nepoch %d learning rate %f\n" % (epoch, optimizer.param_groups[0]['lr'])) # run for one epoch for aug in range(num_aug): for i, data in enumerate(trainloader, 0): inputs, labels = data inputs, labels = inputs.to(device), labels.to(device) # warm up model.train() model.zero_grad() optimizer.zero_grad() if (aug == 0) and ( i == 0): # archive images in order to save to logs print("input:", inputs.shape, "inputs[0:36, :, :, :] ->", inputs[0:36, :, :, :].shape) images_disp.append(inputs[0:36, :, :, :]) # forward pred, __, __, __ = model(inputs) # backward if isinstance(criterion, nn.BCELoss): pred = torch.sigmoid(pred) loss = criterion(pred, labels) #print("loss: %s, pred: %s, labels: %s" % (loss, pred, labels)) loss.backward() #print("post loss backward") optimizer.step() # display results if i % 10 == 0: model.eval() pred, __, __, __ = model(inputs) if isinstance(criterion, nn.BCELoss): predict = pred predict[predict > 0.5] = 1 predict[predict <= 0.5] = 0 elif isinstance(criterion, nn.CrossEntropyLoss): predict = torch.argmax(pred, 1) elif isinstance(criterion, xray_loss.Loss): predict = torch.sigmoid(pred) predict[predict > 0.5] = 1 predict[predict <= 0.5] = 0 else: raise Exception("{} what is this?".format(criterion)) #print("predict: ", predict.shape, "pred: ", pred.shape, "label: ", labels.shape, "input: ", inputs.shape) #print("Train: predict: ", predict, "label: ", labels) total = labels.size(0) * labels.size(1) correct = torch.eq(predict, labels).sum().double().item() accuracy = correct / total # print("accuracy:%s = correct:%s [pred:%s, predict:%s, labels:%s] / total:%s" % (accuracy, correct, pred, predict, labels, total)) running_avg_accuracy = 0.9 * running_avg_accuracy + 0.1 * accuracy writer.add_scalar('train/loss', loss.item(), step) writer.add_scalar('train/accuracy', accuracy, step) writer.add_scalar('train/running_avg_accuracy', running_avg_accuracy, step) print( "[epoch %d][aug %d/%d][%d/%d] loss %.4f accuracy %.2f%% running avg accuracy %.2f%%" % (epoch, aug, num_aug - 1, i, len(trainloader) - 1, loss.item(), (100 * accuracy), (100 * running_avg_accuracy))) step += 1 # the end of each epoch: test & log print('\none epoch done [took: %s], saving records ...\n' % (time.time() - start)) state = os.path.join(opt.outf, 'net.pth') torch.save(model.state_dict(), state) with open(os.path.join(opt.outf, 'record'), 'w') as record: srecord = { "lr": optimizer.param_groups[0]['lr'], "epoch": epoch, "model": state, "step": step, 'global_arg': str(opt) } json.dump(srecord, record) if epoch == opt.epochs / 2: torch.save(model.state_dict(), os.path.join(opt.outf, 'net%d.pth' % epoch)) model.eval() total = 0 correct = 0 with torch.no_grad(): # log scalars images_disp.append(inputs[0:36, :, :, :]) if opt.global_param == 'PACEMAKER': # TODO not needed, remove it print("\n[epoch %d] log images for pacemaker" % epoch) else: for i, data in enumerate(testloader, 0): images_test, labels_test = data images_test, labels_test = images_test.to( device), labels_test.to(device) pred_test, __, __, __ = model(images_test) pred_test = torch.sigmoid(pred_test) #print("Test prediction: %s" % pred_test) #assert not (isinstance(criterion, nn.BCELoss) or isinstance(criterion, nn.CrossEntropyLoss)) if isinstance(criterion, nn.BCELoss): predict = pred_test predict[predict > 0.5] = 1 predict[predict <= 0.5] = 0 elif isinstance(criterion, nn.CrossEntropyLoss): predict = torch.argmax(pred_test, 1) elif isinstance(criterion, xray_loss.Loss): predict = pred_test predict[predict > 0.5] = 1 predict[predict <= 0.5] = 0 else: raise Exception("not sure how we reached here") total += labels_test.size(0) * labels_test.size(1) correct += torch.eq(predict, labels_test).sum().double().item() writer.add_scalar('test/accuracy', correct / total, epoch) print("\n[epoch %d] accuracy on test data: %.2f%%\n" % (epoch, 100 * correct / total)) # log images if opt.log_images: print('\nlog images ...\n') I_train = utils.make_grid(images_disp[0], nrow=6, normalize=True, scale_each=True) writer.add_image('train/image', I_train, epoch) #if epoch == 0: if epoch == first_epoch: I_test = utils.make_grid(images_disp[1], nrow=6, normalize=True, scale_each=True) writer.add_image('test/image', I_test, epoch) if opt.log_images and (not opt.no_attention): print('\nlog attention maps ...\n') # base factor if opt.attn_mode == 'before': min_up_factor = 1 else: min_up_factor = 2 # sigmoid or softmax if opt.normalize_attn: vis_fun = visualize_attn_softmax else: vis_fun = visualize_attn_sigmoid # training data __, c1, c2, c3 = model(images_disp[0]) if c1 is not None: attn1 = vis_fun(I_train, c1, up_factor=min_up_factor, nrow=6) writer.add_image('train/attention_map_1', attn1, epoch) if c2 is not None: attn2 = vis_fun(I_train, c2, up_factor=min_up_factor * 2, nrow=6) writer.add_image('train/attention_map_2', attn2, epoch) if c3 is not None: attn3 = vis_fun(I_train, c3, up_factor=min_up_factor * 4, nrow=6) writer.add_image('train/attention_map_3', attn3, epoch) # test data __, c1, c2, c3 = model(images_disp[1]) if c1 is not None: attn1 = vis_fun(I_test, c1, up_factor=min_up_factor, nrow=6) writer.add_image('test/attention_map_1', attn1, epoch) if c2 is not None: attn2 = vis_fun(I_test, c2, up_factor=min_up_factor * 2, nrow=6) writer.add_image('test/attention_map_2', attn2, epoch) if c3 is not None: attn3 = vis_fun(I_test, c3, up_factor=min_up_factor * 4, nrow=6) writer.add_image('test/attention_map_3', attn3, epoch) start = time.time()
def main(): im_size = 32 mean, std = [0.5071, 0.4867, 0.4408], [0.2675, 0.2565, 0.2761] transform_test = transforms.Compose([ #transforms.ToTensor(), transforms.Normalize(mean, std) ]) print('done') ## load network print('\nloading the network ...\n') # (linear attn) insert attention befroe or after maxpooling? # (grid attn only supports "before" mode) if opt.attn_mode == 'before': print('\npay attention before maxpooling layers...\n') net = AttnVGG_before(im_size=im_size, num_classes=100, attention=True, normalize_attn=opt.normalize_attn, init='xavierUniform') elif opt.attn_mode == 'after': print('\npay attention after maxpooling layers...\n') net = AttnVGG_after(im_size=im_size, num_classes=100, attention=True, normalize_attn=opt.normalize_attn, init='xavierUniform') else: raise NotImplementedError("Invalid attention mode!") print('done') ## load model print('\nloading the model ...\n') state_dict = torch.load(opt.model, map_location=str(device)) # Remove 'module.' prefix state_dict = {k[7:]: v for k, v in state_dict.items()} net.load_state_dict(state_dict) net = net.to(device) net.eval() print('done') model = net # base factor if opt.attn_mode == 'before': min_up_factor = 1 else: min_up_factor = 2 # sigmoid or softmax if opt.normalize_attn: vis_fun = visualize_attn_softmax else: vis_fun = visualize_attn_sigmoid if opt.output_dir: print("\nwill save heatmaps\n") if opt.img: img_dir = "" filenames = [opt.img] else: img_dir = opt.img_dir filenames = os.listdir(img_dir) display_fig = len(filenames) == 1 with torch.no_grad(): for filename in filenames: ## load image path = os.path.join(img_dir, filename) img = imread(path) if len(img.shape) == 2: img = img[:, :, np.newaxis] img = np.concatenate([img, img, img], axis=2) img = np.array(Image.fromarray(img).resize((im_size, im_size))) orig_img = img.copy() img = img.transpose(2, 0, 1) img = img / 255. img = torch.FloatTensor(img).to(device) image = transform_test(img) # (3, 32, 32) if opt.output_dir: file_prefix = os.path.join( opt.output_dir, os.path.splitext(os.path.basename(filename))[0]) else: file_prefix = None batch = image[np.newaxis, :, :, :] __, c1, c2, c3 = model(batch) if display_fig: fig, axs = plt.subplots(1, 4) axs[0].imshow(orig_img) if c1 is not None: attn1 = vis_fun( img, c1, up_factor=min_up_factor, nrow=1, hm_file=None if file_prefix is None else file_prefix + "_c1.npy") if display_fig: axs[1].imshow(attn1.numpy().transpose(1, 2, 0)) if c2 is not None: attn2 = vis_fun( img, c2, up_factor=min_up_factor * 2, nrow=1, hm_file=None if file_prefix is None else file_prefix + "_c2.npy") if display_fig: axs[2].imshow(attn2.numpy().transpose(1, 2, 0)) if c3 is not None: attn3 = vis_fun( img, c3, up_factor=min_up_factor * 4, nrow=1, hm_file=None if file_prefix is None else file_prefix + "_c3.npy") if display_fig: axs[3].imshow(attn3.numpy().transpose(1, 2, 0)) if display_fig: plt.show()
def main(): im_size = 32 mean, std = [0.5071, 0.4867, 0.4408], [0.2675, 0.2565, 0.2761] transform_test = transforms.Compose([ #transforms.ToTensor(), transforms.Normalize(mean, std) ]) print('done') ## load network print('\nloading the network ...\n') # (linear attn) insert attention befroe or after maxpooling? # (grid attn only supports "before" mode) if opt.attn_mode == 'before': print('\npay attention before maxpooling layers...\n') net = AttnVGG_before(im_size=im_size, num_classes=100, attention=True, normalize_attn=opt.normalize_attn, init='xavierUniform') elif opt.attn_mode == 'after': print('\npay attention after maxpooling layers...\n') net = AttnVGG_after(im_size=im_size, num_classes=100, attention=True, normalize_attn=opt.normalize_attn, init='xavierUniform') else: raise NotImplementedError("Invalid attention mode!") print('done') ## load model print('\nloading the model ...\n') state_dict = torch.load(opt.model, map_location=str(device)) # Remove 'module.' prefix state_dict = {k[7:]: v for k, v in state_dict.items()} net.load_state_dict(state_dict) net = net.to(device) net.eval() print('done') model = net # base factor if opt.attn_mode == 'before': min_up_factor = 1 else: min_up_factor = 2 # sigmoid or softmax if opt.normalize_attn: vis_fun = visualize_attn_softmax else: vis_fun = visualize_attn_sigmoid results = [] with torch.no_grad(): for img_file in os.scandir(opt.image_dir): ## load image img = imread(img_file.path) if len(img.shape) == 2: img = img[:, :, np.newaxis] img = np.concatenate([img, img, img], axis=2) img = np.array(Image.fromarray(img).resize((im_size, im_size))) orig_img = img.copy() img = img.transpose(2, 0, 1) img = img / 255. img = torch.FloatTensor(img).to(device) image = transform_test(img) # (3, 256, 256) batch = image[np.newaxis, :, :, :] pred, __, __, __ = model(batch) out, cls = torch.max(F.softmax(pred, dim=1), 1) results.append((out.item(), cls.item(), img_file.name)) sorted_results = sorted(results, reverse=True) print("\n".join(f"{result[2]} {result[0]} {result[1]}" for result in sorted_results[:opt.num_images]))