Example #1
0
def test_no_errors():
    rnn = GeneratorRNN(1)
    strokes = generate_sequence(rnn, 10)
    plot_stroke(strokes, 'strokes.png')

    rnn = GeneratorRNN(20)
    strokes = generate_sequence(rnn, 10)
    plot_stroke(strokes, 'strokes.png')

    rnn = GeneratorRNN(20)
    strokes = generate_sequence(rnn, 10, bias=10)
    plot_stroke(strokes, 'strokes.png')
Example #2
0
def train_all_random_batch(rnn: GeneratorRNN,
                           optimizer: torch.optim.Optimizer,
                           data,
                           output_directory='./output',
                           tail=True):
    batch_size = 100
    i = 0
    model_dir = os.path.join(output_directory, 'models_batch_uncond')
    sample_dir = os.path.join(output_directory, 'batch_uncond')
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    if not os.path.exists(sample_dir):
        os.makedirs(sample_dir)
    pbar = tqdm()
    while True:
        index = np.random.choice(range(len(data) - batch_size), size=1)[0]
        batched_data = data[index:(index + batch_size)]
        if i % 50 == 0:
            for b in [0, 0.1, 1, 5]:
                generated_strokes = generate_sequence(rnn, 700, bias=b)
                file_name = os.path.join(sample_dir, '%d-%s.png' % (i, b))
                plot_stroke(generated_strokes, file_name)
                tqdm.write('Writing file: %s' % file_name)
            model_file_name = os.path.join(model_dir, '%d.pt' % i)
            torch.save(rnn.state_dict(), model_file_name)
        i += 1
        if tail:
            train_full_batch(rnn, optimizer, batched_data)
        else:
            train_truncated_batch(rnn, optimizer, batched_data)
        pbar.update(1)
    return
Example #3
0
def test_correct_distribution():
    """
    Check that the loss function matches the sequence generation function.
    No assertions are made here. Must check manually.
    """
    rnn1 = GeneratorRNN(1)
    rnn2 = GeneratorRNN(1)
    strokes1 = generate_sequence(rnn1, 20, bias=10000)
    strokes2 = generate_sequence(rnn2, 20, bias=10000)
    loss11 = unconditioned.compute_loss(rnn1, strokes1)
    loss22 = unconditioned.compute_loss(rnn2, strokes2)
    loss12 = unconditioned.compute_loss(rnn1, strokes2)
    loss21 = unconditioned.compute_loss(rnn2, strokes1)
    print("loss11", loss11)
    print("loss21", loss21)
    print("loss22", loss22)
    print("loss12", loss12)
Example #4
0
 def test_generate_sequence(self):
     generated = generate_sequence(self.cfd, ('a',), 6)
     # Python 3 changed its seeding method, so the sequence we get is different
     if six.PY2:
         expected = list('abbbab')
     elif six.PY3:
         expected = list('ababac')
     self.assertEqual(generated, expected)
Example #5
0
def train_all(rnn: GeneratorRNN, optimizer: torch.optim.Optimizer, data):
    i = 0
    while True:
        for strokes in tqdm(data):
            if i % 50 == 0:
                for b in [0, 0.1, 1, 5]:
                    generated_strokes = generate_sequence(rnn, 700, bias=b)
                    file_name = 'output/uncond/%d-%s.png' % (i, b)
                    plot_stroke(generated_strokes, file_name)
                    tqdm.write('Writing file: %s' % file_name)
                torch.save(rnn.state_dict(), "output/models/%d.pt" % i)
            i += 1
            train(rnn, optimizer, strokes)
    return
Example #6
0
def train_model(model, epochs, criterion, optimizer):

    for epoch in range(epochs):
        epoch_loss = 0

        for _ in range(100):
            inp_x, inp_y = generate_sequence(sq_len)
            inp_x.transpose_(0, 1)
            h = torch.zeros(1, batch_size, hid_size)
            c = torch.zeros(1, batch_size, hid_size)
            state = (h, c)

            for step in inp_x:
                output, state = model(step.unsqueeze(0), state)

            output = output.squeeze(0)
            loss = criterion(output, inp_y)
            model.zero_grad()
            loss.backward()
            optimizer.step()

            loss_val = loss.item()
            print(loss_val)
            print(output, inp_y)
Example #7
0
def test_biased_generation():
    rnn = GeneratorRNN(20)
    strokes1 = generate_sequence(rnn, 20, bias=10000)
    strokes2 = generate_sequence(rnn, 20, bias=10000)
    diff = np.abs(strokes1 - strokes2)
    print(diff)
Example #8
0
def train_model(model, epochs, criterion, optimizer_1, optimizer_2):

    for epoch in range(epochs):
        epoch_loss = 0

        for _ in range(100):
            inp_x, inp_y = generate_sequence(sq_len)
            inp_x.transpose_(0, 1)
            h = torch.zeros(batch_size, hid_size)
            c = torch.zeros(batch_size, hid_size)
            past_states = []

            for i in range(sq_len):
                best_norm = 1e8
                best_id = -1
                past_states.append(h)

                for j in range(i):
                    time_diff = tensor([[(i - j) / (1.0 * sq_len)]])
                    inp = torch.cat((past_states[j].detach(), time_diff),
                                    dim=1)
                    pred_mean, pred_var = predictor(inp)
                    pred_mean, pred_var = pred_mean.detach(), pred_var.detach()
                    norm = torch.norm(pred_var)

                    if norm < best_norm:
                        best_norm = norm
                        best_id = j
                        best_inp = inp

                if best_id != -1:
                    pred_mean, pred_var = predictor(best_inp)
                    pred_var = pred_var.detach()
                    # look into different ways to combine past into present
                    alpha = torch.min(0.5 + pred_var, torch.ones(1, hid_size))
                    h = alpha * h + (1 - alpha) * pred_mean

                output, (h, c) = net(inp_x[i], (h, c))

            loss = criterion(output, inp_y)
            optimizer_1.zero_grad()
            loss.backward(retain_graph=True)
            optimizer_1.step()

            arr_x = []
            arr_y = []
            for i in range(sq_len):
                for j in range(i):
                    time_diff = tensor([[(i - j) / (1.0 * sq_len)]])
                    inp = torch.cat((past_states[j], time_diff), dim=1)
                    arr_x.append(inp)
                    arr_y.append(past_states[i])
            p_x = torch.stack(arr_x)
            p_y = torch.stack(arr_y)

            pred_mean, pred_var = predictor(p_x)
            p_loss = torch.log(2 * math.pi * pred_var) + torch.pow(
                (pred_mean - p_y) / (pred_var + 1e-7), 2)
            p_loss = p_loss.mean()

            optimizer_2.zero_grad()
            p_loss.backward()
            optimizer_2.step()

            loss_val = loss.item()
            p_val = p_loss.item()
            print(loss_val, p_val)
            print(output, inp_y)
Example #9
0
    # Train/Generate
    if args.train:
        optimizer = create_optimizer(rnn)
        if args.unconditioned:
            if args.batch:
                unconditioned.train_all_random_batch(rnn, optimizer,
                                                     normalized_data)
            else:
                unconditioned.train_all(rnn, optimizer, normalized_data)
        elif args.conditioned:
            if args.batch:
                conditioned.train_all_random_batch(rnn, optimizer,
                                                   normalized_data)
            else:
                conditioned.train_all(rnn, optimizer, normalized_data)
    elif args.generate:
        if args.unconditioned:
            print("Generating a random handwriting sample.")
            strokes = generator.generate_sequence(rnn, 700, 1)
        elif args.conditioned:
            target_sentence = args.output_text
            print("Generating handwriting for text: %s" % target_sentence)
            target_sentence = Variable(torch.from_numpy(
                sentence_to_vectors(target_sentence,
                                    alphabet_dict)).float().cuda(),
                                       requires_grad=False)
            strokes = generator.generate_conditioned_sequence(
                rnn, 2000, target_sentence, 3)
        plot_stroke(strokes, "output.png")
Example #10
0
 def test_generate_sequence_assertion(self):
     with self.assertRaises(AssertionError):
         generate_sequence(self.cfd, ('c',), 4, condition_length=3)
Example #11
0
 def test_generate_sequence_terminates(self):
     generated = generate_sequence(self.cfd, ('c',), 3)
     expected = list('cd')
     self.assertEqual(generated, expected)