| """ |
| This class create a connection to Snowflake, run queries (read and write) |
| """ |
| import json |
| import os |
| from snowflake.snowpark import Session |
| from dotenv import load_dotenv |
| import logging |
| logger = logging.getLogger() |
| load_dotenv() |
|
|
| class SnowFlakeConn: |
| def __init__(self): |
| self. session = self.connect_to_snowflake() |
|
|
|
|
| |
| def connect_to_snowflake(self): |
| |
| conn = dict( |
| user=self.get_credential("SNOWFLAKE_USER"), |
| password=self.get_credential("SNOWFLAKE_PASSWORD"), |
| account=self.get_credential("SNOWFLAKE_ACCOUNT"), |
| role=self.get_credential("SNOWFLAKE_ROLE"), |
| database=self.get_credential("SNOWFLAKE_DATABASE"), |
| warehouse=self.get_credential("SNOWFLAKE_WAREHOUSE"), |
| schema=self.get_credential("SNOWFLAKE_SCHEMA"), |
| ) |
|
|
| session = Session.builder.configs(conn).create() |
| return session |
|
|
| |
| def get_credential(self, key): |
| return os.getenv(key) |
|
|
| |
| def run_read_query(self, query, data): |
| """ |
| Executes a SQL query on Snowflake that fetch the data |
| :return: Pandas dataframe containing the query results |
| """ |
|
|
| |
| try: |
| dataframe = self.session.sql(query).to_pandas() |
| dataframe.columns = dataframe.columns.str.lower() |
| print(f"reading {data} table successfully") |
| return dataframe |
| except Exception as e: |
| print(f"Error in creating/updating table: {e}") |
|
|
| |
| def store_df_to_snowflake(self, table_name, dataframe, database="SOCIAL_MEDIA_DB", schema="ML_FEATURES", overwrite=False): |
| """ |
| Executes a SQL query on Snowflake that write the preprocessed data on new tables |
| :param query: SQL query string to be executed |
| :return: None |
| """ |
|
|
| try: |
| self.session.use_database(database) |
| self.session.use_schema(schema) |
|
|
| dataframe = dataframe.reset_index(drop=True) |
| dataframe.columns = dataframe.columns.str.upper() |
|
|
| self.session.write_pandas(df=dataframe, |
| table_name=table_name.strip().upper(), |
| auto_create_table=True, |
| overwrite=overwrite, |
| use_logical_type=True) |
| print(f"Data inserted into {table_name} successfully.") |
|
|
| except Exception as e: |
| print(f"Error in creating/updating/inserting table: {e}") |
|
|
| |
| def execute_sql_file(self, file_path): |
| """ |
| Executes SQL queries from a file |
| :param file_path: Path to SQL file |
| :return: Query result or None for DDL/DML |
| """ |
| try: |
| with open(file_path, 'r', encoding='utf-8') as file: |
| sql_content = file.read() |
|
|
| result = self.session.sql(sql_content).collect() |
| print(f"Successfully executed SQL from {file_path}") |
| return result |
| except Exception as e: |
| print(f"Error executing SQL file {file_path}: {e}") |
| return None |
|
|
| |
| def execute_query(self, query, description="query"): |
| """ |
| Executes a SQL query and returns results |
| :param query: SQL query string |
| :param description: Description of the query for logging |
| :return: Query results |
| """ |
| try: |
| result = self.session.sql(query).collect() |
| print(f"Successfully executed {description}") |
| return result |
| except Exception as e: |
| print(f"Error executing {description}: {e}") |
| return None |
|
|
|
|
| |
| def get_data(self, data): |
| |
| pass |
|
|
| |
| def close_connection(self): |
| self.session.close() |
|
|
|
|