Example #1
0
    def test_collision(self):
        _ = torchbearer.state_key('test')
        key_1 = torchbearer.state_key('test')
        key_2 = torchbearer.state_key('test')

        self.assertTrue('test' != str(key_1))
        self.assertTrue('test' != str(key_2))
Example #2
0
    def test_duplicate_string(self):
        _ = torchbearer.state_key('test_dup')
        key_1 = torchbearer.state_key('test_dup')
        key_2 = torchbearer.state_key('test_dup')

        self.assertTrue('test_dup_1' == str(key_1))
        self.assertTrue('test_dup_2' == str(key_2))
Example #3
0
    def test_contains(self):
        s = State()

        key1 = torchbearer.state_key('test_a')
        key2 = torchbearer.state_key('test_b')

        s[key1] = 1
        s[key2] = 2

        self.assertTrue(s.__contains__(key1))
Example #4
0
    def test_update(self):
        s = State()

        key1 = torchbearer.state_key('test_a')
        key2 = torchbearer.state_key('test_b')

        new_s = {key1: 1, key2: 2}
        s.update(new_s)

        self.assertTrue(s.__contains__(key1))
        self.assertTrue(s[key1] == 1)
Example #5
0
    def test_delete(self):
        s = State()

        key1 = torchbearer.state_key('test_a')
        key2 = torchbearer.state_key('test_b')

        s[key1] = 1
        s[key2] = 2

        self.assertTrue(s.__contains__(key1))
        s.__delitem__(key1)
        self.assertFalse(s.__contains__(key1))
Example #6
0
    def test_warn(self):
        s = State()

        key1 = torchbearer.state_key('test_a')
        key2 = torchbearer.state_key('test_b')

        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter('always')

            s[key1] = 'key_1'
            s[key2] = 'key_2'
            s['bad_key'] = 'bad_key'
            self.assertTrue(len(w) == 1)
            self.assertTrue(
                'State was accessed with a string' in str(w[-1].message))
Example #7
0
    def test_key_added(self):
        key = torchbearer.state_key('key')

        self.assertTrue(key in torchbearer.STATE_KEYS)
Example #8
0
 def test_compare_to_string(self):
     key_1 = torchbearer.state_key('test_compare')
     self.assertEqual(key_1, 'test_compare')
Example #9
0
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '1'
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchbearer import state_key
from torchbearer import callbacks

T = state_key('t')
T_SHUFFLED = state_key('t_shuffled')
MI = state_key('mi')


class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)


class DoNothing(nn.Module):
    def forward(self, x):
        return x


def resample(x):
    return F.fold(F.unfold(x, kernel_size=2, stride=2),
                  (int(x.size(2) / 2), int(x.size(3) / 2)), 1)

Example #10
0
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from memory import Memory

import torchbearer
from torchbearer import Trial, callbacks
from torchbearer.cv_utils import DatasetValidationSplitter

import visualise

from torch.distributions import RelaxedBernoulli

MU = torchbearer.state_key('mu')
LOGVAR = torchbearer.state_key('logvar')
STAGES = torchbearer.state_key('stages')
MASKED_TARGET = torchbearer.state_key('masked')


class Block(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super(Block, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding)
        torch.nn.init.kaiming_uniform_(self.conv.weight)

    def forward(self, x):
        return self.conv(x)

Example #11
0
# Define constants
epochs = 200
batch_size = 64
lr = 0.0002
nworkers = 8
latent_dim = 100
sample_interval = 400
img_shape = (1, 28, 28)
adversarial_loss = torch.nn.BCELoss()
device = 'cuda'
valid = torch.ones(batch_size, 1, device=device)
fake = torch.zeros(batch_size, 1, device=device)

# Register state keys (optional)
GEN_IMGS = state_key('gen_imgs')
DISC_GEN = state_key('disc_gen')
DISC_GEN_DET = state_key('disc_gen_det')
DISC_REAL = state_key('disc_real')
G_LOSS = state_key('g_loss')
D_LOSS = state_key('d_loss')


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
Example #12
0
    from torch import distributions

    from .vae import VAE, LATENT
    from implementations.torchbearer_implementation import FMix

    import argparse

    parser = argparse.ArgumentParser(description='VAE Training')
    parser.add_argument('--mode', default='base', type=str, help='name of run')
    parser.add_argument('--i', default=1, type=int, help='iteration')
    parser.add_argument('--var', default=1, type=float, help='iteration')
    parser.add_argument('--epochs', default=100, type=int, help='epochs')
    parser.add_argument('--dir', default='vaes', type=str, help='directory')
    args = parser.parse_args()

    KL = torchbearer.state_key('KL')
    NLL = torchbearer.state_key('NLL')
    SAMPLE = torchbearer.state_key('SAMPLE')

    # Data
    normalize = transforms.Normalize((0.4914, 0.4822, 0.4465),
                                     (0.2023, 0.1994, 0.2010))
    inv_normalize = transforms.Normalize(
        (-0.4914 / 0.2023, -0.4822 / 0.1994, -0.4465 / 0.2010),
        (1 / 0.2023, 1 / 0.1994, 1 / 0.2010))
    transform_base = [transforms.ToTensor(), normalize]

    transform = [
        transforms.ColorJitter(0.05, 0.05, 0.05, 0.05),
        transforms.RandomHorizontalFlip()
    ] + transform_base
Example #13
0
    with open(str(directory) + '/' + name, "w") as f:
        f.write(' '.join(sys.argv))


def save_model_info(model, directory, name='model-info.txt'):
    with open(str(directory) + '/' + name, "w") as f:
        f.write(str(model))


# argparse that doesn't show errors and exit
class FakeArgumentParser(argparse.ArgumentParser):
    def error(self, message):
        pass


ORIGINAL_Y_TRUE = torchbearer.state_key('original_y_true')


# Loader that transforms the target into the input, but maintains the labels in a separate key
def autoenc_loader(state):
    image, label = torchbearer.deep_to(next(state[torchbearer.ITERATOR]), state[torchbearer.DEVICE],
                                       state[torchbearer.DATA_TYPE])
    state[torchbearer.X] = image
    state[torchbearer.Y_TRUE] = image
    state[ORIGINAL_Y_TRUE] = label


def _parse_schedule(sched):
    if '@' in sched:
        factor, schtype = sched.split('@')
        factor = float(factor)
Example #14
0
    def test_key_added(self):
        key = torchbearer.state_key('key')

        self.assertTrue('key' in torchbearer.state.__keys__)
Example #15
0
    def test_key_metric(self):
        key = torchbearer.state_key('test')
        state = {key: 4}

        self.assertDictEqual(key.process(state), {str(key): 4})
        self.assertDictEqual(key.process_final(state), {str(key): 4})
Example #16
0
 def test_key_repr(self):
     key = torchbearer.state_key('repr_test')
     self.assertEqual(str(key), 'repr_test')
     self.assertEqual(repr(key), 'repr_test')
Example #17
0
    def test_key_call(self):
        key = torchbearer.state_key('call_test')
        state = {key: 'test'}

        self.assertEqual(key(state), 'test')
Example #18
0
    def test_duplicate(self):
        key = torchbearer.state_key(torchbearer.MODEL)

        self.assertTrue(torchbearer.MODEL != key)
Example #19
0
trainset = AutoEncoderMNIST(basetrainset)

valset = AutoEncoderMNIST(basevalset)

traingen = torch.utils.data.DataLoader(trainset,
                                       batch_size=BATCH_SIZE,
                                       shuffle=True,
                                       num_workers=8)

valgen = torch.utils.data.DataLoader(valset,
                                     batch_size=BATCH_SIZE,
                                     shuffle=True,
                                     num_workers=8)

# State keys
MU, LOGVAR = torchbearer.state_key('mu'), torchbearer.state_key('logvar')


class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, 20)
        self.fc22 = nn.Linear(400, 20)
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 784)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)
Example #20
0
# Define constants
epochs = 200
batch_size = 64
lr = 0.0002
nworkers = 8
latent_dim = 100
sample_interval = 400
img_shape = (1, 28, 28)
adversarial_loss = torch.nn.BCELoss()
device = 'cuda'
valid = torch.ones(batch_size, 1, device=device)
fake = torch.zeros(batch_size, 1, device=device)
batch = torch.randn(25, latent_dim).to(device)

# Register state keys (optional)
GEN_IMGS = state_key('gen_imgs')
DISC_GEN = state_key('disc_gen')
DISC_GEN_DET = state_key('disc_gen_det')
DISC_REAL = state_key('disc_real')
G_LOSS = state_key('g_loss')
D_LOSS = state_key('d_loss')

DISC_OPT = state_key('disc_opt')
GEN_OPT = state_key('gen_opt')
DISC_MODEL = state_key('disc_model')
DISC_IMGS = state_key('disc_imgs')
DISC_CRIT = state_key('disc_crit')


class Generator(nn.Module):
    def __init__(self):
import torch
import torch.nn.functional as F
import torch.nn as nn

import torchbearer
from torchbearer import Trial
import torchbearer.callbacks as callbacks
import torchbearer.callbacks.imaging as imaging
from scattered_cifar import ScatteredCIFAR10
import torchvision.transforms as transforms

TRANSFORMED = torchbearer.state_key('transformed')


class FoveatedConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, input_size=96, pool=False):
        super().__init__()
        out_channels = int(out_channels / 4)

        if pool:
            self.conv1 = nn.Sequential(
                nn.AvgPool2d(2),
                nn.Conv2d(in_channels,
                          out_channels,
                          kernel_size=3,
                          stride=1,
                          padding=1))
        else:
            self.conv1 = nn.Conv2d(in_channels,
                                   out_channels,
                                   kernel_size=3,
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import transforms
from torchvision.utils import make_grid
from torchvision.datasets import FashionMNIST
import torchbearer
import torchbearer.callbacks as callbacks
from torchbearer import Trial, state_key
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
MU = state_key('mu')
LOGVAR = state_key('logvar')


class VAE(nn.Module):
    def __init__(self, latent_size):
        super(VAE, self).__init__()
        self.latent_size = latent_size

        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, 4, 1, 2),   # B,  32, 28, 28
            nn.ReLU(True),
            nn.Conv2d(32, 32, 4, 2, 1),  # B,  32, 14, 14
            nn.ReLU(True),
            nn.Conv2d(32, 64, 4, 2, 1),  # B,  64,  7, 7
        )
Example #23
0
import unittest
from mock import Mock

import torch

import torchbearer
from torchbearer.variational import DivergenceBase, SimpleNormalUnitNormalKL, SimpleNormalSimpleNormalKL, SimpleNormal, SimpleWeibull, SimpleWeibullSimpleWeibullKL

key = torchbearer.state_key('divergence_test')


class TestDivergenceBase(unittest.TestCase):
    def test_on_criterion(self):
        divergence = DivergenceBase({'test': key}).with_sum_sum_reduction()
        divergence.compute = Mock(
            return_value=torch.ones((2, 2), requires_grad=True))

        state = {torchbearer.LOSS: torch.zeros(1, requires_grad=True), key: 1}

        divergence.on_criterion(state)
        self.assertTrue(state[torchbearer.LOSS].item() == 4)
        self.assertTrue(state[torchbearer.LOSS].requires_grad)
        divergence.compute.assert_called_once_with(test=1)

    def test_on_criterion_validation(self):
        divergence = DivergenceBase({'test': key}).with_sum_sum_reduction()
        divergence.compute = Mock(
            return_value=torch.ones((2, 2), requires_grad=True))

        state = {torchbearer.LOSS: torch.zeros(1, requires_grad=True), key: 1}
Example #24
0
import warnings

import torch
from torchbearer import Metric
import torchbearer
from torchbearer.metrics import CategoricalAccuracy, mean, running_mean

from dsketch.experiments.shared import utils
from dsketch.losses import chamfer
from dsketch.utils.mod_haussdorff import binary_image_to_points, mod_hausdorff_distance

HARDRASTER = torchbearer.state_key('hardraster')
SQ_DISTANCE_TRANSFORM = torchbearer.state_key('squared_distance_transform')
Y_PRED_CLASSES = torchbearer.state_key('y_pred_classes')


@running_mean
@mean
class ClassificationMetric(Metric):
    def __init__(self, classification_model):
        super().__init__("recon_class_acc")
        self.classification_model = classification_model

    def process(self, state):
        y_pred = self.classification_model(state[
            torchbearer.Y_PRED])  # take the reconstruction and classify it
        y_true = state[utils.ORIGINAL_Y_TRUE]

        if len(y_true.shape) == 2:
            _, y_true = torch.max(y_true, 1)
Example #25
0
import math

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

import torchbearer
from torchbearer import cite

IMAGE = torchbearer.state_key('image')
""" State key under which to hold the image being ascended on """

_stanley2007compositional = """
@article{stanley2007compositional,
  title={Compositional pattern producing networks: A novel abstraction of development},
  author={Stanley, Kenneth O},
  journal={Genetic programming and evolvable machines},
  volume={8},
  number={2},
  pages={131--162},
  year={2007},
  publisher={Springer}
}
"""


def _correlate_color(image, correlation, max_norm):
    if image.size(0) == 4:
        alpha = image[-1].unsqueeze(0)
Example #26
0
import os

import torch
import torch.nn as nn
import torch.optim as optim
import torchbearer
import torchvision
from torchbearer import Trial, callbacks
from torchvision import transforms

import tb_modules as tm

MU = torchbearer.state_key('mu')
LOGVAR = torchbearer.state_key('logvar')


class Block(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0):
        super(Block, self).__init__()
        self.conv = nn.Conv2d(in_channels,
                              out_channels,
                              kernel_size,
                              stride=stride,
                              padding=padding)
        torch.nn.init.kaiming_uniform_(self.conv.weight)
Example #27
0
    return torch.mean(torch.clamp(1 - y_pred.t() * y_true, min=0))


X, Y = make_blobs(n_samples=1024, centers=2, cluster_std=1.2, random_state=1)
X = (X - X.mean()) / X.std()
Y[np.where(Y == 0)] = -1
X, Y = torch.FloatTensor(X), torch.FloatTensor(Y)


delta = 0.01
x = np.arange(X[:, 0].min(), X[:, 0].max(), delta)
y = np.arange(X[:, 1].min(), X[:, 1].max(), delta)
x, y = np.meshgrid(x, y)
xy = list(map(np.ravel, [x, y]))

CONTOUR = torchbearer.state_key('contour')

def mypause(interval):
    backend = plt.rcParams['backend']
    if backend in matplotlib.rcsetup.interactive_bk:
        figManager = matplotlib._pylab_helpers.Gcf.get_active()
        if figManager is not None:
            canvas = figManager.canvas
            if canvas.figure.stale:
                canvas.draw_idle()
            canvas.start_event_loop(interval)
            return


@callbacks.on_start
def scatter(_):
Example #28
0
import torch
import torch.nn.functional as F

from visual.images import IMAGE
import torchbearer

LAYER_DICT = torchbearer.state_key('layer_dict')
""" StateKey under which to store a dictionary of layer outputs for a model. Keys in this dictionary can be accessed as 
strings in the `target` arguments of vision classes. 
"""


def _evaluate_target(state, target, channels=lambda x: x[:]):
    if isinstance(target, torchbearer.StateKey):
        return channels(state[target])
    else:
        return channels(state[LAYER_DICT][target])


class Criterion(object):
    """
    Abstract criterion object for visual gradient ascent.
    """
    def process(self, state):
        """ Calculates the criterion value

        Args:
            state: Torchbearer state
        """
        raise NotImplementedError
Example #29
0
import math

import torch
import torch.nn as nn
from torch import distributions
from torch.distributions import constraints, register_kl

import torchbearer
from torchbearer import state_key

LATENT = state_key('latent')


class LogitNormal(distributions.Normal):
    arg_constraints = {'loc': constraints.real, 'log_scale': constraints.real}
    support = constraints.real
    has_rsample = True

    def __init__(self, loc, log_scale, validate_args=None):
        self.log_scale = log_scale
        scale = distributions.transform_to(
            distributions.Normal.arg_constraints['scale'])(log_scale)
        super().__init__(loc, scale, validate_args=validate_args)

    def log_prob(self, value):
        if self._validate_args:
            self._validate_sample(value)
        # compute the variance
        var = (self.scale**2)
        log_scale = self.log_scale
        return -((value - self.loc)**2) / (2 * var) - log_scale - math.log(
Example #30
0
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchbearer
import torchvision
from torchbearer import Trial, callbacks
from torchvision import transforms

import visualise
from memory import Memory
import tb_modules as tm

MU = torchbearer.state_key('mu')
LOGVAR = torchbearer.state_key('logvar')
STAGES = torchbearer.state_key('stages')


class Block(nn.Module):
    def __init__(self, in_planes, out_planes, stride=1, padding=0):
        super(Block, self).__init__()
        self.conv = nn.Conv2d(in_planes,
                              out_planes,
                              kernel_size=3,
                              padding=padding,
                              stride=stride,
                              bias=False)
        self.bn = nn.BatchNorm2d(out_planes)
        torch.nn.init.xavier_uniform_(self.conv.weight)
Example #31
0
 def test_compare_to_statekey(self):
     key_1 = torchbearer.state_key('test_compare_sk')
     key_2 = torchbearer.state_key('test_compare_sk_2')
     # Simulates same key in different sessions where the object hash is changed
     key_2.key = 'test_compare_sk'
     self.assertEqual(key_1, key_2)
Example #32
0
import torch
from torch.nn import Module

import torchbearer as tb

ESTIMATE = tb.state_key('est')


class Net(Module):
    def __init__(self, x):
        super().__init__()
        self.pars = torch.nn.Parameter(x)

    def f(self):
        """
        function to be minimised:
        f(x) = (x[0]-5)^2 + x[1]^2 + (x[2]-1)^2
        Solution:
        x = [5,0,1]
        """
        out = torch.zeros_like(self.pars)
        out[0] = self.pars[0] - 5
        out[1] = self.pars[1]
        out[2] = self.pars[2] - 1
        return torch.sum(out**2)

    def forward(self, _, state):
        state[ESTIMATE] = self.pars.detach().unsqueeze(1)
        return self.f()

Example #33
0
# Define constants
epochs = 200
batch_size = 64
lr = 0.0002
nworkers = 8
latent_dim = 100
sample_interval = 400
img_shape = (1, 28, 28)
adversarial_loss = torch.nn.BCELoss()
device = 'cuda'
valid = torch.ones(batch_size, 1, device=device)
fake = torch.zeros(batch_size, 1, device=device)

# Register state keys (optional)
GEN_IMGS = state_key('gen_imgs')
DISC_GEN = state_key('disc_gen')
DISC_GEN_DET = state_key('disc_gen_det')
DISC_REAL = state_key('disc_real')
G_LOSS = state_key('g_loss')
D_LOSS = state_key('d_loss')


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))