class PromptLogprobsWorker:
def __init__(self, max_num_reqs: int):
self.max_num_reqs = max_num_reqs
self.uses_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool)
# req_idx -> list of in-progress LogprobsTensors
self.in_progress_prompt_logprobs: dict[str, list[LogprobsTensors]] = {}
def add_request(self, req_id: str, req_idx: int, sampling_params: SamplingParams):
# For now, only support prompt logprobs for the prompt tokens (not top-k).
uses_prompt_logprobs = sampling_params.prompt_logprobs is not None
if uses_prompt_logprobs:
self.uses_prompt_logprobs[req_idx] = True
self.in_progress_prompt_logprobs[req_id] = []
else:
self.uses_prompt_logprobs[req_idx] = False
def remove_request(self, req_id: str) -> None:
self.in_progress_prompt_logprobs.pop(req_id, None)
def compute_prompt_logprobs(
self,
logits_fn: Callable[[torch.Tensor], torch.Tensor],
hidden_states: torch.Tensor,
input_batch: InputBatch,
# [max_num_reqs, max_model_len]
prefill_token_ids: torch.Tensor,
# [max_num_reqs]
num_computed_tokens: torch.Tensor,
# [max_num_reqs]
prompt_lens: np.ndarray,
# [max_num_reqs]
prefill_lens: np.ndarray,
# [max_num_reqs]
num_computed_prefill_tokens: np.ndarray,
) -> dict[str, LogprobsTensors]:
idx_mapping_np = input_batch.idx_mapping_np
needs_prompt_logprobs = self.uses_prompt_logprobs[idx_mapping_np]
if not np.any(needs_prompt_logprobs):
# Common case: No request asks for prompt logprobs.
return {}
prompt_lens = prompt_lens[idx_mapping_np]
# NOTE(woosuk): -1 because the last prompt token's hidden state is not
# needed for prompt logprobs.
computed_prefill = num_computed_prefill_tokens[idx_mapping_np]
includes_prompt = computed_prefill < prompt_lens - 1
# NOTE(woosuk): If the request was resumed after preemption, its prompt
# logprobs must have been computed before preemption. Skip.
resumed_after_prompt = prompt_lens < prefill_lens[idx_mapping_np]
needs_prompt_logprobs &= includes_prompt & ~resumed_after_prompt
if not np.any(needs_prompt_logprobs):
return {}
# Get the prompt logprobs token_ids.
prompt_logprobs_token_ids = get_prompt_logprobs_token_ids(
input_batch.num_tokens,
input_batch.query_start_loc,
input_batch.idx_mapping,
num_computed_tokens,
prefill_token_ids,
)
# Compute the prompt logprobs.
prompt_logprobs, prompt_ranks = compute_prompt_logprobs_with_chunking(
prompt_logprobs_token_ids,
hidden_states[: input_batch.num_tokens],
logits_fn,
)
pos_after_step = computed_prefill + input_batch.num_scheduled_tokens
is_prompt_chunked = pos_after_step < prompt_lens
query_start_loc_np = input_batch.query_start_loc_np
prompt_token_ids = prompt_logprobs_token_ids.unsqueeze(-1)
prompt_logprobs_dict: dict[str, LogprobsTensors] = {}
for i, req_id in enumerate(input_batch.req_ids):
if not needs_prompt_logprobs[i]:
continue
start_idx = query_start_loc_np[i]
end_idx = query_start_loc_np[i + 1]
assert start_idx < end_idx, (
f"start_idx ({start_idx}) >= end_idx ({end_idx})"
)
if not is_prompt_chunked[i]:
end_idx -= 1
logprobs = LogprobsTensors(
logprob_token_ids=prompt_token_ids[start_idx:end_idx],
logprobs=prompt_logprobs[start_idx:end_idx],
selected_token_ranks=prompt_ranks[start_idx:end_idx],
)
prompt_logprobs_list = self.in_progress_prompt_logprobs[req_id]
if is_prompt_chunked[i]:
# Prompt is chunked. Do not return the logprobs yet.
prompt_logprobs_list.append(logprobs)
continue
if prompt_logprobs_list:
# Merge the in-progress logprobs.
prompt_logprobs_list.append(logprobs)
logprobs = LogprobsTensors(
logprob_token_ids=torch.cat(
[x.logprob_token_ids for x in prompt_logprobs_list]
),
logprobs=torch.cat([x.logprobs for x in prompt_logprobs_list]),
selected_token_ranks=torch.cat(
[x.selected_token_ranks for x in prompt_logprobs_list]
),
)
prompt_logprobs_list.clear()
prompt_logprobs_dict[req_id] = logprobs
return prompt_logprobs_dict