예제 #1
0
def make_model(config: ConfigSchema) -> MultiRelationEmbedder:
    if config.dynamic_relations:
        if len(config.relations) != 1:
            raise RuntimeError(
                "Dynamic relations are enabled, so there should only be one "
                "entry in config.relations with config for all relations.")
        try:
            relation_type_storage = RELATION_TYPE_STORAGES.make_instance(
                config.entity_path)
            num_dynamic_rels = relation_type_storage.load_count()
        except CouldNotLoadData:
            raise RuntimeError(
                "Dynamic relations are enabled, so there should be a file called "
                "dynamic_rel_count.txt in the entity path with their count.")
    else:
        num_dynamic_rels = 0

    if config.num_batch_negs > 0 and config.batch_size % config.num_batch_negs != 0:
        raise RuntimeError(
            "Batch size (%d) must be a multiple of num_batch_negs (%d)" %
            (config.batch_size, config.num_batch_negs))

    lhs_operators: List[Optional[Union[AbstractOperator,
                                       AbstractDynamicOperator]]] = []
    rhs_operators: List[Optional[Union[AbstractOperator,
                                       AbstractDynamicOperator]]] = []
    for r in config.relations:
        lhs_operators.append(
            instantiate_operator(r.operator, Side.LHS, num_dynamic_rels,
                                 config.entity_dimension(r.lhs)))
        rhs_operators.append(
            instantiate_operator(r.operator, Side.RHS, num_dynamic_rels,
                                 config.entity_dimension(r.rhs)))

    comparator_class = COMPARATORS.get_class(config.comparator)
    comparator = comparator_class()

    if config.bias:
        comparator = BiasedComparator(comparator)

    return MultiRelationEmbedder(
        config.dimension,
        config.relations,
        config.entities,
        num_uniform_negs=config.num_uniform_negs,
        num_batch_negs=config.num_batch_negs,
        disable_lhs_negs=config.disable_lhs_negs,
        disable_rhs_negs=config.disable_rhs_negs,
        lhs_operators=lhs_operators,
        rhs_operators=rhs_operators,
        comparator=comparator,
        global_emb=config.global_emb,
        max_norm=config.max_norm,
        num_dynamic_rels=num_dynamic_rels,
        half_precision=config.half_precision,
    )
예제 #2
0
def init_embeddings(target: str, config: ConfigSchema, *, version: int = 0):
    with open(os.path.join(target, "checkpoint_version.txt"), "xt") as tf:
        tf.write("%d" % version)
    for entity_name, entity in config.entities.items():
        for partition in range(entity.num_partitions):
            with open(
                    os.path.join(
                        config.entity_path,
                        "entity_count_%s_%d.txt" % (entity_name, partition),
                    ),
                    "rt",
            ) as tf:
                entity_count = int(tf.read().strip())
            with h5py.File(
                    os.path.join(
                        target,
                        "embeddings_%s_%d.v%d.h5" %
                        (entity_name, partition, version),
                    ),
                    "x",
            ) as hf:
                hf.attrs["format_version"] = 1
                hf.create_dataset(
                    "embeddings",
                    data=np.random.randn(entity_count,
                                         config.entity_dimension(entity_name)),
                )
    with h5py.File(os.path.join(target, "model.v%d.h5" % version), "x") as hf:
        hf.attrs["format_version"] = 1
예제 #3
0
 def write_new_version(
     self,
     config: ConfigSchema,
     entity_counts: Dict[EntityName, List[int]],
     embedding_storage_freelist: Dict[EntityName, Set[torch.FloatStorage]],
 ) -> None:
     metadata = self.collect_metadata()
     new_version = self._version(True)
     if self.partition_client is not None:
         for entity, econf in config.entities.items():
             dimension = config.entity_dimension(entity)
             for part in range(self.rank, econf.num_partitions,
                               self.num_machines):
                 logger.debug(f"Getting {entity} {part}")
                 count = entity_counts[entity][part]
                 s = next(iter(embedding_storage_freelist[entity]))
                 out = torch.FloatTensor(s).view(-1, dimension)[:count]
                 embs, serialized_optim_state = self.partition_client.get(
                     entity, part, out=out)
                 logger.debug(f"Done getting {entity} {part}")
                 logger.debug(f"Saving {entity} {part} v{new_version}")
                 self.storage.save_entity_partition(
                     new_version,
                     entity,
                     part,
                     embs,
                     serialized_optim_state,
                     metadata,
                 )
                 logger.debug(f"Done saving {entity} {part} v{new_version}")
예제 #4
0
    def __init__(
        self,
        config: ConfigSchema,
        model: Optional[MultiRelationEmbedder] = None,
        trainer: Optional[AbstractBatchProcessor] = None,
        evaluator: Optional[AbstractBatchProcessor] = None,
        rank: Rank = SINGLE_TRAINER,
        subprocess_init: Optional[Callable[[], None]] = None,
        stats_handler: StatsHandler = NOOP_STATS_HANDLER,
    ):

        super().__init__(
            config, model, trainer, evaluator, rank, subprocess_init, stats_handler
        )

        assert config.num_gpus > 0
        if not CPP_INSTALLED:
            raise RuntimeError(
                "GPU support requires C++ installation: "
                "install with C++ support by running "
                "`PBG_INSTALL_CPP=1 pip install .`"
            )

        if config.half_precision:
            for entity in config.entities:
                # need this for tensor cores to work
                assert config.entity_dimension(entity) % 8 == 0
            assert config.batch_size % 8 == 0
            assert config.num_batch_negs % 8 == 0
            assert config.num_uniform_negs % 8 == 0

        assert len(self.holder.lhs_unpartitioned_types) == 0
        assert len(self.holder.rhs_unpartitioned_types) == 0

        num_edge_chunks = self.iteration_manager.num_edge_chunks
        max_edges = 0
        for edge_path in config.edge_paths:
            edge_storage = EDGE_STORAGES.make_instance(edge_path)
            for lhs_part in range(self.holder.nparts_lhs):
                for rhs_part in range(self.holder.nparts_rhs):
                    num_edges = edge_storage.get_number_of_edges(lhs_part, rhs_part)
                    num_edges_per_chunk = div_roundup(num_edges, num_edge_chunks)
                    max_edges = max(max_edges, num_edges_per_chunk)
        self.shared_lhs = allocate_shared_tensor((max_edges,), dtype=torch.long)
        self.shared_rhs = allocate_shared_tensor((max_edges,), dtype=torch.long)
        self.shared_rel = allocate_shared_tensor((max_edges,), dtype=torch.long)

        # fork early for HOGWILD threads
        logger.info("Creating GPU workers...")
        torch.set_num_threads(1)
        self.gpu_pool = GPUProcessPool(
            config.num_gpus,
            subprocess_init,
            {s for ss in self.embedding_storage_freelist.values() for s in ss}
            | {
                self.shared_lhs.storage(),
                self.shared_rhs.storage(),
                self.shared_rel.storage(),
            },
        )
예제 #5
0
    def assertCheckpointWritten(self, config: ConfigSchema, *,
                                version: int) -> None:
        with open(
                os.path.join(config.checkpoint_path, "checkpoint_version.txt"),
                "rt") as tf:
            self.assertEqual(version, int(tf.read().strip()))

        with open(os.path.join(config.checkpoint_path, "config.json"),
                  "rt") as tf:
            self.assertEqual(json.load(tf), config.to_dict())

        with h5py.File(
                os.path.join(config.checkpoint_path, "model.v%d.h5" % version),
                "r") as hf:
            self.assertHasMetadata(hf, config)
            self.assertIsModelParameters(hf["model"])
            self.assertIsOptimStateDict(hf["optimizer/state_dict"])

        with open(os.path.join(config.checkpoint_path, "training_stats.json"),
                  "rt") as tf:
            for line in tf:
                self.assertIsStatsDict(json.loads(line))

        for entity_name, entity in config.entities.items():
            for partition in range(entity.num_partitions):
                with open(
                        os.path.join(
                            config.entity_path,
                            "entity_count_%s_%d.txt" %
                            (entity_name, partition),
                        ),
                        "rt",
                ) as tf:
                    entity_count = int(tf.read().strip())
                with h5py.File(
                        os.path.join(
                            config.checkpoint_path,
                            "embeddings_%s_%d.v%d.h5" %
                            (entity_name, partition, version),
                        ),
                        "r",
                ) as hf:
                    self.assertHasMetadata(hf, config)
                    self.assertIsEmbeddings(
                        hf["embeddings"],
                        entity_count,
                        config.entity_dimension(entity_name),
                    )
                    self.assertIsOptimStateDict(hf["optimizer/state_dict"])
예제 #6
0
    def __init__(  # noqa
        self,
        config: ConfigSchema,
        model: Optional[MultiRelationEmbedder] = None,
        trainer: Optional[AbstractBatchProcessor] = None,
        evaluator: Optional[AbstractBatchProcessor] = None,
        rank: Rank = SINGLE_TRAINER,
        subprocess_init: Optional[Callable[[], None]] = None,
        stats_handler: StatsHandler = NOOP_STATS_HANDLER,
    ):
        """Each epoch/pass, for each partition pair, loads in embeddings and edgelist
        from disk, runs HOGWILD training on them, and writes partitions back to disk.
        """
        tag_logs_with_process_name(f"Trainer-{rank}")
        self.config = config
        if config.verbose > 0:
            import pprint

            pprint.PrettyPrinter().pprint(config.to_dict())

        logger.info("Loading entity counts...")
        entity_storage = ENTITY_STORAGES.make_instance(config.entity_path)
        entity_counts: Dict[str, List[int]] = {}
        for entity, econf in config.entities.items():
            entity_counts[entity] = []
            for part in range(econf.num_partitions):
                entity_counts[entity].append(entity_storage.load_count(entity, part))

        # Figure out how many lhs and rhs partitions we need
        holder = self.holder = EmbeddingHolder(config)

        logger.debug(
            f"nparts {holder.nparts_lhs} {holder.nparts_rhs} "
            f"types {holder.lhs_partitioned_types} {holder.rhs_partitioned_types}"
        )

        # We know ahead of time that we wil need 1-2 storages for each embedding type,
        # as well as the max size of this storage (num_entities x D).
        # We allocate these storages n advance in `embedding_storage_freelist`.
        # When we need storage for an entity type, we pop it from this free list,
        # and then add it back when we 'delete' the embedding table.
        embedding_storage_freelist: Dict[
            EntityName, Set[torch.FloatStorage]
        ] = defaultdict(set)
        for entity_type, counts in entity_counts.items():
            max_count = max(counts)
            num_sides = (
                (1 if entity_type in holder.lhs_partitioned_types else 0)
                + (1 if entity_type in holder.rhs_partitioned_types else 0)
                + (
                    1
                    if entity_type
                    in (holder.lhs_unpartitioned_types | holder.rhs_unpartitioned_types)
                    else 0
                )
            )
            for _ in range(num_sides):
                embedding_storage_freelist[entity_type].add(
                    allocate_shared_tensor(
                        (max_count, config.entity_dimension(entity_type)),
                        dtype=torch.float,
                    ).storage()
                )

        # create the handlers, threads, etc. for distributed training
        if config.num_machines > 1 or config.num_partition_servers > 0:
            if not 0 <= rank < config.num_machines:
                raise RuntimeError("Invalid rank for trainer")
            if not td.is_available():
                raise RuntimeError(
                    "The installed PyTorch version doesn't provide "
                    "distributed training capabilities."
                )
            ranks = ProcessRanks.from_num_invocations(
                config.num_machines, config.num_partition_servers
            )

            num_ps_groups = config.num_groups_for_partition_server
            groups: List[List[int]] = [ranks.trainers]  # barrier group
            groups += [
                ranks.trainers + ranks.partition_servers
            ] * num_ps_groups  # ps groups
            group_idxs_for_partition_servers = range(1, len(groups))

            if rank == SINGLE_TRAINER:
                logger.info("Setup lock server...")
                start_server(
                    LockServer(
                        num_clients=len(ranks.trainers),
                        nparts_lhs=holder.nparts_lhs,
                        nparts_rhs=holder.nparts_rhs,
                        entities_lhs=holder.lhs_partitioned_types,
                        entities_rhs=holder.rhs_partitioned_types,
                        entity_counts=entity_counts,
                        init_tree=config.distributed_tree_init_order,
                        stats_handler=stats_handler,
                    ),
                    process_name="LockServer",
                    init_method=config.distributed_init_method,
                    world_size=ranks.world_size,
                    server_rank=ranks.lock_server,
                    groups=groups,
                    subprocess_init=subprocess_init,
                )

            self.bucket_scheduler = DistributedBucketScheduler(
                server_rank=ranks.lock_server, client_rank=ranks.trainers[rank]
            )

            logger.info("Setup param server...")
            start_server(
                ParameterServer(num_clients=len(ranks.trainers)),
                process_name=f"ParamS-{rank}",
                init_method=config.distributed_init_method,
                world_size=ranks.world_size,
                server_rank=ranks.parameter_servers[rank],
                groups=groups,
                subprocess_init=subprocess_init,
            )

            parameter_sharer = ParameterSharer(
                process_name=f"ParamC-{rank}",
                client_rank=ranks.parameter_clients[rank],
                all_server_ranks=ranks.parameter_servers,
                init_method=config.distributed_init_method,
                world_size=ranks.world_size,
                groups=groups,
                subprocess_init=subprocess_init,
            )

            if config.num_partition_servers == -1:
                start_server(
                    ParameterServer(
                        num_clients=len(ranks.trainers),
                        group_idxs=group_idxs_for_partition_servers,
                        log_stats=True,
                    ),
                    process_name=f"PartS-{rank}",
                    init_method=config.distributed_init_method,
                    world_size=ranks.world_size,
                    server_rank=ranks.partition_servers[rank],
                    groups=groups,
                    subprocess_init=subprocess_init,
                )

            groups = init_process_group(
                rank=ranks.trainers[rank],
                world_size=ranks.world_size,
                init_method=config.distributed_init_method,
                groups=groups,
            )
            trainer_group, *groups_for_partition_servers = groups
            self.barrier_group = trainer_group

            if len(ranks.partition_servers) > 0:
                partition_client = PartitionClient(
                    ranks.partition_servers,
                    groups=groups_for_partition_servers,
                    log_stats=True,
                )
            else:
                partition_client = None
        else:
            self.barrier_group = None
            self.bucket_scheduler = SingleMachineBucketScheduler(
                holder.nparts_lhs, holder.nparts_rhs, config.bucket_order, stats_handler
            )
            parameter_sharer = None
            partition_client = None
            hide_distributed_logging()

        # fork early for HOGWILD threads
        logger.info("Creating workers...")
        self.num_workers = get_num_workers(config.workers)
        self.pool = create_pool(
            self.num_workers,
            subprocess_name=f"TWorker-{rank}",
            subprocess_init=subprocess_init,
        )

        checkpoint_manager = CheckpointManager(
            config.checkpoint_path,
            rank=rank,
            num_machines=config.num_machines,
            partition_client=partition_client,
            subprocess_name=f"BackgRW-{rank}",
            subprocess_init=subprocess_init,
        )
        self.checkpoint_manager = checkpoint_manager
        checkpoint_manager.register_metadata_provider(ConfigMetadataProvider(config))
        if rank == 0:
            checkpoint_manager.write_config(config)

        num_edge_chunks = get_num_edge_chunks(config)

        self.iteration_manager = IterationManager(
            config.num_epochs,
            config.edge_paths,
            num_edge_chunks,
            iteration_idx=checkpoint_manager.checkpoint_version,
        )
        checkpoint_manager.register_metadata_provider(self.iteration_manager)

        logger.info("Initializing global model...")
        if model is None:
            model = make_model(config)
        model.share_memory()
        loss_fn = LOSS_FUNCTIONS.get_class(config.loss_fn)(margin=config.margin)
        relation_weights = [relation.weight for relation in config.relations]
        if trainer is None:
            trainer = Trainer(
                model_optimizer=make_optimizer(config, model.parameters(), False),
                loss_fn=loss_fn,
                relation_weights=relation_weights,
            )
        if evaluator is None:
            eval_overrides = {}
            if config.eval_num_batch_negs is not None:
                eval_overrides["num_batch_negs"] = config.eval_num_batch_negs
            if config.eval_num_uniform_negs is not None:
                eval_overrides["num_uniform_negs"] = config.eval_num_uniform_negs

            evaluator = RankingEvaluator(
                loss_fn=loss_fn,
                relation_weights=relation_weights,
                overrides=eval_overrides,
            )

        if config.init_path is not None:
            self.loadpath_manager = CheckpointManager(config.init_path)
        else:
            self.loadpath_manager = None

        # load model from checkpoint or loadpath, if available
        state_dict, optim_state = checkpoint_manager.maybe_read_model()
        if state_dict is None and self.loadpath_manager is not None:
            state_dict, optim_state = self.loadpath_manager.maybe_read_model()
        if state_dict is not None:
            model.load_state_dict(state_dict, strict=False)
        if optim_state is not None:
            trainer.model_optimizer.load_state_dict(optim_state)

        logger.debug("Loading unpartitioned entities...")
        for entity in holder.lhs_unpartitioned_types | holder.rhs_unpartitioned_types:
            count = entity_counts[entity][0]
            s = embedding_storage_freelist[entity].pop()
            dimension = config.entity_dimension(entity)
            embs = torch.FloatTensor(s).view(-1, dimension)[:count]
            embs, optimizer = self._load_embeddings(entity, UNPARTITIONED, out=embs)
            holder.unpartitioned_embeddings[entity] = embs
            trainer.unpartitioned_optimizers[entity] = optimizer

        # start communicating shared parameters with the parameter server
        if parameter_sharer is not None:
            shared_parameters: Set[int] = set()
            for name, param in model.named_parameters():
                if id(param) in shared_parameters:
                    continue
                shared_parameters.add(id(param))
                key = f"model.{name}"
                logger.info(
                    f"Adding {key} ({param.numel()} params) to parameter server"
                )
                parameter_sharer.set_param(key, param.data)
            for entity, embs in holder.unpartitioned_embeddings.items():
                key = f"entity.{entity}"
                logger.info(f"Adding {key} ({embs.numel()} params) to parameter server")
                parameter_sharer.set_param(key, embs.data)

        # store everything in self
        self.model = model
        self.trainer = trainer
        self.evaluator = evaluator
        self.rank = rank
        self.entity_counts = entity_counts
        self.embedding_storage_freelist = embedding_storage_freelist
        self.stats_handler = stats_handler

        self.strict = False