Ejemplo n.º 1
0
def test_from_pretrained(tmpdir):
    vae = VAE(input_height=32)
    ae = AE(input_height=32)

    assert len(VAE.pretrained_weights_available()) > 0
    assert len(AE.pretrained_weights_available()) > 0

    exception_raised = False

    try:
        vae = vae.from_pretrained('cifar10-resnet18')
        vae = vae.from_pretrained(
            'stl10-resnet18'
        )  # try loading weights not compatible with exact architecture

        ae = ae.from_pretrained('cifar10-resnet18')
    except Exception as e:
        exception_raised = True

    assert exception_raised is False, "error in loading weights"

    keyerror = False

    try:
        vae = vae.from_pretrained('abc')
        ae = ae.from_pretrained('xyz')
    except KeyError:
        keyerror = True

    assert keyerror is True, "KeyError not raised when provided with illegal checkpoint name"
Ejemplo n.º 2
0
def test_from_pretrained(datadir):
    dm = CIFAR10DataModule(data_dir=datadir, batch_size=4)
    dm.setup()
    dm.prepare_data()

    data_loader = dm.val_dataloader()

    vae = VAE(input_height=32)
    ae = AE(input_height=32)

    assert len(VAE.pretrained_weights_available()) > 0
    assert len(AE.pretrained_weights_available()) > 0

    exception_raised = False

    try:
        vae = vae.from_pretrained('cifar10-resnet18')

        # test forward method on pre-trained weights
        for x, y in data_loader:
            x_hat = vae(x)
            break

        vae = vae.from_pretrained(
            'stl10-resnet18'
        )  # try loading weights not compatible with exact architecture

        ae = ae.from_pretrained('cifar10-resnet18')

        # test forward method on pre-trained weights
        for x, y in data_loader:
            x_hat = ae(x)
            break

    except Exception:
        exception_raised = True

    assert exception_raised is False, "error in loading weights"

    keyerror = False

    try:
        vae = vae.from_pretrained('abc')
        ae = ae.from_pretrained('xyz')
    except KeyError:
        keyerror = True

    assert keyerror is True, "KeyError not raised when provided with illegal checkpoint name"
Ejemplo n.º 3
0
def test_ae(tmpdir):
    reset_seed()

    model = AE(data_dir=tmpdir, batch_size=2)
    trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir)
    trainer.fit(model)
    trainer.test(model)
Ejemplo n.º 4
0
    def __init__(self, pretrained_path):
        super().__init__()
        self.pretrained_path = pretrained_path
        self.output_dim = 800 * 800

        # ------------------
        # PRE-TRAINED MODEL
        # ------------------
        ae = AE.load_from_checkpoint(pretrained_path)
        ae.freeze()

        self.backbone = ae.encoder
        self.backbone.c3_only = True
        self.backbone.out_channels = 32

        # ------------------
        # FAST RCNN
        # ------------------
        anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512), ),
                                           aspect_ratios=((0.5, 1.0, 2.0), ))

        roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
                                                        output_size=7,
                                                        sampling_ratio=2)
        self.fast_rcnn = FasterRCNN(self.backbone,
                                    num_classes=9,
                                    rpn_anchor_generator=anchor_generator,
                                    box_roi_pool=roi_pooler)

        # for unfreezing encoder later
        self.frozen = True
Ejemplo n.º 5
0
def test_ae(tmpdir, datadir, dm_cls):
    seed_everything()

    dm = dm_cls(data_dir=datadir, batch_size=4)
    model = AE(input_height=dm.size()[-1])
    trainer = Trainer(
        fast_dev_run=True,
        default_root_dir=tmpdir,
        gpus=None,
    )

    trainer.fit(model, datamodule=dm)
Ejemplo n.º 6
0
def test_ae(tmpdir, datadir, dm_cls):
    seed_everything()

    dm = dm_cls(data_dir=datadir, batch_size=4)
    model = AE(input_height=dm.size()[-1])
    trainer = pl.Trainer(fast_dev_run=True,
                         default_root_dir=tmpdir,
                         max_epochs=1,
                         gpus=None)

    result = trainer.fit(model, dm)
    assert result == 1
Ejemplo n.º 7
0
def basic_ae_example():
    from pl_bolts.models.autoencoders import AE
    from pytorch_lightning.plugins import DDPPlugin

    # Data.
    train_dataset = torchvision.datasets.CIFAR10(
        "",
        train=True,
        download=True,
        transform=torchvision.transforms.ToTensor())
    val_dataset = torchvision.datasets.CIFAR10(
        "",
        train=False,
        download=True,
        transform=torchvision.transforms.ToTensor())

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=32,
                                               num_workers=12)
    #train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, num_workers=12, persistent_workers=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=32,
                                             num_workers=12)
    #val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, num_workers=12, persistent_workers=True)

    # Model.
    model = AE(input_height=32)
    """
	# Override any part of this AE to build your own variation.
	class MyAEFlavor(AE):
		def init_encoder(self, hidden_dim, latent_dim, input_width, input_height):
			encoder = YourSuperFancyEncoder(...)
			return encoder

	model = MyAEFlavor(...)
	"""

    # Fit.
    trainer = pl.Trainer(gpus=2,
                         accelerator="ddp",
                         plugins=DDPPlugin(find_unused_parameters=False))
    trainer.fit(model, train_loader, val_loader)

    #--------------------
    # CIFAR-10 pretrained model:
    ae = AE(input_height=32)
    print(AE.pretrained_weights_available())
    ae = ae.from_pretrained("cifar10-resnet18")

    ae.freeze()
                    required=True,
                    help='root directory that contains captions')
parser.add_argument('--fasttext_model',
                    type=str,
                    required=True,
                    help='pretrained fastText model (binary file)')
parser.add_argument('--max_nwords',
                    type=int,
                    default=50,
                    help='maximum number of words (default: 50)')
parser.add_argument('--img_model',
                    type=str,
                    required=True,
                    help='pretrained autoencoder model')
args = parser.parse_args()

if __name__ == '__main__':
    caption_root = args.caption_root.split('/')[-1]
    if (caption_root + '_vec') not in os.listdir(
            args.caption_root.replace(caption_root, '')):
        os.makedirs(args.caption_root + '_vec')
        print('Loading a pretrained image model...')
        img_model = AE.load_from_checkpoint(checkpoint_path=args.img_model)
        model = nn.Sequential(img_model.encoder, img_model.fc)
        model = model.eval()
        print('Loading a pretrained fastText model...')
        word_embedding = fastText.load_model(args.fasttext_model)
        print('Making vectorized caption data files...')
        ConvertCapVec().convert_and_save3(args.caption_root, word_embedding,
                                          args.max_nwords, model)
Ejemplo n.º 9
0
import argparse

import pytorch_lightning as pl

from pl_bolts.models.autoencoders import AE

from ae_datasets import iCLEVRDataModule
from pytorch_lightning.callbacks import ModelCheckpoint

parser = argparse.ArgumentParser()

parser = AE.add_model_specific_args(parser)
parser = iCLEVRDataModule.add_dataset_specific_args(parser)
parser = pl.Trainer.add_argparse_args(parser)

opt = parser.parse_args()

model = AE(**vars(opt))
data = iCLEVRDataModule(**vars(opt))

trainer = pl.Trainer.from_argparse_args(
    opt,
    log_gpu_memory='all',
    checkpoint_callback=ModelCheckpoint(filepath="{epoch:02d}.ckpt"))
trainer.fit(model, data)