def generate_filters(table: t.Any, info: t.Any, *args: t.Any, **kwargs: t.Any) -> t.Any: sqla_filters = [] if "filter" in kwargs and kwargs["filter"] is not None: mapper = get_mapper(table) gql_filters = kwargs["filter"] for filter_name, gql_filter in gql_filters.items(): gql_filter_value = gql_filter["value"] filter_name = underscore(filter_name) if filter_name in table.c: filter_type = table.c[filter_name].type elif filter_name in mapper.relationships: rel = mapper.relationships[filter_name] rel_mapper = get_mapper(rel.target) gql_filter_value = ( info.context.query(rel_mapper.class_) .filter_by(id=gql_filter_value) .one() ) filter_type = rel else: raise KeyError(filter_name) sql_filter = get_filter_comparator(filter_type)( gql_filter_value, gql_filter["operator"], getattr(mapper.class_, filter_name), ) sqla_filters.append(sql_filter) return sqla_filters
def generate_manager(self, table: t.Any) -> t.Optional[MagqlTableManager]: try: get_mapper(table) except ValueError: logging.getLogger(__name__).warning(f"No mapper for table {table.name!r}.") return None return MagqlTableManager( table, create_resolver=self.create_resolver(table), update_resolver=self.update_resolver(table), delete_resolver=self.delete_resolver(table), single_resolver=self.single_resolver(table), many_resolver=self.many_resolver(table), )
def input_to_instance_values(self, input: t.Dict[str, t.Any], mapper: mapper, session: t.Any) -> t.Dict[str, t.Any]: """ Helper method that converts the values in the input into values that can be passed into the creation of the instance. This returns scalars as themselves and passed id's as an object or list of objects that can be set in the creation call. :param input: The user's input dictionary containing the fields that are desired to be created. :param mapper: The mapper for the table that is being created :param session: The SQLAlchemy session :return: A dict of field names to values that will be added/changed on the newly created/modified object. """ instance_values = {} for key, value in input.items(): key = underscore(key) if key in mapper.c: col_type = mapper.c[key].type if isinstance(col_type, ChoiceType): for enum_tuple in col_type.choices: if value == enum_tuple[1]: value = enum_tuple[0] break if key in mapper.relationships: target = get_mapper(mapper.relationships[key].target).class_ query = session.query(target) if value: if isinstance(value, list): value = query.filter(target.id.in_(value)).all() else: value = query.filter(target.id == value).one() instance_values[key] = value return instance_values
def update(self, table=None, columns=None, messages=None): table = self.meta.tables[table] mappings = self._generate_mappings(table, columns=columns, messages=messages) mapper = get_mapper(table) self.session.bulk_update_mappings(mapper, mappings) self.session.commit()
def update(self, *, table: str, columns: List[str], messages: List[Dict[str, Any]]) -> None: table_obj = self.meta.tables[table] mappings = self._generate_mappings(table_obj, columns=columns, messages=messages) mapper = get_mapper(table_obj) self.session.bulk_update_mappings(mapper, mappings) self.session.commit()
def process_products(products, session): products_table = metadata.tables['pharmaceuticals'] for prod in products: timestamp = datetime.utcnow().replace(tzinfo=timezone.utc).replace(microsecond=0).isoformat() hashed = hashlib.md5(json.dumps(prod, sort_keys=True).encode('utf-8')).hexdigest() transformed_record = {"commercial_name": prod['nombreComercial'], "active": prod['Activo'], "inscription_date": prod['FECHA_INSCRIPCION'], "registry_num": prod['noregistro'], "formula": prod['formula'], "active_ingredient": prod['principio_activo'], "pharma_comp": prod['laboratorio'], "pharmaceutical_form": prod['NOMBRE_FORMA_FARMACEUTICA'], "level_price": prod['precio_nivel'], "concentration": prod['concentracion'], "min_presentation_price": prod['precio_presentacion_min'], "max_presentation_price": prod['precio_presentacion_max'], "range_price": prod['precio_rango'], "hash": hashed, "created_at": timestamp, "updated_at": timestamp} _logger.info("Record {}".format(transformed_record)) update_record = session.query(products_table).filter(products_table.c.hash == str(hashed), products_table.c.registry_num != transformed_record['registry_num']) if update_record.first() is not None: _logger.info(update_record.first()) del transformed_record['created_at'] update_record.update(transformed_record) else: _logger.info("Record {}".format(transformed_record)) session.bulk_insert_mappings((get_mapper(metadata.tables['pharmaceuticals'])), [transformed_record])
def __init__(self, table: t.Any): """ MutationResolver can be overriden by :param table: a sqlalchemy table """ self.table = table self.table_class = get_mapper(table).class_ super().__init__()
def __call__(self, parent: t.Any, info: t.Any, *args: t.Any, **kwargs: t.Any) -> t.Optional[t.Any]: for magql_name, table in self.magql_name_to_table.items(): if isinstance(parent, get_mapper(table).class_): for gql_type in info.return_type.of_type.types: if gql_type.name == magql_name: return gql_type raise Exception("Type not found")
def generate_query(self, info: t.Any) -> t.Any: """ Generates a basic query based on the mapped class :param info: GraphQL info dict, used to hold the SQLA session :return: A SQLAlchemy query based on the mapped class, session.query(ModelClass) """ session = info.context mapper = get_mapper(self.table) return session.query(mapper.class_)
def retrieve_value(self, parent: t.Any, info: t.Any, *args: t.Any, **kwargs: t.Any) -> t.Any: """ Retrieves the row in the table that matches the id in the args, if such a row exists :param parent: gql parent. The value returned by the parent resolver. See GraphQL docs for more info :param info: GraphQL info dictionary, see GraphQL docs for more info :return: the value that will be operated on and returned to GraphQL, in this case the row with id matching the requested id """ session = info.context mapper = get_mapper(self.table) id_ = kwargs["id"] return session.query(mapper.class_).filter_by(id=id_).one()
def retrieve_value(self, parent: None, info: t.Any, *args: t.Any, **kwargs: t.Any) -> t.Any: """ Creates an empty row in the table that will be modified by mutate. :param parent: parent object required by GraphQL, always None because mutations are always top level. :param info: GraphQL info dictionary :param args: Not used in automatic generation but left in in case overriding the validate or call methods. :param kwargs: Holds user inputs. :return: The instance with newly modified values """ mapper = get_mapper(self.table) # TODO: Replace with dictionary spread operator instance = mapper.class_() return instance
def generate_sorts(table: t.Any, info: t.Any, *args: t.Any, **kwargs: t.Any) -> t.List[t.Any]: sqla_sorts = [] if "sort" in kwargs and kwargs["sort"] is not None: class_ = get_mapper(table).class_ gql_sorts = kwargs["sort"] for sort in gql_sorts: field_name, direction = sort[0].rsplit("_", 1) field = getattr(class_, field_name) if direction == "asc": sort = field.asc() elif direction == "desc": sort = field.desc() else: raise SortNotFoundError(field_name, direction) sqla_sorts.append(sort) return sqla_sorts
def __init__( self, table: t.Any, magql_name: t.Optional[str] = None, create_resolver: t.Optional[CreateResolver] = None, update_resolver: t.Optional[UpdateResolver] = None, delete_resolver: t.Optional[DeleteResolver] = None, single_resolver: t.Optional[SingleResolver] = None, many_resolver: t.Optional[ManyResolver] = None, ): """ The manager for a single sqlalchemy table. :param table: The table that is being managed :param magql_name: Optional name override for how the table is referred to :param create_resolver: Optional override for create resolver :param update_resolver: Optional override for update resolver :param delete_resolver: Optional override for delete resolver :param single_resolver: Optional override for single resolver :param many_resolver: Optional override for many resolver """ super().__init__( magql_name if magql_name is not None else camelize(table.name) ) # magql_object_name # Throws ValueError if it cannot find a table self.table_class = get_mapper(table).class_ self.table = table self.table_name = table.name self.create_resolver = ( create_resolver if create_resolver else CreateResolver(self.table) ) self.update_resolver = ( update_resolver if update_resolver else UpdateResolver(self.table) ) self.delete_resolver = ( delete_resolver if delete_resolver else DeleteResolver(self.table) ) self.single_resolver = ( single_resolver if single_resolver else SingleResolver(self.table) ) self.many_resolver = ( many_resolver if many_resolver else ManyResolver(self.table) ) self.generate_magql_types()
def retrieve_value(self, parent: None, info: t.Any, *args: t.Any, **kwargs: t.Any) -> t.Any: """ Updates the instance of the associated table with the id passed. Performs setattr on the key/value pairs. :param parent: parent object required by GraphQL, always None because mutations are always top level. :param info: GraphQL info dictionary :param args: Not used in automatic generation but left in in case overriding the validate or call methods. :param kwargs: Holds user inputs. :return: The instance with newly modified valuesf """ session = info.context mapper = get_mapper(self.table) id_ = kwargs["id"] return session.query(mapper.class_).filter_by(id=id_).one()
def pre_resolve(self, parent: t.Any, info: t.Any, *args: t.Any, **kwargs: t.Any) -> t.Tuple[t.Any, t.Any, t.Any, t.Any]: """ Converts ids of rels to actual values and handles enums :param parent: parent object required by GraphQL, always None because mutations are always top level. :param info: GraphQL info dictionary :param args: Not used in automatic generation but left in in case overriding the validate or call methods. :param kwargs: Holds user inputs. :return: The modified arguments """ session = info.context mapper = get_mapper(self.table) kwargs["input"] = self.input_to_instance_values( kwargs["input"], mapper, session) return parent, info, args, kwargs
def ensure_unique_default_per_project(target, value, oldvalue, initiator): """Ensures that only one row in table is specified as the default.""" session = object_session(target) if session is None: return mapped_cls = get_mapper(target) if value: previous_default = (session.query(mapped_cls).filter( mapped_cls.columns.default == true()).filter( mapped_cls.columns.project_id == target.project_id).one_or_none()) if previous_default: # we want exclude updating the current default if previous_default.id != target.id: previous_default.default = False session.commit()
def __call__(self, parent: t.Any, info: t.Any, *args: t.Any, **kwargs: t.Any) -> t.Optional[t.List[t.Any]]: for table in self.tables: try: class_ = get_mapper(table).class_ except ValueError: continue # TODO: switch over frontend to class name if class_.__name__ == kwargs["tableName"]: id_ = kwargs["id"] session = info.context instance = session.query(class_).filter_by(id=id_).one() session.delete(instance) cascades = [] for obj in session.deleted: cascades.append(obj) session.rollback() return cascades return None
def generate_subqueryloads( self, field_node: t.Any, load_path: t.Optional[t.Any] = None) -> t.List[t.Any]: """ A helper function that allows the generation of the top level query to only have to perform one query with subqueryloads to eager load the data that will be accessed due to the structure of the query. Recursively builds a list of subquery loads that are applied to the base query. :param field_node: The document ast node that is used to determine what relationships are accessed by the query :param load_path: The load path that should be appended to in order to build the correct subquery :return: A list of all subqueries needed to eagerly load all data accessed as a result of the query """ options: t.List[t.Any] = [] # A node is a lead if all of its children are scalars for selection in field_node.selection_set.selections: # selections with no sub selection_sets are scalars if selection.selection_set is None: continue field_name = js_underscore(selection.name.value) if field_name not in get_mapper(self.table).relationships: continue if load_path is None: extended_load_path = subqueryload(field_name) else: extended_load_path = load_path.subqueryload(field_name) options = options + self.generate_subqueryloads( selection, extended_load_path) # if all children are leaves then this is the last node, if len(options) == 0: return [load_path] if load_path is not None else [] return options
def test_column(self): assert ( get_mapper(self.Building.__table__.c.id) == sa.inspect(self.Building) )
def test_table_alias(self): alias = sa.orm.aliased(self.Building.__table__) assert ( get_mapper(alias) == sa.inspect(self.Building) )
def test_instrumented_attribute(self): assert ( get_mapper(self.Building.id) == sa.inspect(self.Building) )
def test_class_alias(self): assert ( get_mapper(sa.orm.aliased(self.Building)) == sa.inspect(self.Building) )
def test_mapper(self): assert ( get_mapper(self.Building.__mapper__) == sa.inspect(self.Building) )
def test_column_entity(self): query = self.session.query(self.Building.id) assert get_mapper(query._entities[0]) == sa.inspect(self.Building)
def test_declarative_class(self): assert get_mapper(self.Building) == sa.inspect(self.Building)
def test_column_of_an_alias(self): assert ( get_mapper(sa.orm.aliased(self.Building.__table__).c.id) == sa.inspect(self.Building) )
def test_declarative_object(self): assert ( get_mapper(self.Building()) == sa.inspect(self.Building) )
def test_table_alias(self): alias = sa.orm.aliased(self.building) with raises(ValueError): get_mapper(alias)
def test_mapper_entity_with_class(self): entity = self.session.query(self.Building)._entities[0] assert ( get_mapper(entity) == sa.inspect(self.Building) )
def test_table(self): assert get_mapper(self.Building.__table__) == sa.inspect(self.Building)
def test_table(self): with raises(ValueError): get_mapper(self.building)
def test_table(self, Building): with pytest.raises(ValueError): get_mapper(Building.__table__)
def test_table(self, building): with pytest.raises(ValueError): get_mapper(building)
def test_table_alias(self, building): alias = sa.orm.aliased(building) with pytest.raises(ValueError): get_mapper(alias)
def test_declarative_class(self, Building): assert ( get_mapper(Building) == sa.inspect(Building) )
def test_mapper_entity_with_mapper(self, session, Building): entity = session.query(Building.__mapper__)._entities[0] assert ( get_mapper(entity) == sa.inspect(Building) )
def insert(self, *, table: str, messages: List[Dict[str, Any]]) -> None: table_obj = self.meta.tables[table] mappings = self._generate_mappings(table_obj, messages=messages) mapper = get_mapper(table_obj) self.session.bulk_insert_mappings(mapper, mappings) self.session.commit()