def test_project_repository_version(temp_workdir): with open("replicate.yaml", "w") as f: f.write("repository: file://.replicate") experiment = replicate.init() expected = """{ "version": 1 }""" with open(".replicate/repository.json") as f: assert f.read() == expected # no error on second init experiment = replicate.init() with open(".replicate/repository.json") as f: # repository.json shouldn't have changed assert f.read() == expected with open(".replicate/repository.json", "w") as f: f.write( """{ "version": 2 }""" ) with pytest.raises(NewerRepositoryVersion): replicate.init()
def test_deprecated_repository_backwards_compatible(temp_workdir): os.makedirs(".replicate/storage") experiment = replicate.init() assert isinstance(experiment, Experiment) assert experiment._project._repository_url == "file://.replicate/storage" experiment.stop() with open("replicate.yaml", "w") as f: f.write("repository: file://foobar") experiment = replicate.init() assert isinstance(experiment, Experiment) assert experiment._project._repository_url == "file://foobar" experiment.stop()
def test_heartbeat(temp_workdir): with open("replicate.yaml", "w") as f: f.write("repository: file://.replicate/") experiment = replicate.init() heartbeat_path = f".replicate/metadata/heartbeats/{experiment.id}.json" wait(lambda: os.path.exists(heartbeat_path), timeout_seconds=1, sleep_seconds=0.01) assert json.load(open(heartbeat_path))["experiment_id"] == experiment.id experiment.stop() assert not os.path.exists(heartbeat_path) # check starting and stopping immediately doesn't do anything weird experiment = replicate.init() experiment.stop()
def test_end_to_end(self): with open("replicate.yaml", "w") as f: f.write('repository: "file://.replicate"') with open("foo.txt", "w") as f: f.write("foo") with open("bar.txt", "w") as f: f.write("bar") experiment = replicate.init(path=".", params={"myint": 10, "myfloat": 0.1}) with open("bar.txt", "w") as f: f.write("barrrr") experiment.checkpoint(path="bar.txt", metrics={"value": 123.45}) experiment = replicate.experiments.get(experiment.id) self.assertEqual(10, experiment.params["myint"]) self.assertEqual(0.1, experiment.params["myfloat"]) self.assertEqual(123.45, experiment.checkpoints[0].metrics["value"]) foo = experiment.checkpoints[0].open("foo.txt") self.assertEqual("foo", foo.read().decode("utf-8")) bar = experiment.checkpoints[0].open("bar.txt") self.assertEqual("barrrr", bar.read().decode("utf-8")) with self.assertRaises(ImportError): experiment.plot("value")
def train(learning_rate, num_epochs): # highlight-start # Create an "experiment". This represents a run of your training script. # It saves the training code at the given path and any hyperparameters. experiment = replicate.init( path=".", # highlight-start params={"learning_rate": learning_rate, "num_epochs": num_epochs}, ) # highlight-end print("Downloading data set...") iris = load_iris() train_features, val_features, train_labels, val_labels = train_test_split( iris.data, iris.target, train_size=0.8, test_size=0.2, random_state=0, stratify=iris.target, ) train_features = torch.FloatTensor(train_features) val_features = torch.FloatTensor(val_features) train_labels = torch.LongTensor(train_labels) val_labels = torch.LongTensor(val_labels) torch.manual_seed(0) model = nn.Sequential(nn.Linear(4, 15), nn.ReLU(), nn.Linear(15, 3),) optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) criterion = nn.CrossEntropyLoss() for epoch in range(num_epochs): model.train() optimizer.zero_grad() outputs = model(train_features) loss = criterion(outputs, train_labels) loss.backward() optimizer.step() with torch.no_grad(): model.eval() output = model(val_features) acc = (output.argmax(1) == val_labels).float().sum() / len(val_labels) print( "Epoch {}, train loss: {:.3f}, validation accuracy: {:.3f}".format( epoch, loss.item(), acc ) ) torch.save(model, "model.pth") # highlight-start # Create a checkpoint within the experiment. # This saves the metrics at that point, and makes a copy of the file # or directory given, which could weights and any other artifacts. experiment.checkpoint( path="model.pth", step=epoch, metrics={"loss": loss.item(), "accuracy": acc}, primary_metric=("loss", "minimize"), )
def test_is_running(temp_workdir): with open("replicate.yaml", "w") as f: f.write("repository: file://.replicate/") experiment = replicate.init() # Check whether the experiment is running before heartbeats are saved assert not experiment.is_running() heartbeat_path = f".replicate/metadata/heartbeats/{experiment.id}.json" assert wait(lambda: os.path.exists(heartbeat_path), timeout_seconds=2, sleep_seconds=0.01) # Check whether experiment is running after heartbeats are started assert experiment.is_running() # Heartbeats stopped experiment._heartbeat.kill() assert experiment.is_running() # Modify heartbeat_metadata to record last heartbeat before last tolerable heartbeat heartbeat_metadata = json.load(open(heartbeat_path)) heartbeat_metadata["last_heartbeat"] = rfc3339_datetime( datetime.datetime.utcnow() - HEARTBEAT_MISS_TOLERANCE * DEFAULT_REFRESH_INTERVAL) out_file = open(heartbeat_path, "w") json.dump(heartbeat_metadata, out_file) out_file.close() assert not experiment.is_running() # New experiment to test is_running after stop() experiment = replicate.init() heartbeat_path = f".replicate/metadata/heartbeats/{experiment.id}.json" assert wait(lambda: os.path.exists(heartbeat_path), timeout_seconds=2, sleep_seconds=0.01) assert experiment.is_running() # Check is_running after stopping the experiment experiment.stop() assert not experiment.is_running()
def test_broken_experiment(mock_save, temp_workdir): with open("replicate.yaml", "w") as f: f.write("repository: file://.replicate/") mock_save.side_effect = Exception() # Shouldn't raise an exception experiment = replicate.init() assert isinstance(experiment, BrokenExperiment) experiment.checkpoint() experiment.stop()
def test_s3_experiment(temp_bucket, tmpdir): replicate_yaml_contents = "repository: s3://{bucket}".format(bucket=temp_bucket) with open(os.path.join(tmpdir, "replicate.yaml"), "w") as f: f.write(replicate_yaml_contents) current_workdir = os.getcwd() try: os.chdir(tmpdir) experiment = replicate.init(path=".", params={"foo": "bar"}) checkpoint = experiment.checkpoint( path=".", step=10, metrics={"loss": 1.1, "baz": "qux"} ) actual_experiment_meta = s3_read_json( temp_bucket, os.path.join("metadata", "experiments", experiment.id + ".json"), ) # TODO(andreas): actually check values of host and user assert "host" in actual_experiment_meta assert "user" in actual_experiment_meta del actual_experiment_meta["host"] del actual_experiment_meta["user"] del actual_experiment_meta["command"] del actual_experiment_meta["python_packages"] expected_experiment_meta = { "id": experiment.id, "created": experiment.created.isoformat() + "Z", "params": {"foo": "bar"}, "config": {"repository": "s3://" + temp_bucket}, "path": ".", "checkpoints": [ { "id": checkpoint.id, "created": checkpoint.created.isoformat() + "Z", "step": 10, "metrics": {"loss": 1.1, "baz": "qux"}, "path": ".", "primary_metric": None, } ], "replicate_version": replicate.__version__, } assert actual_experiment_meta == expected_experiment_meta finally: os.chdir(current_workdir)
def test_is_running(temp_workdir): with open("replicate.yaml", "w") as f: f.write("repository: file://.replicate/") experiment = replicate.init() heartbeat_path = f".replicate/metadata/heartbeats/{experiment.id}.json" assert wait(lambda: os.path.exists(heartbeat_path), timeout_seconds=10, sleep_seconds=0.01) # Check whether experiment is running after heartbeats are started assert experiment.is_running() # Heartbeats stopped experiment.stop() assert not experiment.is_running()
def test_init_and_checkpoint(temp_workdir): with open("replicate.yaml", "w") as f: f.write("repository: file://.replicate/") with open("train.py", "w") as fh: fh.write("print(1 + 1)") with open("README.md", "w") as fh: fh.write("Hello") # basic experiment experiment = replicate.init(path=".", params={"learning_rate": 0.002}, disable_heartbeat=True) assert len(experiment.id) == 64 with open(".replicate/metadata/experiments/{}.json".format( experiment.id)) as fh: metadata = json.load(fh) assert metadata["id"] == experiment.id assert metadata["params"] == {"learning_rate": 0.002} with tempfile.TemporaryDirectory() as tmpdir: with tarfile.open(".replicate/experiments/{}.tar.gz".format( experiment.id)) as tar: tar.extractall(tmpdir) assert (open(os.path.join(tmpdir, experiment.id, "train.py")).read() == "print(1 + 1)") assert os.path.exists(os.path.join(tmpdir, experiment.id, "README.md")) # checkpoint with a file with open("weights", "w") as fh: fh.write("1.2kg") checkpoint = experiment.checkpoint(path="weights", step=1, metrics={"validation_loss": 0.123}) assert len(checkpoint.id) == 64 with open(".replicate/metadata/experiments/{}.json".format( experiment.id)) as fh: metadata = json.load(fh) assert len(metadata["checkpoints"]) == 1 checkpoint_metadata = metadata["checkpoints"][0] assert checkpoint_metadata["id"] == checkpoint.id assert checkpoint_metadata["step"] == 1 assert checkpoint_metadata["metrics"] == {"validation_loss": 0.123} with tempfile.TemporaryDirectory() as tmpdir: with tarfile.open(".replicate/checkpoints/{}.tar.gz".format( checkpoint.id)) as tar: tar.extractall(tmpdir) assert open(os.path.join(tmpdir, checkpoint.id, "weights")).read() == "1.2kg" assert not os.path.exists( os.path.join(tmpdir, checkpoint.id, "train.py")) # checkpoint with a directory os.mkdir("data") with open("data/weights", "w") as fh: fh.write("1.3kg") checkpoint = experiment.checkpoint(path="data", step=1, metrics={"validation_loss": 0.123}) with tempfile.TemporaryDirectory() as tmpdir: with tarfile.open(".replicate/checkpoints/{}.tar.gz".format( checkpoint.id)) as tar: tar.extractall(tmpdir) assert (open(os.path.join(tmpdir, checkpoint.id, "data/weights")).read() == "1.3kg") assert not os.path.exists( os.path.join(tmpdir, checkpoint.id, "train.py")) # checkpoint with no path checkpoint = experiment.checkpoint(path=None, step=1, metrics={"validation_loss": 0.123}) with open(".replicate/metadata/experiments/{}.json".format( experiment.id)) as fh: metadata = json.load(fh) assert metadata["checkpoints"][-1]["id"] == checkpoint.id assert not os.path.exists(".replicate/checkpoints/{}.tar.gz".format( checkpoint.id)) # experiment with file experiment = replicate.init(path="train.py", params={"learning_rate": 0.002}, disable_heartbeat=True) with tempfile.TemporaryDirectory() as tmpdir: with tarfile.open(".replicate/experiments/{}.tar.gz".format( experiment.id)) as tar: tar.extractall(tmpdir) assert (open(os.path.join(tmpdir, experiment.id, "train.py")).read() == "print(1 + 1)") assert not os.path.exists( os.path.join(tmpdir, experiment.id, "README.md")) # experiment with no path! experiment = replicate.init(path=None, params={"learning_rate": 0.002}, disable_heartbeat=True) with open(".replicate/metadata/experiments/{}.json".format( experiment.id)) as fh: metadata = json.load(fh) assert metadata["id"] == experiment.id assert metadata["params"] == {"learning_rate": 0.002} assert not os.path.exists(".replicate/experiments/{}.tar.gz".format( experiment.id))
def test_init_without_config_file(temp_workdir): with pytest.raises(ConfigNotFoundError): replicate.init()
def test_init_with_config_file(temp_workdir): with open("replicate.yaml", "w") as f: f.write("repository: file://.replicate/") experiment = replicate.init() assert isinstance(experiment, Experiment) experiment.stop()
def on_pretrain_routine_start(self, trainer, pl_module): self.experiment = replicate.init(path=".", params=self.params)
def test_init_without_config_file(temp_workdir): experiment = replicate.init() assert isinstance(experiment, BrokenExperiment)