def from_dict(cls, d, chains: Optional[List[ChainConfig]] = None): config = super().from_dict(d) # bail if both full shard ID list and chain mask list exist shard_ids = getattr(config, "FULL_SHARD_ID_LIST", None) chain_mask = getattr(config, "CHAIN_MASK_LIST", None) if shard_ids and chain_mask: raise ValueError( "Can only have either FULL_SHARD_ID_LIST or CHAIN_MASK_LIST") elif shard_ids: # parse from hex to int config.FULL_SHARD_ID_LIST = [ int(h, 16) for h in config.FULL_SHARD_ID_LIST ] elif chain_mask: if chains is None: raise ValueError( "Can't handle legacy CHAIN_MASK_LIST without chain configs" ) # a simple way to be backward compatible with hard-coded shard ID # e.g. chain mask 4 => 0x00000001, 0x00040001 # note that this only works if every chain has 1 shard only check(all(chain.SHARD_SIZE == 1 for chain in chains)) for m in chain_mask: bit_mask = (1 << (int_left_most_bit(m) - 1)) - 1 config.FULL_SHARD_ID_LIST = [ int("0x{:04x}0001".format(chain_id), 16) for chain_id in range(len(chains)) if chain_id & bit_mask == m & bit_mask ] delattr(config, "CHAIN_MASK_LIST") else: raise ValueError( "Missing FULL_SHARD_ID_LIST (or CHAIN_MASK_LIST as legacy config)" ) return config
def repr_shard(full_shard_id: Optional[int]): if full_shard_id is None: return "ROOT" chain = full_shard_id >> 16 shard = full_shard_id & 0xffff shard -= 1 << (int_left_most_bit(shard) - 1) return "CHAIN %d SHARD %d" % (chain, shard)
def create(shard_size, reshard_vote=False): assert is_p2(shard_size) reshard_vote = 1 if reshard_vote else 0 return ShardInfo( int_left_most_bit(shard_size) - 1 + (reshard_vote << 31))
def iterate(self, shard_size): shard_bits = int_left_most_bit(shard_size) mask_bits = int_left_most_bit(self.value) - 1 bit_mask = (1 << mask_bits) - 1 for i in range(1 << (shard_bits - mask_bits - 1)): yield (i << mask_bits) + (bit_mask & self.value)
def contain_shard_id(self, shard_id): bit_mask = (1 << (int_left_most_bit(self.value) - 1)) - 1 return (bit_mask & shard_id) == (self.value & bit_mask)
def get_shard_size(self): return 1 << (int_left_most_bit(self.value) - 1)
def contain_full_shard_id(self, full_shard_id: int): chain_id = full_shard_id >> 16 bit_mask = (1 << (int_left_most_bit(self.value) - 1)) - 1 return (bit_mask & chain_id) == (self.value & bit_mask)
def get_shard_size(self): branch_value = self.value & ((1 << 16) - 1) return 1 << (int_left_most_bit(branch_value) - 1)