Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 73 additions & 11 deletions src/webserver/database/error_highlighting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,24 @@ struct NiceDatabaseError {
query_position: Option<SourceSpan>,
}

fn write_source_position_info(
f: &mut std::fmt::Formatter<'_>,
source_file: &Path,
query_position: Option<SourceSpan>,
) -> Result<(), std::fmt::Error> {
write!(f, "\n{}", source_file.display())?;
if let Some(query_position) = query_position {
let start_line = query_position.start.line;
let end_line = query_position.end.line;
if start_line == end_line {
write!(f, ": line {start_line}")?;
} else {
write!(f, ": lines {start_line} to {end_line}")?;
}
}
Ok(())
}

impl std::fmt::Display for NiceDatabaseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
Expand Down Expand Up @@ -51,22 +69,32 @@ impl std::fmt::Display for NiceDatabaseError {

impl NiceDatabaseError {
fn show_position_info(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
write!(f, "\n{}", self.source_file.display())?;
let _: () = if let Some(query_position) = self.query_position {
let start_line = query_position.start.line;
let end_line = query_position.end.line;
if start_line == end_line {
write!(f, ": line {start_line}")?;
} else {
write!(f, ": lines {start_line} to {end_line}")?;
}
};
Ok(())
write_source_position_info(f, &self.source_file, self.query_position)
}
}

impl std::error::Error for NiceDatabaseError {}

#[derive(Debug)]
struct NicePositionedError {
source_file: PathBuf,
query_position: SourceSpan,
error: anyhow::Error,
}

impl std::fmt::Display for NicePositionedError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "In \"{}\": {}", self.source_file.display(), self.error)?;
write_source_position_info(f, &self.source_file, Some(self.query_position))
}
}

impl std::error::Error for NicePositionedError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
Some(self.error.as_ref())
}
}

/// Display a database error without any position information
#[must_use]
pub fn display_db_error(
Expand Down Expand Up @@ -97,6 +125,19 @@ pub fn display_stmt_db_error(
})
}

#[must_use]
pub fn display_stmt_error(
source_file: &Path,
query_position: SourceSpan,
error: anyhow::Error,
) -> anyhow::Error {
anyhow::Error::new(NicePositionedError {
source_file: source_file.to_path_buf(),
query_position,
error,
})
}

/// Highlight a line with a character offset.
pub fn highlight_line_offset<W: std::fmt::Write>(msg: &mut W, line: &str, offset: usize) {
writeln!(msg, "{line}").unwrap();
Expand Down Expand Up @@ -124,6 +165,27 @@ pub fn quote_source_with_highlight(source: &str, line_num: u64, col_num: u64) ->
msg
}

#[test]
fn test_display_stmt_error_includes_file_and_line() {
let err = display_stmt_error(
Path::new("example.sql"),
SourceSpan {
start: super::sql::SourceLocation {
line: 12,
column: 3,
},
end: super::sql::SourceLocation {
line: 12,
column: 17,
},
},
anyhow::anyhow!("boom"),
);
let message = err.to_string();
assert!(message.contains("In \"example.sql\": boom"));
assert!(message.contains("example.sql: line 12"));
}

#[test]
fn test_quote_source_with_highlight() {
let source = "SELECT *\nFROM table\nWHERE <syntax error>";
Expand Down
26 changes: 22 additions & 4 deletions src/webserver/database/execute_queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::path::Path;
use std::pin::Pin;

use super::csv_import::run_csv_import;
use super::error_highlighting::display_stmt_db_error;
use super::error_highlighting::{display_stmt_db_error, display_stmt_error};
use super::sql::{
DelayedFunctionCall, ParsedSqlFile, ParsedStatement, SimpleSelectValue, StmtWithParams,
};
Expand All @@ -17,6 +17,7 @@ use crate::webserver::database::sql_to_json::row_to_string;
use crate::webserver::http_request_info::ExecutionContext;
use crate::webserver::request_variables::SetVariablesMap;
use crate::webserver::single_or_vec::SingleOrVec;
use crate::webserver::ErrorWithStatus;

use super::syntax_tree::{extract_req_param, StmtParam};
use super::{error_highlighting::display_db_error, Database, DbItem};
Expand Down Expand Up @@ -57,7 +58,9 @@ pub fn stream_query_results_with_conn<'a>(
run_csv_import(connection, csv_import, request).await.with_context(|| format!("Failed to import the CSV file {:?} into the table {:?}", csv_import.uploaded_file, csv_import.table_name))?;
},
ParsedStatement::StmtWithParams(stmt) => {
let query = bind_parameters(stmt, request, db_connection).await?;
let query = bind_parameters(stmt, request, db_connection)
.await
.map_err(|e| with_stmt_position(source_file, stmt.query_position, e))?;
request.server_timing.record("bind_params");
let connection = take_connection(&request.app_state.db, db_connection, request).await?;
log::trace!("Executing query {:?}", query.sql);
Expand Down Expand Up @@ -93,8 +96,11 @@ pub fn stream_query_results_with_conn<'a>(
format!("Failed to set the {variable} variable to {value:?}")
)?;
},
ParsedStatement::StaticSimpleSelect(value) => {
for i in parse_dynamic_rows(DbItem::Row(exec_static_simple_select(value, request, db_connection).await?)) {
ParsedStatement::StaticSimpleSelect { values, query_position } => {
let row = exec_static_simple_select(values, request, db_connection)
.await
.map_err(|e| with_stmt_position(source_file, *query_position, e))?;
for i in parse_dynamic_rows(DbItem::Row(row)) {
yield i;
}
}
Expand All @@ -105,6 +111,18 @@ pub fn stream_query_results_with_conn<'a>(
.map(|res| res.unwrap_or_else(DbItem::Error))
}

fn with_stmt_position(
source_file: &Path,
query_position: super::sql::SourceSpan,
error: anyhow::Error,
) -> anyhow::Error {
if error.downcast_ref::<ErrorWithStatus>().is_some() {
error
} else {
display_stmt_error(source_file, query_position, error)
}
}

/// Transforms a stream of database items to stop processing after encountering the first error.
/// The error item itself is still emitted before stopping.
pub fn stop_at_first_error(
Expand Down
12 changes: 9 additions & 3 deletions src/webserver/database/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,10 @@ pub(super) struct SourceLocation {
#[derive(Debug)]
pub(super) enum ParsedStatement {
StmtWithParams(StmtWithParams),
StaticSimpleSelect(Vec<(String, SimpleSelectValue)>),
StaticSimpleSelect {
values: Vec<(String, SimpleSelectValue)>,
query_position: SourceSpan,
},
SetVariable {
variable: StmtParam,
value: StmtWithParams,
Expand Down Expand Up @@ -217,7 +220,10 @@ fn parse_single_statement(
}
if let Some(static_statement) = extract_static_simple_select(&stmt, &params) {
log::debug!("Optimised a static simple select to avoid a trivial database query: {stmt} optimized to {static_statement:?}");
return Some(ParsedStatement::StaticSimpleSelect(static_statement));
return Some(ParsedStatement::StaticSimpleSelect {
values: static_statement,
query_position: extract_query_start(&stmt),
});
}

let delayed_functions = extract_toplevel_functions(&mut stmt);
Expand Down Expand Up @@ -1042,7 +1048,7 @@ mod test {
};
let parsed: Vec<ParsedStatement> = parse_sql(&db_info, dialect, sql).unwrap().collect();
match &parsed[..] {
[ParsedStatement::StaticSimpleSelect(q)] => assert_eq!(
[ParsedStatement::StaticSimpleSelect { values: q, .. }] => assert_eq!(
q,
&[
("component".into(), Static("text".into())),
Expand Down