Snorkeling with Snowflake

Snorkeling with Snowflake

How to use Ray to execute massively parallel compute on a Snowflake dataset

Ray and its managed offering Anyscale have been making waves now primarily because of how straightforward it makes for developers to leverage massively parallel compute with super simple idioms.

I have been using Snowflake at Toplyne for a long time and looking for ways to leverage Ray with Snowflake to unlock next-level compute power.

The Snowflake Python connector offers a basic API that lets us pull data from Snowflake in a batched manner. To date, I have been using Python multi-threading to patch together basic Snowflake workflows. A basic workflow lets us boot up multiple threads wherein individual batches of Snowflake data can be mapped to different threads. This lets us get a seeming increase in throughput.

from concurrent.futures import ThreadPoolExecutor

from snowflake.connector import connect

connect_args = {...}
query = "select * from SNOWFLAKE_SAMPLE_DATA.TPCH_SF1.LINEITEM"

def _compute(_sf_batch: ResultBatch):
    arrow_table = _sf_batch.to_arrow()
    # do more compute on this arrow table

with connect(**connect_args) as conn:
    with conn.cursor() as cur:
        cur.execute(query)
        batches = cur.get_result_batches()

with ThreadPoolExecutor() as _texec:
    for batch in batches:
        _texec.submit(_compute, batch)

_texec.shutdown()

Multithreading is built into Python and for a lot of general-purpose tooling. It so happens that we can marry Snowflake’s APIs with multithreading and patch a workflow, but this approach has limited horizontal scalability capabilities.

Let's see what this workflow will look like in Ray:

import ray
import pyarrow as pa

from snowflake.connector import connect

connect_args = {...}
query = "select * from SNOWFLAKE_SAMPLE_DATA.TPCH_SF1.LINEITEM"

def _compute(arrow_table: pa.table):
    # do more compute on this arrow table
    pass

snowflake_datasource = SnowflakeDatasource(connect_args, query)
rds = ray.data.read_datasource(snowflake_datasource)
rds.map_batches(_compute, batch_format="pyarrow")

This is what is so exciting about Ray. Simple idioms and maximum compute.

But wait, what is SnowflakeDatasource?

Ray’s ray-data library describes APIs to load data from different sources. The library implements a bunch of general-purpose APIs to read data from well-defined data sources. However, currently there are no APIs specifically for Snowflake.

Fortunately, implementing a data source for Snowflake is pretty straightforward.

How do we go about it though? We should collect some data points first:

  1. Ray has a guide wherein the implementation of a Mongo Datasource is described.

  2. Anyscale’s GitHub repo has a fork of Ray data which has an implementation of SnowflakeDataSource as well.

  3. Additionally, Ray has documented the block API which is fundamental to Ray’s internal data representation.

One of Ray’s standout features is that we can easily map individual Snowflake result batches to Ray’s data blocks. Based on this information, we can start with our implementation.

We need to implement two Ray classes and follow these 3 methods to get the entire thing going:

  1. ray.data.datasource.datasource.Datasource [source]
    a) create_reader [source]

  2. ray.data.datasource.datasource.Reader [source]
    a) estimate_inmemory_data_size [source]
    b) get_read_tasks [source]

Now that we know the which of this Ray-data API, we can dive deeper into the what and why of the API:

  1. Reader.get_read_tasks:
    a) Create a Snowflake connection.
    b) Execute the query.
    c) Fetch the snowflake ResultBatches.
    d) Generate read tasks.

    These read tasks fetch the data batch from Snowflake. Since Ray’s APIs are lazy, the memory footprint of this execution step is minimal.

  2. Reader.estimate_inmemory_data_size:
    Get the total size of the table as of when it’ll be loaded into memory.

    For our use, I’ll infer it to be the Pyarrow table size.

  3. Datasource.create_reader:
    Create an instance of a reader which has implemented the above two methods.

Now that we know what & why to implement, let's get into the how. I’ll be adding more descriptions in the documentation of this code.

from ray.data.datasource import Reader
from ray.data.block import BlockMetadata
from ray.data import ReadTask

from snowflake.connector.result_batch import ResultBatch

# The reader performs the heavy lifting
class _SnowflakeDatasourceReader(Reader):
    def __init__(self, connection_args: dict, query: str):
        # connection info like snowflake account name & credentials.
        self._connection_args = connection_args

        # the query to execute.
        self._query = query

    # this method will be reused in both creating the read_tasks
    # as well as calc
    @cached_property
    def _result_batches(self):
        # connect with snowflake
        with connect(**self._connection_args) as conn:
            # get the cursor
            with conn.cursor() as cur:
                cur.execute(self._query)
                # Get the result as batches.
                # This API has a minimal memory footprint because
                # the ResultBatch doesn't have any data. It only
                # tells us how to pull the data and what size/schema
                # to expect from this data once it lands.
                # The driver hence won't have any memory footprint
                # and can safely do the work of creating relevant
                # Block (s) for ray.
                batches = cur.get_result_batches()
        return batches

    def estimate_inmemory_data_size(self) -> Optional[int]:
        sz = None

        for batch in self._result_batches:
            sz = (sz or 0) + (batch.uncompressed_size or 0)

        ray_data_logger.info("Estimating in-memory data size %s", sz)
        return sz

    def get_read_tasks(self, parallelism: int) -> list[ReadTask]:
            read_tasks = []

            for batch in self._result_batches:
                # Map the batch metadata to the ray block metadata.
                metadata = BlockMetadata(
                    num_rows=batch.rowcount,
                    size_bytes=batch.uncompressed_size,
                    schema=pa.schema(
                        [
                            pa.field(
                                s.name,
                                FIELD_TYPE_TO_PA_TYPE[
                                    s.type_code
                                ]
                            )
                            for s in batch.schema
                        ]
                    ),
                    input_files=None,
                    exec_stats=None
                )

                # create a lazy handler that will load up the
                # ResultBatch in the worker and do the actual
                # pull from snowflake.
                _r_task = LazyReadTask(
                    arrow_batch=batch,
                    metadata=metadata
                )

                read_tasks.append(_r_task)

            return read_tasks

# This read task is what executes in the worker(s) and pulls the data
# from snowflake and returns an PyArrow table.
class LazyReadTask(ReadTask):
    def __init__(self, arrow_batch: ResultBatch, metadata: BlockMetadata):
        self._arrow_batch = arrow_batch
        self._metadata = metadata

    def _read_fn(self) -> Iterable[pa.Table]:
        ray_data_logger.debug(
            "Reading %s rows from Snowflake", self._metadata.num_rows
        )
        return [self._arrow_batch.to_arrow()]

Woah 😅, that is smooth.

Now let’s quickly tidy over the data source, which will let us juice the Ray system.

from ray.data.block import Block
from ray.data.datasource import Reader
from snowflake.connector import connect

class SnowflakeDatasource(Datasource):
    def __init__(self, connection_args: dict, query: str):
        self._connection_args = connection_args
        self._query = query

    def create_reader(self, **read_args) -> Reader:
        # Yesss! This is the Reader you had just implemented.
        return _SnowflakeDatasourceReader(
            connection_args=self._connection_args,
            query=self._query
        )

# This is it. You are not missing anything.
# To reaffirm. This is it. You are not missing anything.

That is all!

You have a Snowflake data source. The next time you want to use some Ray goodness on Snowflake, you won’t be left wanting for a fast-reading data source.

You already got it here.

Now please do the cool stuff and show it to me.

Appendix:

  1. The GitHub repo with my implementation.

  2. The Anyscale blog that motivated me: https://www.anyscale.com/blog/introducing-the-anyscale-snowflake-connector

  3. The corresponding Anyscale fork: https://github.com/anyscale/datasets-database/blob/master/python/ray/data/datasource/snowflake_datasource.py