技術ノートへ戻る
automation
DiscordOllamaSQLitePython

OllamaとSQLiteを使ったDiscordボット

ローカルのOllamaインスタンスと統合し、SQLiteでユーザー履歴を追跡する非同期Discordボットの完全ガイド。

このスニペットは、Ollamaを介してローカルLLMを使用して会話するDiscordボットの堅牢なクラスベースの実装を提供します。非同期メッセージ処理、並行性制限(セマフォ)、ストリーミング応答、およびSQLiteを使用した永続的なユーザー履歴追跡を特徴としています。

特徴

  • 非同期 & 並行処理: asyncio キューとセマフォを使用して、複数のチャンネルを効率的に処理し、ローカルLLMに過負荷をかけません。
  • ストリーミング応答: OllamaからのトークンをDiscordにリアルタイム(またはチャンク単位)でストリーミングし、レスポンス性を向上させます。
  • コンテキスト認識: ユーザー/チャンネルごとの会話履歴を維持し、設定可能な制限までトリミングします。
  • 永続性: ユーザーのアクティビティ(最終アクセス日時、ユーザー名)をローカルSQLiteデータベースに記録します。
  • クリーンアーキテクチャ: ConfigDatabaseOllamaClientBotクラスにリファクタリングされ、SOLID原則に従っています。

コード

このファイルを discord_ollama_bot.py として保存してください。

import discord
import httpx
import asyncio
import json
import sqlite3
import logging
from datetime import datetime, timezone
from typing import Dict, List, Optional, Set, Tuple

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S"
)
logger = logging.getLogger(__name__)

class Config:
    """Application configuration constants."""
    # REPLACE THIS WITH YOUR ACTUAL TOKEN
    DISCORD_TOKEN = "YOUR_DISCORD_BOT_TOKEN"

    # Ollama Configuration
    OLLAMA_MODEL = "llama3"  # Change to your preferred model
    OLLAMA_URL = "http://localhost:11434/v1/chat/completions"

    # Bot Settings
    MAX_HISTORY_LENGTH = 6
    MAX_CONCURRENT_WORKERS = 2
    DB_NAME = "bot_data.db"

class DatabaseManager:
    """Handles SQLite database interactions."""

    def __init__(self, db_name: str):
        self.db_name = db_name
        self._init_db()

    def _get_connection(self) -> sqlite3.Connection:
        return sqlite3.connect(self.db_name)

    def _init_db(self):
        """Initializes the database schema."""
        with self._get_connection() as conn:
            cursor = conn.cursor()
            cursor.execute("""
            CREATE TABLE IF NOT EXISTS users (
                user_id TEXT PRIMARY KEY,
                user_name TEXT,
                last_message TEXT,
                last_seen TIMESTAMP
            )
            """)
            conn.commit()

    def register_user(self, user_id: str, user_name: str, last_message: str):
        """Upserts user data into the database."""
        try:
            with self._get_connection() as conn:
                cursor = conn.cursor()
                cursor.execute("""
                    INSERT INTO users (user_id, user_name, last_message, last_seen)
                    VALUES (?, ?, ?, ?)
                    ON CONFLICT(user_id) DO UPDATE SET
                        user_name=excluded.user_name,
                        last_message=excluded.last_message,
                        last_seen=excluded.last_seen
                """, (user_id, user_name, last_message, datetime.now(timezone.utc)))
                conn.commit()
        except sqlite3.Error as e:
            logger.error(f"Database error: {e}")

class OllamaClient:
    """Handles communication with the Ollama API."""

    def __init__(self, model: str, url: str):
        self.model = model
        self.url = url

    async def generate_response(self, messages: List[Dict[str, str]]) -> Tuple[str, str]:
        """
        Generates a response from Ollama.
        Returns a tuple of (full_text, error_message).
        """
        headers = {"Content-Type": "application/json"}
        payload = {
            "model": self.model,
            "messages": messages,
            "stream": True,
            "keep_alive": -1
        }

        full_text = ""
        try:
            async with httpx.AsyncClient(timeout=None) as client:
                async with client.stream("POST", self.url, json=payload, headers=headers) as response:
                    if response.status_code != 200:
                        return "", f"API Error: {response.status_code}"

                    async for line in response.aiter_lines():
                        line = line.strip()
                        if not line or not line.startswith("data: "):
                            continue

                        data = line[len("data: "):]
                        if data == "[DONE]":
                            break

                        try:
                            parsed = json.loads(data)
                            delta = parsed["choices"][0]["delta"].get("content", "")
                            full_text += delta
                        except json.JSONDecodeError:
                            continue

            return full_text, ""

        except httpx.RequestError as e:
            return "", f"Connection error: {str(e)}"

class WorkerManager:
    """Manages async workers for processing message queues per channel."""

    def __init__(self, process_callback):
        self.queues: Dict[str, asyncio.Queue] = {}
        self.semaphores: Dict[str, asyncio.Semaphore] = {}
        self.active_channels: Set[str] = set()
        self.process_callback = process_callback

    async def enqueue(self, channel_id: str, item):
        """Adds an item to the channel's processing queue."""
        if channel_id not in self.queues:
            self.queues[channel_id] = asyncio.Queue()

        if channel_id not in self.semaphores:
            self.semaphores[channel_id] = asyncio.Semaphore(Config.MAX_CONCURRENT_WORKERS)

        await self.queues[channel_id].put(item)
        await self._start_workers(channel_id)

    async def _start_workers(self, channel_id: str):
        """Starts worker tasks for a channel if not already running."""
        if channel_id in self.active_channels:
            return

        self.active_channels.add(channel_id)
        queue = self.queues[channel_id]
        semaphore = self.semaphores[channel_id]

        async def worker():
            while True:
                try:
                    # Wait for a task, but timeout if idle to clean up
                    try:
                        # Simple timeout to allow worker cleanup if queue is empty for a while
                        # In a real app, you might want more sophisticated lifecycle management
                        item = await asyncio.wait_for(queue.get(), timeout=60.0)
                    except asyncio.TimeoutError:
                        break

                    async with semaphore:
                        await self.process_callback(item)

                    queue.task_done()
                except Exception as e:
                    logger.error(f"Worker error in channel {channel_id}: {e}")

        # Launch workers
        for _ in range(Config.MAX_CONCURRENT_WORKERS):
            asyncio.create_task(worker())

        # Note: In this simple implementation, we don't strictly wait for workers to finish
        # before removing from active_channels to keep it simple.
        # A more robust solution would track tasks.

class DiscordBot(discord.Client):
    """Main Bot Class."""

    def __init__(self):
        intents = discord.Intents.default()
        intents.messages = True
        intents.message_content = True
        super().__init__(intents=intents)

        self.db = DatabaseManager(Config.DB_NAME)
        self.ollama = OllamaClient(Config.OLLAMA_MODEL, Config.OLLAMA_URL)
        self.worker_manager = WorkerManager(self.process_message_task)

        # In-memory conversation history: { "channel_id:user_id": [messages] }
        self.conversations: Dict[str, List[Dict[str, str]]] = {}

    async def on_ready(self):
        logger.info(f"✅ Bot connected as {self.user}")

    async def process_message_task(self, task_data):
        """Callback function executed by workers."""
        message, prompt = task_data
        channel_id = str(message.channel.id)
        author_id = str(message.author.id)
        conversation_key = f"{channel_id}:{author_id}"

        # Initialize history if needed
        if conversation_key not in self.conversations:
            self.conversations[conversation_key] = []

        # Update history
        history = self.conversations[conversation_key]
        history.append({"role": "user", "content": prompt})

        # Trim history
        if len(history) > Config.MAX_HISTORY_LENGTH:
            history = history[-Config.MAX_HISTORY_LENGTH:]
            self.conversations[conversation_key] = history

        # Show typing indicator
        async with message.channel.typing():
            response_text, error = await self.ollama.generate_response(history)

        if error:
            await message.channel.send(f"❌ Error: {error}")
            return

        # Discord message limit handling
        if len(response_text) > 2000:
            response_text = response_text[-2000:]

        await message.reply(response_text, mention_author=True)

        # Update history with assistant response
        history.append({"role": "assistant", "content": response_text})

        # Log to DB
        self.db.register_user(author_id, message.author.name, prompt)

    async def on_message(self, message: discord.Message):
        if message.author.bot:
            return

        content = message.content.strip()
        channel_id = str(message.channel.id)
        author_id = str(message.author.id)

        # Command: Reset Context
        if content == "!reset":
            key = f"{channel_id}:{author_id}"
            self.conversations.pop(key, None)
            await message.channel.send("🔄 Conversation context reset.")
            return

        # Check if bot is mentioned or invoked via command
        is_mentioned = self.user in message.mentions
        is_command = content.lower().startswith("!ai ")

        prompt = ""
        if is_command:
            prompt = content[4:].strip()
        elif is_mentioned:
            # Remove mention from prompt
            prompt = content.replace(f"<@{self.user.id}>", "").strip()

        if prompt:
            await self.worker_manager.enqueue(channel_id, (message, prompt))

if __name__ == "__main__":
    bot = DiscordBot()
    bot.run(Config.DISCORD_TOKEN)

セットアップ

  1. 依存関係のインストール:
     pip install discord.py httpx aiofiles
  2. Ollama: Ollamaが実行中であることを確認してください (ollama serve) そしてモデルをプルしてください(例: ollama pull llama3)。
  3. トークン: Config クラス内の DISCORD_TOKEN を、Discord Developer Portalから取得した実際のボットトークンに置き換えてください。

実際のDiscordトークンをパブリックリポジトリにコミットしないでください。本番環境では環境変数(例: python-dotenv)を使用してください。