Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 16 additions & 13 deletions datafusion/sql/src/unparser/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -706,13 +706,6 @@ impl Unparser<'_> {
Ok(())
}
LogicalPlan::Union(union) => {
if union.inputs.len() != 2 {
return not_impl_err!(
"UNION ALL expected 2 inputs, but found {}",
union.inputs.len()
);
}

// Covers cases where the UNION is a subquery and the projection is at the top level
if select.already_projected() {
return self.derive_with_dialect_alias(
Expand All @@ -729,12 +722,22 @@ impl Unparser<'_> {
.map(|input| self.select_to_sql_expr(input, query))
.collect::<Result<Vec<_>>>()?;

let union_expr = SetExpr::SetOperation {
op: ast::SetOperator::Union,
set_quantifier: ast::SetQuantifier::All,
left: Box::new(input_exprs[0].clone()),
right: Box::new(input_exprs[1].clone()),
};
if input_exprs.len() < 2 {
return internal_err!("UNION operator requires at least 2 inputs");
}

// Build the union expression tree bottom-up by reversing the order
// note that we are also swapping left and right inputs because of the rev
let union_expr = input_exprs
.into_iter()
.rev()
.reduce(|a, b| SetExpr::SetOperation {
op: ast::SetOperator::Union,
set_quantifier: ast::SetQuantifier::All,
left: Box::new(b),
right: Box::new(a),
})
.unwrap();

let Some(query) = query.as_mut() else {
return internal_err!(
Expand Down
56 changes: 52 additions & 4 deletions datafusion/sql/tests/cases/plan_to_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
// specific language governing permissions and limitations
// under the License.

use arrow_schema::*;
use arrow_schema::{DataType, Field, Schema};
use datafusion_common::{assert_contains, DFSchema, DFSchemaRef, Result, TableReference};
use datafusion_expr::test::function_stub::{
count_udaf, max_udaf, min_udaf, sum, sum_udaf,
};
use datafusion_expr::{
col, lit, table_scan, wildcard, Expr, Extension, LogicalPlan, LogicalPlanBuilder,
UserDefinedLogicalNode, UserDefinedLogicalNodeCore,
col, lit, table_scan, wildcard, EmptyRelation, Expr, Extension, LogicalPlan,
LogicalPlanBuilder, Union, UserDefinedLogicalNode, UserDefinedLogicalNodeCore,
};
use datafusion_functions::unicode;
use datafusion_functions_aggregate::grouping::grouping_udaf;
Expand All @@ -42,7 +42,7 @@ use std::{fmt, vec};

use crate::common::{MockContextProvider, MockSessionState};
use datafusion_expr::builder::{
table_scan_with_filter_and_fetch, table_scan_with_filters,
project, table_scan_with_filter_and_fetch, table_scan_with_filters,
};
use datafusion_functions::core::planner::CoreFunctionPlanner;
use datafusion_functions_nested::extract::array_element_udf;
Expand Down Expand Up @@ -1615,3 +1615,51 @@ fn test_unparse_extension_to_sql() -> Result<()> {
}
Ok(())
}

#[test]
fn test_unparse_optimized_multi_union() -> Result<()> {
let unparser = Unparser::default();

let schema = Schema::new(vec![
Field::new("x", DataType::Int32, false),
Field::new("y", DataType::Utf8, false),
]);

let dfschema = Arc::new(DFSchema::try_from(schema)?);

let empty = LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: true,
schema: dfschema.clone(),
});

let plan = LogicalPlan::Union(Union {
inputs: vec![
project(empty.clone(), vec![lit(1).alias("x"), lit("a").alias("y")])?.into(),
project(empty.clone(), vec![lit(1).alias("x"), lit("b").alias("y")])?.into(),
project(empty.clone(), vec![lit(2).alias("x"), lit("a").alias("y")])?.into(),
project(empty.clone(), vec![lit(2).alias("x"), lit("c").alias("y")])?.into(),
],
schema: dfschema.clone(),
});

let sql = "SELECT 1 AS x, 'a' AS y UNION ALL SELECT 1 AS x, 'b' AS y UNION ALL SELECT 2 AS x, 'a' AS y UNION ALL SELECT 2 AS x, 'c' AS y";

assert_eq!(unparser.plan_to_sql(&plan)?.to_string(), sql);

let plan = LogicalPlan::Union(Union {
inputs: vec![project(
empty.clone(),
vec![lit(1).alias("x"), lit("a").alias("y")],
)?
.into()],
schema: dfschema.clone(),
});

if let Some(err) = plan_to_sql(&plan).err() {
assert_contains!(err.to_string(), "UNION operator requires at least 2 inputs");
} else {
panic!("Expected error")
}

Ok(())
}