def test_adagrad(): results = [] for scale in scales: A = cgt.shared(1.0) B = cgt.shared(1.0) updates = nn.adagrad(f(A, scale) + f(B, scale), [A, B], learning_rate=0.1) do_update = cgt.function([], [], updates=updates) for _ in range(10): do_update() assert np.allclose(A.op.get_value(), B.op.get_value()) results.append(A.op.get_value().copy()) assert np.allclose(results, torch_values['adagrad'])
def run_adagrad(): results = [] for scale in scales: A = cgt.shared(1.0) B = cgt.shared(1.0) updates = nn.adagrad(f(A, scale) + f(B, scale), [A, B], learning_rate=0.1) do_update = cgt.function([], [], updates=updates) for _ in range(10): do_update() assert np.allclose(A.op.get_value(), B.op.get_value()) results.append(A.op.get_value().copy()) assert np.allclose(results, torch_values['adagrad'])
def main(): X = cgt.matrix(name='data', dtype=cgt.floatX, fixed_shape=(None, 2212)) y = cgt.vector("y", dtype='i8') model = build_nn(X) loss = -cgt.mean(categorical.loglik(y, model)) updates = nn.adagrad(loss, nn.get_parameters(loss), 0.01) y_nodrop = cgt.argmax(model, axis=1) cost_nodrop = -cgt.mean(categorical.loglik(y, model)) err_nodrop = cgt.cast(cgt.not_equal(y_nodrop, y), cgt.floatX).mean() train = cgt.function(inputs=[X, y], outputs=[], updates=updates) computeloss = cgt.function(inputs=[X, y], outputs=[err_nodrop, cost_nodrop]) batch_size = 20 Xdata, ydata = load_data() Xtrain = Xdata[0:5200] ytrain = ydata[0:5200] Xtest = Xdata[5200:5573] ytest = ydata[5200:5573] sortinds = np.random.permutation(5200) Xtrain = Xtrain[sortinds] ytrain = ytrain[sortinds] print fmt_row(10, ["Epoch","Train NLL","Train Err","Test NLL","Test Err","Epoch Time"]) for i_epoch in xrange(20): tstart = time.time() for start in xrange(0, Xtrain.shape[0], batch_size): end = start+batch_size train(Xtrain[start:end], ytrain[start:end]) elapsed = time.time() - tstart trainerr, trainloss = computeloss(Xtrain[:len(Xtest)], ytrain[:len(Xtest)]) testerr, testloss = computeloss(Xtest, ytest) print fmt_row(10, [i_epoch, trainloss, trainerr, testloss, testerr, elapsed])