Source code for sqllineage.core.holders

import itertools
from typing import List, Set, Tuple, Union

import networkx as nx
from networkx import DiGraph

from sqllineage.core.models import Column, Path, SubQuery, Table
from sqllineage.utils.constant import EdgeTag, EdgeType, NodeTag


class ColumnLineageMixin:
    def get_column_lineage(self, exclude_subquery=True) -> Set[Tuple[Column, ...]]:
        self.graph: DiGraph  # For mypy attribute checking
        # filter all the column node in the graph
        column_nodes = [n for n in self.graph.nodes if isinstance(n, Column)]
        column_graph = self.graph.subgraph(column_nodes)
        source_columns = {column for column, deg in column_graph.in_degree if deg == 0}
        # if a column lineage path ends at SubQuery, then it should be pruned
        target_columns = {
            for node, deg in column_graph.out_degree
            if isinstance(node, Column) and deg == 0
        if exclude_subquery:
            target_columns = {
                node for node in target_columns if isinstance(node.parent, Table)
        columns = set()
        for source, target in itertools.product(source_columns, target_columns):
            simple_paths = list(nx.all_simple_paths(self.graph, source, target))
            for path in simple_paths:
        return columns

[docs]class SubQueryLineageHolder(ColumnLineageMixin): """ SubQuery/Query Level Lineage Result. SubQueryLineageHolder will hold attributes like read, write, cte Each of them is a Set[:class:`sqllineage.models.Table`]. This is the most atomic representation of lineage result. """ def __init__(self) -> None: self.graph = nx.DiGraph() def __or__(self, other): self.graph = nx.compose(self.graph, other.graph) return self def _property_getter(self, prop) -> Set[Union[SubQuery, Table]]: return {t for t, attr in self.graph.nodes(data=True) if attr.get(prop) is True} def _property_setter(self, value, prop) -> None: self.graph.add_node(value, **{prop: True}) @property def read(self) -> Set[Union[SubQuery, Table]]: return self._property_getter(NodeTag.READ) def add_read(self, value) -> None: self._property_setter(value, NodeTag.READ) # the same table can be added (in SQL: joined) multiple times with different alias if hasattr(value, "alias"): self.graph.add_edge(value, value.alias, type=EdgeType.HAS_ALIAS) @property def write(self) -> Set[Union[SubQuery, Table]]: return self._property_getter(NodeTag.WRITE) def add_write(self, value) -> None: self._property_setter(value, NodeTag.WRITE) @property def cte(self) -> Set[SubQuery]: return self._property_getter(NodeTag.CTE) # type: ignore def add_cte(self, value) -> None: self._property_setter(value, NodeTag.CTE) @property def write_columns(self) -> List[Column]: """ in case of DML with column specified, like INSERT INTO tab1 (col1, col2) SELECT ... write_columns tell us that tab1 has column col1 and col2 in that order. """ tgt_cols = [] if self.write: tgt_tbl = list(self.write)[0] tgt_col_with_idx: List[Tuple[Column, int]] = sorted( [ (col, attr.get(EdgeTag.INDEX, 0)) for tbl, col, attr in self.graph.out_edges(tgt_tbl, data=True) if attr["type"] == EdgeType.HAS_COLUMN ], key=lambda x: x[1], ) tgt_cols = [x[0] for x in tgt_col_with_idx] return tgt_cols def add_write_column(self, *tgt_cols: Column) -> None: if self.write: tgt_tbl = list(self.write)[0] for idx, tgt_col in enumerate(tgt_cols): tgt_col.parent = tgt_tbl if tgt_col in self.write_columns: # for DDL with PARTITIONED BY (col) or CLUSTERED BY (col), column can be added multiple times break self.graph.add_edge( tgt_tbl, tgt_col, type=EdgeType.HAS_COLUMN, **{EdgeTag.INDEX: idx} ) def add_column_lineage(self, src: Column, tgt: Column) -> None: self.graph.add_edge(src, tgt, type=EdgeType.LINEAGE) self.graph.add_edge(tgt.parent, tgt, type=EdgeType.HAS_COLUMN) if src.parent is not None: # starting NetworkX v2.6, None is not allowed as node, see self.graph.add_edge(src.parent, src, type=EdgeType.HAS_COLUMN)
[docs]class StatementLineageHolder(SubQueryLineageHolder, ColumnLineageMixin): """ Statement Level Lineage Result. Based on SubQueryLineageHolder, StatementLineageHolder holds extra attributes like drop and rename For drop, it is a Set[:class:`sqllineage.models.Table`]. For rename, it a Set[Tuple[:class:`sqllineage.models.Table`, :class:`sqllineage.models.Table`]], with the first table being original table before renaming and the latter after renaming. """ def __str__(self): return "\n".join( f"table {attr}: {sorted(getattr(self, attr), key=lambda x: str(x)) if getattr(self, attr) else '[]'}" for attr in ["read", "write", "cte", "drop", "rename"] ) def __repr__(self): return str(self) @property def read(self) -> Set[Table]: # type: ignore return {t for t in super().read if isinstance(t, DATASET_CLASSES)} @property def write(self) -> Set[Table]: # type: ignore return {t for t in super().write if isinstance(t, DATASET_CLASSES)} @property def drop(self) -> Set[Table]: return self._property_getter(NodeTag.DROP) # type: ignore def add_drop(self, value) -> None: self._property_setter(value, NodeTag.DROP) @property def rename(self) -> Set[Tuple[Table, Table]]: return { (src, tgt) for src, tgt, attr in self.graph.edges(data=True) if attr.get("type") == EdgeType.RENAME } def add_rename(self, src: Table, tgt: Table) -> None: self.graph.add_edge(src, tgt, type=EdgeType.RENAME) @staticmethod def of(holder: SubQueryLineageHolder) -> "StatementLineageHolder": stmt_holder = StatementLineageHolder() stmt_holder.graph = holder.graph return stmt_holder
[docs]class SQLLineageHolder(ColumnLineageMixin): def __init__(self, graph: DiGraph): """ The combined lineage result in representation of Directed Acyclic Graph. :param graph: the Directed Acyclic Graph holding all the combined lineage result. """ self.graph = graph self._selfloop_tables = self.__retrieve_tag_tables(NodeTag.SELFLOOP) self._sourceonly_tables = self.__retrieve_tag_tables(NodeTag.SOURCE_ONLY) self._targetonly_tables = self.__retrieve_tag_tables(NodeTag.TARGET_ONLY) @property def table_lineage_graph(self) -> DiGraph: """ The table level DiGraph held by SQLLineageHolder """ table_nodes = [n for n in self.graph.nodes if isinstance(n, DATASET_CLASSES)] return self.graph.subgraph(table_nodes) @property def column_lineage_graph(self) -> DiGraph: """ The column level DiGraph held by SQLLineageHolder """ column_nodes = [n for n in self.graph.nodes if isinstance(n, Column)] return self.graph.subgraph(column_nodes) @property def source_tables(self) -> Set[Table]: """ a list of source :class:`sqllineage.models.Table` """ source_tables = { table for table, deg in self.table_lineage_graph.in_degree if deg == 0 }.intersection( {table for table, deg in self.table_lineage_graph.out_degree if deg > 0} ) source_tables |= self._selfloop_tables source_tables |= self._sourceonly_tables return source_tables @property def target_tables(self) -> Set[Table]: """ a list of target :class:`sqllineage.models.Table` """ target_tables = { table for table, deg in self.table_lineage_graph.out_degree if deg == 0 }.intersection( {table for table, deg in self.table_lineage_graph.in_degree if deg > 0} ) target_tables |= self._selfloop_tables target_tables |= self._targetonly_tables return target_tables @property def intermediate_tables(self) -> Set[Table]: """ a list of intermediate :class:`sqllineage.models.Table` """ intermediate_tables = { table for table, deg in self.table_lineage_graph.in_degree if deg > 0 }.intersection( {table for table, deg in self.table_lineage_graph.out_degree if deg > 0} ) intermediate_tables -= self.__retrieve_tag_tables(NodeTag.SELFLOOP) return intermediate_tables def __retrieve_tag_tables(self, tag) -> Set[Union[Path, Table]]: return { table for table, attr in self.graph.nodes(data=True) if attr.get(tag) is True and isinstance(table, DATASET_CLASSES) } @staticmethod def _build_digraph(*args: StatementLineageHolder) -> DiGraph: """ To assemble multiple :class:`sqllineage.holders.StatementLineageHolder` into :class:`sqllineage.holders.SQLLineageHolder` """ g = DiGraph() for holder in args: g = nx.compose(g, holder.graph) if holder.drop: for table in holder.drop: if g.has_node(table) and[table] == 0: g.remove_node(table) elif holder.rename: for table_old, table_new in holder.rename: g = nx.relabel_nodes(g, {table_old: table_new}) g.remove_edge(table_new, table_new) if[table_new] == 0: g.remove_node(table_new) else: read, write =, holder.write if len(read) > 0 and len(write) == 0: # source only table comes from SELECT statement nx.set_node_attributes( g, {table: True for table in read}, NodeTag.SOURCE_ONLY ) elif len(read) == 0 and len(write) > 0: # target only table comes from case like: 1) INSERT/UPDATE constant values; 2) CREATE TABLE nx.set_node_attributes( g, {table: True for table in write}, NodeTag.TARGET_ONLY ) else: for source, target in itertools.product(read, write): g.add_edge(source, target, type=EdgeType.LINEAGE) nx.set_node_attributes( g, {table: True for table in {e[0] for e in nx.selfloop_edges(g)}}, NodeTag.SELFLOOP, ) # find all the columns that we can't assign accurately to a parent table (with multiple parent candidates) unresolved_cols = [ (s, t) for s, t in g.edges if isinstance(s, Column) and len(s.parent_candidates) > 1 ] for unresolved_col, tgt_col in unresolved_cols: # check if there's only one parent candidate contains the column with same name src_cols = [] for parent in unresolved_col.parent_candidates: src_col = Column(unresolved_col.raw_name) src_col.parent = parent if g.has_edge(parent, src_col): src_cols.append(src_col) if len(src_cols) == 1: g.add_edge(src_cols[0], tgt_col, type=EdgeType.LINEAGE) g.remove_edge(unresolved_col, tgt_col) # when unresolved column got resolved, it will be orphan node, and we can remove it for node in [n for n, deg in if deg == 0]: if isinstance(node, Column) and len(node.parent_candidates) > 1: g.remove_node(node) return g
[docs] @staticmethod def of(*args: StatementLineageHolder) -> "SQLLineageHolder": """ To assemble multiple :class:`sqllineage.holders.StatementLineageHolder` into :class:`sqllineage.holders.SQLLineageHolder` """ g = SQLLineageHolder._build_digraph(*args) return SQLLineageHolder(g)