Exemplo n.º 1
0
def generate_from_synth_model(
        model_path="data/model_files/handwriting_cond_best.pt",
        sentence_list=[
            "hello world !!",
            "this text is generated using an RNN model",
            "Welcome to Lyrebird!",
        ],
        bias=3.0,
        device=torch.device("cpu"),
):
    model = HandWritingSynthRNN()
    if os.path.exists(model_path):
        model.load_state_dict(torch.load(model_path, map_location=device))
    oh_encoder = pickle.load(open("data/one_hot_encoder.pkl", "rb"))
    sentences = [s.to(device) for s in oh_encoder.one_hot(sentence_list)]
    generated_samples, attn_vars = model.generate(sentences=sentences,
                                                  bias=bias,
                                                  device=device,
                                                  use_stopping=True)

    model_name = model_path.split("/")[-1].replace(".pt", "")
    for i in range(len(sentence_list)):
        plot_stroke(
            generated_samples[:, i, :].cpu().numpy(),
            save_name="samples/{}_{}.png".format(model_name, i),
        )
        print(f"generated strokes for: {sentence_list[i]}")
Exemplo n.º 2
0
def train_all_random_batch(rnn: GeneratorRNN,
                           optimizer: torch.optim.Optimizer,
                           data,
                           output_directory='./output',
                           tail=True):
    batch_size = 100
    i = 0
    model_dir = os.path.join(output_directory, 'models_batch_uncond')
    sample_dir = os.path.join(output_directory, 'batch_uncond')
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    if not os.path.exists(sample_dir):
        os.makedirs(sample_dir)
    pbar = tqdm()
    while True:
        index = np.random.choice(range(len(data) - batch_size), size=1)[0]
        batched_data = data[index:(index + batch_size)]
        if i % 50 == 0:
            for b in [0, 0.1, 1, 5]:
                generated_strokes = generate_sequence(rnn, 700, bias=b)
                file_name = os.path.join(sample_dir, '%d-%s.png' % (i, b))
                plot_stroke(generated_strokes, file_name)
                tqdm.write('Writing file: %s' % file_name)
            model_file_name = os.path.join(model_dir, '%d.pt' % i)
            torch.save(rnn.state_dict(), model_file_name)
        i += 1
        if tail:
            train_full_batch(rnn, optimizer, batched_data)
        else:
            train_truncated_batch(rnn, optimizer, batched_data)
        pbar.update(1)
    return
    def generate(self, text, timesteps=1000, seed=None, filepath='samples/conditional.png'):
        self.build_model(seq_length=1, max_sentence_length=len(text))
        sample = np.zeros((1, timesteps + 1, 3), dtype='float32')
        char_index, _ = char_to_index()

        one_hot_text = tf.expand_dims(one_hot_encode(text, self.num_characters, char_index), 0)
        text_length = tf.expand_dims(tf.constant(len(text)), 0)

        input_states = self.initial_states(1)
        for i in range(timesteps):
            outputs, input_states, phi = self.model([sample[:,i:i+1,:], input_states, one_hot_text, text_length])
            input_states[-1] = tf.reshape(input_states[-1], (1, self.num_characters))
            sample[0,i+1] = self.sample(outputs, seed)

            # stopping heuristic
            finished = True
            phi_last = phi[0,0,-1]
            for phi_u in phi[0,0,:-1]:
                if phi_u.numpy() > phi_last.numpy():
                    finished = False
                    break

            # prevent early stopping
            if i < 100:
                finished = False

            if finished:
                break

        # remove first zeros and discard unused timesteps
        sample = sample[0,1:i]
        plot_stroke(sample, save_name=filepath)
        return sample
Exemplo n.º 4
0
def train_all(rnn: GeneratorRNN, optimizer: torch.optim.Optimizer, data):
    i = 0
    while True:
        for strokes in tqdm(data):
            if i % 50 == 0:
                for b in [0, 0.1, 1, 5]:
                    generated_strokes = generate_sequence(rnn, 700, bias=b)
                    file_name = 'output/uncond/%d-%s.png' % (i, b)
                    plot_stroke(generated_strokes, file_name)
                    tqdm.write('Writing file: %s' % file_name)
                torch.save(rnn.state_dict(), "output/models/%d.pt" % i)
            i += 1
            train(rnn, optimizer, strokes)
    return
Exemplo n.º 5
0
    def generate(self,
                 timesteps=400,
                 seed=None,
                 filepath='samples/unconditional.jpeg'):
        self.build_model(seq_length=1)
        sample = np.zeros((1, timesteps + 1, 3), dtype='float32')
        input_states = [tf.zeros((1, self.num_cells))] * 2 * self.num_layers

        for i in range(timesteps):
            outputs, input_states = self.model(
                [sample[:, i:i + 1, :], input_states])
            sample[0, i + 1] = self.sample(outputs, seed)

        # remove first zeros
        sample = sample[0, 1:]
        plot_stroke(sample, save_name=filepath)
        return sample
Exemplo n.º 6
0
def path_to_stroke(path_data, k=1, save_path="./mobile/"):
    """
        Convert svg path data into stroke data with offset coordinates
        args:
            path_data: list of svg path points
            k: downsample factor, default 1 means no downsampling
            save_path: directory path to save stroke.npy file
    """

    save_path = Path(save_path)
    stroke = np.zeros((len(path_data), 3))
    i = 0
    while i < len(path_data):
        command = path_data[i][0]
        coord = path_data[i][1:].split(',')
        if command == 'M':
            stroke[i, 0] = 1.0
        elif command == 'L':
            stroke[i, 0] = 0.0
        stroke[i, 1] = float(coord[0])
        stroke[i, 2] = -float(coord[1])
        i += 1

    stroke[0, 0] = 0.0
    stroke[-1, 0] = 1.
    print("initial shape of data: ", stroke.shape)

    cuts = np.where(stroke[:, 0] == 1.)[0]
    print("EOS index:", cuts)

    start = 0
    down_sample_data = []
    for eos in cuts:
        down_sample_data.append(stroke[start:eos:k])
        down_sample_data.append(stroke[eos])
        start = eos + 1

    down_sample_stroke = np.vstack(down_sample_data)
    # convert absolute coordinates into offset
    down_sample_stroke[
        1:, 1:] = down_sample_stroke[1:, 1:] - down_sample_stroke[:-1, 1:]
    print("final shape of data: ", down_sample_stroke.shape)

    plot_stroke(down_sample_stroke, "img.png")
    np.save(save_path, down_sample_stroke, allow_pickle=True)
def main(config):
    logger = config.get_logger('experiments')

    # setup the device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # setup data_loader instances
    data_loader = config.init_obj('data_loader', module_data)

    # build model architecture
    model = config.init_obj('arch',
                            module_arch,
                            char2idx=data_loader.dataset.char2idx,
                            device=device)
    logger.info(model)

    # Loading the weights of the model
    logger.info('Loading checkpoint: {} ...'.format(config.resume))
    checkpoint = torch.load(config.resume, map_location=device)
    state_dict = checkpoint['state_dict']
    model.load_state_dict(state_dict)

    # prepare model for testing
    model = model.to(device)
    model.eval()

    with torch.no_grad():

        if str(model).startswith('Unconditional'):
            sampled_stroke = model.generate_unconditional_sample()
            plot_stroke(sampled_stroke)

        elif str(model).startswith('Conditional'):
            sampled_stroke = model.generate_conditional_sample('hello world')
            plot_stroke(sampled_stroke)

        elif str(model).startswith('Seq2Seq'):
            sent, stroke = data_loader.dataset[21]
            predicted_seq = model.recognize_sample(stroke)
            print('real text:      ',
                  data_loader.dataset.tensor2sentence(sent))
            print(
                'predicted text: ',
                data_loader.dataset.tensor2sentence(
                    torch.tensor(predicted_seq)))
Exemplo n.º 8
0
def submit_style_data():
    data = request.get_json()
    path = data["path"]
    text = data["text"]
    if path == "":
        return jsonify(
            dict({"redirect": url_for("draw"), "message": "Please enter some style"})
        )

    id = str(uuid.uuid4())
    session["id"] = id
    tmp_dir = os.path.join(flask_app.root_path, "static", "uploads", id)
    if not os.path.exists(tmp_dir):
        os.makedirs(tmp_dir)

    os.chmod(tmp_dir, 0o777)
    print(tmp_dir)
    # user agent info
    user_agent = request.user_agent
    print(user_agent.string)
    print(user_agent.platform)
    phones = ["android", "iphone"]
    down_sample = True

    if user_agent.platform in phones:
        down_sample = False

    text_path = os.path.join(tmp_dir, "inpText.txt")
    print(text_path)
    with open(text_path, "w") as f:
        f.write(text)
    f.close()

    stroke = path_string_to_stroke(
        path, str_len=len(list(text)), down_sample=down_sample
    )
    save_path = os.path.join(tmp_dir, "style.npy")
    np.save(save_path, stroke, allow_pickle=True)
    print(save_path)

    # plot the sequence
    plot_stroke(stroke.astype(np.float32), os.path.join(tmp_dir, "original.png"))

    return jsonify(dict({"redirect": url_for("generate"), "message": ""}))
Exemplo n.º 9
0
def generate_from_model(
        model_path="data/model_files/handwriting_uncond_best.pt",
        sample_length=600,
        num_sample=2,
        bias=0.5,
        device=torch.device("cpu"),
):
    """
    Generate num_sample number of samples each of length sample_length using a 
    pretrained model
    """
    handWritingRNN = HandWritingRNN()
    handWritingRNN.load_state_dict(torch.load(model_path, map_location=device))
    generated_samples = handWritingRNN.generate(device=device,
                                                length=sample_length,
                                                batch=num_sample,
                                                bias=bias)

    model_name = model_path.split("/")[-1].replace(".pt", "")
    for i in range(num_sample):
        plot_stroke(
            generated_samples[:, i, :].cpu().numpy(),
            save_name="samples/{}_{}.png".format(model_name, i),
        )
Exemplo n.º 10
0
def test_no_errors():
    rnn = GeneratorRNN(1)
    strokes = generate_sequence(rnn, 10)
    plot_stroke(strokes, 'strokes.png')

    rnn = GeneratorRNN(20)
    strokes = generate_sequence(rnn, 10)
    plot_stroke(strokes, 'strokes.png')

    rnn = GeneratorRNN(20)
    strokes = generate_sequence(rnn, 10, bias=10)
    plot_stroke(strokes, 'strokes.png')
Exemplo n.º 11
0
def test_conditioned_no_errors():
    d = {'a': 0, 'b': 1, 'c': 2}
    sentence_vec = sentence_to_vectors('abcabc', d)
    sentence_vec = Variable(torch.from_numpy(sentence_vec).float(),
                            requires_grad=False)

    rnn = ConditionedRNN(1, num_chars=3)
    strokes = generate_conditioned_sequence(rnn, 10, sentence_vec)
    plot_stroke(strokes, 'strokes.png')

    rnn = ConditionedRNN(20, num_chars=3)
    strokes = generate_conditioned_sequence(rnn, 10, sentence_vec)
    plot_stroke(strokes, 'strokes.png')

    rnn = ConditionedRNN(20, num_chars=3)
    strokes = generate_conditioned_sequence(rnn, 10, sentence_vec, bias=10)
    plot_stroke(strokes, 'strokes.png')
Exemplo n.º 12
0
def generate_handwriting(
    char_seq="hello world",
    real_text="",
    style_path="../app/static/mobile/style.npy",
    save_path="",
    app_path="",
    n_samples=1,
    bias=10.0,
):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    data_path = os.path.join(app_path, "../data/")
    model_path = os.path.join(app_path,
                              "../results/synthesis/best_model_synthesis.pt")
    # seed = 194
    # print("seed:",seed)
    # torch.manual_seed(seed)
    # np.random.seed(seed)
    # print(np.random.get_state())

    train_dataset = HandwritingDataset(data_path, split="train", text_req=True)

    prime = True
    is_map = False
    style = np.load(style_path, allow_pickle=True,
                    encoding="bytes").astype(np.float32)
    # plot the sequence
    # plot_stroke(style, os.path.join(save_path, "original.png"))

    print("Priming text: ", real_text)
    mean, std, style = data_normalization(style)
    style = torch.from_numpy(style).unsqueeze(0).to(device)
    # style = valid_offset_normalization(Global.train_mean, Global.train_std, style[None,:,:])
    # style = torch.from_numpy(style).to(device)
    print("Priming sequence size: ", style.shape)
    ytext = real_text + " " + char_seq + "  "

    for i in range(n_samples):
        gen_seq, phi = generate_conditional_sequence(
            model_path,
            char_seq,
            device,
            train_dataset.char_to_id,
            train_dataset.idx_to_char,
            bias,
            prime,
            style,
            real_text,
            is_map,
        )
        if is_map:
            plt.imshow(phi, cmap="viridis", aspect="auto")
            plt.colorbar()
            plt.xlabel("time steps")
            plt.yticks(np.arange(phi.shape[0]),
                       list(ytext),
                       rotation="horizontal")
            plt.margins(0.2)
            plt.subplots_adjust(bottom=0.15)
            plt.show()

        # denormalize the generated offsets using train set mean and std
        print("data denormalization...")
        end = style.shape[1]
        # gen_seq[:,:end] = data_denormalization(mean, std, gen_seq[:, :end])
        # gen_seq[:,end:] = data_denormalization(Global.train_mean, Global.train_std, gen_seq[:,end:])
        gen_seq = data_denormalization(Global.train_mean, Global.train_std,
                                       gen_seq)
        # plot the sequence
        print(gen_seq.shape)
        # plot_stroke(gen_seq[0, :end])
        plot_stroke(gen_seq[0],
                    os.path.join(save_path, "gen_stroke_" + str(i) + ".png"))
        print(save_path)
Exemplo n.º 13
0
        gen_seq = generate_conditional_sequence(model_path, args.char_seq,
                                                device,
                                                train_dataset.char_to_id,
                                                train_dataset.idx_to_char,
                                                args.bias, prime,
                                                style, real_text)

    gen_seq = data_denormalization(
        Global.train_mean, Global.train_std, gen_seq)
    gen_seq = np.squeeze(gen_seq)

    # plot the sequence
    if args.save_img:
        img_path = os.path.join(str(args.save_path),
                                "gen_img"+str(time.time())+".png")
        plot_stroke(gen_seq, save_name=img_path)
        print(f"Image saved as: {img_path}")

    if args.save_gif:
        gif_path = os.path.join(str(args.save_path),
                                "gen_img"+str(time.time())+".gif")
        plot_stroke_gif(gen_seq, save_name=gif_path)
        print(f"GIF saved as: {gif_path}")

    # Export generated sequence as json
    seq_list = gen_seq.tolist()
    json_file_path = os.path.join(str(args.save_path), "generated_seq.json")
    json.dump(seq_list,
              codecs.open(json_file_path, 'w', encoding='utf-8'),
              separators=(',', ':'), sort_keys=True, indent=4)
    print(f"Sequence saved to json: {json_file_path}")
Exemplo n.º 14
0
def train(device, args, data_path="data/"):
    """
    """
    random_seed = 42

    writer = SummaryWriter(log_dir=args.logdir, comment="")

    model_path = args.logdir + ("/unconditional_models/"
                                if args.uncond else "/conditional_models/")
    os.makedirs(model_path, exist_ok=True)

    strokes = np.load(data_path + "strokes.npy", encoding="latin1")
    sentences = ""
    with open(data_path + "sentences.txt") as f:
        sentences = f.readlines()
    sentences = [snt.replace("\n", "") for snt in sentences]
    # Instead of removing the newline symbols, should it be used instead?

    MAX_STROKE_LEN = 800
    strokes, sentences, MAX_SENTENCE_LEN = filter_long_strokes(
        strokes, sentences, MAX_STROKE_LEN, max_index=args.n_data)
    # print("Max sentence len after filter is: {}".format(MAX_SENTENCE_LEN))

    # dimension of one-hot representation
    N_CHAR = 57
    oh_encoder = OneHotEncoder(sentences, n_char=N_CHAR)
    pickle.dump(oh_encoder, open("data/one_hot_encoder.pkl", "wb"))
    sentences_oh = [s.to(device) for s in oh_encoder.one_hot(sentences)]

    # normalize strokes data and convert to pytorch tensors
    strokes = normalize_data(strokes)
    # plot_stroke(strokes[1])
    tstrokes = [torch.from_numpy(stroke).to(device) for stroke in strokes]

    # pytorch dataset
    dataset = HandWritingData(sentences_oh, tstrokes)

    # validating the padding lengths
    assert dataset.strokes_padded_len <= MAX_STROKE_LEN
    assert dataset.sentences_padded_len == MAX_SENTENCE_LEN

    # train - validation split
    train_split = 0.95
    train_size = int(train_split * len(dataset))
    validn_size = len(dataset) - train_size
    dataset_train, dataset_validn = torch.utils.data.random_split(
        dataset, [train_size, validn_size])

    dataloader_train = DataLoader(
        dataset_train,
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=False)  # last batch may be smaller than batch_size
    dataloader_validn = DataLoader(dataset_validn,
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   drop_last=False)

    common_model_structure = {
        "memory_cells": 400,
        "n_gaussians": 20,
        "num_layers": 2
    }
    model = (HandWritingRNN(**common_model_structure).to(device)
             if args.uncond else HandWritingSynthRNN(
                 n_char=N_CHAR,
                 n_gaussians_window=10,
                 kappa_factor=0.05,
                 **common_model_structure,
             ).to(device))
    print(model)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=0)
    # optimizer = torch.optim.RMSprop(model.parameters(), lr=1e-2,
    #                                   weight_decay=0, momentum=0)

    if args.resume is None:
        model.init_params()
    else:
        model.load_state_dict(torch.load(args.resume, map_location=device))
        print("Resuming trainig on {}".format(args.resume))
        # resume_optim_file = args.resume.split(".pt")[0] + "_optim.pt"
        # if os.path.exists(resume_optim_file):
        #     optimizer = torch.load(resume_optim_file, map_location=device)

    scheduler = ReduceLROnPlateau(optimizer,
                                  mode="min",
                                  factor=0.1**0.5,
                                  patience=10,
                                  verbose=True)

    best_batch_loss = 1e7
    for epoch in range(200):

        train_losses = []
        validation_iters = []
        validation_losses = []
        for i, (c_seq, x, masks, c_masks) in enumerate(dataloader_train):

            # make batch_first = false
            x = x.permute(1, 0, 2)
            masks = masks.permute(1, 0)

            # remove last point (prepending a dummy point (zeros) already done in data)
            inp_x = x[:-1]  # shape : (T, B, 3)
            masks = masks[:-1]  # shape: (B, T)
            # c_seq.shape: (B, MAX_SENTENCE_LEN, n_char), c_masks.shape: (B, MAX_SENTENCE_LEN)
            inputs = (inp_x, c_seq, c_masks)
            if args.uncond:
                inputs = (inp_x, )

            e, log_pi, mu, sigma, rho, *_ = model(*inputs)

            # remove first point from x to make it y
            loss = criterion(x[1:], e, log_pi, mu, sigma, rho, masks)
            train_losses.append(loss.detach().cpu().numpy())

            optimizer.zero_grad()

            loss.backward()

            # --- this may not be needed
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5)

            optimizer.step()

            # do logging
            print("{},\t".format(loss))
            if i % 10 == 0:
                writer.add_scalar("Every_10th_batch_loss", loss,
                                  epoch * len(dataloader_train) + i)

            # save as best model if loss is better than previous best
            if loss < best_batch_loss:
                best_batch_loss = loss
                model_file = (
                    model_path +
                    f"handwriting_{('un' if args.uncond else '')}cond_best.pt")
                torch.save(model.state_dict(), model_file)
                optim_file = model_file.split(".pt")[0] + "_optim.pt"
                torch.save(optimizer, optim_file)

        epoch_avg_loss = np.array(train_losses).mean()
        scheduler.step(epoch_avg_loss)

        # ======================== do the per-epoch logging ========================
        writer.add_scalar("Avg_loss_for_epoch", epoch_avg_loss, epoch)
        print(f"Average training-loss for epoch {epoch} is: {epoch_avg_loss}")

        model_file = (
            model_path +
            f"handwriting_{('un' if args.uncond else '')}cond_ep{epoch}.pt")
        torch.save(model.state_dict(), model_file)
        optim_file = model_file.split(".pt")[0] + "_optim.pt"
        torch.save(optimizer, optim_file)

        # generate samples from model
        sample_count = 3
        sentences = ["welcome to lyrebird"
                     ] + ["abcd efgh vicki"] * (sample_count - 1)
        sentences = [s.to(device) for s in oh_encoder.one_hot(sentences)]

        if args.uncond:
            generated_samples = model.generate(600,
                                               batch=sample_count,
                                               device=device)
        else:
            generated_samples, attn_vars = model.generate(sentences,
                                                          device=device)

        figs = []
        # save png files of the generated models
        for i in range(sample_count):
            f = plot_stroke(
                generated_samples[:, i, :].cpu().numpy(),
                save_name=args.logdir +
                "/training_imgs/{}cond_ep{}_{}.png".format(
                    ("un" if args.uncond else ""), epoch, i),
            )
            figs.append(f)

        for i, f in enumerate(figs):
            writer.add_figure(f"samples/image_{i}", f, epoch)

        if not args.uncond:
            figs_phi = plot_phi(attn_vars["phi_list"])
            figs_kappa = plot_attn_scalar(attn_vars["kappa_list"])
            for i, (f_phi, f_kappa) in enumerate(zip(figs_phi, figs_kappa)):
                writer.add_figure(f"attention/phi_{i}", f_phi, epoch)
                writer.add_figure(f"attention/kappa_{i}", f_kappa, epoch)
def save_stroke(new_stroke, name):

    print()
    new_stroke = new_stroke.squeeze().data
    new_stroke = np.array(new_stroke)
    lb_uts.plot_stroke(new_stroke, name)
Exemplo n.º 16
0
                nan = True
                print('exiting train @epoch : {}'.format(e))
                break

        mean_loss = np.mean(train_loss)
        print("Epoch {:03d}: Loss: {:.3f}".format(e, mean_loss))

        if e % 1 == 0:
            hws.model.save_weights(EPOCH_MODEL_PATH.format(e))

        if nan:
            break

except KeyboardInterrupt:
    pass

if not nan:
    hws.model.save_weights(MODEL_PATH)

# ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''
# '''''''''''''''''''''''''''''''EVALUATE'''''''''''''''''''''''''''''''
# ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''

verbose_sentence = "".join(D.encoder.inverse_transform(sentence)[0])
strokes1, _, _, _ = hws.infer(sentence,
                              inf_type='max',
                              verbose=verbose_sentence)
plot_stroke(strokes1)
import ipdb
ipdb.set_trace()
Exemplo n.º 17
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--strokes_path",
                        type=str,
                        default="../data/strokes.npy")
    parser.add_argument("--texts_path",
                        type=str,
                        default="../data/sentences.txt")
    parser.add_argument("--batch_size", type=int, default=128)
    parser.add_argument("--num_epochs", type=int, default=100)
    parser.add_argument("--hidden_size", type=int, default=200)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--model_dir", type=str, required=True)
    parser.add_argument("--grad_norm", type=float, default=10.0)
    parser.add_argument("--model_type",
                        choices=["prediction", "synthesis"],
                        required=True)

    parser.add_argument('--disable-cuda',
                        action='store_true',
                        help='Disable CUDA')

    args = parser.parse_args()
    args.device = None
    if not args.disable_cuda and torch.cuda.is_available():
        args.device = torch.device('cuda')
    else:
        args.device = torch.device('cpu')
    print("Device: {}".format(args.device))
    strokes = np.load(args.strokes_path, encoding="latin1", allow_pickle=True)
    with open(args.texts_path) as fin:
        texts = list(map(lambda x: x.strip(), fin))

    attention_scale = 1. / np.mean([
        len(stroke) / len(sentence)
        for stroke, sentence in zip(strokes, texts)
    ])
    print("Attention scale: {}".format(attention_scale))

    alphabet = {}
    for t in texts:
        for c in t:
            if c not in alphabet:
                alphabet[c] = len(alphabet)
    inv_alphabet = {y: x for x, y in alphabet.items()}
    test_text = "Welcome to lyrebird!"

    train_strokes, valid_strokes, train_texts, valid_texts = train_test_split(
        strokes, texts, test_size=0.1)
    no_text = args.model_type == "prediction"
    train_dataset = StrokesDataset(train_strokes, train_texts, alphabet,
                                   no_text)
    valid_dataset = StrokesDataset(valid_strokes, valid_texts, alphabet,
                                   no_text)
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=16)
    valid_dataloader = DataLoader(valid_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=False,
                                  num_workers=16)

    if args.model_type == "prediction":
        model = StrokesPrediction(hidden_size=args.hidden_size)
        model.to(args.device)
        model.sample(device=args.device)
    elif args.model_type == "synthesis":
        model = StrokesSynthesis(num_letters=len(alphabet),
                                 attention_scale=attention_scale,
                                 hidden_size=args.hidden_size)
        model.to(args.device)
        model.sample(torch.LongTensor([alphabet[x] for x in test_text
                                       ])[None, :].to(args.device),
                     device=args.device)
    else:
        raise ValueError("unknown model type")

    optmizer = torch.optim.AdamW(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=1e-6)
    writer = SummaryWriter(args.model_dir)
    print("Starting to train...")
    global_step = 0
    for i in range(args.num_epochs):
        print("Epoch {} / {}".format(i + 1, args.num_epochs))
        for batch_X in tqdm.tqdm(train_dataloader):
            copy_to_device(batch_X, args.device)
            batch_X["strokes_inputs"] = batch_X["strokes_inputs"].permute(
                (1, 0, 2))
            batch_X["strokes_targets"] = batch_X["strokes_targets"].permute(
                (1, 0, 2))
            mixture_loss, end_loss = model.loss(**batch_X, device=args.device)
            loss = mixture_loss + end_loss
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_norm)
            optmizer.step()
            optmizer.zero_grad()
            writer.add_scalar("Mixture loss/train",
                              mixture_loss.cpu().detach().numpy(),
                              global_step=global_step)
            writer.add_scalar("End loss/train",
                              end_loss.cpu().detach().numpy(),
                              global_step=global_step)
            loss = loss.cpu().detach().numpy()
            print("Loss: {}".format(loss))
            writer.add_scalar("Loss/train", loss, global_step=global_step)
            global_step += 1

        model_path = os.path.join(args.model_dir, "epoch_{}.pt".format(i))
        save_model(model, model_path)

        valid_losses = []
        valid_end_losses = []
        valid_mixture_losses = []
        for batch_X in tqdm.tqdm(valid_dataloader):
            copy_to_device(batch_X, args.device)
            batch_X["strokes_inputs"] = batch_X["strokes_inputs"].permute(
                (1, 0, 2))
            batch_X["strokes_targets"] = batch_X["strokes_targets"].permute(
                (1, 0, 2))
            mixture_loss, end_loss = model.loss(**batch_X, device=args.device)
            loss = mixture_loss + end_loss
            valid_losses.append(loss.cpu().detach().numpy())
            valid_end_losses.append(end_loss.cpu().detach().numpy())
            valid_mixture_losses.append(mixture_loss.cpu().detach().numpy())

        writer.add_scalar("Mixture loss/valid",
                          np.mean(valid_mixture_losses),
                          global_step=i)
        writer.add_scalar("End loss/valid",
                          np.mean(valid_end_losses),
                          global_step=i)
        writer.add_scalar("Loss/valid", np.mean(valid_losses), global_step=i)

        if args.model_type == "synthesis":
            h_0, h_1, h_2, prev_w_t, prev_k = model.init_states(
                1, device=args.device)
            strokes_inputs = torch.FloatTensor(
                valid_dataset[0]["strokes_inputs"])[:, None, :].to(args.device)
            text_inputs = torch.LongTensor(
                valid_dataset[0]["text"])[None, :].to(args.device)
            text_lengths = torch.LongTensor([valid_dataset[0]["text_lengths"]
                                             ]).to(args.device)

            mixture_params, end_of_stroke_logits, h_0, h_1, h_2, prev_w_t, prev_k, alignments = model(
                strokes_inputs, text_inputs, text_lengths, h_0, h_1, h_2,
                prev_w_t, prev_k)
            first_alignment = alignments[:valid_dataset[0]["strokes_lengths"],
                                         0, :text_lengths[0]]
            first_alignment = first_alignment.detach().cpu().numpy().T
            writer.add_image("alignment", first_alignment[None, :])

            text = "".join([
                inv_alphabet[int(x)] for x in valid_dataset[0]["text"]
            ]).strip()
            sample = model.sample(text_inputs, device=args.device)
            img_dir = os.path.join(args.model_dir, "imgs")
            os.makedirs(img_dir, exist_ok=True)
            print(text)
            plot_stroke(sample,
                        save_name=os.path.join(img_dir, "epoch_{}".format(i)))
        else:
            sample = model.sample(device=args.device)
            img_dir = os.path.join(args.model_dir, "imgs")
            os.makedirs(img_dir, exist_ok=True)
            plot_stroke(sample,
                        save_name=os.path.join(img_dir, "epoch_{}".format(i)))
Exemplo n.º 18
0
import numpy as np
import sys

sys.path.append("../")
from utils import plot_stroke

style = np.load(
    "./static/uploads/default_style.npy", allow_pickle=True, encoding="bytes"
).astype(np.float32)
# plot the sequence
plot_stroke(style, "default.png")
Exemplo n.º 19
0
def train(args):
	dataset = Dataset(args.data_path, args.batch_size)
	data_loader = DataLoader(dataset, batch_size=1, 
		shuffle=True, collate_fn=_collate_fn)
	outputlayer_name = ['e', 'pi', 'mu1', 'mu2', 'sig1', 'sig2', 'ro'] # for gradient cliping
	pkl_file = open('char2int.pkl', 'rb')
	char2int = pickle.load(pkl_file)
	pkl_file.close()
	test_char = "welcome to lyrebird"
	char2array = torch.from_numpy(np.array([char2int[x] for x in test_char])).long().cuda()
	char2array = char2array.unsqueeze(0)
	epochs = args.num_epoch
	lr = args.lr
	max_len = 600
	use_cuda = torch.cuda.is_available()
	print_freq = 10
	if use_cuda:
		Model = ConditionalModel(use_cuda=use_cuda).cuda()
	else:
		Model = ConditionalModel()
	#train

	optimizer = torch.optim.Adam(Model.parameters(), lr=lr)
	plot_loss = []
	for epoch in range(epochs):
		#batch_loss = 0
		total_loss = 0
		start = time.time()
		for i, (data) in enumerate(data_loader):
			ys, char, ys_mask, char_mask = data
			if use_cuda:
				ys = ys.cuda()
				char = char.long().cuda()
				ys_mask = torch.from_numpy(ys_mask).long().cuda()
				char_mask = torch.from_numpy(char_mask).long().cuda()
			#print(ys)
			prev_state, prev_offset, prev_w = None, None, None
			e, pi, mu1, mu2, sig1, sig2, ro, _, _, _, _= Model(ys.permute(1,0,2)[:-1], char, char_mask, prev_state, prev_offset, prev_w)
			loss = Model.prediction_loss(e, pi, mu1, mu2, sig1, sig2, ro, ys.permute(1,0,2)[1:])
			optimizer.zero_grad()
			loss.backward()

			for name, param in Model.named_parameters():
				if 'lstm' in name:
					param.grad.data.clamp_(-10, 10)
				else:
					param.grad.data.clamp_(-100, 100)
			optimizer.step()
			plot_loss.append(loss.item())
			total_loss += loss.item()
			#print(data)
			if i % print_freq == 0:
				stroke = torch.tensor([1, 0, 0])
				print('Epoch {0} | Iter {1} | Average Loss {2:.3f} | '
                      'Current Loss {3:.6f} | {4:.1f} ms/batch'.format(
                          epoch + 1, i + 1, total_loss / (i + 1),
                          loss.item(), 1000 * (time.time() - start) / (i + 1)),
                      flush=True)

				prediction= Model.generate_samples(stroke, char2array, max_len)
				prediction = prediction.squeeze(0).cpu().numpy()
				print(prediction.shape)
				plot_stroke(prediction, 'generation_fix.png')
				torch.save(Model.state_dict(), "./checkpoint/synthesis_model")
Exemplo n.º 20
0
    # Train/Generate
    if args.train:
        optimizer = create_optimizer(rnn)
        if args.unconditioned:
            if args.batch:
                unconditioned.train_all_random_batch(rnn, optimizer,
                                                     normalized_data)
            else:
                unconditioned.train_all(rnn, optimizer, normalized_data)
        elif args.conditioned:
            if args.batch:
                conditioned.train_all_random_batch(rnn, optimizer,
                                                   normalized_data)
            else:
                conditioned.train_all(rnn, optimizer, normalized_data)
    elif args.generate:
        if args.unconditioned:
            print("Generating a random handwriting sample.")
            strokes = generator.generate_sequence(rnn, 700, 1)
        elif args.conditioned:
            target_sentence = args.output_text
            print("Generating handwriting for text: %s" % target_sentence)
            target_sentence = Variable(torch.from_numpy(
                sentence_to_vectors(target_sentence,
                                    alphabet_dict)).float().cuda(),
                                       requires_grad=False)
            strokes = generator.generate_conditioned_sequence(
                rnn, 2000, target_sentence, 3)
        plot_stroke(strokes, "output.png")
Exemplo n.º 21
0
def train(model, train_loader, valid_loader, batch_size, n_epochs, lr,
          patience, step_size, device, model_type, save_path):
    model_path = save_path + "model_" + model_type + ".pt"
    model = model.to(device)

    if os.path.isfile(model_path):
        model.load_state_dict(torch.load(model_path))
        print(f"[ACTION] Loaded model weights from '{model_path}'")
    else:
        print("[INFO] No saved weights found, training from scratch.")
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = StepLR(optimizer, step_size=step_size, gamma=0.1)

    train_losses = []
    valid_losses = []
    best_loss = math.inf
    best_epoch = 0
    k = 0
    for epoch in range(n_epochs):
        start_time = time.time()
        print(f"[Epoch {epoch + 1}/{n_epochs}]")
        print("[INFO] Training Model.....")
        train_loss = train_epoch(model, optimizer, epoch, train_loader, device,
                                 model_type)

        print("[INFO] Validating Model....")
        valid_loss = validation(model, valid_loader, device, epoch, model_type)

        train_losses.append(train_loss)
        valid_losses.append(valid_loss)

        print(f"[RESULT] Epoch {epoch + 1}/{n_epochs}"
              f"\tTrain loss: {train_loss:.3f}\tVal loss: {valid_loss:.3f}")

        if step_size != -1:
            scheduler.step()

        if valid_loss < best_loss:
            best_loss = valid_loss
            best_epoch = epoch + 1
            print("[SAVE] Saving weights at epoch: {}".format(epoch + 1))
            torch.save(model.state_dict(), model_path)
            if model_type == "prediction":
                gen_seq = generate_unconditional_seq(model_path,
                                                     700,
                                                     device,
                                                     bias=10.0,
                                                     style=None,
                                                     prime=False)

            else:
                gen_seq = generate_conditional_sequence(
                    model_path,
                    "Hello world!",
                    device,
                    train_loader.dataset.char_to_id,
                    train_loader.dataset.idx_to_char,
                    bias=10.0,
                    prime=False,
                    prime_seq=None,
                    real_text=None)

            # denormalize the generated offsets using train set mean and std
            gen_seq = data_denormalization(Global.train_mean, Global.train_std,
                                           gen_seq)

            # plot the sequence
            plot_stroke(
                gen_seq[0],
                save_name=save_path + model_type + "_seq_" + str(best_epoch) +
                ".png",
            )
            k = 0
        elif k > patience:
            print("Best model was saved at epoch: {}".format(best_epoch))
            print("Early stopping at epoch {}".format(epoch))
            break
        else:
            k += 1
        total_time_taken = time.time() - start_time
        print('Time taken per epoch: {:.2f}s\n'.format(total_time_taken))
Exemplo n.º 22
0
from argparse import ArgumentParser

if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('text')
    parser.add_argument('-m', '--model', dest='model', default='conditional')
    parser.add_argument('-e', '--epoch', dest='epoch', type=int, required=True)
    parser.add_argument('-b',
                        '--sample-bias',
                        dest='sample_bias',
                        type=float,
                        default=0.0)
    parser.add_argument('-l',
                        '--stroke-length',
                        dest='stroke_length',
                        type=int,
                        default=None)
    parser.add_argument('-o', '--output', dest='output', default=None)
    args = parser.parse_args()

    stroke = generate_conditionally(text=args.text,
                                    model=args.model,
                                    epoch=args.epoch,
                                    sample_bias=args.sample_bias,
                                    stroke_length=args.stroke_length
                                    or len(args.text.replace(' ', '')) * 30)

    print('Stroke length: {}'.format(len(stroke)))

    plot_stroke(stroke, save_name=args.output)
Exemplo n.º 23
0
def train(model, train_loader, valid_loader, batch_size, n_epochs, lr,
          patience, step_size, device, model_type, save_path):
    model_path = save_path + "best_model_" + model_type + ".pt"
    model = model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = StepLR(optimizer, step_size=step_size, gamma=0.1)

    train_losses = []
    valid_losses = []
    best_loss = math.inf
    best_epoch = 0
    k = 0
    for epoch in range(n_epochs):
        print("training.....")
        train_loss = train_epoch(model, optimizer, epoch, train_loader, device,
                                 model_type)

        print("validation....")
        valid_loss = validation(model, valid_loader, device, epoch, model_type)

        train_losses.append(train_loss)
        valid_losses.append(valid_loss)

        print('Epoch {}: Train: avg. loss: {:.3f}'.format(
            epoch + 1, train_loss))
        print('Epoch {}: Valid: avg. loss: {:.3f}'.format(
            epoch + 1, valid_loss))

        if step_size != -1:
            scheduler.step()

        if valid_loss < best_loss:
            best_loss = valid_loss
            best_epoch = epoch + 1
            print('Saving best model at epoch {}'.format(epoch + 1))
            torch.save(model.state_dict(), model_path)
            if model_type == "prediction":
                gen_seq = generate_unconditional_seq(model_path,
                                                     700,
                                                     device,
                                                     bias=10.0,
                                                     style=None,
                                                     prime=False)

            else:
                gen_seq, phi = generate_conditional_sequence(
                    model_path,
                    "Hello world!",
                    device,
                    train_loader.dataset.char_to_id,
                    train_loader.dataset.idx_to_char,
                    bias=10.,
                    prime=False,
                    prime_seq=None,
                    real_text=None)

                plt.imshow(phi, cmap='viridis', aspect='auto')
                plt.colorbar()
                plt.xlabel("time steps")
                plt.yticks(np.arange(phi.shape[1]),
                           list("Hello world!  "),
                           rotation='horizontal')
                plt.margins(0.2)
                plt.subplots_adjust(bottom=0.15)
                plt.savefig(save_path + "heat_map" + str(best_epoch) + ".png")
                plt.close()
            # denormalize the generated offsets using train set mean and std
            gen_seq = data_denormalization(Global.train_mean, Global.train_std,
                                           gen_seq)

            # plot the sequence
            plot_stroke(gen_seq[0],
                        save_name=save_path + model_type + "_seq_" +
                        str(best_epoch) + ".png")
            k = 0
        elif k > patience:
            print("Best model was saved at epoch: {}".format(best_epoch))
            print("Early stopping at epoch {}".format(epoch))
            break
        else:
            k += 1
Exemplo n.º 24
0
    model_path = args.model_path
    model = args.model

    train_dataset = HandwritingDataset(args.data_path,
                                       split="train",
                                       text_req=args.text_req)

    if args.prime and args.file_path:
        style = np.load(args.file_path + "style.npy",
                        allow_pickle=True,
                        encoding="bytes").astype(np.float32)
        with open(args.file_path + "inpText.txt") as file:
            texts = file.read().splitlines()
        real_text = texts[0]
        # plot the sequence
        plot_stroke(style, save_name=args.save_path / "style.png")
        print(real_text)
        mean, std, _ = data_normalization(style)
        style = torch.from_numpy(style).unsqueeze(0).to(device)
        print(style.shape)
        ytext = real_text + " " + args.char_seq + "  "
    elif args.prime:
        strokes = np.load(args.data_path + "strokes.npy",
                          allow_pickle=True,
                          encoding="bytes")
        with open(args.data_path + "sentences.txt") as file:
            texts = file.read().splitlines()
        idx = 3949  # np.random.randint(0, len(strokes))
        print("Prime style index: ", idx)
        real_text = texts[idx]
        style = strokes[idx]
Exemplo n.º 25
0
                 num_workers=1,
                 collate_fn=pad_collate):
        self.data_dir = data_dir
        self.dataset = HandWritingDataset(data_dir)
        super().__init__(self.dataset, batch_size, shuffle, validation_split,
                         num_workers, collate_fn)


if __name__ == '__main__':
    data_directory = '../data'
    dataset = HandWritingDataset(data_directory)
    dataloader = HandWritingDataLoader(data_directory,
                                       batch_size=32,
                                       shuffle=False)

    # Test dataset
    print('Test dataset')
    print('size of the dataset: ', len(dataset))
    sent, stroke = dataset[0]
    print('sentence 1: ', dataset.tensor2sentence(sent))
    plot_stroke(stroke.numpy())

    # Test dataloader
    print('\nTest dataloader')
    batch = next(iter(dataloader))
    (sentences, sentences_mask, strokes, strokes_mask) = batch
    print('shape sentences:      ', sentences.shape)
    print('shape sentences_mask: ', sentences_mask.shape)
    print('shape strokes:        ', strokes.shape)
    print('shape strokes_mask:   ', strokes_mask.shape)
Exemplo n.º 26
0
# -*- coding: utf-8 -*-
"""
Created on Mon Apr 26 23:01:13 2021

@author: prajw
"""
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

from data import DataSynthesis
from models import HandWritingPrediction, HandWritingSynthesis
from utils import plot_stroke

d = DataSynthesis()
generator = d.batch_generator(shuffle=False)
_, sentence, target = next(generator)

hws = HandWritingSynthesis()
hws.make_model(load_weights='models/trained/model_synthesis_overfit.h5')

sentence = tf.dtypes.cast(d.prepare_text('independent expert body'), float)
strokes, windows, phis, kappas = hws.infer(sentence, seed=18)

plot_stroke(strokes)
Exemplo n.º 27
0
        model.eval()

        point, hidden = model(point, hidden)
        point[:, 0] = torch.sigmoid(point[:, 0])
        #         point[:, 0] = torch.Tensor([1 if x > 0.5 else 0 for x in point[:, 0]])
        point[:, 0] = torch.Tensor(
            np.random.binomial(1, point[:, 0].data.cpu().numpy())).to(device)
        stroke[k] = point
        point = point.unsqueeze(0)

    return np.array(stroke.squeeze().data.cpu())


if __name__ == "__main__":

    import argparse

    arg_parser = argparse.ArgumentParser(
        description="Generate a random stroke")
    arg_parser.add_argument("--random_seed",
                            "-r",
                            dest="random_seed",
                            required=False,
                            help="The random seed")

    args = arg_parser.parse_args()
    random_seed = args.random_seed if args.random_seed else 1
    stroke = generate_unconditionally(random_seed)
    print('Random seed:', random_seed)
    plot_stroke(stroke)