Scale Apache Airflow on ECS

At $day_job we use Apache Airflow to process aerial imagery. Airflow processes each flight using a dynamically generated DAG, and they vary in size: some have a few hundred tasks, and some others have thousands. Depending on the time of the year, we might be processing just a couple or about 30 flights at the same time.

Our Airflow instance runs on AWS ECS, and since we have very uneven processing patterns, we originally leveraged Application Autoscaling based on CPU usage. We needed to be able to scale workers on demand as new DAGs start and kill them when things are quieter.

CPU usage scaling certainly works, but it was too slow for our purposes. For example. one of the workers was GPU-intensive, but it didn’t stress the CPU enough to trigger the scale up event with the responsiveness we needed.

How did we improve this? We realized that this isn’t a web app where the number of visitors can be totally unpredictable. We actually know how many tasks are waiting to be executed! Why don’t we just scale the workers based on the Airflow task queue size?

Our original idea was to write a totally custom scaling engine: starting and stopping the tasks ourselves using the ECS API. Madness, there must be a better way… and there actually is a better way!

ECS supports a scaling policy called Target Tracking which is exactly what we need. It will basically add or remove tasks based on a custom metric.

All we needed then is to turn the queue size into that metric. The only caveat is that the ECS policy reads the metric data from Cloudwatch. The solution? CloudWatch custom metrics.

And how can we publish the task queue size to CloudWatch? Well, I can think of several ways, but in the end we used a Lambda function that runs every two minutes to write the custom metrics. The code that actually writes the data looks like this (we have two environments for Airflow, staging and production):

def publish_cloudwatch_metric(environment: str, metric_name: str, metric_value: int):
    cloudwatch = boto3.client("cloudwatch")
    response = cloudwatch.put_metric_data(
        MetricData=[
            {
                "MetricName": metric_name,
                "Dimensions": [
                    {
                        "Name": "Environment",
                        "Value": environment,
                    }
                ],
                "Unit": "None",
                "Value": metric_value,
            },
        ],
        Namespace="Airflow",
    )
    if response["ResponseMetadata"]["HTTPStatusCode"] == 200:
        logger.info(f"Published {environment}/{metric_name}: {metric_value}")
    else:
        logger.error(
            f"Error publishing {environment}/{metric_name}: {metric_value} - {response['ResponseMetadata']}"
        )

The task queue size can be read straight from the Airflow database (I’m sure newer versions of Airflow have a nice REST API for this)

def get_active_airflow_tasks(queue_names: list, environment: str) -> int:
    """Retrieves the number of active (queued or running) airflow tasks from running DAGs from a particular queue."""
    airflow_conn = get_airflow_db_connection(environment)
    with airflow_conn:
        with airflow_conn.cursor(
            cursor_factory=psycopg2.extras.NamedTupleCursor
        ) as cursor:
            query = """SELECT COUNT(DISTINCT ti.task_id)
                        FROM task_instance ti
                                 INNER JOIN dag_run dr ON ti.dag_id = dr.dag_id
                        WHERE ti.state IN ('queued', 'running')
                          AND ti.queue IN %(queue_names)s
                          AND dr.state = 'running'"""
            cursor.execute(query, {"queue_names": tuple(queue_names)})
            active_tasks = cursor.fetchone()

            return active_tasks.count

We have several queues based on the type of task, this reads the values for a given queue.

Obviously, we capped the max number of ECS tasks at some values that make sense for our use case, there isn’t a 1:1 correspondence between queue size and ECS tasks.

And that’s about it, once we replaced the old scaling policy with this new target tracking policy, things started to improved significantly. Workers scale up and down way faster than before, and that means our DAGs take less overall time to run.