Using custom Spark clusters (and interactive PySpark!) with SageMaker

Amazon EMR seems like the natural choice for running production Spark clusters on AWS, but it’s not so suited for development because it doesn’t support interactive PySpark sessions (at least as of the time of writing) and so rolling a custom Spark cluster seems to be the only option, particularly if you’re developing with SageMaker.

This isn’t actually as daunting as it sounds. You can do it in three easy steps:

  1. Roll a custom Spark cluster with flintrock
  2. Install Spark binaries on your SageMaker notebook instance
  3. Install PySpark and connect to your cluster from SageMaker

Rolling a custom cluster with flintrock

Flintrock is a simple command-line tool that allows you to orchestrate and administrate Spark clusters on EC2 with minimal configuration and hassle. Once you have a minimal configuration defined, spinning up a brand new cluster with tens of nodes takes less than five minutes.

I won’t go into deep detail about the flintrock setup (for that, see Heather Miller’s excellent and very detailed post here) but one thing to note is that it’s important to create a custom security group for your cluster (that is, aside from the flintrock and flintrock-* groups that flintrock creates itself) so that you can add your SageMaker notebook instance to the same group later - otherwise, you’ll end up with siloed resources that can’t talk to one another. You might be tempted to just add your notebook to the flintrock-* security group, but this can cause problems later if you try to terminate your cluster.

Once your cluster has been created, run the following script to figure out the private IP address of your cluster master on the subnet where you’ll add your SageMaker notebook instance to - you’ll need this later in order to point PySpark at your cluster master.

# Get the public DNS of your cluster master
master_dns=$(flintrock --config "config.yml" describe "my-cluster" | grep "master" | awk -F "master: " '{ print $2}')

# Look up the private IP address of your cluster master
aws ec2 describe-instances --filters "Name=dns-name, Values=$master_dns" | jq -r '.Reservations[0].Instances[0].PrivateIpAddress'

Installing Spark binaries on SageMaker

Next, create a Jupyter Notebook instance and assign it to the same VPC, subnet and security group that you specified when creating your cluster with flintrock. Start the instance, open a terminal and run the following script to set the environment variables and install Spark binaries:

sudo -i <<EOF

# Set environment variables
SPARK_LOCAL_IP="A.B.C.D" # This should be the IP address of your SageMaker notebook instance on the same subnet that the cluster is running on
PYSPARK_PYTHON="python3" # Or whatever Python you like

touch "$env_file"
  echo "export SPARK_HOME=$SPARK_HOME"
  echo "export PATH=$SPARK_HOME/bin:$PATH"
} >> "$env_file"
source "$env_file"

# Download and install Spark
filename="$(basename "$url")"
if [! -d "$SPARK_HOME"]; then
  wget "$url"
  tar -xzf "$filename"
  mv "${filename%.*}" "$SPARK_HOME"

# Restart the Jupyter server for changes to take effect
initctl restart jupyter-server --no-wait


If you use lifecycle configuration scripts, you can add the code between the EOF tags above to your to automatically install Spark each time you start your instance, although you’ll need to figure out how to compute the value of SPARK_LOCAL_IP programatically (in my case, I filter the output of hostname -I based on the CIDR block assigned to my VPC).

Installing PySpark and connecting to the cluster

Next, select (or create) a conda environment to work with and install the pyspark package:

# Activate your conda environment
source /home/ec2-user/anaconda3/bin/activate my-env

# Install the pyspark package
pip install pyspark

Finally, create a new Jupyter notebook instance from your updated conda environment and try a word count example to test things out:

from pyspark.sql import SparkSession

master_ip = '...' # Whatever your cluster master private IP was earlier
spark = SparkSession.builder\
                    .master(f'spark:{master_ip}//:7077') \
sc = spark.sparkContext

text = 'the quick brown fox jumps over the lazy dog'
rdd = sc.parallelize([text])
counts = rdd.flatMap(lambda line: line.split()) \
            .map(lambda word: (word, 1)) \
            .reduceByKey(lambda x, y: x + y)