def test_get_ckpt_db_name(self): try: tmpdir = tempfile.mkdtemp() num_nodes = 3 checkpoint = MultiNodeCheckpointManager(tmpdir, 'minidb') with Job() as job: for node_id in range(num_nodes): build_pipeline(node_id) compiled_job = job.compile(LocalSession) checkpoint.init(compiled_job.nodes_to_checkpoint()) for node_id in range(num_nodes): epoch = 5 node_name = 'trainer:%d' % node_id expected_db_name = tmpdir + '/' + node_name + '.5' self.assertEquals( checkpoint.get_ckpt_db_name(node_name, epoch), expected_db_name) finally: shutil.rmtree(tmpdir)
def test_ckpt_name_and_load_model_from_ckpts(self): try: num_nodes = 3 tmpdir = tempfile.mkdtemp() # First, check if the checkpoint name generation mechanism is # correct. checkpoint = MultiNodeCheckpointManager(tmpdir, 'minidb') with Cluster(): with Job() as job: for node_id in range(num_nodes): build_pipeline(node_id) compiled_job = job.compile(LocalSession) checkpoint.init(compiled_job.nodes_to_checkpoint()) for node_id in range(num_nodes): epoch = 5 node_name = 'trainer_%d' % node_id expected_db_name = tmpdir + '/' + node_name + '.5' self.assertEquals( checkpoint.get_ckpt_db_name(node_name, epoch), expected_db_name) shutil.rmtree(tmpdir) # Next, check mechanism to load model from checkpoints. tmpdir = tempfile.mkdtemp() workspace.ResetWorkspace() for node_id in range(num_nodes): ws = workspace.C.Workspace() session = LocalSession(ws) checkpoint = MultiNodeCheckpointManager(tmpdir, 'minidb') with Cluster(): with Job() as job: build_pipeline(node_id) compiled_job = job.compile(LocalSession) job_runner = JobRunner(compiled_job, checkpoint) num_epochs = job_runner(session) self.assertEquals(num_epochs, len(EXPECTED_TOTALS)) # There are 12 global blobs after finishing up the job runner. # (only blobs on init_group are checkpointed) self.assertEquals(len(ws.blobs), 12) ws = workspace.C.Workspace() session = LocalSession(ws) self.assertEquals(len(ws.blobs), 0) model_blob_names = ['trainer_1/task_2/GivenTensorInt64Fill:0', 'trainer_2/task_2/GivenTensorInt64Fill:0'] checkpoint = MultiNodeCheckpointManager(tmpdir, 'minidb') with Cluster(): with Job() as job: for node_id in range(num_nodes): build_pipeline(node_id) compiled_job = job.compile(LocalSession) job_runner = JobRunner(compiled_job, checkpoint) job_runner.load_blobs_from_checkpoints( blob_names=model_blob_names, epoch=1, session=session) # Check that we can successfully load from checkpoints of epochs # 1 to 4, but not epoch 5. for epoch in range(1, 5): self.assertTrue( job_runner.load_blobs_from_checkpoints( blob_names=model_blob_names, epoch=epoch, session=session)) # Check that all the model blobs are loaded. for blob_name in model_blob_names: self.assertTrue(ws.has_blob(blob_name)) self.assertEquals( ws.fetch_blob(blob_name), np.array([EXPECTED_TOTALS[epoch - 1]])) self.assertFalse( job_runner.load_blobs_from_checkpoints( blob_names=model_blob_names, epoch=5, session=session)) finally: shutil.rmtree(tmpdir)