Пример #1
0
    def __init__(
        self,
        config: ConfigSchema,
        model: Optional[MultiRelationEmbedder] = None,
        trainer: Optional[AbstractBatchProcessor] = None,
        evaluator: Optional[AbstractBatchProcessor] = None,
        rank: Rank = SINGLE_TRAINER,
        subprocess_init: Optional[Callable[[], None]] = None,
        stats_handler: StatsHandler = NOOP_STATS_HANDLER,
    ):

        super().__init__(
            config, model, trainer, evaluator, rank, subprocess_init, stats_handler
        )

        assert config.num_gpus > 0
        if not CPP_INSTALLED:
            raise RuntimeError(
                "GPU support requires C++ installation: "
                "install with C++ support by running "
                "`PBG_INSTALL_CPP=1 pip install .`"
            )

        if config.half_precision:
            for entity in config.entities:
                # need this for tensor cores to work
                assert config.entity_dimension(entity) % 8 == 0
            assert config.batch_size % 8 == 0
            assert config.num_batch_negs % 8 == 0
            assert config.num_uniform_negs % 8 == 0

        assert len(self.holder.lhs_unpartitioned_types) == 0
        assert len(self.holder.rhs_unpartitioned_types) == 0

        num_edge_chunks = self.iteration_manager.num_edge_chunks
        max_edges = 0
        for edge_path in config.edge_paths:
            edge_storage = EDGE_STORAGES.make_instance(edge_path)
            for lhs_part in range(self.holder.nparts_lhs):
                for rhs_part in range(self.holder.nparts_rhs):
                    num_edges = edge_storage.get_number_of_edges(lhs_part, rhs_part)
                    num_edges_per_chunk = div_roundup(num_edges, num_edge_chunks)
                    max_edges = max(max_edges, num_edges_per_chunk)
        self.shared_lhs = allocate_shared_tensor((max_edges,), dtype=torch.long)
        self.shared_rhs = allocate_shared_tensor((max_edges,), dtype=torch.long)
        self.shared_rel = allocate_shared_tensor((max_edges,), dtype=torch.long)

        # fork early for HOGWILD threads
        logger.info("Creating GPU workers...")
        torch.set_num_threads(1)
        self.gpu_pool = GPUProcessPool(
            config.num_gpus,
            subprocess_init,
            {s for ss in self.embedding_storage_freelist.values() for s in ss}
            | {
                self.shared_lhs.storage(),
                self.shared_rhs.storage(),
                self.shared_rel.storage(),
            },
        )
def load_embeddings(hf: h5py.File) -> FloatTensorType:
    dataset: h5py.Dataset = hf[EMBEDDING_DATASET]
    embeddings = allocate_shared_tensor(dataset.shape, dtype=torch.float)
    # Needed because https://github.com/h5py/h5py/issues/870.
    if dataset.size > 0:
        dataset.read_direct(embeddings.numpy())
    return embeddings
Пример #3
0
 def get(self,
         key: str,
         dst: Optional[torch.Tensor] = None,
         shared: bool = False) -> Optional[torch.Tensor]:
     """Get a tensor from the server."""
     self._validate_get(key, dst=dst, shared=shared)
     cmd_rpc = torch.tensor(
         [GET_CMD, len(key), dst is None, 0, 0, 0], dtype=torch.long)
     metadata_pg = self._metadata_pg()
     td.send(cmd_rpc, self.server_rank, group=metadata_pg)
     td.send(_fromstring(key), self.server_rank, group=metadata_pg)
     if dst is None:
         meta = torch.full((2, ), -1, dtype=torch.long)
         td.recv(meta, src=self.server_rank, group=metadata_pg)
         ndim, ttype = meta
         if ndim.item() == -1:
             return None
         size = torch.full((ndim.item(), ), -1, dtype=torch.long)
         td.recv(size, src=self.server_rank, group=metadata_pg)
         dtype = _dtypes[ttype.item()]
         if shared:
             dst = allocate_shared_tensor(size.tolist(), dtype=dtype)
         else:
             dst = torch.empty(size.tolist(), dtype=dtype)
     start_t = time.monotonic()
     data_pgs = self._data_pgs()
     if data_pgs is None:
         td.recv(dst, src=self.server_rank)
     else:
         outstanding_work = []
         flattened_dst = dst.flatten()
         flattened_size = flattened_dst.shape[0]
         for idx, (pg, slice_) in enumerate(
                 zip(
                     data_pgs,
                     split_almost_equally(flattened_size,
                                          num_parts=len(data_pgs)),
                 )):
             outstanding_work.append(
                 td.irecv(
                     tensor=flattened_dst[slice_],
                     src=self.server_rank,
                     group=pg,
                     tag=idx,
                 ))
         for w in outstanding_work:
             w.wait()
     end_t = time.monotonic()
     if self.log_stats:
         stats_size = dst.numel() * dst.element_size()
         stats_time = end_t - start_t
         logger.debug(
             f"Received tensor {key} from server {self.server_rank}: "
             f"{stats_size:,} bytes "
             f"in {stats_time:,g} seconds "
             f"=> {stats_size / stats_time:,.0f} B/s")
     return dst
Пример #4
0
def load_embeddings(hf: h5py.File,
                    out: Optional[FloatTensorType] = None) -> FloatTensorType:
    dataset: h5py.Dataset = hf[EMBEDDING_DATASET]
    if out is None:
        out = allocate_shared_tensor(dataset.shape, dtype=torch.float)
    # Needed because https://github.com/h5py/h5py/issues/870.
    if dataset.size > 0:
        dataset.read_direct(out.numpy())
    return out
Пример #5
0
    def load_chunk_of_edges(
        self,
        lhs_p: int,
        rhs_p: int,
        chunk_idx: int = 0,
        num_chunks: int = 1,
    ) -> EdgeList:
        file_path = self.get_edges_file(lhs_p, rhs_p)
        try:
            with h5py.File(file_path, "r") as hf:
                if hf.attrs.get(FORMAT_VERSION_ATTR, None) != FORMAT_VERSION:
                    raise RuntimeError(f"Version mismatch in edge file {file_path}")
                lhs_ds = hf["lhs"]
                rhs_ds = hf["rhs"]
                rel_ds = hf["rel"]

                num_edges = rel_ds.len()
                begin = int(chunk_idx * num_edges / num_chunks)
                end = int((chunk_idx + 1) * num_edges / num_chunks)
                chunk_size = end - begin

                lhs = allocate_shared_tensor((chunk_size,), dtype=torch.long)
                rhs = allocate_shared_tensor((chunk_size,), dtype=torch.long)
                rel = allocate_shared_tensor((chunk_size,), dtype=torch.long)

                # Needed because https://github.com/h5py/h5py/issues/870.
                if chunk_size > 0:
                    lhs_ds.read_direct(lhs.numpy(), source_sel=np.s_[begin:end])
                    rhs_ds.read_direct(rhs.numpy(), source_sel=np.s_[begin:end])
                    rel_ds.read_direct(rel.numpy(), source_sel=np.s_[begin:end])

                lhsd = self.read_dynamic(hf, "lhsd", begin, end)
                rhsd = self.read_dynamic(hf, "rhsd", begin, end)

                return EdgeList(EntityList(lhs, lhsd),
                                EntityList(rhs, rhsd),
                                rel)
        except OSError as err:
            # h5py refuses to make it easy to figure out what went wrong. The errno
            # attribute is set to None. See https://github.com/h5py/h5py/issues/493.
            if f"errno = {errno.ENOENT}" in str(err):
                raise CouldNotLoadData() from err
            raise err
Пример #6
0
    def read_dynamic(
        hf: h5py.File,
        key: str,
        begin: int,
        end: int,
    ) -> TensorList:
        try:
            offsets_ds = hf[f"{key}_offsets"]
            data_ds = hf[f"{key}_data"]
        except LookupError:
            return TensorList.empty(num_tensors=end - begin)

        offsets = allocate_shared_tensor((end - begin + 1,), dtype=torch.long)
        offsets_ds.read_direct(offsets.numpy(), source_sel=np.s_[begin:end + 1])
        data_begin = offsets[0].item()
        data_end = offsets[-1].item()
        data = allocate_shared_tensor((data_end - data_begin,), dtype=torch.long)
        # Needed because https://github.com/h5py/h5py/issues/870.
        if data_end - data_begin > 0:
            data_ds.read_direct(data.numpy(), source_sel=np.s_[data_begin:data_end])

        offsets -= int(offsets[0])

        return TensorList(offsets, data)
Пример #7
0
 def get(
     self,
     key: str,
     dst: Optional[torch.Tensor] = None,
     shared: bool = False,
 ) -> Optional[torch.Tensor]:
     """Get a tensor from the server.
     """
     cmd_rpc = torch.tensor([GET_CMD, len(key), dst is None, 0, 0, 0], dtype=torch.long)
     td.send(cmd_rpc, self.server_rank)
     td.send(_fromstring(key), self.server_rank)
     if dst is None:
         meta = torch.full((2,), -1, dtype=torch.long)
         td.recv(meta, src=self.server_rank)
         ndim, ttype = meta
         if ndim.item() == -1:
             return None
         size = torch.full((ndim.item(),), -1, dtype=torch.long)
         td.recv(size, src=self.server_rank)
         dtype = _dtypes[ttype.item()]
         if shared:
             dst = allocate_shared_tensor(size.tolist(), dtype=dtype)
         else:
             dst = torch.empty(size.tolist(), dtype=dtype)
     start_t = time.monotonic()
     td.recv(dst, src=self.server_rank)
     end_t = time.monotonic()
     if self.log_stats:
         stats_size = dst.numel() * dst.element_size()
         stats_time = end_t - start_t
         logger.debug(
             f"Received tensor {key} from server {self.server_rank}: "
             f"{stats_size:,} bytes "
             f"in {stats_time:,g} seconds "
             f"=> {stats_size / stats_time:,.0f} B/s")
     return dst
Пример #8
0
    def __init__(  # noqa
        self,
        config: ConfigSchema,
        model: Optional[MultiRelationEmbedder] = None,
        trainer: Optional[AbstractBatchProcessor] = None,
        evaluator: Optional[AbstractBatchProcessor] = None,
        rank: Rank = SINGLE_TRAINER,
        subprocess_init: Optional[Callable[[], None]] = None,
        stats_handler: StatsHandler = NOOP_STATS_HANDLER,
    ):
        """Each epoch/pass, for each partition pair, loads in embeddings and edgelist
        from disk, runs HOGWILD training on them, and writes partitions back to disk.
        """
        tag_logs_with_process_name(f"Trainer-{rank}")
        self.config = config
        if config.verbose > 0:
            import pprint

            pprint.PrettyPrinter().pprint(config.to_dict())

        logger.info("Loading entity counts...")
        entity_storage = ENTITY_STORAGES.make_instance(config.entity_path)
        entity_counts: Dict[str, List[int]] = {}
        for entity, econf in config.entities.items():
            entity_counts[entity] = []
            for part in range(econf.num_partitions):
                entity_counts[entity].append(entity_storage.load_count(entity, part))

        # Figure out how many lhs and rhs partitions we need
        holder = self.holder = EmbeddingHolder(config)

        logger.debug(
            f"nparts {holder.nparts_lhs} {holder.nparts_rhs} "
            f"types {holder.lhs_partitioned_types} {holder.rhs_partitioned_types}"
        )

        # We know ahead of time that we wil need 1-2 storages for each embedding type,
        # as well as the max size of this storage (num_entities x D).
        # We allocate these storages n advance in `embedding_storage_freelist`.
        # When we need storage for an entity type, we pop it from this free list,
        # and then add it back when we 'delete' the embedding table.
        embedding_storage_freelist: Dict[
            EntityName, Set[torch.FloatStorage]
        ] = defaultdict(set)
        for entity_type, counts in entity_counts.items():
            max_count = max(counts)
            num_sides = (
                (1 if entity_type in holder.lhs_partitioned_types else 0)
                + (1 if entity_type in holder.rhs_partitioned_types else 0)
                + (
                    1
                    if entity_type
                    in (holder.lhs_unpartitioned_types | holder.rhs_unpartitioned_types)
                    else 0
                )
            )
            for _ in range(num_sides):
                embedding_storage_freelist[entity_type].add(
                    allocate_shared_tensor(
                        (max_count, config.entity_dimension(entity_type)),
                        dtype=torch.float,
                    ).storage()
                )

        # create the handlers, threads, etc. for distributed training
        if config.num_machines > 1 or config.num_partition_servers > 0:
            if not 0 <= rank < config.num_machines:
                raise RuntimeError("Invalid rank for trainer")
            if not td.is_available():
                raise RuntimeError(
                    "The installed PyTorch version doesn't provide "
                    "distributed training capabilities."
                )
            ranks = ProcessRanks.from_num_invocations(
                config.num_machines, config.num_partition_servers
            )

            num_ps_groups = config.num_groups_for_partition_server
            groups: List[List[int]] = [ranks.trainers]  # barrier group
            groups += [
                ranks.trainers + ranks.partition_servers
            ] * num_ps_groups  # ps groups
            group_idxs_for_partition_servers = range(1, len(groups))

            if rank == SINGLE_TRAINER:
                logger.info("Setup lock server...")
                start_server(
                    LockServer(
                        num_clients=len(ranks.trainers),
                        nparts_lhs=holder.nparts_lhs,
                        nparts_rhs=holder.nparts_rhs,
                        entities_lhs=holder.lhs_partitioned_types,
                        entities_rhs=holder.rhs_partitioned_types,
                        entity_counts=entity_counts,
                        init_tree=config.distributed_tree_init_order,
                        stats_handler=stats_handler,
                    ),
                    process_name="LockServer",
                    init_method=config.distributed_init_method,
                    world_size=ranks.world_size,
                    server_rank=ranks.lock_server,
                    groups=groups,
                    subprocess_init=subprocess_init,
                )

            self.bucket_scheduler = DistributedBucketScheduler(
                server_rank=ranks.lock_server, client_rank=ranks.trainers[rank]
            )

            logger.info("Setup param server...")
            start_server(
                ParameterServer(num_clients=len(ranks.trainers)),
                process_name=f"ParamS-{rank}",
                init_method=config.distributed_init_method,
                world_size=ranks.world_size,
                server_rank=ranks.parameter_servers[rank],
                groups=groups,
                subprocess_init=subprocess_init,
            )

            parameter_sharer = ParameterSharer(
                process_name=f"ParamC-{rank}",
                client_rank=ranks.parameter_clients[rank],
                all_server_ranks=ranks.parameter_servers,
                init_method=config.distributed_init_method,
                world_size=ranks.world_size,
                groups=groups,
                subprocess_init=subprocess_init,
            )

            if config.num_partition_servers == -1:
                start_server(
                    ParameterServer(
                        num_clients=len(ranks.trainers),
                        group_idxs=group_idxs_for_partition_servers,
                        log_stats=True,
                    ),
                    process_name=f"PartS-{rank}",
                    init_method=config.distributed_init_method,
                    world_size=ranks.world_size,
                    server_rank=ranks.partition_servers[rank],
                    groups=groups,
                    subprocess_init=subprocess_init,
                )

            groups = init_process_group(
                rank=ranks.trainers[rank],
                world_size=ranks.world_size,
                init_method=config.distributed_init_method,
                groups=groups,
            )
            trainer_group, *groups_for_partition_servers = groups
            self.barrier_group = trainer_group

            if len(ranks.partition_servers) > 0:
                partition_client = PartitionClient(
                    ranks.partition_servers,
                    groups=groups_for_partition_servers,
                    log_stats=True,
                )
            else:
                partition_client = None
        else:
            self.barrier_group = None
            self.bucket_scheduler = SingleMachineBucketScheduler(
                holder.nparts_lhs, holder.nparts_rhs, config.bucket_order, stats_handler
            )
            parameter_sharer = None
            partition_client = None
            hide_distributed_logging()

        # fork early for HOGWILD threads
        logger.info("Creating workers...")
        self.num_workers = get_num_workers(config.workers)
        self.pool = create_pool(
            self.num_workers,
            subprocess_name=f"TWorker-{rank}",
            subprocess_init=subprocess_init,
        )

        checkpoint_manager = CheckpointManager(
            config.checkpoint_path,
            rank=rank,
            num_machines=config.num_machines,
            partition_client=partition_client,
            subprocess_name=f"BackgRW-{rank}",
            subprocess_init=subprocess_init,
        )
        self.checkpoint_manager = checkpoint_manager
        checkpoint_manager.register_metadata_provider(ConfigMetadataProvider(config))
        if rank == 0:
            checkpoint_manager.write_config(config)

        num_edge_chunks = get_num_edge_chunks(config)

        self.iteration_manager = IterationManager(
            config.num_epochs,
            config.edge_paths,
            num_edge_chunks,
            iteration_idx=checkpoint_manager.checkpoint_version,
        )
        checkpoint_manager.register_metadata_provider(self.iteration_manager)

        logger.info("Initializing global model...")
        if model is None:
            model = make_model(config)
        model.share_memory()
        loss_fn = LOSS_FUNCTIONS.get_class(config.loss_fn)(margin=config.margin)
        relation_weights = [relation.weight for relation in config.relations]
        if trainer is None:
            trainer = Trainer(
                model_optimizer=make_optimizer(config, model.parameters(), False),
                loss_fn=loss_fn,
                relation_weights=relation_weights,
            )
        if evaluator is None:
            eval_overrides = {}
            if config.eval_num_batch_negs is not None:
                eval_overrides["num_batch_negs"] = config.eval_num_batch_negs
            if config.eval_num_uniform_negs is not None:
                eval_overrides["num_uniform_negs"] = config.eval_num_uniform_negs

            evaluator = RankingEvaluator(
                loss_fn=loss_fn,
                relation_weights=relation_weights,
                overrides=eval_overrides,
            )

        if config.init_path is not None:
            self.loadpath_manager = CheckpointManager(config.init_path)
        else:
            self.loadpath_manager = None

        # load model from checkpoint or loadpath, if available
        state_dict, optim_state = checkpoint_manager.maybe_read_model()
        if state_dict is None and self.loadpath_manager is not None:
            state_dict, optim_state = self.loadpath_manager.maybe_read_model()
        if state_dict is not None:
            model.load_state_dict(state_dict, strict=False)
        if optim_state is not None:
            trainer.model_optimizer.load_state_dict(optim_state)

        logger.debug("Loading unpartitioned entities...")
        for entity in holder.lhs_unpartitioned_types | holder.rhs_unpartitioned_types:
            count = entity_counts[entity][0]
            s = embedding_storage_freelist[entity].pop()
            dimension = config.entity_dimension(entity)
            embs = torch.FloatTensor(s).view(-1, dimension)[:count]
            embs, optimizer = self._load_embeddings(entity, UNPARTITIONED, out=embs)
            holder.unpartitioned_embeddings[entity] = embs
            trainer.unpartitioned_optimizers[entity] = optimizer

        # start communicating shared parameters with the parameter server
        if parameter_sharer is not None:
            shared_parameters: Set[int] = set()
            for name, param in model.named_parameters():
                if id(param) in shared_parameters:
                    continue
                shared_parameters.add(id(param))
                key = f"model.{name}"
                logger.info(
                    f"Adding {key} ({param.numel()} params) to parameter server"
                )
                parameter_sharer.set_param(key, param.data)
            for entity, embs in holder.unpartitioned_embeddings.items():
                key = f"entity.{entity}"
                logger.info(f"Adding {key} ({embs.numel()} params) to parameter server")
                parameter_sharer.set_param(key, embs.data)

        # store everything in self
        self.model = model
        self.trainer = trainer
        self.evaluator = evaluator
        self.rank = rank
        self.entity_counts = entity_counts
        self.embedding_storage_freelist = embedding_storage_freelist
        self.stats_handler = stats_handler

        self.strict = False