def main():
    device = torch.device(best_available_device())
    model = MNISTBaseline()
    if os.path.exists(OUT_PATH):
        model.load_state_dict(torch.load(OUT_PATH))
    model.to(device)
    opt = optim.Adam(model.parameters(), lr=1e-3)
    samples = iterate_mini_datasets()
    last_n = []
    for i in itertools.count():
        input_batch = []
        output_batch = []
        for inputs, outputs in [next(samples) for _ in range(BATCH)]:
            shifted_outputs = torch.cat(
                [torch.zeros_like(outputs[:1]), outputs[:-1]], dim=0)
            ins = torch.cat([inputs, shifted_outputs.long()], dim=-1)
            input_batch.append(ins)
            output_batch.append(outputs)
        inputs = torch.stack(input_batch, dim=1).to(device)
        outputs = torch.stack(output_batch, dim=1).to(device)
        logits, _ = model(inputs)
        loss = F.binary_cross_entropy_with_logits(logits, outputs)
        last_n.append(loss.item())
        last_n = last_n[-AVG_SIZE:]

        opt.zero_grad()
        loss.backward()
        opt.step()

        model.to(torch.device('cpu'))
        torch.save(model.state_dict(), OUT_PATH)
        model.to(device)

        print('step %d: loss=%f last_%d=%f' %
              (i, loss.item(), AVG_SIZE, np.mean(last_n)))
Exemple #2
0
def main():
    device = torch.device(best_available_device())
    model = make_mnist_model()
    model.load_state_dict(torch.load(OUT_PATH))
    model.to(device)
    history = []
    for i, (inputs, outputs) in enumerate(iterate_mini_datasets(train=False)):
        losses = reptile_grad(model, [(inputs, outputs)], INNER_LR)
        history.append(np.mean(losses))
        print('step %d: loss=%f' % (i, np.mean(history)))
Exemple #3
0
def main():
    device = torch.device(best_available_device())

    model = make_mnist_model()
    model.load_state_dict(torch.load(OUT_PATH))
    model.to(device)

    grid = np.zeros([28 * GRID_SIZE, 28 * GRID_SIZE, 3], dtype=np.uint8)
    for i in range(GRID_SIZE):
        for j in range(GRID_SIZE):
            print('generating tile %d,%d' % (i, j))
            grid[28*i:28*(i+1), 28*j:28*(j+1)] = generate_single(model, device)

    Image.fromarray(grid).save('samples.png')
Exemple #4
0
def main():
    device = torch.device(best_available_device())
    model = MNISTBaseline()
    model.load_state_dict(torch.load(OUT_PATH))
    model.to(device)

    samples = iterate_mini_datasets(train=False)
    history = []
    for i in itertools.count():
        input_batch = []
        output_batch = []
        for inputs, outputs in [next(samples) for _ in range(BATCH)]:
            shifted_outputs = torch.cat([torch.zeros_like(outputs[:1]), outputs[:-1]], dim=0)
            ins = torch.cat([inputs, shifted_outputs.long()], dim=-1)
            input_batch.append(ins)
            output_batch.append(outputs)
        inputs = torch.stack(input_batch, dim=1).to(device)
        outputs = torch.stack(output_batch, dim=1).to(device)
        logits, _ = model(inputs)
        loss = F.binary_cross_entropy_with_logits(logits, outputs)
        history.append(loss.item())
        print('samples=%d loss=%f' % (i * BATCH, np.mean(history)))
Exemple #5
0
def main():
    device = torch.device(best_available_device())
    model = make_text_model()
    if os.path.exists(OUT_PATH):
        model.load_state_dict(torch.load(OUT_PATH))
    model.to(device)
    outer_opt = optim.Adam(model.parameters(), lr=1e-3)
    mini_batches = iterate_mini_datasets(DATASET)

    last_n = []
    for i in itertools.count():
        batch = [next(mini_batches) for _ in range(META_BATCH)]
        outer_opt.zero_grad()
        losses = reptile_grad(model, batch, INNER_LR)
        outer_opt.step()
        loss = np.mean(losses)
        last_n.append(loss)
        last_n = last_n[-AVG_SIZE:]
        model.cpu()
        torch.save(model.state_dict(), OUT_PATH)
        model.to(device)
        print('step %d: loss=%f last_%d=%f' %
              (i, loss, AVG_SIZE, np.mean(last_n)))
Exemple #6
0
def main():
    device = torch.device(best_available_device())
    model = make_text_model()
    model.load_state_dict(torch.load(OUT_PATH))
    model.to(device)
    opt = optim.SGD(model.parameters(), lr=INNER_LR)

    sequence = []

    for i in range(128):
        inputs = torch.from_numpy(np.array([[i]])).to(device).long()
        logits = model(inputs)
        probs = F.softmax(logits[0], dim=0).detach().cpu().numpy()
        sample = np.random.choice(np.arange(256), p=probs)
        sequence.append(int(sample))
        if sample == 0:
            break
        targets = torch.from_numpy(np.array([sample])).to(device).long()
        loss = F.cross_entropy(logits, targets)
        opt.zero_grad()
        loss.backward()
        opt.step()

    print(str(bytes([min(0x79, x) for x in sequence]), 'ascii'))