def restore(self, restored_tensors, unused_restored_shapes): """Restores the associated tree from 'restored_tensors'. Args: restored_tensors: the tensors that were loaded from a checkpoint. unused_restored_shapes: the shapes this object should conform to after restore. Not meaningful for trees. Returns: The operation that restores the state of the tree variable. """ with ops.control_dependencies([self._create_op]): return gen_model_ops.tree_deserialize( self._tree_handle, restored_tensors[0], params=self.params.serialized_params_proto)