Exemplo n.º 1
0
def predict_char(prefix, num_predicts, model, vocab):
    model.init_state(batch_size=1)
    outputs = [vocab[prefix[0]]]
    get_input = lambda: one_hot(
        jn.array([outputs[-1]]).reshape(1, 1), len(vocab))
    for y in prefix[1:]:  # Warmup state with prefix
        model(get_input())
        outputs.append(vocab[y])
    for _ in range(num_predicts):  # Predict num_predicts steps
        Y = model(get_input())
        outc = int(Y.argmax(axis=1).reshape(1))
        outputs.append(outc)
    return ''.join([vocab.idx_to_token[i] for i in outputs])
Exemplo n.º 2
0
)
with SummaryWriter(os.path.join(logdir, 'tb')) as tensorboard:
    for epoch in range(num_train_epochs):
        # Train one epoch
        summary = Summary()
        loop = trange(0,
                      train_size,
                      batch,
                      leave=False,
                      unit='img',
                      unit_scale=batch,
                      desc='Epoch %d/%d' % (1 + epoch, num_train_epochs))
        for it in loop:
            sel = np.random.randint(size=(batch, ), low=0, high=train_size)
            x, xl = train.image[sel], train.label[sel]
            xl = one_hot(xl, nclass)
            v = train_op(x, xl)
            summary.scalar('losses/xe', float(v[0]))

        # Eval
        accuracy = 0
        for it in trange(0,
                         test.image.shape[0],
                         batch,
                         leave=False,
                         desc='Evaluating'):
            x = test.image[it:it + batch]
            xl = test.label[it:it + batch]
            accuracy += (np.argmax(predict(x), axis=1) == xl).sum()
        accuracy /= test.image.shape[0]
        summary.scalar('eval/accuracy', 100 * accuracy)
Exemplo n.º 3
0
num_epochs = 500
num_hiddens = 256
lr = 0.0001
theta = 1
print(jax.local_devices())

train_iter, vocab = load_shakespeare(batch_size, num_steps, 'char')
vocab_size = len(vocab)

model = RNN(num_hiddens, vocab_size, vocab_size)
model_vars = model.vars()
model.init_state(batch_size)

# Sample call for forward pass
X = jn.arange(batch_size * num_steps).reshape(batch_size, num_steps).T
X_one_hot = one_hot(X, vocab_size)
Z = model(X_one_hot)
print("X_one_hot.shape:", X_one_hot.shape)
print("Z.shape:", Z.shape)


def predict_char(prefix, num_predicts, model, vocab):
    model.init_state(batch_size=1)
    outputs = [vocab[prefix[0]]]
    get_input = lambda: one_hot(
        jn.array([outputs[-1]]).reshape(1, 1), len(vocab))
    for y in prefix[1:]:  # Warmup state with prefix
        model(get_input())
        outputs.append(vocab[y])
    for _ in range(num_predicts):  # Predict num_predicts steps
        Y = model(get_input())