示例#1
0
def test_resume_checkpoint():
    with TemporaryDirectory() as tmpdir:
        model_prefix = 'test_net'
        file_path = os.path.join(tmpdir, model_prefix)
        test_data = _get_test_data()

        net = _get_test_network()
        ce_loss = loss.SoftmaxCrossEntropyLoss()
        acc = mx.gluon.metric.Accuracy()
        est = estimator.Estimator(net, loss=ce_loss, train_metrics=acc)
        checkpoint_handler = event_handler.CheckpointHandler(model_dir=tmpdir,
                                                             model_prefix=model_prefix,
                                                             monitor=acc,
                                                             max_checkpoints=1)
        est.fit(test_data, event_handlers=[checkpoint_handler], epochs=2)
        assert os.path.isfile(file_path + '-epoch1batch8.params')
        assert os.path.isfile(file_path + '-epoch1batch8.states')
        checkpoint_handler = event_handler.CheckpointHandler(model_dir=tmpdir,
                                                             model_prefix=model_prefix,
                                                             monitor=acc,
                                                             max_checkpoints=1,
                                                             resume_from_checkpoint=True)
        est.fit(test_data, event_handlers=[checkpoint_handler], epochs=5)
        # should only continue to train 3 epochs and last checkpoint file is epoch4
        assert est.max_epoch == 3
        assert os.path.isfile(file_path + '-epoch4batch20.states')
示例#2
0
def test_np_save_load_ndarrays():
    shapes = [(2, 0, 1), (0, ), (), (), (0, 4), (), (3, 0, 0, 0), (2, 1),
              (0, 5, 0), (4, 5, 6), (0, 0, 0)]
    array_list = [_np.random.randint(0, 10, size=shape) for shape in shapes]
    array_list = [np.array(arr, dtype=arr.dtype) for arr in array_list]
    # test save/load single ndarray
    for i, arr in enumerate(array_list):
        with TemporaryDirectory() as work_dir:
            fname = os.path.join(work_dir, 'dataset.npy')
            npx.save(fname, arr)
            arr_loaded = npx.load(fname)
            assert isinstance(arr_loaded, list)
            assert len(arr_loaded) == 1
            assert _np.array_equal(arr_loaded[0].asnumpy(),
                                   array_list[i].asnumpy())

    # test save/load a list of ndarrays
    with TemporaryDirectory() as work_dir:
        fname = os.path.join(work_dir, 'dataset.npy')
        npx.save(fname, array_list)
        array_list_loaded = mx.nd.load(fname)
        assert isinstance(arr_loaded, list)
        assert len(array_list) == len(array_list_loaded)
        assert all(isinstance(arr, np.ndarray) for arr in arr_loaded)
        for a1, a2 in zip(array_list, array_list_loaded):
            assert _np.array_equal(a1.asnumpy(), a2.asnumpy())

    # test save/load a dict of str->ndarray
    arr_dict = {}
    keys = [str(i) for i in range(len(array_list))]
    for k, v in zip(keys, array_list):
        arr_dict[k] = v
    with TemporaryDirectory() as work_dir:
        fname = os.path.join(work_dir, 'dataset.npy')
        npx.save(fname, arr_dict)
        arr_dict_loaded = npx.load(fname)
        assert isinstance(arr_dict_loaded, dict)
        assert len(arr_dict_loaded) == len(arr_dict)
        for k, v in arr_dict_loaded.items():
            assert k in arr_dict
            assert _np.array_equal(v.asnumpy(), arr_dict[k].asnumpy())
示例#3
0
def test_np_ndarray_pickle():
    a = np.random.uniform(size=(4, 5))
    a_copy = a.copy()
    import pickle

    with TemporaryDirectory() as work_dir:
        fname = os.path.join(work_dir, 'np_ndarray_pickle_test_file')
        with open(fname, 'wb') as f:
            pickle.dump(a_copy, f)
        with open(fname, 'rb') as f:
            a_load = pickle.load(f)
        same(a.asnumpy(), a_load.asnumpy())
示例#4
0
def test_buffer_load():
    nrepeat = 10
    with TemporaryDirectory(prefix='test_buffer_load_') as tmpdir:
        for repeat in range(nrepeat):
            # test load_buffer as list
            data = []
            for i in range(10):
                data.append(random_ndarray(np.random.randint(1, 5)))
            fname = os.path.join(tmpdir, 'list_{0}.param'.format(repeat))
            mx.nd.save(fname, data)
            with open(fname, 'rb') as dfile:
                buf_data = dfile.read()
                data2 = mx.nd.load_frombuffer(buf_data)
                assert len(data) == len(data2)
                for x, y in zip(data, data2):
                    assert np.sum(x.asnumpy() != y.asnumpy()) == 0
                # test garbage values
                assertRaises(mx.base.MXNetError, mx.nd.load_frombuffer,
                             buf_data[:-10])
            # test load_buffer as dict
            dmap = {'ndarray xx %s' % i: x for i, x in enumerate(data)}
            fname = os.path.join(tmpdir, 'dict_{0}.param'.format(repeat))
            mx.nd.save(fname, dmap)
            with open(fname, 'rb') as dfile:
                buf_dmap = dfile.read()
                dmap2 = mx.nd.load_frombuffer(buf_dmap)
                assert len(dmap2) == len(dmap)
                for k, x in dmap.items():
                    y = dmap2[k]
                    assert np.sum(x.asnumpy() != y.asnumpy()) == 0
                # test garbage values
                assertRaises(mx.base.MXNetError, mx.nd.load_frombuffer,
                             buf_dmap[:-10])

            # we expect the single ndarray to be converted into a list containing the ndarray
            single_ndarray = data[0]
            fname = os.path.join(tmpdir, 'single_{0}.param'.format(repeat))
            mx.nd.save(fname, single_ndarray)
            with open(fname, 'rb') as dfile:
                buf_single_ndarray = dfile.read()
                single_ndarray_loaded = mx.nd.load_frombuffer(
                    buf_single_ndarray)
                assert len(single_ndarray_loaded) == 1
                single_ndarray_loaded = single_ndarray_loaded[0]
                assert np.sum(single_ndarray.asnumpy() !=
                              single_ndarray_loaded.asnumpy()) == 0
                # test garbage values
                assertRaises(mx.base.MXNetError, mx.nd.load_frombuffer,
                             buf_single_ndarray[:-10])
示例#5
0
def test_checkpoint_handler():
    with TemporaryDirectory() as tmpdir:
        model_prefix = 'test_epoch'
        file_path = os.path.join(tmpdir, model_prefix)
        test_data = _get_test_data()

        net = _get_test_network()
        ce_loss = loss.SoftmaxCrossEntropyLoss()
        acc = mx.gluon.metric.Accuracy()
        est = estimator.Estimator(net, loss=ce_loss, train_metrics=acc)
        checkpoint_handler = event_handler.CheckpointHandler(
            model_dir=tmpdir,
            model_prefix=model_prefix,
            monitor=acc,
            save_best=True,
            epoch_period=1)
        est.fit(test_data, event_handlers=[checkpoint_handler], epochs=1)
        assert checkpoint_handler.current_epoch == 1
        assert checkpoint_handler.current_batch == 4
        assert os.path.isfile(file_path + '-best.params')
        assert os.path.isfile(file_path + '-best.states')
        assert os.path.isfile(file_path + '-epoch0batch4.params')
        assert os.path.isfile(file_path + '-epoch0batch4.states')

        model_prefix = 'test_batch'
        file_path = os.path.join(tmpdir, model_prefix)
        net = _get_test_network(nn.HybridSequential())
        net.hybridize()
        est = estimator.Estimator(net, loss=ce_loss, train_metrics=acc)
        checkpoint_handler = event_handler.CheckpointHandler(
            model_dir=tmpdir,
            model_prefix=model_prefix,
            epoch_period=None,
            batch_period=2,
            max_checkpoints=2)
        est.fit(test_data, event_handlers=[checkpoint_handler], batches=10)
        assert checkpoint_handler.current_batch == 10
        assert checkpoint_handler.current_epoch == 3
        assert not os.path.isfile(file_path + 'best.params')
        assert not os.path.isfile(file_path + 'best.states')
        assert not os.path.isfile(file_path + '-epoch0batch0.params')
        assert not os.path.isfile(file_path + '-epoch0batch0.states')
        assert os.path.isfile(file_path + '-symbol.json')
        assert os.path.isfile(file_path + '-epoch1batch7.params')
        assert os.path.isfile(file_path + '-epoch1batch7.states')
        assert os.path.isfile(file_path + '-epoch2batch9.params')
        assert os.path.isfile(file_path + '-epoch2batch9.states')
def test_logging():
    with TemporaryDirectory() as tmpdir:
        test_data = _get_test_data()
        file_name = 'test_log'
        output_dir = os.path.join(tmpdir, file_name)

        net = _get_test_network()
        ce_loss = loss.SoftmaxCrossEntropyLoss()
        acc = mx.metric.Accuracy()
        est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
        train_metrics, val_metrics = est.prepare_loss_and_metrics()
        logging_handler = event_handler.LoggingHandler(
            file_name=file_name,
            file_location=tmpdir,
            train_metrics=train_metrics,
            val_metrics=val_metrics)
        est.fit(test_data, event_handlers=[logging_handler], epochs=3)
        assert logging_handler.batch_index == 0
        assert logging_handler.current_epoch == 3
        assert os.path.isfile(output_dir)
示例#7
0
def test_logging():
    with TemporaryDirectory() as tmpdir:
        test_data = _get_test_data()
        file_name = 'test_log'
        output_dir = os.path.join(tmpdir, file_name)

        net = _get_test_network()
        ce_loss = loss.SoftmaxCrossEntropyLoss()
        acc = mx.gluon.metric.Accuracy()
        est = estimator.Estimator(net, loss=ce_loss, train_metrics=acc)

        est.logger.addHandler(logging.FileHandler(output_dir))

        train_metrics = est.train_metrics
        val_metrics = est.val_metrics
        logging_handler = event_handler.LoggingHandler(metrics=train_metrics)
        est.fit(test_data, event_handlers=[logging_handler], epochs=3)
        assert logging_handler.batch_index == 0
        assert logging_handler.current_epoch == 3
        assert os.path.isfile(output_dir)
        del est  # Clean up estimator and logger before deleting tmpdir
def test_load_save_symbol():
    batch_size = 10
    num_hdidden = 128
    num_features = 784

    def get_net():
        data = mx.sym.var('data')
        weight = mx.sym.var('weight', shape=(num_hdidden, 0))
        return mx.sym.FullyConnected(data, weight, num_hidden=num_hdidden)

    for flag1 in [False, True]:
        with np_shape(flag1):
            net_json_str = get_net().tojson()
            net_data = json.loads(net_json_str)
            assert "attrs" in net_data
            if flag1:
                assert "is_np_shape" in net_data["attrs"]
            else:
                assert "is_np_shape" not in net_data["attrs"]

        with TemporaryDirectory() as work_dir:
            fname = os.path.join(work_dir, 'test_sym.json')
            with open(fname, 'w') as fp:
                json.dump(net_data, fp)

            # test loading 1.5.0 symbol file since 1.6.0
            # w/ or w/o np_shape semantics
            for flag2 in [False, True]:
                if flag1:  # Do not need to test this case since 0 indicates zero-size dim
                    continue
                with np_shape(flag2):
                    net = mx.sym.load(fname)
                    arg_shapes, out_shapes, aux_shapes = net.infer_shape(
                        data=(batch_size, num_features))
                    assert arg_shapes[0] == (batch_size, num_features)  # data
                    assert arg_shapes[1] == (num_hdidden, num_features
                                             )  # weight
                    assert arg_shapes[2] == (num_hdidden, )  # bias
                    assert out_shapes[0] == (batch_size, num_hdidden)  # output
                    assert len(aux_shapes) == 0