def _load_embeddings( self, entity: EntityName, part: Partition, out: FloatTensorType, strict: bool = False, force_dirty: bool = False, ) -> Tuple[torch.nn.Parameter, Optimizer]: if strict: embs, optim_state = self.checkpoint_manager.read( entity, part, out=out, force_dirty=force_dirty ) else: # Strict is only false during the first iteration, because in that # case the checkpoint may not contain any data (unless a previous # run was resumed) so we fall back on initial values. embs, optim_state = self.checkpoint_manager.maybe_read( entity, part, out=out, force_dirty=force_dirty ) if embs is None and self.loadpath_manager is not None: embs, optim_state = self.loadpath_manager.maybe_read( entity, part, out=out ) if embs is None: embs = out fast_approx_rand(embs) embs.mul_(self.config.init_scale) optim_state = None embs = torch.nn.Parameter(embs) optimizer = make_optimizer(self.config, [embs], True) if optim_state is not None: optimizer.load_state_dict(optim_state) return embs, optimizer
def init_embs( entity: EntityName, entity_count: int, dim: int, scale: float, ) -> Tuple[FloatTensorType, None]: """Initialize embeddings of size entity_count x dim. """ # FIXME: Use multi-threaded instead of fast_approx_rand vlog("Initializing %s" % entity) return fast_approx_rand(entity_count * dim).view(entity_count, dim).mul_(scale), None
def init_embs( entity: EntityName, entity_count: int, relation_dim: int, scale: float, ) -> Tuple[FloatTensorType, None]: """Initialize embeddings of size entity_count x dim. """ # FIXME: Use multi-threaded instead of fast_approx_rand logger.debug(f"Initializing {entity}") return fast_approx_rand(entity_count * relation_dim).view(entity_count, relation_dim).mul_(scale), None