在LangChain使用LLM API自訂客製化的LLM

最近必須使用一個LLM的API來做一個RAG應用,但是LangChain沒辦法直接接外部的LLM API,必須自己實作一個客製化的LLM Class來配合LangChain其他的功能。
這篇文章將會介紹如何在LangChain使用外部的LLM API自訂LangChain的LLM Class。

1. 安裝LangChain

1
2
3
pip install langchain
# or
conda install langchain -c conda-forge

2. 建立自訂的LLM Class

需要透過繼承langchain.llms.base.LLM來建立自訂的LLM Class,並實作以下的方法:

  1. _call:用來呼叫LLM API的方法,並回傳結果。
  2. _stream (optional):用來呼叫LLM API的方法,會回傳一個Generator,來做到Streaming的效果。
  3. _astream (optional):用來呼叫LLM API的方法,並回傳一個AsyncGenerator,來做到Async Streaming的效果。

如果沒有2. _stream 和 3. _astream 的話,LangChain會自動使用_call的方法來呼叫LLM API。

詳細請見:

裡面可能可以更好的地方

  1. 因為api是透過Bearer Token來驗證,所以每次呼叫我都先取得JWT Token,然後再透過Token來呼叫LLM API。
  2. _stream_astream的方法裡面,LLM的response是透過requestshttpx來取得Server Sent Event的response,然後再用這些response做出generator。我對python的async不太熟,不知道有沒有更好的寫法。

基本上,這個LLM Class就是一個對外部LLM API的包裝,讓LangChain可以使用外部的LLM API。之後就可以直接instanitate這個LLM Class,然後就可以使用LangChain的其他功能了。像是RAG Chain等等。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import json
import os
import httpx
import requests
import asyncio
from typing import Optional, List, Any, Mapping, Iterator, Dict, AsyncIterator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.schema.output import GenerationChunk
from langchain.llms.base import LLM
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import LLMResult


class TaideLLM(LLM):

llm_host = 'https://td.nchc.org.tw'
llm_url = f'{llm_host}/api/v1' # or whatever your REST path is...
llm_type = "TAIDE/b.11.0.0"
username = os.getenv('TAIDE_USERNAME')
password = os.getenv('TAIDE_PASSWORD')

@property
def _llm_type(self) -> str:
return self.llm_type

@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {"llmUrl": self.llm_url}

def _get_token(self) -> str:
# get token
r = requests.post(self.llm_url+"/token"
, data={"username": self.username,
"password": self.password
})

token = r.json()["access_token"]
return token

def _stream_response_to_generation_chunk(
self,
stream_response: Dict[str, Any],
) -> GenerationChunk:
"""Convert a stream response to a generation chunk."""
if not stream_response["choices"]:
return GenerationChunk(text="")
return GenerationChunk(
text=stream_response["choices"][0]["text"],
generation_info=dict(
finish_reason=stream_response["choices"][0].get("finish_reason", None),
logprobs=stream_response["choices"][0].get("logprobs", None),
),
)

def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
if stop is not None:
raise ValueError("stop kwargs are not permitted.")

# chat
question = prompt
prompt_1 = f"[INST] {question} [/INST]"
data = {
"model": self.llm_type,
"prompt": prompt_1,
"temperature": 0.2,
"top_p": 0.9,
"presence_penalty": 1,
"frequency_penalty": 1,
"max_tokens": 2000
}
headers = {
"Authorization": "Bearer "+ self._get_token()
}
r = requests.post(self.llm_url+'/completions', headers=headers, json=data)
r.raise_for_status()

return r.json()["choices"][0]["text"]

def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:

question = prompt
prompt_1 = f"[INST] {question} [/INST]"
data = {
"model": self.llm_type,
"prompt": prompt_1,
"temperature": 0.2,
"top_p": 0.9,
"presence_penalty": 1,
"frequency_penalty": 1,
"max_tokens": 2000,
"stream": True
}

headers = {
"Authorization": "Bearer "+ self._get_token()
}

with requests.post(self.llm_url+"/completions", stream=True, json=data, headers=headers) as r:
r.raise_for_status()
for raw_chunk in r.iter_lines():
if raw_chunk[0:6] == "data: ":
json_str = raw_chunk[6:] # Strip 'data: ' to get the JSON part

# Step 3: Parse the JSON string into a Python dictionary
try:
data_dict = json.loads(json_str)
# Step 4: Access elements in the dictionary
chunk = self._stream_response_to_generation_chunk(data_dict)
print(chunk.text)
yield chunk
if run_manager:
run_manager.on_llm_new_token(
chunk.text,
chunk=chunk,
verbose=self.verbose,
logprobs=(
chunk.generation_info["logprobs"]
if chunk.generation_info
else None
),
)
except json.JSONDecodeError:
continue

async def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:

question = prompt
prompt_1 = f"[INST] {question} [/INST]"
data = {
"model": self.llm_type,
"prompt": prompt_1,
"temperature": 0.2,
"top_p": 0.9,
"presence_penalty": 1,
"frequency_penalty": 1,
"max_tokens": 2000,
"stream": True
}

headers = {
"Authorization": "Bearer "+ self._get_token()
}

async with httpx.AsyncClient(timeout=600) as client:
async with client.stream("POST", f"{self.llm_url}/completions", json=data, headers=headers) as response:
response.raise_for_status()
async for raw_chunk in response.aiter_lines():
if raw_chunk.startswith('data:'): # Check if the chunk starts with 'data:'
json_str = raw_chunk[6:]
else:
continue

# Step 3: Parse the JSON string into a Python dictionary
try:
data_dict = json.loads(json_str)
# Step 4: Access elements in the dictionary
chunk = self._stream_response_to_generation_chunk(data_dict)
yield chunk
if run_manager:
await run_manager.on_llm_new_token(
chunk.text,
chunk=chunk,
verbose=self.verbose,
logprobs=(
chunk.generation_info["logprobs"]
if chunk.generation_info
else None
),
)
except json.JSONDecodeError:
continue

在LangChain使用LLM API自訂客製化的LLM

https://blog.ytli.tw/2024/02/25/Custom-LLM-in-LangChain/

作者

Yen-Ting Li

發表於

2024-02-25

更新於

2024-02-25

許可協議

評論