def test_reset(): sj.current_graph().reset() w = symjax.tensor.Variable(1.0, name="w", dtype="float32") x = symjax.tensor.Variable(2.0, name="x", dtype="float32") f = symjax.function(outputs=[w, x], updates={w: w + 1, x: x + 1}) for i in range(10): print(i) assert np.array_equal(np.array(f()), np.array([1, 2]) + i) # reset only the w variable symjax.reset_variables("*w") assert np.array_equal(np.array(f()), np.array([1, 2 + i + 1])) # reset all variables symjax.reset_variables("*") assert np.array_equal(np.array(f()), np.array([1, 2]))
import symjax.tensor as T import matplotlib.pyplot as plt # GRADIENT DESCENT z = T.Variable(3.0, dtype="float32") loss = (z - 1) ** 2 g_z = symjax.gradients(loss, [z])[0] symjax.current_graph().add_updates({z: z - 0.1 * g_z}) train = symjax.function(outputs=[loss, z], updates=symjax.get_updates()) losses = list() values = list() for i in range(200): if (i + 1) % 50 == 0: symjax.reset_variables("*") a, b = train() losses.append(a) values.append(b) plt.figure() plt.subplot(121) plt.plot(losses, "-x") plt.ylabel("loss") plt.xlabel("number of gradient updates") plt.subplot(122) plt.plot(values, "-x") plt.axhline(1, c="red") plt.ylabel("value")
}) # pretend we train for a while for i in range(4): print(train()) # [0.] # [8.] # [16.] # [24.] # now say we wanted to reset the variables and retrain, we can do # either with g, as it contains all the variables g.reset() # or we can do symjax.reset_variables("*") # or if we wanted to only reset say variables from layer2 symjax.reset_variables("*layer2*") # now that all has been reset, let's retrain for a while # pretend we train for a while for i in range(2): print(train()) # [0.] # [8.] # now resetting is nice, but we might want to save the model parameters, to # keep training later or do some other analyses. We can do so as follows: g.save_variables("model1_saved") # this would save all variables as they are contained in g. Now say we want to