Exemple #1
0
def train(model_params, **training_kwargs):
    model_config = make_config(**model_params)

    training_params = vars(parse_args())
    training_params.update(training_kwargs)

    hparams = dict(model_config=model_config)
    hparams.update(training_params)
    model = SCAEMNIST(Namespace(**hparams))
    print("___________")
    print("loading model")
    # model = SCAEMNIST()
    # model = SCAEMNIST.load_from_checkpoint('/content/lightning_logs/version_0/checkpoints/epoch=48.ckpt')

    # training_params.update(resume_from_checkpoint='/content/lightning_logs/version_10/checkpoints/epoch=72.ckpt')
    checkpoint_callback = ModelCheckpoint(save_top_k=1)
    training_params.update(checkpoint_callback=checkpoint_callback)
    # if 'save_top_k' in training_params:
    #     checkpoint_callback = ModelCheckpoint(
    #         save_top_k=training_params['save_top_k'])
    #     training_params.update(checkpoint_callback=checkpoint_callback)
    #     del training_params['save_top_k']
    training_params.update(gpus=1, auto_select_gpus=True)
    trainer = Trainer(**training_params)
    trainer.fit(model)
Exemple #2
0
def get_mdl_obj():
    from torch_scae_experiments.wdata.hparams import model_params

    model_config = make_config(**model_params)
    hparams = dict(model_config=model_config)
    model = SCAEMNIST(Namespace(**hparams))
    return model
    def test_scae(self):
        config = Namespace(**factory.make_config(**model_params))

        cnn_encoder = CNNEncoder(**config.pcae_cnn_encoder)
        part_encoder = CapsuleImageEncoder(encoder=cnn_encoder,
                                           **config.pcae_encoder)

        template_generator = TemplateGenerator(
            **config.pcae_template_generator)
        part_decoder = TemplateBasedImageDecoder(**config.pcae_decoder)

        obj_encoder = SetTransformer(**config.ocae_encoder_set_transformer)

        obj_decoder_capsule = CapsuleLayer(**config.ocae_decoder_capsule)
        obj_decoder = CapsuleObjectDecoder(obj_decoder_capsule)

        scae = SCAE(part_encoder=part_encoder,
                    template_generator=template_generator,
                    part_decoder=part_decoder,
                    obj_encoder=obj_encoder,
                    obj_decoder=obj_decoder,
                    **config.scae)

        with torch.no_grad():
            batch_size = 24
            image = torch.rand(batch_size, *config.image_shape)
            label = torch.randint(0, config.n_classes, (batch_size, ))
            reconstruction_target = image

            res = scae(image=image)
            loss = scae.loss(res, reconstruction_target, label)
            accuracy = scae.calculate_accuracy(res, label)
Exemple #4
0
def train(model_params, **training_kwargs):
    model_config = make_config(**model_params)

    training_params = vars(parse_args())
    training_params.update(training_kwargs)

    hparams = dict(model_config=model_config)
    hparams.update(training_params)
    model = SCAECIFAR10(Namespace(**hparams))

    if 'save_top_k' in training_params:
        checkpoint_callback = ModelCheckpoint(
            save_top_k=training_params['save_top_k'])
        training_params.update(checkpoint_callback=checkpoint_callback)
        del training_params['save_top_k']

    trainer = Trainer(**training_params)
    trainer.fit(model)
Exemple #5
0
def train(model_params, **training_kwargs):
    model_config = make_config(**model_params)

    training_params = vars(parse_args())
    training_params.update(training_kwargs)

    hparams = dict(model_config=model_config)
    hparams.update(training_params)
    model = SCAEBONEAGE(Namespace(**hparams))

    if 'save_top_k' in training_params:
        if ('using_colab' in training_params):
            del training_params['using_colab']
            checkpoint_callback = ModelCheckpoint(
                filepath="/content/drive/My Drive/data/boneage",
                save_top_k=training_params['save_top_k'])
        else:
            checkpoint_callback = ModelCheckpoint(
                save_top_k=training_params['save_top_k'])
            training_params.update(checkpoint_callback=checkpoint_callback)
        del training_params['save_top_k']

    trainer = Trainer(**training_params)
    trainer.fit(model)
Exemple #6
0
from torch_scae_experiments.mnist.train import SCAEMNIST
from torch_scae_experiments.mnist.hparams import model_params
from torch_scae.factory import make_config
from argparse import Namespace
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torch import cuda
import pathlib

if cuda.is_available():
    cuda.set_device(0)
    print(cuda.current_device())

#Model Initialization
model_config = make_config(**model_params)
training_hparams = dict(data_dir=str(pathlib.Path.home() / 'torch-datasets'),
                        gpus=1,
                        batch_size=16,
                        num_workers=0,
                        max_epochs=100,
                        learning_rate=1e-4,
                        optimizer_type='RMSprop',
                        use_lr_scheduler=True,
                        lr_scheduler_decay_rate=0.997,
                        model_config=model_config)
scaemnist = SCAEMNIST(Namespace(**training_hparams))

#Data preparation
data_dir = training_hparams['data_dir']
# train and validation datasets
mnist_train = MNIST(data_dir,