def predict_char(prefix, num_predicts, model, vocab): model.init_state(batch_size=1) outputs = [vocab[prefix[0]]] get_input = lambda: one_hot( jn.array([outputs[-1]]).reshape(1, 1), len(vocab)) for y in prefix[1:]: # Warmup state with prefix model(get_input()) outputs.append(vocab[y]) for _ in range(num_predicts): # Predict num_predicts steps Y = model(get_input()) outc = int(Y.argmax(axis=1).reshape(1)) outputs.append(outc) return ''.join([vocab.idx_to_token[i] for i in outputs])
) with SummaryWriter(os.path.join(logdir, 'tb')) as tensorboard: for epoch in range(num_train_epochs): # Train one epoch summary = Summary() loop = trange(0, train_size, batch, leave=False, unit='img', unit_scale=batch, desc='Epoch %d/%d' % (1 + epoch, num_train_epochs)) for it in loop: sel = np.random.randint(size=(batch, ), low=0, high=train_size) x, xl = train.image[sel], train.label[sel] xl = one_hot(xl, nclass) v = train_op(x, xl) summary.scalar('losses/xe', float(v[0])) # Eval accuracy = 0 for it in trange(0, test.image.shape[0], batch, leave=False, desc='Evaluating'): x = test.image[it:it + batch] xl = test.label[it:it + batch] accuracy += (np.argmax(predict(x), axis=1) == xl).sum() accuracy /= test.image.shape[0] summary.scalar('eval/accuracy', 100 * accuracy)
num_epochs = 500 num_hiddens = 256 lr = 0.0001 theta = 1 print(jax.local_devices()) train_iter, vocab = load_shakespeare(batch_size, num_steps, 'char') vocab_size = len(vocab) model = RNN(num_hiddens, vocab_size, vocab_size) model_vars = model.vars() model.init_state(batch_size) # Sample call for forward pass X = jn.arange(batch_size * num_steps).reshape(batch_size, num_steps).T X_one_hot = one_hot(X, vocab_size) Z = model(X_one_hot) print("X_one_hot.shape:", X_one_hot.shape) print("Z.shape:", Z.shape) def predict_char(prefix, num_predicts, model, vocab): model.init_state(batch_size=1) outputs = [vocab[prefix[0]]] get_input = lambda: one_hot( jn.array([outputs[-1]]).reshape(1, 1), len(vocab)) for y in prefix[1:]: # Warmup state with prefix model(get_input()) outputs.append(vocab[y]) for _ in range(num_predicts): # Predict num_predicts steps Y = model(get_input())