Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 25 additions & 21 deletions datafusion/expr/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ use datafusion_common::tree_node::{
};
use datafusion_common::utils::get_at_indices;
use datafusion_common::{
internal_err, plan_datafusion_err, plan_err, Column, DFSchema, DFSchemaRef, HashMap,
Result, TableReference,
internal_err, plan_err, Column, DFSchema, DFSchemaRef, HashMap, Result,
TableReference,
};

#[cfg(not(feature = "sql"))]
Expand Down Expand Up @@ -66,6 +66,23 @@ pub fn grouping_set_expr_count(group_expr: &[Expr]) -> Result<usize> {
}
}

/// Internal helper that generates indices for powerset subsets using bitset iteration.
/// Returns an iterator of index vectors, where each vector contains the indices
/// of elements to include in that subset.
fn powerset_indices(len: usize) -> impl Iterator<Item = Vec<usize>> {
(0..(1 << len)).map(move |mask| {
let mut indices = vec![];
let mut bitset = mask;
while bitset > 0 {
let rightmost: u64 = bitset & !(bitset - 1);
let idx = rightmost.trailing_zeros() as usize;
indices.push(idx);
bitset &= bitset - 1;
}
indices
})
}

/// The [power set] (or powerset) of a set S is the set of all subsets of S, \
/// including the empty set and S itself.
///
Expand All @@ -83,26 +100,14 @@ pub fn grouping_set_expr_count(group_expr: &[Expr]) -> Result<usize> {
/// and hence the power set of S is {{}, {x}, {y}, {z}, {x, y}, {x, z}, {y, z}, {x, y, z}}.
///
/// [power set]: https://en.wikipedia.org/wiki/Power_set
fn powerset<T>(slice: &[T]) -> Result<Vec<Vec<&T>>, String> {
pub fn powerset<T>(slice: &[T]) -> Result<Vec<Vec<&T>>> {
if slice.len() >= 64 {
return Err("The size of the set must be less than 64.".into());
return plan_err!("The size of the set must be less than 64");
}

let mut v = Vec::new();
for mask in 0..(1 << slice.len()) {
let mut ss = vec![];
let mut bitset = mask;
while bitset > 0 {
let rightmost: u64 = bitset & !(bitset - 1);
let idx = rightmost.trailing_zeros();
let item = slice.get(idx as usize).unwrap();
ss.push(item);
// zero the trailing bit
bitset &= bitset - 1;
}
v.push(ss);
}
Ok(v)
Ok(powerset_indices(slice.len())
.map(|indices| indices.iter().map(|&idx| &slice[idx]).collect())
.collect())
}

/// check the number of expressions contained in the grouping_set
Expand Down Expand Up @@ -207,8 +212,7 @@ pub fn enumerate_grouping_sets(group_expr: Vec<Expr>) -> Result<Vec<Expr>> {
grouping_sets.iter().map(|e| e.iter().collect()).collect()
}
Expr::GroupingSet(GroupingSet::Cube(group_exprs)) => {
let grouping_sets = powerset(group_exprs)
.map_err(|e| plan_datafusion_err!("{}", e))?;
let grouping_sets = powerset(group_exprs)?;
check_grouping_sets_size_limit(grouping_sets.len())?;
grouping_sets
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
use crate::logical_plan::producer::{
from_aggregate_function, substrait_field_ref, SubstraitProducer,
};
use datafusion::common::{internal_err, not_impl_err, DFSchemaRef, DataFusionError};
use datafusion::common::{internal_err, not_impl_err, DFSchemaRef};
use datafusion::logical_expr::expr::Alias;
use datafusion::logical_expr::utils::powerset;
use datafusion::logical_expr::{Aggregate, Distinct, Expr, GroupingSet};
use substrait::proto::aggregate_rel::{Grouping, Measure};
use substrait::proto::rel::RelType;
Expand Down Expand Up @@ -91,10 +92,22 @@ pub fn to_substrait_groupings(
let groupings = match exprs.len() {
1 => match &exprs[0] {
Expr::GroupingSet(gs) => match gs {
GroupingSet::Cube(_) => Err(DataFusionError::NotImplemented(
"GroupingSet CUBE is not yet supported".to_string(),
)),
GroupingSet::GroupingSets(sets) => Ok(sets
GroupingSet::Cube(set) => {
// Generate power set of grouping expressions
let cube_sets = powerset(set)?;
cube_sets
.iter()
.map(|set| {
parse_flat_grouping_exprs(
producer,
&set.iter().map(|v| (*v).clone()).collect::<Vec<_>>(),
schema,
&mut ref_group_exprs,
)
})
.collect::<datafusion::common::Result<Vec<_>>>()
}
GroupingSet::GroupingSets(sets) => sets
.iter()
.map(|set| {
parse_flat_grouping_exprs(
Expand All @@ -104,14 +117,13 @@ pub fn to_substrait_groupings(
&mut ref_group_exprs,
)
})
.collect::<datafusion::common::Result<Vec<_>>>()?),
.collect::<datafusion::common::Result<Vec<_>>>(),
GroupingSet::Rollup(set) => {
let mut sets: Vec<Vec<Expr>> = vec![vec![]];
for i in 0..set.len() {
sets.push(set[..=i].to_vec());
}
Ok(sets
.iter()
sets.iter()
.rev()
.map(|set| {
parse_flat_grouping_exprs(
Expand All @@ -121,7 +133,7 @@ pub fn to_substrait_groupings(
&mut ref_group_exprs,
)
})
.collect::<datafusion::common::Result<Vec<_>>>()?)
.collect::<datafusion::common::Result<Vec<_>>>()
}
},
_ => Ok(vec![parse_flat_grouping_exprs(
Expand Down
20 changes: 20 additions & 0 deletions datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,26 @@ async fn aggregate_grouping_rollup() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn aggregate_grouping_cube() -> Result<()> {
let plan = generate_plan_from_sql(
"SELECT a, c, avg(b) FROM data GROUP BY CUBE (a, c)",
true,
true,
)
.await?;

assert_snapshot!(
plan,
@r#"
Projection: data.a, data.c, avg(data.b)
Aggregate: groupBy=[[GROUPING SETS ((), (data.a), (data.c), (data.a, data.c))]], aggr=[[avg(data.b)]]
TableScan: data projection=[a, b, c]
"#
);
Ok(())
}

#[tokio::test]
async fn multilayer_aggregate() -> Result<()> {
let plan = generate_plan_from_sql(
Expand Down