def create_network():
        train_data = np.random.randint(1, 5, [10, 2])
        weights = np.array([1.0, 2.0])
        train_label = train_data.dot(weights)

        di = mx.io.NDArrayIter(train_data,
                               train_label,
                               batch_size=5,
                               shuffle=True,
                               label_name='lin_reg_label')
        X = mx.sym.Variable('data')
        Y = mx.symbol.Variable('lin_reg_label')
        fully_connected_layer = mx.sym.FullyConnected(data=X,
                                                      name='fc1',
                                                      num_hidden=1)
        lro = mx.sym.LinearRegressionOutput(data=fully_connected_layer,
                                            label=Y,
                                            name="lro")

        mod = SVRGModule(symbol=lro,
                         data_names=['data'],
                         label_names=['lin_reg_label'],
                         update_freq=2)
        mod.bind(data_shapes=di.provide_data, label_shapes=di.provide_label)
        mod.init_params(initializer=mx.init.One(),
                        allow_missing=False,
                        force_init=False,
                        allow_extra=False)
        mod.init_optimizer(kvstore='local',
                           optimizer='sgd',
                           optimizer_params=(('learning_rate', 0.01), ),
                           force_init=False)
        return di, mod
def test_svrgmodule_reshape():
    data = mx.sym.Variable("data")
    sym = mx.sym.FullyConnected(data=data, num_hidden=4, name='fc')

    dshape = (3, 4)
    mod = SVRGModule(sym,
                     data_names=["data"],
                     label_names=None,
                     context=[mx.cpu(0), mx.cpu(1)],
                     update_freq=2)
    mod.bind(data_shapes=[('data', dshape)])
    mod.init_params()
    mod._mod_aux.init_params()
    mod.init_optimizer(optimizer_params={"learning_rate": 1.0})

    data_batch = mx.io.DataBatch(data=[mx.nd.ones(dshape)], label=None)
    mod.forward(data_batch)
    mod.backward([mx.nd.ones(dshape)])
    mod.update()
    assert mod.get_outputs()[0].shape == dshape

    dshape = (2, 4)
    mod.reshape(data_shapes=[('data', dshape)])
    mod.forward(mx.io.DataBatch(data=[mx.nd.ones(dshape)], label=None))
    mod.backward([mx.nd.ones(dshape)])
    mod.update()
    assert mod.get_outputs()[0].shape == dshape
    def create_module_with_sgd():
        train_data = np.random.randint(1, 5, [100, 2])
        weights = np.array([1.0, 2.0])
        train_label = train_data.dot(weights)

        di = mx.io.NDArrayIter(train_data, train_label, batch_size=10, shuffle=True, label_name='lin_reg_label')
        X = mx.sym.Variable('data')
        Y = mx.symbol.Variable('lin_reg_label')
        fully_connected_layer = mx.sym.FullyConnected(data=X, name='fc1', num_hidden=1)
        lro = mx.sym.LinearRegressionOutput(data=fully_connected_layer, label=Y, name="lro")

        reg_mod = mx.mod.Module(
            symbol=lro,
            data_names=['data'],
            label_names=['lin_reg_label'])
        reg_mod.bind(data_shapes=di.provide_data, label_shapes=di.provide_label)
        reg_mod.init_params(initializer=mx.init.One(), allow_missing=False, force_init=False, allow_extra=False)
        reg_mod.init_optimizer(kvstore='local', optimizer='sgd', optimizer_params=(('learning_rate', 0.01),))

        svrg_mod = SVRGModule(symbol=lro,
            data_names=['data'],
            label_names=['lin_reg_label'],
            update_freq=2)
        svrg_mod.bind(data_shapes=di.provide_data, label_shapes=di.provide_label)
        svrg_mod.init_params(initializer=mx.init.One(), allow_missing=False, force_init=False, allow_extra=False)
        svrg_mod.init_optimizer(kvstore='local', optimizer='sgd', optimizer_params=(('learning_rate', 0.01),))

        return di,reg_mod, svrg_mod
Example #4
0
def test_module_save_load():
    import tempfile
    import os

    x = mx.sym.Variable("data")
    y = mx.sym.Variable("softmax_label")
    net = mx.sym.FullyConnected(x, y, num_hidden=1)

    mod = SVRGModule(symbol=net,
                     data_names=['data'],
                     label_names=['softmax_label'],
                     update_freq=2)
    mod.bind(data_shapes=[('data', (1, 1))])
    mod.init_params()
    mod.init_optimizer(optimizer='sgd',
                       optimizer_params={'learning_rate': 0.1})
    mod.update()

    # Create tempfile
    tmp = tempfile.mkdtemp()
    tmp_file = os.path.join(tmp, 'svrg_test_output')
    mod.save_checkpoint(tmp_file, 0, save_optimizer_states=True)

    mod2 = SVRGModule.load(tmp_file,
                           0,
                           load_optimizer_states=True,
                           data_names=('data', ))
    mod2.bind(data_shapes=[('data', (1, 1))])
    mod2.init_optimizer(optimizer_params={'learning_rate': 0.1})
    assert mod._symbol.tojson() == mod2._symbol.tojson()

    # Multi-device
    mod3 = SVRGModule(symbol=net,
                      data_names=['data'],
                      label_names=['softmax_label'],
                      update_freq=3,
                      context=[mx.cpu(0), mx.cpu(1)])
    mod3.bind(data_shapes=[('data', (10, 10))])
    mod3.init_params()
    mod3.init_optimizer(optimizer_params={'learning_rate': 1.0})
    mod3.update()
    mod3.save_checkpoint(tmp_file, 0, save_optimizer_states=True)

    mod4 = SVRGModule.load(tmp_file,
                           0,
                           load_optimizer_states=True,
                           data_names=('data', ))
    mod4.bind(data_shapes=[('data', (10, 10))])
    mod4.init_optimizer(optimizer_params={'learning_rate': 1.0})
    assert mod3._symbol.tojson() == mod4._symbol.tojson()
def test_module_save_load():
    import tempfile
    import os

    x = mx.sym.Variable("data")
    y = mx.sym.Variable("softmax_label")
    net = mx.sym.FullyConnected(x, y, num_hidden=1)

    mod = SVRGModule(symbol=net, data_names=['data'], label_names=['softmax_label'], update_freq=2)
    mod.bind(data_shapes=[('data', (1, 1))])
    mod.init_params()
    mod.init_optimizer(optimizer='sgd', optimizer_params={'learning_rate': 0.1})
    mod.update()

    # Create tempfile
    tmp = tempfile.mkdtemp()
    tmp_file = os.path.join(tmp, 'svrg_test_output')
    mod.save_checkpoint(tmp_file, 0, save_optimizer_states=True)

    mod2 = SVRGModule.load(tmp_file, 0, load_optimizer_states=True, data_names=('data', ))
    mod2.bind(data_shapes=[('data', (1, 1))])
    mod2.init_optimizer(optimizer_params={'learning_rate': 0.1})
    assert mod._symbol.tojson() == mod2._symbol.tojson()

    # Multi-device
    mod3 = SVRGModule(symbol=net, data_names=['data'], label_names=['softmax_label'], update_freq=3,
                     context=[mx.cpu(0), mx.cpu(1)])
    mod3.bind(data_shapes=[('data', (10, 10))])
    mod3.init_params()
    mod3.init_optimizer(optimizer_params={'learning_rate': 1.0})
    mod3.update()
    mod3.save_checkpoint(tmp_file, 0, save_optimizer_states=True)

    mod4 = SVRGModule.load(tmp_file, 0, load_optimizer_states=True, data_names=('data', ))
    mod4.bind(data_shapes=[('data', (10, 10))])
    mod4.init_optimizer(optimizer_params={'learning_rate': 1.0})
    assert mod3._symbol.tojson() == mod4._symbol.tojson()
def test_svrgmodule_reshape():
    data = mx.sym.Variable("data")
    sym = mx.sym.FullyConnected(data=data, num_hidden=4, name='fc')

    dshape=(3, 4)
    mod = SVRGModule(sym, data_names=["data"], label_names=None, context=[mx.cpu(0), mx.cpu(1)], update_freq=2)
    mod.bind(data_shapes=[('data', dshape)])
    mod.init_params()
    mod._mod_aux.init_params()
    mod.init_optimizer(optimizer_params={"learning_rate": 1.0})

    data_batch = mx.io.DataBatch(data=[mx.nd.ones(dshape)], label=None)
    mod.forward(data_batch)
    mod.backward([mx.nd.ones(dshape)])
    mod.update()
    assert mod.get_outputs()[0].shape == dshape

    dshape = (2, 4)
    mod.reshape(data_shapes=[('data', dshape)])
    mod.forward(mx.io.DataBatch(data=[mx.nd.ones(dshape)],
                                label=None))
    mod.backward([mx.nd.ones(dshape)])
    mod.update()
    assert mod.get_outputs()[0].shape == dshape