Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 49 additions & 17 deletions datafusion/physical-plan/src/joins/sort_merge_join/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,13 @@ use crate::{PhysicalExpr, RecordBatchStream, SendableRecordBatchStream};
use arrow::array::{types::UInt64Type, *};
use arrow::compute::{
self, BatchCoalescer, SortOptions, concat_batches, filter_record_batch, is_not_null,
take,
take, take_arrays,
};
use arrow::datatypes::{DataType, SchemaRef, TimeUnit};
use arrow::error::ArrowError;
use arrow::ipc::reader::StreamReader;
use datafusion_common::config::SpillCompression;
use datafusion_common::{
DataFusionError, HashSet, JoinType, NullEquality, Result, exec_err, internal_err,
not_impl_err,
HashSet, JoinType, NullEquality, Result, exec_err, internal_err, not_impl_err,
};
use datafusion_execution::disk_manager::RefCountedTempFile;
use datafusion_execution::memory_pool::MemoryReservation;
Expand Down Expand Up @@ -1248,13 +1246,19 @@ impl SortMergeJoinStream {
continue;
}

let mut left_columns = self
.streamed_batch
.batch
.columns()
.iter()
.map(|column| take(column, &left_indices, None))
.collect::<Result<Vec<_>, ArrowError>>()?;
let mut left_columns = if let Some(range) = is_contiguous_range(&left_indices)
{
// When indices form a contiguous range (common for the streamed
// side which advances sequentially), use zero-copy slice instead
// of the O(n) take kernel.
self.streamed_batch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great idea - it is something that could be done for probe side of hash join as well.

.batch
.slice(range.start, range.len())
.columns()
.to_vec()
} else {
take_arrays(self.streamed_batch.batch.columns(), &left_indices, None)?
};

// The row indices of joined buffered batch
let right_indices: UInt64Array = chunk.buffered_indices.finish();
Expand Down Expand Up @@ -1577,6 +1581,30 @@ fn produce_buffered_null_batch(
)?))
}

/// Checks if a `UInt64Array` contains a contiguous ascending range (e.g. \[3,4,5,6\]).
/// Returns `Some(start..start+len)` if so, `None` otherwise.
/// This allows replacing an O(n) `take` with an O(1) `slice`.
#[inline]
fn is_contiguous_range(indices: &UInt64Array) -> Option<Range<usize>> {
if indices.is_empty() || indices.null_count() > 0 {
return None;
}
let values = indices.values();
let start = values[0];
let len = values.len() as u64;
// Quick rejection: if last element doesn't match expected, not contiguous
if values[values.len() - 1] != start + len - 1 {
return None;
}
// Verify every element is sequential (handles duplicates and gaps)
for i in 1..values.len() {
if values[i] != start + i as u64 {
return None;
}
}
Some(start as usize..(start + len) as usize)
}

/// Get `buffered_indices` rows for `buffered_data[buffered_batch_idx]` by specific column indices
#[inline(always)]
fn fetch_right_columns_by_idxs(
Expand All @@ -1597,12 +1625,16 @@ fn fetch_right_columns_from_batch_by_idxs(
) -> Result<Vec<ArrayRef>> {
match &buffered_batch.batch {
// In memory batch
BufferedBatchState::InMemory(batch) => Ok(batch
.columns()
.iter()
.map(|column| take(column, &buffered_indices, None))
.collect::<Result<Vec<_>, ArrowError>>()
.map_err(Into::<DataFusionError>::into)?),
// In memory batch
BufferedBatchState::InMemory(batch) => {
// When indices form a contiguous range (common in SMJ since the
// buffered side is scanned sequentially), use zero-copy slice.
if let Some(range) = is_contiguous_range(buffered_indices) {
Ok(batch.slice(range.start, range.len()).columns().to_vec())
} else {
Ok(take_arrays(batch.columns(), buffered_indices, None)?)
}
}
// If the batch was spilled to disk, less likely
BufferedBatchState::Spilled(spill_file) => {
let mut buffered_cols: Vec<ArrayRef> =
Expand Down