def vis(model, data_dir, img_size): """ """ # ## Visualising some of the data # The top row and the bottom row of any column is one pair. The 0s and 1s correspond to the column of the # image. # 0 indicates dissimilar, and 1 indicates similar. example_batch = next( iter( DataLoader( PairDataset( data_path=data_dir, transform=transforms.Compose([ transforms.Grayscale(), transforms.Resize(img_size), transforms.ToTensor(), ]), split=SplitEnum.validation, ), shuffle=True, num_workers=0, batch_size=8, ))) concatenated = torch.cat((example_batch[0], example_batch[1]), 0) boxed_text_overlay_plot(torchvision.utils.make_grid(concatenated), str(example_batch[2].numpy()))
def stest_one_versus_many(model, data_dir, img_size): """ """ data_iterator = iter( DataLoader( PairDataset( data_dir, transform=transforms.Compose([ transforms.Grayscale(), transforms.Resize(img_size), transforms.ToTensor(), ]), split=SplitEnum.testing, ), num_workers=0, batch_size=1, shuffle=True, )) x0, *_ = next(data_iterator) for i in range(10): _, x1, _ = next(data_iterator) dis = (torch.pairwise_distance(*model( to_tensor(x0, device=global_torch_device()), to_tensor(x1, device=global_torch_device()), )).cpu().item()) boxed_text_overlay_plot( torchvision.utils.make_grid(torch.cat((x0, x1), 0)), f"Dissimilarity: {dis:.2f}", )
def stest_many_versus_many(model, data_dir, img_size, threshold=0.5): """ """ data_iterator = iter( DataLoader( PairDataset( data_dir, transform=transforms.Compose([ transforms.Grayscale(), transforms.Resize(img_size), transforms.ToTensor(), ]), ), num_workers=0, batch_size=1, shuffle=True, )) for i in range(10): x0, x1, is_diff = next(data_iterator) distance = (torch.pairwise_distance(*model( to_tensor(x0, device=global_torch_device()), to_tensor(x1, device=global_torch_device()), )).cpu().item()) boxed_text_overlay_plot( torchvision.utils.make_grid(torch.cat((x0, x1), 0)), f"Truth: {'Different' if is_diff.cpu().item() else 'Alike'}," f" Dissimilarity: {distance:.2f}," f" Verdict: {'Different' if distance > threshold else 'Alike'}", )
def sample(self, horizontal_merge: bool = False) -> None: """ """ dl = iter( DataLoader( self, batch_size=9, shuffle=True, num_workers=0, pin_memory=global_pin_memory(0), )) for _ in range(3): images1, images2, images3, *labels = next(dl) X1 = numpy.transpose(images1.numpy(), [0, 2, 3, 1]) X2 = numpy.transpose(images2.numpy(), [0, 2, 3, 1]) X3 = numpy.transpose(images3.numpy(), [0, 2, 3, 1]) if horizontal_merge: X = numpy.dstack((X1, X2, X3)) else: X = numpy.hstack((X1, X2, X3)) PairDataset.plot_images(X, list(zip(*labels)))
def sample(self, horizontal_merge: bool = False) -> None: """ """ dl = iter( torch.utils.data.DataLoader( self, batch_size=9, shuffle=True, num_workers=1, pin_memory=False ) ) for _ in range(3): images1, images2, images3 = next(dl) X1 = images1.numpy() X1 = numpy.transpose(X1, [0, 2, 3, 1]) X2 = images2.numpy() X2 = numpy.transpose(X2, [0, 2, 3, 1]) X3 = images3.numpy() X3 = numpy.transpose(X3, [0, 2, 3, 1]) if horizontal_merge: X = numpy.dstack((X1, X2, X3)) else: X = numpy.hstack((X1, X2, X3)) PairDataset.plot_images(X)
def stest_many_versus_many2(model: Module, data_dir: Path, img_size: Tuple[int, int], threshold=0.5): """ :param model: :type model: :param data_dir: :type data_dir: :param img_size: :type img_size: :param threshold: :type threshold: """ dataiter = iter( DataLoader( PairDataset( data_dir, transform=transforms.Compose([ transforms.Grayscale(), transforms.Resize(img_size), transforms.ToTensor(), ]), ), num_workers=4, batch_size=1, shuffle=True, )) for i in range(10): x0, x1, is_diff = next(dataiter) distance = (model( to_tensor(x0, device=global_torch_device()), to_tensor(x1, device=global_torch_device()), ).cpu().item()) boxed_text_overlay_plot( torchvision.utils.make_grid(torch.cat((x0, x1), 0)), f"Truth: {'Different' if is_diff.cpu().item() else 'Alike'}," f" Dissimilarity: {distance:.2f}," f" Verdict: {'Different' if distance > threshold else 'Alike'}", )
def train_siamese( model: Module, optimiser: Optimizer, criterion: callable, *, writer: Writer = MockWriter(), train_number_epochs: int, data_dir: Path, train_batch_size: int, model_name: str, save_path: Path, save_best: bool = False, img_size: Tuple[int, int], validation_interval: int = 1, ): """ :param img_size: :type img_size: :param validation_interval: :type validation_interval: :param data_dir: :type data_dir: :param optimiser: :type optimiser: :param criterion: :type criterion: :param writer: :type writer: :param model_name: :type model_name: :param save_path: :type save_path: :param save_best: :type save_best: :param model: :type model: :param train_number_epochs: :type train_number_epochs: :param train_batch_size: :type train_batch_size: :return: :rtype: """ train_dataloader = DataLoader( PairDataset( data_path=data_dir, transform=transforms.Compose([ transforms.Grayscale(), transforms.Resize(img_size), transforms.ToTensor(), ]), split=Split.Training, ), shuffle=True, num_workers=4, batch_size=train_batch_size, ) valid_dataloader = DataLoader( PairDataset( data_path=data_dir, transform=transforms.Compose([ transforms.Grayscale(), transforms.Resize(img_size), transforms.ToTensor(), ]), split=Split.Validation, ), shuffle=True, num_workers=4, batch_size=train_batch_size, ) best = math.inf E = tqdm(range(0, train_number_epochs)) batch_counter = count() for epoch in E: for tss in train_dataloader: batch_i = next(batch_counter) with TorchTrainSession(model): o = [t.to(global_torch_device()) for t in tss] optimiser.zero_grad() loss_contrastive = criterion(model(*o[:2]), o[2].to(dtype=torch.float)) loss_contrastive.backward() optimiser.step() train_loss = loss_contrastive.cpu().item() writer.scalar("train_loss", train_loss, batch_i) if batch_counter.__next__() % validation_interval == 0: with TorchEvalSession(model): for tsv in valid_dataloader: ov = [t.to(global_torch_device()) for t in tsv] v_o, fact = model(*ov[:2]), ov[2].to(dtype=torch.float) valid_loss = criterion(v_o, fact).cpu().item() valid_accuracy = (accuracy(distances=v_o, is_diff=fact).cpu().item()) writer.scalar("valid_loss", valid_loss, batch_i) if valid_loss < best: best = valid_loss print(f"new best {best}") writer.blip("new_best", batch_i) if save_best: save_model_parameters( model, optimiser=optimiser, model_name=model_name, save_directory=save_path, ) E.set_description( f"Epoch number {epoch}, Current train loss {train_loss}, valid loss {valid_loss}, valid_accuracy {valid_accuracy}" ) return model