echo.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. from typing import Type, Optional
  2. from time import time
  3. from html import escape
  4. from mautrix.types import MediaMessageEventContent, TextMessageEventContent, MessageType, Format, RelatesTo, RelationType, RoomID
  5. from mautrix.util.config import BaseProxyConfig, ConfigUpdateHelper
  6. from maubot import Plugin, MessageEvent
  7. from maubot.handlers import command
  8. import json
  9. import openai
  10. class Config(BaseProxyConfig):
  11. def do_update(self, helper: ConfigUpdateHelper) -> None:
  12. helper.copy("gpt_on")
  13. helper.copy("gpt_room_id")
  14. helper.copy("gpt_apikey")
  15. helper.copy("gpt_model_engine")
  16. helper.copy("gpt_temperature")
  17. helper.copy("gpt_top_p")
  18. helper.copy("gpt_presence_penalty")
  19. helper.copy("gpt_frequency_penalty")
  20. helper.copy("gpt_echo")
  21. helper.copy("gpt_stop")
  22. helper.copy("gpt_n")
  23. helper.copy("gpt_stream")
  24. helper.copy("gpt_logprobs")
  25. helper.copy("gpt_best_of")
  26. helper.copy("gpt_logit_bias")
  27. class EchoBot(Plugin):
  28. @classmethod
  29. def get_config_class(cls) -> Type[BaseProxyConfig]:
  30. return Config
  31. async def start(self) -> None:
  32. await super().start()
  33. self.config.load_and_update()
  34. self.http = self.client.api.session
  35. async def stop(self) -> None:
  36. await super().stop()
  37. @staticmethod
  38. def plural(num: float, unit: str, decimals: Optional[int] = None) -> str:
  39. num = round(num, decimals)
  40. if num == 1:
  41. return f"{num} {unit}"
  42. else:
  43. return f"{num} {unit}s"
  44. @classmethod
  45. def prettify_diff(cls, diff: int) -> str:
  46. if abs(diff) < 10 * 1_000:
  47. return f"{diff} ms"
  48. elif abs(diff) < 60 * 1_000:
  49. return cls.plural(diff / 1_000, 'second', decimals=1)
  50. minutes, seconds = divmod(diff / 1_000, 60)
  51. if abs(minutes) < 60:
  52. return f"{cls.plural(minutes, 'minute')} and {cls.plural(seconds, 'second')}"
  53. hours, minutes = divmod(minutes, 60)
  54. if abs(hours) < 24:
  55. return (f"{cls.plural(hours, 'hour')}, {cls.plural(minutes, 'minute')}"
  56. f" and {cls.plural(seconds, 'second')}")
  57. days, hours = divmod(hours, 24)
  58. return (f"{cls.plural(days, 'day')}, {cls.plural(hours, 'hour')}, "
  59. f"{cls.plural(minutes, 'minute')} and {cls.plural(seconds, 'second')}")
  60. @command.new("ping", help="Ping")
  61. @command.argument("message", pass_raw=True, required=False)
  62. async def ping_handler(self, evt: MessageEvent, message: str = "") -> None:
  63. diff = int(time() * 1000) - evt.timestamp
  64. pretty_diff = self.prettify_diff(diff)
  65. text_message = f'"{message[:20]}" took' if message else "took"
  66. html_message = f'"{escape(message[:20])}" took' if message else "took"
  67. content = TextMessageEventContent(
  68. msgtype=MessageType.NOTICE, format=Format.HTML,
  69. body=f"{evt.sender}: Pong! (ping {text_message} {pretty_diff} to arrive)",
  70. formatted_body=f"<a href='https://matrix.example.pl/#/{evt.sender}'>{evt.sender}</a>: Pong! "
  71. f"(<a href='https://matrix.example.pl/#/{evt.room_id}/{evt.event_id}'>ping</a> {html_message} "
  72. f"{pretty_diff} to arrive)",
  73. relates_to=RelatesTo(
  74. rel_type=RelationType("xyz.maubot.gpt.echo"),
  75. event_id=evt.event_id,
  76. ))
  77. pong_from = evt.sender.split(":", 1)[1]
  78. content.relates_to["from"] = pong_from
  79. content.relates_to["ms"] = diff
  80. content["pong"] = {
  81. "ms": diff,
  82. "from": pong_from,
  83. "ping": evt.event_id,
  84. }
  85. await evt.respond(content)
  86. @command.new("echo", help="Repeat a message")
  87. @command.argument("message", pass_raw=True)
  88. async def echo_handler(self, evt: MessageEvent, message: str) -> None:
  89. await evt.respond(message)
  90. @command.new("gpt", help="ChatGPT response")
  91. @command.argument("message", pass_raw=True, required=False)
  92. async def gpt_handler(self, evt: MessageEvent, message: str = "") -> None:
  93. if self.config["gpt_on"] and evt.room_id in self.config["gpt_room_id"]:
  94. openai.api_key = self.config["gpt_apikey"]
  95. resp = openai.Completion.create(engine=self.config["gpt_model_engine"],
  96. prompt=message,
  97. max_tokens=int(self.config["gpt_max_tokens"]),
  98. temperature=float(self.config["gpt_temperature"]),
  99. top_p=int(self.config["gpt_top_p"]),
  100. presence_penalty=int(self.config["gpt_presence_penalty"]),
  101. frequency_penalty=int(self.config["gpt_frequency_penalty"]),
  102. echo=self.config["gpt_echo"],
  103. stop=self.config["gpt_stop"],
  104. n=int(self.config["gpt_n"]),
  105. stream=self.config["gpt_stream"],
  106. # logprobs=self.config["gpt_logprobs"],
  107. best_of=int(self.config["gpt_best_of"]),
  108. logit_bias={}
  109. )
  110. n = len(resp.choices)
  111. if n == 1:
  112. html_message = resp.choices[0].text
  113. text_message = resp.choices[0].text
  114. else:
  115. texts = []
  116. for idx in range(0, n):
  117. html_message.append(resp.choices[idx].text)
  118. text_message.append(resp.choices[idx].text)
  119. else:
  120. html_message = f'chatGPT jest wyłączony.'
  121. text_message = 'chatGPT jest wyłączony.'
  122. content = TextMessageEventContent(
  123. msgtype=MessageType.NOTICE, format=Format.HTML,
  124. body=f"{evt.sender}: {text_message}",
  125. formatted_body=f"{evt.sender}: "
  126. f"{html_message}",
  127. relates_to=RelatesTo(
  128. rel_type=RelationType("xyz.maubot.gpt.echo"),
  129. event_id=evt.event_id,
  130. ))
  131. pong_from = evt.sender.split(":", 1)[1]
  132. content.relates_to["from"] = pong_from
  133. await evt.respond(content)