Ejemplo n.º 1
0
    def restore_from_object(self, obj):
        self.temp_checkpoint_dir = FuncCheckpointUtil.mk_temp_checkpoint_dir(
            self.logdir)
        checkpoint = Checkpoint.from_bytes(obj)
        checkpoint.to_directory(self.temp_checkpoint_dir)

        self.restore(self.temp_checkpoint_dir)
Ejemplo n.º 2
0
    def testBytesCheckpointSerde(self):
        # Bytes checkpoints are just dict checkpoints constructed
        # from pickled data, so we compare with the source dict checkpoint.
        source_checkpoint = Checkpoint.from_dict({"checkpoint_data": 5})
        blob = source_checkpoint.to_bytes()
        checkpoint = Checkpoint.from_bytes(blob)

        self._testCheckpointSerde(
            checkpoint, *source_checkpoint.get_internal_representation())
Ejemplo n.º 3
0
    def restore_from_object(self, obj):
        """Restores training state from a checkpoint object.

        These checkpoints are returned from calls to save_to_object().
        """
        checkpoint = Checkpoint.from_bytes(obj)

        with checkpoint.as_directory() as checkpoint_path:
            self.restore(checkpoint_path)
Ejemplo n.º 4
0
    def test_dict_checkpoint_bytes(self):
        """Test conversion from dict to bytes checkpoint and back."""
        checkpoint = self._prepare_dict_checkpoint()

        # Convert into bytes checkpoint
        blob = checkpoint.to_bytes()
        self.assertIsInstance(blob, bytes)

        # Create from bytes
        checkpoint = Checkpoint.from_bytes(blob)
        self.assertTrue(checkpoint._data_dict)

        self._assert_dict_checkpoint(checkpoint)
Ejemplo n.º 5
0
def get_checkpoint_from_remote_node(
        checkpoint_path: str,
        node_ip: str,
        timeout: float = 300.0) -> Optional[Checkpoint]:
    if not any(node["NodeManagerAddress"] == node_ip and node["Alive"]
               for node in ray.nodes()):
        logger.warning(
            f"Could not fetch checkpoint with path {checkpoint_path} from "
            f"node with IP {node_ip} because the node is not available "
            f"anymore.")
        return None
    fut = _serialize_checkpoint.options(resources={
        f"node:{node_ip}": 0.01
    },
                                        num_cpus=0).remote(checkpoint_path)
    try:
        checkpoint_data = ray.get(fut, timeout=timeout)
    except Exception as e:
        logger.warning(
            f"Could not fetch checkpoint with path {checkpoint_path} from "
            f"node with IP {node_ip} because serialization failed: {e}")
        return None
    return Checkpoint.from_bytes(checkpoint_data)