Source code for sqllineage.core.holders

import itertools
from typing import Dict, List, Optional, Set, Tuple, Union

import networkx as nx
from networkx import DiGraph

from sqllineage.core.metadata_provider import MetaDataProvider
from sqllineage.core.models import Column, Path, Schema, SubQuery, Table
from sqllineage.utils.constant import EdgeTag, EdgeType, NodeTag

DATASET_CLASSES = (Path, Table)


class ColumnLineageMixin:
    def get_column_lineage(
        self, exclude_path_ending_in_subquery=True, exclude_subquery_columns=False
    ) -> Set[Tuple[Column, ...]]:
        """
        :param exclude_path_ending_in_subquery:  exclude_subquery rename to exclude_path_ending_in_subquery
               exclude column from SubQuery in the ending path
        :param exclude_subquery_columns: exclude column from SubQuery in the path.

        return a list of column tuple :class:`sqllineage.models.Column`
        """
        self.graph: DiGraph  # For mypy attribute checking
        # filter all the column node in the graph
        column_nodes = [n for n in self.graph.nodes if isinstance(n, Column)]
        column_graph = self.graph.subgraph(column_nodes)
        source_columns = {column for column, deg in column_graph.in_degree if deg == 0}
        # if a column lineage path ends at SubQuery, then it should be pruned
        target_columns = {
            node
            for node, deg in column_graph.out_degree
            if isinstance(node, Column) and deg == 0
        }
        if exclude_path_ending_in_subquery:
            target_columns = {
                node for node in target_columns if isinstance(node.parent, Table)
            }
        columns = set()
        for source, target in itertools.product(source_columns, target_columns):
            simple_paths = list(nx.all_simple_paths(self.graph, source, target))
            for path in simple_paths:
                if exclude_subquery_columns:
                    path = [
                        node for node in path if not isinstance(node.parent, SubQuery)
                    ]
                    if len(path) > 1:
                        columns.add(tuple(path))
                else:
                    columns.add(tuple(path))
        return columns


[docs] class SubQueryLineageHolder(ColumnLineageMixin): """ SubQuery/Query Level Lineage Result. SubQueryLineageHolder will hold attributes like read, write, cte. Each of them is a Set[:class:`sqllineage.core.models.Table`]. This is the most atomic representation of lineage result. """ def __init__(self) -> None: self.graph = nx.DiGraph() def __or__(self, other): self.graph = nx.compose(self.graph, other.graph) return self def _property_getter(self, prop) -> Set[Union[SubQuery, Table]]: return {t for t, attr in self.graph.nodes(data=True) if attr.get(prop) is True} def _property_setter(self, value, prop) -> None: self.graph.add_node(value, **{prop: True}) @property def read(self) -> Set[Union[SubQuery, Table]]: return self._property_getter(NodeTag.READ) def add_read(self, value) -> None: self._property_setter(value, NodeTag.READ) # the same table can be added (in SQL: joined) multiple times with different alias if hasattr(value, "alias"): self.graph.add_edge(value, value.alias, type=EdgeType.HAS_ALIAS) @property def write(self) -> Set[Union[SubQuery, Table]]: # SubQueryLineageHolder.write can return a single SubQuery or Table, or both when __or__ together. # This is different from StatementLineageHolder.write, where Table is the only possibility. return self._property_getter(NodeTag.WRITE) def add_write(self, value) -> None: self._property_setter(value, NodeTag.WRITE) @property def cte(self) -> Set[SubQuery]: return self._property_getter(NodeTag.CTE) # type: ignore def add_cte(self, value) -> None: self._property_setter(value, NodeTag.CTE) @property def write_columns(self) -> List[Column]: """ return a list of columns that write table contains. It's either manually added via `add_write_column` if specified in DML or automatic added via `add_column_lineage` after parsing from SELECT """ tgt_cols = [] if tgt_tbl := self._get_target_table(): tgt_col_with_idx: List[Tuple[Column, int]] = sorted( [ (col, attr.get(EdgeTag.INDEX, 0)) for tbl, col, attr in self.graph.out_edges(tgt_tbl, data=True) if attr["type"] == EdgeType.HAS_COLUMN ], key=lambda x: x[1], ) tgt_cols = [x[0] for x in tgt_col_with_idx] return tgt_cols
[docs] def add_write_column(self, *tgt_cols: Column) -> None: """ in case of DML with column specified, like: .. code-block:: sql INSERT INTO tab1 (col1, col2) SELECT col3, col4 this method is called to make sure tab1 has column col1 and col2 instead of col3 and col4 """ if self.write: tgt_tbl = list(self.write)[0] for idx, tgt_col in enumerate(tgt_cols): tgt_col.parent = tgt_tbl if tgt_col in self.write_columns: # for DDL with PARTITIONED BY (col) or CLUSTERED BY (col), column can be added multiple times break self.graph.add_edge( tgt_tbl, tgt_col, type=EdgeType.HAS_COLUMN, **{EdgeTag.INDEX: idx} )
[docs] def add_column_lineage(self, src: Column, tgt: Column) -> None: """ link source column to target. """ self.graph.add_edge(src, tgt, type=EdgeType.LINEAGE) self.graph.add_edge(tgt.parent, tgt, type=EdgeType.HAS_COLUMN) if src.parent is not None: # starting NetworkX v2.6, None is not allowed as node, see https://github.com/networkx/networkx/pull/4892 self.graph.add_edge(src.parent, src, type=EdgeType.HAS_COLUMN)
def get_table_columns(self, table: Union[Table, SubQuery]) -> List[Column]: return [ tgt for (src, tgt, edge_type) in self.graph.out_edges(nbunch=table, data="type") if edge_type == EdgeType.HAS_COLUMN and isinstance(tgt, Column) and tgt.raw_name != "*" ] def expand_wildcard(self, metadata_provider: MetaDataProvider) -> None: if tgt_table := self._get_target_table(): for column in self.write_columns: if column.raw_name == "*": tgt_wildcard = column for src_wildcard in self.get_source_columns(tgt_wildcard): if source_table := src_wildcard.parent: src_table_columns = [] if isinstance(source_table, SubQuery): # the columns of SubQuery can be inferred from graph src_table_columns = self.get_table_columns(source_table) elif isinstance(source_table, Table) and metadata_provider: # search by metadata service src_table_columns = metadata_provider.get_table_columns( source_table ) if src_table_columns: self._replace_wildcard( tgt_table, src_table_columns, tgt_wildcard, src_wildcard, )
[docs] def get_alias_mapping_from_table_group( self, table_group: List[Union[Path, Table, SubQuery]] ) -> Dict[str, Union[Path, Table, SubQuery]]: """ A table can be referred to as alias, table name, or database_name.table_name, create the mapping here. For SubQuery, it's only alias then. """ return { **{ tgt: src for src, tgt, attr in self.graph.edges(data=True) if attr.get("type") == EdgeType.HAS_ALIAS and src in table_group }, **{ table.raw_name: table for table in table_group if isinstance(table, Table) }, **{str(table): table for table in table_group if isinstance(table, Table)}, }
def _get_target_table(self) -> Optional[Union[SubQuery, Table]]: table = None if write_only := self.write.difference(self.read): table = next(iter(write_only)) return table def get_source_columns(self, node: Column) -> List[Column]: return [ src for (src, tgt, edge_type) in self.graph.in_edges(nbunch=node, data="type") if edge_type == EdgeType.LINEAGE and isinstance(src, Column) ] def _replace_wildcard( self, tgt_table: Union[Table, SubQuery], src_table_columns: List[Column], tgt_wildcard: Column, src_wildcard: Column, ) -> None: target_columns = self.get_table_columns(tgt_table) for src_col in src_table_columns: new_column = Column(src_col.raw_name) new_column.parent = tgt_table if new_column in target_columns or src_col.raw_name == "*": continue self.graph.add_edge(tgt_table, new_column, type=EdgeType.HAS_COLUMN) self.graph.add_edge(src_col.parent, src_col, type=EdgeType.HAS_COLUMN) self.graph.add_edge(src_col, new_column, type=EdgeType.LINEAGE) # remove wildcard if self.graph.has_node(tgt_wildcard): self.graph.remove_node(tgt_wildcard) if self.graph.has_node(src_wildcard): self.graph.remove_node(src_wildcard)
[docs] class StatementLineageHolder(SubQueryLineageHolder, ColumnLineageMixin): """ Statement Level Lineage Result. Based on SubQueryLineageHolder, StatementLineageHolder holds extra attributes like drop and rename For drop, it is a Set[:class:`sqllineage.core.models.Table`]. For rename, it a Set[Tuple[:class:`sqllineage.core.models.Table`, :class:`sqllineage.core.models.Table`]], with the first table being original table before renaming and the latter after renaming. """ def __str__(self): return "\n".join( f"table {attr}: {sorted(getattr(self, attr), key=lambda x: str(x)) if getattr(self, attr) else '[]'}" for attr in ["read", "write", "cte", "drop", "rename"] ) def __repr__(self): return str(self) @property def read(self) -> Set[Table]: # type: ignore return {t for t in super().read if isinstance(t, DATASET_CLASSES)} @property def write(self) -> Set[Table]: # type: ignore return {t for t in super().write if isinstance(t, DATASET_CLASSES)} @property def drop(self) -> Set[Table]: return self._property_getter(NodeTag.DROP) # type: ignore def add_drop(self, value) -> None: self._property_setter(value, NodeTag.DROP) @property def rename(self) -> Set[Tuple[Table, Table]]: return { (src, tgt) for src, tgt, attr in self.graph.edges(data=True) if attr.get("type") == EdgeType.RENAME } def add_rename(self, src: Table, tgt: Table) -> None: self.graph.add_edge(src, tgt, type=EdgeType.RENAME) @staticmethod def of(holder: SubQueryLineageHolder) -> "StatementLineageHolder": stmt_holder = StatementLineageHolder() stmt_holder.graph = holder.graph return stmt_holder
[docs] class SQLLineageHolder(ColumnLineageMixin): def __init__(self, graph: DiGraph): """ The combined lineage result in representation of Directed Acyclic Graph. :param graph: the Directed Acyclic Graph holding all the combined lineage result. """ self.graph = graph self._selfloop_tables = self.__retrieve_tag_tables(NodeTag.SELFLOOP) self._sourceonly_tables = self.__retrieve_tag_tables(NodeTag.SOURCE_ONLY) self._targetonly_tables = self.__retrieve_tag_tables(NodeTag.TARGET_ONLY) @property def table_lineage_graph(self) -> DiGraph: """ The table level DiGraph held by SQLLineageHolder """ table_nodes = [n for n in self.graph.nodes if isinstance(n, DATASET_CLASSES)] return self.graph.subgraph(table_nodes) @property def column_lineage_graph(self) -> DiGraph: """ The column level DiGraph held by SQLLineageHolder """ column_nodes = [n for n in self.graph.nodes if isinstance(n, Column)] return self.graph.subgraph(column_nodes) @property def source_tables(self) -> Set[Table]: """ a list of source :class:`sqllineage.core.models.Table` """ source_tables = { table for table, deg in self.table_lineage_graph.in_degree if deg == 0 }.intersection( {table for table, deg in self.table_lineage_graph.out_degree if deg > 0} ) source_tables |= self._selfloop_tables source_tables |= self._sourceonly_tables return source_tables @property def target_tables(self) -> Set[Table]: """ a list of target :class:`sqllineage.core.models.Table` """ target_tables = { table for table, deg in self.table_lineage_graph.out_degree if deg == 0 }.intersection( {table for table, deg in self.table_lineage_graph.in_degree if deg > 0} ) target_tables |= self._selfloop_tables target_tables |= self._targetonly_tables return target_tables @property def intermediate_tables(self) -> Set[Table]: """ a list of intermediate :class:`sqllineage.core.models.Table` """ intermediate_tables = { table for table, deg in self.table_lineage_graph.in_degree if deg > 0 }.intersection( {table for table, deg in self.table_lineage_graph.out_degree if deg > 0} ) intermediate_tables -= self.__retrieve_tag_tables(NodeTag.SELFLOOP) return intermediate_tables def __retrieve_tag_tables(self, tag) -> Set[Union[Path, Table]]: return { table for table, attr in self.graph.nodes(data=True) if attr.get(tag) is True and isinstance(table, DATASET_CLASSES) } @staticmethod def _build_digraph( metadata_provider: MetaDataProvider, *args: StatementLineageHolder ) -> DiGraph: g = DiGraph() for holder in args: g = nx.compose(g, holder.graph) if holder.drop: for table in holder.drop: if g.has_node(table) and g.degree[table] == 0: g.remove_node(table) elif holder.rename: for table_old, table_new in holder.rename: g = nx.relabel_nodes(g, {table_old: table_new}) g.remove_edge(table_new, table_new) if g.degree[table_new] == 0: g.remove_node(table_new) else: read, write = holder.read, holder.write if len(read) > 0 and len(write) == 0: # source only table comes from SELECT statement nx.set_node_attributes( g, {table: True for table in read}, NodeTag.SOURCE_ONLY ) elif len(read) == 0 and len(write) > 0: # target only table comes from case like: 1) INSERT/UPDATE constant values; 2) CREATE TABLE nx.set_node_attributes( g, {table: True for table in write}, NodeTag.TARGET_ONLY ) else: for source, target in itertools.product(read, write): g.add_edge(source, target, type=EdgeType.LINEAGE) nx.set_node_attributes( g, {table: True for table in {e[0] for e in nx.selfloop_edges(g)}}, NodeTag.SELFLOOP, ) # find all the columns that we can't assign accurately to a parent table (with multiple parent candidates) unresolved_cols = [ (s, t) for s, t in g.edges if isinstance(s, Column) and len(s.parent_candidates) > 1 ] for unresolved_col, tgt_col in unresolved_cols: # check if there's only one parent candidate contains the column with same name src_cols = [] # check if source column exists in graph (either from subquery or from table created in prev statement) for parent in unresolved_col.parent_candidates: src_col = Column(unresolved_col.raw_name) src_col.parent = parent if g.has_edge(parent, src_col): src_cols.append(src_col) # if not in graph, check if defined in table schema by metadata service if len(src_cols) == 0 and bool(metadata_provider): for parent in unresolved_col.parent_candidates: if ( isinstance(parent, Table) and str(parent.schema) != Schema.unknown ): columns = metadata_provider.get_table_columns(parent) for src_col in columns: if unresolved_col.raw_name == src_col.raw_name: src_cols.append(src_col) # Multiple sources is a correct case for JOIN with USING # It incorrect for JOIN with ON, but sql without specifying an alias in this case will be invalid for src_col in src_cols: g.add_edge(src_col, tgt_col, type=EdgeType.LINEAGE) if len(src_cols) > 0: # only delete unresolved column when it's resolved g.remove_edge(unresolved_col, tgt_col) # when unresolved column got resolved, it will be orphan node, and we can remove it for node in [n for n, deg in g.degree if deg == 0]: if isinstance(node, Column) and len(node.parent_candidates) > 1: g.remove_node(node) return g
[docs] @staticmethod def of(metadata_provider, *args: StatementLineageHolder) -> "SQLLineageHolder": """ To assemble multiple :class:`sqllineage.core.holders.StatementLineageHolder` into :class:`sqllineage.core.holders.SQLLineageHolder` """ g = SQLLineageHolder._build_digraph(metadata_provider, *args) return SQLLineageHolder(g)