def process_state(self, state):
     x = torch.tensor(data=state, dtype=torch.float, device=self.device)
     # Change color channel position from (210, 160, 3) to (1, 210, 160)
     x = x.permute(2, 0, 1)
     # From color to gray
     x = rgb_to_grayscale(x)
     # Resize from (1, 210, 160) to (1, 84, 84)
     x = Resize([RESIZE, RESIZE])(x)
     # Reduce size 1 dimension
     x = x.squeeze(0)
     # Normalize input 0 to i
     x = x.div(255)
     return x.detach().cpu().numpy()
Esempio n. 2
0
def predict(image_path, checkpoint_path, save_path):
    model = Unet()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.load_state_dict(torch.load(checkpoint_path, map_location=device))

    image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    size = image.shape
    image = Compose([
        ToTensor(),
        Resize((512, 512)),
    ])(image)
    image = image.unsqueeze(0)

    mask = model(image)[0]
    mask[mask < 0.5] = 0
    mask[mask > 0.5] = 255
    mask = Resize(size)(mask)
    mask = mask.detach().numpy()

    cv2.imwrite('result.png', mask[0])
    pass