def train(model_params, **training_kwargs): model_config = make_config(**model_params) training_params = vars(parse_args()) training_params.update(training_kwargs) hparams = dict(model_config=model_config) hparams.update(training_params) model = SCAEMNIST(Namespace(**hparams)) print("___________") print("loading model") # model = SCAEMNIST() # model = SCAEMNIST.load_from_checkpoint('/content/lightning_logs/version_0/checkpoints/epoch=48.ckpt') # training_params.update(resume_from_checkpoint='/content/lightning_logs/version_10/checkpoints/epoch=72.ckpt') checkpoint_callback = ModelCheckpoint(save_top_k=1) training_params.update(checkpoint_callback=checkpoint_callback) # if 'save_top_k' in training_params: # checkpoint_callback = ModelCheckpoint( # save_top_k=training_params['save_top_k']) # training_params.update(checkpoint_callback=checkpoint_callback) # del training_params['save_top_k'] training_params.update(gpus=1, auto_select_gpus=True) trainer = Trainer(**training_params) trainer.fit(model)
def get_mdl_obj(): from torch_scae_experiments.wdata.hparams import model_params model_config = make_config(**model_params) hparams = dict(model_config=model_config) model = SCAEMNIST(Namespace(**hparams)) return model
def test_scae(self): config = Namespace(**factory.make_config(**model_params)) cnn_encoder = CNNEncoder(**config.pcae_cnn_encoder) part_encoder = CapsuleImageEncoder(encoder=cnn_encoder, **config.pcae_encoder) template_generator = TemplateGenerator( **config.pcae_template_generator) part_decoder = TemplateBasedImageDecoder(**config.pcae_decoder) obj_encoder = SetTransformer(**config.ocae_encoder_set_transformer) obj_decoder_capsule = CapsuleLayer(**config.ocae_decoder_capsule) obj_decoder = CapsuleObjectDecoder(obj_decoder_capsule) scae = SCAE(part_encoder=part_encoder, template_generator=template_generator, part_decoder=part_decoder, obj_encoder=obj_encoder, obj_decoder=obj_decoder, **config.scae) with torch.no_grad(): batch_size = 24 image = torch.rand(batch_size, *config.image_shape) label = torch.randint(0, config.n_classes, (batch_size, )) reconstruction_target = image res = scae(image=image) loss = scae.loss(res, reconstruction_target, label) accuracy = scae.calculate_accuracy(res, label)
def train(model_params, **training_kwargs): model_config = make_config(**model_params) training_params = vars(parse_args()) training_params.update(training_kwargs) hparams = dict(model_config=model_config) hparams.update(training_params) model = SCAECIFAR10(Namespace(**hparams)) if 'save_top_k' in training_params: checkpoint_callback = ModelCheckpoint( save_top_k=training_params['save_top_k']) training_params.update(checkpoint_callback=checkpoint_callback) del training_params['save_top_k'] trainer = Trainer(**training_params) trainer.fit(model)
def train(model_params, **training_kwargs): model_config = make_config(**model_params) training_params = vars(parse_args()) training_params.update(training_kwargs) hparams = dict(model_config=model_config) hparams.update(training_params) model = SCAEBONEAGE(Namespace(**hparams)) if 'save_top_k' in training_params: if ('using_colab' in training_params): del training_params['using_colab'] checkpoint_callback = ModelCheckpoint( filepath="/content/drive/My Drive/data/boneage", save_top_k=training_params['save_top_k']) else: checkpoint_callback = ModelCheckpoint( save_top_k=training_params['save_top_k']) training_params.update(checkpoint_callback=checkpoint_callback) del training_params['save_top_k'] trainer = Trainer(**training_params) trainer.fit(model)
from torch_scae_experiments.mnist.train import SCAEMNIST from torch_scae_experiments.mnist.hparams import model_params from torch_scae.factory import make_config from argparse import Namespace from torch.utils.data import DataLoader, random_split from torchvision.datasets import MNIST from torch import cuda import pathlib if cuda.is_available(): cuda.set_device(0) print(cuda.current_device()) #Model Initialization model_config = make_config(**model_params) training_hparams = dict(data_dir=str(pathlib.Path.home() / 'torch-datasets'), gpus=1, batch_size=16, num_workers=0, max_epochs=100, learning_rate=1e-4, optimizer_type='RMSprop', use_lr_scheduler=True, lr_scheduler_decay_rate=0.997, model_config=model_config) scaemnist = SCAEMNIST(Namespace(**training_hparams)) #Data preparation data_dir = training_hparams['data_dir'] # train and validation datasets mnist_train = MNIST(data_dir,