Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/main/cpp/jllama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jo
sparams.chat_template = "chatml";
}
}
ctx_server->chat_template = sparams.chat_template;

// print sample chat example to make it clear which template is used
{
Expand Down Expand Up @@ -358,6 +359,13 @@ JNIEXPORT jint JNICALL Java_de_kherud_llama_LlamaModel_requestCompletion(JNIEnv
json json_params = json::parse(c_params);
const bool infill = json_params.contains("input_prefix") || json_params.contains("input_suffix");

if (json_params.value("use_chat_template", false)) {
json chat;
chat.push_back({{"role", "system"}, {"content", ctx_server->system_prompt}});
chat.push_back({{"role", "user"}, {"content", json_params["prompt"]}});
json_params["prompt"] = format_chat(ctx_server->model, ctx_server->chat_template, chat);
}

const int id_task = ctx_server->queue_tasks.get_new_id();
ctx_server->queue_results.add_waiting_task_id(id_task);
ctx_server->request_completion(id_task, -1, json_params, infill, false);
Expand Down
3 changes: 3 additions & 0 deletions src/main/cpp/server.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,8 @@ struct server_context
std::string name_user; // this should be the antiprompt
std::string name_assistant;

std::string chat_template;

// slots / clients
std::vector<server_slot> slots;
json default_generation_settings_for_props;
Expand Down Expand Up @@ -2596,6 +2598,7 @@ static void server_params_parse(json jparams, server_params &sparams, gpt_params
params.no_kv_offload = json_value(jparams, "no_kv_offload", default_params.no_kv_offload);
server_log_json = !jparams.contains("log_format") || jparams["log_format"] == "json";
sparams.system_prompt = json_value(jparams, "system_prompt", default_sparams.system_prompt);
sparams.chat_template = json_value(jparams, "chat_template", default_sparams.chat_template);

if (jparams.contains("n_gpu_layers"))
{
Expand Down
10 changes: 10 additions & 0 deletions src/main/java/de/kherud/llama/InferenceParameters.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ public final class InferenceParameters extends JsonParameters {
private static final String PARAM_STOP = "stop";
private static final String PARAM_SAMPLERS = "samplers";
private static final String PARAM_STREAM = "stream";
private static final String PARAM_USE_CHAT_TEMPLATE = "use_chat_template";

public InferenceParameters(String prompt) {
// we always need a prompt
Expand Down Expand Up @@ -488,4 +489,13 @@ InferenceParameters setStream(boolean stream) {
parameters.put(PARAM_STREAM, String.valueOf(stream));
return this;
}

/**
* Set whether or not generate should apply a chat template (default: false)
*/
public InferenceParameters setUseChatTemplate(boolean useChatTemplate) {
parameters.put(PARAM_USE_CHAT_TEMPLATE, String.valueOf(useChatTemplate));
return this;
}

}
12 changes: 11 additions & 1 deletion src/main/java/de/kherud/llama/ModelParameters.java
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ public final class ModelParameters extends JsonParameters {
private static final String PARAM_NO_KV_OFFLOAD = "no_kv_offload";
private static final String PARAM_SYSTEM_PROMPT = "system_prompt";
private static final String PARAM_LOG_FORMAT = "log_format";
private static final String PARAM_CHAT_TEMPLATE = "chat_template";

/**
* Set the RNG seed
Expand Down Expand Up @@ -579,7 +580,7 @@ public ModelParameters setNoKvOffload(boolean noKvOffload) {
* Set a system prompt to use
*/
public ModelParameters setSystemPrompt(String systemPrompt) {
parameters.put(PARAM_SYSTEM_PROMPT, systemPrompt);
parameters.put(PARAM_SYSTEM_PROMPT, toJsonString(systemPrompt));
return this;
}

Expand All @@ -600,4 +601,13 @@ public ModelParameters setLogFormat(LogFormat logFormat) {
}
return this;
}

/**
* The chat template to use (default: empty)
*/
public ModelParameters setChatTemplate(String chatTemplate) {
parameters.put(PARAM_CHAT_TEMPLATE, toJsonString(chatTemplate));
return this;
}

}