Example #1
0
def test_reset():
    sj.current_graph().reset()
    w = symjax.tensor.Variable(1.0, name="w", dtype="float32")
    x = symjax.tensor.Variable(2.0, name="x", dtype="float32")
    f = symjax.function(outputs=[w, x], updates={w: w + 1, x: x + 1})
    for i in range(10):
        print(i)
        assert np.array_equal(np.array(f()), np.array([1, 2]) + i)

    # reset only the w variable
    symjax.reset_variables("*w")
    assert np.array_equal(np.array(f()), np.array([1, 2 + i + 1]))
    # reset all variables
    symjax.reset_variables("*")
    assert np.array_equal(np.array(f()), np.array([1, 2]))
Example #2
0
import symjax.tensor as T
import matplotlib.pyplot as plt

# GRADIENT DESCENT
z = T.Variable(3.0, dtype="float32")
loss = (z - 1) ** 2
g_z = symjax.gradients(loss, [z])[0]
symjax.current_graph().add_updates({z: z - 0.1 * g_z})

train = symjax.function(outputs=[loss, z], updates=symjax.get_updates())

losses = list()
values = list()
for i in range(200):
    if (i + 1) % 50 == 0:
        symjax.reset_variables("*")
    a, b = train()
    losses.append(a)
    values.append(b)

plt.figure()

plt.subplot(121)
plt.plot(losses, "-x")
plt.ylabel("loss")
plt.xlabel("number of gradient updates")

plt.subplot(122)
plt.plot(values, "-x")
plt.axhline(1, c="red")
plt.ylabel("value")
Example #3
0
                        })

# pretend we train for a while
for i in range(4):
    print(train())

# [0.]
# [8.]
# [16.]
# [24.]

# now say we wanted to reset the variables and retrain, we can do
# either with g, as it contains all the variables
g.reset()
# or we can do
symjax.reset_variables("*")
# or if we wanted to only reset say variables from layer2
symjax.reset_variables("*layer2*")

# now that all has been reset, let's retrain for a while
# pretend we train for a while
for i in range(2):
    print(train())

# [0.]
# [8.]

# now resetting is nice, but we might want to save the model parameters, to
# keep training later or do some other analyses. We can do so as follows:
g.save_variables("model1_saved")
# this would save all variables as they are contained in g. Now say we want to