def test_collision(self): _ = torchbearer.state_key('test') key_1 = torchbearer.state_key('test') key_2 = torchbearer.state_key('test') self.assertTrue('test' != str(key_1)) self.assertTrue('test' != str(key_2))
def test_duplicate_string(self): _ = torchbearer.state_key('test_dup') key_1 = torchbearer.state_key('test_dup') key_2 = torchbearer.state_key('test_dup') self.assertTrue('test_dup_1' == str(key_1)) self.assertTrue('test_dup_2' == str(key_2))
def test_contains(self): s = State() key1 = torchbearer.state_key('test_a') key2 = torchbearer.state_key('test_b') s[key1] = 1 s[key2] = 2 self.assertTrue(s.__contains__(key1))
def test_update(self): s = State() key1 = torchbearer.state_key('test_a') key2 = torchbearer.state_key('test_b') new_s = {key1: 1, key2: 2} s.update(new_s) self.assertTrue(s.__contains__(key1)) self.assertTrue(s[key1] == 1)
def test_delete(self): s = State() key1 = torchbearer.state_key('test_a') key2 = torchbearer.state_key('test_b') s[key1] = 1 s[key2] = 2 self.assertTrue(s.__contains__(key1)) s.__delitem__(key1) self.assertFalse(s.__contains__(key1))
def test_warn(self): s = State() key1 = torchbearer.state_key('test_a') key2 = torchbearer.state_key('test_b') with warnings.catch_warnings(record=True) as w: warnings.simplefilter('always') s[key1] = 'key_1' s[key2] = 'key_2' s['bad_key'] = 'bad_key' self.assertTrue(len(w) == 1) self.assertTrue( 'State was accessed with a string' in str(w[-1].message))
def test_key_added(self): key = torchbearer.state_key('key') self.assertTrue(key in torchbearer.STATE_KEYS)
def test_compare_to_string(self): key_1 = torchbearer.state_key('test_compare') self.assertEqual(key_1, 'test_compare')
import os os.environ['CUDA_VISIBLE_DEVICES'] = '1' import math import torch import torch.nn as nn import torch.nn.functional as F from torchbearer import state_key from torchbearer import callbacks T = state_key('t') T_SHUFFLED = state_key('t_shuffled') MI = state_key('mi') class Flatten(nn.Module): def forward(self, x): return x.view(x.size(0), -1) class DoNothing(nn.Module): def forward(self, x): return x def resample(x): return F.fold(F.unfold(x, kernel_size=2, stride=2), (int(x.size(2) / 2), int(x.size(3) / 2)), 1)
from torchvision import transforms import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from memory import Memory import torchbearer from torchbearer import Trial, callbacks from torchbearer.cv_utils import DatasetValidationSplitter import visualise from torch.distributions import RelaxedBernoulli MU = torchbearer.state_key('mu') LOGVAR = torchbearer.state_key('logvar') STAGES = torchbearer.state_key('stages') MASKED_TARGET = torchbearer.state_key('masked') class Block(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0): super(Block, self).__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding) torch.nn.init.kaiming_uniform_(self.conv.weight) def forward(self, x): return self.conv(x)
# Define constants epochs = 200 batch_size = 64 lr = 0.0002 nworkers = 8 latent_dim = 100 sample_interval = 400 img_shape = (1, 28, 28) adversarial_loss = torch.nn.BCELoss() device = 'cuda' valid = torch.ones(batch_size, 1, device=device) fake = torch.zeros(batch_size, 1, device=device) # Register state keys (optional) GEN_IMGS = state_key('gen_imgs') DISC_GEN = state_key('disc_gen') DISC_GEN_DET = state_key('disc_gen_det') DISC_REAL = state_key('disc_real') G_LOSS = state_key('g_loss') D_LOSS = state_key('d_loss') class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() def block(in_feat, out_feat, normalize=True): layers = [nn.Linear(in_feat, out_feat)] if normalize: layers.append(nn.BatchNorm1d(out_feat, 0.8))
from torch import distributions from .vae import VAE, LATENT from implementations.torchbearer_implementation import FMix import argparse parser = argparse.ArgumentParser(description='VAE Training') parser.add_argument('--mode', default='base', type=str, help='name of run') parser.add_argument('--i', default=1, type=int, help='iteration') parser.add_argument('--var', default=1, type=float, help='iteration') parser.add_argument('--epochs', default=100, type=int, help='epochs') parser.add_argument('--dir', default='vaes', type=str, help='directory') args = parser.parse_args() KL = torchbearer.state_key('KL') NLL = torchbearer.state_key('NLL') SAMPLE = torchbearer.state_key('SAMPLE') # Data normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) inv_normalize = transforms.Normalize( (-0.4914 / 0.2023, -0.4822 / 0.1994, -0.4465 / 0.2010), (1 / 0.2023, 1 / 0.1994, 1 / 0.2010)) transform_base = [transforms.ToTensor(), normalize] transform = [ transforms.ColorJitter(0.05, 0.05, 0.05, 0.05), transforms.RandomHorizontalFlip() ] + transform_base
with open(str(directory) + '/' + name, "w") as f: f.write(' '.join(sys.argv)) def save_model_info(model, directory, name='model-info.txt'): with open(str(directory) + '/' + name, "w") as f: f.write(str(model)) # argparse that doesn't show errors and exit class FakeArgumentParser(argparse.ArgumentParser): def error(self, message): pass ORIGINAL_Y_TRUE = torchbearer.state_key('original_y_true') # Loader that transforms the target into the input, but maintains the labels in a separate key def autoenc_loader(state): image, label = torchbearer.deep_to(next(state[torchbearer.ITERATOR]), state[torchbearer.DEVICE], state[torchbearer.DATA_TYPE]) state[torchbearer.X] = image state[torchbearer.Y_TRUE] = image state[ORIGINAL_Y_TRUE] = label def _parse_schedule(sched): if '@' in sched: factor, schtype = sched.split('@') factor = float(factor)
def test_key_added(self): key = torchbearer.state_key('key') self.assertTrue('key' in torchbearer.state.__keys__)
def test_key_metric(self): key = torchbearer.state_key('test') state = {key: 4} self.assertDictEqual(key.process(state), {str(key): 4}) self.assertDictEqual(key.process_final(state), {str(key): 4})
def test_key_repr(self): key = torchbearer.state_key('repr_test') self.assertEqual(str(key), 'repr_test') self.assertEqual(repr(key), 'repr_test')
def test_key_call(self): key = torchbearer.state_key('call_test') state = {key: 'test'} self.assertEqual(key(state), 'test')
def test_duplicate(self): key = torchbearer.state_key(torchbearer.MODEL) self.assertTrue(torchbearer.MODEL != key)
trainset = AutoEncoderMNIST(basetrainset) valset = AutoEncoderMNIST(basevalset) traingen = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8) valgen = torch.utils.data.DataLoader(valset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8) # State keys MU, LOGVAR = torchbearer.state_key('mu'), torchbearer.state_key('logvar') class VAE(nn.Module): def __init__(self): super(VAE, self).__init__() self.fc1 = nn.Linear(784, 400) self.fc21 = nn.Linear(400, 20) self.fc22 = nn.Linear(400, 20) self.fc3 = nn.Linear(20, 400) self.fc4 = nn.Linear(400, 784) def encode(self, x): h1 = F.relu(self.fc1(x)) return self.fc21(h1), self.fc22(h1)
# Define constants epochs = 200 batch_size = 64 lr = 0.0002 nworkers = 8 latent_dim = 100 sample_interval = 400 img_shape = (1, 28, 28) adversarial_loss = torch.nn.BCELoss() device = 'cuda' valid = torch.ones(batch_size, 1, device=device) fake = torch.zeros(batch_size, 1, device=device) batch = torch.randn(25, latent_dim).to(device) # Register state keys (optional) GEN_IMGS = state_key('gen_imgs') DISC_GEN = state_key('disc_gen') DISC_GEN_DET = state_key('disc_gen_det') DISC_REAL = state_key('disc_real') G_LOSS = state_key('g_loss') D_LOSS = state_key('d_loss') DISC_OPT = state_key('disc_opt') GEN_OPT = state_key('gen_opt') DISC_MODEL = state_key('disc_model') DISC_IMGS = state_key('disc_imgs') DISC_CRIT = state_key('disc_crit') class Generator(nn.Module): def __init__(self):
import torch import torch.nn.functional as F import torch.nn as nn import torchbearer from torchbearer import Trial import torchbearer.callbacks as callbacks import torchbearer.callbacks.imaging as imaging from scattered_cifar import ScatteredCIFAR10 import torchvision.transforms as transforms TRANSFORMED = torchbearer.state_key('transformed') class FoveatedConv2d(nn.Module): def __init__(self, in_channels, out_channels, input_size=96, pool=False): super().__init__() out_channels = int(out_channels / 4) if pool: self.conv1 = nn.Sequential( nn.AvgPool2d(2), nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)) else: self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
import numpy as np import matplotlib.pyplot as plt import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torchvision from torchvision import transforms from torchvision.utils import make_grid from torchvision.datasets import FashionMNIST import torchbearer import torchbearer.callbacks as callbacks from torchbearer import Trial, state_key import os os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' MU = state_key('mu') LOGVAR = state_key('logvar') class VAE(nn.Module): def __init__(self, latent_size): super(VAE, self).__init__() self.latent_size = latent_size self.encoder = nn.Sequential( nn.Conv2d(1, 32, 4, 1, 2), # B, 32, 28, 28 nn.ReLU(True), nn.Conv2d(32, 32, 4, 2, 1), # B, 32, 14, 14 nn.ReLU(True), nn.Conv2d(32, 64, 4, 2, 1), # B, 64, 7, 7 )
import unittest from mock import Mock import torch import torchbearer from torchbearer.variational import DivergenceBase, SimpleNormalUnitNormalKL, SimpleNormalSimpleNormalKL, SimpleNormal, SimpleWeibull, SimpleWeibullSimpleWeibullKL key = torchbearer.state_key('divergence_test') class TestDivergenceBase(unittest.TestCase): def test_on_criterion(self): divergence = DivergenceBase({'test': key}).with_sum_sum_reduction() divergence.compute = Mock( return_value=torch.ones((2, 2), requires_grad=True)) state = {torchbearer.LOSS: torch.zeros(1, requires_grad=True), key: 1} divergence.on_criterion(state) self.assertTrue(state[torchbearer.LOSS].item() == 4) self.assertTrue(state[torchbearer.LOSS].requires_grad) divergence.compute.assert_called_once_with(test=1) def test_on_criterion_validation(self): divergence = DivergenceBase({'test': key}).with_sum_sum_reduction() divergence.compute = Mock( return_value=torch.ones((2, 2), requires_grad=True)) state = {torchbearer.LOSS: torch.zeros(1, requires_grad=True), key: 1}
import warnings import torch from torchbearer import Metric import torchbearer from torchbearer.metrics import CategoricalAccuracy, mean, running_mean from dsketch.experiments.shared import utils from dsketch.losses import chamfer from dsketch.utils.mod_haussdorff import binary_image_to_points, mod_hausdorff_distance HARDRASTER = torchbearer.state_key('hardraster') SQ_DISTANCE_TRANSFORM = torchbearer.state_key('squared_distance_transform') Y_PRED_CLASSES = torchbearer.state_key('y_pred_classes') @running_mean @mean class ClassificationMetric(Metric): def __init__(self, classification_model): super().__init__("recon_class_acc") self.classification_model = classification_model def process(self, state): y_pred = self.classification_model(state[ torchbearer.Y_PRED]) # take the reconstruction and classify it y_true = state[utils.ORIGINAL_Y_TRUE] if len(y_true.shape) == 2: _, y_true = torch.max(y_true, 1)
import math import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torchbearer from torchbearer import cite IMAGE = torchbearer.state_key('image') """ State key under which to hold the image being ascended on """ _stanley2007compositional = """ @article{stanley2007compositional, title={Compositional pattern producing networks: A novel abstraction of development}, author={Stanley, Kenneth O}, journal={Genetic programming and evolvable machines}, volume={8}, number={2}, pages={131--162}, year={2007}, publisher={Springer} } """ def _correlate_color(image, correlation, max_norm): if image.size(0) == 4: alpha = image[-1].unsqueeze(0)
import os import torch import torch.nn as nn import torch.optim as optim import torchbearer import torchvision from torchbearer import Trial, callbacks from torchvision import transforms import tb_modules as tm MU = torchbearer.state_key('mu') LOGVAR = torchbearer.state_key('logvar') class Block(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0): super(Block, self).__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding) torch.nn.init.kaiming_uniform_(self.conv.weight)
return torch.mean(torch.clamp(1 - y_pred.t() * y_true, min=0)) X, Y = make_blobs(n_samples=1024, centers=2, cluster_std=1.2, random_state=1) X = (X - X.mean()) / X.std() Y[np.where(Y == 0)] = -1 X, Y = torch.FloatTensor(X), torch.FloatTensor(Y) delta = 0.01 x = np.arange(X[:, 0].min(), X[:, 0].max(), delta) y = np.arange(X[:, 1].min(), X[:, 1].max(), delta) x, y = np.meshgrid(x, y) xy = list(map(np.ravel, [x, y])) CONTOUR = torchbearer.state_key('contour') def mypause(interval): backend = plt.rcParams['backend'] if backend in matplotlib.rcsetup.interactive_bk: figManager = matplotlib._pylab_helpers.Gcf.get_active() if figManager is not None: canvas = figManager.canvas if canvas.figure.stale: canvas.draw_idle() canvas.start_event_loop(interval) return @callbacks.on_start def scatter(_):
import torch import torch.nn.functional as F from visual.images import IMAGE import torchbearer LAYER_DICT = torchbearer.state_key('layer_dict') """ StateKey under which to store a dictionary of layer outputs for a model. Keys in this dictionary can be accessed as strings in the `target` arguments of vision classes. """ def _evaluate_target(state, target, channels=lambda x: x[:]): if isinstance(target, torchbearer.StateKey): return channels(state[target]) else: return channels(state[LAYER_DICT][target]) class Criterion(object): """ Abstract criterion object for visual gradient ascent. """ def process(self, state): """ Calculates the criterion value Args: state: Torchbearer state """ raise NotImplementedError
import math import torch import torch.nn as nn from torch import distributions from torch.distributions import constraints, register_kl import torchbearer from torchbearer import state_key LATENT = state_key('latent') class LogitNormal(distributions.Normal): arg_constraints = {'loc': constraints.real, 'log_scale': constraints.real} support = constraints.real has_rsample = True def __init__(self, loc, log_scale, validate_args=None): self.log_scale = log_scale scale = distributions.transform_to( distributions.Normal.arg_constraints['scale'])(log_scale) super().__init__(loc, scale, validate_args=validate_args) def log_prob(self, value): if self._validate_args: self._validate_sample(value) # compute the variance var = (self.scale**2) log_scale = self.log_scale return -((value - self.loc)**2) / (2 * var) - log_scale - math.log(
import os import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torchbearer import torchvision from torchbearer import Trial, callbacks from torchvision import transforms import visualise from memory import Memory import tb_modules as tm MU = torchbearer.state_key('mu') LOGVAR = torchbearer.state_key('logvar') STAGES = torchbearer.state_key('stages') class Block(nn.Module): def __init__(self, in_planes, out_planes, stride=1, padding=0): super(Block, self).__init__() self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=3, padding=padding, stride=stride, bias=False) self.bn = nn.BatchNorm2d(out_planes) torch.nn.init.xavier_uniform_(self.conv.weight)
def test_compare_to_statekey(self): key_1 = torchbearer.state_key('test_compare_sk') key_2 = torchbearer.state_key('test_compare_sk_2') # Simulates same key in different sessions where the object hash is changed key_2.key = 'test_compare_sk' self.assertEqual(key_1, key_2)
import torch from torch.nn import Module import torchbearer as tb ESTIMATE = tb.state_key('est') class Net(Module): def __init__(self, x): super().__init__() self.pars = torch.nn.Parameter(x) def f(self): """ function to be minimised: f(x) = (x[0]-5)^2 + x[1]^2 + (x[2]-1)^2 Solution: x = [5,0,1] """ out = torch.zeros_like(self.pars) out[0] = self.pars[0] - 5 out[1] = self.pars[1] out[2] = self.pars[2] - 1 return torch.sum(out**2) def forward(self, _, state): state[ESTIMATE] = self.pars.detach().unsqueeze(1) return self.f()