def test_conditional_wgan(self): E = lambda x: x[1] G = View([2]) @ ConvNet([[1, 2], [1, 1], [1]]) D = View([1]) @ ConvNet([[3, 1], [1, 1], [1]]) cgan = WGAN.conditional(G, D, E) self.assertTrue(cgan.encoder == E) z = torch.randn([10, 1, 1]) p_gen = cgan(z) x_gen = cgan.gen(z) self.assertTrue(tuple(p_gen.shape) == (10, 3)) self.assertTrue(tuple(x_gen.shape) == (10, 2)) self.assertClose(p_gen, torch.cat( [z.flatten(1), x_gen.flatten(1)], dim=1))
def test_conditional_wgan_fit(self, writer): E = lambda x: x[:, 1:] G = Affine(1, 2) D = View([1]) @ ConvNet([[3, 12, 1], [1, 1, 1], [1, 1]]) @ View([3, 1]) cgan = WGAN.conditional(G, D, E, ns, lr_crit) cgan.writer = writer #--- 2d-gaussian with (.3, .3) mean and (.05, .1) stdev mean = torch.tensor([.3, .3], device="cuda") devs = torch.tensor([.05, .1], device="cuda") x_true = mean + (torch.randn([N, Nb, 2], device="cuda") * devs) #--- 1d-gaussian codes with .3 mean and .1 stdev z = mean[1] + (torch.randn([N, Nb, 1], device="cuda") * devs[1]) #--- fit on dataset of (code, sample) pairs dset = [(zi, xi) for zi, xi in zip(z, x_true)] cgan.cuda() print(f"\n\tn_gen = {cgan.n_gen} \tlr_gen = {lr_gen}") print(f"\tn_crit = {cgan.n_crit} \tlr_crit = {cgan.lr_crit}\n") cgan.fit(dset, lr=lr_gen, epochs=epochs, progress=True, tag=tag) #--- generate z = z.view([-1, 1]) with torch.no_grad(): x_gen = cgan.gen(z) #--- check section consistency self.assertClose(z, E(x_gen), tol=.1) #--- check mean and support print(f"\n\t=> x_gen.mean : {x_gen.mean([0])}") print(f"\t x_gen.std : {x_gen.std([0])}") self.assertClose(z.mean([0]), E(x_true.view([-1, 1])).mean([0]), tol=.1) self.assertClose(x_gen.mean([0]), mean, tol=.1) self.assertClose(x_gen[:, 0].std([0]), 0, tol=.1)
def test_view(self): conv = ConvNet([[6, 32], [12, 1], [12]]) view = View([32]) model = view @ conv x = torch.randn([N, 6, 12]) result = tuple(model(x).shape) expect = (N, 32) self.assertEqual(expect, result)
def generator(args): c = 2 if args.mask else 1 scale_out = Affine(c, c, dim=-2) conv = Pipe(Linear(dz, 128), View([16, 8]), ConvNet(args.layers_G), scale_out) return Pipe(conv)
def test_shapes(self): # Nc = 1 implicit in input conv1 = ConvNet([[1, 3], [12, 4], [3]]) x1 = torch.randn([N, 12]) result = tuple(conv1(x1).shape) expect = (N, 3, 4) self.assertEqual(expect, result) # Npts = 1 not squeezed on output conv2 = ConvNet([[2, 4, 8], [8, 4, 1], [4, 4]]) x2 = torch.randn([N, 2, 8]) result = tuple(conv2(x2).shape) expect = (N, 8, 1) self.assertEqual(expect, result) conv3 = ConvNet([[2, 4], [3, 6], [2]]) x3 = torch.randn(N, 2, 3) result = tuple(conv3(x3).shape) expect = (N, 4, 6) self.assertEqual(expect, result)
def critic(args): c = 2 if args.mask else 1 # vicreg encoder twins = Twins.load(args.input).freeze() encoder = twins.model.module1 # map input to non-saturating domain scale_in = Affine(1, 1, dim=-2) with torch.no_grad(): scale_in.bias = nn.Parameter(torch.tensor([.1])) scale_in.weight = nn.Parameter(torch.tensor([[.2]])) D = Pipe(scale_in, View([c, 64]), ConvNet(args.layers_D, activation=F.leaky_relu), View([1]), Linear(1, 1)) D = Lipschitz(D, args.beta) D.scale_in = D.model.module0 return D
def main(model_state=None, Npulses=64, minutes=6): # model variation losses print( f"loading model from '{model_state}'" if model_state else "model = Id") model = (ConvNet.load(model_state) if model_state else lambda x: x) losses = [mean_loss(model), diff_loss(model)] loss = mixed_loss(losses) # files with {label} print(f"filtering files with '{label}' timestamps") keys = db.filter(lambda f: db.periods[f.key][label]) # filters and peak detection print(f"extracting pulses from {len(keys)} recordings") Npts = minutes * 6000 bp = bandpass(.6, 12, fs) argmin = Troughs(Npts, 50) # main loop out, good = [], [] bad = {"y_quant": [], "amp": [], "errors": []} for k in tqdm.tqdm(keys): keep = True evts = db.periods[k] file = db.get(k) try: i0 = int(100 * (evts[label][0] - evts["start"])) icp = file.icp(i0, Npts) icp = filter_spikes(icp)[0] troughs = argmin(bp(icp)) if icp.shape[0] != Npts: raise RuntimeError("not enough points") if quantization_y(icp) >= .099: bad["y_quant"].append(k) keep = False continue segments = select_pulses(bp(icp), troughs, Npulses, loss) if amplitude_avg(segments[1]) <= 1: bad["amp"].append(k) keep = False if keep: out.append(segments) good.append(k) except Exception as e: bad["errors"].append(k) file.close() #--- Save output print(f"saving output as '{dest}'") xs = [torch.stack([x[i] for x in out]) for i in range(4)] names = ["masks", "pulses", "means", "slopes"] data = {ni: xi for xi, ni in zip(xs, names)} data["keys"] = good data |= bad for n in names: print(f" + {n}\t: {list(data[n].shape)} tensor") print(f" + keys\t: {xs[0].shape[0]} list string") torch.save(data, f'{dest}') print(f"extracted {Npulses} pulses from {len(data['keys'])} recordings") print(f" - {len(bad['y_quant'])} bad Y-quantizations encountered") print(f" - {len(bad['amp'])} low amplitudes encountered") print(f" - {len(bad['errors'])} errors encountered")
import test import torch from revert.models import ConvNet, Pipe, Prod, View,\ Stack, Cat, Cut N = 20 x = torch.randn([N, 6, 32]) y = torch.randn([N, 24, 8]) f1 = ConvNet([[6, 12], [32, 16], [4]]) f2 = ConvNet([[12, 24], [16, 8], [4]]) class TestModule(test.TestCase): def test_matmul(self): model = f2 @ f1 result = tuple(model(x).shape) expect = (N, 24, 8) self.assertEqual(expect, result) def test_pipe(self): model = Pipe(f1, f2) result = tuple(model(x).shape) expect = (N, 24, 8) self.assertEqual(expect, result) def test_loss(self): f2.loss = lambda out, tgt : ((out - tgt)**2).sum() model = Pipe(f1, f2)
import torch import sys from revert.models import WGAN, WGANCritic, Clipped,\ ConvNet, View, Affine ns = (5, 100) # (n_gen, n_crit) lr_gen, lr_crit = (5e-3, 5e-3) clip = .5 N, Nb = 512, 256 epochs = 5 tag = "critic score" dx, dz = 6, 3 G = Affine(3, 6) D = View([1]) @ ConvNet([[dx, 12, 1], [1, 1, 1], [1, 1]]) @ View([dx, 1]) #--- hyperplane x_true = torch.randn([N, dx]) x_true += (.2 - x_true.mean()) #--- generated distribution z = torch.randn([N, dz]) x_gen = G(z) #--- labels xs = torch.cat([x_gen, x_true]) ys = (torch.tensor([0, 1]).repeat_interleave(N).flatten()) options = sys.argv[1] if len(sys.argv) > 1 else '' class TestWGAN(test.TestCase):
shifted, y = shift_all(stdev)(flows) data_dataset = TensorDataset(shifted, y) data_loader = DataLoader(data_dataset, shuffle=True, batch_size=1) return data_loader #================================================== #--- Models --- Npts = 32 layers = [[Npts, 6, 8], [16, 6 * 12, 8], [8, 6 * 24, 8], [1, 6 * 12, 1]] base = ConvNet(layers, pool='max') dim_out = 6 * 12 dim_task = 6 head = ConvNet([[1, dim_out, 1], [1, dim_task, 1]]) convnet = Pipe(base, head) # find the path to save args = read_args(arg_parser(prefix='convnet')) if args.input: convnet = convnet.load(args.input) #--- Main ---