def train(model_folder, num_epochs, learning_rate, batch_size, stroke_length,
          steps_per_char, save_every):

    char_seq_length = stroke_length // steps_per_char

    train_data, val_data, test_data, metadata = get_preprocessed_data_splits()
    train_data_gen = conditional_batch_generator(train_data,
                                                 stroke_length=stroke_length,
                                                 char_length=char_seq_length,
                                                 batch_size=batch_size)
    epoch_size = len(train_data[0]) // batch_size

    model = ConditionalStrokeModel(model_folder,
                                   learning_rate=learning_rate,
                                   batch_size=batch_size,
                                   rnn_steps=stroke_length,
                                   is_train=True,
                                   char_dict=metadata['char_dict'],
                                   char_seq_len=char_seq_length)

    for epoch in range(num_epochs):
        for i in range(epoch_size):
            input_batch, target_batch, char_batch = next(train_data_gen)
            loss, _ = model.train(input_batch, target_batch, char_batch)

        if not epoch % save_every and epoch != 0:
            model.save()

        if not epoch % save_every and epoch != 0:
            model = ConditionalStrokeModel.load(model_folder,
                                                batch_size=1,
                                                rnn_steps=1,
                                                is_train=False,
                                                char_seq_len=30)
            strokes = decode(model)

            plotting.plot_stroke(
                strokes, os.path.join(model_folder, 'test{}'.format(epoch)))

            model = ConditionalStrokeModel.load(model_folder,
                                                learning_rate=learning_rate,
                                                batch_size=batch_size,
                                                rnn_steps=stroke_length,
                                                is_train=True,
                                                char_seq_len=char_seq_length)

        _print_loss(model, train_data, epoch, stroke_length, char_seq_length,
                    batch_size, 'Train')
        _print_loss(model, val_data, epoch, stroke_length, char_seq_length,
                    batch_size, '    Val')
    model.save()
def train(model_folder, learning_rate, num_epochs, num_layers, num_steps,
          hidden_size, lr_decay, batch_size):

    train_data, _, test_data, metadata = get_preprocessed_data_splits(
        sentences_to_int=True)

    train_data_gen = character_batch_generator(train_data,
                                               batch_size=batch_size)

    vocab_size = len(metadata['char_dict'])
    epoch_size = len(train_data[0]) // batch_size

    model = CharacterGenModel(model_folder,
                              batch_size,
                              num_steps,
                              vocab_size,
                              num_layers,
                              hidden_size,
                              char_dict=metadata['char_dict'])

    for i in range(num_epochs):
        learning_rate *= 1 / (1. + lr_decay * i)
        model.sess.run(model.lr_update, {model.new_lr: learning_rate})

        print('Epoch {}, lr {}'.format(i, model.sess.run(model.lr)))

        total_cost = 0
        for _ in range(epoch_size):
            input_batch, target_batch = next(train_data_gen)
            cost = model.train(input_batch, target_batch)[0]
            total_cost += cost

        model.save()
        train_perplexity = np.exp(total_cost / (epoch_size))
        print('Epoch: {} Train perplexity: {}'.format(i, train_perplexity))

    total_cost = 0
    test_epochs = len(test_data[0]) // batch_size
    for _ in range(test_epochs):
        test_input, test_target = next(character_batch_generator(test_data))
        cost = model.sess.run(model.cost, {
            model.inputs: test_input,
            model.targets: test_target
        })
        total_cost += cost

    test_perplexity = np.exp(total_cost / test_epochs)
    print("Test perplexity: {}".format(test_perplexity))
    print('BPC: {}'.format(total_cost / test_epochs))
def train(model_folder, num_epochs, learning_rate, batch_size, stroke_length,
          save_every):

    train_data, val_data, test_data, metadata = get_preprocessed_data_splits()
    train_data_gen = stroke_batch_generator(train_data,
                                            length=stroke_length,
                                            batch_size=batch_size)

    model = UnconditionalStrokeModel(model_folder,
                                     learning_rate=learning_rate,
                                     batch_size=batch_size,
                                     rnn_steps=stroke_length,
                                     is_train=True)
    epoch_size = len(train_data[0]) // batch_size

    for epoch in range(num_epochs):
        for i in range(epoch_size):
            input_batch, target_batch = next(train_data_gen)
            model.train(input_batch, target_batch)

        if not epoch % save_every and epoch != 0:
            model.save()

        if not epoch % save_every and epoch != 0:
            # make some stroke plots
            model = UnconditionalStrokeModel.load(model_folder,
                                                  batch_size=1,
                                                  rnn_steps=1,
                                                  is_train=False)

            strokes = decode(model, 650, epoch)
            plotting.plot_stroke(
                strokes, os.path.join(model_folder, 'test{}'.format(epoch)))

            model = UnconditionalStrokeModel.load(model_folder,
                                                  learning_rate=learning_rate,
                                                  batch_size=batch_size,
                                                  rnn_steps=stroke_length)

        _print_loss(model, train_data, epoch, stroke_length, batch_size,
                    'Train')
        _print_loss(model, val_data, epoch, stroke_length, batch_size,
                    '    Val')
    model.save()