SQLglot – Get required input columns, output columns and input tables from an SQL query

SQLglot is cool.

Date de publication

23 août 2024

“my” script:



from sqlglot import parse_one, exp
query = """
select
    sales.order_id as id,
    p.product_name,
    sum(p.price) as sales_volume
from sales
right join products as p
    on sales.product_id=p.product_id
group by id, p.product_name;

"""


input_column_names = []


for column in parse_one(query).find_all(exp.Column):
    input_column_names.append(column.alias_or_name)

unique_input_column_names = list(set(input_column_names))
# print unique column input names  ['order_id', 'price', 'product_name', 'id', 'product_id'] 
print(unique_input_column_names) 

output_column_names = []
for select in parse_one(query).find_all(exp.Select):
    for projection in select.expressions:
        output_column_names.append(projection.alias_or_name)


unique_output_column_names = list(set(output_column_names))
# print unique output column names   ['sales_volume', 'product_name', 'id']
print(unique_output_column_names) 

table_names = []

for table in parse_one(query).find_all(exp.Table):
    table_names.append(table.name)
    
unique_table_names = list(set(table_names))
#print unique input table names ['products', 'sales']
print(unique_table_names) 
    

reference #1: https://github.com/tobymao/sqlglot

Metadata You can explore SQL with expression helpers to do things like find columns and tables in a query:

from sqlglot import parse_one, exp

# print all column references (a and b)
for column in parse_one("SELECT a, b + 1 AS c FROM d").find_all(exp.Column):
    print(column.alias_or_name)

# find all projections in select statements (a and c)
for select in parse_one("SELECT a, b + 1 AS c FROM d").find_all(exp.Select):
    for projection in select.expressions:
        print(projection.alias_or_name)

# find all tables (x, y, z)
for table in parse_one("SELECT * FROM x JOIN y JOIN z").find_all(exp.Table):
    print(table.name)

reference #2 https://stackoverflow.com/questions/70745252/how-to-extract-column-names-from-sql-query-using-python

import sqlglot
import sqlglot.expressions as exp

query = """
select
    sales.order_id as id,
    p.product_name,
    sum(p.price) as sales_volume
from sales
right join products as p
    on sales.product_id=p.product_id
group by id, p.product_name;

"""

column_names = []

for expression in sqlglot.parse_one(query).find(exp.Select).args["expressions"]:
    if isinstance(expression, exp.Alias):
        column_names.append(expression.text("alias"))
    elif isinstance(expression, exp.Column):
        column_names.append(expression.text("this"))

print(column_names)