Skip to content

Commit ea4d370

Browse files
authored
feat: Add GeneratedOutput enum for direct JSON returns (#1395)
feat: allow LlmGenerationClient::generate() to return JSON directly
1 parent 8a88925 commit ea4d370

File tree

7 files changed

+79
-23
lines changed

7 files changed

+79
-23
lines changed

rust/cocoindex/src/llm/anthropic.rs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use crate::prelude::*;
22
use base64::prelude::*;
33

44
use crate::llm::{
5-
LlmGenerateRequest, LlmGenerateResponse, LlmGenerationClient, OutputFormat,
5+
GeneratedOutput, LlmGenerateRequest, LlmGenerateResponse, LlmGenerationClient, OutputFormat,
66
ToJsonSchemaOptions, detect_image_mime_type,
77
};
88
use anyhow::Context;
@@ -126,22 +126,21 @@ impl LlmGenerationClient for Client {
126126
}
127127
}
128128
}
129-
let text = if let Some(json) = extracted_json {
130-
// Try strict JSON serialization first
131-
serde_json::to_string(&json)?
129+
let json_value = if let Some(json) = extracted_json {
130+
json
132131
} else {
133132
// Fallback: try text if no tool output found
134133
match &mut resp_json["content"][0]["text"] {
135134
serde_json::Value::String(s) => {
136135
// Try strict JSON parsing first
137136
match utils::deser::from_json_str::<serde_json::Value>(s) {
138-
Ok(_) => std::mem::take(s),
137+
Ok(value) => value,
139138
Err(e) => {
140139
// Try permissive json5 parsing as fallback
141140
match json5::from_str::<serde_json::Value>(s) {
142141
Ok(value) => {
143142
println!("[Anthropic] Used permissive JSON5 parser for output");
144-
serde_json::to_string(&value)?
143+
value
145144
}
146145
Err(e2) => {
147146
return Err(anyhow::anyhow!(format!(
@@ -160,7 +159,9 @@ impl LlmGenerationClient for Client {
160159
}
161160
};
162161

163-
Ok(LlmGenerateResponse { text })
162+
Ok(LlmGenerateResponse {
163+
output: GeneratedOutput::Json(json_value),
164+
})
164165
}
165166

166167
fn json_schema_options(&self) -> ToJsonSchemaOptions {

rust/cocoindex/src/llm/bedrock.rs

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use crate::prelude::*;
22
use base64::prelude::*;
33

44
use crate::llm::{
5-
LlmGenerateRequest, LlmGenerateResponse, LlmGenerationClient, OutputFormat,
5+
GeneratedOutput, LlmGenerateRequest, LlmGenerateResponse, LlmGenerationClient, OutputFormat,
66
ToJsonSchemaOptions, detect_image_mime_type,
77
};
88
use anyhow::Context;
@@ -83,6 +83,7 @@ impl LlmGenerationClient for Client {
8383
}
8484

8585
// Handle structured output using tool schema
86+
let has_json_schema = request.output_format.is_some();
8687
if let Some(OutputFormat::JsonSchema { schema, name }) = request.output_format.as_ref() {
8788
let schema_json = serde_json::to_value(schema)?;
8889
payload["toolConfig"] = serde_json::json!({
@@ -134,7 +135,7 @@ impl LlmGenerationClient for Client {
134135
let message = &output["message"];
135136
let content = &message["content"];
136137

137-
let text = if let Some(content_array) = content.as_array() {
138+
let generated_output = if let Some(content_array) = content.as_array() {
138139
// Look for tool use first (structured output)
139140
let mut extracted_json: Option<serde_json::Value> = None;
140141
for item in content_array {
@@ -148,7 +149,19 @@ impl LlmGenerationClient for Client {
148149

149150
if let Some(json) = extracted_json {
150151
// Return the structured output as JSON
151-
serde_json::to_string(&json)?
152+
GeneratedOutput::Json(json)
153+
} else if has_json_schema {
154+
// If JSON schema was requested but no tool output found, try parsing text as JSON
155+
let mut text_parts = Vec::new();
156+
for item in content_array {
157+
if let Some(text) = item.get("text") {
158+
if let Some(text_str) = text.as_str() {
159+
text_parts.push(text_str);
160+
}
161+
}
162+
}
163+
let text = text_parts.join("");
164+
GeneratedOutput::Json(serde_json::from_str(&text)?)
152165
} else {
153166
// Fall back to text content
154167
let mut text_parts = Vec::new();
@@ -159,13 +172,15 @@ impl LlmGenerationClient for Client {
159172
}
160173
}
161174
}
162-
text_parts.join("")
175+
GeneratedOutput::Text(text_parts.join(""))
163176
}
164177
} else {
165178
return Err(anyhow::anyhow!("No content found in Bedrock response"));
166179
};
167180

168-
Ok(LlmGenerateResponse { text })
181+
Ok(LlmGenerateResponse {
182+
output: generated_output,
183+
})
169184
}
170185

171186
fn json_schema_options(&self) -> ToJsonSchemaOptions {

rust/cocoindex/src/llm/gemini.rs

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
use crate::prelude::*;
22

33
use crate::llm::{
4-
LlmEmbeddingClient, LlmGenerateRequest, LlmGenerateResponse, LlmGenerationClient, OutputFormat,
5-
ToJsonSchemaOptions, detect_image_mime_type,
4+
GeneratedOutput, LlmEmbeddingClient, LlmGenerateRequest, LlmGenerateResponse,
5+
LlmGenerationClient, OutputFormat, ToJsonSchemaOptions, detect_image_mime_type,
66
};
77
use base64::prelude::*;
88
use google_cloud_aiplatform_v1 as vertexai;
@@ -134,6 +134,7 @@ impl LlmGenerationClient for AiStudioClient {
134134
}
135135

136136
// If structured output is requested, add schema and responseMimeType
137+
let has_json_schema = request.output_format.is_some();
137138
if let Some(OutputFormat::JsonSchema { schema, .. }) = &request.output_format {
138139
let schema_json = serde_json::to_value(schema)?;
139140
payload["generationConfig"] = serde_json::json!({
@@ -162,7 +163,13 @@ impl LlmGenerationClient for AiStudioClient {
162163
_ => bail!("No text in response"),
163164
};
164165

165-
Ok(LlmGenerateResponse { text })
166+
let output = if has_json_schema {
167+
GeneratedOutput::Json(serde_json::from_str(&text)?)
168+
} else {
169+
GeneratedOutput::Text(text)
170+
};
171+
172+
Ok(LlmGenerateResponse { output })
166173
}
167174

168175
fn json_schema_options(&self) -> ToJsonSchemaOptions {
@@ -331,6 +338,7 @@ impl LlmGenerationClient for VertexAiClient {
331338
});
332339

333340
// Compose generation config
341+
let has_json_schema = request.output_format.is_some();
334342
let mut generation_config = None;
335343
if let Some(OutputFormat::JsonSchema { schema, .. }) = &request.output_format {
336344
let schema_json = serde_json::to_value(schema)?;
@@ -367,7 +375,14 @@ impl LlmGenerationClient for VertexAiClient {
367375
else {
368376
bail!("No text in response");
369377
};
370-
Ok(super::LlmGenerateResponse { text })
378+
379+
let output = if has_json_schema {
380+
super::GeneratedOutput::Json(serde_json::from_str(&text)?)
381+
} else {
382+
super::GeneratedOutput::Text(text)
383+
};
384+
385+
Ok(super::LlmGenerateResponse { output })
371386
}
372387

373388
fn json_schema_options(&self) -> ToJsonSchemaOptions {

rust/cocoindex/src/llm/mod.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,15 @@ pub struct LlmGenerateRequest<'a> {
7474
pub output_format: Option<OutputFormat<'a>>,
7575
}
7676

77+
#[derive(Debug)]
78+
pub enum GeneratedOutput {
79+
Json(serde_json::Value),
80+
Text(String),
81+
}
82+
7783
#[derive(Debug)]
7884
pub struct LlmGenerateResponse {
79-
pub text: String,
85+
pub output: GeneratedOutput,
8086
}
8187

8288
#[async_trait]

rust/cocoindex/src/llm/ollama.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ impl LlmGenerationClient for Client {
8989
&self,
9090
request: super::LlmGenerateRequest<'req>,
9191
) -> Result<super::LlmGenerateResponse> {
92+
let has_json_schema = request.output_format.is_some();
9293
let req = OllamaRequest {
9394
model: request.model,
9495
prompt: request.user_prompt.as_ref(),
@@ -109,9 +110,14 @@ impl LlmGenerationClient for Client {
109110
.await
110111
.context("Ollama API error")?;
111112
let json: OllamaResponse = res.json().await?;
112-
Ok(super::LlmGenerateResponse {
113-
text: json.response,
114-
})
113+
114+
let output = if has_json_schema {
115+
super::GeneratedOutput::Json(serde_json::from_str(&json.response)?)
116+
} else {
117+
super::GeneratedOutput::Text(json.response)
118+
};
119+
120+
Ok(super::LlmGenerateResponse { output })
115121
}
116122

117123
fn json_schema_options(&self) -> super::ToJsonSchemaOptions {

rust/cocoindex/src/llm/openai.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ where
184184
&self,
185185
request: super::LlmGenerateRequest<'req>,
186186
) -> Result<super::LlmGenerateResponse> {
187+
let has_json_schema = request.output_format.is_some();
187188
let request = &request;
188189
let response = retryable::run(
189190
|| async {
@@ -203,7 +204,13 @@ where
203204
.and_then(|choice| choice.message.content)
204205
.ok_or_else(|| anyhow::anyhow!("No response from OpenAI"))?;
205206

206-
Ok(super::LlmGenerateResponse { text })
207+
let output = if has_json_schema {
208+
super::GeneratedOutput::Json(serde_json::from_str(&text)?)
209+
} else {
210+
super::GeneratedOutput::Text(text)
211+
};
212+
213+
Ok(super::LlmGenerateResponse { output })
207214
}
208215

209216
fn json_schema_options(&self) -> super::ToJsonSchemaOptions {

rust/cocoindex/src/ops/functions/extract_by_llm.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use crate::llm::{
2-
LlmGenerateRequest, LlmGenerationClient, LlmSpec, OutputFormat, new_llm_generation_client,
2+
GeneratedOutput, LlmGenerateRequest, LlmGenerationClient, LlmSpec, OutputFormat,
3+
new_llm_generation_client,
34
};
45
use crate::ops::sdk::*;
56
use crate::prelude::*;
@@ -117,7 +118,12 @@ impl SimpleFunctionExecutor for Executor {
117118
}),
118119
};
119120
let res = self.client.generate(req).await?;
120-
let json_value: serde_json::Value = utils::deser::from_json_str(res.text.as_str())?;
121+
let json_value = match res.output {
122+
GeneratedOutput::Json(json) => json,
123+
GeneratedOutput::Text(text) => {
124+
bail!("Expected JSON response but got text: {}", text)
125+
}
126+
};
121127
let value = self.value_extractor.extract_value(json_value)?;
122128
Ok(value)
123129
}

0 commit comments

Comments
 (0)