示例#1
0
def vis(model, data_dir, img_size):
    """ """
    # ## Visualising some of the data
    # The top row and the bottom row of any column is one pair. The 0s and 1s correspond to the column of the
    # image.
    # 0 indicates dissimilar, and 1 indicates similar.

    example_batch = next(
        iter(
            DataLoader(
                PairDataset(
                    data_path=data_dir,
                    transform=transforms.Compose([
                        transforms.Grayscale(),
                        transforms.Resize(img_size),
                        transforms.ToTensor(),
                    ]),
                    split=SplitEnum.validation,
                ),
                shuffle=True,
                num_workers=0,
                batch_size=8,
            )))
    concatenated = torch.cat((example_batch[0], example_batch[1]), 0)
    boxed_text_overlay_plot(torchvision.utils.make_grid(concatenated),
                            str(example_batch[2].numpy()))
示例#2
0
def stest_one_versus_many(model, data_dir, img_size):
    """ """
    data_iterator = iter(
        DataLoader(
            PairDataset(
                data_dir,
                transform=transforms.Compose([
                    transforms.Grayscale(),
                    transforms.Resize(img_size),
                    transforms.ToTensor(),
                ]),
                split=SplitEnum.testing,
            ),
            num_workers=0,
            batch_size=1,
            shuffle=True,
        ))
    x0, *_ = next(data_iterator)
    for i in range(10):
        _, x1, _ = next(data_iterator)
        dis = (torch.pairwise_distance(*model(
            to_tensor(x0, device=global_torch_device()),
            to_tensor(x1, device=global_torch_device()),
        )).cpu().item())
        boxed_text_overlay_plot(
            torchvision.utils.make_grid(torch.cat((x0, x1), 0)),
            f"Dissimilarity: {dis:.2f}",
        )
示例#3
0
def stest_many_versus_many(model, data_dir, img_size, threshold=0.5):
    """ """
    data_iterator = iter(
        DataLoader(
            PairDataset(
                data_dir,
                transform=transforms.Compose([
                    transforms.Grayscale(),
                    transforms.Resize(img_size),
                    transforms.ToTensor(),
                ]),
            ),
            num_workers=0,
            batch_size=1,
            shuffle=True,
        ))
    for i in range(10):
        x0, x1, is_diff = next(data_iterator)
        distance = (torch.pairwise_distance(*model(
            to_tensor(x0, device=global_torch_device()),
            to_tensor(x1, device=global_torch_device()),
        )).cpu().item())
        boxed_text_overlay_plot(
            torchvision.utils.make_grid(torch.cat((x0, x1), 0)),
            f"Truth: {'Different' if is_diff.cpu().item() else 'Alike'},"
            f" Dissimilarity: {distance:.2f},"
            f" Verdict: {'Different' if distance > threshold else 'Alike'}",
        )
示例#4
0
 def sample(self, horizontal_merge: bool = False) -> None:
     """ """
     dl = iter(
         DataLoader(
             self,
             batch_size=9,
             shuffle=True,
             num_workers=0,
             pin_memory=global_pin_memory(0),
         ))
     for _ in range(3):
         images1, images2, images3, *labels = next(dl)
         X1 = numpy.transpose(images1.numpy(), [0, 2, 3, 1])
         X2 = numpy.transpose(images2.numpy(), [0, 2, 3, 1])
         X3 = numpy.transpose(images3.numpy(), [0, 2, 3, 1])
         if horizontal_merge:
             X = numpy.dstack((X1, X2, X3))
         else:
             X = numpy.hstack((X1, X2, X3))
         PairDataset.plot_images(X, list(zip(*labels)))
示例#5
0
    def sample(self, horizontal_merge: bool = False) -> None:
        """

  """
        dl = iter(
            torch.utils.data.DataLoader(
                self, batch_size=9, shuffle=True, num_workers=1, pin_memory=False
            )
        )
        for _ in range(3):
            images1, images2, images3 = next(dl)
            X1 = images1.numpy()
            X1 = numpy.transpose(X1, [0, 2, 3, 1])
            X2 = images2.numpy()
            X2 = numpy.transpose(X2, [0, 2, 3, 1])
            X3 = images3.numpy()
            X3 = numpy.transpose(X3, [0, 2, 3, 1])
            if horizontal_merge:
                X = numpy.dstack((X1, X2, X3))
            else:
                X = numpy.hstack((X1, X2, X3))
            PairDataset.plot_images(X)
def stest_many_versus_many2(model: Module,
                            data_dir: Path,
                            img_size: Tuple[int, int],
                            threshold=0.5):
    """

:param model:
:type model:
:param data_dir:
:type data_dir:
:param img_size:
:type img_size:
:param threshold:
:type threshold:
"""
    dataiter = iter(
        DataLoader(
            PairDataset(
                data_dir,
                transform=transforms.Compose([
                    transforms.Grayscale(),
                    transforms.Resize(img_size),
                    transforms.ToTensor(),
                ]),
            ),
            num_workers=4,
            batch_size=1,
            shuffle=True,
        ))
    for i in range(10):
        x0, x1, is_diff = next(dataiter)
        distance = (model(
            to_tensor(x0, device=global_torch_device()),
            to_tensor(x1, device=global_torch_device()),
        ).cpu().item())
        boxed_text_overlay_plot(
            torchvision.utils.make_grid(torch.cat((x0, x1), 0)),
            f"Truth: {'Different' if is_diff.cpu().item() else 'Alike'},"
            f" Dissimilarity: {distance:.2f},"
            f" Verdict: {'Different' if distance > threshold else 'Alike'}",
        )
def train_siamese(
    model: Module,
    optimiser: Optimizer,
    criterion: callable,
    *,
    writer: Writer = MockWriter(),
    train_number_epochs: int,
    data_dir: Path,
    train_batch_size: int,
    model_name: str,
    save_path: Path,
    save_best: bool = False,
    img_size: Tuple[int, int],
    validation_interval: int = 1,
):
    """
:param img_size:
:type img_size:
:param validation_interval:
:type validation_interval:
:param data_dir:
:type data_dir:
:param optimiser:
:type optimiser:
:param criterion:
:type criterion:
:param writer:
:type writer:
:param model_name:
:type model_name:
:param save_path:
:type save_path:
:param save_best:
:type save_best:
:param model:
:type model:
:param train_number_epochs:
:type train_number_epochs:
:param train_batch_size:
:type train_batch_size:
:return:
:rtype:
"""

    train_dataloader = DataLoader(
        PairDataset(
            data_path=data_dir,
            transform=transforms.Compose([
                transforms.Grayscale(),
                transforms.Resize(img_size),
                transforms.ToTensor(),
            ]),
            split=Split.Training,
        ),
        shuffle=True,
        num_workers=4,
        batch_size=train_batch_size,
    )

    valid_dataloader = DataLoader(
        PairDataset(
            data_path=data_dir,
            transform=transforms.Compose([
                transforms.Grayscale(),
                transforms.Resize(img_size),
                transforms.ToTensor(),
            ]),
            split=Split.Validation,
        ),
        shuffle=True,
        num_workers=4,
        batch_size=train_batch_size,
    )

    best = math.inf

    E = tqdm(range(0, train_number_epochs))
    batch_counter = count()

    for epoch in E:
        for tss in train_dataloader:
            batch_i = next(batch_counter)
            with TorchTrainSession(model):
                o = [t.to(global_torch_device()) for t in tss]
                optimiser.zero_grad()
                loss_contrastive = criterion(model(*o[:2]),
                                             o[2].to(dtype=torch.float))
                loss_contrastive.backward()
                optimiser.step()
                train_loss = loss_contrastive.cpu().item()
                writer.scalar("train_loss", train_loss, batch_i)
            if batch_counter.__next__() % validation_interval == 0:
                with TorchEvalSession(model):
                    for tsv in valid_dataloader:
                        ov = [t.to(global_torch_device()) for t in tsv]
                        v_o, fact = model(*ov[:2]), ov[2].to(dtype=torch.float)
                        valid_loss = criterion(v_o, fact).cpu().item()
                        valid_accuracy = (accuracy(distances=v_o,
                                                   is_diff=fact).cpu().item())
                        writer.scalar("valid_loss", valid_loss, batch_i)
                        if valid_loss < best:
                            best = valid_loss
                            print(f"new best {best}")
                            writer.blip("new_best", batch_i)
                            if save_best:
                                save_model_parameters(
                                    model,
                                    optimiser=optimiser,
                                    model_name=model_name,
                                    save_directory=save_path,
                                )
            E.set_description(
                f"Epoch number {epoch}, Current train loss {train_loss}, valid loss {valid_loss}, valid_accuracy {valid_accuracy}"
            )

    return model