예제 #1
0
def test_migration_checkpoint_removal(start_connected_emptyhead_cluster,
                                      tmpdir, durable):
    """Test checks that trial restarts if checkpoint is lost w/ node fail."""
    cluster = start_connected_emptyhead_cluster
    node = cluster.add_node(num_cpus=1)
    cluster.wait_for_nodes()

    if durable:
        upload_dir = "file://" + str(tmpdir)
        syncer_callback = SyncerCallback()
    else:
        upload_dir = None
        syncer_callback = custom_driver_logdir_callback(str(tmpdir))

    runner = TrialRunner(BasicVariantGenerator(), callbacks=[syncer_callback])
    kwargs = {
        "stopping_criterion": {
            "training_iteration": 4
        },
        "checkpoint_freq": 2,
        "max_failures": 2,
        "remote_checkpoint_dir": upload_dir,
    }

    # Test recovery of trial that has been checkpointed
    t1 = Trial("__fake", **kwargs)
    runner.add_trial(t1)

    # Start trial, process result (x2), process save
    while not t1.has_checkpoint():
        runner.step()

    cluster.add_node(num_cpus=1)
    cluster.remove_node(node)
    cluster.wait_for_nodes()

    # Remove checkpoint on "remote" node
    shutil.rmtree(os.path.dirname(t1.checkpoint.dir_or_data))

    if not durable:
        # Recover from driver file
        t1.checkpoint.dir_or_data = os.path.join(
            tmpdir,
            t1.relative_logdir,
            os.path.relpath(t1.checkpoint.dir_or_data, t1.logdir),
        )

    while not runner.is_finished():
        runner.step()
    assert t1.status == Trial.TERMINATED, runner.debug_string()
예제 #2
0
def test_trial_migration(start_connected_emptyhead_cluster, tmpdir, durable):
    """Removing a node while cluster has space should migrate trial.

    The trial state should also be consistent with the checkpoint.
    """
    cluster = start_connected_emptyhead_cluster
    node = cluster.add_node(num_cpus=1)
    cluster.wait_for_nodes()

    if durable:
        upload_dir = "file://" + str(tmpdir)
        syncer_callback = SyncerCallback()
    else:
        upload_dir = None
        syncer_callback = custom_driver_logdir_callback(str(tmpdir))

    runner = TrialRunner(BasicVariantGenerator(), callbacks=[syncer_callback])
    kwargs = {
        "stopping_criterion": {
            "training_iteration": 4
        },
        "checkpoint_freq": 2,
        "max_failures": 2,
        "remote_checkpoint_dir": upload_dir,
    }

    # Test recovery of trial that hasn't been checkpointed
    t = Trial("__fake", **kwargs)
    runner.add_trial(t)
    runner.step()  # Start trial
    runner.step()  # Process result
    assert t.last_result
    node2 = cluster.add_node(num_cpus=1)
    cluster.remove_node(node)
    cluster.wait_for_nodes()
    # TODO(ujvl): Node failure does not propagate until a step after it
    #  actually should. This is possibly a problem with `Cluster`.
    runner.step()
    runner.step()  # Recovery step

    # TODO(rliaw): This assertion is not critical but will not pass
    #   because checkpoint handling is messy and should be refactored
    #   rather than hotfixed.
    # assert t.last_result is None, "Trial result not restored correctly."

    # Process result (x2), process save, process result (x2), process save
    while not runner.is_finished():
        runner.step()

    assert t.status == Trial.TERMINATED, runner.debug_string()

    # Test recovery of trial that has been checkpointed
    t2 = Trial("__fake", **kwargs)
    runner.add_trial(t2)
    # Start trial, process result (x2), process save
    while not t2.has_checkpoint():
        runner.step()
    node3 = cluster.add_node(num_cpus=1)
    cluster.remove_node(node2)
    cluster.wait_for_nodes()
    while not runner.is_finished():
        runner.step()
    assert t2.status == Trial.TERMINATED, runner.debug_string()

    # Test recovery of trial that won't be checkpointed
    kwargs = {
        "stopping_criterion": {
            "training_iteration": 3
        },
        "remote_checkpoint_dir": upload_dir,
    }

    t3 = Trial("__fake", **kwargs)
    runner.add_trial(t3)
    runner.step()  # Start trial
    runner.step()  # Process result 1
    cluster.add_node(num_cpus=1)
    cluster.remove_node(node3)
    cluster.wait_for_nodes()
    while not runner.is_finished():
        runner.step()
    assert t3.status == Trial.ERROR, runner.debug_string()

    with pytest.raises(TuneError):
        runner.step()