def expand_secret(self, serialized_unhashed_secret): assert (isinstance(serialized_unhashed_secret, bytes)) if len(serialized_unhashed_secret) != self.element_octet_cnt: raise InvalidInputException( "Do not know how to decode %s secret. Expected %d octets, but got %d." % (str(self), self.element_octet_cnt, len(serialized_unhashed_secret))) hashfnc = hashlib.new(self.expand_hashfnc) hashfnc.update(serialized_unhashed_secret) if ("expand_hashlen" in self._domain_parameters) and (hashfnc.digest_size == 0): hashed_secret = hashfnc.digest(self.expand_hashlen) else: hashed_secret = hashfnc.digest() if ("expand_hashlen" in self._domain_parameters) and ( len(hashed_secret) != self._domain_parameters["expand_hashlen"]): raise InvalidInputException( "Expansion of secret requires %d byte hash function output, but %s provided %d bytes." % (self._domain_parameters["expand_hashlen"], str(hashfnc), len(hashed_secret))) scalar = int.from_bytes(hashed_secret[:self.element_octet_cnt], byteorder="little") scalar &= self.expand_bitwise_and scalar |= self.expand_bitwise_or return (scalar, self.G.scalar_mul(scalar))
def lookup(self, oid=None, name=None, on_error="none"): assert (on_error in ["none", "raise"]) # Check that either OID or name must be given, not both if (oid is None) and (name is None): raise InvalidInputException( "Lookup from curve database needs either OID or name of curve to look up." ) if (oid is not None) and (name is not None): raise InvalidInputException( "Lookup from curve database needs either OID or name of curve to look up, not both. Given: OID %s and name %s." % (oid, name)) if oid is not None: curve_data = self._DB_DATA.get(oid) else: curve_data = self._DB_DATA.get(self._OID_BY_NAME.get(name)) if (curve_data is None) and (on_error == "raise"): if oid is not None: raise CurveNotFoundException( "No such curve with OID %s in database." % (oid)) else: raise CurveNotFoundException( "No such curve with name %s in database." % (name)) return curve_data
def _verify_pkcs11_uri(self): uri = self._parameters["pkcs11uri"] if uri.startswith("pkcs11:"): # Literal PKCS#11 URI, leave as-is. pass elif uri.startswith("label="): # Replace with encoded version label = uri[6:] self._parameters["pkcs11uri"] = "pkcs11:object=%s;type=private" % ( urllib.parse.quote(label)) elif uri.startswith("id="): # Replace with encoded version key_id_str = uri[3:] try: key_id = baseint(key_id_str) key_bytes = key_id.to_bytes(length=16, byteorder="big").lstrip(bytes(1)) except (ValueError, OverflowError) as e: raise InvalidInputException( "Key ID '%s' is not a valid hex value or is too large: %s" % (key_id_str, e.__class__.__name__)) key_id_quoted = "".join("%%%02x" % (c) for c in key_bytes) self._parameters["pkcs11uri"] = "pkcs11:id=%s;type=private" % ( key_id_quoted) else: raise InvalidInputException( "For hardware keys, you need to either give a RFC7512-compliant pkcs11-scheme URI (starts with 'pkcs11:'), a key label in the form 'label=foobar' or a key id in the hex form like 'id=0xabc123' or in decimal form like 'id=11256099'. The supplied value '%s' is neither." % (uri))
def decode_point(self, serialized_point): assert (isinstance(serialized_point, bytes)) if len(serialized_point) != self.element_octet_cnt: raise InvalidInputException( "Do not know how to decode %s point. Expected %d octets, but got %d." % (str(self), self.element_octet_cnt, len(serialized_point))) serialized_point = bytearray(serialized_point) x_lsb = (serialized_point[-1] >> 7) & 1 serialized_point[-1] &= 0x7f y = int.from_bytes(serialized_point, byteorder="little") if y >= self.p: raise InvalidInputException( "y coordinate of point must be smaller than p.") # x^2 = (1 - y^2) / (a - dy^2) x2 = (1 - y * y) % self.p x2 *= NumberTheory.modinv(self.a - self.d * y * y, self.p) (x_pos, x_neg) = NumberTheory.sqrt_mod_p(x2, self.p) if x_lsb == 0: x = x_pos else: x = x_neg point = self.point(x, y) return point
def check_single(self, arg, hint=None): if arg not in self._allowed_arguments: if hint is None: raise InvalidInputException("%s is not a valid argument." % (arg)) else: raise InvalidInputException( "%s is not a valid argument for %s." % (arg, hint))
def __init__(self, cmdname, args): BaseAction.__init__(self, cmdname, args) certs = X509Certificate.read_pemfile(self._args.crt_filename) if not certs[0].is_selfsigned: raise InvalidInputException("First certificate in chain (%s) is not self-signed." % (certs[0])) for (cert_id, (issuer, subject)) in enumerate(zip(certs, certs[1:]), 1): if not subject.signed_by(issuer): raise InvalidInputException("Certificate %d in file (%s) is not issuer for certificate %d (%s)." % (cert_id, issuer, cert_id + 1, subject)) self._log.debug("Chain of %d certificates to forge.", len(certs)) self._forge_cert(0, certs[0], 0, certs[0]) for (cert_subject_id, (issuer, subject)) in enumerate(zip(certs, certs[1:]), 1): self._forge_cert(cert_subject_id - 1, issuer, cert_subject_id, subject)
def _post_decode_hook(self): if self.asn1["parameters"] is None: raise InvalidInputException( "ECC private key does not contain curve OID. Cannot proceed.") if self.asn1["publicKey"] is None: raise InvalidInputException( "ECC private key does not contain public key. Cannot proceed.") curve_oid = OID.from_asn1(self.asn1["parameters"]) self._curve = CurveDB().instantiate(oid=curve_oid) self._d = int.from_bytes(self.asn1["privateKey"], byteorder="big") (self._x, self._y) = ECCTools.decode_enc_pubkey( ASN1Tools.bitstring2bytes(self.asn1["publicKey"]))
def sqrt_mod_p(cls, x, p): """Calculates the quadratic residue of x modulo p. Not a generic Tonelli-Shanks implementation, works only for p thats's 3 mod 4 or 5 mod 8.""" x %= p if (p % 4) == 3: sqrt = pow(x, (p + 1) // 4, p) elif (p % 8) == 5: # Two possibilities, depending on if x is a quartic residue modulo # p or not sqrt_qr = pow(x, (p + 3) // 8, p) if pow(sqrt_qr, 2, p) == x: # x is a quartic residue mod p sqrt = sqrt_qr else: # x is a quartic non-residue mod p sqrt = (sqrt_qr * pow(2, (p - 1) // 4, p)) % p else: raise NotImplementedError( "Need to use Tonelli-Shanks algorithm to find quadratic residues for a p %% 8 == %d" % (p % 8)) if ((sqrt * sqrt) % p) != (x % p): raise InvalidInputException( "Given input value has no quadratric residue modulo p.") if (sqrt & 1) == 0: return (sqrt, p - sqrt) else: return (p - sqrt, sqrt)
def check(self, args, hint=None): args = set(args) unknown_arguments = args - self._allowed_arguments missing_arguments = self._required_arguments - args errors = [] if len(unknown_arguments) > 0: if len(unknown_arguments) == 1: errors.append("unknown argument: %s" % (list(unknown_arguments)[0])) else: errors.append("%d unknown arguments: %s" % (len(unknown_arguments), ", ".join( sorted(unknown_arguments)))) if len(missing_arguments) > 0: errors.append("required argument(s) missing: %s" % (", ".join(sorted(missing_arguments)))) if len(errors) > 0: if len(errors) == 1: msg = "There was an error with the arguments" else: msg = "There were %d error(s) with the arguments" % ( len(errors)) if hint is not None: msg += " supplied to %s" % (hint) msg += ": " msg += " / ".join(errors) if len(self._required_arguments) > 0: msg += " -- required are: %s" % (", ".join( sorted(self._required_arguments))) if len(self._optional_arguments) > 0: msg += " -- allowed are: %s" % (", ".join( sorted(self._optional_arguments))) raise InvalidInputException(msg)
def modinv(cls, a, m): """Calculate modular inverse of a modulo m.""" (g, x, y) = cls.egcd(a, m) if g != 1: raise InvalidInputException( "Modular inverse of %d mod %d does not exist" % (a, m)) else: return x % m
def _parse(self, filename): with open(filename) as f: for (lineno, line) in enumerate(f, 1): line = line.rstrip("\r\n") try: self._parseline(line) except InvalidCAIndexFileEntry as e: raise InvalidInputException("Could not parse %s:%d (\"%s\") %s: %s" % (filename, lineno, line, e.__class__.__name__, str(e)))
def __init__(self, rdn_list): assert (isinstance(rdn_list, (tuple, list))) assert (all(isinstance(value[0], OID) for value in rdn_list)) assert (all(isinstance(value[1], bytes) for value in rdn_list)) if len(rdn_list) == 0: raise InvalidInputException("Empty RDN is not permitted.") self._rdn_list = tuple( self._decode_item(oid, derdata) for (oid, derdata) in rdn_list)
def from_str(cls, oid_string): assert (isinstance(oid_string, str)) split_string = oid_string.split(".") try: int_string = [int(value) for value in split_string] except ValueError: raise InvalidInputException( "Cannot parse \"%s\" as a string OID." % (oid_string)) return cls(int_string)
def pack(self, data): if self._padbyte is None: # Size must exactly match up if len(data) == self._length: return data else: raise InvalidInputException( "For packing of array of length %d without padding, %d bytes must be provided. Got %d bytes." % (self._length, self._length, len(data))) else: # Can pad if len(data) <= self._length: pad_len = self._length - len(data) padding = bytes([self._padbyte]) * pad_len return data + padding else: raise InvalidInputException( "For packing of array of length %d with padding, at most %d bytes must be provided. Got %d bytes." % (self._length, self._length, len(data)))
def pack(self, value=None): if (value is None) and (self._fixed_value is not None): value = self._fixed_value if not isinstance(value, int): raise InvalidInputTypeException( "%s requires int to be supplied for packing, got %s: %s" % (self.typename, type(value).__name__, str(value))) if self._enum_class is not None: if self._strict_enum: if not isinstance(value, self._enum_class): raise InvalidInputException( "%s packing input must be of type %s." % (str(self), self._enum_class)) value = int(value) if (value < self._minval) or (value > self._maxval): raise InvalidInputException( "%s must be between %d and %d (given value was %d)." % (str(self), self._minval, self._maxval, value)) data = int.to_bytes(value, byteorder="big", length=self._length_bytes) return data
def create(cls, *key_values): def encode(text): asn1 = UTF8String(text) derdata = pyasn1.codec.der.encoder.encode(asn1) return derdata if (len(key_values) % 2) != 0: raise InvalidInputException( "Length of key/values must be evenly divisible by two.") rdn_list = [] for (key, value) in zip(key_values[::2], key_values[1::2]): assert (isinstance(key, str)) assert (isinstance(value, (str, bytes))) try: oid = OIDDB.RDNTypes.inverse(key) except KeyError: raise InvalidInputException( "Cannot create RDN with key '%s', cannot look up OID for it." % (key)) rdn_list.append((oid, encode(value))) return cls(rdn_list)
def isqrt(cls, value): if value < 0: raise InvalidInputException("Cannot return isqrt(%d)" % (value)) elif value < 2: return value else: small = cls.isqrt(value >> 2) << 1 large = small + 1 if (large * large) > value: return small else: return large
def __init__(self, cmdname, args): BaseAction.__init__(self, cmdname, args) if len(args.hash_alg) == 0: hash_fncs = self.get_default_hash_fncs() elif ("all" in args.hash_alg): hash_fncs = self.get_supported_hash_fncs() else: hash_fncs = set(args.hash_alg) with open(args.filename, "rb") as f: f.seek(args.seek_offset) self._data = f.read(args.analysis_length) if len(self._args.variable_hash_length) == 0: hash_lengths = self.get_default_variable_hash_lengths_bits() else: hash_lengths = self._args.variable_hash_length if any((value % 8) != 0 for value in hash_lengths): raise InvalidInputException("Not all hash lengths are evenly divisible by 8: %s" % (", ".join("%d" % (value) for value in hash_lengths))) if any((value < 8) for value in hash_lengths): raise InvalidInputException("Hash length must be at least 8 bit, not all are: %s" % (", ".join("%d" % (value) for value in hash_lengths))) valid_search_chars = set("abcdefABCDEF0123456789") if (self._args.search is not None) and (len(set(self._args.search) - valid_search_chars) > 0): raise InvalidInputException("Search pattern may only contain hexadecimal characters, but '%s' was given." % (self._args.search)) self._search = (self._args.search or "").lower() t0 = time.time() hash_fncs_by_name = { hashfnc.name: hashfnc for hashfnc in self.get_all_supported_hash_fncs() } for hash_fnc_name in sorted(hash_fncs): hash_fnc = hash_fncs_by_name[hash_fnc_name] if not hash_fnc.variable_output_length: self._run_hash_function(hash_fnc_name) else: for hash_len_bits in hash_lengths: self._run_hash_function(hash_fnc_name, output_length_bits = hash_len_bits) t1 = time.time() self._log.debug("Hash search took %.1f secs.", t1 - t0)
def unpad_pkcs1(cls, data): if data[0] != 1: raise InvalidInputException("PKCS#1 padding must start with 0x01") i = 0 last_char = None for i in range(1, len(data)): if data[i] == 0xff: continue elif data[i] == 0x0: # Finished last_char = i break else: raise InvalidInputException( "PKCS#1 padding must be either 0xff or 0x00 at offset %d, was 0x%02x" % (i, data[i])) if last_char is None: raise InvalidInputException( "PKCS#1 padding does not seem to contain data.") return data[i + 1:]
def _post_decode_hook(self): if self.asn1["privateKeyAlgorithm"]["algorithm"] is None: raise InvalidInputException( "EdDSA private key does not contain curve OID. Cannot proceed." ) curve_oid = OID.from_asn1( self.asn1["privateKeyAlgorithm"]["algorithm"]) pk_alg = PublicKeyAlgorithms.lookup("oid", curve_oid) self._curve = CurveDB().instantiate(oid=curve_oid) self._prehash = pk_alg.value.fixed_params["prehash"] private_key = bytes(self.asn1["privateKey"]) if (private_key[0] != 0x04) or (private_key[1] != self.curve.element_octet_cnt): raise InvalidInputException( "EdDSA private key does start with 04 %02x, but with %02x %02x." % (self.curve.element_octet_cnt, private_key[0], private_key[1])) if len(private_key) != self.curve.element_octet_cnt + 2: raise InvalidInputException( "EdDSA private key length expected to be %d octets, but was %d octets." % (self.curve.element_octet_cnt + 2, len(private_key[0]))) self._priv = private_key[2:]
def from_rfc2253_str(cls, string): keyvalues = [] for item in _RFC2253_StringParser.split(string, control_char="+", reassemble=True): keyvalue = _RFC2253_StringParser.split(item, control_char="=", reassemble=True) if len(keyvalue) != 2: raise InvalidInputException( "RDN string item has %d members, expected 2." % (len(keyvalue))) (key, value) = keyvalue keyvalues.append(key) keyvalues.append(value) return cls.create(*keyvalues)
def private_to_public(cls, private_key_filename, public_key_filename): success = SubprocessExecutor([ cls._EXECUTABLE, "rsa", "-in", private_key_filename, "-pubout", "-out", public_key_filename ], on_failure="pass").run().successful if not success: success = SubprocessExecutor([ cls._EXECUTABLE, "ec", "-in", private_key_filename, "-pubout", "-out", public_key_filename ], on_failure="pass").run().successful if not success: raise InvalidInputException( "File %s contained neither RSA nor ECC private key." % (private_key_filename))
def decode_point(self, serialized_point): if serialized_point[0] == 0x04: expected_length = 1 + (2 * self.element_octet_cnt) if len(serialized_point) == expected_length: Gx = int.from_bytes(serialized_point[1:1 + self.element_octet_cnt], byteorder="big") Gy = int.from_bytes( serialized_point[1 + self.element_octet_cnt:1 + (2 * self.element_octet_cnt)], byteorder="big") return EllipticCurvePoint(x=Gx, y=Gy, curve=self) else: raise InvalidInputException( "Do not know how to decode explicit serialized point with length %d (expected %d = 1 + 2 * %d bytes)." % (len(serialized_point), 1 + (2 * self.element_octet_cnt), self.element_octet_cnt)) else: raise UnsupportedEncodingException( "Do not know how to decode serialized point in non-explicit point format 0x%x." % (serialized_point[0]))
def __init__(self, cmdname, args): BaseAction.__init__(self, cmdname, args) if (not self._args.force) and os.path.exists(self._args.outfile): raise UnfulfilledPrerequisitesException( "File/directory %s already exists. Remove it first or use --force." % (self._args.outfile)) if not self._args.gcd_n_phi_n: self._primetype = "2msb" self._p_bitlen = self._args.bitlen // 2 self._q_bitlen = self._args.bitlen - self._p_bitlen else: self._primetype = "3msb" self._p_bitlen = self._args.bitlen // 3 self._q_bitlen = self._args.bitlen - (2 * self._p_bitlen) - 1 if (self._args.close_q) and (self._p_bitlen != self._q_bitlen): raise UnfulfilledPrerequisitesException( "Generating a close-q keypair with a %d modulus does't work, because p would have to be %d bit and q %d bit. Choose an even modulus bitlength." % (self._args.bitlen, self._p_bitlen, self._q_bitlen)) if self._args.q_stepping < 1: raise InvalidInputException( "q-stepping value must be greater or equal to 1, was %d." % (self._args.q_stepping)) self._log.debug("Selecting %s primes with p = %d bit and q = %d bit.", self._primetype, self._p_bitlen, self._q_bitlen) self._prime_db = PrimeDB(self._args.prime_db, generator_program=self._args.generator) p = None q = None while True: if p is None: p = self._prime_db.get(bitlen=self._p_bitlen, primetype=self._primetype) q_generator = self._select_q(p) if q is None: q = next(q_generator) if self._args.gcd_n_phi_n: # q = (2 * r * p) + 1 r = q q = 2 * r * p + 1 if not NumberTheory.is_probable_prime(q): q = None continue # Always make p the smaller factor if p > q: (p, q) = (q, p) n = p * q if self._args.public_exponent == -1: e = random.randint(2, n - 1) else: e = self._args.public_exponent if self._args.carmichael_totient: totient = NumberTheory.lcm(p - 1, q - 1) else: totient = (p - 1) * (q - 1) gcd = NumberTheory.gcd(totient, e) if self._args.accept_unusable_key or (gcd == 1): break else: # Pair (phi(n), e) wasn't acceptable. self._log.debug("gcd(totient, e) was %d, retrying.", gcd) if self._args.public_exponent != -1: # Public exponent e is fixed, need to choose another q. if p.bit_length() == q.bit_length(): # Can re-use q as next p (p, q) = (q, None) q_generator = self._select_q(p) else: # When they differ in length, need to re-choose both values (p, q) = (None, None) rsa_keypair = RSAPrivateKey.create( p=p, q=q, e=e, swap_e_d=self._args.switch_e_d, valid_only=not self._args.accept_unusable_key, carmichael_totient=self._args.carmichael_totient) rsa_keypair.write_pemfile(self._args.outfile) if self._args.verbose >= 1: diff = q - p print("Generated %d bit RSA key:" % (rsa_keypair.n.bit_length())) print("p = 0x%x" % (rsa_keypair.p)) if not self._args.gcd_n_phi_n: print("q = 0x%x" % (rsa_keypair.q)) else: print("q = 2 * r * p + 1 = 0x%x" % (rsa_keypair.q)) print("r = 0x%x" % (r)) print("phi(n) = 0x%x" % (rsa_keypair.phi_n)) print("lambda(n) = 0x%x" % (rsa_keypair.lambda_n)) print("phi(n) / lambda(n) = gcd(p - 1, q - 1) = %d" % (rsa_keypair.phi_n // rsa_keypair.lambda_n)) gcd_n_phin = NumberTheory.gcd(rsa_keypair.n, rsa_keypair.phi_n) if gcd_n_phin == rsa_keypair.p: print("gcd(n, phi(n)) = p") else: print("gcd(n, phi(n)) = 0x%x" % (gcd_n_phin)) if self._args.close_q: print("q - p = %d (%d bit)" % (diff, diff.bit_length())) print("n = 0x%x" % (rsa_keypair.n)) print("d = 0x%x" % (rsa_keypair.d)) print("e = 0x%x" % (rsa_keypair.e))
def from_asn1(cls, asn1): """Decode explicitly encoded elliptic curve domain parameters, given as a Sequence (SpecifiedECDomain).""" version = int(asn1["version"]) if version != 1: raise InvalidInputException( "Attempted to decode the excplicit EC domain and saw unknown version %d." % (version)) field_type = OID.from_asn1(asn1["fieldID"]["fieldType"]) field_type_id = OIDDB.ECFieldType.get(field_type) if field_type_id is None: raise InvalidInputException( "Encountered explicit EC domain parameters in unknown field with OID %s." % (str(field_type))) domain_parameters = { "a": int.from_bytes(bytes(asn1["curve"]["a"]), byteorder="big"), "b": int.from_bytes(bytes(asn1["curve"]["b"]), byteorder="big"), "n": int(asn1["order"]), } if asn1["cofactor"].hasValue(): domain_parameters["h"] = int(asn1["cofactor"]) base_point = bytes(asn1["base"]) if field_type_id == "prime-field": (field_params, tail) = pyasn1.codec.der.decoder.decode( bytes(asn1["fieldID"]["parameters"]), asn1Spec=ECFieldParametersPrimeField()) if len(tail) != 0: raise InvalidInputException( "Attempted to decode the excplicit EC domain and encountered %d bytes of trailing data of the prime basis Integer." % (len(tail))) domain_parameters.update({ "p": int(field_params), }) return cls.get_class_for_curvetype("prime").instantiate( domain_parameters, base_point) elif field_type_id == "characteristic-two-field": (field_params, tail) = pyasn1.codec.der.decoder.decode( bytes(asn1["fieldID"]["parameters"]), asn1Spec=ECFieldParametersCharacteristicTwoField()) if len(tail) != 0: raise InvalidInputException( "Attempted to decode the excplicit EC domain and encountered %d bytes of trailing data of the characteristic two field Sequence." % (len(tail))) basis_type = OID.from_asn1(field_params["basis"]) basis_type_id = OIDDB.ECTwoFieldBasistype.get(basis_type) if basis_type_id is None: raise InvalidInputException( "Unknown two-field basis type with OID %s found in public key." % (str(basis_type))) # Field width is common to all two-fields domain_parameters.update({ "m": int(field_params["m"]), "basis": basis_type_id, }) if basis_type_id == "gnBasis": raise InvalidInputException( "Binary field explicit domain parameters with Gaussian polynomial basis is not implemented." ) elif basis_type_id == "tpBasis": (params, tail) = ASN1Tools.redecode( field_params["parameters"], ECFieldParametersCharacteristicTwoFieldTrinomial()) if len(tail) != 0: raise InvalidInputException( "Attempted to decode the excplicit EC domain and encountered %d bytes of trailing data of the characteristic two field trinomial basis." % (len(tail))) poly = [domain_parameters["m"], int(params), 0] elif basis_type_id == "ppBasis": (params, tail) = ASN1Tools.redecode( field_params["parameters"], ECFieldParametersCharacteristicTwoFieldPentanomial()) if len(tail) != 0: raise InvalidInputException( "Attempted to decode the excplicit EC domain and encountered %d bytes of trailing data of the characteristic two field pentanomial basis." % (len(tail))) poly = [ domain_parameters["m"], int(params["k1"]), int(params["k2"]), int(params["k3"]), 0 ] else: raise NotImplementedError("Binary field basis", basis_type_id) domain_parameters.update({ "m": int(field_params["m"]), "poly": poly, }) return cls.get_class_for_curvetype("binary").instantiate( domain_parameters, base_point) else: raise NotImplementedError( "Explicit EC domain parameter encoding for field type \"%s\" is not implemented." % (field_type_id))
def from_asn1(cls, oid_asn1): if not isinstance(oid_asn1, pyasn1.type.univ.ObjectIdentifier): raise InvalidInputException( "Tried to construct an OID from ASN.1 of type %s." % (oid_asn1.__class__.__name__)) return cls.from_str(str(oid_asn1))