Exemplo n.º 1
0
    def _coordinate_train(self, edges, eval_edge_idxs, epoch_idx) -> Stats:
        assert self.config.num_gpus == 0, "GPU training not supported"

        if eval_edge_idxs is not None:
            num_train_edges = len(edges) - len(eval_edge_idxs)
            train_edge_idxs = torch.arange(len(edges))
            train_edge_idxs[eval_edge_idxs] = torch.arange(num_train_edges, len(edges))
            train_edge_idxs = train_edge_idxs[:num_train_edges]
            edge_perm = train_edge_idxs[torch.randperm(num_train_edges)]
        else:
            edge_perm = torch.randperm(len(edges))

        future_all_stats = self.pool.map_async(
            call,
            [
                partial(
                    process_in_batches,
                    batch_size=self.config.batch_size,
                    model=self.model,
                    batch_processor=self.trainer,
                    edges=edges,
                    indices=edge_perm[s],
                    # FIXME should we only delay if iteration_idx == 0?
                    delay=self.config.hogwild_delay
                    if epoch_idx == 0 and self.rank > 0
                    else 0,
                )
                for rank, s in enumerate(
                    split_almost_equally(edge_perm.size(0), num_parts=self.num_workers)
                )
            ],
        )
        all_stats = get_async_result(future_all_stats, self.pool)
        return Stats.sum(all_stats).average()
Exemplo n.º 2
0
 def _coordinate_eval(self, edges, eval_edge_idxs) -> Optional[Stats]:
     eval_batch_size = round_up_to_nearest_multiple(
         self.config.batch_size, self.config.eval_num_batch_negs
     )
     if eval_edge_idxs is not None:
         self.bucket_logger.debug("Waiting for workers to perform evaluation")
         future_all_eval_stats = self.pool.map_async(
             call,
             [
                 partial(
                     process_in_batches,
                     batch_size=eval_batch_size,
                     model=self.model,
                     batch_processor=self.evaluator,
                     edges=edges,
                     indices=eval_edge_idxs[s],
                 )
                 for s in split_almost_equally(
                     eval_edge_idxs.size(0), num_parts=self.num_workers
                 )
             ],
         )
         all_eval_stats = get_async_result(future_all_eval_stats, self.pool)
         return Stats.sum(all_eval_stats).average()
     else:
         return None
Exemplo n.º 3
0
 def test_fewer(self):
     self.assertEqual(
         list(split_almost_equally(23, num_parts=4)),
         [slice(0, 6),
          slice(6, 12),
          slice(12, 18),
          slice(18, 23)],
     )
Exemplo n.º 4
0
 def test_more(self):
     self.assertEqual(
         list(split_almost_equally(25, num_parts=4)),
         [slice(0, 7),
          slice(7, 13),
          slice(13, 19),
          slice(19, 25)],
     )
Exemplo n.º 5
0
 def test_exact(self):
     self.assertEqual(
         list(split_almost_equally(24, num_parts=4)),
         [slice(0, 6),
          slice(6, 12),
          slice(12, 18),
          slice(18, 24)],
     )
Exemplo n.º 6
0
 def test_more(self):
     self.assertEqual(
         list(split_almost_equally(25, num_parts=4)),
         [slice(0, 7),
          slice(7, 14),
          slice(14, 21),
          slice(21, 25)],
     )
Exemplo n.º 7
0
 def store(self,
           key: str,
           src: torch.Tensor,
           accum: bool = False,
           overwrite: bool = True) -> None:
     """Store or accumulate a tensor on the server."""
     self._validate_store(key, src, accum=accum, overwrite=overwrite)
     cmd_rpc = torch.tensor(
         [
             STORE_CMD,
             len(key),
             -1 if accum else src.ndimension(),
             int(accum),
             int(overwrite),
             _dtypes.index(src.dtype),
         ],
         dtype=torch.long,
     )
     metadata_pg = self._metadata_pg()
     td.send(cmd_rpc, self.server_rank, group=metadata_pg)
     td.send(_fromstring(key), self.server_rank, group=metadata_pg)
     if not accum:
         td.send(
             torch.tensor(list(src.size()), dtype=torch.long),
             self.server_rank,
             group=metadata_pg,
         )
     start_t = time.monotonic()
     data_pgs = self._data_pgs()
     if data_pgs is None:
         td.send(src, dst=self.server_rank)
     else:
         outstanding_work = []
         flattened_src = src.flatten()
         flattened_size = flattened_src.shape[0]
         for idx, (pg, slice_) in enumerate(
                 zip(
                     data_pgs,
                     split_almost_equally(flattened_size,
                                          num_parts=len(data_pgs)),
                 )):
             outstanding_work.append(
                 td.isend(
                     tensor=flattened_src[slice_],
                     dst=self.server_rank,
                     group=pg,
                     tag=idx,
                 ))
         for w in outstanding_work:
             w.wait()
     end_t = time.monotonic()
     if self.log_stats:
         stats_size = src.numel() * src.element_size()
         stats_time = end_t - start_t
         logger.debug(f"Sent tensor {key} to server {self.server_rank}: "
                      f"{stats_size:,} bytes "
                      f"in {stats_time:,g} seconds "
                      f"=> {stats_size / stats_time:,.0f} B/s")
Exemplo n.º 8
0
 def get(self,
         key: str,
         dst: Optional[torch.Tensor] = None,
         shared: bool = False) -> Optional[torch.Tensor]:
     """Get a tensor from the server."""
     self._validate_get(key, dst=dst, shared=shared)
     cmd_rpc = torch.tensor(
         [GET_CMD, len(key), dst is None, 0, 0, 0], dtype=torch.long)
     metadata_pg = self._metadata_pg()
     td.send(cmd_rpc, self.server_rank, group=metadata_pg)
     td.send(_fromstring(key), self.server_rank, group=metadata_pg)
     if dst is None:
         meta = torch.full((2, ), -1, dtype=torch.long)
         td.recv(meta, src=self.server_rank, group=metadata_pg)
         ndim, ttype = meta
         if ndim.item() == -1:
             return None
         size = torch.full((ndim.item(), ), -1, dtype=torch.long)
         td.recv(size, src=self.server_rank, group=metadata_pg)
         dtype = _dtypes[ttype.item()]
         if shared:
             dst = allocate_shared_tensor(size.tolist(), dtype=dtype)
         else:
             dst = torch.empty(size.tolist(), dtype=dtype)
     start_t = time.monotonic()
     data_pgs = self._data_pgs()
     if data_pgs is None:
         td.recv(dst, src=self.server_rank)
     else:
         outstanding_work = []
         flattened_dst = dst.flatten()
         flattened_size = flattened_dst.shape[0]
         for idx, (pg, slice_) in enumerate(
                 zip(
                     data_pgs,
                     split_almost_equally(flattened_size,
                                          num_parts=len(data_pgs)),
                 )):
             outstanding_work.append(
                 td.irecv(
                     tensor=flattened_dst[slice_],
                     src=self.server_rank,
                     group=pg,
                     tag=idx,
                 ))
         for w in outstanding_work:
             w.wait()
     end_t = time.monotonic()
     if self.log_stats:
         stats_size = dst.numel() * dst.element_size()
         stats_time = end_t - start_t
         logger.debug(
             f"Received tensor {key} from server {self.server_rank}: "
             f"{stats_size:,} bytes "
             f"in {stats_time:,g} seconds "
             f"=> {stats_size / stats_time:,.0f} B/s")
     return dst
 def test_so_few_that_last_slice_would_underflow(self):
     # All slices have the same size, which is the ratio size/num_parts
     # rounded up. This however may cause earlier slices to get so many
     # elements that later ones end up being empty. We need to be careful
     # about not returning negative slices in that case.
     self.assertEqual(
         list(split_almost_equally(5, num_parts=4)),
         [slice(0, 2), slice(2, 4),
          slice(4, 5), slice(5, 5)],
     )
     self.assertEqual(
         list(split_almost_equally(6, num_parts=5)),
         [slice(0, 2),
          slice(2, 4),
          slice(4, 6),
          slice(6, 6),
          slice(6, 6)],
     )
Exemplo n.º 10
0
    def handle_store(
        self,
        rank: int,
        key: str,
        ndim: int,
        accum: int,
        overwrite: int,
        ttype: int,
    ) -> None:
        if ndim == -1:
            assert key in self.parameters
            size = self.parameters[key].size()
        else:
            size = torch.empty((ndim, ), dtype=torch.long)
            td.recv(size, src=rank)
            size = size.tolist()
        dtype = _dtypes[ttype]
        if not accum and overwrite and key in self.parameters:
            # avoid holding onto 2x the memory
            del self.parameters[key]
        data = torch.empty(size, dtype=dtype)

        start_t = time.monotonic()
        if self.groups is None:
            td.recv(tensor=data, src=rank)
        else:
            outstanding_work = []
            flattened_data = data.flatten()
            flattened_size = flattened_data.shape[0]
            for idx, (pg, slice_) in enumerate(
                    zip(
                        self.groups,
                        split_almost_equally(flattened_size,
                                             num_parts=len(self.groups)))):
                outstanding_work.append(
                    td.irecv(tensor=flattened_data[slice_],
                             src=rank,
                             group=pg,
                             tag=idx))
            for w in outstanding_work:
                w.wait()
        end_t = time.monotonic()
        if self.log_stats:
            stats_size = data.numel() * data.element_size()
            stats_time = end_t - start_t
            logger.debug(f"Received tensor {key} from client {rank}: "
                         f"{stats_size:,} bytes "
                         f"in {stats_time:,g} seconds "
                         f"=> {stats_size / stats_time:,.0f} B/s")

        if accum:
            self.parameters[key] += data
        elif (key not in self.parameters) or overwrite:
            self.parameters[key] = data
Exemplo n.º 11
0
    def handle_get(self, rank: int, key: str, send_size: int) -> None:
        metadata_pg = self._metadata_pg()
        if key not in self.parameters:
            assert send_size, "Key %s not found" % key
            td.send(torch.tensor([-1, -1], dtype=torch.long),
                    rank,
                    group=metadata_pg)
            return

        data = self.parameters[key]
        if send_size:
            type_idx = _dtypes.index(data.dtype)
            td.send(
                torch.tensor([data.ndimension(), type_idx], dtype=torch.long),
                rank,
                group=metadata_pg,
            )
            td.send(
                torch.tensor(list(data.size()), dtype=torch.long),
                rank,
                group=metadata_pg,
            )

        start_t = time.monotonic()
        data_pgs = self._data_pgs()
        if data_pgs is None:
            td.send(data, dst=rank)
        else:
            outstanding_work = []
            flattened_data = data.flatten()
            flattened_size = flattened_data.shape[0]
            for idx, (pg, slice_) in enumerate(
                    zip(
                        data_pgs,
                        split_almost_equally(flattened_size,
                                             num_parts=len(data_pgs)),
                    )):
                outstanding_work.append(
                    td.isend(tensor=flattened_data[slice_],
                             dst=rank,
                             group=pg,
                             tag=idx))
            for w in outstanding_work:
                w.wait()
        end_t = time.monotonic()
        if self.log_stats:
            stats_size = data.numel() * data.element_size()
            stats_time = end_t - start_t
            logger.debug(f"Sent tensor {key} to client {rank}: "
                         f"{stats_size:,} bytes "
                         f"in {stats_time:,g} seconds "
                         f"=> {stats_size / stats_time:,.0f} B/s")
Exemplo n.º 12
0
    def _coordinate_train(self, edges, edge_perm, epoch_idx) -> Stats:
        assert self.config.num_gpus == 0, "GPU training not supported"

        future_all_stats = self.pool.map_async(
            call,
            [
                partial(
                    process_in_batches,
                    batch_size=self.config.batch_size,
                    model=self.model,
                    batch_processor=self.trainer,
                    edges=edges,
                    indices=edge_perm[s],
                    # FIXME should we only delay if iteration_idx == 0?
                    delay=self.config.hogwild_delay
                    if epoch_idx == 0 and self.rank > 0 else 0,
                ) for rank, s in enumerate(
                    split_almost_equally(edge_perm.size(0),
                                         num_parts=self.num_workers))
            ])
        all_stats = get_async_result(future_all_stats, self.pool)
        return Stats.sum(all_stats).average()
Exemplo n.º 13
0
def do_eval_and_report_stats(
    config: ConfigSchema,
    model: Optional[MultiRelationEmbedder] = None,
    evaluator: Optional[AbstractBatchProcessor] = None,
) -> Generator[Tuple[Optional[int], Optional[Bucket], Stats], None, None]:
    """Computes eval metrics (mr/mrr/r1/r10/r50) for a checkpoint with trained
       embeddings.
    """

    if evaluator is None:
        evaluator = RankingEvaluator()

    if config.verbose > 0:
        import pprint
        pprint.PrettyPrinter().pprint(config.to_dict())

    checkpoint_manager = CheckpointManager(config.checkpoint_path)

    def load_embeddings(entity: EntityName, part: Partition) -> torch.nn.Parameter:
        embs, _ = checkpoint_manager.read(entity, part)
        assert embs.is_shared()
        return torch.nn.Parameter(embs)

    nparts_lhs, lhs_partitioned_types = get_partitioned_types(config, Side.LHS)
    nparts_rhs, rhs_partitioned_types = get_partitioned_types(config, Side.RHS)

    num_workers = get_num_workers(config.workers)
    pool = create_pool(num_workers)

    if model is None:
        model = make_model(config)
    model.share_memory()

    state_dict, _ = checkpoint_manager.maybe_read_model()
    if state_dict is not None:
        model.load_state_dict(state_dict, strict=False)

    model.eval()

    for entity, econfig in config.entities.items():
        if econfig.num_partitions == 1:
            embs = load_embeddings(entity, Partition(0))
            model.set_embeddings(entity, embs, Side.LHS)
            model.set_embeddings(entity, embs, Side.RHS)

    all_stats: List[Stats] = []
    for edge_path_idx, edge_path in enumerate(config.edge_paths):
        log("Starting edge path %d / %d (%s)"
            % (edge_path_idx + 1, len(config.edge_paths), edge_path))
        edge_reader = EdgeReader(edge_path)

        all_edge_path_stats = []
        last_lhs, last_rhs = None, None
        for bucket in create_buckets_ordered_lexicographically(nparts_lhs, nparts_rhs):
            tic = time.time()
            # log("%s: Loading entities" % (bucket,))

            if last_lhs != bucket.lhs:
                for e in lhs_partitioned_types:
                    model.clear_embeddings(e, Side.LHS)
                    embs = load_embeddings(e, bucket.lhs)
                    model.set_embeddings(e, embs, Side.LHS)
            if last_rhs != bucket.rhs:
                for e in rhs_partitioned_types:
                    model.clear_embeddings(e, Side.RHS)
                    embs = load_embeddings(e, bucket.rhs)
                    model.set_embeddings(e, embs, Side.RHS)
            last_lhs, last_rhs = bucket.lhs, bucket.rhs

            # log("%s: Loading edges" % (bucket,))
            edges = edge_reader.read(bucket.lhs, bucket.rhs)
            num_edges = len(edges)

            load_time = time.time() - tic
            tic = time.time()
            # log("%s: Launching and waiting for workers" % (bucket,))
            all_bucket_stats = pool.map(call, [
                partial(
                    process_in_batches,
                    batch_size=config.batch_size,
                    model=model,
                    batch_processor=evaluator,
                    edges=edges[s],
                )
                for s in split_almost_equally(num_edges, num_parts=num_workers)
            ])

            compute_time = time.time() - tic
            log("%s: Processed %d edges in %.2g s (%.2gM/sec); load time: %.2g s"
                % (bucket, num_edges, compute_time,
                   num_edges / compute_time / 1e6, load_time))

            total_bucket_stats = Stats.sum(all_bucket_stats)
            all_edge_path_stats.append(total_bucket_stats)
            mean_bucket_stats = total_bucket_stats.average()
            log("Stats for edge path %d / %d, bucket %s: %s"
                % (edge_path_idx + 1, len(config.edge_paths), bucket,
                   mean_bucket_stats))

            yield edge_path_idx, bucket, mean_bucket_stats

        total_edge_path_stats = Stats.sum(all_edge_path_stats)
        all_stats.append(total_edge_path_stats)
        mean_edge_path_stats = total_edge_path_stats.average()
        log("")
        log("Stats for edge path %d / %d: %s"
            % (edge_path_idx + 1, len(config.edge_paths), mean_edge_path_stats))
        log("")

        yield edge_path_idx, None, mean_edge_path_stats

    mean_stats = Stats.sum(all_stats).average()
    log("")
    log("Stats: %s" % mean_stats)
    log("")

    yield None, None, mean_stats

    pool.close()
    pool.join()
Exemplo n.º 14
0
def train_and_report_stats(
    config: ConfigSchema,
    model: Optional[MultiRelationEmbedder] = None,
    trainer: Optional[AbstractBatchProcessor] = None,
    evaluator: Optional[AbstractBatchProcessor] = None,
    rank: Rank = RANK_ZERO,
    subprocess_init: Optional[Callable[[], None]] = None,
) -> Generator[Tuple[int, Optional[Stats], Stats, Optional[Stats]], None,
               None]:
    """Each epoch/pass, for each partition pair, loads in embeddings and edgelist
    from disk, runs HOGWILD training on them, and writes partitions back to disk.
    """
    tag_logs_with_process_name(f"Trainer-{rank}")

    if config.verbose > 0:
        import pprint
        pprint.PrettyPrinter().pprint(config.to_dict())

    logger.info("Loading entity counts...")
    entity_storage = ENTITY_STORAGES.make_instance(config.entity_path)
    entity_counts: Dict[str, List[int]] = {}
    for entity, econf in config.entities.items():
        entity_counts[entity] = []
        for part in range(econf.num_partitions):
            entity_counts[entity].append(
                entity_storage.load_count(entity, part))

    # Figure out how many lhs and rhs partitions we need
    nparts_lhs, lhs_partitioned_types = get_partitioned_types(config, Side.LHS)
    nparts_rhs, rhs_partitioned_types = get_partitioned_types(config, Side.RHS)
    logger.debug(f"nparts {nparts_lhs} {nparts_rhs} "
                 f"types {lhs_partitioned_types} {rhs_partitioned_types}")
    total_buckets = nparts_lhs * nparts_rhs

    sync: AbstractSynchronizer
    bucket_scheduler: AbstractBucketScheduler
    parameter_sharer: Optional[ParameterSharer]
    partition_client: Optional[PartitionClient]
    if config.num_machines > 1:
        if not 0 <= rank < config.num_machines:
            raise RuntimeError("Invalid rank for trainer")
        if not td.is_available():
            raise RuntimeError("The installed PyTorch version doesn't provide "
                               "distributed training capabilities.")
        ranks = ProcessRanks.from_num_invocations(config.num_machines,
                                                  config.num_partition_servers)

        if rank == RANK_ZERO:
            logger.info("Setup lock server...")
            start_server(
                LockServer(
                    num_clients=len(ranks.trainers),
                    nparts_lhs=nparts_lhs,
                    nparts_rhs=nparts_rhs,
                    lock_lhs=len(lhs_partitioned_types) > 0,
                    lock_rhs=len(rhs_partitioned_types) > 0,
                    init_tree=config.distributed_tree_init_order,
                ),
                process_name="LockServer",
                init_method=config.distributed_init_method,
                world_size=ranks.world_size,
                server_rank=ranks.lock_server,
                groups=[ranks.trainers],
                subprocess_init=subprocess_init,
            )

        bucket_scheduler = DistributedBucketScheduler(
            server_rank=ranks.lock_server,
            client_rank=ranks.trainers[rank],
        )

        logger.info("Setup param server...")
        start_server(
            ParameterServer(num_clients=len(ranks.trainers)),
            process_name=f"ParamS-{rank}",
            init_method=config.distributed_init_method,
            world_size=ranks.world_size,
            server_rank=ranks.parameter_servers[rank],
            groups=[ranks.trainers],
            subprocess_init=subprocess_init,
        )

        parameter_sharer = ParameterSharer(
            process_name=f"ParamC-{rank}",
            client_rank=ranks.parameter_clients[rank],
            all_server_ranks=ranks.parameter_servers,
            init_method=config.distributed_init_method,
            world_size=ranks.world_size,
            groups=[ranks.trainers],
            subprocess_init=subprocess_init,
        )

        if config.num_partition_servers == -1:
            start_server(
                ParameterServer(num_clients=len(ranks.trainers),
                                log_stats=True),
                process_name=f"PartS-{rank}",
                init_method=config.distributed_init_method,
                world_size=ranks.world_size,
                server_rank=ranks.partition_servers[rank],
                groups=[ranks.trainers],
                subprocess_init=subprocess_init,
            )

        if len(ranks.partition_servers) > 0:
            partition_client = PartitionClient(ranks.partition_servers,
                                               log_stats=True)
        else:
            partition_client = None

        groups = init_process_group(
            rank=ranks.trainers[rank],
            world_size=ranks.world_size,
            init_method=config.distributed_init_method,
            groups=[ranks.trainers],
        )
        trainer_group, = groups
        sync = DistributedSynchronizer(trainer_group)

    else:
        sync = DummySynchronizer()
        bucket_scheduler = SingleMachineBucketScheduler(
            nparts_lhs, nparts_rhs, config.bucket_order)
        parameter_sharer = None
        partition_client = None
        hide_distributed_logging()

    # fork early for HOGWILD threads
    logger.info("Creating workers...")
    num_workers = get_num_workers(config.workers)
    pool = create_pool(
        num_workers,
        subprocess_name=f"TWorker-{rank}",
        subprocess_init=subprocess_init,
    )

    def make_optimizer(params: Iterable[torch.nn.Parameter],
                       is_emb: bool) -> Optimizer:
        params = list(params)
        if len(params) == 0:
            optimizer = DummyOptimizer()
        elif is_emb:
            optimizer = RowAdagrad(params, lr=config.lr)
        else:
            if config.relation_lr is not None:
                lr = config.relation_lr
            else:
                lr = config.lr
            optimizer = Adagrad(params, lr=lr)
        optimizer.share_memory()
        return optimizer

    # background_io is only supported in single-machine mode
    background_io = config.background_io and config.num_machines == 1

    checkpoint_manager = CheckpointManager(
        config.checkpoint_path,
        background=background_io,
        rank=rank,
        num_machines=config.num_machines,
        partition_client=partition_client,
        subprocess_name=f"BackgRW-{rank}",
        subprocess_init=subprocess_init,
    )
    checkpoint_manager.register_metadata_provider(
        ConfigMetadataProvider(config))
    checkpoint_manager.write_config(config)

    if config.num_edge_chunks is not None:
        num_edge_chunks = config.num_edge_chunks
    else:
        num_edge_chunks = get_num_edge_chunks(config.edge_paths, nparts_lhs,
                                              nparts_rhs,
                                              config.max_edges_per_chunk)
    iteration_manager = IterationManager(
        config.num_epochs,
        config.edge_paths,
        num_edge_chunks,
        iteration_idx=checkpoint_manager.checkpoint_version)
    checkpoint_manager.register_metadata_provider(iteration_manager)

    if config.init_path is not None:
        loadpath_manager = CheckpointManager(config.init_path)
    else:
        loadpath_manager = None

    def load_embeddings(
        entity: EntityName,
        part: Partition,
        strict: bool = False,
        force_dirty: bool = False,
    ) -> Tuple[torch.nn.Parameter, Optional[OptimizerStateDict]]:
        if strict:
            embs, optim_state = checkpoint_manager.read(
                entity, part, force_dirty=force_dirty)
        else:
            # Strict is only false during the first iteration, because in that
            # case the checkpoint may not contain any data (unless a previous
            # run was resumed) so we fall back on initial values.
            embs, optim_state = checkpoint_manager.maybe_read(
                entity, part, force_dirty=force_dirty)
            if embs is None and loadpath_manager is not None:
                embs, optim_state = loadpath_manager.maybe_read(entity, part)
            if embs is None:
                embs, optim_state = init_embs(entity,
                                              entity_counts[entity][part],
                                              config.dimension,
                                              config.init_scale)
        assert embs.is_shared()
        return torch.nn.Parameter(embs), optim_state

    logger.info("Initializing global model...")

    if model is None:
        model = make_model(config)
    model.share_memory()
    if trainer is None:
        trainer = Trainer(
            global_optimizer=make_optimizer(model.parameters(), False),
            loss_fn=config.loss_fn,
            margin=config.margin,
            relations=config.relations,
        )
    if evaluator is None:
        evaluator = TrainingRankingEvaluator(
            override_num_batch_negs=config.eval_num_batch_negs,
            override_num_uniform_negs=config.eval_num_uniform_negs,
        )
    eval_batch_size = round_up_to_nearest_multiple(config.batch_size,
                                                   config.eval_num_batch_negs)

    state_dict, optim_state = checkpoint_manager.maybe_read_model()

    if state_dict is None and loadpath_manager is not None:
        state_dict, optim_state = loadpath_manager.maybe_read_model()
    if state_dict is not None:
        model.load_state_dict(state_dict, strict=False)
    if optim_state is not None:
        trainer.global_optimizer.load_state_dict(optim_state)

    logger.debug("Loading unpartitioned entities...")
    for entity, econfig in config.entities.items():
        if econfig.num_partitions == 1:
            embs, optim_state = load_embeddings(entity, Partition(0))
            model.set_embeddings(entity, embs, Side.LHS)
            model.set_embeddings(entity, embs, Side.RHS)
            optimizer = make_optimizer([embs], True)
            if optim_state is not None:
                optimizer.load_state_dict(optim_state)
            trainer.entity_optimizers[(entity, Partition(0))] = optimizer

    # start communicating shared parameters with the parameter server
    if parameter_sharer is not None:
        parameter_sharer.share_model_params(model)

    strict = False

    def swap_partitioned_embeddings(
        old_b: Optional[Bucket],
        new_b: Optional[Bucket],
    ):
        # 0. given the old and new buckets, construct data structures to keep
        #    track of old and new embedding (entity, part) tuples

        io_bytes = 0
        logger.info(f"Swapping partitioned embeddings {old_b} {new_b}")

        types = ([(e, Side.LHS) for e in lhs_partitioned_types] +
                 [(e, Side.RHS) for e in rhs_partitioned_types])
        old_parts = {(e, old_b.get_partition(side)): side
                     for e, side in types if old_b is not None}
        new_parts = {(e, new_b.get_partition(side)): side
                     for e, side in types if new_b is not None}

        to_checkpoint = set(old_parts) - set(new_parts)
        preserved = set(old_parts) & set(new_parts)

        # 1. checkpoint embeddings that will not be used in the next pair
        #
        if old_b is not None:  # there are previous embeddings to checkpoint
            logger.info("Writing partitioned embeddings")
            for entity, part in to_checkpoint:
                side = old_parts[(entity, part)]
                side_name = side.pick("lhs", "rhs")
                logger.debug(f"Checkpointing ({entity} {part} {side_name})")
                embs = model.get_embeddings(entity, side)
                optim_key = (entity, part)
                optim_state = OptimizerStateDict(
                    trainer.entity_optimizers[optim_key].state_dict())
                io_bytes += embs.numel() * embs.element_size(
                )  # ignore optim state
                checkpoint_manager.write(entity, part, embs.detach(),
                                         optim_state)
                if optim_key in trainer.entity_optimizers:
                    del trainer.entity_optimizers[optim_key]
                # these variables are holding large objects; let them be freed
                del embs
                del optim_state

            bucket_scheduler.release_bucket(old_b)

        # 2. copy old embeddings that will be used in the next pair
        #    into a temporary dictionary
        #
        tmp_emb = {
            x: model.get_embeddings(x[0], old_parts[x])
            for x in preserved
        }

        for entity, _ in types:
            model.clear_embeddings(entity, Side.LHS)
            model.clear_embeddings(entity, Side.RHS)

        if new_b is None:  # there are no new embeddings to load
            return io_bytes

        bucket_logger = BucketLogger(logger, bucket=new_b)

        # 3. load new embeddings into the model/optimizer, either from disk
        #    or the temporary dictionary
        #
        bucket_logger.info("Loading entities")
        for entity, side in types:
            part = new_b.get_partition(side)
            part_key = (entity, part)
            if part_key in tmp_emb:
                bucket_logger.debug(
                    f"Loading ({entity}, {part}) from preserved")
                embs, optim_state = tmp_emb[part_key], None
            else:
                bucket_logger.debug(f"Loading ({entity}, {part})")

                force_dirty = bucket_scheduler.check_and_set_dirty(
                    entity, part)
                embs, optim_state = load_embeddings(entity,
                                                    part,
                                                    strict=strict,
                                                    force_dirty=force_dirty)
                io_bytes += embs.numel() * embs.element_size(
                )  # ignore optim state

            model.set_embeddings(entity, embs, side)
            tmp_emb[part_key] = embs

            optim_key = (entity, part)
            if optim_key not in trainer.entity_optimizers:
                bucket_logger.debug(f"Resetting optimizer {optim_key}")
                optimizer = make_optimizer([embs], True)
                if optim_state is not None:
                    bucket_logger.debug("Setting optim state")
                    optimizer.load_state_dict(optim_state)

                trainer.entity_optimizers[optim_key] = optimizer

        return io_bytes

    if rank == RANK_ZERO:
        for stats_dict in checkpoint_manager.maybe_read_stats():
            index: int = stats_dict["index"]
            stats: Stats = Stats.from_dict(stats_dict["stats"])
            eval_stats_before: Optional[Stats] = None
            if "eval_stats_before" in stats_dict:
                eval_stats_before = Stats.from_dict(
                    stats_dict["eval_stats_before"])
            eval_stats_after: Optional[Stats] = None
            if "eval_stats_after" in stats_dict:
                eval_stats_after = Stats.from_dict(
                    stats_dict["eval_stats_after"])
            yield (index, eval_stats_before, stats, eval_stats_after)

    # Start of the main training loop.
    for epoch_idx, edge_path_idx, edge_chunk_idx in iteration_manager:
        logger.info(
            f"Starting epoch {epoch_idx + 1} / {iteration_manager.num_epochs}, "
            f"edge path {edge_path_idx + 1} / {iteration_manager.num_edge_paths}, "
            f"edge chunk {edge_chunk_idx + 1} / {iteration_manager.num_edge_chunks}"
        )
        edge_storage = EDGE_STORAGES.make_instance(iteration_manager.edge_path)
        logger.info(f"Edge path: {iteration_manager.edge_path}")

        sync.barrier()
        dist_logger.info("Lock client new epoch...")
        bucket_scheduler.new_pass(
            is_first=iteration_manager.iteration_idx == 0)
        sync.barrier()

        remaining = total_buckets
        cur_b = None
        while remaining > 0:
            old_b = cur_b
            io_time = 0.
            io_bytes = 0
            cur_b, remaining = bucket_scheduler.acquire_bucket()
            logger.info(f"still in queue: {remaining}")
            if cur_b is None:
                if old_b is not None:
                    # if you couldn't get a new pair, release the lock
                    # to prevent a deadlock!
                    tic = time.time()
                    io_bytes += swap_partitioned_embeddings(old_b, None)
                    io_time += time.time() - tic
                time.sleep(1)  # don't hammer td
                continue

            bucket_logger = BucketLogger(logger, bucket=cur_b)

            tic = time.time()

            io_bytes += swap_partitioned_embeddings(old_b, cur_b)

            current_index = \
                (iteration_manager.iteration_idx + 1) * total_buckets - remaining

            next_b = bucket_scheduler.peek()
            if next_b is not None and background_io:
                # Ensure the previous bucket finished writing to disk.
                checkpoint_manager.wait_for_marker(current_index - 1)

                bucket_logger.debug("Prefetching")
                for entity in lhs_partitioned_types:
                    checkpoint_manager.prefetch(entity, next_b.lhs)
                for entity in rhs_partitioned_types:
                    checkpoint_manager.prefetch(entity, next_b.rhs)

                checkpoint_manager.record_marker(current_index)

            bucket_logger.debug("Loading edges")
            edges = edge_storage.load_chunk_of_edges(
                cur_b.lhs, cur_b.rhs, edge_chunk_idx,
                iteration_manager.num_edge_chunks)
            num_edges = len(edges)
            # this might be off in the case of tensorlist or extra edge fields
            io_bytes += edges.lhs.tensor.numel(
            ) * edges.lhs.tensor.element_size()
            io_bytes += edges.rhs.tensor.numel(
            ) * edges.rhs.tensor.element_size()
            io_bytes += edges.rel.numel() * edges.rel.element_size()

            bucket_logger.debug("Shuffling edges")
            # Fix a seed to get the same permutation every time; have it
            # depend on all and only what affects the set of edges.
            g = torch.Generator()
            g.manual_seed(
                hash((edge_path_idx, edge_chunk_idx, cur_b.lhs, cur_b.rhs)))

            num_eval_edges = int(num_edges * config.eval_fraction)
            if num_eval_edges > 0:
                edge_perm = torch.randperm(num_edges, generator=g)
                eval_edge_perm = edge_perm[-num_eval_edges:]
                num_edges -= num_eval_edges
                edge_perm = edge_perm[torch.randperm(num_edges)]
            else:
                edge_perm = torch.randperm(num_edges)

            # HOGWILD evaluation before training
            eval_stats_before: Optional[Stats] = None
            if num_eval_edges > 0:
                bucket_logger.debug(
                    "Waiting for workers to perform evaluation")
                future_all_eval_stats_before = pool.map_async(
                    call, [
                        partial(
                            process_in_batches,
                            batch_size=eval_batch_size,
                            model=model,
                            batch_processor=evaluator,
                            edges=edges,
                            indices=eval_edge_perm[s],
                        ) for s in split_almost_equally(eval_edge_perm.size(0),
                                                        num_parts=num_workers)
                    ])
                all_eval_stats_before = \
                    get_async_result(future_all_eval_stats_before, pool)
                eval_stats_before = Stats.sum(all_eval_stats_before).average()
                bucket_logger.info(
                    f"Stats before training: {eval_stats_before}")

            io_time += time.time() - tic
            tic = time.time()
            # HOGWILD training
            bucket_logger.debug("Waiting for workers to perform training")
            # FIXME should we only delay if iteration_idx == 0?
            future_all_stats = pool.map_async(call, [
                partial(
                    process_in_batches,
                    batch_size=config.batch_size,
                    model=model,
                    batch_processor=trainer,
                    edges=edges,
                    indices=edge_perm[s],
                    delay=config.hogwild_delay
                    if epoch_idx == 0 and rank > 0 else 0,
                ) for rank, s in enumerate(
                    split_almost_equally(edge_perm.size(0),
                                         num_parts=num_workers))
            ])
            all_stats = get_async_result(future_all_stats, pool)
            stats = Stats.sum(all_stats).average()
            compute_time = time.time() - tic

            bucket_logger.info(
                f"bucket {total_buckets - remaining} / {total_buckets} : "
                f"Processed {num_edges} edges in {compute_time:.2f} s "
                f"( {num_edges / compute_time / 1e6:.2g} M/sec ); "
                f"io: {io_time:.2f} s ( {io_bytes / io_time / 1e6:.2f} MB/sec )"
            )
            bucket_logger.info(f"{stats}")

            # HOGWILD eval after training
            eval_stats_after: Optional[Stats] = None
            if num_eval_edges > 0:
                bucket_logger.debug(
                    "Waiting for workers to perform evaluation")
                future_all_eval_stats_after = pool.map_async(
                    call, [
                        partial(
                            process_in_batches,
                            batch_size=eval_batch_size,
                            model=model,
                            batch_processor=evaluator,
                            edges=edges,
                            indices=eval_edge_perm[s],
                        ) for s in split_almost_equally(eval_edge_perm.size(0),
                                                        num_parts=num_workers)
                    ])
                all_eval_stats_after = \
                    get_async_result(future_all_eval_stats_after, pool)
                eval_stats_after = Stats.sum(all_eval_stats_after).average()
                bucket_logger.info(f"Stats after training: {eval_stats_after}")

            # Add train/eval metrics to queue
            stats_dict = {
                "index": current_index,
                "stats": stats.to_dict(),
            }
            if eval_stats_before is not None:
                stats_dict["eval_stats_before"] = eval_stats_before.to_dict()
            if eval_stats_after is not None:
                stats_dict["eval_stats_after"] = eval_stats_after.to_dict()
            checkpoint_manager.append_stats(stats_dict)
            yield current_index, eval_stats_before, stats, eval_stats_after

        swap_partitioned_embeddings(cur_b, None)

        # Distributed Processing: all machines can leave the barrier now.
        sync.barrier()

        # Preserving a checkpoint requires two steps:
        # - create a snapshot (w/ symlinks) after it's first written;
        # - don't delete it once the following one is written.
        # These two happen in two successive iterations of the main loop: the
        # one just before and the one just after the epoch boundary.
        preserve_old_checkpoint = should_preserve_old_checkpoint(
            iteration_manager, config.checkpoint_preservation_interval)
        preserve_new_checkpoint = should_preserve_old_checkpoint(
            iteration_manager + 1, config.checkpoint_preservation_interval)

        # Write metadata: for multiple machines, write from rank-0
        logger.info(
            f"Finished epoch {epoch_idx + 1} / {iteration_manager.num_epochs}, "
            f"edge path {edge_path_idx + 1} / {iteration_manager.num_edge_paths}, "
            f"edge chunk {edge_chunk_idx + 1} / {iteration_manager.num_edge_chunks}"
        )
        if rank == 0:
            for entity, econfig in config.entities.items():
                if econfig.num_partitions == 1:
                    embs = model.get_embeddings(entity, Side.LHS)
                    optimizer = trainer.entity_optimizers[(entity,
                                                           Partition(0))]

                    checkpoint_manager.write(
                        entity, Partition(0), embs.detach(),
                        OptimizerStateDict(optimizer.state_dict()))

            sanitized_state_dict: ModuleStateDict = {}
            for k, v in ModuleStateDict(model.state_dict()).items():
                if k.startswith('lhs_embs') or k.startswith('rhs_embs'):
                    # skipping state that's an entity embedding
                    continue
                sanitized_state_dict[k] = v

            logger.info("Writing the metadata")
            checkpoint_manager.write_model(
                sanitized_state_dict,
                OptimizerStateDict(trainer.global_optimizer.state_dict()),
            )

        logger.info("Writing the checkpoint")
        checkpoint_manager.write_new_version(config)

        dist_logger.info(
            "Waiting for other workers to write their parts of the checkpoint")
        sync.barrier()
        dist_logger.info("All parts of the checkpoint have been written")

        logger.info("Switching to the new checkpoint version")
        checkpoint_manager.switch_to_new_version()

        dist_logger.info(
            "Waiting for other workers to switch to the new checkpoint version"
        )
        sync.barrier()
        dist_logger.info(
            "All workers have switched to the new checkpoint version")

        # After all the machines have finished committing
        # checkpoints, we either remove the old checkpoints
        # or we preserve it
        if preserve_new_checkpoint:
            # Add 1 so the index is a multiple of the interval, it looks nicer.
            checkpoint_manager.preserve_current_version(config, epoch_idx + 1)
        if not preserve_old_checkpoint:
            checkpoint_manager.remove_old_version(config)

        # now we're sure that all partition files exist,
        # so be strict about loading them
        strict = True

    # quiescence
    pool.close()
    pool.join()

    sync.barrier()

    checkpoint_manager.close()
    if loadpath_manager is not None:
        loadpath_manager.close()

    # FIXME join distributed workers (not really necessary)

    logger.info("Exiting")
Exemplo n.º 15
0
def do_eval_and_report_stats(
    config: ConfigSchema,
    model: Optional[MultiRelationEmbedder] = None,
    evaluator: Optional[AbstractBatchProcessor] = None,
    subprocess_init: Optional[Callable[[], None]] = None,
) -> Generator[Tuple[Optional[int], Optional[Bucket], Stats], None, None]:
    """Computes eval metrics (mr/mrr/r1/r10/r50) for a checkpoint with trained
       embeddings.
    """
    tag_logs_with_process_name(f"Evaluator")

    if evaluator is None:
        evaluator = RankingEvaluator(
            loss_fn=LOSS_FUNCTIONS.get_class(
                config.loss_fn)(margin=config.margin),
            relation_weights=[
                relation.weight for relation in config.relations
            ],
        )

    if config.verbose > 0:
        import pprint

        pprint.PrettyPrinter().pprint(config.to_dict())

    checkpoint_manager = CheckpointManager(config.checkpoint_path)

    def load_embeddings(entity: EntityName,
                        part: Partition) -> torch.nn.Parameter:
        embs, _ = checkpoint_manager.read(entity, part)
        assert embs.is_shared()
        return torch.nn.Parameter(embs)

    holder = EmbeddingHolder(config)

    num_workers = get_num_workers(config.workers)
    pool = create_pool(num_workers,
                       subprocess_name="EvalWorker",
                       subprocess_init=subprocess_init)

    if model is None:
        model = make_model(config)
    model.share_memory()

    state_dict, _ = checkpoint_manager.maybe_read_model()
    if state_dict is not None:
        model.load_state_dict(state_dict, strict=False)

    model.eval()

    for entity in holder.lhs_unpartitioned_types | holder.rhs_unpartitioned_types:
        embs = load_embeddings(entity, UNPARTITIONED)
        holder.unpartitioned_embeddings[entity] = embs

    all_stats: List[Stats] = []
    for edge_path_idx, edge_path in enumerate(config.edge_paths):
        logger.info(
            f"Starting edge path {edge_path_idx + 1} / {len(config.edge_paths)} "
            f"({edge_path})")
        edge_storage = EDGE_STORAGES.make_instance(edge_path)

        all_edge_path_stats = []
        # FIXME This order assumes higher affinity on the left-hand side, as it's
        # the one changing more slowly. Make this adaptive to the actual affinity.
        for bucket in create_buckets_ordered_lexicographically(
                holder.nparts_lhs, holder.nparts_rhs):
            tic = time.perf_counter()
            # logger.info(f"{bucket}: Loading entities")

            old_parts = set(holder.partitioned_embeddings.keys())
            new_parts = {(e, bucket.lhs)
                         for e in holder.lhs_partitioned_types
                         } | {(e, bucket.rhs)
                              for e in holder.rhs_partitioned_types}
            for entity, part in old_parts - new_parts:
                del holder.partitioned_embeddings[entity, part]
            for entity, part in new_parts - old_parts:
                embs = load_embeddings(entity, part)
                holder.partitioned_embeddings[entity, part] = embs

            model.set_all_embeddings(holder, bucket)

            # logger.info(f"{bucket}: Loading edges")
            edges = edge_storage.load_edges(bucket.lhs, bucket.rhs)
            num_edges = len(edges)

            load_time = time.perf_counter() - tic
            tic = time.perf_counter()
            # logger.info(f"{bucket}: Launching and waiting for workers")
            future_all_bucket_stats = pool.map_async(
                call,
                [
                    partial(
                        process_in_batches,
                        batch_size=config.batch_size,
                        model=model,
                        batch_processor=evaluator,
                        edges=edges[s],
                    ) for s in split_almost_equally(num_edges,
                                                    num_parts=num_workers)
                ],
            )
            all_bucket_stats = get_async_result(future_all_bucket_stats, pool)

            compute_time = time.perf_counter() - tic
            logger.info(
                f"{bucket}: Processed {num_edges} edges in {compute_time:.2g} s "
                f"({num_edges / compute_time / 1e6:.2g}M/sec); "
                f"load time: {load_time:.2g} s")

            total_bucket_stats = Stats.sum(all_bucket_stats)
            all_edge_path_stats.append(total_bucket_stats)
            mean_bucket_stats = total_bucket_stats.average()
            logger.info(
                f"Stats for edge path {edge_path_idx + 1} / {len(config.edge_paths)}, "
                f"bucket {bucket}: {mean_bucket_stats}")

            model.clear_all_embeddings()

            yield edge_path_idx, bucket, mean_bucket_stats

        total_edge_path_stats = Stats.sum(all_edge_path_stats)
        all_stats.append(total_edge_path_stats)
        mean_edge_path_stats = total_edge_path_stats.average()
        logger.info("")
        logger.info(
            f"Stats for edge path {edge_path_idx + 1} / {len(config.edge_paths)}: "
            f"{mean_edge_path_stats}")
        logger.info("")

        yield edge_path_idx, None, mean_edge_path_stats

    mean_stats = Stats.sum(all_stats).average()
    logger.info("")
    logger.info(f"Stats: {mean_stats}")
    logger.info("")

    yield None, None, mean_stats

    pool.close()
    pool.join()
Exemplo n.º 16
0
def do_eval_and_report_stats(
    config: ConfigSchema,
    model: Optional[MultiRelationEmbedder] = None,
    evaluator: Optional[AbstractBatchProcessor] = None,
    subprocess_init: Optional[Callable[[], None]] = None,
) -> Generator[Tuple[Optional[int], Optional[Bucket], Stats], None, None]:
    """Computes eval metrics (mr/mrr/r1/r10/r50) for a checkpoint with trained
       embeddings.
    """
    tag_logs_with_process_name(f"Evaluator")

    if evaluator is None:
        evaluator = RankingEvaluator()

    if config.verbose > 0:
        import pprint
        pprint.PrettyPrinter().pprint(config.to_dict())

    checkpoint_manager = CheckpointManager(config.checkpoint_path)

    def load_embeddings(entity: EntityName,
                        part: Partition) -> torch.nn.Parameter:
        embs, _ = checkpoint_manager.read(entity, part)
        assert embs.is_shared()
        return torch.nn.Parameter(embs)

    nparts_lhs, lhs_partitioned_types = get_partitioned_types(config, Side.LHS)
    nparts_rhs, rhs_partitioned_types = get_partitioned_types(config, Side.RHS)

    num_workers = get_num_workers(config.workers)
    pool = create_pool(
        num_workers,
        subprocess_name="EvalWorker",
        subprocess_init=subprocess_init,
    )

    if model is None:
        model = make_model(config)
    model.share_memory()

    state_dict, _ = checkpoint_manager.maybe_read_model()
    if state_dict is not None:
        model.load_state_dict(state_dict, strict=False)

    model.eval()

    for entity, econfig in config.entities.items():
        if econfig.num_partitions == 1:
            embs = load_embeddings(entity, Partition(0))
            model.set_embeddings(entity, embs, Side.LHS)
            model.set_embeddings(entity, embs, Side.RHS)

    all_stats: List[Stats] = []
    for edge_path_idx, edge_path in enumerate(config.edge_paths):
        logger.info(
            f"Starting edge path {edge_path_idx + 1} / {len(config.edge_paths)} "
            f"({edge_path})")
        edge_storage = EDGE_STORAGES.make_instance(edge_path)

        all_edge_path_stats = []
        last_lhs, last_rhs = None, None
        for bucket in create_buckets_ordered_lexicographically(
                nparts_lhs, nparts_rhs):
            tic = time.time()
            # logger.info(f"{bucket}: Loading entities")

            if last_lhs != bucket.lhs:
                for e in lhs_partitioned_types:
                    model.clear_embeddings(e, Side.LHS)
                    embs = load_embeddings(e, bucket.lhs)
                    model.set_embeddings(e, embs, Side.LHS)
            if last_rhs != bucket.rhs:
                for e in rhs_partitioned_types:
                    model.clear_embeddings(e, Side.RHS)
                    embs = load_embeddings(e, bucket.rhs)
                    model.set_embeddings(e, embs, Side.RHS)
            last_lhs, last_rhs = bucket.lhs, bucket.rhs

            # logger.info(f"{bucket}: Loading edges")
            edges = edge_storage.load_edges(bucket.lhs, bucket.rhs)
            num_edges = len(edges)

            load_time = time.time() - tic
            tic = time.time()
            # logger.info(f"{bucket}: Launching and waiting for workers")
            future_all_bucket_stats = pool.map_async(call, [
                partial(
                    process_in_batches,
                    batch_size=config.batch_size,
                    model=model,
                    batch_processor=evaluator,
                    edges=edges[s],
                )
                for s in split_almost_equally(num_edges, num_parts=num_workers)
            ])
            all_bucket_stats = \
                get_async_result(future_all_bucket_stats, pool)

            compute_time = time.time() - tic
            logger.info(
                f"{bucket}: Processed {num_edges} edges in {compute_time:.2g} s "
                f"({num_edges / compute_time / 1e6:.2g}M/sec); "
                f"load time: {load_time:.2g} s")

            total_bucket_stats = Stats.sum(all_bucket_stats)
            all_edge_path_stats.append(total_bucket_stats)
            mean_bucket_stats = total_bucket_stats.average()
            logger.info(
                f"Stats for edge path {edge_path_idx + 1} / {len(config.edge_paths)}, "
                f"bucket {bucket}: {mean_bucket_stats}")

            yield edge_path_idx, bucket, mean_bucket_stats

        total_edge_path_stats = Stats.sum(all_edge_path_stats)
        all_stats.append(total_edge_path_stats)
        mean_edge_path_stats = total_edge_path_stats.average()
        logger.info("")
        logger.info(
            f"Stats for edge path {edge_path_idx + 1} / {len(config.edge_paths)}: "
            f"{mean_edge_path_stats}")
        logger.info("")

        yield edge_path_idx, None, mean_edge_path_stats

    mean_stats = Stats.sum(all_stats).average()
    logger.info("")
    logger.info(f"Stats: {mean_stats}")
    logger.info("")

    yield None, None, mean_stats

    pool.close()
    pool.join()
Exemplo n.º 17
0
    def _coordinate_train(self, edges, eval_edge_idxs, epoch_idx) -> Stats:
        tk = TimeKeeper()

        config = self.config
        holder = self.holder
        cur_b = self.cur_b
        bucket_logger = self.bucket_logger
        num_edges = len(edges)
        if cur_b.lhs == cur_b.rhs and config.num_gpus > 1:
            num_subparts = 2 * config.num_gpus
        else:
            num_subparts = config.num_gpus

        edges_lhs = edges.lhs.tensor
        edges_rhs = edges.rhs.tensor
        edges_rel = edges.rel
        eval_edges_lhs = None
        eval_edges_rhs = None
        eval_edges_rel = None
        assert edges.weight is None, "Edge weights not implemented in GPU mode yet"
        if eval_edge_idxs is not None:
            bucket_logger.debug("Removing eval edges")
            tk.start("remove_eval")
            num_eval_edges = len(eval_edge_idxs)
            eval_edges_lhs = edges_lhs[eval_edge_idxs]
            eval_edges_rhs = edges_rhs[eval_edge_idxs]
            eval_edges_rel = edges_rel[eval_edge_idxs]
            edges_lhs[eval_edge_idxs] = edges_lhs[-num_eval_edges:].clone()
            edges_rhs[eval_edge_idxs] = edges_rhs[-num_eval_edges:].clone()
            edges_rel[eval_edge_idxs] = edges_rel[-num_eval_edges:].clone()
            edges_lhs = edges_lhs[:-num_eval_edges]
            edges_rhs = edges_rhs[:-num_eval_edges]
            edges_rel = edges_rel[:-num_eval_edges]
            bucket_logger.debug(
                f"Time spent removing eval edges: {tk.stop('remove_eval'):.4f} s"
            )

        bucket_logger.debug("Splitting edges into sub-buckets")
        tk.start("mapping_edges")
        # randomly permute the entities, to get a random subbucketing
        perm_holder = {}
        rev_perm_holder = {}
        for (entity, part), embs in holder.partitioned_embeddings.items():
            perm = _C.randperm(self.entity_counts[entity][part],
                               os.cpu_count())
            _C.shuffle(embs, perm, os.cpu_count())
            optimizer = self.trainer.partitioned_optimizers[entity, part]
            (optimizer_state, ) = optimizer.state.values()
            _C.shuffle(optimizer_state["sum"], perm, os.cpu_count())
            perm_holder[entity, part] = perm
            rev_perm = _C.reverse_permutation(perm, os.cpu_count())
            rev_perm_holder[entity, part] = rev_perm

        subpart_slices: Dict[Tuple[EntityName, Partition, SubPartition],
                             slice] = {}
        for entity_name, part in holder.partitioned_embeddings.keys():
            num_entities = self.entity_counts[entity_name][part]
            for subpart, subpart_slice in enumerate(
                    split_almost_equally(num_entities,
                                         num_parts=num_subparts)):
                subpart_slices[entity_name, part, subpart] = subpart_slice

        subbuckets = _C.sub_bucket(
            edges_lhs,
            edges_rhs,
            edges_rel,
            [self.entity_counts[r.lhs][cur_b.lhs] for r in config.relations],
            [perm_holder[r.lhs, cur_b.lhs] for r in config.relations],
            [self.entity_counts[r.rhs][cur_b.rhs] for r in config.relations],
            [perm_holder[r.rhs, cur_b.rhs] for r in config.relations],
            self.shared_lhs,
            self.shared_rhs,
            self.shared_rel,
            num_subparts,
            num_subparts,
            os.cpu_count(),
            config.dynamic_relations,
        )
        bucket_logger.debug("Time spent splitting edges into sub-buckets: "
                            f"{tk.stop('mapping_edges'):.4f} s")
        bucket_logger.debug("Done splitting edges into sub-buckets")
        bucket_logger.debug(f"{subpart_slices}")

        tk.start("scheduling")
        busy_gpus: Set[int] = set()
        all_stats: List[Stats] = []
        if cur_b.lhs != cur_b.rhs:  # Graph is bipartite!!
            gpu_schedules = build_bipartite_schedule(num_subparts)
        else:
            gpu_schedules = build_nonbipartite_schedule(num_subparts)
        for s in gpu_schedules:
            s.append(None)
            s.append(None)
        index_in_schedule = [0 for _ in range(self.gpu_pool.num_gpus)]
        locked_parts = set()

        def schedule(gpu_idx: GPURank) -> None:
            if gpu_idx in busy_gpus:
                return
            this_bucket = gpu_schedules[gpu_idx][index_in_schedule[gpu_idx]]
            next_bucket = gpu_schedules[gpu_idx][index_in_schedule[gpu_idx] +
                                                 1]
            if this_bucket is None:
                return
            subparts = {(e, cur_b.lhs, this_bucket[0])
                        for e in holder.lhs_partitioned_types
                        } | {(e, cur_b.rhs, this_bucket[1])
                             for e in holder.rhs_partitioned_types}
            if any(k in locked_parts for k in subparts):
                return
            for k in subparts:
                locked_parts.add(k)
            busy_gpus.add(gpu_idx)
            bucket_logger.debug(
                f"GPU #{gpu_idx} gets {this_bucket[0]}, {this_bucket[1]}")
            for embs in holder.partitioned_embeddings.values():
                assert embs.is_shared()
            self.gpu_pool.schedule(
                gpu_idx,
                SubprocessArgs(
                    lhs_types=holder.lhs_partitioned_types,
                    rhs_types=holder.rhs_partitioned_types,
                    lhs_part=cur_b.lhs,
                    rhs_part=cur_b.rhs,
                    lhs_subpart=this_bucket[0],
                    rhs_subpart=this_bucket[1],
                    next_lhs_subpart=next_bucket[0]
                    if next_bucket is not None else None,
                    next_rhs_subpart=next_bucket[1]
                    if next_bucket is not None else None,
                    trainer=self.trainer,
                    model=self.model,
                    all_embs=holder.partitioned_embeddings,
                    subpart_slices=subpart_slices,
                    subbuckets=subbuckets,
                    batch_size=config.batch_size,
                    lr=config.lr,
                ),
            )

        for gpu_idx in range(self.gpu_pool.num_gpus):
            schedule(gpu_idx)

        while busy_gpus:
            gpu_idx, result = self.gpu_pool.wait_for_next()
            assert gpu_idx == result.gpu_idx
            all_stats.append(result.stats)
            busy_gpus.remove(gpu_idx)
            this_bucket = gpu_schedules[gpu_idx][index_in_schedule[gpu_idx]]
            next_bucket = gpu_schedules[gpu_idx][index_in_schedule[gpu_idx] +
                                                 1]
            subparts = {(e, cur_b.lhs, this_bucket[0])
                        for e in holder.lhs_partitioned_types
                        } | {(e, cur_b.rhs, this_bucket[1])
                             for e in holder.rhs_partitioned_types}
            for k in subparts:
                locked_parts.remove(k)
            index_in_schedule[gpu_idx] += 1
            if next_bucket is None:
                bucket_logger.debug(f"GPU #{gpu_idx} finished its schedule")
            for gpu_idx in range(config.num_gpus):
                schedule(gpu_idx)

        assert len(all_stats) == num_subparts * num_subparts
        time_spent_scheduling = tk.stop("scheduling")
        bucket_logger.debug(
            f"Time spent scheduling sub-buckets: {time_spent_scheduling:.4f} s"
        )
        bucket_logger.info(
            f"Speed: {num_edges / time_spent_scheduling:,.0f} edges/sec")

        tk.start("rev_perm")

        for (entity, part), embs in holder.partitioned_embeddings.items():
            rev_perm = rev_perm_holder[entity, part]
            optimizer = self.trainer.partitioned_optimizers[entity, part]
            _C.shuffle(embs, rev_perm, os.cpu_count())
            (state, ) = optimizer.state.values()
            _C.shuffle(state["sum"], rev_perm, os.cpu_count())

        bucket_logger.debug(
            f"Time spent mapping embeddings back from sub-buckets: {tk.stop('rev_perm'):.4f} s"
        )

        if eval_edge_idxs is not None:
            bucket_logger.debug("Restoring eval edges")
            tk.start("restore_eval")
            edges.lhs.tensor[eval_edge_idxs] = eval_edges_lhs
            edges.rhs.tensor[eval_edge_idxs] = eval_edges_rhs
            edges.rel[eval_edge_idxs] = eval_edges_rel
            bucket_logger.debug(
                f"Time spent restoring eval edges: {tk.stop('restore_eval'):.4f} s"
            )

        logger.debug(
            f"_coordinate_train: Time unaccounted for: {tk.unaccounted():.4f} s"
        )

        return Stats.sum(all_stats).average()
Exemplo n.º 18
0
def train_and_report_stats(
    config: ConfigSchema,
    model: Optional[MultiRelationEmbedder] = None,
    trainer: Optional[AbstractBatchProcessor] = None,
    evaluator: Optional[AbstractBatchProcessor] = None,
    rank: Rank = RANK_ZERO,
) -> Generator[Tuple[int, Optional[Stats], Stats, Optional[Stats]], None, None]:
    """Each epoch/pass, for each partition pair, loads in embeddings and edgelist
    from disk, runs HOGWILD training on them, and writes partitions back to disk.
    """

    if config.verbose > 0:
        import pprint
        pprint.PrettyPrinter().pprint(config.to_dict())

    log("Loading entity counts...")
    if maybe_old_entity_path(config.entity_path):
        log("WARNING: It may be that your entity path contains files using the "
            "old format. See D14241362 for how to update them.")
    entity_counts: Dict[str, List[int]] = {}
    for entity, econf in config.entities.items():
        entity_counts[entity] = []
        for part in range(econf.num_partitions):
            with open(os.path.join(
                config.entity_path, "entity_count_%s_%d.txt" % (entity, part)
            ), "rt") as tf:
                entity_counts[entity].append(int(tf.read().strip()))

    # Figure out how many lhs and rhs partitions we need
    nparts_lhs, lhs_partitioned_types = get_partitioned_types(config, Side.LHS)
    nparts_rhs, rhs_partitioned_types = get_partitioned_types(config, Side.RHS)
    vlog("nparts %d %d types %s %s" %
         (nparts_lhs, nparts_rhs, lhs_partitioned_types, rhs_partitioned_types))
    total_buckets = nparts_lhs * nparts_rhs

    sync: AbstractSynchronizer
    bucket_scheduler: AbstractBucketScheduler
    parameter_sharer: Optional[ParameterSharer]
    partition_client: Optional[PartitionClient]
    if config.num_machines > 1:
        if not 0 <= rank < config.num_machines:
            raise RuntimeError("Invalid rank for trainer")
        if not td.is_available():
            raise RuntimeError("The installed PyTorch version doesn't provide "
                               "distributed training capabilities.")
        ranks = ProcessRanks.from_num_invocations(
            config.num_machines, config.num_partition_servers)

        if rank == RANK_ZERO:
            log("Setup lock server...")
            start_server(
                LockServer(
                    num_clients=len(ranks.trainers),
                    nparts_lhs=nparts_lhs,
                    nparts_rhs=nparts_rhs,
                    lock_lhs=len(lhs_partitioned_types) > 0,
                    lock_rhs=len(rhs_partitioned_types) > 0,
                    init_tree=config.distributed_tree_init_order,
                ),
                server_rank=ranks.lock_server,
                world_size=ranks.world_size,
                init_method=config.distributed_init_method,
                groups=[ranks.trainers],
            )

        bucket_scheduler = DistributedBucketScheduler(
            server_rank=ranks.lock_server,
            client_rank=ranks.trainers[rank],
        )

        log("Setup param server...")
        start_server(
            ParameterServer(num_clients=len(ranks.trainers)),
            server_rank=ranks.parameter_servers[rank],
            init_method=config.distributed_init_method,
            world_size=ranks.world_size,
            groups=[ranks.trainers],
        )

        parameter_sharer = ParameterSharer(
            client_rank=ranks.parameter_clients[rank],
            all_server_ranks=ranks.parameter_servers,
            init_method=config.distributed_init_method,
            world_size=ranks.world_size,
            groups=[ranks.trainers],
        )

        if config.num_partition_servers == -1:
            start_server(
                ParameterServer(num_clients=len(ranks.trainers)),
                server_rank=ranks.partition_servers[rank],
                world_size=ranks.world_size,
                init_method=config.distributed_init_method,
                groups=[ranks.trainers],
            )

        if len(ranks.partition_servers) > 0:
            partition_client = PartitionClient(ranks.partition_servers)
        else:
            partition_client = None

        groups = init_process_group(
            rank=ranks.trainers[rank],
            world_size=ranks.world_size,
            init_method=config.distributed_init_method,
            groups=[ranks.trainers],
        )
        trainer_group, = groups
        sync = DistributedSynchronizer(trainer_group)
        dlog = log

    else:
        sync = DummySynchronizer()
        bucket_scheduler = SingleMachineBucketScheduler(
            nparts_lhs, nparts_rhs, config.bucket_order)
        parameter_sharer = None
        partition_client = None
        dlog = lambda msg: None

    # fork early for HOGWILD threads
    log("Creating workers...")
    num_workers = get_num_workers(config.workers)
    pool = create_pool(num_workers)

    def make_optimizer(params: Iterable[torch.nn.Parameter], is_emb: bool) -> Optimizer:
        params = list(params)
        if len(params) == 0:
            optimizer = DummyOptimizer()
        elif is_emb:
            optimizer = RowAdagrad(params, lr=config.lr)
        else:
            if config.relation_lr is not None:
                lr = config.relation_lr
            else:
                lr = config.lr
            optimizer = Adagrad(params, lr=lr)
        optimizer.share_memory()
        return optimizer

    # background_io is only supported in single-machine mode
    background_io = config.background_io and config.num_machines == 1

    checkpoint_manager = CheckpointManager(
        config.checkpoint_path,
        background=background_io,
        rank=rank,
        num_machines=config.num_machines,
        partition_client=partition_client,
    )
    checkpoint_manager.register_metadata_provider(ConfigMetadataProvider(config))
    checkpoint_manager.write_config(config)

    iteration_manager = IterationManager(
        config.num_epochs, config.edge_paths, config.num_edge_chunks,
        iteration_idx=checkpoint_manager.checkpoint_version)
    checkpoint_manager.register_metadata_provider(iteration_manager)

    if config.init_path is not None:
        loadpath_manager = CheckpointManager(config.init_path)
    else:
        loadpath_manager = None

    def load_embeddings(
        entity: EntityName,
        part: Partition,
        strict: bool = False,
        force_dirty: bool = False,
    ) -> Tuple[torch.nn.Parameter, Optional[OptimizerStateDict]]:
        if strict:
            embs, optim_state = checkpoint_manager.read(entity, part,
                                                        force_dirty=force_dirty)
        else:
            # Strict is only false during the first iteration, because in that
            # case the checkpoint may not contain any data (unless a previous
            # run was resumed) so we fall back on initial values.
            embs, optim_state = checkpoint_manager.maybe_read(entity, part,
                                                              force_dirty=force_dirty)
            if embs is None and loadpath_manager is not None:
                embs, optim_state = loadpath_manager.maybe_read(entity, part)
            if embs is None:
                embs, optim_state = init_embs(entity, entity_counts[entity][part],
                                              config.dimension, config.init_scale)
        assert embs.is_shared()
        return torch.nn.Parameter(embs), optim_state

    log("Initializing global model...")

    if model is None:
        model = make_model(config)
    model.share_memory()
    if trainer is None:
        trainer = Trainer(
            global_optimizer=make_optimizer(model.parameters(), False),
            loss_fn=config.loss_fn,
            margin=config.margin,
            relations=config.relations,
        )
    if evaluator is None:
        evaluator = TrainingRankingEvaluator(
            override_num_batch_negs=config.eval_num_batch_negs,
            override_num_uniform_negs=config.eval_num_uniform_negs,
        )
    eval_batch_size = round_up_to_nearest_multiple(config.batch_size, config.eval_num_batch_negs)

    state_dict, optim_state = checkpoint_manager.maybe_read_model()

    if state_dict is None and loadpath_manager is not None:
        state_dict, optim_state = loadpath_manager.maybe_read_model()
    if state_dict is not None:
        model.load_state_dict(state_dict, strict=False)
    if optim_state is not None:
        trainer.global_optimizer.load_state_dict(optim_state)

    vlog("Loading unpartitioned entities...")
    for entity, econfig in config.entities.items():
        if econfig.num_partitions == 1:
            embs, optim_state = load_embeddings(entity, Partition(0))
            model.set_embeddings(entity, embs, Side.LHS)
            model.set_embeddings(entity, embs, Side.RHS)
            optimizer = make_optimizer([embs], True)
            if optim_state is not None:
                optimizer.load_state_dict(optim_state)
            trainer.entity_optimizers[(entity, Partition(0))] = optimizer

    # start communicating shared parameters with the parameter server
    if parameter_sharer is not None:
        parameter_sharer.share_model_params(model)

    strict = False

    def swap_partitioned_embeddings(
        old_b: Optional[Bucket],
        new_b: Optional[Bucket],
    ):
        # 0. given the old and new buckets, construct data structures to keep
        #    track of old and new embedding (entity, part) tuples

        io_bytes = 0
        log("Swapping partitioned embeddings %s %s" % (old_b, new_b))

        types = ([(e, Side.LHS) for e in lhs_partitioned_types]
                 + [(e, Side.RHS) for e in rhs_partitioned_types])
        old_parts = {(e, old_b.get_partition(side)): side
                     for e, side in types if old_b is not None}
        new_parts = {(e, new_b.get_partition(side)): side
                     for e, side in types if new_b is not None}

        to_checkpoint = set(old_parts) - set(new_parts)
        preserved = set(old_parts) & set(new_parts)

        # 1. checkpoint embeddings that will not be used in the next pair
        #
        if old_b is not None:  # there are previous embeddings to checkpoint
            log("Writing partitioned embeddings")
            for entity, part in to_checkpoint:
                side = old_parts[(entity, part)]
                vlog("Checkpointing (%s %d %s)" %
                     (entity, part, side.pick("lhs", "rhs")))
                embs = model.get_embeddings(entity, side)
                optim_key = (entity, part)
                optim_state = OptimizerStateDict(trainer.entity_optimizers[optim_key].state_dict())
                io_bytes += embs.numel() * embs.element_size()  # ignore optim state
                checkpoint_manager.write(entity, part, embs.detach(), optim_state)
                if optim_key in trainer.entity_optimizers:
                    del trainer.entity_optimizers[optim_key]
                # these variables are holding large objects; let them be freed
                del embs
                del optim_state

            bucket_scheduler.release_bucket(old_b)

        # 2. copy old embeddings that will be used in the next pair
        #    into a temporary dictionary
        #
        tmp_emb = {x: model.get_embeddings(x[0], old_parts[x]) for x in preserved}

        for entity, _ in types:
            model.clear_embeddings(entity, Side.LHS)
            model.clear_embeddings(entity, Side.RHS)

        if new_b is None:  # there are no new embeddings to load
            return io_bytes

        # 3. load new embeddings into the model/optimizer, either from disk
        #    or the temporary dictionary
        #
        log("Loading entities")
        for entity, side in types:
            part = new_b.get_partition(side)
            part_key = (entity, part)
            if part_key in tmp_emb:
                vlog("Loading (%s, %d) from preserved" % (entity, part))
                embs, optim_state = tmp_emb[part_key], None
            else:
                vlog("Loading (%s, %d)" % (entity, part))

                force_dirty = bucket_scheduler.check_and_set_dirty(entity, part)
                embs, optim_state = load_embeddings(
                    entity, part, strict=strict, force_dirty=force_dirty)
                io_bytes += embs.numel() * embs.element_size()  # ignore optim state

            model.set_embeddings(entity, embs, side)
            tmp_emb[part_key] = embs

            optim_key = (entity, part)
            if optim_key not in trainer.entity_optimizers:
                vlog("Resetting optimizer %s" % (optim_key,))
                optimizer = make_optimizer([embs], True)
                if optim_state is not None:
                    vlog("Setting optim state")
                    optimizer.load_state_dict(optim_state)

                trainer.entity_optimizers[optim_key] = optimizer

        return io_bytes

    # Start of the main training loop.
    for epoch_idx, edge_path_idx, edge_chunk_idx \
            in iteration_manager.remaining_iterations():
        log("Starting epoch %d / %d edge path %d / %d edge chunk %d / %d" %
            (epoch_idx + 1, iteration_manager.num_epochs,
             edge_path_idx + 1, iteration_manager.num_edge_paths,
             edge_chunk_idx + 1, iteration_manager.num_edge_chunks))
        edge_reader = EdgeReader(iteration_manager.edge_path)
        log("edge_path= %s" % iteration_manager.edge_path)

        sync.barrier()
        dlog("Lock client new epoch...")
        bucket_scheduler.new_pass(is_first=iteration_manager.iteration_idx == 0)
        sync.barrier()

        remaining = total_buckets
        cur_b = None
        while remaining > 0:
            old_b = cur_b
            io_time = 0.
            io_bytes = 0
            cur_b, remaining = bucket_scheduler.acquire_bucket()
            print('still in queue: %d' % remaining, file=sys.stderr)
            if cur_b is None:
                if old_b is not None:
                    # if you couldn't get a new pair, release the lock
                    # to prevent a deadlock!
                    tic = time.time()
                    io_bytes += swap_partitioned_embeddings(old_b, None)
                    io_time += time.time() - tic
                time.sleep(1)  # don't hammer td
                continue

            def log_status(msg, always=False):
                f = log if always else vlog
                f("%s: %s" % (cur_b, msg))

            tic = time.time()

            io_bytes += swap_partitioned_embeddings(old_b, cur_b)

            current_index = \
                (iteration_manager.iteration_idx + 1) * total_buckets - remaining

            next_b = bucket_scheduler.peek()
            if next_b is not None and background_io:
                # Ensure the previous bucket finished writing to disk.
                checkpoint_manager.wait_for_marker(current_index - 1)

                log_status("Prefetching")
                for entity in lhs_partitioned_types:
                    checkpoint_manager.prefetch(entity, next_b.lhs)
                for entity in rhs_partitioned_types:
                    checkpoint_manager.prefetch(entity, next_b.rhs)

                checkpoint_manager.record_marker(current_index)

            log_status("Loading edges")
            edges = edge_reader.read(
                cur_b.lhs, cur_b.rhs, edge_chunk_idx, config.num_edge_chunks)
            num_edges = len(edges)
            # this might be off in the case of tensorlist or extra edge fields
            io_bytes += edges.lhs.tensor.numel() * edges.lhs.tensor.element_size()
            io_bytes += edges.rhs.tensor.numel() * edges.rhs.tensor.element_size()
            io_bytes += edges.rel.numel() * edges.rel.element_size()

            log_status("Shuffling edges")
            # Fix a seed to get the same permutation every time; have it
            # depend on all and only what affects the set of edges.
            g = torch.Generator()
            g.manual_seed(hash((edge_path_idx, edge_chunk_idx, cur_b.lhs, cur_b.rhs)))

            num_eval_edges = int(num_edges * config.eval_fraction)
            if num_eval_edges > 0:
                edge_perm = torch.randperm(num_edges, generator=g)
                eval_edge_perm = edge_perm[-num_eval_edges:]
                num_edges -= num_eval_edges
                edge_perm = edge_perm[torch.randperm(num_edges)]
            else:
                edge_perm = torch.randperm(num_edges)

            # HOGWILD evaluation before training
            eval_stats_before: Optional[Stats] = None
            if num_eval_edges > 0:
                log_status("Waiting for workers to perform evaluation")
                all_eval_stats_before = pool.map(call, [
                    partial(
                        process_in_batches,
                        batch_size=eval_batch_size,
                        model=model,
                        batch_processor=evaluator,
                        edges=edges,
                        indices=eval_edge_perm[s],
                    )
                    for s in split_almost_equally(eval_edge_perm.size(0),
                                                  num_parts=num_workers)
                ])
                eval_stats_before = Stats.sum(all_eval_stats_before).average()
                log("stats before %s: %s" % (cur_b, eval_stats_before))

            io_time += time.time() - tic
            tic = time.time()
            # HOGWILD training
            log_status("Waiting for workers to perform training")
            # FIXME should we only delay if iteration_idx == 0?
            all_stats = pool.map(call, [
                partial(
                    process_in_batches,
                    batch_size=config.batch_size,
                    model=model,
                    batch_processor=trainer,
                    edges=edges,
                    indices=edge_perm[s],
                    delay=config.hogwild_delay if epoch_idx == 0 and rank > 0 else 0,
                )
                for rank, s in enumerate(split_almost_equally(edge_perm.size(0),
                                                              num_parts=num_workers))
            ])
            stats = Stats.sum(all_stats).average()
            compute_time = time.time() - tic

            log_status(
                "bucket %d / %d : Processed %d edges in %.2f s "
                "( %.2g M/sec ); io: %.2f s ( %.2f MB/sec )" %
                (total_buckets - remaining, total_buckets,
                 num_edges, compute_time, num_edges / compute_time / 1e6,
                 io_time, io_bytes / io_time / 1e6),
                always=True)
            log_status("%s" % stats, always=True)

            # HOGWILD eval after training
            eval_stats_after: Optional[Stats] = None
            if num_eval_edges > 0:
                log_status("Waiting for workers to perform evaluation")
                all_eval_stats_after = pool.map(call, [
                    partial(
                        process_in_batches,
                        batch_size=eval_batch_size,
                        model=model,
                        batch_processor=evaluator,
                        edges=edges,
                        indices=eval_edge_perm[s],
                    )
                    for s in split_almost_equally(eval_edge_perm.size(0),
                                                  num_parts=num_workers)
                ])
                eval_stats_after = Stats.sum(all_eval_stats_after).average()
                log("stats after %s: %s" % (cur_b, eval_stats_after))

            # Add train/eval metrics to queue
            yield current_index, eval_stats_before, stats, eval_stats_after

        swap_partitioned_embeddings(cur_b, None)

        # Distributed Processing: all machines can leave the barrier now.
        sync.barrier()

        # Write metadata: for multiple machines, write from rank-0
        log("Finished epoch %d path %d pass %d; checkpointing global state."
            % (epoch_idx + 1, edge_path_idx + 1, edge_chunk_idx + 1))
        log("My rank: %d" % rank)
        if rank == 0:
            for entity, econfig in config.entities.items():
                if econfig.num_partitions == 1:
                    embs = model.get_embeddings(entity, Side.LHS)
                    optimizer = trainer.entity_optimizers[(entity, Partition(0))]

                    checkpoint_manager.write(
                        entity, Partition(0),
                        embs.detach(), OptimizerStateDict(optimizer.state_dict()))

            sanitized_state_dict: ModuleStateDict = {}
            for k, v in ModuleStateDict(model.state_dict()).items():
                if k.startswith('lhs_embs') or k.startswith('rhs_embs'):
                    # skipping state that's an entity embedding
                    continue
                sanitized_state_dict[k] = v

            log("Writing metadata...")
            checkpoint_manager.write_model(
                sanitized_state_dict,
                OptimizerStateDict(trainer.global_optimizer.state_dict()),
            )

        log("Writing the checkpoint...")
        checkpoint_manager.write_new_version(config)

        dlog("Waiting for other workers to write their parts of the checkpoint: rank %d" % rank)
        sync.barrier()
        dlog("All parts of the checkpoint have been written")

        log("Switching to new checkpoint version...")
        checkpoint_manager.switch_to_new_version()

        dlog("Waiting for other workers to switch to the new checkpoint version: rank %d" % rank)
        sync.barrier()
        dlog("All workers have switched to the new checkpoint version")

        # After all the machines have finished committing
        # checkpoints, we remove the old checkpoints.
        checkpoint_manager.remove_old_version(config)

        # now we're sure that all partition files exist,
        # so be strict about loading them
        strict = True

    # quiescence
    pool.close()
    pool.join()

    sync.barrier()

    checkpoint_manager.close()
    if loadpath_manager is not None:
        loadpath_manager.close()

    # FIXME join distributed workers (not really necessary)

    log("Exiting")