Exemple #1
0
def test_project_repository_version(temp_workdir):
    with open("replicate.yaml", "w") as f:
        f.write("repository: file://.replicate")
    experiment = replicate.init()

    expected = """{
  "version": 1
}"""
    with open(".replicate/repository.json") as f:
        assert f.read() == expected

    # no error on second init
    experiment = replicate.init()
    with open(".replicate/repository.json") as f:
        # repository.json shouldn't have changed
        assert f.read() == expected

    with open(".replicate/repository.json", "w") as f:
        f.write(
            """{
  "version": 2
}"""
        )
    with pytest.raises(NewerRepositoryVersion):
        replicate.init()
Exemple #2
0
def test_deprecated_repository_backwards_compatible(temp_workdir):
    os.makedirs(".replicate/storage")
    experiment = replicate.init()
    assert isinstance(experiment, Experiment)
    assert experiment._project._repository_url == "file://.replicate/storage"
    experiment.stop()

    with open("replicate.yaml", "w") as f:
        f.write("repository: file://foobar")
    experiment = replicate.init()
    assert isinstance(experiment, Experiment)
    assert experiment._project._repository_url == "file://foobar"
    experiment.stop()
Exemple #3
0
def test_heartbeat(temp_workdir):
    with open("replicate.yaml", "w") as f:
        f.write("repository: file://.replicate/")

    experiment = replicate.init()
    heartbeat_path = f".replicate/metadata/heartbeats/{experiment.id}.json"
    wait(lambda: os.path.exists(heartbeat_path), timeout_seconds=1, sleep_seconds=0.01)
    assert json.load(open(heartbeat_path))["experiment_id"] == experiment.id
    experiment.stop()
    assert not os.path.exists(heartbeat_path)

    # check starting and stopping immediately doesn't do anything weird
    experiment = replicate.init()
    experiment.stop()
Exemple #4
0
    def test_end_to_end(self):
        with open("replicate.yaml", "w") as f:
            f.write('repository: "file://.replicate"')

        with open("foo.txt", "w") as f:
            f.write("foo")
        with open("bar.txt", "w") as f:
            f.write("bar")

        experiment = replicate.init(path=".", params={"myint": 10, "myfloat": 0.1})

        with open("bar.txt", "w") as f:
            f.write("barrrr")

        experiment.checkpoint(path="bar.txt", metrics={"value": 123.45})

        experiment = replicate.experiments.get(experiment.id)
        self.assertEqual(10, experiment.params["myint"])
        self.assertEqual(0.1, experiment.params["myfloat"])
        self.assertEqual(123.45, experiment.checkpoints[0].metrics["value"])

        foo = experiment.checkpoints[0].open("foo.txt")
        self.assertEqual("foo", foo.read().decode("utf-8"))
        bar = experiment.checkpoints[0].open("bar.txt")
        self.assertEqual("barrrr", bar.read().decode("utf-8"))

        with self.assertRaises(ImportError):
            experiment.plot("value")
Exemple #5
0
def train(learning_rate, num_epochs):
    # highlight-start
    # Create an "experiment". This represents a run of your training script.
    # It saves the training code at the given path and any hyperparameters.
    experiment = replicate.init(
        path=".",
        # highlight-start
        params={"learning_rate": learning_rate, "num_epochs": num_epochs},
    )
    # highlight-end

    print("Downloading data set...")
    iris = load_iris()
    train_features, val_features, train_labels, val_labels = train_test_split(
        iris.data,
        iris.target,
        train_size=0.8,
        test_size=0.2,
        random_state=0,
        stratify=iris.target,
    )
    train_features = torch.FloatTensor(train_features)
    val_features = torch.FloatTensor(val_features)
    train_labels = torch.LongTensor(train_labels)
    val_labels = torch.LongTensor(val_labels)

    torch.manual_seed(0)
    model = nn.Sequential(nn.Linear(4, 15), nn.ReLU(), nn.Linear(15, 3),)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(num_epochs):
        model.train()
        optimizer.zero_grad()
        outputs = model(train_features)
        loss = criterion(outputs, train_labels)
        loss.backward()
        optimizer.step()

        with torch.no_grad():
            model.eval()
            output = model(val_features)
            acc = (output.argmax(1) == val_labels).float().sum() / len(val_labels)

        print(
            "Epoch {}, train loss: {:.3f}, validation accuracy: {:.3f}".format(
                epoch, loss.item(), acc
            )
        )
        torch.save(model, "model.pth")
        # highlight-start
        # Create a checkpoint within the experiment.
        # This saves the metrics at that point, and makes a copy of the file
        # or directory given, which could weights and any other artifacts.
        experiment.checkpoint(
            path="model.pth",
            step=epoch,
            metrics={"loss": loss.item(), "accuracy": acc},
            primary_metric=("loss", "minimize"),
        )
Exemple #6
0
def test_is_running(temp_workdir):
    with open("replicate.yaml", "w") as f:
        f.write("repository: file://.replicate/")

    experiment = replicate.init()
    # Check whether the experiment is running before heartbeats are saved
    assert not experiment.is_running()

    heartbeat_path = f".replicate/metadata/heartbeats/{experiment.id}.json"
    assert wait(lambda: os.path.exists(heartbeat_path),
                timeout_seconds=2,
                sleep_seconds=0.01)

    # Check whether experiment is running after heartbeats are started
    assert experiment.is_running()

    # Heartbeats stopped
    experiment._heartbeat.kill()
    assert experiment.is_running()

    # Modify heartbeat_metadata to record last heartbeat before last tolerable heartbeat
    heartbeat_metadata = json.load(open(heartbeat_path))
    heartbeat_metadata["last_heartbeat"] = rfc3339_datetime(
        datetime.datetime.utcnow() -
        HEARTBEAT_MISS_TOLERANCE * DEFAULT_REFRESH_INTERVAL)

    out_file = open(heartbeat_path, "w")
    json.dump(heartbeat_metadata, out_file)
    out_file.close()

    assert not experiment.is_running()

    # New experiment to test is_running after stop()
    experiment = replicate.init()
    heartbeat_path = f".replicate/metadata/heartbeats/{experiment.id}.json"
    assert wait(lambda: os.path.exists(heartbeat_path),
                timeout_seconds=2,
                sleep_seconds=0.01)
    assert experiment.is_running()

    # Check is_running after stopping the experiment
    experiment.stop()
    assert not experiment.is_running()
Exemple #7
0
def test_broken_experiment(mock_save, temp_workdir):
    with open("replicate.yaml", "w") as f:
        f.write("repository: file://.replicate/")

    mock_save.side_effect = Exception()
    # Shouldn't raise an exception
    experiment = replicate.init()
    assert isinstance(experiment, BrokenExperiment)
    experiment.checkpoint()
    experiment.stop()
def test_s3_experiment(temp_bucket, tmpdir):
    replicate_yaml_contents = "repository: s3://{bucket}".format(bucket=temp_bucket)

    with open(os.path.join(tmpdir, "replicate.yaml"), "w") as f:
        f.write(replicate_yaml_contents)

    current_workdir = os.getcwd()
    try:
        os.chdir(tmpdir)
        experiment = replicate.init(path=".", params={"foo": "bar"})
        checkpoint = experiment.checkpoint(
            path=".", step=10, metrics={"loss": 1.1, "baz": "qux"}
        )

        actual_experiment_meta = s3_read_json(
            temp_bucket,
            os.path.join("metadata", "experiments", experiment.id + ".json"),
        )
        # TODO(andreas): actually check values of host and user
        assert "host" in actual_experiment_meta
        assert "user" in actual_experiment_meta
        del actual_experiment_meta["host"]
        del actual_experiment_meta["user"]
        del actual_experiment_meta["command"]
        del actual_experiment_meta["python_packages"]

        expected_experiment_meta = {
            "id": experiment.id,
            "created": experiment.created.isoformat() + "Z",
            "params": {"foo": "bar"},
            "config": {"repository": "s3://" + temp_bucket},
            "path": ".",
            "checkpoints": [
                {
                    "id": checkpoint.id,
                    "created": checkpoint.created.isoformat() + "Z",
                    "step": 10,
                    "metrics": {"loss": 1.1, "baz": "qux"},
                    "path": ".",
                    "primary_metric": None,
                }
            ],
            "replicate_version": replicate.__version__,
        }
        assert actual_experiment_meta == expected_experiment_meta

    finally:
        os.chdir(current_workdir)
def test_is_running(temp_workdir):
    with open("replicate.yaml", "w") as f:
        f.write("repository: file://.replicate/")

    experiment = replicate.init()

    heartbeat_path = f".replicate/metadata/heartbeats/{experiment.id}.json"

    assert wait(lambda: os.path.exists(heartbeat_path),
                timeout_seconds=10,
                sleep_seconds=0.01)

    # Check whether experiment is running after heartbeats are started
    assert experiment.is_running()

    # Heartbeats stopped
    experiment.stop()
    assert not experiment.is_running()
Exemple #10
0
def test_init_and_checkpoint(temp_workdir):
    with open("replicate.yaml", "w") as f:
        f.write("repository: file://.replicate/")

    with open("train.py", "w") as fh:
        fh.write("print(1 + 1)")

    with open("README.md", "w") as fh:
        fh.write("Hello")

    # basic experiment
    experiment = replicate.init(path=".",
                                params={"learning_rate": 0.002},
                                disable_heartbeat=True)

    assert len(experiment.id) == 64
    with open(".replicate/metadata/experiments/{}.json".format(
            experiment.id)) as fh:
        metadata = json.load(fh)
    assert metadata["id"] == experiment.id
    assert metadata["params"] == {"learning_rate": 0.002}

    with tempfile.TemporaryDirectory() as tmpdir:
        with tarfile.open(".replicate/experiments/{}.tar.gz".format(
                experiment.id)) as tar:
            tar.extractall(tmpdir)

        assert (open(os.path.join(tmpdir, experiment.id,
                                  "train.py")).read() == "print(1 + 1)")
        assert os.path.exists(os.path.join(tmpdir, experiment.id, "README.md"))

    # checkpoint with a file
    with open("weights", "w") as fh:
        fh.write("1.2kg")

    checkpoint = experiment.checkpoint(path="weights",
                                       step=1,
                                       metrics={"validation_loss": 0.123})

    assert len(checkpoint.id) == 64
    with open(".replicate/metadata/experiments/{}.json".format(
            experiment.id)) as fh:
        metadata = json.load(fh)
    assert len(metadata["checkpoints"]) == 1
    checkpoint_metadata = metadata["checkpoints"][0]
    assert checkpoint_metadata["id"] == checkpoint.id
    assert checkpoint_metadata["step"] == 1
    assert checkpoint_metadata["metrics"] == {"validation_loss": 0.123}

    with tempfile.TemporaryDirectory() as tmpdir:
        with tarfile.open(".replicate/checkpoints/{}.tar.gz".format(
                checkpoint.id)) as tar:
            tar.extractall(tmpdir)

        assert open(os.path.join(tmpdir, checkpoint.id,
                                 "weights")).read() == "1.2kg"
        assert not os.path.exists(
            os.path.join(tmpdir, checkpoint.id, "train.py"))

    # checkpoint with a directory
    os.mkdir("data")
    with open("data/weights", "w") as fh:
        fh.write("1.3kg")

    checkpoint = experiment.checkpoint(path="data",
                                       step=1,
                                       metrics={"validation_loss": 0.123})

    with tempfile.TemporaryDirectory() as tmpdir:
        with tarfile.open(".replicate/checkpoints/{}.tar.gz".format(
                checkpoint.id)) as tar:
            tar.extractall(tmpdir)

        assert (open(os.path.join(tmpdir, checkpoint.id,
                                  "data/weights")).read() == "1.3kg")
        assert not os.path.exists(
            os.path.join(tmpdir, checkpoint.id, "train.py"))

    # checkpoint with no path
    checkpoint = experiment.checkpoint(path=None,
                                       step=1,
                                       metrics={"validation_loss": 0.123})
    with open(".replicate/metadata/experiments/{}.json".format(
            experiment.id)) as fh:
        metadata = json.load(fh)
    assert metadata["checkpoints"][-1]["id"] == checkpoint.id
    assert not os.path.exists(".replicate/checkpoints/{}.tar.gz".format(
        checkpoint.id))

    # experiment with file
    experiment = replicate.init(path="train.py",
                                params={"learning_rate": 0.002},
                                disable_heartbeat=True)
    with tempfile.TemporaryDirectory() as tmpdir:
        with tarfile.open(".replicate/experiments/{}.tar.gz".format(
                experiment.id)) as tar:
            tar.extractall(tmpdir)

        assert (open(os.path.join(tmpdir, experiment.id,
                                  "train.py")).read() == "print(1 + 1)")
        assert not os.path.exists(
            os.path.join(tmpdir, experiment.id, "README.md"))

    # experiment with no path!
    experiment = replicate.init(path=None,
                                params={"learning_rate": 0.002},
                                disable_heartbeat=True)
    with open(".replicate/metadata/experiments/{}.json".format(
            experiment.id)) as fh:
        metadata = json.load(fh)
    assert metadata["id"] == experiment.id
    assert metadata["params"] == {"learning_rate": 0.002}
    assert not os.path.exists(".replicate/experiments/{}.tar.gz".format(
        experiment.id))
Exemple #11
0
def test_init_without_config_file(temp_workdir):
    with pytest.raises(ConfigNotFoundError):
        replicate.init()
Exemple #12
0
def test_init_with_config_file(temp_workdir):
    with open("replicate.yaml", "w") as f:
        f.write("repository: file://.replicate/")
    experiment = replicate.init()
    assert isinstance(experiment, Experiment)
    experiment.stop()
Exemple #13
0
 def on_pretrain_routine_start(self, trainer, pl_module):
     self.experiment = replicate.init(path=".", params=self.params)
Exemple #14
0
def test_init_without_config_file(temp_workdir):
    experiment = replicate.init()
    assert isinstance(experiment, BrokenExperiment)