Skip to content
162 changes: 139 additions & 23 deletions datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,21 @@
// under the License.

use crate::logical_plan::consumer::{from_substrait_func_args, SubstraitConsumer};
use datafusion::common::{not_impl_err, plan_err, substrait_err, DFSchema, ScalarValue};
use datafusion::common::Result;
use datafusion::common::{
not_impl_err, plan_err, substrait_err, DFSchema, DataFusionError, ScalarValue,
};
use datafusion::execution::FunctionRegistry;
use datafusion::logical_expr::{expr, BinaryExpr, Expr, Like, Operator};
use std::vec::Drain;
use substrait::proto::expression::ScalarFunction;
use substrait::proto::function_argument::ArgType;

pub async fn from_scalar_function(
consumer: &impl SubstraitConsumer,
f: &ScalarFunction,
input_schema: &DFSchema,
) -> datafusion::common::Result<Expr> {
) -> Result<Expr> {
let Some(fn_signature) = consumer
.get_extensions()
.functions
Expand All @@ -48,29 +52,14 @@ pub async fn from_scalar_function(
args,
)))
} else if let Some(op) = name_to_op(fn_name) {
if f.arguments.len() < 2 {
if args.len() < 2 {
return not_impl_err!(
"Expect at least two arguments for binary operator {op:?}, the provided number of operators is {:?}",
f.arguments.len()
);
}
// Some expressions are binary in DataFusion but take in a variadic number of args in Substrait.
// In those cases we iterate through all the arguments, applying the binary expression against them all
let combined_expr = args
.into_iter()
.fold(None, |combined_expr: Option<Expr>, arg: Expr| {
Some(match combined_expr {
Some(expr) => Expr::BinaryExpr(BinaryExpr {
left: Box::new(expr),
op,
right: Box::new(arg),
}),
None => arg,
})
})
.unwrap();

Ok(combined_expr)
// In those cases we build a balanced tree of BinaryExprs
arg_list_to_binary_op_tree(op, args)
} else if let Some(builder) = BuiltinExprBuilder::try_from_name(fn_name) {
builder.build(consumer, f, input_schema).await
} else {
Expand Down Expand Up @@ -124,6 +113,47 @@ pub fn name_to_op(name: &str) -> Option<Operator> {
}
}

/// Build a balanced tree of binary operations from a binary operator and a list of arguments.
///
/// For example, `OR` `(a, b, c, d, e)` will be converted to: `OR(OR(a, OR(b, c)), OR(d, e))`.
///
/// `args` must not be empty.
fn arg_list_to_binary_op_tree(op: Operator, mut args: Vec<Expr>) -> Result<Expr> {
let n_args = args.len();
let mut drained_args = args.drain(..);
arg_list_to_binary_op_tree_inner(op, &mut drained_args, n_args)
}

/// Helper function for [`arg_list_to_binary_op_tree`] implementation
///
/// `take_len` represents the number of elements to take from `args` before returning.
/// We use `take_len` to avoid recursively building a `Take<Take<Take<...>>>` type.
fn arg_list_to_binary_op_tree_inner(
op: Operator,
args: &mut Drain<Expr>,
take_len: usize,
) -> Result<Expr> {
if take_len == 1 {
return args.next().ok_or_else(|| {
DataFusionError::Substrait(
"Expected one more available element in iterator, found none".to_string(),
)
});
} else if take_len == 0 {
return substrait_err!("Cannot build binary operation tree with 0 arguments");
}
// Cut argument list in 2 balanced parts
let left_take = take_len / 2;
let right_take = take_len - left_take;
let left = arg_list_to_binary_op_tree_inner(op, args, left_take)?;
let right = arg_list_to_binary_op_tree_inner(op, args, right_take)?;
Ok(Expr::BinaryExpr(BinaryExpr {
left: Box::new(left),
op,
right: Box::new(right),
}))
}

/// Build [`Expr`] from its name and required inputs.
struct BuiltinExprBuilder {
expr_name: String,
Expand All @@ -146,7 +176,7 @@ impl BuiltinExprBuilder {
consumer: &impl SubstraitConsumer,
f: &ScalarFunction,
input_schema: &DFSchema,
) -> datafusion::common::Result<Expr> {
) -> Result<Expr> {
match self.expr_name.as_str() {
"like" => Self::build_like_expr(consumer, false, f, input_schema).await,
"ilike" => Self::build_like_expr(consumer, true, f, input_schema).await,
Expand All @@ -166,7 +196,7 @@ impl BuiltinExprBuilder {
fn_name: &str,
f: &ScalarFunction,
input_schema: &DFSchema,
) -> datafusion::common::Result<Expr> {
) -> Result<Expr> {
if f.arguments.len() != 1 {
return substrait_err!("Expect one argument for {fn_name} expr");
}
Expand Down Expand Up @@ -200,7 +230,7 @@ impl BuiltinExprBuilder {
case_insensitive: bool,
f: &ScalarFunction,
input_schema: &DFSchema,
) -> datafusion::common::Result<Expr> {
) -> Result<Expr> {
let fn_name = if case_insensitive { "ILIKE" } else { "LIKE" };
if f.arguments.len() != 2 && f.arguments.len() != 3 {
return substrait_err!("Expect two or three arguments for `{fn_name}` expr");
Expand Down Expand Up @@ -254,3 +284,89 @@ impl BuiltinExprBuilder {
}))
}
}

#[cfg(test)]
mod tests {
use super::arg_list_to_binary_op_tree;
use crate::extensions::Extensions;
use crate::logical_plan::consumer::tests::TEST_SESSION_STATE;
use crate::logical_plan::consumer::{DefaultSubstraitConsumer, SubstraitConsumer};
use datafusion::arrow::datatypes::{DataType, Field, Schema};
use datafusion::common::{DFSchema, Result, ScalarValue};
use datafusion::logical_expr::{Expr, Operator};
use insta::assert_snapshot;
use substrait::proto::expression::literal::LiteralType;
use substrait::proto::expression::{Literal, RexType, ScalarFunction};
use substrait::proto::function_argument::ArgType;
use substrait::proto::{Expression, FunctionArgument};

/// Test that large argument lists for binary operations do not crash the consumer
#[tokio::test]
async fn test_binary_op_large_argument_list() -> Result<()> {
// Build substrait extensions (we are using only one function)
let mut extensions = Extensions::default();
extensions.functions.insert(0, String::from("or:bool_bool"));
// Build substrait consumer
let consumer = DefaultSubstraitConsumer::new(&extensions, &TEST_SESSION_STATE);

// Build arguments for the function call, this is basically an OR(true, true, ..., true)
let arg = FunctionArgument {
arg_type: Some(ArgType::Value(Expression {
rex_type: Some(RexType::Literal(Literal {
nullable: false,
type_variation_reference: 0,
literal_type: Some(LiteralType::Boolean(true)),
})),
})),
};
let arguments = vec![arg; 50000];
let func = ScalarFunction {
function_reference: 0,
arguments,
..Default::default()
};
// Trivial input schema
let schema = Schema::new(vec![Field::new("a", DataType::Boolean, false)]);
let df_schema = DFSchema::try_from(schema).unwrap();

// Consume the expression and ensure we don't crash
let _ = consumer.consume_scalar_function(&func, &df_schema).await?;
Ok(())
}

fn int64_literals(integers: &[i64]) -> Vec<Expr> {
integers
.iter()
.map(|value| Expr::Literal(ScalarValue::Int64(Some(*value))))
.collect()
}

#[test]
fn arg_list_to_binary_op_tree_1_arg() -> Result<()> {
let expr = arg_list_to_binary_op_tree(Operator::Or, int64_literals(&[1]))?;
assert_snapshot!(expr.to_string(), @"Int64(1)");
Ok(())
}

#[test]
fn arg_list_to_binary_op_tree_2_args() -> Result<()> {
let expr = arg_list_to_binary_op_tree(Operator::Or, int64_literals(&[1, 2]))?;
assert_snapshot!(expr.to_string(), @"Int64(1) OR Int64(2)");
Ok(())
}

#[test]
fn arg_list_to_binary_op_tree_3_args() -> Result<()> {
let expr = arg_list_to_binary_op_tree(Operator::Or, int64_literals(&[1, 2, 3]))?;
assert_snapshot!(expr.to_string(), @"Int64(1) OR Int64(2) OR Int64(3)");
Ok(())
}

#[test]
fn arg_list_to_binary_op_tree_4_args() -> Result<()> {
let expr =
arg_list_to_binary_op_tree(Operator::Or, int64_literals(&[1, 2, 3, 4]))?;
assert_snapshot!(expr.to_string(), @"Int64(1) OR Int64(2) OR Int64(3) OR Int64(4)");
Ok(())
}
}