Apache Spark's RewriteDistinctAggregate Rule
Table of Contents
Introduction
These are some of my notes on the optimizer rule RewriteDistinctAggregate. At the time I jotted down these notes, Apache Spark was at version 3.3.x, but I think these notes largely still hold true (as of some preview release of 4.1.0).
These notes help to remind me what RewriteDistinctAggregate is trying to achieve, so that I can understand the context of the rather complex code. Because these notes are for me and how my brain works, they are somewhat (and purposefully) imprecise. My descriptions of the query transformations performed by RewriteDistinctAggregate use SQL syntax rather than plan trees because that's just easier for me to grok. In real life, RewriteDistinctAggregate does not add UNION ALL operators to the query plan, but instead adds an Expand operator. Since there is no SQL syntax for the Expand operator, I use UNION ALL in the descriptions.
This rule goes against my preconceived notion about optimizer rules, namely that they are optional. My previous thought was that optimizer rules make a query more efficient, but the query should otherwise work without them. But if your query has multiple distinct aggregations of different expressions, then this rule must run for the query to return correct results.
For the rest of this post, assume this data (stolen from the rule's top-level comment and then expanded):
create or replace temp view data(key, cat1, cat2, id, value) as values
("a", "ca1", "cb1", 0, 10),
("a", "ca1", "cb2", 2, 5),
("b", "ca1", "cb1", 2, 13),
("b", "ca1", "cb1", 4, 17),
("b", "ca1", "cb3", 5, 2),
("c", "ca1", "cb2", 4, 3);
First case
This case has more than 1 distinct aggregation:
select
key,
count(distinct cat1) as cat_cnt1,
count(distinct cat2) as cat_cnt2,
sum(value) as total
from
data
group by
key;
We want to sum up all values of value, and we also want to count the distinct values of cat1 and cat2. We can't do that in a single pass (since by making the rows unique for cat1, for example, we lose data for cat2 and value).
First we need to create 3 groups of rows (groups 0, 1, 2, as indicated by a new column, gid). The view expand1 looks like this:
-- expand output: key, cat1, cat2, gid, value
create or replace temp view expand1 as
select
key,
null as cat1,
null as cat2,
0 as gid,
cast(value as bigint) as value
from data
union all
select
key,
cat1,
null as cat2,
1 as gid,
null as value
from data
union all
select
key,
null as cat1,
cat2,
2 as gid,
null as value
from data;
Output of expand1, if queried:
+---+----+----+---+-----+
|key|cat1|cat2|gid|value|
+---+----+----+---+-----+
|a |null|null|0 |10 |
|a |null|null|0 |5 |
|b |null|null|0 |13 |
|b |null|null|0 |17 |
|b |null|null|0 |2 |
|c |null|null|0 |3 |
|a |ca1 |null|1 |null |
|a |ca1 |null|1 |null |
|b |ca1 |null|1 |null |
|b |ca1 |null|1 |null |
|b |ca1 |null|1 |null |
|c |ca1 |null|1 |null |
|a |null|cb1 |2 |null |
|a |null|cb2 |2 |null |
|b |null|cb1 |2 |null |
|b |null|cb1 |2 |null |
|b |null|cb3 |2 |null |
|c |null|cb2 |2 |null |
+---+----+----+---+-----+
Note that all the rows with gid = 0 have null values for cat1 and cat2, but all the actual values (null or not) for value. Similarly, all rows with gid = 1 have null values for cat2 and value, but all the actual values (null or not) for cat1, etc.
The first aggregation (view agg1) will do two things for us: sum up the expression value and get distinct values for cat1 and cat2 (via the group by key, cat1, cat2, gid):
create or replace temp view agg1 as
select key, cat1, cat2, gid, sum(value) total
from expand1
group by key, cat1, cat2, gid;
Output of agg1, if queried:
+---+----+----+---+-----+
|key|cat1|cat2|gid|total|
+---+----+----+---+-----+
|a |null|null|0 |15 |
|b |null|null|0 |32 |
|c |null|null|0 |3 |
|a |ca1 |null|1 |null |
|b |ca1 |null|1 |null |
|c |ca1 |null|1 |null |
|a |null|cb1 |2 |null |
|a |null|cb2 |2 |null |
|b |null|cb1 |2 |null |
|b |null|cb3 |2 |null |
|c |null|cb2 |2 |null |
+---+----+----+---+-----+
Rows with gid = 0 will have null values for cat1 and cat2. So for those rows, grouping by key, cat1, cat2, and gid is essentially the same as grouping by key alone.
However, by including cat1 and cat2 in the group-by, we get the distinct values of cat1 (in group 1) for each value of key, as well as the distinct values of cat2 (in group 2) for each value of key.
The second and final aggregation will perform the distinct aggregations (without the distinct keyword, since the group-by in agg1 already provided distinct values for cat1 and cat2).
The second aggregation also provides the final sum of value by grabbing the first, albeit the only, non-null sum of value where gid = 0. Since we're grouping by key, we do this for each value of key.
select
key,
count(cat1) filter (where gid = 1) as cat_cnt1,
count(cat2) filter (where gid = 2) as cat_cnt2,
first(total, true) filter (where gid = 0) as total
from agg1
group by key;
Output of second and final aggregation, which matches the output of the example query:
+---+--------+--------+-----+
|key|cat_cnt1|cat_cnt2|total|
+---+--------+--------+-----+
|c |1 |1 |3 |
|b |1 |2 |32 |
|a |1 |2 |15 |
+---+--------+--------+-----+
Note that case #1 doesn't need a non-distinct aggregation (e.g., sum(value)). It just needs more than one distinct aggregation. So the following also qualifies as case #1:
-- also included in the first case
select
key,
count(distinct cat1) as cat_cnt1,
count(distinct cat2) as cat_cnt2
from
data
group by
key;
Here is what happens if there is more than one non-distinct aggregation: E.g.:
select
key,
count(distinct cat1) as cat_cnt1,
count(distinct cat2) as cat_cnt2,
sum(value) as total,
sum(id) as sum_id
from
data
group by
key;
This is logically broken down into these queries:
create or replace temp view expand1 as
select
key,
null as cat1,
null as cat2,
0 as gid,
value,
id
from data
union all
select
key,
cat1,
null as cat2,
1 as gid,
null as value,
null as id
from data
union all
select
key,
cat1 as null,
cat2,
2 as gid,
null as value,
null as id
from data;
create or replace temp view agg1 as
select key, cat1, cat2, gid, sum(value) total, sum(id) as sum_id
from expand1
group by key, cat1, cat2, gid;
select key,
count(cat1) filter (where gid = 1) as cat_cnt1,
count(cat2) filter (where gid = 2) as cat_cnt2,
first(total, true) filter (where gid = 0) as total,
first(sum_id, true) filter (where gid = 0) as sum_id
from agg1
group by key;
Note that the values for both expressions id and value are included in group 0. All values that serve as input to non-distinct aggregations go to group 0.
Here is what happens if the same expression shows up in both a distinct and non-distinct aggregation: E.g.:
select
key,
count(distinct cat1) as cat_cnt1,
count(distinct cat2) as cat_cnt2,
sum(value) as total,
max(cat2) as max_cat2
from
data
group by
key;
In this query, cat2 shows up both in a distinct aggregation as well as a non-distinct aggregation. This case is logically broken down into these queries:
create or replace temp view expand1 as
select
key,
null as cat1,
null as cat2,
0 as gid,
value,
cat2 as cat2_2
from data
union all
select
key,
cat1,
null as cat2,
1 as gid,
null as value,
null as cat2_2
from data
union all
select
key,
cat1 as null,
cat2,
2 as gid,
null as value,
null as cat2_2
from data;
create or replace temp view agg1 as
select key, cat1, cat2, gid, sum(value) total,
max(cat2_2) as max_cat2
from expand1
group by key, cat1, cat2, gid;
select key,
count(cat1) filter (where gid = 1) as cat_cnt1,
count(cat2) filter (where gid = 2) as cat_cnt2,
first(total, true) filter (where gid = 0) as total,
first(max_cat2, true) filter (where gid = 0) as max_cat2
from agg1
group by key;
Note that cat2 shows up in both group 0 (along with value) as well as group 2, but in two different positions in the projection: in position 5 (origin 0) in group 0, and position 2 in group 2.
To reiterate, the values for all non-distinct aggregations go into group 0.
Second case
The second case has two distinct aggregations and a non-distinct aggregation with a filter. E.g.:
select
key,
count(distinct cat1) as cat1_cnt,
count(distinct cat2) as cat2_cnt,
sum(value) filter (where id > 1) as total
from
data
group by
key;
The second case is very much like the first case, except we need to pass through id in group 0 (we also include an id column in groups 1 and 2, but in those rows the value of id is always null). We must do this because id is needed in an aggregation filter in view agg1.
-- expand output key, cat1, cat2, gid, value, id
-- (because id is used in the *regular* aggregation filter).
create or replace temp view expand1 as
select
key,
null as cat1,
null as cat2,
0 as gid,
cast(value as bigint) as value,
id
from data
union all
select
key,
cat1,
null as cat2,
1 as gid,
null as value,
null as id
from data
union all
select
key,
null as cat1,
cat2,
2 as gid,
null as value,
null as id
from data;
-- rows with gid = 0 will have null values for cat1, cat2,
-- so it's essentially group by key, and will pass through
-- values for value and id, so we can sum value here.
create or replace temp view agg1 as
select key, cat1, cat2, gid,
sum(value) filter (where id > 1) as total
from expand1
group by key, cat1, cat2, gid;
select key,
count(cat1) filter (where gid = 1) as cat_cnt1,
count(cat2) filter (where gid = 2) as cat_cnt1,
first(total, true) filter (where gid = 0) as total
from agg1
group by key;
Third case
The third case has a distinct aggregation with a filter. This case does not require a non-distinct aggregation or multiple distinct aggregations. E.g.:
select
key,
count(distinct cat1) filter (where id > 1) as cat1_cnt,
count(distinct cat2) filter (where id > 2) as cat2_cnt,
sum(value) filter (where id > 3) as total
from
data
group by
key;
This case is logically broken down into several queries, starting with the following:
-- expand output is key, cat1, cat2, gid, cond1, cond2, value, id
create or replace temp view expand1 as
select
key,
null as cat1,
null as cat2,
0 as gid,
null as cond1,
null as cond2,
cast(value as bigint) as value,
id -- only because the non-distinct aggregation uses it
from data
union all
select
key,
cat1,
null as cat2,
1 as gid,
(id > 1) as cond1,
null as cond2,
null as value,
null as id
from data
union all
select
key,
null as cat1,
cat2,
2 as gid,
null as cond1,
(id > 2) as cond2,
null as value,
null as id
from data;
id is included in the projection because id is used in the regular aggregation filter. If sum had not used it, it would not be projected because
group 1 uses cond1 and group 2 uses cond2.
create or replace temp view agg1 as
select
key,
cat1,
cat2,
gid,
max(cond1) as max_cond1,
max(cond2) as max_cond2,
sum(value) filter (where id > 3) as total
from
expand1
group by
key, cat1, cat2, gid;
select
key,
count(cat1) filter (where gid = 1 and max_cond1) as cat1_cnt,
count(cat2) filter (where gid = 2 and max_cond2) as cat2_cnt,
first(total, true) filter (where gid = 0) as total
from
agg1
group by
key;
Rows with gid = 0 will have null values for cat1, cat2, max_cond1, and max_cond2, so it's essentially a group by key, and will pass through values for value and id, so we can sum value here. We use max(cond1) because we care only if there is at least one row for a particular key/cat1 combo where cond1 is true. The same applies to max(cond2).
Note that even the following qualifies as case #3:
select
key,
count(distinct cat1) filter (where id > 1) as cat1_cnt
from
data
group by
key;