def test_sgd_momentum(): data, data_shape, label, label_shape = get_input() mlp = MLP() opt = SGD(mlp.parameters(), lr=0.01, momentum=0.9) slots = TensorDict() for param in mlp.parameters(): slots[param] = np.zeros(param.shape).astype(np.float32) for _ in range(3): data.set_value(np.random.random(data_shape).astype(np.float32)) label.set_value(np.random.randint(0, 10, label_shape)) pred = mlp(data) loss = F.square_loss(pred, label.reshape(-1, 1)) opt.zero_grad() opt.backward(loss) orig_params = TensorDict() grads = TensorDict() for param in mlp.parameters(): orig_params[param] = np.copy(param.numpy()) grads[param] = np.copy(param.grad.numpy()) opt.step() for param in mlp.parameters(): slot = slots[param] orig_param = orig_params[param] slot *= 0.9 slot -= param.grad.numpy() * 0.01 assertTensorClose(param.numpy(), orig_param + slot)
def test_sgd_simple(): data, data_shape, label, label_shape = get_input() mlp = MLP() opt = SGD(mlp.parameters(), lr=0.01, weight_decay=0.1) for idx in range(3): data.set_value(np.random.random(data_shape).astype(np.float32)) label.set_value(np.random.randint(0, 10, label_shape)) pred = mlp(data) loss = F.square_loss(pred, label.reshape(-1, 1)) if idx % 2: opt.zero_grad() else: mlp.zero_grad() opt.backward(loss) grads = TensorDict() orig_params = TensorDict() for param in mlp.parameters(): grad = F.grad(loss, param, use_virtual_grad=False) assertTensorClose(grad.numpy(), param.grad.numpy()) grads[param] = np.copy(grad.numpy()) orig_params[param] = np.copy(param.numpy()) opt.step() for param in mlp.parameters(): assertTensorClose(param.numpy(), orig_params[param] * 0.999 - grads[param] * 0.01)
def test_update_lr(): data, data_shape, label, label_shape = get_input() mlp = MLP() opt = SGD(mlp.parameters(), lr=0.01) pred = mlp(data) loss = F.square_loss(pred, label.reshape(-1, 1)) opt.zero_grad() opt.backward(loss) opt.step() for group in opt.param_groups: group["lr"] += 0.02 for _ in range(3): data.set_value(np.random.random(data_shape).astype(np.float32)) label.set_value(np.random.randint(0, 10, label_shape)) pred = mlp(data) loss = F.square_loss(pred, label.reshape(-1, 1)) opt.zero_grad() opt.backward(loss) for param in mlp.parameters(): grad = F.grad(loss, param, use_virtual_grad=False) assertTensorClose(grad.numpy(), param.grad.numpy()) orig_params = [] for param in mlp.parameters(): orig_params.append(np.copy(param.numpy())) opt.step() for param, orig_param in zip(mlp.parameters(), orig_params): assertTensorClose(param.numpy(), orig_param - param.grad.numpy() * 0.03)
def test_blur(): net = Blur() data = tensor(np.random.random((32, 16)).astype("float32")) opt = SGD(net.parameters(requires_grad=True), lr=0.1) opt.zero_grad() loss = net(data) opt.backward(loss.sum())
def test_compile_multi_times_static(): return # XXX: rewrite or remove this test with Graph() as cg: cg.set_option("eager_evaluation", False) data = Input("data", shape=(2, 28)) label = Input("label", shape=(2, ), dtype=np.int32) mlp = MLP() opt = SGD(mlp.parameters(requires_grad=True), lr=0.01) pred0 = mlp(data) pred = F.softmax(pred0) loss = F.square_loss(pred, label.reshape(2, 1)) opt.zero_grad() grads = opt.backward(loss) opt.step() f0 = compile(pred, None) f1 = compile([pred, loss], grads, copy=True) data = np.random.random((2, 28)).astype(np.float32) label = np.random.randint(0, 10, (2, )).astype(np.float32) out0 = f0(data=data) out1 = f1(data=data, label=label) assertTensorClose(out0[0], out1[0]) _ = compile([pred, loss], grads, copy=False) with pytest.raises(mgb.MegBrainError): f0(data=data)
def test_optimizer_serialization(): data, data_shape, label, label_shape = get_input() mlp = MLP() opt = SGD(mlp.parameters(), lr=0.01, momentum=0.9) slots = TensorDict() for param in mlp.parameters(): slots[param] = np.zeros(param.shape).astype(np.float32) pred = mlp(data) loss = F.square_loss(pred, label.reshape(-1, 1)) opt.zero_grad() opt.backward(loss) opt.step() for param in mlp.parameters(): slot = slots[param] slot *= 0.9 slot -= param.grad.numpy() * 0.01 with BytesIO() as fout: save(opt.state_dict(), fout) fout.seek(0) state_dict = load(fout) opt1 = SGD(mlp.parameters(), lr=0.02, momentum=0.8) opt1.load_state_dict(state_dict) data.set_value(np.random.random(data_shape).astype(np.float32)) label.set_value(np.random.randint(0, 10, label_shape)) pred = mlp(data) loss = F.square_loss(pred, label.reshape(-1, 1)) opt1.zero_grad() opt1.backward(loss) orig_params = TensorDict() for param in mlp.parameters(): orig_params[param] = np.copy(param.numpy()) opt1.step() for param in mlp.parameters(): orig_param = orig_params[param] slot = slots[param] slot *= 0.9 slot -= param.grad.numpy() * 0.01 assertTensorClose(param.numpy(), orig_param + slot)
def test_release_memory(): mnist_datasets = load_mnist_datasets() data_train, label_train = mnist_datasets["train"] batch_size = 15000 data_shape = (batch_size, 1, 28, 28) label_shape = (batch_size, ) data = nn.Input("data", shape=data_shape, dtype=np.float32) label = nn.Input("label", shape=label_shape, dtype=np.int32, value=np.zeros(label_shape)) net = MnistNet() opt = SGD(net.parameters(), lr=0.01) pred = F.softmax(net(data)) loss = F.cross_entropy(pred, label) opt.zero_grad() opt.backward(loss) add_updates = opt.step() mge.graph._default_graph.get_default().clear_device_memory() f = mge.graph.compile(loss, add_updates) for _ in range(3): train_loss = 0.0 for i in range(0, data_train.shape[0], batch_size): opt.zero_grad() data = data_train[i:i + batch_size, :, :, :] label = label_train[i:i + batch_size] loss = f(data=data, label=label)[0] train_loss += loss[0]
def test_compile_multi_times_eager(): return # XXX: rewrite or remove this test data = Input("data", shape=(2, 28)) label = Input("label", shape=(2, ), dtype=np.int32) mlp = MLP() opt = SGD(mlp.parameters(requires_grad=True), lr=0.01) pred0 = mlp(data) pred = F.softmax(pred0) loss = F.square_loss(pred, label.reshape(2, 1)) opt.zero_grad() grads = opt.backward(loss) opt.step() f0 = compile(pred, None) f1 = compile([pred, loss], grads, copy=False) for _ in range(3): data = np.random.random((2, 28)).astype(np.float32) label = np.random.randint(0, 10, (2, )).astype(np.float32) out0 = f0(data=data) out1 = f1(data=data, label=label) assertTensorClose(out0[0], out1[0])