前言
最近写了一个基于 OneBot v11
的机器人开发框架,也就是 SDK
,其中包含了蛮多东西,所以单独写篇文章聊聊。
整个项目说起来还是比较复杂,所以这里只捡出几个核心实现。这个项目用来参考学习,没有投入生产环境的打算。当然如果你愿意或者喜欢本项目的风格,部署到生产环境是没问题的 :)
项目地址:https://github.com/kifuan/shirasu
使用方法
还是比较直观的,这里把 README
里的示例粘贴过来。
import asyncio
from shirasu import AddonPool, OneBotClient
if __name__ == '__main__':
pool = AddonPool.from_modules(
'shirasu.addons.echo',
'shirasu.addons.help',
)
asyncio.run(OneBotClient.listen(pool=pool))
至于插件的定义方法下文会介绍,接下来在入口文件同级目录下创建 shirasu.yml
配置 WebSocket
地址。
# The WebSocket server URL(not reverse WebSocket).
ws: ws://127.0.0.1:8080
随后打开一个 OneBot v11
具体实现,如 go-cqhttp
,以正向ws的方式进行连接。给机器人发 /echo hello
它就会正常回复一个 hello
了。
Task 隐患
根据 python/cpython#91887,使用 asyncio.create_task
创建的任务都只有一个弱引用。也就是说,如果你不手动将创建的任务储存起来,GC
可以把没执行的任务直接回收掉。所以就有了这个 issue
,下方回复中给出可以使用 asyncio.TaskGroup
但是需要 Python 3.11+。我的开发环境是 3.10,所以手动处理一下这个问题。
解决方式很简单,用一个 set
把没完成的任务存起来就行,大致代码如下:
def __init__(self):
self._tasks: set[asyncio.Task] = set()
def use(self):
task = asyncio.create_task(...)
self._tasks.add(task)
task.add_done_callback(self._tasks.discard)
这样就可以保证没完成的任务不会被 GC
删除,扫除了一个隐患。
获取 WS 响应数据
使用 Python 中的 websockets
库,每次收到消息时都会使用 asyncio.create_task
创建一个 task
并行处理,所以这里就会有一个问题,就是当你发送之后并不能确定下一次收到的是否为当前调用结果,所以 OneBot
定义了一个 echo
字段,同一个请求和返回时这个字段将是相同的。
这里提一下,在 OneBot v11 中并没有规定 echo
的数据类型,而 v12
中规定是 string
。至于具体实现,参考 go-cqhttp
源码发现传什么数据都无所谓。下面的 j
是 gjson.Parse
返回的对象。
ret := c.apiCaller.Call(t, j.Get("params"))
if j.Get("echo").Exists() {
ret["echo"] = j.Get("echo").Value()
}
但是最新标准规定要传字符串,那就用字符串了。
据此,需要创建一个工具类,利用 asyncio.Future
实现数据的延时获取,下方源码位于 shirasu/util/future_table.py
。
import sys
import asyncio
from typing import Any
class FutureTable:
def __init__(self) -> None:
self._future_id = 0
self._futures: dict[int, asyncio.Future] = {}
def register(self) -> int:
self._future_id = (self._future_id + 1) % sys.maxsize
self._futures[self._future_id] = asyncio.get_event_loop().create_future()
return self._future_id
def set(self, echo: int, data: dict[str, Any]) -> None:
if future := self._futures.get(echo):
future.set_result(data)
async def get(self, future_id: int, timeout: float) -> dict[str, Any]:
if not (future := self._futures.get(future_id)):
raise KeyError(f'future id {future_id} does not exist')
try:
return await asyncio.wait_for(future, timeout)
finally:
del self._futures[future_id]
在每一次收数据时,如果检查到存在 echo
字段,就将这个 FutureTable
中设置上对应的相应值,从而达到调用 API 并获取响应数据的目的。这里附上核心代码,源码位于shirasu/client/onebot.py
。
async def call_action(self, action: str, **params: Any) -> dict[str, Any]:
logger.info(f'Calling action {action}.')
future_id = self._futures.register()
await self._ws.send(ujson.dumps({
'action': action,
'params': params,
'echo': str(future_id),
}))
data = await self._futures.get(future_id, self._global_config.action_timeout)
if data.get('status') == 'failed':
raise ClientActionError(data)
return data.get('data', {})
async def _handle(self, data: dict[str, Any]) -> None:
if echo := data.get('echo'):
self._futures.set(int(echo), data)
return
...
这个 _handle
就是每次获取到数据时建立的 task
。
简单依赖注入框架
我个人比较喜欢 Spring
那种依赖注入。
@Autowired
private FooService fooService;
但是在 FastAPI
中使用时往往需要你这么做。
def use(foo: Foo = Depends(get_foo)) -> None:
...
其实也有别的框架,可以做到 Spring
那种类型的,但是我还是按照我的想法实现了一个简单的框架,源码在 shirasu/di.py
,这里就贴上使用方法。
import asyncio
from datetime import datetime
from shirasu.di import inject, provide
@provide('now')
async def provide_now() -> datetime:
return datetime.now()
@provide('today')
async def provide_today(now: datetime) -> int:
await asyncio.sleep(.1)
return now.day
@inject()
async def use_today(today: int) -> None:
print(today)
@inject()
async def use_now(now: datetime) -> None:
await asyncio.sleep(.1)
print(now.year)
# 1
await use_today()
# 2023
await use_now()
推荐使用
IPython
测试这些,它支持顶层await
。如果非要写到.py
文件里请自行包装一层async def main
然后使用asyncio.run
运行。
为了保持一致性和方便使用,这里均采用异步函数,目的是方便在 provider
中进行异步操作。现在看上去不太方便,那是因为没有真正到应用场景——需要依赖异步操作的时候。
我个人认为,依赖注入最重要的就是你的 provider
也可以有自己的依赖项,比如上方代码的 today
就依赖于 now
,如果后期它还有别的依赖项,可以只修改它本身,其它地方均无需修改。
此外,这个框架还会根据你提供的 type hint
来判断 provider
返回的类型和你在参数后面写的类型是否一致,当然如果是子类也可以,如果不一致它就会打印一条警告信息。你可以直接不标注类型来跳过类型检测。
接下来聊聊实现的核心逻辑,它是根据你 provide
的时候提供的字符串来判断。以下几个函数是实现的核心,如果读者想看源码可以直接去文章开头提到的仓库下 shirasu/di.py
找到完整代码。
class DependencyInjector:
"""
Dependency injector based on parameter names.
Note: positional-only arguments are not supported.
"""
def __init__(self) -> None:
self._providers: dict[str, Callable[..., Awaitable[T]]] = {}
async def _inject_func_args(self, func: Callable[..., Awaitable[T]], *inject_for: str) -> dict[str, Any]:
params = inspect.signature(func).parameters
# Check unknown dependencies.
if unknown_deps := [dep for dep in params if dep not in self._providers]:
raise UnknownDependencyError(unknown_deps)
# Check circular dependencies.
if circular_deps := [dep for dep in params if dep in inject_for]:
raise CircularDependencyError(circular_deps)
args = dict(zip(params, await asyncio.gather(*(
self._apply(self._providers[dep], dep, *inject_for)
for dep in params
))))
# Check types of injected parameters.
for dep, param in params.items():
anno = param.annotation
# Skip untyped parameters.
if anno == inspect.Parameter.empty:
continue
if not isinstance(val := args[dep], expected := anno):
module = inspect.getmodule(func)
module_name = module.__name__ if module else '<unknown module>'
module_func_name = f'{module_name}:{func.__name__}'
logger.warning(f'type mismatch for parameter {dep} in function {module_func_name}, '
f'real type: {type(val).__name__}, expected: {expected.__name__}')
return args
async def _apply(self, func: Callable[..., Awaitable[T]], *apply_for: str) -> T:
injected_args = await self._inject_func_args(func, *apply_for)
return await func(**injected_args)
通过可变长参数 apply_for
与 inject_for
来判断是否存在循环依赖,还在获取依赖的时候递归调用 _apply
从而达到前文提到的 provider
也可以有依赖项的目的。被 @inject()
包装的代码其实就是调用 _apply(func)
。
插件系统
这是面向用户的接口,源码位于 shirasu/addon
下。还有一部分内置插件位于 shirasu/addons
,基本都是为了我测试而写的。
实现起来没什么难度,这里就简单贴一下使用方法,以 shirasu/addons/square.py
为例,这是一个计算平方的插件。
from pydantic import BaseModel
from shirasu import Client, Addon, MessageEvent, command
class SquareConfig(BaseModel):
precision: int = 2
square = Addon(
name='square',
usage='/square number',
description='Calculates the square of given number.',
config_model=SquareConfig,
)
@square.receive(command('square'))
async def handle_square(client: Client, event: MessageEvent, config: SquareConfig) -> None:
arg = event.arg
try:
result = round(float(arg) ** 2, config.precision)
await client.send(f'{result:g}')
except ValueError:
await client.reject(f'Invalid number: {arg}')
可以看到内部使用 pydantic
进行数据配置,这个在下个章节细说,这里先跳过。
被 @square.receive(...)
装饰的函数默认就也会被注入,所以不需要手动写 @inject()
。
在加载插件的时候,使用了 importlib
这个内置库,原理就是扫描整个模块,把 Addon
的实例加载进来而已,源码位于 shirasu/addon/pool.py
。
def load_module(self, module_name: str) -> 'AddonPool':
try:
module = importlib.import_module(module_name)
except ImportError as e:
raise LoadAddonError(f'failed to load addons in module {module_name}') from e
addons = [p for p in module.__dict__.values() if isinstance(p, Addon)]
if not addons:
raise LoadAddonError(f'no addons in module {module_name}')
for addon in addons:
self.load(addon)
return self
配置系统
本项目采用 YAML
与 pydantic
进行配置,相对于 .env
的优势不必多说,我本人也是比较喜欢 YAML
或者 JSON
的配置模式的。
为了加载 yml
文件,需要依赖 pyyaml
这个库,以下源码位于 shirasu/config.py
。
import yaml
from typing import Any
from pathlib import Path
from pydantic import BaseModel
class GlobalConfig(BaseModel):
"""
The global configuration.
"""
ws: str = 'ws://127.0.0.1:8080'
addons: dict[str, dict[str, Any]] = {}
superusers: list[int] = []
action_timeout: float = 30.
command_prefixes: list[str] = ['/']
command_separator: str = '\\s+'
def load_config(path: str | Path) -> GlobalConfig:
return GlobalConfig.parse_obj(yaml.safe_load(Path(path).read_text('utf8')))
以下为各项解释:
ws
:正向WebSocket
地址。addons
:每个插件的配置项,下文会说明。superusers
:超级用户,这可以使用superuser()
这个Rule
对象来指定某个插件只为超级用户开放。action_timeout
:运行每个ws action
的超时时间,单位为秒。command_prefixes
:命令前缀,如果不想让命令前都加上/
可以加一个空字符串到这个配置项。command_separator
:命令间的分隔符,正则表达式,默认为所有空白字符。
对于 addons
这个配置项,这是每个插件具体的配置,如 square
这个插件配置它的精度,就在 yml
里这么写:
addons:
square:
precision: 3
在 shirasu/addon/addon.py
中,会为每个插件的 receiver
注入一个它自身的配置,如下:
def _provide_config(self) -> None:
async def provide(global_config: GlobalConfig) -> Any:
return self._config_model.parse_obj(global_config.addons.get(self._name, {}))
di.provide('config', provide, check_duplicate=False)
在调用 matcher
和 receiver
前都会先调用 _provide_config
来为它们提供插件的配置,其中 self._config_model
是 Type[pydantic.BaseModel]
,当插件被定义的时候需要指定 config_model
,就像上文中的 square
插件做到的那样。
这里顺带提一下,@provide()
装饰器只是对 di.provide
这个方法进行了包装,在框架内部有大量代码都直接使用 di.provide
来提供依赖。
测试系统
核心步骤就是发一条假信息,收一条信息,判断是否符合要求。
本项目提供了一个简单的方法进行单元测试,推荐使用 pytest
+ pytest-asyncio
进行单元测试。
import pytest
import asyncio
from shirasu import MockClient, AddonPool
@pytest.mark.asyncio
async def test_square():
pool = AddonPool.from_modules('shirasu.addons.square')
client = MockClient(pool)
await client.post_message('/square 2')
square2_msg = await client.get_message()
assert square2_msg.plain_text == '4'
await client.post_message('/square a')
rejected_msg = await client.get_message_event()
assert rejected_msg.is_rejected
@pytest.mark.asyncio
async def test_echo():
pool = AddonPool.from_modules('shirasu.addons.echo')
client = MockClient(pool)
await client.post_message('/echo hello')
echo_msg = await client.get_message()
assert echo_msg.plain_text == 'hello'
await client.post_message('echo hello')
with pytest.raises(asyncio.TimeoutError):
await client.get_message()
如果不希望它有任何回复,可以捕捉 asyncio.TimeoutError
,下面是 shirasu/client/mock.py
中的核心代码。
_message_event_queue: asyncio.queues.Queue[MessageEvent]
async def post_event(self, event: Event) -> None:
self.curr_event = event
await self.apply_addons()
async def get_message_event(self, timeout: float = .1) -> MessageEvent:
return await asyncio.wait_for(self._message_event_queue.get(), timeout)
通过 asyncio
提供的 queue
与 asyncio.wait_for
来实现一个超时自动报错的效果。
虽然当调用 await client.post_message(...)
后就已经将消息添加到队列中了,但在自定义函数中可能会使用 asyncio.create_task
或者其它方式并行运行,最终还是使用了超时机制,而不是直接 get_nowait
,它会在没有值时抛出 QueueEmpty
异常。
特别感谢
- nonebot-adapter-onebot:部分代码参考 NB 实现,NB 就是 NB!
- go-cqhttp:开发过程中主要参考
go-cqhttp
的文档。 - voidbot:最核心的功能都在这200行代码中实现了,对我帮助很大,但他写的是同步逻辑。
- arashi:这个仓库目的也是提供一个功能比较完善的最小实现。