跳转至

service

Ariadne 的 launart 服务相关

ElizabethService 🔗

Bases: Service

ElizabethService, Ariadne 的直接后端

Source code in src/graia/ariadne/service.py
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
class ElizabethService(Service):
    """ElizabethService, Ariadne 的直接后端"""

    id = "elizabeth.service"
    supported_interface_types = {ConnectionInterface}
    http_interface: AiohttpClientInterface
    connections: Dict[int, ConnectionMixin[U_Info]]
    broadcast: Broadcast

    def __init__(self) -> None:
        """初始化 ElizabethService"""
        import creart

        self.connections = {}
        self.broadcast = creart.it(Broadcast)

        if ContextDispatcher not in self.broadcast.prelude_dispatchers:
            self.broadcast.prelude_dispatchers.append(ContextDispatcher)
        if LaunartInterfaceDispatcher not in self.broadcast.prelude_dispatchers:
            self.broadcast.prelude_dispatchers.append(LaunartInterfaceDispatcher)
        if NoneDispatcher not in self.broadcast.finale_dispatchers:
            self.broadcast.finale_dispatchers.append(NoneDispatcher)

        super().__init__()

    @staticmethod
    def base_telemetry() -> None:
        """执行基础遥测检查"""
        output: List[str] = [""]
        dist_map: Dict[str, str] = get_dist_map()
        output.extend(
            " ".join(
                [
                    f"[blue]{name}[/]:" if name.startswith("graiax-") else f"[magenta]{name}[/]:",
                    f"[green]{version}[/]",
                ]
            )
            for name, version in dist_map.items()
        )
        output.sort()
        output.insert(0, f"[cyan]{ARIADNE_ASCII_LOGO}[/]")
        rich_output = "\n".join(output)
        logger.opt(colors=True).info(
            rich_output.replace("[", "<").replace("]", ">"), alt=rich_output, highlighter=None
        )

    @staticmethod
    async def check_update() -> None:
        """执行更新检查"""
        output: List[str] = []
        dist_map: Dict[str, str] = get_dist_map()
        async with ClientSession() as session:
            await asyncio.gather(
                *(check_update(session, name, version, output) for name, version in dist_map.items())
            )
        output.sort()
        if output:
            output = (
                ["", "[bold]", f"[red]{len(output)}[/] [yellow]update(s) available:[/]"] + output + ["[/]"]
            )
            rich_output = "\n".join(output)
            logger.opt(colors=True).warning(
                rich_output.replace("[", "<").replace("]", ">"), alt=rich_output, highlighter=None
            )
        else:
            logger.opt(colors=True).success("All dependencies up to date!", style="green")

    def add_infos(self, infos: Iterable[U_Info]) -> Tuple[List[ConnectionMixin], int]:
        """通过传入的 Info 对象创建 Connection"""
        infos = list(infos)
        if not infos:
            raise AriadneConfigurationError("No configs provided")

        account: int = infos[0].account
        if account in self.connections:
            raise AriadneConfigurationError(f"Account {account} already exists")
        if len({i.account for i in infos}) != 1:
            raise AriadneConfigurationError("All configs must be for the same account")

        infos.sort(key=lambda x: isinstance(x, HttpClientInfo))
        # make sure the http client is the last one
        conns: List[ConnectionMixin] = [self.add_info(conf) for conf in infos]
        return conns, account

    def add_info(self, config: U_Info) -> ConnectionMixin:
        """添加单个 Info"""
        account: int = config.account
        connection = CONFIG_MAP[config.__class__](config)
        if account not in self.connections:
            self.connections[account] = connection
        elif isinstance(connection, HttpClientConnection):
            upstream_conn = self.connections[account]
            if upstream_conn.fallback:
                raise ValueError(f"{upstream_conn} already has fallback connection")
            connection.status = upstream_conn.status
            connection.is_hook = True
            upstream_conn.fallback = connection
        else:
            raise ValueError(f"Connection {self.connections[account]} conflicts with {connection}")
        return connection

    async def launch(self, mgr: Launart):
        """Launart 启动点"""
        from .app import Ariadne
        from .context import enter_context
        from .event.lifecycle import AccountLaunch, AccountShutdown, ApplicationLaunch, ApplicationShutdown

        self.base_telemetry()
        async with self.stage("preparing"):
            self.http_interface = mgr.get_interface(AiohttpClientInterface)
            if "default_account" in Ariadne.options:
                app = Ariadne.current()
                with enter_context(app=app):
                    self.broadcast.postEvent(ApplicationLaunch(app))
            for conn in self.connections.values():
                app = Ariadne.current(conn.info.account)

                def _disconnect_cb():
                    from graia.ariadne.event.lifecycle import AccountConnectionFail

                    self.broadcast.postEvent(AccountConnectionFail(app))

                conn._connection_fail = _disconnect_cb

                with enter_context(app=app):
                    self.broadcast.postEvent(AccountLaunch(app))

        async with self.stage("cleanup"):
            logger.info("Elizabeth Service cleaning up...", style="dark_orange")
            if "default_account" in Ariadne.options:
                app = Ariadne.current()
                if app.connection.status.available:
                    with enter_context(app=app):
                        await self.broadcast.postEvent(ApplicationShutdown(app))
            for conn in self.connections.values():
                if conn.status.available:
                    app = Ariadne.current(conn.info.account)
                    with enter_context(app=app):
                        await self.broadcast.postEvent(AccountShutdown(app))

            for task in asyncio.all_tasks():
                if task.done():
                    continue
                coro: Coroutine = task.get_coro()  # type: ignore
                if coro.__qualname__ == "Broadcast.Executor":
                    task.cancel()
                    logger.debug(f"Cancelled {task.get_name()} (Broadcast.Executor)")

            logger.info("Checking for updates...", alt="[cyan]Checking for updates...[/]")
            await self.check_update()

    @property
    def client_session(self) -> ClientSession:
        """获取 aiohttp 的 ClientSession

        Returns:
            ClientSession: ClientSession 对象
        """
        return self.http_interface.service.session

    @property
    def required(self):
        dependencies = {AiohttpClientInterface}
        for conn in self.connections.values():
            dependencies |= conn.dependencies
            dependencies.add(conn.id)
        return dependencies

    @property
    def stages(self):
        return {"preparing", "cleanup"}

    @property
    def loop(self) -> asyncio.AbstractEventLoop:
        """获取绑定的事件循环

        Returns:
            asyncio.AbstractEventLoop: 事件循环
        """
        return it(asyncio.AbstractEventLoop)

    @overload
    def get_interface(self, interface_type: Type[ConnectionInterface]) -> ConnectionInterface:
        ...

    @overload
    def get_interface(self, interface_type: type) -> None:
        ...

    def get_interface(self, interface_type: type):
        if interface_type is ConnectionInterface:
            return ConnectionInterface(self)

client_session property 🔗

获取 aiohttp 的 ClientSession

Returns:

Name Type Description
ClientSession ClientSession

ClientSession 对象

loop property 🔗

获取绑定的事件循环

Returns:

Type Description
AbstractEventLoop

asyncio.AbstractEventLoop: 事件循环

__init__() 🔗

初始化 ElizabethService

Source code in src/graia/ariadne/service.py
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def __init__(self) -> None:
    """初始化 ElizabethService"""
    import creart

    self.connections = {}
    self.broadcast = creart.it(Broadcast)

    if ContextDispatcher not in self.broadcast.prelude_dispatchers:
        self.broadcast.prelude_dispatchers.append(ContextDispatcher)
    if LaunartInterfaceDispatcher not in self.broadcast.prelude_dispatchers:
        self.broadcast.prelude_dispatchers.append(LaunartInterfaceDispatcher)
    if NoneDispatcher not in self.broadcast.finale_dispatchers:
        self.broadcast.finale_dispatchers.append(NoneDispatcher)

    super().__init__()

add_info(config) 🔗

添加单个 Info

Source code in src/graia/ariadne/service.py
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
def add_info(self, config: U_Info) -> ConnectionMixin:
    """添加单个 Info"""
    account: int = config.account
    connection = CONFIG_MAP[config.__class__](config)
    if account not in self.connections:
        self.connections[account] = connection
    elif isinstance(connection, HttpClientConnection):
        upstream_conn = self.connections[account]
        if upstream_conn.fallback:
            raise ValueError(f"{upstream_conn} already has fallback connection")
        connection.status = upstream_conn.status
        connection.is_hook = True
        upstream_conn.fallback = connection
    else:
        raise ValueError(f"Connection {self.connections[account]} conflicts with {connection}")
    return connection

add_infos(infos) 🔗

通过传入的 Info 对象创建 Connection

Source code in src/graia/ariadne/service.py
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
def add_infos(self, infos: Iterable[U_Info]) -> Tuple[List[ConnectionMixin], int]:
    """通过传入的 Info 对象创建 Connection"""
    infos = list(infos)
    if not infos:
        raise AriadneConfigurationError("No configs provided")

    account: int = infos[0].account
    if account in self.connections:
        raise AriadneConfigurationError(f"Account {account} already exists")
    if len({i.account for i in infos}) != 1:
        raise AriadneConfigurationError("All configs must be for the same account")

    infos.sort(key=lambda x: isinstance(x, HttpClientInfo))
    # make sure the http client is the last one
    conns: List[ConnectionMixin] = [self.add_info(conf) for conf in infos]
    return conns, account

base_telemetry() staticmethod 🔗

执行基础遥测检查

Source code in src/graia/ariadne/service.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
@staticmethod
def base_telemetry() -> None:
    """执行基础遥测检查"""
    output: List[str] = [""]
    dist_map: Dict[str, str] = get_dist_map()
    output.extend(
        " ".join(
            [
                f"[blue]{name}[/]:" if name.startswith("graiax-") else f"[magenta]{name}[/]:",
                f"[green]{version}[/]",
            ]
        )
        for name, version in dist_map.items()
    )
    output.sort()
    output.insert(0, f"[cyan]{ARIADNE_ASCII_LOGO}[/]")
    rich_output = "\n".join(output)
    logger.opt(colors=True).info(
        rich_output.replace("[", "<").replace("]", ">"), alt=rich_output, highlighter=None
    )

check_update() async staticmethod 🔗

执行更新检查

Source code in src/graia/ariadne/service.py
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
@staticmethod
async def check_update() -> None:
    """执行更新检查"""
    output: List[str] = []
    dist_map: Dict[str, str] = get_dist_map()
    async with ClientSession() as session:
        await asyncio.gather(
            *(check_update(session, name, version, output) for name, version in dist_map.items())
        )
    output.sort()
    if output:
        output = (
            ["", "[bold]", f"[red]{len(output)}[/] [yellow]update(s) available:[/]"] + output + ["[/]"]
        )
        rich_output = "\n".join(output)
        logger.opt(colors=True).warning(
            rich_output.replace("[", "<").replace("]", ">"), alt=rich_output, highlighter=None
        )
    else:
        logger.opt(colors=True).success("All dependencies up to date!", style="green")

launch(mgr) async 🔗

Launart 启动点

Source code in src/graia/ariadne/service.py
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
async def launch(self, mgr: Launart):
    """Launart 启动点"""
    from .app import Ariadne
    from .context import enter_context
    from .event.lifecycle import AccountLaunch, AccountShutdown, ApplicationLaunch, ApplicationShutdown

    self.base_telemetry()
    async with self.stage("preparing"):
        self.http_interface = mgr.get_interface(AiohttpClientInterface)
        if "default_account" in Ariadne.options:
            app = Ariadne.current()
            with enter_context(app=app):
                self.broadcast.postEvent(ApplicationLaunch(app))
        for conn in self.connections.values():
            app = Ariadne.current(conn.info.account)

            def _disconnect_cb():
                from graia.ariadne.event.lifecycle import AccountConnectionFail

                self.broadcast.postEvent(AccountConnectionFail(app))

            conn._connection_fail = _disconnect_cb

            with enter_context(app=app):
                self.broadcast.postEvent(AccountLaunch(app))

    async with self.stage("cleanup"):
        logger.info("Elizabeth Service cleaning up...", style="dark_orange")
        if "default_account" in Ariadne.options:
            app = Ariadne.current()
            if app.connection.status.available:
                with enter_context(app=app):
                    await self.broadcast.postEvent(ApplicationShutdown(app))
        for conn in self.connections.values():
            if conn.status.available:
                app = Ariadne.current(conn.info.account)
                with enter_context(app=app):
                    await self.broadcast.postEvent(AccountShutdown(app))

        for task in asyncio.all_tasks():
            if task.done():
                continue
            coro: Coroutine = task.get_coro()  # type: ignore
            if coro.__qualname__ == "Broadcast.Executor":
                task.cancel()
                logger.debug(f"Cancelled {task.get_name()} (Broadcast.Executor)")

        logger.info("Checking for updates...", alt="[cyan]Checking for updates...[/]")
        await self.check_update()

check_update(session, name, current, output) async 🔗

在线检查更新

Source code in src/graia/ariadne/service.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
async def check_update(session: ClientSession, name: str, current: str, output: List[str]) -> None:
    """在线检查更新"""
    result: str = current
    result_version = current_version = Version(current)
    try:
        async with session.get(f"http://mirrors.aliyun.com/pypi/web/json/{name}") as resp:
            data = await resp.text()
            result: str = json.loads(data)["info"]["version"]
            result_version = Version(result)
    except Exception as e:
        logger.warning(f"Failed to retrieve latest version of {name}: {e}")
    if result_version > current_version:
        output.append(
            " ".join(
                [
                    f"[cyan]{name}[/]:" if name.startswith("graiax-") else f"[magenta]{name}[/]:",
                    f"[green]{current}[/]",
                    f"-> [yellow]{result}[/]",
                ]
            )
        )

get_dist_map() 🔗

获取与项目相关的发行字典

Source code in src/graia/ariadne/service.py
54
55
56
57
58
59
60
61
62
63
64
def get_dist_map() -> Dict[str, str]:
    """获取与项目相关的发行字典"""
    dist_map: dict[str, str] = {}
    for dist in importlib.metadata.distributions():
        name: str = dist.metadata["Name"]
        version: str = dist.metadata["Version"]
        if not name or not version:
            continue
        if name.startswith(monitored_prefix):
            dist_map[name] = max(version, dist_map.get(name, ""))
    return dist_map