def equals(self, other: "SQABase") -> bool: """Check if `other` equals `self.`""" for field in self.attributes: if field in ["id", "_sa_instance_state"] or is_foreign_key_field(field): # We don't want to perform equality checks on foreign key fields, # since our equality checks are used to determine whether or not # to a new object is the same as an existing one. The new object # will always have None for its foreign key fields, because it # hasn't been inserted into the database yet. continue if not self.fields_equal(other, field): return False return True
def testEncoders(self): for class_, fake_func, unbound_encode_func, _ in TEST_CASES: original_object = fake_func() # We can skip metrics and runners; the encoders will automatically # handle the addition of new fields to these classes if isinstance(original_object, Metric) or isinstance( original_object, Runner ): continue encode_func = unbound_encode_func.__get__(self.encoder) sqa_object = encode_func(original_object) if isinstance( original_object, AbandonedArm ): # handle NamedTuple differently object_keys = original_object._asdict().keys() else: object_keys = original_object.__dict__.keys() object_keys = {remove_prefix(key, "_") for key in object_keys} sqa_keys = { remove_prefix(key, "_") for key in sqa_object.attributes if key not in ["id", "_sa_instance_state"] and not is_foreign_key_field(key) } # Account for fields that appear in the Python object but not the SQA # the SQA but not the Python, and for fields that appear in both places # but with different names if class_ in ENCODE_DECODE_FIELD_MAPS: map = ENCODE_DECODE_FIELD_MAPS[class_] for field in map.python_only: sqa_keys.add(field) for field in map.encoded_only: object_keys.add(field) for python, encoded in map.python_to_encoded.items(): sqa_keys.remove(encoded) sqa_keys.add(python) self.assertEqual( object_keys, sqa_keys, msg=f"Mismatch between Python and SQA representation in {class_}.", )
def update(self, other: "SQABase") -> None: """Merge `other` into `self.`""" ignore_during_update_fields = set( getattr(self, "ignore_during_update_fields", []) + ["id", "_sa_instance_state"]) immutable_fields = set(getattr(self, "immutable_fields", [])) for field in self.attributes: if field in immutable_fields: if self.fields_equal(other, field): continue raise ImmutabilityError( f"Cannot change `{field}` of {self.__class__.__name__}.") if (field in ignore_during_update_fields # We don't want to update foreign key fields, e.g. experiment_id. # The new object will always have a value of None for this field, # but we don't want to overwrite the value on the existing object. or is_foreign_key_field(field)): continue self.update_field(other, field)