Skip to content

vllm.entrypoints.openai.tool_parsers.hermes_tool_parser ΒΆ

logger module-attribute ΒΆ

logger = init_logger(__name__)

Hermes2ProToolParser ΒΆ

Bases: ToolParser

Source code in vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
@ToolParserManager.register_module("hermes")
class Hermes2ProToolParser(ToolParser):

    def __init__(self, tokenizer: AnyTokenizer):
        super().__init__(tokenizer)

        if isinstance(self.model_tokenizer, MistralTokenizer):
            logger.error(
                "Detected Mistral tokenizer when using a Hermes model")
            self.model_tokenizer = self.model_tokenizer.tokenizer

        self.current_tool_name_sent: bool = False
        self.prev_tool_call_arr: list[dict] = []
        self.current_tool_id: int = -1
        self.streamed_args_for_tool: list[str] = [
        ]  # map what has been streamed for each tool so far to a list

        self.tool_call_start_token: str = "<tool_call>"
        self.tool_call_end_token: str = "</tool_call>"

        self.tool_call_regex = re.compile(
            r"<tool_call>(.*?)</tool_call>|<tool_call>(.*)", re.DOTALL)
        self.scratch_pad_regex = re.compile(
            r"<scratch_pad>(.*?)</scratch_pad>", re.DOTALL)

        if not self.model_tokenizer:
            raise ValueError(
                "The model tokenizer must be passed to the ToolParser "
                "constructor during construction.")
        self.tool_call_start_token_ids = self.model_tokenizer.encode(
            self.tool_call_start_token, add_special_tokens=False)
        self.tool_call_end_token_ids = self.model_tokenizer.encode(
            self.tool_call_end_token, add_special_tokens=False)

        self.tool_call_start_token_array = [
            self.model_tokenizer.decode([token_id])
            for token_id in self.tool_call_start_token_ids
        ]

        self.tool_call_end_token_array = [
            self.model_tokenizer.decode([token_id])
            for token_id in self.tool_call_end_token_ids
        ]

        self.buffered_delta_text = ""

    # Very simple idea: when encountering tokens like <, tool, _call, >,
    # <, /, tool, _call, >, store them in a buffer.
    # When the last token is encountered, empty the buffer and return it.
    # If a token appears in an incorrect sequence while storing in the buffer,
    # return the preceding buffer along with the token.
    def tool_call_delta_buffer(self, delta_text: str):
        # If the sequence of tool_call_start or tool_call_end tokens is not yet
        # complete, fill the buffer with the token and return "".
        if (delta_text in self.tool_call_start_token_array
                or delta_text in self.tool_call_end_token_array):
            # If delta_text is the last token of tool_call_start_token or
            # tool_call_end_token, empty the buffer and return
            # the buffered text + delta_text.
            if (delta_text == self.tool_call_start_token_array[-1]
                    or delta_text == self.tool_call_end_token_array[-1]):
                buffered_text = self.buffered_delta_text
                self.buffered_delta_text = ""
                return buffered_text + delta_text
            else:
                self.buffered_delta_text = self.buffered_delta_text + delta_text
                return ""
        else:
            if self.buffered_delta_text:
                buffered_text = self.buffered_delta_text
                self.buffered_delta_text = ""
                return buffered_text + delta_text
            else:
                return delta_text

    def adjust_request(
            self, request: ChatCompletionRequest) -> ChatCompletionRequest:
        if request.tools and request.tool_choice != 'none':
            # do not skip special tokens because the tool_call tokens are
            # marked "special" in some models. Since they are skipped
            # prior to the call to the tool parser, it breaks tool calling.
            request.skip_special_tokens = False
        return request

    def extract_tool_calls(
        self,
        model_output: str,
        request: ChatCompletionRequest,
    ) -> ExtractedToolCallInformation:

        # sanity check; avoid unnecessary processing
        if self.tool_call_start_token not in model_output:
            return ExtractedToolCallInformation(tools_called=False,
                                                tool_calls=[],
                                                content=model_output)

        else:

            try:
                # there are two possible captures - between tags, or between a
                # tag and end-of-string so the result of
                # findall is an array of tuples where one is a function call and
                # the other is None
                function_call_tuples = (
                    self.tool_call_regex.findall(model_output))

                # load the JSON, and then use it to build the Function and
                # Tool Call
                raw_function_calls = [
                    json.loads(match[0] if match[0] else match[1])
                    for match in function_call_tuples
                ]
                tool_calls = [
                    ToolCall(
                        type="function",
                        function=FunctionCall(
                            name=function_call["name"],
                            # function call args are JSON but as a string
                            arguments=json.dumps(function_call["arguments"],
                                                 ensure_ascii=False)))
                    for function_call in raw_function_calls
                ]

                content = model_output[:model_output.
                                       find(self.tool_call_start_token)]
                return ExtractedToolCallInformation(
                    tools_called=True,
                    tool_calls=tool_calls,
                    content=content if content else None)

            except Exception:
                logger.exception(
                    "Error in extracting tool call from response.")
                return ExtractedToolCallInformation(tools_called=False,
                                                    tool_calls=[],
                                                    content=model_output)

    def extract_tool_calls_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
        request: ChatCompletionRequest,
    ) -> Union[DeltaMessage, None]:
        # 1. All tokens are parsed based on _text, not token_ids.
        # 2. All incoming text data is processed by the tool_call_delta_buffer
        #    function for buffering before being used for parsing.

        delta_text = self.tool_call_delta_buffer(delta_text)
        # If the last characters of previous_text
        # match self.buffered_delta_text, remove only the matching part.
        if (len(previous_text) >= len(self.buffered_delta_text)
                and previous_text[-len(self.buffered_delta_text):]
                == self.buffered_delta_text):
            previous_text = previous_text[:-len(self.buffered_delta_text)]
            current_text = previous_text + delta_text

        logger.debug("delta_text: %s", delta_text)
        logger.debug("delta_token_ids: %s", delta_token_ids)
        # check to see if we should be streaming a tool call - is there a
        if self.tool_call_start_token not in current_text:
            logger.debug("No tool call tokens found!")
            return DeltaMessage(content=delta_text)

        try:

            # figure out where we are in the parsing by counting tool call
            # start & end tags
            prev_tool_start_count = previous_text.count(
                self.tool_call_start_token)
            prev_tool_end_count = previous_text.count(self.tool_call_end_token)
            cur_tool_start_count = current_text.count(
                self.tool_call_start_token)
            cur_tool_end_count = current_text.count(self.tool_call_end_token)
            tool_call_portion = None
            text_portion = None

            # case: if we're generating text, OR rounding out a tool call
            if (cur_tool_start_count == cur_tool_end_count
                    and prev_tool_end_count == cur_tool_end_count
                    and self.tool_call_end_token not in delta_text):
                logger.debug("Generating text content! skipping tool parsing.")
                return DeltaMessage(content=delta_text)

            if self.tool_call_end_token in delta_text:
                logger.debug("tool_call_end_token in delta_text")
                full_text = current_text + delta_text
                tool_call_portion = full_text.split(
                    self.tool_call_start_token)[-1].split(
                        self.tool_call_end_token)[0].rstrip()
                delta_text = delta_text.split(
                    self.tool_call_end_token)[0].rstrip()
                text_portion = delta_text.split(
                    self.tool_call_end_token)[-1].lstrip()

            # case: if tool open & close tag counts don't match, we're doing
            # imaginary "else" block here
            # something with tools with this diff.
            # flags for partial JSON parting. exported constants from
            # "Allow" are handled via BIT MASK
            flags = Allow.ALL if self.current_tool_name_sent \
                else Allow.ALL & ~Allow.STR

            # case -- we're starting a new tool call
            if (cur_tool_start_count > cur_tool_end_count
                    and cur_tool_start_count > prev_tool_start_count):
                if len(delta_token_ids) > 1:
                    tool_call_portion = current_text.split(
                        self.tool_call_start_token)[-1]
                else:
                    tool_call_portion = None
                    delta = None

                text_portion = None

                # set cursors and state appropriately
                self.current_tool_id += 1
                self.current_tool_name_sent = False
                self.streamed_args_for_tool.append("")
                logger.debug("Starting on a new tool %s", self.current_tool_id)

            # case -- we're updating an existing tool call
            elif (cur_tool_start_count > cur_tool_end_count
                  and cur_tool_start_count == prev_tool_start_count):

                # get the portion of the text that's the tool call
                tool_call_portion = current_text.split(
                    self.tool_call_start_token)[-1]
                text_portion = None

            # case -- the current tool call is being closed.
            elif (cur_tool_start_count == cur_tool_end_count
                  and cur_tool_end_count >= prev_tool_end_count):
                if (self.prev_tool_call_arr is None
                        or len(self.prev_tool_call_arr) == 0):
                    logger.debug(
                        "attempting to close tool call, but no tool call")
                    return None
                diff = self.prev_tool_call_arr[self.current_tool_id].get(
                    "arguments")
                if diff:
                    diff = diff.encode('utf-8').decode(
                        'unicode_escape') if diff is str else diff
                    if ('"}' not in delta_text):
                        return None
                    end_loc = delta_text.rindex('"}')
                    diff = delta_text[:end_loc] + '"}'
                    logger.debug(
                        "Finishing tool and found diff that had not "
                        "been streamed yet: %s", diff)
                    self.streamed_args_for_tool[self.current_tool_id] \
                        += diff
                    return DeltaMessage(tool_calls=[
                        DeltaToolCall(index=self.current_tool_id,
                                      function=DeltaFunctionCall(
                                          arguments=diff).model_dump(
                                              exclude_none=True))
                    ])

            # case -- otherwise we're just generating text
            else:
                text = delta_text.replace(self.tool_call_start_token, "")
                text = text.replace(self.tool_call_end_token, "")
                delta = DeltaMessage(tool_calls=[], content=text)
                return delta

            try:

                current_tool_call = partial_json_parser.loads(
                    tool_call_portion or "{}",
                    flags) if tool_call_portion else None
                logger.debug("Parsed tool call %s", current_tool_call)
            except partial_json_parser.core.exceptions.MalformedJSON:
                logger.debug('not enough tokens to parse into JSON yet')
                return None
            except json.decoder.JSONDecodeError:
                logger.debug("unable to parse JSON")
                return None

            # case - we haven't sent the tool name yet. If it's available, send
            #   it. otherwise, wait until it's available.
            if not self.current_tool_name_sent:
                if (current_tool_call is None):
                    return None
                function_name: Union[str, None] = current_tool_call.get("name")
                if function_name:
                    self.current_tool_name_sent = True
                    return DeltaMessage(tool_calls=[
                        DeltaToolCall(index=self.current_tool_id,
                                      type="function",
                                      id=make_tool_call_id(),
                                      function=DeltaFunctionCall(
                                          name=function_name).model_dump(
                                              exclude_none=True))
                    ])
                else:
                    return None
            # case -- otherwise, send the tool call delta

            # if the tool call portion is None, send the delta as text
            if tool_call_portion is None:
                # if there's text but not tool calls, send that -
                # otherwise None to skip chunk
                delta = DeltaMessage(content=delta_text) \
                    if text_portion is not None else None
                return delta

            # now, the nitty-gritty of tool calls
            # now we have the portion to parse as tool call.

            logger.debug("Trying to parse current tool call with ID %s",
                         self.current_tool_id)

            # if we're starting a new tool call, push an empty object in as
            #   a placeholder for the arguments
            if len(self.prev_tool_call_arr) <= self.current_tool_id:
                self.prev_tool_call_arr.append({})

            # main logic for tool parsing here - compare prev. partially-parsed
            #   JSON to the current partially-parsed JSON
            prev_arguments = (
                self.prev_tool_call_arr[self.current_tool_id].get("arguments"))
            cur_arguments = current_tool_call.get("arguments")

            logger.debug("diffing old arguments: %s", prev_arguments)
            logger.debug("against new ones: %s", cur_arguments)

            # case -- no arguments have been created yet. skip sending a delta.
            if not cur_arguments and not prev_arguments:
                logger.debug("Skipping text %s - no arguments", delta_text)
                delta = None

            # case -- prev arguments are defined, but non are now.
            #   probably impossible, but not a fatal error - just keep going
            elif not cur_arguments and prev_arguments:
                logger.error("should be impossible to have arguments reset "
                             "mid-call. skipping streaming anything.")
                delta = None

            # case -- we now have the first info about arguments available from
            #   autocompleting the JSON
            elif cur_arguments and not prev_arguments:
                # extract the content after {"name": ..., "arguments":
                #   directly from tool_call_portion as cur_arguments_json,
                #   since cur_arguments may differ from the original text
                #   due to partial JSON parsing
                #   for example, tool_call_portion =
                #     {"name": "search", "arguments": {"search_request": {"
                #   but cur_arguments =
                #     {"search_request": {}}
                function_name = current_tool_call.get("name")
                match = re.search(
                    r'\{"name":\s*"' +
                    re.escape(function_name) + r'"\s*,\s*"arguments":\s*(.*)',
                    tool_call_portion.strip(), re.DOTALL)
                if match:
                    cur_arguments_json = match.group(1)
                else:
                    cur_arguments_json = json.dumps(cur_arguments,
                                                    ensure_ascii=False)

                logger.debug("finding %s in %s", delta_text,
                             cur_arguments_json)

                # get the location where previous args differ from current.
                if (delta_text not in cur_arguments_json):
                    return None
                args_delta_start_loc = cur_arguments_json. \
                                           rindex(delta_text) + \
                                           len(delta_text)

                # use that to find the actual delta
                arguments_delta = cur_arguments_json[:args_delta_start_loc]
                logger.debug("First tokens in arguments received: %s",
                             arguments_delta)

                delta = DeltaMessage(tool_calls=[
                    DeltaToolCall(index=self.current_tool_id,
                                  function=DeltaFunctionCall(
                                      arguments=arguments_delta).model_dump(
                                          exclude_none=True))
                ])
                self.streamed_args_for_tool[self.current_tool_id] \
                    += arguments_delta

            # last case -- we have an update to existing arguments.
            elif cur_arguments and prev_arguments:
                # judge whether the tool_call_portion is a complete JSON
                try:
                    json.loads(tool_call_portion)
                    is_complete_json = True
                except Exception:
                    is_complete_json = False

                # if the delta_text ends with a '}' and tool_call_portion is a
                #   complete JSON, then the last '}' does not belong to the
                #   arguments, so we should trim it off
                if isinstance(delta_text, str) \
                    and len(delta_text.rstrip()) >= 1 \
                    and delta_text.rstrip()[-1] == '}' \
                    and is_complete_json:
                    delta_text = delta_text.rstrip()[:-1]

                logger.debug("got diff %s", delta_text)

                delta = DeltaMessage(tool_calls=[
                    DeltaToolCall(index=self.current_tool_id,
                                  function=DeltaFunctionCall(
                                      arguments=delta_text).model_dump(
                                          exclude_none=True))
                ])
                self.streamed_args_for_tool[self.current_tool_id] \
                    += delta_text

            # handle saving the state for the current tool into
            # the "prev" list for use in diffing for the next iteration
            if self.current_tool_id == len(self.prev_tool_call_arr) - 1:
                self.prev_tool_call_arr[self.current_tool_id] = \
                    current_tool_call
            else:
                self.prev_tool_call_arr.append(current_tool_call)

            return delta

        except Exception:
            logger.exception("Error trying to handle streaming tool call.")
            return None  # do not stream a delta. skip this token ID.

buffered_delta_text instance-attribute ΒΆ

buffered_delta_text = ''

current_tool_id instance-attribute ΒΆ

current_tool_id: int = -1

current_tool_name_sent instance-attribute ΒΆ

current_tool_name_sent: bool = False

model_tokenizer instance-attribute ΒΆ

model_tokenizer = tokenizer

prev_tool_call_arr instance-attribute ΒΆ

prev_tool_call_arr: list[dict] = []

scratch_pad_regex instance-attribute ΒΆ

scratch_pad_regex = compile(
    "<scratch_pad>(.*?)</scratch_pad>", DOTALL
)

streamed_args_for_tool instance-attribute ΒΆ

streamed_args_for_tool: list[str] = []

tool_call_end_token instance-attribute ΒΆ

tool_call_end_token: str = '</tool_call>'

tool_call_end_token_array instance-attribute ΒΆ

tool_call_end_token_array = [
    (decode([token_id]))
    for token_id in (tool_call_end_token_ids)
]

tool_call_end_token_ids instance-attribute ΒΆ

tool_call_end_token_ids = encode(
    tool_call_end_token, add_special_tokens=False
)

tool_call_regex instance-attribute ΒΆ

tool_call_regex = compile(
    "<tool_call>(.*?)</tool_call>|<tool_call>(.*)", DOTALL
)

tool_call_start_token instance-attribute ΒΆ

tool_call_start_token: str = '<tool_call>'

tool_call_start_token_array instance-attribute ΒΆ

tool_call_start_token_array = [
    (decode([token_id]))
    for token_id in (tool_call_start_token_ids)
]

tool_call_start_token_ids instance-attribute ΒΆ

tool_call_start_token_ids = encode(
    tool_call_start_token, add_special_tokens=False
)

__init__ ΒΆ

__init__(tokenizer: AnyTokenizer)
Source code in vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py
def __init__(self, tokenizer: AnyTokenizer):
    super().__init__(tokenizer)

    if isinstance(self.model_tokenizer, MistralTokenizer):
        logger.error(
            "Detected Mistral tokenizer when using a Hermes model")
        self.model_tokenizer = self.model_tokenizer.tokenizer

    self.current_tool_name_sent: bool = False
    self.prev_tool_call_arr: list[dict] = []
    self.current_tool_id: int = -1
    self.streamed_args_for_tool: list[str] = [
    ]  # map what has been streamed for each tool so far to a list

    self.tool_call_start_token: str = "<tool_call>"
    self.tool_call_end_token: str = "</tool_call>"

    self.tool_call_regex = re.compile(
        r"<tool_call>(.*?)</tool_call>|<tool_call>(.*)", re.DOTALL)
    self.scratch_pad_regex = re.compile(
        r"<scratch_pad>(.*?)</scratch_pad>", re.DOTALL)

    if not self.model_tokenizer:
        raise ValueError(
            "The model tokenizer must be passed to the ToolParser "
            "constructor during construction.")
    self.tool_call_start_token_ids = self.model_tokenizer.encode(
        self.tool_call_start_token, add_special_tokens=False)
    self.tool_call_end_token_ids = self.model_tokenizer.encode(
        self.tool_call_end_token, add_special_tokens=False)

    self.tool_call_start_token_array = [
        self.model_tokenizer.decode([token_id])
        for token_id in self.tool_call_start_token_ids
    ]

    self.tool_call_end_token_array = [
        self.model_tokenizer.decode([token_id])
        for token_id in self.tool_call_end_token_ids
    ]

    self.buffered_delta_text = ""

adjust_request ΒΆ

adjust_request(
    request: ChatCompletionRequest,
) -> ChatCompletionRequest
Source code in vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py
def adjust_request(
        self, request: ChatCompletionRequest) -> ChatCompletionRequest:
    if request.tools and request.tool_choice != 'none':
        # do not skip special tokens because the tool_call tokens are
        # marked "special" in some models. Since they are skipped
        # prior to the call to the tool parser, it breaks tool calling.
        request.skip_special_tokens = False
    return request

extract_tool_calls ΒΆ

extract_tool_calls(
    model_output: str, request: ChatCompletionRequest
) -> ExtractedToolCallInformation
Source code in vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py
def extract_tool_calls(
    self,
    model_output: str,
    request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:

    # sanity check; avoid unnecessary processing
    if self.tool_call_start_token not in model_output:
        return ExtractedToolCallInformation(tools_called=False,
                                            tool_calls=[],
                                            content=model_output)

    else:

        try:
            # there are two possible captures - between tags, or between a
            # tag and end-of-string so the result of
            # findall is an array of tuples where one is a function call and
            # the other is None
            function_call_tuples = (
                self.tool_call_regex.findall(model_output))

            # load the JSON, and then use it to build the Function and
            # Tool Call
            raw_function_calls = [
                json.loads(match[0] if match[0] else match[1])
                for match in function_call_tuples
            ]
            tool_calls = [
                ToolCall(
                    type="function",
                    function=FunctionCall(
                        name=function_call["name"],
                        # function call args are JSON but as a string
                        arguments=json.dumps(function_call["arguments"],
                                             ensure_ascii=False)))
                for function_call in raw_function_calls
            ]

            content = model_output[:model_output.
                                   find(self.tool_call_start_token)]
            return ExtractedToolCallInformation(
                tools_called=True,
                tool_calls=tool_calls,
                content=content if content else None)

        except Exception:
            logger.exception(
                "Error in extracting tool call from response.")
            return ExtractedToolCallInformation(tools_called=False,
                                                tool_calls=[],
                                                content=model_output)

extract_tool_calls_streaming ΒΆ

extract_tool_calls_streaming(
    previous_text: str,
    current_text: str,
    delta_text: str,
    previous_token_ids: Sequence[int],
    current_token_ids: Sequence[int],
    delta_token_ids: Sequence[int],
    request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]
Source code in vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py
def extract_tool_calls_streaming(
    self,
    previous_text: str,
    current_text: str,
    delta_text: str,
    previous_token_ids: Sequence[int],
    current_token_ids: Sequence[int],
    delta_token_ids: Sequence[int],
    request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
    # 1. All tokens are parsed based on _text, not token_ids.
    # 2. All incoming text data is processed by the tool_call_delta_buffer
    #    function for buffering before being used for parsing.

    delta_text = self.tool_call_delta_buffer(delta_text)
    # If the last characters of previous_text
    # match self.buffered_delta_text, remove only the matching part.
    if (len(previous_text) >= len(self.buffered_delta_text)
            and previous_text[-len(self.buffered_delta_text):]
            == self.buffered_delta_text):
        previous_text = previous_text[:-len(self.buffered_delta_text)]
        current_text = previous_text + delta_text

    logger.debug("delta_text: %s", delta_text)
    logger.debug("delta_token_ids: %s", delta_token_ids)
    # check to see if we should be streaming a tool call - is there a
    if self.tool_call_start_token not in current_text:
        logger.debug("No tool call tokens found!")
        return DeltaMessage(content=delta_text)

    try:

        # figure out where we are in the parsing by counting tool call
        # start & end tags
        prev_tool_start_count = previous_text.count(
            self.tool_call_start_token)
        prev_tool_end_count = previous_text.count(self.tool_call_end_token)
        cur_tool_start_count = current_text.count(
            self.tool_call_start_token)
        cur_tool_end_count = current_text.count(self.tool_call_end_token)
        tool_call_portion = None
        text_portion = None

        # case: if we're generating text, OR rounding out a tool call
        if (cur_tool_start_count == cur_tool_end_count
                and prev_tool_end_count == cur_tool_end_count
                and self.tool_call_end_token not in delta_text):
            logger.debug("Generating text content! skipping tool parsing.")
            return DeltaMessage(content=delta_text)

        if self.tool_call_end_token in delta_text:
            logger.debug("tool_call_end_token in delta_text")
            full_text = current_text + delta_text
            tool_call_portion = full_text.split(
                self.tool_call_start_token)[-1].split(
                    self.tool_call_end_token)[0].rstrip()
            delta_text = delta_text.split(
                self.tool_call_end_token)[0].rstrip()
            text_portion = delta_text.split(
                self.tool_call_end_token)[-1].lstrip()

        # case: if tool open & close tag counts don't match, we're doing
        # imaginary "else" block here
        # something with tools with this diff.
        # flags for partial JSON parting. exported constants from
        # "Allow" are handled via BIT MASK
        flags = Allow.ALL if self.current_tool_name_sent \
            else Allow.ALL & ~Allow.STR

        # case -- we're starting a new tool call
        if (cur_tool_start_count > cur_tool_end_count
                and cur_tool_start_count > prev_tool_start_count):
            if len(delta_token_ids) > 1:
                tool_call_portion = current_text.split(
                    self.tool_call_start_token)[-1]
            else:
                tool_call_portion = None
                delta = None

            text_portion = None

            # set cursors and state appropriately
            self.current_tool_id += 1
            self.current_tool_name_sent = False
            self.streamed_args_for_tool.append("")
            logger.debug("Starting on a new tool %s", self.current_tool_id)

        # case -- we're updating an existing tool call
        elif (cur_tool_start_count > cur_tool_end_count
              and cur_tool_start_count == prev_tool_start_count):

            # get the portion of the text that's the tool call
            tool_call_portion = current_text.split(
                self.tool_call_start_token)[-1]
            text_portion = None

        # case -- the current tool call is being closed.
        elif (cur_tool_start_count == cur_tool_end_count
              and cur_tool_end_count >= prev_tool_end_count):
            if (self.prev_tool_call_arr is None
                    or len(self.prev_tool_call_arr) == 0):
                logger.debug(
                    "attempting to close tool call, but no tool call")
                return None
            diff = self.prev_tool_call_arr[self.current_tool_id].get(
                "arguments")
            if diff:
                diff = diff.encode('utf-8').decode(
                    'unicode_escape') if diff is str else diff
                if ('"}' not in delta_text):
                    return None
                end_loc = delta_text.rindex('"}')
                diff = delta_text[:end_loc] + '"}'
                logger.debug(
                    "Finishing tool and found diff that had not "
                    "been streamed yet: %s", diff)
                self.streamed_args_for_tool[self.current_tool_id] \
                    += diff
                return DeltaMessage(tool_calls=[
                    DeltaToolCall(index=self.current_tool_id,
                                  function=DeltaFunctionCall(
                                      arguments=diff).model_dump(
                                          exclude_none=True))
                ])

        # case -- otherwise we're just generating text
        else:
            text = delta_text.replace(self.tool_call_start_token, "")
            text = text.replace(self.tool_call_end_token, "")
            delta = DeltaMessage(tool_calls=[], content=text)
            return delta

        try:

            current_tool_call = partial_json_parser.loads(
                tool_call_portion or "{}",
                flags) if tool_call_portion else None
            logger.debug("Parsed tool call %s", current_tool_call)
        except partial_json_parser.core.exceptions.MalformedJSON:
            logger.debug('not enough tokens to parse into JSON yet')
            return None
        except json.decoder.JSONDecodeError:
            logger.debug("unable to parse JSON")
            return None

        # case - we haven't sent the tool name yet. If it's available, send
        #   it. otherwise, wait until it's available.
        if not self.current_tool_name_sent:
            if (current_tool_call is None):
                return None
            function_name: Union[str, None] = current_tool_call.get("name")
            if function_name:
                self.current_tool_name_sent = True
                return DeltaMessage(tool_calls=[
                    DeltaToolCall(index=self.current_tool_id,
                                  type="function",
                                  id=make_tool_call_id(),
                                  function=DeltaFunctionCall(
                                      name=function_name).model_dump(
                                          exclude_none=True))
                ])
            else:
                return None
        # case -- otherwise, send the tool call delta

        # if the tool call portion is None, send the delta as text
        if tool_call_portion is None:
            # if there's text but not tool calls, send that -
            # otherwise None to skip chunk
            delta = DeltaMessage(content=delta_text) \
                if text_portion is not None else None
            return delta

        # now, the nitty-gritty of tool calls
        # now we have the portion to parse as tool call.

        logger.debug("Trying to parse current tool call with ID %s",
                     self.current_tool_id)

        # if we're starting a new tool call, push an empty object in as
        #   a placeholder for the arguments
        if len(self.prev_tool_call_arr) <= self.current_tool_id:
            self.prev_tool_call_arr.append({})

        # main logic for tool parsing here - compare prev. partially-parsed
        #   JSON to the current partially-parsed JSON
        prev_arguments = (
            self.prev_tool_call_arr[self.current_tool_id].get("arguments"))
        cur_arguments = current_tool_call.get("arguments")

        logger.debug("diffing old arguments: %s", prev_arguments)
        logger.debug("against new ones: %s", cur_arguments)

        # case -- no arguments have been created yet. skip sending a delta.
        if not cur_arguments and not prev_arguments:
            logger.debug("Skipping text %s - no arguments", delta_text)
            delta = None

        # case -- prev arguments are defined, but non are now.
        #   probably impossible, but not a fatal error - just keep going
        elif not cur_arguments and prev_arguments:
            logger.error("should be impossible to have arguments reset "
                         "mid-call. skipping streaming anything.")
            delta = None

        # case -- we now have the first info about arguments available from
        #   autocompleting the JSON
        elif cur_arguments and not prev_arguments:
            # extract the content after {"name": ..., "arguments":
            #   directly from tool_call_portion as cur_arguments_json,
            #   since cur_arguments may differ from the original text
            #   due to partial JSON parsing
            #   for example, tool_call_portion =
            #     {"name": "search", "arguments": {"search_request": {"
            #   but cur_arguments =
            #     {"search_request": {}}
            function_name = current_tool_call.get("name")
            match = re.search(
                r'\{"name":\s*"' +
                re.escape(function_name) + r'"\s*,\s*"arguments":\s*(.*)',
                tool_call_portion.strip(), re.DOTALL)
            if match:
                cur_arguments_json = match.group(1)
            else:
                cur_arguments_json = json.dumps(cur_arguments,
                                                ensure_ascii=False)

            logger.debug("finding %s in %s", delta_text,
                         cur_arguments_json)

            # get the location where previous args differ from current.
            if (delta_text not in cur_arguments_json):
                return None
            args_delta_start_loc = cur_arguments_json. \
                                       rindex(delta_text) + \
                                       len(delta_text)

            # use that to find the actual delta
            arguments_delta = cur_arguments_json[:args_delta_start_loc]
            logger.debug("First tokens in arguments received: %s",
                         arguments_delta)

            delta = DeltaMessage(tool_calls=[
                DeltaToolCall(index=self.current_tool_id,
                              function=DeltaFunctionCall(
                                  arguments=arguments_delta).model_dump(
                                      exclude_none=True))
            ])
            self.streamed_args_for_tool[self.current_tool_id] \
                += arguments_delta

        # last case -- we have an update to existing arguments.
        elif cur_arguments and prev_arguments:
            # judge whether the tool_call_portion is a complete JSON
            try:
                json.loads(tool_call_portion)
                is_complete_json = True
            except Exception:
                is_complete_json = False

            # if the delta_text ends with a '}' and tool_call_portion is a
            #   complete JSON, then the last '}' does not belong to the
            #   arguments, so we should trim it off
            if isinstance(delta_text, str) \
                and len(delta_text.rstrip()) >= 1 \
                and delta_text.rstrip()[-1] == '}' \
                and is_complete_json:
                delta_text = delta_text.rstrip()[:-1]

            logger.debug("got diff %s", delta_text)

            delta = DeltaMessage(tool_calls=[
                DeltaToolCall(index=self.current_tool_id,
                              function=DeltaFunctionCall(
                                  arguments=delta_text).model_dump(
                                      exclude_none=True))
            ])
            self.streamed_args_for_tool[self.current_tool_id] \
                += delta_text

        # handle saving the state for the current tool into
        # the "prev" list for use in diffing for the next iteration
        if self.current_tool_id == len(self.prev_tool_call_arr) - 1:
            self.prev_tool_call_arr[self.current_tool_id] = \
                current_tool_call
        else:
            self.prev_tool_call_arr.append(current_tool_call)

        return delta

    except Exception:
        logger.exception("Error trying to handle streaming tool call.")
        return None  # do not stream a delta. skip this token ID.

tool_call_delta_buffer ΒΆ

tool_call_delta_buffer(delta_text: str)
Source code in vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py
def tool_call_delta_buffer(self, delta_text: str):
    # If the sequence of tool_call_start or tool_call_end tokens is not yet
    # complete, fill the buffer with the token and return "".
    if (delta_text in self.tool_call_start_token_array
            or delta_text in self.tool_call_end_token_array):
        # If delta_text is the last token of tool_call_start_token or
        # tool_call_end_token, empty the buffer and return
        # the buffered text + delta_text.
        if (delta_text == self.tool_call_start_token_array[-1]
                or delta_text == self.tool_call_end_token_array[-1]):
            buffered_text = self.buffered_delta_text
            self.buffered_delta_text = ""
            return buffered_text + delta_text
        else:
            self.buffered_delta_text = self.buffered_delta_text + delta_text
            return ""
    else:
        if self.buffered_delta_text:
            buffered_text = self.buffered_delta_text
            self.buffered_delta_text = ""
            return buffered_text + delta_text
        else:
            return delta_text