示例#1
0
 def test_conditional_wgan(self):
     E = lambda x: x[1]
     G = View([2]) @ ConvNet([[1, 2], [1, 1], [1]])
     D = View([1]) @ ConvNet([[3, 1], [1, 1], [1]])
     cgan = WGAN.conditional(G, D, E)
     self.assertTrue(cgan.encoder == E)
     z = torch.randn([10, 1, 1])
     p_gen = cgan(z)
     x_gen = cgan.gen(z)
     self.assertTrue(tuple(p_gen.shape) == (10, 3))
     self.assertTrue(tuple(x_gen.shape) == (10, 2))
     self.assertClose(p_gen,
                      torch.cat(
                          [z.flatten(1), x_gen.flatten(1)], dim=1))
示例#2
0
 def test_conditional_wgan_fit(self, writer):
     E = lambda x: x[:, 1:]
     G = Affine(1, 2)
     D = View([1]) @ ConvNet([[3, 12, 1], [1, 1, 1], [1, 1]]) @ View([3, 1])
     cgan = WGAN.conditional(G, D, E, ns, lr_crit)
     cgan.writer = writer
     #--- 2d-gaussian with (.3, .3) mean and (.05, .1) stdev
     mean = torch.tensor([.3, .3], device="cuda")
     devs = torch.tensor([.05, .1], device="cuda")
     x_true = mean + (torch.randn([N, Nb, 2], device="cuda") * devs)
     #--- 1d-gaussian codes with .3 mean and .1 stdev
     z = mean[1] + (torch.randn([N, Nb, 1], device="cuda") * devs[1])
     #--- fit on dataset of (code, sample) pairs
     dset = [(zi, xi) for zi, xi in zip(z, x_true)]
     cgan.cuda()
     print(f"\n\tn_gen = {cgan.n_gen} \tlr_gen = {lr_gen}")
     print(f"\tn_crit = {cgan.n_crit} \tlr_crit = {cgan.lr_crit}\n")
     cgan.fit(dset, lr=lr_gen, epochs=epochs, progress=True, tag=tag)
     #--- generate
     z = z.view([-1, 1])
     with torch.no_grad():
         x_gen = cgan.gen(z)
     #--- check section consistency
     self.assertClose(z, E(x_gen), tol=.1)
     #--- check mean and support
     print(f"\n\t=> x_gen.mean : {x_gen.mean([0])}")
     print(f"\t   x_gen.std  : {x_gen.std([0])}")
     self.assertClose(z.mean([0]),
                      E(x_true.view([-1, 1])).mean([0]),
                      tol=.1)
     self.assertClose(x_gen.mean([0]), mean, tol=.1)
     self.assertClose(x_gen[:, 0].std([0]), 0, tol=.1)
示例#3
0
 def test_view(self):
     conv = ConvNet([[6, 32], [12, 1], [12]])
     view = View([32])
     model = view @ conv
     x = torch.randn([N, 6, 12])
     result = tuple(model(x).shape)
     expect = (N, 32)
     self.assertEqual(expect, result)
示例#4
0
def generator(args):

    c = 2 if args.mask else 1

    scale_out = Affine(c, c, dim=-2)

    conv = Pipe(Linear(dz, 128), View([16, 8]), ConvNet(args.layers_G),
                scale_out)

    return Pipe(conv)
示例#5
0
    def test_shapes(self):
        # Nc = 1 implicit in input
        conv1 = ConvNet([[1,  3],
                         [12, 4],
                         [3]])
        x1 = torch.randn([N, 12])
        result = tuple(conv1(x1).shape)
        expect = (N, 3, 4)
        self.assertEqual(expect, result)
        # Npts = 1 not squeezed on output
        conv2 = ConvNet([[2, 4, 8],
                         [8, 4, 1],
                         [4, 4]])
        x2 = torch.randn([N, 2, 8])
        result = tuple(conv2(x2).shape)
        expect = (N, 8, 1)
        self.assertEqual(expect, result)

        conv3 = ConvNet([[2, 4], [3, 6], [2]])
        x3 = torch.randn(N, 2, 3)
        result = tuple(conv3(x3).shape)
        expect = (N, 4, 6)
        self.assertEqual(expect, result)
示例#6
0
def critic(args):
    c = 2 if args.mask else 1
    # vicreg encoder
    twins = Twins.load(args.input).freeze()
    encoder = twins.model.module1
    # map input to non-saturating domain
    scale_in = Affine(1, 1, dim=-2)
    with torch.no_grad():
        scale_in.bias = nn.Parameter(torch.tensor([.1]))
        scale_in.weight = nn.Parameter(torch.tensor([[.2]]))

    D = Pipe(scale_in, View([c, 64]),
             ConvNet(args.layers_D, activation=F.leaky_relu), View([1]),
             Linear(1, 1))

    D = Lipschitz(D, args.beta)
    D.scale_in = D.model.module0
    return D
示例#7
0
def main(model_state=None, Npulses=64, minutes=6):

    # model variation losses
    print(
        f"loading model from '{model_state}'" if model_state else "model = Id")
    model = (ConvNet.load(model_state) if model_state else lambda x: x)
    losses = [mean_loss(model), diff_loss(model)]
    loss = mixed_loss(losses)

    # files with {label}
    print(f"filtering files with '{label}' timestamps")
    keys = db.filter(lambda f: db.periods[f.key][label])

    # filters and peak detection
    print(f"extracting pulses from {len(keys)} recordings")
    Npts = minutes * 6000
    bp = bandpass(.6, 12, fs)
    argmin = Troughs(Npts, 50)

    # main loop
    out, good = [], []
    bad = {"y_quant": [], "amp": [], "errors": []}
    for k in tqdm.tqdm(keys):
        keep = True
        evts = db.periods[k]
        file = db.get(k)
        try:
            i0 = int(100 * (evts[label][0] - evts["start"]))
            icp = file.icp(i0, Npts)
            icp = filter_spikes(icp)[0]
            troughs = argmin(bp(icp))
            if icp.shape[0] != Npts:
                raise RuntimeError("not enough points")
            if quantization_y(icp) >= .099:
                bad["y_quant"].append(k)
                keep = False
                continue
            segments = select_pulses(bp(icp), troughs, Npulses, loss)
            if amplitude_avg(segments[1]) <= 1:
                bad["amp"].append(k)
                keep = False
            if keep:
                out.append(segments)
                good.append(k)
        except Exception as e:
            bad["errors"].append(k)
        file.close()

    #--- Save output

    print(f"saving output as '{dest}'")
    xs = [torch.stack([x[i] for x in out]) for i in range(4)]
    names = ["masks", "pulses", "means", "slopes"]

    data = {ni: xi for xi, ni in zip(xs, names)}
    data["keys"] = good
    data |= bad

    for n in names:
        print(f"  + {n}\t: {list(data[n].shape)} tensor")
    print(f"  + keys\t: {xs[0].shape[0]} list string")

    torch.save(data, f'{dest}')

    print(f"extracted {Npulses} pulses from {len(data['keys'])} recordings")
    print(f"  - {len(bad['y_quant'])} bad Y-quantizations encountered")
    print(f"  - {len(bad['amp'])} low amplitudes encountered")
    print(f"  - {len(bad['errors'])} errors encountered")
示例#8
0
import test
import torch

from revert.models import ConvNet, Pipe, Prod, View,\
                          Stack, Cat, Cut

N = 20 
x = torch.randn([N, 6, 32])
y = torch.randn([N, 24, 8])

f1 = ConvNet([[6, 12],  [32, 16], [4]])
f2 = ConvNet([[12, 24], [16, 8],  [4]])


class TestModule(test.TestCase):
    
    def test_matmul(self):
        model = f2 @ f1
        result = tuple(model(x).shape)
        expect = (N, 24, 8)
        self.assertEqual(expect, result)

    def test_pipe(self): 
        model = Pipe(f1, f2)
        result = tuple(model(x).shape)
        expect = (N, 24, 8)
        self.assertEqual(expect, result)

    def test_loss(self):
        f2.loss = lambda out, tgt : ((out - tgt)**2).sum()
        model = Pipe(f1, f2)
示例#9
0
import torch
import sys

from revert.models import WGAN, WGANCritic, Clipped,\
                          ConvNet, View, Affine

ns = (5, 100)  # (n_gen, n_crit)
lr_gen, lr_crit = (5e-3, 5e-3)
clip = .5
N, Nb = 512, 256
epochs = 5
tag = "critic score"

dx, dz = 6, 3
G = Affine(3, 6)
D = View([1]) @ ConvNet([[dx, 12, 1], [1, 1, 1], [1, 1]]) @ View([dx, 1])

#--- hyperplane
x_true = torch.randn([N, dx])
x_true += (.2 - x_true.mean())
#--- generated distribution
z = torch.randn([N, dz])
x_gen = G(z)
#--- labels
xs = torch.cat([x_gen, x_true])
ys = (torch.tensor([0, 1]).repeat_interleave(N).flatten())

options = sys.argv[1] if len(sys.argv) > 1 else ''


class TestWGAN(test.TestCase):
示例#10
0
    shifted, y = shift_all(stdev)(flows)

    data_dataset = TensorDataset(shifted, y)
    data_loader = DataLoader(data_dataset, shuffle=True, batch_size=1)

    return data_loader


#==================================================

#--- Models ---

Npts = 32
layers = [[Npts, 6, 8], [16, 6 * 12, 8], [8, 6 * 24, 8], [1, 6 * 12, 1]]

base = ConvNet(layers, pool='max')

dim_out = 6 * 12
dim_task = 6
head = ConvNet([[1, dim_out, 1], [1, dim_task, 1]])

convnet = Pipe(base, head)

# find the path to save
args = read_args(arg_parser(prefix='convnet'))
if args.input:
    convnet = convnet.load(args.input)

#--- Main ---