class Predictor: def __init__(self, checkpoint_path): self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') self.model = EmbeddingNet() self.model.load_state_dict(torch.load(checkpoint_path)) self.model.to(self.device) self.model.eval() self.transform = transforms.Compose([transforms.Lambda(lambda image: image.convert('RGB')), transforms.Resize((224, 224)), transforms.ToTensor()]) def _preprocess(self, image): image = transforms.ToPILImage()(image) image = self.transform(image) return image def predict(self, image_list): image_tensor = torch.cat([self._preprocess(im).unsqueeze(0) for im in image_list], dim=0) with torch.no_grad(): image_tensor = image_tensor.cuda() embedings = self.model(image_tensor) return embedings.cpu().numpy()
class TensorPredictor: def __init__(self, checkpoint_path): self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') self.model = EmbeddingNet() self.model.load_state_dict(torch.load(checkpoint_path)) self.model.to(self.device) self.model.eval() def predict(self, image_list): tensor_list = [] for image_tensor in image_list: #image_tensor = torch.index_select(image_tensor, 2, torch.tensor([2, 1, 0], device=self.device)) # BGR -> RBG ? image_tensor = F.interpolate(torch.unsqueeze(image_tensor,0), size=(224, 224))[0] tensor_list.append(image_tensor) input_tensor = torch.stack(tensor_list) with torch.no_grad(): input_tensor = input_tensor.to(self.device) embeddings = self.model(input_tensor) return embeddings
def main(): # 1. argparse parser = argparse.ArgumentParser() parser.add_argument('--epoch', type=int, default=100) parser.add_argument('--lr', type=float, default=1e-4) parser.add_argument('--batch_size', type=int, default=16) parser.add_argument('--resume', type=int, default=0) opts = parser.parse_args() print(opts) # 2. device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 3. visdom vis = visdom.Visdom() # 4. dataset mean, std = 0.1307, 0.3081 transform = tfs.Compose([tfs.Normalize((mean, ), (std, ))]) test_transform = tfs.Compose( [tfs.ToTensor(), tfs.Normalize((mean, ), (std, ))]) train_set = MNIST('./data/MNIST', train=True, download=True, transform=None) train_set = SEMI_MNIST(train_set, transform=transform, num_samples=100) test_set = MNIST('./data/MNIST', train=False, download=True, transform=test_transform) # 5. data loader train_loader = DataLoader(dataset=train_set, shuffle=True, batch_size=opts.batch_size, num_workers=8, pin_memory=True) test_loader = DataLoader( dataset=test_set, shuffle=False, batch_size=opts.batch_size, ) # 6. model model = EmbeddingNet().to(device) # 7. criterion criterion = MetricCrossEntropy().to(device) # 8. optimizer optimizer = torch.optim.SGD(params=model.parameters(), lr=opts.lr, momentum=0.9, weight_decay=5e-4) # 9. scheduler scheduler = StepLR(optimizer=optimizer, step_size=50, gamma=1) # 10. resume if opts.resume: model.load_state_dict( torch.load('./saves/state_dict.{}'.format(opts.resume))) print("resume from {} epoch..".format(opts.resume - 1)) else: print("no checkpoint to resume.. train from scratch.") # -- for epoch in range(opts.resume, opts.epoch): # 11. trian for idx, (imgs, targets, samples, is_known) in enumerate(train_loader): model.train() batch_size = opts.batch_size imgs = imgs.to(device) # [N, 1, 28, 28] targets = targets.to(device) # [N] samples = samples.to(device) # [N, 1, 32, 32] is_known = is_known.to(device) samples = samples.view(batch_size * 10, 1, 28, 28) out_x = model(imgs) # [N, 10] out_z = model(samples).view(batch_size, 10, out_x.size(-1)) # [N * 10 , 2] loss = criterion(out_x, targets, out_z, is_known, 10, 1) optimizer.zero_grad() loss.backward() optimizer.step() for param_group in optimizer.param_groups: lr = param_group['lr'] if idx % 100 == 0: print('Epoch : {}\t' 'step : [{}/{}]\t' 'loss : {}\t' 'lr : {}\t'.format(epoch, idx, len(train_loader), loss, lr)) vis.line(X=torch.ones( (1, 1)) * idx + epoch * len(train_loader), Y=torch.Tensor([loss]).unsqueeze(0), update='append', win='loss', opts=dict(x_label='step', y_label='loss', title='loss', legend=['total_loss'])) torch.save(model.state_dict(), './saves/state_dict.{}'.format(epoch)) # 12. test correct = 0 avg_loss = 0 for idx, (img, target) in enumerate(test_loader): model.load_state_dict( torch.load('./saves/state_dict.{}'.format(epoch))) model.eval() img = img.to(device) # [N, 1, 28, 28] target = target.to(device) # [N] output = model(img) # [N, 10] output = torch.softmax(output, -1) pred, idx_ = output.max(-1) print(idx_) correct += torch.eq(target, idx_).sum() #loss = criterion(output, target) #avg_loss += loss.item() print('Epoch {} test : '.format(epoch)) accuracy = correct.item() / len(test_set) print("accuracy : {:.4f}%".format(accuracy * 100.)) #avg_loss = avg_loss / len(test_loader) #print("avg_loss : {:.4f}".format(avg_loss)) vis.line(X=torch.ones((1, 1)) * epoch, Y=torch.Tensor([accuracy]).unsqueeze(0), update='append', win='test', opts=dict(x_label='epoch', y_label='test_', title='test_loss', legend=['accuracy'])) scheduler.step()