Source code for sqllineage.core.holders

import itertools

from sqllineage.core.graph import get_graph_operator_class
from sqllineage.core.graph_operator import GraphOperator
from sqllineage.core.metadata_provider import MetaDataProvider
from sqllineage.core.models import Column, Path, Schema, SubQuery, Table
from sqllineage.utils.constant import EdgeDirection, EdgeTag, EdgeType, NodeTag

DATASET_CLASSES = (Path, Table)


class ColumnLineageMixin:
    def get_column_lineage(
        self, exclude_path_ending_in_subquery=True, exclude_subquery_columns=False
    ) -> set[tuple[Column, ...]]:
        """
        :param exclude_path_ending_in_subquery:  exclude_subquery rename to exclude_path_ending_in_subquery
               exclude column from SubQuery in the ending path
        :param exclude_subquery_columns: exclude column from SubQuery in the path.

        return a list of column tuple :class:`sqllineage.models.Column`
        """
        self.go: GraphOperator  # For mypy attribute checking
        # filter all the column node in the graph
        column_graph = self.go.get_sub_graph(
            *[v for v in self.go.retrieve_vertices_by_props() if isinstance(v, Column)]
        )
        source_columns = column_graph.retrieve_source_vertices()
        target_columns = column_graph.retrieve_target_vertices()
        # handle column-level self-loop case like table-level
        selfloop_columns = column_graph.retrieve_selfloop_vertices()
        for column_group in [source_columns, target_columns]:
            for column in selfloop_columns:
                if column not in column_group:
                    column_group.append(column)
        # if a column lineage path ends at SubQuery, then it should be pruned
        if exclude_path_ending_in_subquery:
            target_columns = {
                node for node in target_columns if isinstance(node.parent, Table)
            }
        columns = set()
        for source, target in itertools.product(source_columns, target_columns):
            simple_paths = self.go.list_lineage_paths(source, target)
            for path in simple_paths:
                if exclude_subquery_columns:
                    path = [
                        node for node in path if not isinstance(node.parent, SubQuery)
                    ]
                    if len(path) > 1:
                        columns.add(tuple(path))
                else:
                    columns.add(tuple(path))
        return columns


[docs] class SubQueryLineageHolder(ColumnLineageMixin): """ SubQuery/Query Level Lineage Result. SubQueryLineageHolder will hold attributes like read, write, cte. Each of them is a set[:class:`sqllineage.core.models.Table`]. This is the most atomic representation of lineage result. """ def __init__(self) -> None: self.go = get_graph_operator_class()() def __or__(self, other): self.go.merge(other.go) return self def _property_getter(self, prop) -> set[SubQuery | Table]: vertices: list[SubQuery | Table] = self.go.retrieve_vertices_by_props( **{prop: True} ) return {t for t in vertices} def _property_setter(self, value, prop) -> None: self.go.add_vertex_if_not_exist(value, **{prop: True}) @property def read(self) -> set[SubQuery | Table]: return self._property_getter(NodeTag.READ) def add_read(self, value) -> None: self._property_setter(value, NodeTag.READ) # the same table can be added (in SQL: joined) multiple times with different alias if hasattr(value, "alias"): self.go.add_edge_if_not_exist(value, value.alias, EdgeType.HAS_ALIAS) @property def write(self) -> set[SubQuery | Table]: # SubQueryLineageHolder.write can return a single SubQuery or Table, or both when __or__ together. # This is different from StatementLineageHolder.write, where Table is the only possibility. return self._property_getter(NodeTag.WRITE) def add_write(self, value) -> None: self._property_setter(value, NodeTag.WRITE) @property def cte(self) -> set[SubQuery]: return self._property_getter(NodeTag.CTE) # type: ignore def add_cte(self, value) -> None: self._property_setter(value, NodeTag.CTE) @property def write_columns(self) -> list[Column]: """ return a list of columns that write table contains. It's either manually added via `add_write_column` if specified in DML or automatic added via `add_column_lineage` after parsing from SELECT """ tgt_cols = [] if tgt_tbl := self._get_target_table(): tbl_col_edges = self.go.retrieve_edges_by_vertex( tgt_tbl, EdgeDirection.OUT, EdgeType.HAS_COLUMN ) tgt_col_with_idx: list[tuple[Column, int]] = sorted( [(e.target, e.attributes.get(EdgeTag.INDEX, 0)) for e in tbl_col_edges], key=lambda x: x[1], ) tgt_cols = [x[0] for x in tgt_col_with_idx] return tgt_cols
[docs] def add_write_column(self, *tgt_cols: Column) -> None: """ in case of DML with column specified, like: .. code-block:: sql INSERT INTO tab1 (col1, col2) SELECT col3, col4 this method is called to make sure tab1 has column col1 and col2 instead of col3 and col4 """ if self.write: tgt_tbl = list(self.write)[0] for idx, tgt_col in enumerate(tgt_cols): tgt_col.parent = tgt_tbl self.go.add_edge_if_not_exist( tgt_tbl, tgt_col, EdgeType.HAS_COLUMN, **{EdgeTag.INDEX: idx} )
[docs] def add_column_lineage(self, src: Column, tgt: Column) -> None: """ link source column to target. """ self.go.add_edge_if_not_exist(src, tgt, EdgeType.LINEAGE) self.go.add_edge_if_not_exist(tgt.parent, tgt, EdgeType.HAS_COLUMN) self.go.add_edge_if_not_exist(src.parent, src, EdgeType.HAS_COLUMN)
def get_table_columns(self, table: Table | SubQuery) -> list[Column]: return [ edge.target for edge in self.go.retrieve_edges_by_vertex( table, EdgeDirection.OUT, EdgeType.HAS_COLUMN ) if isinstance(edge.target, Column) and edge.target.raw_name != "*" ] def expand_wildcard(self, metadata_provider: MetaDataProvider) -> None: if tgt_table := self._get_target_table(): for column in self.write_columns: if column.raw_name == "*": tgt_wildcard = column src_wildcards = self.get_source_columns(tgt_wildcard) # Enable positional mapping only for UNION-of-* into a real table; avoid join/subquery cases wildcard_in_union = ( isinstance(tgt_table, Table) and len(self.write_columns) == 1 and len(src_wildcards) > 1 ) for src_wildcard in src_wildcards: if source_table := src_wildcard.parent: src_table_columns = [] if isinstance(source_table, SubQuery): # the columns of SubQuery can be inferred from graph src_table_columns = self.get_table_columns(source_table) elif isinstance(source_table, Table) and metadata_provider: # search by metadata service src_table_columns = metadata_provider.get_table_columns( source_table ) if src_table_columns: self._replace_wildcard( tgt_table, src_table_columns, tgt_wildcard, src_wildcard, wildcard_in_union=wildcard_in_union, )
[docs] def get_alias_mapping_from_table_group( self, table_group: list[Path | Table | SubQuery] ) -> dict[str, Path | Table | SubQuery]: """ A table can be referred to as alias, table name, or database_name.table_name, create the mapping here. For SubQuery, it's only alias then. """ alias_map = { edge.target: edge.source for edge in self.go.retrieve_edges_by_label(label=EdgeType.HAS_ALIAS) if edge.source in table_group } unqualified_map = { table.raw_name: table for table in table_group if isinstance(table, Table) } qualified_map = { str(table): table for table in table_group if isinstance(table, Table) } return alias_map | unqualified_map | qualified_map
def _get_target_table(self) -> SubQuery | Table | None: table = None if write_only := self.write.difference(self.read): table = next(iter(write_only)) return table def get_source_columns(self, node: Column) -> list[Column]: return [ e.source for e in self.go.retrieve_edges_by_vertex( node, EdgeDirection.IN, EdgeType.LINEAGE ) if isinstance(e.source, Column) ] def _replace_wildcard( self, tgt_table: Table | SubQuery, src_table_columns: list[Column], tgt_wildcard: Column, src_wildcard: Column, wildcard_in_union: bool = False, ) -> None: target_columns = self.get_table_columns(tgt_table) use_positional = wildcard_in_union or ( len(target_columns) == len(src_table_columns) ) for idx, src_col in enumerate(src_table_columns): if use_positional and idx < len(target_columns): target_col = target_columns[idx] else: # otherwise, if target column with same name exists (union scenario), reuse it; or create a new one existing_col = next( (c for c in target_columns if c.raw_name == src_col.raw_name), None, ) if existing_col is None: new_column = Column(src_col.raw_name) new_column.parent = tgt_table self.go.add_edge_if_not_exist( tgt_table, new_column, EdgeType.HAS_COLUMN ) target_col = new_column # keep local target_columns in sync to preserve order for the same call target_columns.append(target_col) else: target_col = existing_col # ensure source column node exists and link lineage if src_col.parent is not None: self.go.add_edge_if_not_exist( src_col.parent, src_col, EdgeType.HAS_COLUMN ) self.go.add_edge_if_not_exist(src_col, target_col, EdgeType.LINEAGE) # preserve SubQuery wildcards in the lineage graph to maintain the wildcard chain in case of partial expansion # otherwise, remove wildcard for Table if not isinstance(tgt_table, SubQuery): self.go.drop_vertices(tgt_wildcard) if src_wildcard.parent is not None and not isinstance( src_wildcard.parent, SubQuery ): self.go.drop_vertices(src_wildcard)
[docs] class StatementLineageHolder(SubQueryLineageHolder, ColumnLineageMixin): """ Statement Level Lineage Result. Based on SubQueryLineageHolder, StatementLineageHolder holds extra attributes like drop and rename For drop, it is a set[:class:`sqllineage.core.models.Table`]. For rename, it a set[tuple[:class:`sqllineage.core.models.Table`, :class:`sqllineage.core.models.Table`]], with the first table being original table before renaming and the latter after renaming. """ def __str__(self): return "\n".join( f"table {attr}: {sorted(getattr(self, attr), key=lambda x: str(x)) if getattr(self, attr) else '[]'}" for attr in ["read", "write", "cte", "drop", "rename"] ) def __repr__(self): return str(self) @property def read(self) -> set[Table]: # type: ignore return {t for t in super().read if isinstance(t, DATASET_CLASSES)} @property def write(self) -> set[Table]: # type: ignore return {t for t in super().write if isinstance(t, DATASET_CLASSES)} @property def drop(self) -> set[Table]: return self._property_getter(NodeTag.DROP) # type: ignore def add_drop(self, value) -> None: self._property_setter(value, NodeTag.DROP) @property def rename(self) -> set[tuple[Table, Table]]: return { (e.source, e.target) for e in self.go.retrieve_edges_by_label(EdgeType.RENAME) } def add_rename(self, src: Table, tgt: Table) -> None: self.go.add_edge_if_not_exist(src, tgt, EdgeType.RENAME) @staticmethod def of(holder: SubQueryLineageHolder) -> "StatementLineageHolder": stmt_holder = StatementLineageHolder() stmt_holder.go = holder.go return stmt_holder
[docs] class SQLLineageHolder(ColumnLineageMixin): def __init__(self, go: GraphOperator): """ The combined lineage result in representation of Directed Acyclic Graph. :param go: the Graph Operator holding all the combined lineage result. """ self.go = go self._selfloop_tables = self.__retrieve_tag_tables(NodeTag.SELFLOOP) self._sourceonly_tables = self.__retrieve_tag_tables(NodeTag.SOURCE_ONLY) self._targetonly_tables = self.__retrieve_tag_tables(NodeTag.TARGET_ONLY) @property def table_lineage_graph(self) -> GraphOperator: """ The table level GraphOperator held by SQLLineageHolder """ table_nodes = [ v for v in self.go.retrieve_vertices_by_props() if isinstance(v, DATASET_CLASSES) ] return self.go.get_sub_graph(*table_nodes) @property def column_lineage_graph(self) -> GraphOperator: """ The column level GraphOperator held by SQLLineageHolder """ column_nodes = [ v for v in self.go.retrieve_vertices_by_props() if isinstance(v, Column) ] return self.go.get_sub_graph(*column_nodes) @property def source_tables(self) -> set[Table | Path]: """ a list of source :class:`sqllineage.core.models.Table` """ source_tables = set(self.table_lineage_graph.retrieve_source_vertices()) source_tables |= self._selfloop_tables source_tables |= self._sourceonly_tables return source_tables @property def target_tables(self) -> set[Table | Path]: """ a list of target :class:`sqllineage.core.models.Table` """ target_tables = set(self.table_lineage_graph.retrieve_target_vertices()) target_tables |= self._selfloop_tables target_tables |= self._targetonly_tables return target_tables @property def intermediate_tables(self) -> set[Table | Path]: """ a list of intermediate :class:`sqllineage.core.models.Table` """ all_tables: list[Table | Path] = ( self.table_lineage_graph.retrieve_vertices_by_props() ) intermediate_tables = { table for table in all_tables if len( self.table_lineage_graph.retrieve_edges_by_vertex( table, EdgeDirection.IN ) ) > 0 and len( self.table_lineage_graph.retrieve_edges_by_vertex( table, EdgeDirection.OUT ) ) > 0 } intermediate_tables -= self.__retrieve_tag_tables(NodeTag.SELFLOOP) return intermediate_tables def __retrieve_tag_tables(self, tag) -> set[Path | Table]: return { vertex for vertex in self.go.retrieve_vertices_by_props(**{tag: True}) if isinstance(vertex, DATASET_CLASSES) }
[docs] @staticmethod def of(metadata_provider, *args: StatementLineageHolder) -> "SQLLineageHolder": """ To assemble multiple :class:`sqllineage.core.holders.StatementLineageHolder` into :class:`sqllineage.core.holders.SQLLineageHolder` """ ngo = get_graph_operator_class()() for holder in args: ngo.merge(holder.go) if holder.drop: for table in holder.drop: if ( len(ngo.retrieve_edges_by_vertex(table, EdgeDirection.IN)) == 0 and len(ngo.retrieve_edges_by_vertex(table, EdgeDirection.OUT)) == 0 ): ngo.drop_vertices(table) elif holder.rename: for table_old, table_new in holder.rename: for edge in ngo.retrieve_edges_by_vertex( table_old, EdgeDirection.IN ): ngo.add_edge_if_not_exist( edge.source, table_new, edge.label, **edge.attributes ) for edge in ngo.retrieve_edges_by_vertex( table_old, EdgeDirection.OUT ): ngo.add_edge_if_not_exist( table_new, edge.target, edge.label, **edge.attributes ) ngo.drop_vertices(table_old) # remove possible self-loop edge created by rename ngo.drop_edge(table_new, table_new) if ( len(ngo.retrieve_edges_by_vertex(table_new, EdgeDirection.IN)) == 0 and len( ngo.retrieve_edges_by_vertex(table_new, EdgeDirection.OUT) ) == 0 ): ngo.drop_vertices(table_new) else: read, write = holder.read, holder.write if len(read) > 0 and len(write) == 0: # source only table comes from SELECT statement ngo.update_vertices(*read, **{NodeTag.SOURCE_ONLY: True}) elif len(read) == 0 and len(write) > 0: # target only table comes from case like: 1) INSERT/UPDATE constant values; 2) CREATE TABLE ngo.update_vertices(*write, **{NodeTag.TARGET_ONLY: True}) else: for source, target in itertools.product(read, write): ngo.add_edge_if_not_exist(source, target, EdgeType.LINEAGE) # selfloop table comes from cases like: INSERT INTO tbl (part='xx') SELECT * FROM tbl WHERE part = '' ngo.update_vertices( *ngo.retrieve_selfloop_vertices(), **{NodeTag.SELFLOOP: True} ) # find all the columns that we can't assign accurately to a parent table (with multiple parent candidates) unresolved_column_lineages = [ (e.source, e.target) for e in ngo.retrieve_edges_by_label(label=EdgeType.LINEAGE) if isinstance(e.source, Column) and len(e.source.parent_candidates) > 1 ] for unresolved_col, tgt_col in unresolved_column_lineages: # check if there's only one parent candidate contains the column with same name src_cols = [] # check if source column exists in graph (either from subquery or from table created in prev statement) for parent in unresolved_col.parent_candidates: src_col_candidate = Column(unresolved_col.raw_name) src_col_candidate.parent = parent parent_columns = [ e.target for e in ngo.retrieve_edges_by_vertex( parent, EdgeDirection.OUT, EdgeType.HAS_COLUMN ) ] if src_col_candidate in parent_columns: src_cols.append(src_col_candidate) # if not in graph, check if defined in table schema by metadata service if len(src_cols) == 0 and bool(metadata_provider): for parent in unresolved_col.parent_candidates: if ( isinstance(parent, Table) and str(parent.schema) != Schema.unknown ): for parent_col in metadata_provider.get_table_columns(parent): if unresolved_col.raw_name == parent_col.raw_name: src_cols.append(parent_col) # Multiple sources is a correct case for JOIN with USING # It incorrect for JOIN with ON, but sql without specifying an alias in this case will be invalid for src_col in src_cols: ngo.add_edge_if_not_exist(src_col, tgt_col, EdgeType.LINEAGE) if len(src_cols) > 0: # only delete unresolved column when it's resolved ngo.drop_edge(unresolved_col, tgt_col) # when unresolved column got resolved, it will be orphan node, and we can remove it # convert unresolved_column_lineages to a set of cols, otherwise if an unresolved appears multiple times, # calling retrieve_edges_by_vertex using a deleted node would cause inconsistent behavior for different graph # operator, e.g. NetworkX 3.x would throw exception while NetworkX 2.x and Rustworkx would succeed silently for unresolved_col in {col for col, _ in unresolved_column_lineages}: if ( len(ngo.retrieve_edges_by_vertex(unresolved_col, EdgeDirection.OUT)) == 0 and len(ngo.retrieve_edges_by_vertex(unresolved_col, EdgeDirection.IN)) == 0 ): ngo.drop_vertices(unresolved_col) return SQLLineageHolder(ngo)