最近写了一个基于 OneBot v11
的机器人开发框架,也就是 SDK
,其中包含了蛮多东西,所以单独写篇文章分享一下其中的细节和原理。
整个项目说起来还是比较复杂,所以这里只捡出几个核心实现。这个项目用来参考学习,没有投入生产环境的打算 。当然如果你愿意或者喜欢本项目的风格,部署到生产环境是没问题的 :)
项目地址:https://github.com/kifuan/shirasu
使用方法
已上传至 pypi
,可用 pip
安装。
这里把 README
里的示例粘贴过来。
1 2 3 4 5 6 7 8 9 10 import asynciofrom shirasu import AddonPool, OneBotClientif __name__ == '__main__' : pool = AddonPool.from_modules( 'shirasu.addons.echo' , 'shirasu.addons.help' , ) asyncio.run(OneBotClient.listen(pool=pool))
至于插件的定义方法下文会介绍,接下来在入口文件同级目录下创建 shirasu.yml
配置 WebSocket
地址。
1 2 ws: ws://127.0.0.1:8080
随后打开一个 OneBot v11
具体实现,如 go-cqhttp
,以正向ws 的方式进行连接。给机器人发 /echo hello
它就会正常回复一个 hello
了。
Task 隐患
根据 python/cpython#91887 ,使用 asyncio.create_task
创建的任务都只有一个弱引用。也就是说,如果你不手动将创建的任务储存起来,GC
可以把没执行的任务直接回收掉 。所以就有了这个 issue
,下方回复中给出可以使用 asyncio.TaskGroup
但是需要 Python 3.11+。我的开发环境是 3.10,所以手动处理一下这个问题。
解决方式很简单,用一个 set
把没完成的任务存起来就行,大致代码如下:
1 2 3 4 5 6 7 def __init__ (self ): self._tasks: set [asyncio.Task] = set () def use (self ): task = asyncio.create_task(...) self._tasks.add(task) task.add_done_callback(self._tasks.discard)
这样就可以保证没完成的任务不会被 GC
删除 ,扫除了一个隐患。
获取 WS 响应数据
使用 Python 中的 websockets
库,每次收到消息时都会使用 asyncio.create_task
创建一个 task
并行处理,所以这里就会有一个问题,就是当你发送之后并不能确定 下一次收到的是否为当前调用结果 ,所以 OneBot
定义了一个 echo
字段,同一个请求和返回时这个字段将是相同 的。
这里提一下,在 OneBot v11 中并没有规定 echo
的数据类型,而 v12
中规定是 string
。至于具体实现,参考 go-cqhttp
源码 发现传什么数据都无所谓。下面的 j
是 gjson.Parse
返回的对象。
1 2 3 4 ret := c.apiCaller.Call(t, j.Get("params" )) if j.Get("echo" ).Exists() { ret["echo" ] = j.Get("echo" ).Value() }
但是最新标准规定要传字符串,那就用字符串了。
据此,需要创建一个工具类,利用 asyncio.Future
实现数据的延时获取,下方源码位于 shirasu/util/future_table.py
。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 import sysimport asynciofrom typing import Any class FutureTable : def __init__ (self ) -> None : self._future_id = 0 self._futures: dict [int , asyncio.Future] = {} def register (self ) -> int : self._future_id = (self._future_id + 1 ) % sys.maxsize self._futures[self._future_id] = asyncio.get_event_loop().create_future() return self._future_id def set (self, echo: int , data: dict [str , Any ] ) -> None : if future := self._futures.get(echo): future.set_result(data) async def get (self, future_id: int , timeout: float ) -> dict [str , Any ]: if not (future := self._futures.get(future_id)): raise KeyError(f'future id {future_id} does not exist' ) try : return await asyncio.wait_for(future, timeout) finally : del self._futures[future_id]
在每一次收数据时,如果检查到存在 echo
字段,就将这个 FutureTable
中设置上对应的相应值,从而达到调用 API 并获取响应数据 的目的。这里附上核心代码,源码位于shirasu/client/onebot.py
。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 async def call_action (self, action: str , **params: Any ) -> dict [str , Any ]: logger.info(f'Calling action {action} .' ) future_id = self._futures.register() await self._ws.send(ujson.dumps({ 'action' : action, 'params' : params, 'echo' : str (future_id), })) data = await self._futures.get(future_id, self._global_config.action_timeout) if data.get('status' ) == 'failed' : raise ClientActionError(data) return data.get('data' , {}) async def _handle (self, data: dict [str , Any ] ) -> None : if echo := data.get('echo' ): self._futures.set (int (echo), data) return ...
这个 _handle
就是每次获取到数据时建立的 task
。
简单依赖注入框架
我个人比较喜欢 Spring
那种依赖注入。
1 2 @Autowired private FooService fooService;
但是在 FastAPI
中使用时往往需要你这么做。
1 2 def use (foo: Foo = Depends(get_foo ) ) -> None : ...
其实也有别的框架,可以做到 Spring
那种类型的,但是我还是按照我的想法实现了一个简单的框架,源码在 shirasu/di.py
,这里就贴上使用方法。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 import asynciofrom datetime import datetimefrom shirasu.di import inject, provide@provide('now' ) async def provide_now () -> datetime: return datetime.now() @provide('today' ) async def provide_today (now: datetime ) -> int : await asyncio.sleep(.1 ) return now.day @inject() async def use_today (today: int ) -> None : print (today) @inject() async def use_now (now: datetime ) -> None : await asyncio.sleep(.1 ) print (now.year) await use_today()await use_now()
推荐使用 IPython
测试这些,它支持顶层 await
。如果非要写到 .py
文件里请自行包装一层 async def main
然后使用 asyncio.run
运行。
为了保持一致性和方便使用,这里均采用异步函数,目的是方便在 provider
中进行异步操作。现在看上去不太方便,那是因为没有真正到应用场景——需要依赖异步操作的时候。
我个人认为,依赖注入最重要的就是你的 provider
也可以有自己的依赖项,比如上方代码的 today
就依赖于 now
,如果后期它还有别的依赖项,可以只修改它本身,其它地方均无需修改 。
此外,这个框架还会根据你提供的 type hint
来判断 provider
返回的类型和你在参数后面写的类型是否一致,当然如果是子类也可以,如果不一致它就会打印一条警告信息。你可以直接不标注类型来跳过类型检测 。
接下来聊聊实现的核心逻辑,它是根据你 provide
的时候提供的字符串来判断。以下几个函数是实现的核心,如果读者想看源码可以直接去文章开头提到的仓库下 shirasu/di.py
找到完整代码。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 class DependencyInjector : """ Dependency injector based on parameter names. Note: positional-only arguments are not supported. """ def __init__ (self ) -> None : self._providers: dict [str , Callable [..., Awaitable[T]]] = {} async def _inject_func_args (self, func: Callable [..., Awaitable[T]], *inject_for: str ) -> dict [str , Any ]: params = inspect.signature(func).parameters if unknown_deps := [dep for dep in params if dep not in self._providers]: raise UnknownDependencyError(unknown_deps) if circular_deps := [dep for dep in params if dep in inject_for]: raise CircularDependencyError(circular_deps) args = dict (zip (params, await asyncio.gather(*( self._apply(self._providers[dep], dep, *inject_for) for dep in params )))) for dep, param in params.items(): anno = param.annotation if anno == inspect.Parameter.empty: continue if not isinstance (val := args[dep], expected := anno): module = inspect.getmodule(func) module_name = module.__name__ if module else '<unknown module>' module_func_name = f'{module_name} :{func.__name__} ' logger.warning(f'type mismatch for parameter {dep} in function {module_func_name} , ' f'real type: {type (val).__name__} , expected: {expected.__name__} ' ) return args async def _apply (self, func: Callable [..., Awaitable[T]], *apply_for: str ) -> T: injected_args = await self._inject_func_args(func, *apply_for) return await func(**injected_args)
通过可变长参数 apply_for
与 inject_for
来判断是否存在循环依赖,还在获取依赖的时候递归调用 _apply
从而达到前文提到的 provider
也可以有依赖项的目的。被 @inject()
包装的代码其实就是调用 _apply(func)
。
插件系统
这是面向用户的接口,源码位于 shirasu/addon
下。还有一部分内置插件位于 shirasu/addons
,基本都是为了我测试而写的。
实现起来没什么难度,这里就简单贴一下使用方法,以 shirasu/addons/square.py
为例,这是一个计算平方的插件。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 from pydantic import BaseModelfrom shirasu import Client, Addon, MessageEvent, commandclass SquareConfig (BaseModel ): precision: int = 2 square = Addon( name='square' , usage='/square number' , description='Calculates the square of given number.' , config_model=SquareConfig, ) @square.receive(command('square' ) ) async def handle_square (client: Client, event: MessageEvent, config: SquareConfig ) -> None : arg = event.arg try : result = round (float (arg) ** 2 , config.precision) await client.send(f'{result:g} ' ) except ValueError: await client.reject(f'Invalid number: {arg} ' )
可以看到内部使用 pydantic
进行数据配置,这个在下个章节细说,这里先跳过。
被 @square.receive(...)
装饰的函数默认就也会被注入,所以不需要手动写 @inject()
。
在加载插件的时候,使用了 importlib
这个内置库,原理就是扫描整个模块,把 Addon
的实例加载进来而已,源码位于 shirasu/addon/pool.py
。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 def load_module (self, module_name: str ) -> 'AddonPool' : try : module = importlib.import_module(module_name) except ImportError as e: raise LoadAddonError(f'failed to load addons in module {module_name} ' ) from e addons = [p for p in module.__dict__.values() if isinstance (p, Addon)] if not addons: raise LoadAddonError(f'no addons in module {module_name} ' ) for addon in addons: self.load(addon) return self
配置系统
本项目采用 YAML
与 pydantic
进行配置,相对于 .env
的优势不必多说,我本人也是比较喜欢 YAML
或者 JSON
的配置模式的。
为了加载 yml
文件,需要依赖 pyyaml
这个库,以下源码位于 shirasu/config.py
。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 import yamlfrom typing import Any from pathlib import Pathfrom pydantic import BaseModelclass GlobalConfig (BaseModel ): """ The global configuration. """ ws: str = 'ws://127.0.0.1:8080' addons: dict [str , dict [str , Any ]] = {} superusers: list [int ] = [] action_timeout: float = 30. command_prefixes: list [str ] = ['/' ] command_separator: str = '\\s+' def load_config (path: str | Path ) -> GlobalConfig: return GlobalConfig.parse_obj(yaml.safe_load(Path(path).read_text('utf8' )))
以下为各项解释:
ws
:正向 WebSocket
地址。
addons
:每个插件的配置项,下文会说明。
superusers
:超级用户,这可以使用 superuser()
这个 Rule
对象来指定某个插件只为超级用户开放。
action_timeout
:运行每个 ws action
的超时时间,单位为秒。
command_prefixes
:命令前缀,如果不想让命令前都加上 /
可以加一个空字符串到这个配置项。
command_separator
:命令间的分隔符,正则表达式,默认为所有空白字符。
对于 addons
这个配置项,这是每个插件具体的配置,如 square
这个插件配置它的精度,就在 yml
里这么写:
1 2 3 addons: square: precision: 3
在 shirasu/addon/addon.py
中,会为每个插件的 receiver
注入一个它自身的配置,如下:
1 2 3 4 def _provide_config (self ) -> None : async def provide (global_config: GlobalConfig ) -> Any : return self._config_model.parse_obj(global_config.addons.get(self._name, {})) di.provide('config' , provide, check_duplicate=False )
在调用 matcher
和 receiver
前都会先调用 _provide_config
来为它们提供插件的配置,其中 self._config_model
是 Type[pydantic.BaseModel]
,当插件被定义的时候需要指定 config_model
,就像上文中的 square
插件做到的那样。
这里顺带提一下,@provide()
装饰器只是对 di.provide
这个方法进行了包装,在框架内部有大量代码都直接使用 di.provide
来提供依赖。
测试系统
核心步骤就是发一条假信息,收一条信息,判断是否符合要求。
本项目提供了一个简单的方法进行单元测试,推荐使用 pytest
+ pytest-asyncio
进行单元测试。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 import pytestimport asynciofrom shirasu import MockClient, AddonPool@pytest.mark.asyncio async def test_square (): pool = AddonPool.from_modules('shirasu.addons.square' ) client = MockClient(pool) await client.post_message('/square 2' ) square2_msg = await client.get_message() assert square2_msg.plain_text == '4' await client.post_message('/square a' ) rejected_msg = await client.get_message_event() assert rejected_msg.is_rejected @pytest.mark.asyncio async def test_echo (): pool = AddonPool.from_modules('shirasu.addons.echo' ) client = MockClient(pool) await client.post_message('/echo hello' ) echo_msg = await client.get_message() assert echo_msg.plain_text == 'hello' await client.post_message('echo hello' ) with pytest.raises(asyncio.TimeoutError): await client.get_message()
如果不希望它有任何回复,可以捕捉 asyncio.TimeoutError
,下面是 shirasu/client/mock.py
中的核心代码。
1 2 3 4 5 6 7 8 _message_event_queue: asyncio.queues.Queue[MessageEvent] async def post_event (self, event: Event ) -> None : self.curr_event = event await self.apply_addons() async def get_message_event (self, timeout: float = .1 ) -> MessageEvent: return await asyncio.wait_for(self._message_event_queue.get(), timeout)
通过 asyncio
提供的 queue
与 asyncio.wait_for
来实现一个超时自动报错的效果。
虽然当调用 await client.post_message(...)
后就已经将消息添加到队列中 了,但在自定义函数 中可能会使用 asyncio.create_task
或者其它方式并行运行,最终还是使用了超时机制,而不是直接 get_nowait
,它会在没有值时抛出 QueueEmpty
异常。
特别感谢
名称来源于 BA 的白洲梓,可爱捏。