Beispiel #1
0
    def save(self, **kwargs) -> None:
        """

@param kwargs:
@return:
"""
        for (k, v), o in zip(self.models.items(), self.optimisers.values()):
            save_model_parameters(
                v, optimiser=o, model_name=self.model_name(k, v), **kwargs
            )
    def main():
        """

"""
        data_dir = Path.home() / "Data" / "mnist_png"
        train_batch_size = 64
        train_number_epochs = 100
        save_path = PROJECT_APP_PATH.user_data / "models"
        model_name = "pair_siamese_mnist"
        load_prev = True
        train = False
        img_size = (28, 28)
        model = PairRankingSiamese(img_size).to(global_torch_device())
        optimiser = optim.Adam(model.parameters(), lr=3e-4)

        if train:
            if load_prev:
                model, optimer = load_model_parameters(
                    model,
                    optimiser=optimiser,
                    model_name=model_name,
                    model_directory=save_path,
                )

            with TensorBoardPytorchWriter(PROJECT_APP_PATH.user_log /
                                          model_name /
                                          str(time.time())) as writer:
                # with CaptureEarlyStop() as _:
                with suppress(KeyboardInterrupt):
                    model = train_siamese(
                        model,
                        optimiser,
                        nn.BCELoss().to(global_torch_device()),
                        train_number_epochs=train_number_epochs,
                        data_dir=data_dir,
                        train_batch_size=train_batch_size,
                        model_name=model_name,
                        save_path=save_path,
                        writer=writer,
                        img_size=img_size,
                    )
            save_model_parameters(
                model,
                optimiser=optimiser,
                model_name=f"{model_name}",
                save_directory=save_path,
            )
        else:
            model = load_model_parameters(model,
                                          model_name=model_name,
                                          model_directory=save_path)
            print("loaded best val")
            stest_many_versus_many2(model, data_dir, img_size)
Beispiel #3
0
    def main():
        """ """
        data_dir = Path.home() / "Data" / "mnist_png"
        train_batch_size = 64
        train_number_epochs = 100
        save_path = PROJECT_APP_PATH.user_data / "models"
        model_name = "triplet_siamese_mnist"
        load_prev = True
        train = False

        img_size = (28, 28)
        model = NLetConvNet(img_size).to(global_torch_device())
        optimiser = optim.Adam(model.parameters(), lr=3e-4)

        if train:
            if load_prev:
                model, optimiser = load_model_parameters(
                    model,
                    optimiser=optimiser,
                    model_name=model_name,
                    model_directory=save_path,
                )

            with TensorBoardPytorchWriter():
                # from draugr.stopping import CaptureEarlyStop

                # with CaptureEarlyStop() as _:
                with IgnoreInterruptSignal():
                    model = train_siamese(
                        model,
                        optimiser,
                        TripletMarginLoss().to(global_torch_device()),
                        train_number_epochs=train_number_epochs,
                        data_dir=data_dir,
                        train_batch_size=train_batch_size,
                        model_name=model_name,
                        save_path=save_path,
                        img_size=img_size,
                    )
            save_model_parameters(
                model,
                optimiser=optimiser,
                model_name=f"{model_name}",
                save_directory=save_path,
            )
        else:
            model = load_model_parameters(model,
                                          model_name=model_name,
                                          model_directory=save_path)
            print("loaded best val")
            stest_many_versus_many(model, data_dir, img_size)
Beispiel #4
0
def train_siamese(
    model,
    optimiser,
    criterion,
    *,
    writer: Writer = MockWriter(),
    train_number_epochs,
    data_dir,
    train_batch_size,
    model_name,
    save_path,
    save_best=False,
    img_size,
    validation_interval: int = 1,
):
    """
    :param data_dir:
    :type data_dir:
    :param optimiser:
    :type optimiser:
    :param criterion:
    :type criterion:
    :param writer:
    :type writer:
    :param model_name:
    :type model_name:
    :param save_path:
    :type save_path:
    :param save_best:
    :type save_best:
    :param model:
    :type model:
    :param train_number_epochs:
    :type train_number_epochs:
    :param train_batch_size:
    :type train_batch_size:
    :return:
    :rtype:

      Parameters
      ----------
      img_size
      validation_interval"""

    train_dataloader = DataLoader(
        TripletDataset(
            data_path=data_dir,
            transform=transforms.Compose([
                transforms.Grayscale(),
                transforms.Resize(img_size),
                transforms.ToTensor(),
            ]),
            split=SplitEnum.training,
        ),
        shuffle=True,
        num_workers=0,
        batch_size=train_batch_size,
    )

    valid_dataloader = DataLoader(
        TripletDataset(
            data_path=data_dir,
            transform=transforms.Compose([
                transforms.Grayscale(),
                transforms.Resize(img_size),
                transforms.ToTensor(),
            ]),
            split=SplitEnum.validation,
        ),
        shuffle=True,
        num_workers=0,
        batch_size=train_batch_size,
    )

    best = math.inf

    E = tqdm(range(0, train_number_epochs))
    batch_counter = count()

    for epoch in E:
        for tss in train_dataloader:
            batch_i = next(batch_counter)
            with TorchTrainSession(model):
                optimiser.zero_grad()
                loss_contrastive = criterion(*model(
                    *[t.to(global_torch_device()) for t in tss]))
                loss_contrastive.backward()
                optimiser.step()
                a = loss_contrastive.cpu().item()
                writer.scalar("train_loss", a, batch_i)
            if batch_counter.__next__() % validation_interval == 0:
                with TorchEvalSession(model):
                    for tsv in valid_dataloader:
                        o = model(*[t.to(global_torch_device()) for t in tsv])
                        a_v = criterion(*o).cpu().item()
                        valid_positive_acc = (accuracy(
                            distances=pairwise_distance(o[0], o[1]),
                            is_diff=0).cpu().item())
                        valid_negative_acc = (accuracy(
                            distances=pairwise_distance(o[0], o[2]),
                            is_diff=1).cpu().item())
                        valid_acc = numpy.mean(
                            (valid_negative_acc, valid_positive_acc))
                        writer.scalar("valid_loss", a_v, batch_i)
                        writer.scalar("valid_positive_acc", valid_positive_acc,
                                      batch_i)
                        writer.scalar("valid_negative_acc", valid_negative_acc,
                                      batch_i)
                        writer.scalar("valid_acc", valid_acc, batch_i)
                        if a_v < best:
                            best = a_v
                            print(f"new best {best}")
                            if save_best:
                                save_model_parameters(
                                    model,
                                    optimiser=optimiser,
                                    model_name=model_name,
                                    save_directory=save_path,
                                )
            E.set_description(
                f"Epoch number {epoch}, Current train loss {a}, valid loss {a_v}, valid acc {valid_acc}"
            )

    return model
def train_siamese(
    model: Module,
    optimiser: Optimizer,
    criterion: callable,
    *,
    writer: Writer = MockWriter(),
    train_number_epochs: int,
    data_dir: Path,
    train_batch_size: int,
    model_name: str,
    save_path: Path,
    save_best: bool = False,
    img_size: Tuple[int, int],
    validation_interval: int = 1,
):
    """
:param img_size:
:type img_size:
:param validation_interval:
:type validation_interval:
:param data_dir:
:type data_dir:
:param optimiser:
:type optimiser:
:param criterion:
:type criterion:
:param writer:
:type writer:
:param model_name:
:type model_name:
:param save_path:
:type save_path:
:param save_best:
:type save_best:
:param model:
:type model:
:param train_number_epochs:
:type train_number_epochs:
:param train_batch_size:
:type train_batch_size:
:return:
:rtype:
"""

    train_dataloader = DataLoader(
        PairDataset(
            data_path=data_dir,
            transform=transforms.Compose([
                transforms.Grayscale(),
                transforms.Resize(img_size),
                transforms.ToTensor(),
            ]),
            split=Split.Training,
        ),
        shuffle=True,
        num_workers=4,
        batch_size=train_batch_size,
    )

    valid_dataloader = DataLoader(
        PairDataset(
            data_path=data_dir,
            transform=transforms.Compose([
                transforms.Grayscale(),
                transforms.Resize(img_size),
                transforms.ToTensor(),
            ]),
            split=Split.Validation,
        ),
        shuffle=True,
        num_workers=4,
        batch_size=train_batch_size,
    )

    best = math.inf

    E = tqdm(range(0, train_number_epochs))
    batch_counter = count()

    for epoch in E:
        for tss in train_dataloader:
            batch_i = next(batch_counter)
            with TorchTrainSession(model):
                o = [t.to(global_torch_device()) for t in tss]
                optimiser.zero_grad()
                loss_contrastive = criterion(model(*o[:2]),
                                             o[2].to(dtype=torch.float))
                loss_contrastive.backward()
                optimiser.step()
                train_loss = loss_contrastive.cpu().item()
                writer.scalar("train_loss", train_loss, batch_i)
            if batch_counter.__next__() % validation_interval == 0:
                with TorchEvalSession(model):
                    for tsv in valid_dataloader:
                        ov = [t.to(global_torch_device()) for t in tsv]
                        v_o, fact = model(*ov[:2]), ov[2].to(dtype=torch.float)
                        valid_loss = criterion(v_o, fact).cpu().item()
                        valid_accuracy = (accuracy(distances=v_o,
                                                   is_diff=fact).cpu().item())
                        writer.scalar("valid_loss", valid_loss, batch_i)
                        if valid_loss < best:
                            best = valid_loss
                            print(f"new best {best}")
                            writer.blip("new_best", batch_i)
                            if save_best:
                                save_model_parameters(
                                    model,
                                    optimiser=optimiser,
                                    model_name=model_name,
                                    save_directory=save_path,
                                )
            E.set_description(
                f"Epoch number {epoch}, Current train loss {train_loss}, valid loss {valid_loss}, valid_accuracy {valid_accuracy}"
            )

    return model