@@ -2941,9 +2941,6 @@ struct llama_sbatch_seq {
29412941 llama_seq_id * seq_id;
29422942 size_t offset;
29432943 size_t length;
2944-
2945- // helper for smoother batch API transition -- can be deprecated in the future
2946- llama_seq_id all_seq_id; // used if seq_id == NULL
29472944};
29482945
29492946// sequence-length-aware batch splitting
@@ -3038,30 +3035,18 @@ struct llama_sbatch {
30383035 } else {
30393036 ubatch.embd = nullptr;
30403037 }
3041- // from here on, the else branches are deprecated;
3042- // they are helpers for smoother batch API transition
3043- if (batch->pos) {
3044- if (ubatch.equal_seqs) {
3045- for (size_t i = 0; i < length; ++i) {
3046- ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
3047- }
3048- } else {
3049- // simple split
3050- ubatch.pos = batch->pos + seq.offset;
3051- }
3052- } else {
3038+ if (ubatch.equal_seqs) {
30533039 for (size_t i = 0; i < length; ++i) {
3054- llama_pos bi = ids[seq.offset + i];
3055- ubatch.pos[ubatch.n_tokens + i] = batch->all_pos_0 + (bi * batch->all_pos_1);
3040+ ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
30563041 }
3042+ } else {
3043+ // simple split
3044+ ubatch.pos = batch->pos + seq.offset;
30573045 }
30583046 if (ubatch.equal_seqs) {
30593047 ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id;
30603048 if (seq.seq_id) {
30613049 ubatch.seq_id[ubatch.n_seqs] = seq.seq_id;
3062- } else {
3063- GGML_ASSERT(seq.n_seq_id == 1);
3064- ubatch.seq_id[ubatch.n_seqs] = &seq.all_seq_id;
30653050 }
30663051 } else {
30673052 // simple split
@@ -3074,10 +3059,6 @@ struct llama_sbatch {
30743059 }
30753060 if (batch->seq_id) {
30763061 ubatch.seq_id = batch->seq_id + seq.offset;
3077- } else {
3078- for (size_t i = 0; i < length; ++i) {
3079- ubatch.seq_id[ubatch.n_seqs + i] = &seq.all_seq_id;
3080- }
30813062 }
30823063 }
30833064 if (logits_all) {
@@ -3196,7 +3177,6 @@ struct llama_sbatch {
31963177 s.seq_id = nullptr;
31973178 s.offset = 0;
31983179 s.length = n_tokens;
3199- s.all_seq_id = batch.all_seq_id;
32003180 return;
32013181 }
32023182 std::sort(ids.begin(), ids.end(),
@@ -3219,7 +3199,7 @@ struct llama_sbatch {
32193199 if (batch.pos) {
32203200 return batch.pos[a] < batch.pos[b];
32213201 }
3222- // no pos, sort by id (assuming batch.all_pos_1 is positive)
3202+ // no pos, sort by id
32233203 return a < b;
32243204 }
32253205 // shared prompts go first
@@ -3229,30 +3209,25 @@ struct llama_sbatch {
32293209 // init seq
32303210 llama_sbatch_seq * last_seq = nullptr;
32313211
3232- if (batch.n_seq_id != nullptr && batch.seq_id != nullptr) {
3233- for (size_t i = 0; i < n_tokens; ++i) {
3234- const size_t bi = ids[i];
3235- const int32_t n_seqs = batch.n_seq_id[bi];
3236- llama_seq_id * seq_ids = batch.seq_id[bi];
3237- if (last_seq != nullptr) {
3238- bool same = n_seqs == last_seq->n_seq_id;
3239- for (int32_t j = 0; same && j < n_seqs; ++j) {
3240- if (seq_ids[j] != last_seq->seq_id[j]) {
3241- same = false;
3242- }
3243- }
3244- if (same) {
3245- last_seq->length += 1;
3246- continue;
3212+ for (size_t i = 0; i < n_tokens; ++i) {
3213+ const size_t bi = ids[i];
3214+ const int32_t n_seqs = batch.n_seq_id[bi];
3215+ llama_seq_id * seq_ids = batch.seq_id[bi];
3216+ if (last_seq != nullptr) {
3217+ bool same = n_seqs == last_seq->n_seq_id;
3218+ for (int32_t j = 0; same && j < n_seqs; ++j) {
3219+ if (seq_ids[j] != last_seq->seq_id[j]) {
3220+ same = false;
32473221 }
32483222 }
3249- llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1, batch.all_seq_id};
3250- seq.push_back(new_seq);
3251- last_seq = &seq.back();
3223+ if (same) {
3224+ last_seq->length += 1;
3225+ continue;
3226+ }
32523227 }
3253- } else {
3254- llama_sbatch_seq new_seq = {1, nullptr, 0, n_tokens, batch.all_seq_id};
3228+ llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1};
32553229 seq.push_back(new_seq);
3230+ last_seq = &seq.back();
32563231 }
32573232 // keep shared prompts first at the end, then sort by length descending.
32583233 std::sort(seq.begin(), seq.end(),
@@ -21069,9 +21044,7 @@ void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
2106921044
2107021045struct llama_batch llama_batch_get_one(
2107121046 llama_token * tokens,
21072- int32_t n_tokens,
21073- llama_pos pos_0,
21074- llama_seq_id seq_id) {
21047+ int32_t n_tokens) {
2107521048 return {
2107621049 /*n_tokens =*/ n_tokens,
2107721050 /*tokens =*/ tokens,
@@ -21080,9 +21053,6 @@ struct llama_batch llama_batch_get_one(
2108021053 /*n_seq_id =*/ nullptr,
2108121054 /*seq_id =*/ nullptr,
2108221055 /*logits =*/ nullptr,
21083- /*all_pos_0 =*/ pos_0,
21084- /*all_pos_1 =*/ 1,
21085- /*all_seq_id =*/ seq_id,
2108621056 };
2108721057}
2108821058
@@ -21095,9 +21065,6 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
2109521065 /*n_seq_id =*/ nullptr,
2109621066 /*seq_id =*/ nullptr,
2109721067 /*logits =*/ nullptr,
21098- /*all_pos_0 =*/ 0,
21099- /*all_pos_1 =*/ 0,
21100- /*all_seq_id =*/ 0,
2110121068 };
2110221069
2110321070 if (embd) {
@@ -21133,10 +21100,58 @@ void llama_batch_free(struct llama_batch batch) {
2113321100 if (batch.logits) free(batch.logits);
2113421101}
2113521102
21103+ // temporary allocate memory for the input batch if needed
21104+ struct llama_batch_allocr {
21105+ static const llama_seq_id default_seq_id = 0;
21106+ std::array<llama_seq_id, 1> seq_id_0 = {default_seq_id};
21107+ std::vector<llama_pos> pos;
21108+ std::vector<int32_t> n_seq_id;
21109+ std::vector<llama_seq_id *> seq_id;
21110+ std::vector<int8_t> logits;
21111+ // fulfill the batch returned by llama_batch_get_one
21112+ struct llama_batch get_fulfilled_batch(struct llama_context * ctx, struct llama_batch in_batch) {
21113+ struct llama_batch batch = in_batch;
21114+ if (!batch.pos) {
21115+ // determine the last position in KV cache
21116+ llama_pos last_pos;
21117+ for (const auto & cell : ctx->kv_self.cells) {
21118+ if (cell.seq_id.find(default_seq_id) != cell.seq_id.end()) {
21119+ last_pos = std::max(last_pos, cell.pos);
21120+ }
21121+ }
21122+ pos.resize(batch.n_tokens);
21123+ for (int32_t i = 1; i <= batch.n_tokens; i++) {
21124+ pos[i] = i+last_pos;
21125+ }
21126+ batch.pos = pos.data();
21127+ }
21128+ if (!batch.n_seq_id) {
21129+ n_seq_id.reserve(batch.n_tokens);
21130+ for (int32_t i = 1; i <= batch.n_tokens; i++) {
21131+ n_seq_id[i] = seq_id_0.size();
21132+ }
21133+ batch.n_seq_id = n_seq_id.data();
21134+ }
21135+ if (!batch.seq_id) {
21136+ seq_id.reserve(batch.n_tokens);
21137+ for (int32_t i = 1; i <= batch.n_tokens; i++) {
21138+ seq_id[i] = seq_id_0.data();
21139+ }
21140+ batch.seq_id = seq_id.data();
21141+ }
21142+ if (!batch.logits) {
21143+ logits.reserve(batch.n_tokens);
21144+ logits[logits.size() - 1] = true;
21145+ batch.logits = logits.data();
21146+ }
21147+ }
21148+ };
21149+
2113621150int32_t llama_encode(
2113721151 struct llama_context * ctx,
2113821152 struct llama_batch batch) {
21139- const int ret = llama_encode_internal(*ctx, batch);
21153+ llama_batch_allocr batch_allocr;
21154+ const int ret = llama_encode_internal(*ctx, batch_allocr.get_fulfilled_batch(ctx, batch));
2114021155 if (ret < 0) {
2114121156 LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
2114221157 }
@@ -21147,7 +21162,8 @@ int32_t llama_encode(
2114721162int32_t llama_decode(
2114821163 struct llama_context * ctx,
2114921164 struct llama_batch batch) {
21150- const int ret = llama_decode_internal(*ctx, batch);
21165+ llama_batch_allocr batch_allocr;
21166+ const int ret = llama_decode_internal(*ctx, batch_allocr.get_fulfilled_batch(ctx, batch));
2115121167 if (ret < 0) {
2115221168 LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
2115321169 }
0 commit comments