예제 #1
0
def create_layer_of_buckets(
    nparts_lhs: int,
    nparts_rhs: int,
    layer_idx: int,
    *,
    generator: random.Random,
) -> List[Bucket]:
    """Return the layer of #LHS x #RHS matrix of the given index

    The i-th layer contains the buckets (lhs, rhs) such that min(lhs, rhs) == i.
    Buckets that are one the transpose of the other will be consecutive. Other
    than that, the order is random.

    """
    layer_p = Partition(layer_idx)
    pairs = [[Bucket(layer_p, layer_p)]]
    for idx in range(layer_idx + 1, max(nparts_lhs, nparts_rhs)):
        p = Partition(idx)
        pair = []
        if p < nparts_lhs:
            pair.append(Bucket(p, layer_p))
        if p < nparts_rhs:
            pair.append(Bucket(layer_p, p))
        generator.shuffle(pair)
        pairs.append(pair)
    generator.shuffle(pairs)
    return [b for p in pairs for b in p]
예제 #2
0
    def __init__(self, config: ConfigSchema, filter_paths: List[str]):
        super().__init__()
        if len(config.relations) != 1 or len(config.entities) != 1:
            raise RuntimeError("Filtered ranking evaluation should only be used "
                               "with dynamic relations and one entity type.")
        if not config.relations[0].all_negs:
            raise RuntimeError("Filtered Eval can only be done with all negatives.")

        entity, = config.entities.values()
        if entity.featurized:
            raise RuntimeError("Entity cannot be featurized for filtered eval.")
        if entity.num_partitions > 1:
            raise RuntimeError("Entity cannot be partitioned for filtered eval.")

        self.lhs_map: Dict[Tuple[int, int], List[int]] = defaultdict(list)
        self.rhs_map: Dict[Tuple[int, int], List[int]] = defaultdict(list)
        for path in filter_paths:
            logger.info(f"Building links map from path {path}")
            e_reader = EDGELIST_READERS.make_instance(path)
            # Assume unpartitioned.
            edges = e_reader.read_edgelist(Partition(0), Partition(0))
            for idx in range(len(edges)):
                # Assume non-featurized.
                cur_lhs = int(edges.lhs.to_tensor()[idx])
                # Assume dynamic relations.
                cur_rel = int(edges.rel[idx])
                # Assume non-featurized.
                cur_rhs = int(edges.rhs.to_tensor()[idx])

                self.lhs_map[cur_lhs, cur_rel].append(cur_rhs)
                self.rhs_map[cur_rhs, cur_rel].append(cur_lhs)

            logger.info(f"Done building links map from path {path}")
예제 #3
0
    def __init__(self, config: ConfigSchema, filter_paths: List[str]):
        if len(config.relations) != 1 or len(config.entities) != 1:
            raise RuntimeError("Filtered ranking evaluation should only be used "
                               "with dynamic relations and one entity type.")
        if not config.relations[0].all_negs:
            raise RuntimeError("Filtered Eval can only be done with all negatives.")

        entity, = config.entities.values()
        if entity.featurized:
            raise RuntimeError("Entity cannot be featurized for filtered eval.")
        if entity.num_partitions > 1:
            raise RuntimeError("Entity cannot be partitioned for filtered eval.")

        self.lhs_map: Dict[Tuple[int, int], List[int]] = defaultdict(list)
        self.rhs_map: Dict[Tuple[int, int], List[int]] = defaultdict(list)
        for path in filter_paths:
            log("Building links map from path %s" % path)
            e_reader = EdgeReader(path)
            # Assume unpartitioned.
            lhs, rhs, rel = e_reader.read(Partition(0), Partition(0))
            num_edges = lhs.size(0)
            for i in range(num_edges):
                # Assume non-featurized.
                cur_lhs = lhs[i].collapse(is_featurized=False).item()
                cur_rel = rel[i].item()
                # Assume non-featurized.
                cur_rhs = rhs[i].collapse(is_featurized=False).item()

                self.lhs_map[cur_lhs, cur_rel].append(cur_rhs)
                self.rhs_map[cur_rhs, cur_rel].append(cur_lhs)

            log("Done building links map from path %s" % path)
예제 #4
0
def create_buckets_ordered_lexicographically(
    nparts_lhs: int,
    nparts_rhs: int,
) -> List[Bucket]:
    """Return buckets in increasing LHS and, for the same LHS, in increasing RHS

    """
    buckets = [
        Bucket(Partition(lhs), Partition(rhs)) for lhs in range(nparts_lhs)
        for rhs in range(nparts_rhs)
    ]
    return buckets
예제 #5
0
 def write_new_version(
     self,
     config: ConfigSchema,
     entity_counts: Dict[EntityName, List[int]],
     embedding_storage_freelist: Dict[EntityName, Set[torch.FloatStorage]],
 ) -> None:
     metadata = self.collect_metadata()
     new_version = self._version(True)
     if self.partition_client is not None:
         for entity, econf in config.entities.items():
             for part in range(self.rank, econf.num_partitions,
                               self.num_machines):
                 logger.debug(f"Getting {entity} {part}")
                 count = entity_counts[entity][part]
                 s = next(iter(embedding_storage_freelist[entity]))
                 out = torch.FloatTensor(s).view(-1,
                                                 config.dimension)[:count]
                 embs, serialized_optim_state = self.partition_client.get(
                     EntityName(entity), Partition(part), out=out)
                 logger.debug(f"Done getting {entity} {part}")
                 logger.debug(f"Saving {entity} {part} v{new_version}")
                 self.storage.save_entity_partition(new_version, entity,
                                                    part, embs,
                                                    serialized_optim_state,
                                                    metadata)
                 logger.debug(f"Done saving {entity} {part} v{new_version}")
예제 #6
0
 def new_pass(self, is_first: bool = False) -> None:
     """Start a new epoch of training."""
     self.active = {}
     self.done = set()
     self.dirty = set()
     if self.init_tree and is_first:
         self.initialized_partitions = {Partition(0)}
     else:
         self.initialized_partitions = None
예제 #7
0
 def write_new_version(self, config: ConfigSchema) -> None:
     if self.background:
         self._sync()
     metadata = self.collect_metadata()
     new_version = self._version(True)
     if self.partition_client is not None:
         for entity, econf in config.entities.items():
             for part in range(self.rank, econf.num_partitions, self.num_machines):
                 vlog("Rank %d: getting %s %d" % (self.rank, entity, part))
                 embs, optim_state = \
                     self.partition_client.get(EntityName(entity), Partition(part))
                 vlog("Rank %d: saving %s %d to disk" % (self.rank, entity, part))
                 new_file_path = os.path.join(
                     self.path, "embeddings_%s_%d.v%d.h5" % (entity, part, new_version))
                 save_entity_partition(new_file_path, embs, optim_state, metadata)
 def write_new_version(self, config: ConfigSchema) -> None:
     if self.background:
         self._sync()
     metadata = self.collect_metadata()
     new_version = self._version(True)
     if self.partition_client is not None:
         for entity, econf in config.entities.items():
             for part in range(self.rank, econf.num_partitions, self.num_machines):
                 logger.debug(f"Getting {entity} {part}")
                 embs, serialized_optim_state = \
                     self.partition_client.get(EntityName(entity), Partition(part))
                 logger.debug(f"Done getting {entity} {part}")
                 logger.debug(f"Saving {entity} {part} v{new_version}")
                 self.storage.save_entity_partition(
                     new_version, entity, part, embs, serialized_optim_state, metadata)
                 logger.debug(f"Done saving {entity} {part} v{new_version}")
예제 #9
0
 def write_new_version(self, config: ConfigSchema) -> None:
     if self.background:
         self._sync()
     metadata = self.collect_metadata()
     new_version = self._version(True)
     if self.partition_client is not None:
         for entity, econf in config.entities.items():
             for part in range(self.rank, econf.num_partitions,
                               self.num_machines):
                 logger.debug(f"Getting {entity} {part}")
                 embs, optim_state = \
                     self.partition_client.get(EntityName(entity), Partition(part))
                 logger.debug(f"Done getting {entity} {part}")
                 new_file_path = os.path.join(
                     self.path,
                     f"embeddings_{entity}_{part}.v{new_version}.h5")
                 logger.debug(f"Saving {entity} {part} to {new_file_path}")
                 save_entity_partition(new_file_path, embs, optim_state,
                                       metadata)
                 logger.debug(
                     f"Done saving {entity} {part} to {new_file_path}")
예제 #10
0
def create_buckets_ordered_by_affinity(
    nparts_lhs: int,
    nparts_rhs: int,
    *,
    generator: random.Random,
) -> List[Bucket]:
    """Try having consecutive buckets share as many partitions as possible.

    Start from a random bucket. Until there are buckets left, try to choose the
    next one so that it has as many partitions in common as possible with the
    previous one. When multiple options are available, pick one randomly.

    """
    if nparts_lhs <= 0 or nparts_rhs <= 0:
        return []

    # This is our "source of truth" on what buckets we haven't outputted yet. It
    # can be queried in constant time.
    remaining: Set[Bucket] = set()
    # These are our random orders: we shuffle them once and then pop from the
    # end. Each bucket appears in several of them. They are updated lazily,
    # which means they may contain buckets that have already been outputted.
    all_buckets: List[Bucket] = []
    buckets_per_partition: List[List[Bucket]] = \
        [[] for _ in range(max(nparts_lhs, nparts_rhs))]

    for lhs in range(nparts_lhs):
        for rhs in range(nparts_rhs):
            b = Bucket(Partition(lhs), Partition(rhs))
            remaining.add(b)
            all_buckets.append(b)
            buckets_per_partition[lhs].append(b)
            buckets_per_partition[rhs].append(b)

    generator.shuffle(all_buckets)
    for buckets in buckets_per_partition:
        generator.shuffle(buckets)

    b = all_buckets.pop()
    remaining.remove(b)
    order = [b]

    while remaining:
        transposed_b = Bucket(b.rhs, b.lhs)
        if transposed_b in remaining:
            remaining.remove(transposed_b)
            order.append(transposed_b)
            if not remaining:
                break

        same_as_lhs = buckets_per_partition[b.lhs]
        same_as_rhs = buckets_per_partition[b.rhs]
        while len(same_as_lhs) > 0 or len(same_as_rhs) > 0:
            chosen, = generator.choices(
                [same_as_lhs, same_as_rhs],
                weights=[len(same_as_lhs), len(same_as_rhs)],
            )
            next_b = chosen.pop()
            if next_b in remaining:
                break
        else:
            while True:
                next_b = all_buckets.pop()
                if next_b in remaining:
                    break
        remaining.remove(next_b)
        order.append(next_b)
        b = next_b

    return order
예제 #11
0
def create_buckets_ordered_by_affinity(
    nparts_lhs: int,
    nparts_rhs: int,
    *,
    generator: random.Random,
) -> List[Bucket]:
    """Try having consecutive buckets share as many partitions as possible.

    Start from a random bucket. Until there are buckets left, try to choose the
    next one so that it has as many partitions in common as possible with the
    previous one. When multiple options are available, pick one randomly.

    """
    if nparts_lhs <= 0 or nparts_rhs <= 0:
        return []

    # TODO Change this function to use the same cost model as the LockServer
    # when computing affinity (based on the number of entities to save and load)
    # rather than just the number of partitions in common. Pay attention to keep
    # the complexity of this algorithm linear in the number of buckets. This
    # comment is too short to give a full description, but the idea is that only
    # a few transitions are possible between a bucket and the next: the one that
    # preserves all (ent, part) pairs, the one that preserves only the lhs ones,
    # only the rhs ones, only the intersection of the two, or none at all. So we
    # can keep a dict from sets of (ent, part) to lists of buckets, and insert
    # each bucket into four of those lists, namely the ones for all its (ent,
    # part), its lhs ones, its rhs ones and the intersection of its lhs and rhs
    # ones. Then, when looking for the next bucket, we figure out the transition
    # that is cheapest (among the options defined above), determine the set of
    # (ent, part) we need to move to in order to achieve that transition type
    # and we look up in the dict to find a bucket containing those (ent, part).

    # This is our "source of truth" on what buckets we haven't outputted yet. It
    # can be queried in constant time.
    remaining: Set[Bucket] = set()
    # These are our random orders: we shuffle them once and then pop from the
    # end. Each bucket appears in several of them. They are updated lazily,
    # which means they may contain buckets that have already been outputted.
    all_buckets: List[Bucket] = []
    buckets_per_partition: List[List[Bucket]] = \
        [[] for _ in range(max(nparts_lhs, nparts_rhs))]

    for lhs in range(nparts_lhs):
        for rhs in range(nparts_rhs):
            b = Bucket(Partition(lhs), Partition(rhs))
            remaining.add(b)
            all_buckets.append(b)
            buckets_per_partition[lhs].append(b)
            buckets_per_partition[rhs].append(b)

    generator.shuffle(all_buckets)
    for buckets in buckets_per_partition:
        generator.shuffle(buckets)

    b = all_buckets.pop()
    remaining.remove(b)
    order = [b]

    while remaining:
        transposed_b = Bucket(b.rhs, b.lhs)
        if transposed_b in remaining:
            remaining.remove(transposed_b)
            order.append(transposed_b)
            if not remaining:
                break

        same_as_lhs = buckets_per_partition[b.lhs]
        same_as_rhs = buckets_per_partition[b.rhs]
        while len(same_as_lhs) > 0 or len(same_as_rhs) > 0:
            chosen, = generator.choices(
                [same_as_lhs, same_as_rhs],
                weights=[len(same_as_lhs), len(same_as_rhs)],
            )
            next_b = chosen.pop()
            if next_b in remaining:
                break
        else:
            while True:
                next_b = all_buckets.pop()
                if next_b in remaining:
                    break
        remaining.remove(next_b)
        order.append(next_b)
        b = next_b

    return order
예제 #12
0
def do_eval_and_report_stats(
    config: ConfigSchema,
    model: Optional[MultiRelationEmbedder] = None,
    evaluator: Optional[AbstractBatchProcessor] = None,
) -> Generator[Tuple[Optional[int], Optional[Bucket], Stats], None, None]:
    """Computes eval metrics (mr/mrr/r1/r10/r50) for a checkpoint with trained
       embeddings.
    """

    if evaluator is None:
        evaluator = RankingEvaluator()

    if config.verbose > 0:
        import pprint
        pprint.PrettyPrinter().pprint(config.to_dict())

    checkpoint_manager = CheckpointManager(config.checkpoint_path)

    def load_embeddings(entity: EntityName, part: Partition) -> torch.nn.Parameter:
        embs, _ = checkpoint_manager.read(entity, part)
        assert embs.is_shared()
        return torch.nn.Parameter(embs)

    nparts_lhs, lhs_partitioned_types = get_partitioned_types(config, Side.LHS)
    nparts_rhs, rhs_partitioned_types = get_partitioned_types(config, Side.RHS)

    num_workers = get_num_workers(config.workers)
    pool = create_pool(num_workers)

    if model is None:
        model = make_model(config)
    model.share_memory()

    state_dict, _ = checkpoint_manager.maybe_read_model()
    if state_dict is not None:
        model.load_state_dict(state_dict, strict=False)

    model.eval()

    for entity, econfig in config.entities.items():
        if econfig.num_partitions == 1:
            embs = load_embeddings(entity, Partition(0))
            model.set_embeddings(entity, embs, Side.LHS)
            model.set_embeddings(entity, embs, Side.RHS)

    all_stats: List[Stats] = []
    for edge_path_idx, edge_path in enumerate(config.edge_paths):
        log("Starting edge path %d / %d (%s)"
            % (edge_path_idx + 1, len(config.edge_paths), edge_path))
        edge_reader = EdgeReader(edge_path)

        all_edge_path_stats = []
        last_lhs, last_rhs = None, None
        for bucket in create_buckets_ordered_lexicographically(nparts_lhs, nparts_rhs):
            tic = time.time()
            # log("%s: Loading entities" % (bucket,))

            if last_lhs != bucket.lhs:
                for e in lhs_partitioned_types:
                    model.clear_embeddings(e, Side.LHS)
                    embs = load_embeddings(e, bucket.lhs)
                    model.set_embeddings(e, embs, Side.LHS)
            if last_rhs != bucket.rhs:
                for e in rhs_partitioned_types:
                    model.clear_embeddings(e, Side.RHS)
                    embs = load_embeddings(e, bucket.rhs)
                    model.set_embeddings(e, embs, Side.RHS)
            last_lhs, last_rhs = bucket.lhs, bucket.rhs

            # log("%s: Loading edges" % (bucket,))
            edges = edge_reader.read(bucket.lhs, bucket.rhs)
            num_edges = len(edges)

            load_time = time.time() - tic
            tic = time.time()
            # log("%s: Launching and waiting for workers" % (bucket,))
            all_bucket_stats = pool.map(call, [
                partial(
                    process_in_batches,
                    batch_size=config.batch_size,
                    model=model,
                    batch_processor=evaluator,
                    edges=edges[s],
                )
                for s in split_almost_equally(num_edges, num_parts=num_workers)
            ])

            compute_time = time.time() - tic
            log("%s: Processed %d edges in %.2g s (%.2gM/sec); load time: %.2g s"
                % (bucket, num_edges, compute_time,
                   num_edges / compute_time / 1e6, load_time))

            total_bucket_stats = Stats.sum(all_bucket_stats)
            all_edge_path_stats.append(total_bucket_stats)
            mean_bucket_stats = total_bucket_stats.average()
            log("Stats for edge path %d / %d, bucket %s: %s"
                % (edge_path_idx + 1, len(config.edge_paths), bucket,
                   mean_bucket_stats))

            yield edge_path_idx, bucket, mean_bucket_stats

        total_edge_path_stats = Stats.sum(all_edge_path_stats)
        all_stats.append(total_edge_path_stats)
        mean_edge_path_stats = total_edge_path_stats.average()
        log("")
        log("Stats for edge path %d / %d: %s"
            % (edge_path_idx + 1, len(config.edge_paths), mean_edge_path_stats))
        log("")

        yield edge_path_idx, None, mean_edge_path_stats

    mean_stats = Stats.sum(all_stats).average()
    log("")
    log("Stats: %s" % mean_stats)
    log("")

    yield None, None, mean_stats

    pool.close()
    pool.join()
예제 #13
0
    def _maybe_write_checkpoint(
        self,
        epoch_idx: int,
        edge_path_idx: int,
        edge_chunk_idx: int,
    ) -> None:
        config = self.config

        # Preserving a checkpoint requires two steps:
        # - create a snapshot (w/ symlinks) after it's first written;
        # - don't delete it once the following one is written.
        # These two happen in two successive iterations of the main loop: the
        # one just before and the one just after the epoch boundary.
        preserve_old_checkpoint = should_preserve_old_checkpoint(
            self.iteration_manager, config.checkpoint_preservation_interval)
        preserve_new_checkpoint = should_preserve_old_checkpoint(
            self.iteration_manager + 1,
            config.checkpoint_preservation_interval)

        # Write metadata: for multiple machines, write from rank-0
        logger.info(
            f"Finished epoch {epoch_idx + 1} / {self.iteration_manager.num_epochs}, "
            f"edge path {edge_path_idx + 1} / {self.iteration_manager.num_edge_paths}, "
            f"edge chunk {edge_chunk_idx + 1} / {self.iteration_manager.num_edge_chunks}"
        )
        if self.rank == 0:
            for entity, econfig in config.entities.items():
                if econfig.num_partitions == 1:
                    embs = self.holder.unpartitioned_embeddings[entity]
                    optimizer = self.trainer.unpartitioned_optimizers[entity]
                    self.checkpoint_manager.write(
                        entity, Partition(0), embs.detach(),
                        OptimizerStateDict(optimizer.state_dict()))

            logger.info("Writing the metadata")
            state_dict: ModuleStateDict = ModuleStateDict(
                self.model.state_dict())
            self.checkpoint_manager.write_model(
                state_dict,
                OptimizerStateDict(self.trainer.model_optimizer.state_dict()))

            logger.info("Writing the training stats")
            all_stats_dicts: List[Dict[...]] = []
            for stats in self.bucket_scheduler.get_stats_for_pass():
                stats_dict = {
                    "epoch_idx": epoch_idx,
                    "edge_path_idx": edge_path_idx,
                    "edge_chunk_idx": edge_chunk_idx,
                    "lhs_partition": stats.lhs_partition,
                    "rhs_partition": stats.rhs_partition,
                    "index": stats.index,
                    "stats": stats.train.to_dict(),
                }
                if stats.eval_before is not None:
                    stats_dict[
                        "eval_stats_before"] = stats.eval_before.to_dict()
                if stats.eval_after is not None:
                    stats_dict["eval_stats_after"] = stats.eval_after.to_dict()
                all_stats_dicts.append(stats_dict)
            self.checkpoint_manager.append_stats(all_stats_dicts)

        logger.info("Writing the checkpoint")
        self.checkpoint_manager.write_new_version(
            config, self.entity_counts, self.embedding_storage_freelist)

        dist_logger.info(
            "Waiting for other workers to write their parts of the checkpoint")
        self._barrier()
        dist_logger.info("All parts of the checkpoint have been written")

        logger.info("Switching to the new checkpoint version")
        self.checkpoint_manager.switch_to_new_version()

        dist_logger.info(
            "Waiting for other workers to switch to the new checkpoint version"
        )
        self._barrier()
        dist_logger.info(
            "All workers have switched to the new checkpoint version")

        # After all the machines have finished committing
        # checkpoints, we either remove the old checkpoints
        # or we preserve it
        if preserve_new_checkpoint:
            # Add 1 so the index is a multiple of the interval, it looks nicer.
            self.checkpoint_manager.preserve_current_version(
                config, epoch_idx + 1)
        if not preserve_old_checkpoint:
            self.checkpoint_manager.remove_old_version(config)
예제 #14
0
    def __init__(
        self,
        config: ConfigSchema,
        model: Optional[MultiRelationEmbedder] = None,
        trainer: Optional[AbstractBatchProcessor] = None,
        evaluator: Optional[AbstractBatchProcessor] = None,
        rank: Rank = RANK_ZERO,
        subprocess_init: Optional[Callable[[], None]] = None,
    ):
        """Each epoch/pass, for each partition pair, loads in embeddings and edgelist
        from disk, runs HOGWILD training on them, and writes partitions back to disk.
        """
        tag_logs_with_process_name(f"Trainer-{rank}")
        self.config = config
        if config.verbose > 0:
            import pprint
            pprint.PrettyPrinter().pprint(config.to_dict())

        logger.info("Loading entity counts...")
        entity_storage = ENTITY_STORAGES.make_instance(config.entity_path)
        entity_counts: Dict[str, List[int]] = {}
        for entity, econf in config.entities.items():
            entity_counts[entity] = []
            for part in range(econf.num_partitions):
                entity_counts[entity].append(
                    entity_storage.load_count(entity, part))

        # Figure out how many lhs and rhs partitions we need
        holder = self.holder = EmbeddingHolder(config)

        logger.debug(
            f"nparts {holder.nparts_lhs} {holder.nparts_rhs} "
            f"types {holder.lhs_partitioned_types} {holder.rhs_partitioned_types}"
        )

        # We know ahead of time that we wil need 1-2 storages for each embedding type,
        # as well as the max size of this storage (num_entities x D).
        # We allocate these storages n advance in `embedding_storage_freelist`.
        # When we need storage for an entity type, we pop it from this free list,
        # and then add it back when we 'delete' the embedding table.
        embedding_storage_freelist: Dict[
            EntityName, Set[torch.FloatStorage]] = defaultdict(set)
        for entity_type, counts in entity_counts.items():
            max_count = max(counts)
            num_sides = (
                (1 if entity_type in holder.lhs_partitioned_types else 0) +
                (1 if entity_type in holder.rhs_partitioned_types else 0) +
                (1 if entity_type in (holder.lhs_unpartitioned_types |
                                      holder.rhs_unpartitioned_types) else 0))
            for _ in range(num_sides):
                embedding_storage_freelist[entity_type].add(
                    allocate_shared_tensor((max_count, config.dimension),
                                           dtype=torch.float).storage())

        # create the handlers, threads, etc. for distributed training
        if config.num_machines > 1 or config.num_partition_servers > 0:
            if not 0 <= rank < config.num_machines:
                raise RuntimeError("Invalid rank for trainer")
            if not td.is_available():
                raise RuntimeError(
                    "The installed PyTorch version doesn't provide "
                    "distributed training capabilities.")
            ranks = ProcessRanks.from_num_invocations(
                config.num_machines, config.num_partition_servers)

            num_ps_groups = config.num_groups_for_partition_server
            groups: List[List[int]] = [ranks.trainers]  # barrier group
            groups += [ranks.trainers + ranks.partition_servers
                       ] * num_ps_groups  # ps groups
            group_idxs_for_partition_servers = range(1, len(groups))

            if rank == RANK_ZERO:
                logger.info("Setup lock server...")
                start_server(
                    LockServer(
                        num_clients=len(ranks.trainers),
                        nparts_lhs=holder.nparts_lhs,
                        nparts_rhs=holder.nparts_rhs,
                        entities_lhs=holder.lhs_partitioned_types,
                        entities_rhs=holder.rhs_partitioned_types,
                        entity_counts=entity_counts,
                        init_tree=config.distributed_tree_init_order,
                    ),
                    process_name="LockServer",
                    init_method=config.distributed_init_method,
                    world_size=ranks.world_size,
                    server_rank=ranks.lock_server,
                    groups=groups,
                    subprocess_init=subprocess_init,
                )

            self.bucket_scheduler = DistributedBucketScheduler(
                server_rank=ranks.lock_server,
                client_rank=ranks.trainers[rank],
            )

            logger.info("Setup param server...")
            start_server(
                ParameterServer(num_clients=len(ranks.trainers)),
                process_name=f"ParamS-{rank}",
                init_method=config.distributed_init_method,
                world_size=ranks.world_size,
                server_rank=ranks.parameter_servers[rank],
                groups=groups,
                subprocess_init=subprocess_init,
            )

            parameter_sharer = ParameterSharer(
                process_name=f"ParamC-{rank}",
                client_rank=ranks.parameter_clients[rank],
                all_server_ranks=ranks.parameter_servers,
                init_method=config.distributed_init_method,
                world_size=ranks.world_size,
                groups=groups,
                subprocess_init=subprocess_init,
            )

            if config.num_partition_servers == -1:
                start_server(
                    ParameterServer(
                        num_clients=len(ranks.trainers),
                        group_idxs=group_idxs_for_partition_servers,
                        log_stats=True,
                    ),
                    process_name=f"PartS-{rank}",
                    init_method=config.distributed_init_method,
                    world_size=ranks.world_size,
                    server_rank=ranks.partition_servers[rank],
                    groups=groups,
                    subprocess_init=subprocess_init,
                )

            groups = init_process_group(
                rank=ranks.trainers[rank],
                world_size=ranks.world_size,
                init_method=config.distributed_init_method,
                groups=groups,
            )
            trainer_group, *groups_for_partition_servers = groups
            self.barrier_group = trainer_group

            if len(ranks.partition_servers) > 0:
                partition_client = PartitionClient(
                    ranks.partition_servers,
                    groups=groups_for_partition_servers,
                    log_stats=True,
                )
            else:
                partition_client = None
        else:
            self.barrier_group = None
            self.bucket_scheduler = SingleMachineBucketScheduler(
                holder.nparts_lhs, holder.nparts_rhs, config.bucket_order)
            parameter_sharer = None
            partition_client = None
            hide_distributed_logging()

        # fork early for HOGWILD threads
        logger.info("Creating workers...")
        self.num_workers = get_num_workers(config.workers)
        self.pool = create_pool(
            self.num_workers,
            subprocess_name=f"TWorker-{rank}",
            subprocess_init=subprocess_init,
        )

        checkpoint_manager = CheckpointManager(
            config.checkpoint_path,
            rank=rank,
            num_machines=config.num_machines,
            partition_client=partition_client,
            subprocess_name=f"BackgRW-{rank}",
            subprocess_init=subprocess_init,
        )
        self.checkpoint_manager = checkpoint_manager
        checkpoint_manager.register_metadata_provider(
            ConfigMetadataProvider(config))
        if rank == 0:
            checkpoint_manager.write_config(config)

        num_edge_chunks = get_num_edge_chunks(config)

        self.iteration_manager = IterationManager(
            config.num_epochs,
            config.edge_paths,
            num_edge_chunks,
            iteration_idx=checkpoint_manager.checkpoint_version)
        checkpoint_manager.register_metadata_provider(self.iteration_manager)

        logger.info("Initializing global model...")
        if model is None:
            model = make_model(config)
        model.share_memory()
        if trainer is None:
            trainer = Trainer(
                model_optimizer=make_optimizer(config, model.parameters(),
                                               False),
                loss_fn=config.loss_fn,
                margin=config.margin,
                relations=config.relations,
            )
        if evaluator is None:
            evaluator = TrainingRankingEvaluator(
                override_num_batch_negs=config.eval_num_batch_negs,
                override_num_uniform_negs=config.eval_num_uniform_negs,
            )

        if config.init_path is not None:
            self.loadpath_manager = CheckpointManager(config.init_path)
        else:
            self.loadpath_manager = None

        # load model from checkpoint or loadpath, if available
        state_dict, optim_state = checkpoint_manager.maybe_read_model()
        if state_dict is None and self.loadpath_manager is not None:
            state_dict, optim_state = self.loadpath_manager.maybe_read_model()
        if state_dict is not None:
            model.load_state_dict(state_dict, strict=False)
        if optim_state is not None:
            trainer.model_optimizer.load_state_dict(optim_state)

        logger.debug("Loading unpartitioned entities...")
        for entity in holder.lhs_unpartitioned_types | holder.rhs_unpartitioned_types:
            count = entity_counts[entity][0]
            s = embedding_storage_freelist[entity].pop()
            embs = torch.FloatTensor(s).view(-1, config.dimension)[:count]
            embs, optimizer = self._load_embeddings(entity,
                                                    Partition(0),
                                                    out=embs)
            holder.unpartitioned_embeddings[entity] = embs
            trainer.unpartitioned_optimizers[entity] = optimizer

        # start communicating shared parameters with the parameter server
        if parameter_sharer is not None:
            shared_parameters: Set[int] = set()
            for name, param in model.named_parameters():
                if id(param) in shared_parameters:
                    continue
                shared_parameters.add(id(param))
                key = f"model.{name}"
                logger.info(
                    f"Adding {key} ({param.numel()} params) to parameter server"
                )
                parameter_sharer.set_param(key, param.data)
            for entity, embs in holder.unpartitioned_embeddings.items():
                key = f"entity.{entity}"
                logger.info(
                    f"Adding {key} ({embs.numel()} params) to parameter server"
                )
                parameter_sharer.set_param(key, embs.data)

        # store everything in self
        self.model = model
        self.trainer = trainer
        self.evaluator = evaluator
        self.rank = rank
        self.entity_counts = entity_counts
        self.embedding_storage_freelist = embedding_storage_freelist

        self.strict = False
예제 #15
0
파일: eval.py 프로젝트: delldu/BigGraph
def do_eval_and_report_stats(
    config: ConfigSchema,
    model: Optional[MultiRelationEmbedder] = None,
    evaluator: Optional[AbstractBatchProcessor] = None,
    subprocess_init: Optional[Callable[[], None]] = None,
) -> Generator[Tuple[Optional[int], Optional[Bucket], Stats], None, None]:
    """Computes eval metrics (mr/mrr/r1/r10/r50) for a checkpoint with trained
       embeddings.
    """
    tag_logs_with_process_name(f"Evaluator")

    if evaluator is None:
        evaluator = RankingEvaluator()

    if config.verbose > 0:
        import pprint
        pprint.PrettyPrinter().pprint(config.to_dict())

    checkpoint_manager = CheckpointManager(config.checkpoint_path)

    def load_embeddings(entity: EntityName,
                        part: Partition) -> torch.nn.Parameter:
        embs, _ = checkpoint_manager.read(entity, part)
        assert embs.is_shared()
        return torch.nn.Parameter(embs)

    nparts_lhs, lhs_partitioned_types = get_partitioned_types(config, Side.LHS)
    nparts_rhs, rhs_partitioned_types = get_partitioned_types(config, Side.RHS)

    num_workers = get_num_workers(config.workers)
    pool = create_pool(
        num_workers,
        subprocess_name="EvalWorker",
        subprocess_init=subprocess_init,
    )

    if model is None:
        model = make_model(config)
    model.share_memory()

    state_dict, _ = checkpoint_manager.maybe_read_model()
    if state_dict is not None:
        model.load_state_dict(state_dict, strict=False)

    model.eval()

    for entity, econfig in config.entities.items():
        if econfig.num_partitions == 1:
            embs = load_embeddings(entity, Partition(0))
            model.set_embeddings(entity, embs, Side.LHS)
            model.set_embeddings(entity, embs, Side.RHS)

    all_stats: List[Stats] = []
    for edge_path_idx, edge_path in enumerate(config.edge_paths):
        logger.info(
            f"Starting edge path {edge_path_idx + 1} / {len(config.edge_paths)} "
            f"({edge_path})")
        edge_storage = EDGE_STORAGES.make_instance(edge_path)

        all_edge_path_stats = []
        last_lhs, last_rhs = None, None
        for bucket in create_buckets_ordered_lexicographically(
                nparts_lhs, nparts_rhs):
            tic = time.time()
            # logger.info(f"{bucket}: Loading entities")

            if last_lhs != bucket.lhs:
                for e in lhs_partitioned_types:
                    model.clear_embeddings(e, Side.LHS)
                    embs = load_embeddings(e, bucket.lhs)
                    model.set_embeddings(e, embs, Side.LHS)
            if last_rhs != bucket.rhs:
                for e in rhs_partitioned_types:
                    model.clear_embeddings(e, Side.RHS)
                    embs = load_embeddings(e, bucket.rhs)
                    model.set_embeddings(e, embs, Side.RHS)
            last_lhs, last_rhs = bucket.lhs, bucket.rhs

            # logger.info(f"{bucket}: Loading edges")
            edges = edge_storage.load_edges(bucket.lhs, bucket.rhs)
            num_edges = len(edges)

            load_time = time.time() - tic
            tic = time.time()
            # logger.info(f"{bucket}: Launching and waiting for workers")
            future_all_bucket_stats = pool.map_async(call, [
                partial(
                    process_in_batches,
                    batch_size=config.batch_size,
                    model=model,
                    batch_processor=evaluator,
                    edges=edges[s],
                )
                for s in split_almost_equally(num_edges, num_parts=num_workers)
            ])
            all_bucket_stats = \
                get_async_result(future_all_bucket_stats, pool)

            compute_time = time.time() - tic
            logger.info(
                f"{bucket}: Processed {num_edges} edges in {compute_time:.2g} s "
                f"({num_edges / compute_time / 1e6:.2g}M/sec); "
                f"load time: {load_time:.2g} s")

            total_bucket_stats = Stats.sum(all_bucket_stats)
            all_edge_path_stats.append(total_bucket_stats)
            mean_bucket_stats = total_bucket_stats.average()
            logger.info(
                f"Stats for edge path {edge_path_idx + 1} / {len(config.edge_paths)}, "
                f"bucket {bucket}: {mean_bucket_stats}")

            yield edge_path_idx, bucket, mean_bucket_stats

        total_edge_path_stats = Stats.sum(all_edge_path_stats)
        all_stats.append(total_edge_path_stats)
        mean_edge_path_stats = total_edge_path_stats.average()
        logger.info("")
        logger.info(
            f"Stats for edge path {edge_path_idx + 1} / {len(config.edge_paths)}: "
            f"{mean_edge_path_stats}")
        logger.info("")

        yield edge_path_idx, None, mean_edge_path_stats

    mean_stats = Stats.sum(all_stats).average()
    logger.info("")
    logger.info(f"Stats: {mean_stats}")
    logger.info("")

    yield None, None, mean_stats

    pool.close()
    pool.join()
예제 #16
0
def train_and_report_stats(
    config: ConfigSchema,
    model: Optional[MultiRelationEmbedder] = None,
    trainer: Optional[AbstractBatchProcessor] = None,
    evaluator: Optional[AbstractBatchProcessor] = None,
    rank: Rank = RANK_ZERO,
    subprocess_init: Optional[Callable[[], None]] = None,
) -> Generator[Tuple[int, Optional[Stats], Stats, Optional[Stats]], None,
               None]:
    """Each epoch/pass, for each partition pair, loads in embeddings and edgelist
    from disk, runs HOGWILD training on them, and writes partitions back to disk.
    """
    tag_logs_with_process_name(f"Trainer-{rank}")

    if config.verbose > 0:
        import pprint
        pprint.PrettyPrinter().pprint(config.to_dict())

    logger.info("Loading entity counts...")
    entity_storage = ENTITY_STORAGES.make_instance(config.entity_path)
    entity_counts: Dict[str, List[int]] = {}
    for entity, econf in config.entities.items():
        entity_counts[entity] = []
        for part in range(econf.num_partitions):
            entity_counts[entity].append(
                entity_storage.load_count(entity, part))

    # Figure out how many lhs and rhs partitions we need
    nparts_lhs, lhs_partitioned_types = get_partitioned_types(config, Side.LHS)
    nparts_rhs, rhs_partitioned_types = get_partitioned_types(config, Side.RHS)
    logger.debug(f"nparts {nparts_lhs} {nparts_rhs} "
                 f"types {lhs_partitioned_types} {rhs_partitioned_types}")
    total_buckets = nparts_lhs * nparts_rhs

    sync: AbstractSynchronizer
    bucket_scheduler: AbstractBucketScheduler
    parameter_sharer: Optional[ParameterSharer]
    partition_client: Optional[PartitionClient]
    if config.num_machines > 1:
        if not 0 <= rank < config.num_machines:
            raise RuntimeError("Invalid rank for trainer")
        if not td.is_available():
            raise RuntimeError("The installed PyTorch version doesn't provide "
                               "distributed training capabilities.")
        ranks = ProcessRanks.from_num_invocations(config.num_machines,
                                                  config.num_partition_servers)

        if rank == RANK_ZERO:
            logger.info("Setup lock server...")
            start_server(
                LockServer(
                    num_clients=len(ranks.trainers),
                    nparts_lhs=nparts_lhs,
                    nparts_rhs=nparts_rhs,
                    lock_lhs=len(lhs_partitioned_types) > 0,
                    lock_rhs=len(rhs_partitioned_types) > 0,
                    init_tree=config.distributed_tree_init_order,
                ),
                process_name="LockServer",
                init_method=config.distributed_init_method,
                world_size=ranks.world_size,
                server_rank=ranks.lock_server,
                groups=[ranks.trainers],
                subprocess_init=subprocess_init,
            )

        bucket_scheduler = DistributedBucketScheduler(
            server_rank=ranks.lock_server,
            client_rank=ranks.trainers[rank],
        )

        logger.info("Setup param server...")
        start_server(
            ParameterServer(num_clients=len(ranks.trainers)),
            process_name=f"ParamS-{rank}",
            init_method=config.distributed_init_method,
            world_size=ranks.world_size,
            server_rank=ranks.parameter_servers[rank],
            groups=[ranks.trainers],
            subprocess_init=subprocess_init,
        )

        parameter_sharer = ParameterSharer(
            process_name=f"ParamC-{rank}",
            client_rank=ranks.parameter_clients[rank],
            all_server_ranks=ranks.parameter_servers,
            init_method=config.distributed_init_method,
            world_size=ranks.world_size,
            groups=[ranks.trainers],
            subprocess_init=subprocess_init,
        )

        if config.num_partition_servers == -1:
            start_server(
                ParameterServer(num_clients=len(ranks.trainers),
                                log_stats=True),
                process_name=f"PartS-{rank}",
                init_method=config.distributed_init_method,
                world_size=ranks.world_size,
                server_rank=ranks.partition_servers[rank],
                groups=[ranks.trainers],
                subprocess_init=subprocess_init,
            )

        if len(ranks.partition_servers) > 0:
            partition_client = PartitionClient(ranks.partition_servers,
                                               log_stats=True)
        else:
            partition_client = None

        groups = init_process_group(
            rank=ranks.trainers[rank],
            world_size=ranks.world_size,
            init_method=config.distributed_init_method,
            groups=[ranks.trainers],
        )
        trainer_group, = groups
        sync = DistributedSynchronizer(trainer_group)

    else:
        sync = DummySynchronizer()
        bucket_scheduler = SingleMachineBucketScheduler(
            nparts_lhs, nparts_rhs, config.bucket_order)
        parameter_sharer = None
        partition_client = None
        hide_distributed_logging()

    # fork early for HOGWILD threads
    logger.info("Creating workers...")
    num_workers = get_num_workers(config.workers)
    pool = create_pool(
        num_workers,
        subprocess_name=f"TWorker-{rank}",
        subprocess_init=subprocess_init,
    )

    def make_optimizer(params: Iterable[torch.nn.Parameter],
                       is_emb: bool) -> Optimizer:
        params = list(params)
        if len(params) == 0:
            optimizer = DummyOptimizer()
        elif is_emb:
            optimizer = RowAdagrad(params, lr=config.lr)
        else:
            if config.relation_lr is not None:
                lr = config.relation_lr
            else:
                lr = config.lr
            optimizer = Adagrad(params, lr=lr)
        optimizer.share_memory()
        return optimizer

    # background_io is only supported in single-machine mode
    background_io = config.background_io and config.num_machines == 1

    checkpoint_manager = CheckpointManager(
        config.checkpoint_path,
        background=background_io,
        rank=rank,
        num_machines=config.num_machines,
        partition_client=partition_client,
        subprocess_name=f"BackgRW-{rank}",
        subprocess_init=subprocess_init,
    )
    checkpoint_manager.register_metadata_provider(
        ConfigMetadataProvider(config))
    checkpoint_manager.write_config(config)

    if config.num_edge_chunks is not None:
        num_edge_chunks = config.num_edge_chunks
    else:
        num_edge_chunks = get_num_edge_chunks(config.edge_paths, nparts_lhs,
                                              nparts_rhs,
                                              config.max_edges_per_chunk)
    iteration_manager = IterationManager(
        config.num_epochs,
        config.edge_paths,
        num_edge_chunks,
        iteration_idx=checkpoint_manager.checkpoint_version)
    checkpoint_manager.register_metadata_provider(iteration_manager)

    if config.init_path is not None:
        loadpath_manager = CheckpointManager(config.init_path)
    else:
        loadpath_manager = None

    def load_embeddings(
        entity: EntityName,
        part: Partition,
        strict: bool = False,
        force_dirty: bool = False,
    ) -> Tuple[torch.nn.Parameter, Optional[OptimizerStateDict]]:
        if strict:
            embs, optim_state = checkpoint_manager.read(
                entity, part, force_dirty=force_dirty)
        else:
            # Strict is only false during the first iteration, because in that
            # case the checkpoint may not contain any data (unless a previous
            # run was resumed) so we fall back on initial values.
            embs, optim_state = checkpoint_manager.maybe_read(
                entity, part, force_dirty=force_dirty)
            if embs is None and loadpath_manager is not None:
                embs, optim_state = loadpath_manager.maybe_read(entity, part)
            if embs is None:
                embs, optim_state = init_embs(entity,
                                              entity_counts[entity][part],
                                              config.dimension,
                                              config.init_scale)
        assert embs.is_shared()
        return torch.nn.Parameter(embs), optim_state

    logger.info("Initializing global model...")

    if model is None:
        model = make_model(config)
    model.share_memory()
    if trainer is None:
        trainer = Trainer(
            global_optimizer=make_optimizer(model.parameters(), False),
            loss_fn=config.loss_fn,
            margin=config.margin,
            relations=config.relations,
        )
    if evaluator is None:
        evaluator = TrainingRankingEvaluator(
            override_num_batch_negs=config.eval_num_batch_negs,
            override_num_uniform_negs=config.eval_num_uniform_negs,
        )
    eval_batch_size = round_up_to_nearest_multiple(config.batch_size,
                                                   config.eval_num_batch_negs)

    state_dict, optim_state = checkpoint_manager.maybe_read_model()

    if state_dict is None and loadpath_manager is not None:
        state_dict, optim_state = loadpath_manager.maybe_read_model()
    if state_dict is not None:
        model.load_state_dict(state_dict, strict=False)
    if optim_state is not None:
        trainer.global_optimizer.load_state_dict(optim_state)

    logger.debug("Loading unpartitioned entities...")
    for entity, econfig in config.entities.items():
        if econfig.num_partitions == 1:
            embs, optim_state = load_embeddings(entity, Partition(0))
            model.set_embeddings(entity, embs, Side.LHS)
            model.set_embeddings(entity, embs, Side.RHS)
            optimizer = make_optimizer([embs], True)
            if optim_state is not None:
                optimizer.load_state_dict(optim_state)
            trainer.entity_optimizers[(entity, Partition(0))] = optimizer

    # start communicating shared parameters with the parameter server
    if parameter_sharer is not None:
        parameter_sharer.share_model_params(model)

    strict = False

    def swap_partitioned_embeddings(
        old_b: Optional[Bucket],
        new_b: Optional[Bucket],
    ):
        # 0. given the old and new buckets, construct data structures to keep
        #    track of old and new embedding (entity, part) tuples

        io_bytes = 0
        logger.info(f"Swapping partitioned embeddings {old_b} {new_b}")

        types = ([(e, Side.LHS) for e in lhs_partitioned_types] +
                 [(e, Side.RHS) for e in rhs_partitioned_types])
        old_parts = {(e, old_b.get_partition(side)): side
                     for e, side in types if old_b is not None}
        new_parts = {(e, new_b.get_partition(side)): side
                     for e, side in types if new_b is not None}

        to_checkpoint = set(old_parts) - set(new_parts)
        preserved = set(old_parts) & set(new_parts)

        # 1. checkpoint embeddings that will not be used in the next pair
        #
        if old_b is not None:  # there are previous embeddings to checkpoint
            logger.info("Writing partitioned embeddings")
            for entity, part in to_checkpoint:
                side = old_parts[(entity, part)]
                side_name = side.pick("lhs", "rhs")
                logger.debug(f"Checkpointing ({entity} {part} {side_name})")
                embs = model.get_embeddings(entity, side)
                optim_key = (entity, part)
                optim_state = OptimizerStateDict(
                    trainer.entity_optimizers[optim_key].state_dict())
                io_bytes += embs.numel() * embs.element_size(
                )  # ignore optim state
                checkpoint_manager.write(entity, part, embs.detach(),
                                         optim_state)
                if optim_key in trainer.entity_optimizers:
                    del trainer.entity_optimizers[optim_key]
                # these variables are holding large objects; let them be freed
                del embs
                del optim_state

            bucket_scheduler.release_bucket(old_b)

        # 2. copy old embeddings that will be used in the next pair
        #    into a temporary dictionary
        #
        tmp_emb = {
            x: model.get_embeddings(x[0], old_parts[x])
            for x in preserved
        }

        for entity, _ in types:
            model.clear_embeddings(entity, Side.LHS)
            model.clear_embeddings(entity, Side.RHS)

        if new_b is None:  # there are no new embeddings to load
            return io_bytes

        bucket_logger = BucketLogger(logger, bucket=new_b)

        # 3. load new embeddings into the model/optimizer, either from disk
        #    or the temporary dictionary
        #
        bucket_logger.info("Loading entities")
        for entity, side in types:
            part = new_b.get_partition(side)
            part_key = (entity, part)
            if part_key in tmp_emb:
                bucket_logger.debug(
                    f"Loading ({entity}, {part}) from preserved")
                embs, optim_state = tmp_emb[part_key], None
            else:
                bucket_logger.debug(f"Loading ({entity}, {part})")

                force_dirty = bucket_scheduler.check_and_set_dirty(
                    entity, part)
                embs, optim_state = load_embeddings(entity,
                                                    part,
                                                    strict=strict,
                                                    force_dirty=force_dirty)
                io_bytes += embs.numel() * embs.element_size(
                )  # ignore optim state

            model.set_embeddings(entity, embs, side)
            tmp_emb[part_key] = embs

            optim_key = (entity, part)
            if optim_key not in trainer.entity_optimizers:
                bucket_logger.debug(f"Resetting optimizer {optim_key}")
                optimizer = make_optimizer([embs], True)
                if optim_state is not None:
                    bucket_logger.debug("Setting optim state")
                    optimizer.load_state_dict(optim_state)

                trainer.entity_optimizers[optim_key] = optimizer

        return io_bytes

    if rank == RANK_ZERO:
        for stats_dict in checkpoint_manager.maybe_read_stats():
            index: int = stats_dict["index"]
            stats: Stats = Stats.from_dict(stats_dict["stats"])
            eval_stats_before: Optional[Stats] = None
            if "eval_stats_before" in stats_dict:
                eval_stats_before = Stats.from_dict(
                    stats_dict["eval_stats_before"])
            eval_stats_after: Optional[Stats] = None
            if "eval_stats_after" in stats_dict:
                eval_stats_after = Stats.from_dict(
                    stats_dict["eval_stats_after"])
            yield (index, eval_stats_before, stats, eval_stats_after)

    # Start of the main training loop.
    for epoch_idx, edge_path_idx, edge_chunk_idx in iteration_manager:
        logger.info(
            f"Starting epoch {epoch_idx + 1} / {iteration_manager.num_epochs}, "
            f"edge path {edge_path_idx + 1} / {iteration_manager.num_edge_paths}, "
            f"edge chunk {edge_chunk_idx + 1} / {iteration_manager.num_edge_chunks}"
        )
        edge_storage = EDGE_STORAGES.make_instance(iteration_manager.edge_path)
        logger.info(f"Edge path: {iteration_manager.edge_path}")

        sync.barrier()
        dist_logger.info("Lock client new epoch...")
        bucket_scheduler.new_pass(
            is_first=iteration_manager.iteration_idx == 0)
        sync.barrier()

        remaining = total_buckets
        cur_b = None
        while remaining > 0:
            old_b = cur_b
            io_time = 0.
            io_bytes = 0
            cur_b, remaining = bucket_scheduler.acquire_bucket()
            logger.info(f"still in queue: {remaining}")
            if cur_b is None:
                if old_b is not None:
                    # if you couldn't get a new pair, release the lock
                    # to prevent a deadlock!
                    tic = time.time()
                    io_bytes += swap_partitioned_embeddings(old_b, None)
                    io_time += time.time() - tic
                time.sleep(1)  # don't hammer td
                continue

            bucket_logger = BucketLogger(logger, bucket=cur_b)

            tic = time.time()

            io_bytes += swap_partitioned_embeddings(old_b, cur_b)

            current_index = \
                (iteration_manager.iteration_idx + 1) * total_buckets - remaining

            next_b = bucket_scheduler.peek()
            if next_b is not None and background_io:
                # Ensure the previous bucket finished writing to disk.
                checkpoint_manager.wait_for_marker(current_index - 1)

                bucket_logger.debug("Prefetching")
                for entity in lhs_partitioned_types:
                    checkpoint_manager.prefetch(entity, next_b.lhs)
                for entity in rhs_partitioned_types:
                    checkpoint_manager.prefetch(entity, next_b.rhs)

                checkpoint_manager.record_marker(current_index)

            bucket_logger.debug("Loading edges")
            edges = edge_storage.load_chunk_of_edges(
                cur_b.lhs, cur_b.rhs, edge_chunk_idx,
                iteration_manager.num_edge_chunks)
            num_edges = len(edges)
            # this might be off in the case of tensorlist or extra edge fields
            io_bytes += edges.lhs.tensor.numel(
            ) * edges.lhs.tensor.element_size()
            io_bytes += edges.rhs.tensor.numel(
            ) * edges.rhs.tensor.element_size()
            io_bytes += edges.rel.numel() * edges.rel.element_size()

            bucket_logger.debug("Shuffling edges")
            # Fix a seed to get the same permutation every time; have it
            # depend on all and only what affects the set of edges.
            g = torch.Generator()
            g.manual_seed(
                hash((edge_path_idx, edge_chunk_idx, cur_b.lhs, cur_b.rhs)))

            num_eval_edges = int(num_edges * config.eval_fraction)
            if num_eval_edges > 0:
                edge_perm = torch.randperm(num_edges, generator=g)
                eval_edge_perm = edge_perm[-num_eval_edges:]
                num_edges -= num_eval_edges
                edge_perm = edge_perm[torch.randperm(num_edges)]
            else:
                edge_perm = torch.randperm(num_edges)

            # HOGWILD evaluation before training
            eval_stats_before: Optional[Stats] = None
            if num_eval_edges > 0:
                bucket_logger.debug(
                    "Waiting for workers to perform evaluation")
                future_all_eval_stats_before = pool.map_async(
                    call, [
                        partial(
                            process_in_batches,
                            batch_size=eval_batch_size,
                            model=model,
                            batch_processor=evaluator,
                            edges=edges,
                            indices=eval_edge_perm[s],
                        ) for s in split_almost_equally(eval_edge_perm.size(0),
                                                        num_parts=num_workers)
                    ])
                all_eval_stats_before = \
                    get_async_result(future_all_eval_stats_before, pool)
                eval_stats_before = Stats.sum(all_eval_stats_before).average()
                bucket_logger.info(
                    f"Stats before training: {eval_stats_before}")

            io_time += time.time() - tic
            tic = time.time()
            # HOGWILD training
            bucket_logger.debug("Waiting for workers to perform training")
            # FIXME should we only delay if iteration_idx == 0?
            future_all_stats = pool.map_async(call, [
                partial(
                    process_in_batches,
                    batch_size=config.batch_size,
                    model=model,
                    batch_processor=trainer,
                    edges=edges,
                    indices=edge_perm[s],
                    delay=config.hogwild_delay
                    if epoch_idx == 0 and rank > 0 else 0,
                ) for rank, s in enumerate(
                    split_almost_equally(edge_perm.size(0),
                                         num_parts=num_workers))
            ])
            all_stats = get_async_result(future_all_stats, pool)
            stats = Stats.sum(all_stats).average()
            compute_time = time.time() - tic

            bucket_logger.info(
                f"bucket {total_buckets - remaining} / {total_buckets} : "
                f"Processed {num_edges} edges in {compute_time:.2f} s "
                f"( {num_edges / compute_time / 1e6:.2g} M/sec ); "
                f"io: {io_time:.2f} s ( {io_bytes / io_time / 1e6:.2f} MB/sec )"
            )
            bucket_logger.info(f"{stats}")

            # HOGWILD eval after training
            eval_stats_after: Optional[Stats] = None
            if num_eval_edges > 0:
                bucket_logger.debug(
                    "Waiting for workers to perform evaluation")
                future_all_eval_stats_after = pool.map_async(
                    call, [
                        partial(
                            process_in_batches,
                            batch_size=eval_batch_size,
                            model=model,
                            batch_processor=evaluator,
                            edges=edges,
                            indices=eval_edge_perm[s],
                        ) for s in split_almost_equally(eval_edge_perm.size(0),
                                                        num_parts=num_workers)
                    ])
                all_eval_stats_after = \
                    get_async_result(future_all_eval_stats_after, pool)
                eval_stats_after = Stats.sum(all_eval_stats_after).average()
                bucket_logger.info(f"Stats after training: {eval_stats_after}")

            # Add train/eval metrics to queue
            stats_dict = {
                "index": current_index,
                "stats": stats.to_dict(),
            }
            if eval_stats_before is not None:
                stats_dict["eval_stats_before"] = eval_stats_before.to_dict()
            if eval_stats_after is not None:
                stats_dict["eval_stats_after"] = eval_stats_after.to_dict()
            checkpoint_manager.append_stats(stats_dict)
            yield current_index, eval_stats_before, stats, eval_stats_after

        swap_partitioned_embeddings(cur_b, None)

        # Distributed Processing: all machines can leave the barrier now.
        sync.barrier()

        # Preserving a checkpoint requires two steps:
        # - create a snapshot (w/ symlinks) after it's first written;
        # - don't delete it once the following one is written.
        # These two happen in two successive iterations of the main loop: the
        # one just before and the one just after the epoch boundary.
        preserve_old_checkpoint = should_preserve_old_checkpoint(
            iteration_manager, config.checkpoint_preservation_interval)
        preserve_new_checkpoint = should_preserve_old_checkpoint(
            iteration_manager + 1, config.checkpoint_preservation_interval)

        # Write metadata: for multiple machines, write from rank-0
        logger.info(
            f"Finished epoch {epoch_idx + 1} / {iteration_manager.num_epochs}, "
            f"edge path {edge_path_idx + 1} / {iteration_manager.num_edge_paths}, "
            f"edge chunk {edge_chunk_idx + 1} / {iteration_manager.num_edge_chunks}"
        )
        if rank == 0:
            for entity, econfig in config.entities.items():
                if econfig.num_partitions == 1:
                    embs = model.get_embeddings(entity, Side.LHS)
                    optimizer = trainer.entity_optimizers[(entity,
                                                           Partition(0))]

                    checkpoint_manager.write(
                        entity, Partition(0), embs.detach(),
                        OptimizerStateDict(optimizer.state_dict()))

            sanitized_state_dict: ModuleStateDict = {}
            for k, v in ModuleStateDict(model.state_dict()).items():
                if k.startswith('lhs_embs') or k.startswith('rhs_embs'):
                    # skipping state that's an entity embedding
                    continue
                sanitized_state_dict[k] = v

            logger.info("Writing the metadata")
            checkpoint_manager.write_model(
                sanitized_state_dict,
                OptimizerStateDict(trainer.global_optimizer.state_dict()),
            )

        logger.info("Writing the checkpoint")
        checkpoint_manager.write_new_version(config)

        dist_logger.info(
            "Waiting for other workers to write their parts of the checkpoint")
        sync.barrier()
        dist_logger.info("All parts of the checkpoint have been written")

        logger.info("Switching to the new checkpoint version")
        checkpoint_manager.switch_to_new_version()

        dist_logger.info(
            "Waiting for other workers to switch to the new checkpoint version"
        )
        sync.barrier()
        dist_logger.info(
            "All workers have switched to the new checkpoint version")

        # After all the machines have finished committing
        # checkpoints, we either remove the old checkpoints
        # or we preserve it
        if preserve_new_checkpoint:
            # Add 1 so the index is a multiple of the interval, it looks nicer.
            checkpoint_manager.preserve_current_version(config, epoch_idx + 1)
        if not preserve_old_checkpoint:
            checkpoint_manager.remove_old_version(config)

        # now we're sure that all partition files exist,
        # so be strict about loading them
        strict = True

    # quiescence
    pool.close()
    pool.join()

    sync.barrier()

    checkpoint_manager.close()
    if loadpath_manager is not None:
        loadpath_manager.close()

    # FIXME join distributed workers (not really necessary)

    logger.info("Exiting")
예제 #17
0
def train_and_report_stats(
    config: ConfigSchema,
    model: Optional[MultiRelationEmbedder] = None,
    trainer: Optional[AbstractBatchProcessor] = None,
    evaluator: Optional[AbstractBatchProcessor] = None,
    rank: Rank = RANK_ZERO,
) -> Generator[Tuple[int, Optional[Stats], Stats, Optional[Stats]], None, None]:
    """Each epoch/pass, for each partition pair, loads in embeddings and edgelist
    from disk, runs HOGWILD training on them, and writes partitions back to disk.
    """

    if config.verbose > 0:
        import pprint
        pprint.PrettyPrinter().pprint(config.to_dict())

    log("Loading entity counts...")
    if maybe_old_entity_path(config.entity_path):
        log("WARNING: It may be that your entity path contains files using the "
            "old format. See D14241362 for how to update them.")
    entity_counts: Dict[str, List[int]] = {}
    for entity, econf in config.entities.items():
        entity_counts[entity] = []
        for part in range(econf.num_partitions):
            with open(os.path.join(
                config.entity_path, "entity_count_%s_%d.txt" % (entity, part)
            ), "rt") as tf:
                entity_counts[entity].append(int(tf.read().strip()))

    # Figure out how many lhs and rhs partitions we need
    nparts_lhs, lhs_partitioned_types = get_partitioned_types(config, Side.LHS)
    nparts_rhs, rhs_partitioned_types = get_partitioned_types(config, Side.RHS)
    vlog("nparts %d %d types %s %s" %
         (nparts_lhs, nparts_rhs, lhs_partitioned_types, rhs_partitioned_types))
    total_buckets = nparts_lhs * nparts_rhs

    sync: AbstractSynchronizer
    bucket_scheduler: AbstractBucketScheduler
    parameter_sharer: Optional[ParameterSharer]
    partition_client: Optional[PartitionClient]
    if config.num_machines > 1:
        if not 0 <= rank < config.num_machines:
            raise RuntimeError("Invalid rank for trainer")
        if not td.is_available():
            raise RuntimeError("The installed PyTorch version doesn't provide "
                               "distributed training capabilities.")
        ranks = ProcessRanks.from_num_invocations(
            config.num_machines, config.num_partition_servers)

        if rank == RANK_ZERO:
            log("Setup lock server...")
            start_server(
                LockServer(
                    num_clients=len(ranks.trainers),
                    nparts_lhs=nparts_lhs,
                    nparts_rhs=nparts_rhs,
                    lock_lhs=len(lhs_partitioned_types) > 0,
                    lock_rhs=len(rhs_partitioned_types) > 0,
                    init_tree=config.distributed_tree_init_order,
                ),
                server_rank=ranks.lock_server,
                world_size=ranks.world_size,
                init_method=config.distributed_init_method,
                groups=[ranks.trainers],
            )

        bucket_scheduler = DistributedBucketScheduler(
            server_rank=ranks.lock_server,
            client_rank=ranks.trainers[rank],
        )

        log("Setup param server...")
        start_server(
            ParameterServer(num_clients=len(ranks.trainers)),
            server_rank=ranks.parameter_servers[rank],
            init_method=config.distributed_init_method,
            world_size=ranks.world_size,
            groups=[ranks.trainers],
        )

        parameter_sharer = ParameterSharer(
            client_rank=ranks.parameter_clients[rank],
            all_server_ranks=ranks.parameter_servers,
            init_method=config.distributed_init_method,
            world_size=ranks.world_size,
            groups=[ranks.trainers],
        )

        if config.num_partition_servers == -1:
            start_server(
                ParameterServer(num_clients=len(ranks.trainers)),
                server_rank=ranks.partition_servers[rank],
                world_size=ranks.world_size,
                init_method=config.distributed_init_method,
                groups=[ranks.trainers],
            )

        if len(ranks.partition_servers) > 0:
            partition_client = PartitionClient(ranks.partition_servers)
        else:
            partition_client = None

        groups = init_process_group(
            rank=ranks.trainers[rank],
            world_size=ranks.world_size,
            init_method=config.distributed_init_method,
            groups=[ranks.trainers],
        )
        trainer_group, = groups
        sync = DistributedSynchronizer(trainer_group)
        dlog = log

    else:
        sync = DummySynchronizer()
        bucket_scheduler = SingleMachineBucketScheduler(
            nparts_lhs, nparts_rhs, config.bucket_order)
        parameter_sharer = None
        partition_client = None
        dlog = lambda msg: None

    # fork early for HOGWILD threads
    log("Creating workers...")
    num_workers = get_num_workers(config.workers)
    pool = create_pool(num_workers)

    def make_optimizer(params: Iterable[torch.nn.Parameter], is_emb: bool) -> Optimizer:
        params = list(params)
        if len(params) == 0:
            optimizer = DummyOptimizer()
        elif is_emb:
            optimizer = RowAdagrad(params, lr=config.lr)
        else:
            if config.relation_lr is not None:
                lr = config.relation_lr
            else:
                lr = config.lr
            optimizer = Adagrad(params, lr=lr)
        optimizer.share_memory()
        return optimizer

    # background_io is only supported in single-machine mode
    background_io = config.background_io and config.num_machines == 1

    checkpoint_manager = CheckpointManager(
        config.checkpoint_path,
        background=background_io,
        rank=rank,
        num_machines=config.num_machines,
        partition_client=partition_client,
    )
    checkpoint_manager.register_metadata_provider(ConfigMetadataProvider(config))
    checkpoint_manager.write_config(config)

    iteration_manager = IterationManager(
        config.num_epochs, config.edge_paths, config.num_edge_chunks,
        iteration_idx=checkpoint_manager.checkpoint_version)
    checkpoint_manager.register_metadata_provider(iteration_manager)

    if config.init_path is not None:
        loadpath_manager = CheckpointManager(config.init_path)
    else:
        loadpath_manager = None

    def load_embeddings(
        entity: EntityName,
        part: Partition,
        strict: bool = False,
        force_dirty: bool = False,
    ) -> Tuple[torch.nn.Parameter, Optional[OptimizerStateDict]]:
        if strict:
            embs, optim_state = checkpoint_manager.read(entity, part,
                                                        force_dirty=force_dirty)
        else:
            # Strict is only false during the first iteration, because in that
            # case the checkpoint may not contain any data (unless a previous
            # run was resumed) so we fall back on initial values.
            embs, optim_state = checkpoint_manager.maybe_read(entity, part,
                                                              force_dirty=force_dirty)
            if embs is None and loadpath_manager is not None:
                embs, optim_state = loadpath_manager.maybe_read(entity, part)
            if embs is None:
                embs, optim_state = init_embs(entity, entity_counts[entity][part],
                                              config.dimension, config.init_scale)
        assert embs.is_shared()
        return torch.nn.Parameter(embs), optim_state

    log("Initializing global model...")

    if model is None:
        model = make_model(config)
    model.share_memory()
    if trainer is None:
        trainer = Trainer(
            global_optimizer=make_optimizer(model.parameters(), False),
            loss_fn=config.loss_fn,
            margin=config.margin,
            relations=config.relations,
        )
    if evaluator is None:
        evaluator = TrainingRankingEvaluator(
            override_num_batch_negs=config.eval_num_batch_negs,
            override_num_uniform_negs=config.eval_num_uniform_negs,
        )
    eval_batch_size = round_up_to_nearest_multiple(config.batch_size, config.eval_num_batch_negs)

    state_dict, optim_state = checkpoint_manager.maybe_read_model()

    if state_dict is None and loadpath_manager is not None:
        state_dict, optim_state = loadpath_manager.maybe_read_model()
    if state_dict is not None:
        model.load_state_dict(state_dict, strict=False)
    if optim_state is not None:
        trainer.global_optimizer.load_state_dict(optim_state)

    vlog("Loading unpartitioned entities...")
    for entity, econfig in config.entities.items():
        if econfig.num_partitions == 1:
            embs, optim_state = load_embeddings(entity, Partition(0))
            model.set_embeddings(entity, embs, Side.LHS)
            model.set_embeddings(entity, embs, Side.RHS)
            optimizer = make_optimizer([embs], True)
            if optim_state is not None:
                optimizer.load_state_dict(optim_state)
            trainer.entity_optimizers[(entity, Partition(0))] = optimizer

    # start communicating shared parameters with the parameter server
    if parameter_sharer is not None:
        parameter_sharer.share_model_params(model)

    strict = False

    def swap_partitioned_embeddings(
        old_b: Optional[Bucket],
        new_b: Optional[Bucket],
    ):
        # 0. given the old and new buckets, construct data structures to keep
        #    track of old and new embedding (entity, part) tuples

        io_bytes = 0
        log("Swapping partitioned embeddings %s %s" % (old_b, new_b))

        types = ([(e, Side.LHS) for e in lhs_partitioned_types]
                 + [(e, Side.RHS) for e in rhs_partitioned_types])
        old_parts = {(e, old_b.get_partition(side)): side
                     for e, side in types if old_b is not None}
        new_parts = {(e, new_b.get_partition(side)): side
                     for e, side in types if new_b is not None}

        to_checkpoint = set(old_parts) - set(new_parts)
        preserved = set(old_parts) & set(new_parts)

        # 1. checkpoint embeddings that will not be used in the next pair
        #
        if old_b is not None:  # there are previous embeddings to checkpoint
            log("Writing partitioned embeddings")
            for entity, part in to_checkpoint:
                side = old_parts[(entity, part)]
                vlog("Checkpointing (%s %d %s)" %
                     (entity, part, side.pick("lhs", "rhs")))
                embs = model.get_embeddings(entity, side)
                optim_key = (entity, part)
                optim_state = OptimizerStateDict(trainer.entity_optimizers[optim_key].state_dict())
                io_bytes += embs.numel() * embs.element_size()  # ignore optim state
                checkpoint_manager.write(entity, part, embs.detach(), optim_state)
                if optim_key in trainer.entity_optimizers:
                    del trainer.entity_optimizers[optim_key]
                # these variables are holding large objects; let them be freed
                del embs
                del optim_state

            bucket_scheduler.release_bucket(old_b)

        # 2. copy old embeddings that will be used in the next pair
        #    into a temporary dictionary
        #
        tmp_emb = {x: model.get_embeddings(x[0], old_parts[x]) for x in preserved}

        for entity, _ in types:
            model.clear_embeddings(entity, Side.LHS)
            model.clear_embeddings(entity, Side.RHS)

        if new_b is None:  # there are no new embeddings to load
            return io_bytes

        # 3. load new embeddings into the model/optimizer, either from disk
        #    or the temporary dictionary
        #
        log("Loading entities")
        for entity, side in types:
            part = new_b.get_partition(side)
            part_key = (entity, part)
            if part_key in tmp_emb:
                vlog("Loading (%s, %d) from preserved" % (entity, part))
                embs, optim_state = tmp_emb[part_key], None
            else:
                vlog("Loading (%s, %d)" % (entity, part))

                force_dirty = bucket_scheduler.check_and_set_dirty(entity, part)
                embs, optim_state = load_embeddings(
                    entity, part, strict=strict, force_dirty=force_dirty)
                io_bytes += embs.numel() * embs.element_size()  # ignore optim state

            model.set_embeddings(entity, embs, side)
            tmp_emb[part_key] = embs

            optim_key = (entity, part)
            if optim_key not in trainer.entity_optimizers:
                vlog("Resetting optimizer %s" % (optim_key,))
                optimizer = make_optimizer([embs], True)
                if optim_state is not None:
                    vlog("Setting optim state")
                    optimizer.load_state_dict(optim_state)

                trainer.entity_optimizers[optim_key] = optimizer

        return io_bytes

    # Start of the main training loop.
    for epoch_idx, edge_path_idx, edge_chunk_idx \
            in iteration_manager.remaining_iterations():
        log("Starting epoch %d / %d edge path %d / %d edge chunk %d / %d" %
            (epoch_idx + 1, iteration_manager.num_epochs,
             edge_path_idx + 1, iteration_manager.num_edge_paths,
             edge_chunk_idx + 1, iteration_manager.num_edge_chunks))
        edge_reader = EdgeReader(iteration_manager.edge_path)
        log("edge_path= %s" % iteration_manager.edge_path)

        sync.barrier()
        dlog("Lock client new epoch...")
        bucket_scheduler.new_pass(is_first=iteration_manager.iteration_idx == 0)
        sync.barrier()

        remaining = total_buckets
        cur_b = None
        while remaining > 0:
            old_b = cur_b
            io_time = 0.
            io_bytes = 0
            cur_b, remaining = bucket_scheduler.acquire_bucket()
            print('still in queue: %d' % remaining, file=sys.stderr)
            if cur_b is None:
                if old_b is not None:
                    # if you couldn't get a new pair, release the lock
                    # to prevent a deadlock!
                    tic = time.time()
                    io_bytes += swap_partitioned_embeddings(old_b, None)
                    io_time += time.time() - tic
                time.sleep(1)  # don't hammer td
                continue

            def log_status(msg, always=False):
                f = log if always else vlog
                f("%s: %s" % (cur_b, msg))

            tic = time.time()

            io_bytes += swap_partitioned_embeddings(old_b, cur_b)

            current_index = \
                (iteration_manager.iteration_idx + 1) * total_buckets - remaining

            next_b = bucket_scheduler.peek()
            if next_b is not None and background_io:
                # Ensure the previous bucket finished writing to disk.
                checkpoint_manager.wait_for_marker(current_index - 1)

                log_status("Prefetching")
                for entity in lhs_partitioned_types:
                    checkpoint_manager.prefetch(entity, next_b.lhs)
                for entity in rhs_partitioned_types:
                    checkpoint_manager.prefetch(entity, next_b.rhs)

                checkpoint_manager.record_marker(current_index)

            log_status("Loading edges")
            edges = edge_reader.read(
                cur_b.lhs, cur_b.rhs, edge_chunk_idx, config.num_edge_chunks)
            num_edges = len(edges)
            # this might be off in the case of tensorlist or extra edge fields
            io_bytes += edges.lhs.tensor.numel() * edges.lhs.tensor.element_size()
            io_bytes += edges.rhs.tensor.numel() * edges.rhs.tensor.element_size()
            io_bytes += edges.rel.numel() * edges.rel.element_size()

            log_status("Shuffling edges")
            # Fix a seed to get the same permutation every time; have it
            # depend on all and only what affects the set of edges.
            g = torch.Generator()
            g.manual_seed(hash((edge_path_idx, edge_chunk_idx, cur_b.lhs, cur_b.rhs)))

            num_eval_edges = int(num_edges * config.eval_fraction)
            if num_eval_edges > 0:
                edge_perm = torch.randperm(num_edges, generator=g)
                eval_edge_perm = edge_perm[-num_eval_edges:]
                num_edges -= num_eval_edges
                edge_perm = edge_perm[torch.randperm(num_edges)]
            else:
                edge_perm = torch.randperm(num_edges)

            # HOGWILD evaluation before training
            eval_stats_before: Optional[Stats] = None
            if num_eval_edges > 0:
                log_status("Waiting for workers to perform evaluation")
                all_eval_stats_before = pool.map(call, [
                    partial(
                        process_in_batches,
                        batch_size=eval_batch_size,
                        model=model,
                        batch_processor=evaluator,
                        edges=edges,
                        indices=eval_edge_perm[s],
                    )
                    for s in split_almost_equally(eval_edge_perm.size(0),
                                                  num_parts=num_workers)
                ])
                eval_stats_before = Stats.sum(all_eval_stats_before).average()
                log("stats before %s: %s" % (cur_b, eval_stats_before))

            io_time += time.time() - tic
            tic = time.time()
            # HOGWILD training
            log_status("Waiting for workers to perform training")
            # FIXME should we only delay if iteration_idx == 0?
            all_stats = pool.map(call, [
                partial(
                    process_in_batches,
                    batch_size=config.batch_size,
                    model=model,
                    batch_processor=trainer,
                    edges=edges,
                    indices=edge_perm[s],
                    delay=config.hogwild_delay if epoch_idx == 0 and rank > 0 else 0,
                )
                for rank, s in enumerate(split_almost_equally(edge_perm.size(0),
                                                              num_parts=num_workers))
            ])
            stats = Stats.sum(all_stats).average()
            compute_time = time.time() - tic

            log_status(
                "bucket %d / %d : Processed %d edges in %.2f s "
                "( %.2g M/sec ); io: %.2f s ( %.2f MB/sec )" %
                (total_buckets - remaining, total_buckets,
                 num_edges, compute_time, num_edges / compute_time / 1e6,
                 io_time, io_bytes / io_time / 1e6),
                always=True)
            log_status("%s" % stats, always=True)

            # HOGWILD eval after training
            eval_stats_after: Optional[Stats] = None
            if num_eval_edges > 0:
                log_status("Waiting for workers to perform evaluation")
                all_eval_stats_after = pool.map(call, [
                    partial(
                        process_in_batches,
                        batch_size=eval_batch_size,
                        model=model,
                        batch_processor=evaluator,
                        edges=edges,
                        indices=eval_edge_perm[s],
                    )
                    for s in split_almost_equally(eval_edge_perm.size(0),
                                                  num_parts=num_workers)
                ])
                eval_stats_after = Stats.sum(all_eval_stats_after).average()
                log("stats after %s: %s" % (cur_b, eval_stats_after))

            # Add train/eval metrics to queue
            yield current_index, eval_stats_before, stats, eval_stats_after

        swap_partitioned_embeddings(cur_b, None)

        # Distributed Processing: all machines can leave the barrier now.
        sync.barrier()

        # Write metadata: for multiple machines, write from rank-0
        log("Finished epoch %d path %d pass %d; checkpointing global state."
            % (epoch_idx + 1, edge_path_idx + 1, edge_chunk_idx + 1))
        log("My rank: %d" % rank)
        if rank == 0:
            for entity, econfig in config.entities.items():
                if econfig.num_partitions == 1:
                    embs = model.get_embeddings(entity, Side.LHS)
                    optimizer = trainer.entity_optimizers[(entity, Partition(0))]

                    checkpoint_manager.write(
                        entity, Partition(0),
                        embs.detach(), OptimizerStateDict(optimizer.state_dict()))

            sanitized_state_dict: ModuleStateDict = {}
            for k, v in ModuleStateDict(model.state_dict()).items():
                if k.startswith('lhs_embs') or k.startswith('rhs_embs'):
                    # skipping state that's an entity embedding
                    continue
                sanitized_state_dict[k] = v

            log("Writing metadata...")
            checkpoint_manager.write_model(
                sanitized_state_dict,
                OptimizerStateDict(trainer.global_optimizer.state_dict()),
            )

        log("Writing the checkpoint...")
        checkpoint_manager.write_new_version(config)

        dlog("Waiting for other workers to write their parts of the checkpoint: rank %d" % rank)
        sync.barrier()
        dlog("All parts of the checkpoint have been written")

        log("Switching to new checkpoint version...")
        checkpoint_manager.switch_to_new_version()

        dlog("Waiting for other workers to switch to the new checkpoint version: rank %d" % rank)
        sync.barrier()
        dlog("All workers have switched to the new checkpoint version")

        # After all the machines have finished committing
        # checkpoints, we remove the old checkpoints.
        checkpoint_manager.remove_old_version(config)

        # now we're sure that all partition files exist,
        # so be strict about loading them
        strict = True

    # quiescence
    pool.close()
    pool.join()

    sync.barrier()

    checkpoint_manager.close()
    if loadpath_manager is not None:
        loadpath_manager.close()

    # FIXME join distributed workers (not really necessary)

    log("Exiting")