Beispiel #1
0
def test_monitored_session(script_mode):
    """ Works as intended. """
    smd.del_hook()
    tf.reset_default_graph()
    json_file_contents = """
            {
                "S3OutputPath": "s3://sagemaker-test",
                "LocalPath": "/opt/ml/output/tensors",
                "HookParameters" : {
                    "save_interval": "100"
                }
            }
            """
    with SagemakerSimulator(json_file_contents=json_file_contents) as sim:
        train_op, X, Y = get_train_op_and_placeholders()
        init = tf.global_variables_initializer()
        mnist = get_data()

        if script_mode:
            hook = smd.SessionHook(out_dir=sim.out_dir)
            sess = tf.train.MonitoredSession(hooks=[hook])
        else:
            sess = tf.train.MonitoredSession()

        with sess:
            sess.run(init)
            for step in range(1, 101):
                batch_x, batch_y = mnist.train.next_batch(32)
                sess.run(train_op, feed_dict={X: batch_x, Y: batch_y})

        # Check that hook created and tensors saved
        trial = smd.create_trial(path=sim.out_dir)
        assert smd.get_hook() is not None, "Hook was not created."
        assert len(trial.steps()) > 0, "Nothing saved at any step."
        assert len(trial.tensor_names()) > 0, "Tensors were not saved."
Beispiel #2
0
def test_uninit_sess_run(out_dir):
    train_op, X, Y = get_train_op_and_placeholders()
    init = tf.global_variables_initializer()
    mnist = get_data()
    hook = smd.SessionHook(out_dir, include_collections=["weights"])
    sess = tf.train.MonitoredSession(hooks=[hook])

    with sess:
        sess.run(init)
        for step in range(1, 101):
            batch_x, batch_y = mnist.train.next_batch(32)
            sess.run(train_op, feed_dict={X: batch_x, Y: batch_y})

    # Check that hook created and tensors saved
    trial = smd.create_trial(path=out_dir)
    assert len(trial.steps()) > 0, "Nothing saved at any step."
    assert len(trial.tensor_names()) > 0, "Tensors were not saved."
    assert len(trial.tensor_names(collection="weights")) > 0