Ejemplo n.º 1
0
class Predictor:

  def __init__(self, checkpoint_path):

    self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    self.model = EmbeddingNet()
    self.model.load_state_dict(torch.load(checkpoint_path))

    self.model.to(self.device)
    self.model.eval()
    
    self.transform = transforms.Compose([transforms.Lambda(lambda image: image.convert('RGB')),
                                transforms.Resize((224, 224)),
                                transforms.ToTensor()])

  def _preprocess(self, image):
    image = transforms.ToPILImage()(image)
    image = self.transform(image)
    
    return image

  
  def predict(self, image_list):
    
    image_tensor = torch.cat([self._preprocess(im).unsqueeze(0) for im in image_list], dim=0)

    with torch.no_grad():
      image_tensor = image_tensor.cuda()      
      embedings = self.model(image_tensor)
      
    return embedings.cpu().numpy()
Ejemplo n.º 2
0
class TensorPredictor:

  def __init__(self, checkpoint_path):

    self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    self.model = EmbeddingNet()
    self.model.load_state_dict(torch.load(checkpoint_path))

    self.model.to(self.device)
    self.model.eval()    

  def predict(self, image_list):
    
    tensor_list = []
    
    for image_tensor in image_list:
      #image_tensor = torch.index_select(image_tensor, 2, torch.tensor([2, 1, 0], device=self.device))  # BGR -> RBG ?
      image_tensor = F.interpolate(torch.unsqueeze(image_tensor,0), size=(224, 224))[0]
      tensor_list.append(image_tensor)
    
    input_tensor = torch.stack(tensor_list)

    with torch.no_grad():
      input_tensor = input_tensor.to(self.device)      
      embeddings = self.model(input_tensor)
      
    return embeddings
def main():
    # 1. argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--epoch', type=int, default=100)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--resume', type=int, default=0)
    opts = parser.parse_args()
    print(opts)

    # 2. device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # 3. visdom
    vis = visdom.Visdom()

    # 4. dataset
    mean, std = 0.1307, 0.3081

    transform = tfs.Compose([tfs.Normalize((mean, ), (std, ))])
    test_transform = tfs.Compose(
        [tfs.ToTensor(), tfs.Normalize((mean, ), (std, ))])

    train_set = MNIST('./data/MNIST',
                      train=True,
                      download=True,
                      transform=None)

    train_set = SEMI_MNIST(train_set, transform=transform, num_samples=100)

    test_set = MNIST('./data/MNIST',
                     train=False,
                     download=True,
                     transform=test_transform)

    # 5. data loader
    train_loader = DataLoader(dataset=train_set,
                              shuffle=True,
                              batch_size=opts.batch_size,
                              num_workers=8,
                              pin_memory=True)

    test_loader = DataLoader(
        dataset=test_set,
        shuffle=False,
        batch_size=opts.batch_size,
    )

    # 6. model
    model = EmbeddingNet().to(device)

    # 7. criterion
    criterion = MetricCrossEntropy().to(device)

    # 8. optimizer
    optimizer = torch.optim.SGD(params=model.parameters(),
                                lr=opts.lr,
                                momentum=0.9,
                                weight_decay=5e-4)

    # 9. scheduler
    scheduler = StepLR(optimizer=optimizer, step_size=50, gamma=1)
    # 10. resume
    if opts.resume:
        model.load_state_dict(
            torch.load('./saves/state_dict.{}'.format(opts.resume)))
        print("resume from {} epoch..".format(opts.resume - 1))
    else:
        print("no checkpoint to resume.. train from scratch.")

    # --
    for epoch in range(opts.resume, opts.epoch):

        # 11. trian
        for idx, (imgs, targets, samples, is_known) in enumerate(train_loader):
            model.train()
            batch_size = opts.batch_size

            imgs = imgs.to(device)  # [N, 1, 28, 28]
            targets = targets.to(device)  # [N]
            samples = samples.to(device)  # [N, 1, 32, 32]
            is_known = is_known.to(device)

            samples = samples.view(batch_size * 10, 1, 28, 28)
            out_x = model(imgs)  # [N, 10]
            out_z = model(samples).view(batch_size, 10,
                                        out_x.size(-1))  # [N * 10 , 2]
            loss = criterion(out_x, targets, out_z, is_known, 10, 1)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            for param_group in optimizer.param_groups:
                lr = param_group['lr']

            if idx % 100 == 0:
                print('Epoch : {}\t'
                      'step : [{}/{}]\t'
                      'loss : {}\t'
                      'lr   : {}\t'.format(epoch, idx, len(train_loader), loss,
                                           lr))

                vis.line(X=torch.ones(
                    (1, 1)) * idx + epoch * len(train_loader),
                         Y=torch.Tensor([loss]).unsqueeze(0),
                         update='append',
                         win='loss',
                         opts=dict(x_label='step',
                                   y_label='loss',
                                   title='loss',
                                   legend=['total_loss']))

        torch.save(model.state_dict(), './saves/state_dict.{}'.format(epoch))

        # 12. test
        correct = 0
        avg_loss = 0
        for idx, (img, target) in enumerate(test_loader):

            model.load_state_dict(
                torch.load('./saves/state_dict.{}'.format(epoch)))
            model.eval()
            img = img.to(device)  # [N, 1, 28, 28]
            target = target.to(device)  # [N]
            output = model(img)  # [N, 10]

            output = torch.softmax(output, -1)
            pred, idx_ = output.max(-1)
            print(idx_)
            correct += torch.eq(target, idx_).sum()
            #loss = criterion(output, target)
            #avg_loss += loss.item()

        print('Epoch {} test : '.format(epoch))
        accuracy = correct.item() / len(test_set)
        print("accuracy : {:.4f}%".format(accuracy * 100.))
        #avg_loss = avg_loss / len(test_loader)
        #print("avg_loss : {:.4f}".format(avg_loss))

        vis.line(X=torch.ones((1, 1)) * epoch,
                 Y=torch.Tensor([accuracy]).unsqueeze(0),
                 update='append',
                 win='test',
                 opts=dict(x_label='epoch',
                           y_label='test_',
                           title='test_loss',
                           legend=['accuracy']))
        scheduler.step()