354 lines
12 KiB
Python
354 lines
12 KiB
Python
"""
|
|
MCP Server exposing search and read tools for the indexed emails.
|
|
"""
|
|
|
|
import logging
|
|
import os
|
|
from typing import Any, Optional
|
|
|
|
import qdrant_client as qdrant_pkg
|
|
from dateutil import parser as date_parser
|
|
from email.utils import parseaddr
|
|
from dotenv import load_dotenv
|
|
from starlette.requests import Request
|
|
from starlette.responses import JSONResponse
|
|
from fastmcp import FastMCP
|
|
from qdrant_client import QdrantClient
|
|
from qdrant_client.http import models
|
|
from fastembed import TextEmbedding
|
|
|
|
load_dotenv()
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s %(levelname)s [%(name)s] %(message)s",
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Verbose MCP diagnostics (useful for -32602 / validation flow)
|
|
logging.getLogger("mcp").setLevel(logging.DEBUG)
|
|
logging.getLogger("mcp.server").setLevel(logging.DEBUG)
|
|
logging.getLogger("mcp.server.lowlevel.server").setLevel(logging.DEBUG)
|
|
|
|
QDRANT_URL = os.environ.get("QDRANT_URL", "")
|
|
COLLECTION_NAME = os.environ.get("COLLECTION_NAME", "")
|
|
EMBEDDING_MODEL_NAME = os.environ.get("EMBEDDING_MODEL_NAME", "BAAI/bge-small-en-v1.5")
|
|
DEFAULT_LIMIT = int(os.environ.get("SEARCH_LIMIT", "50"))
|
|
|
|
if not QDRANT_URL:
|
|
raise ValueError("QDRANT_URL environment variable is required.")
|
|
if not COLLECTION_NAME:
|
|
raise ValueError("COLLECTION_NAME environment variable is required.")
|
|
|
|
logger.info(f"Starting MCP server with collection: {COLLECTION_NAME}")
|
|
|
|
# Initialize FastMCP server
|
|
mcp = FastMCP("mcp-maildir")
|
|
|
|
@mcp.custom_route("/health", methods=["GET"])
|
|
async def health_check(request: Request):
|
|
"""Simple health check endpoint for Kubernetes liveness/readiness probes."""
|
|
return JSONResponse({"status": "ok"})
|
|
|
|
# Lazy singletons
|
|
_qdrant_client: Optional[QdrantClient] = None
|
|
_embedding_model: Optional[TextEmbedding] = None
|
|
|
|
|
|
def get_qdrant_client() -> QdrantClient:
|
|
"""Returns a singleton Qdrant client."""
|
|
global _qdrant_client
|
|
if _qdrant_client is None:
|
|
logger.info(f"Connecting to Qdrant at {QDRANT_URL}")
|
|
_qdrant_client = QdrantClient(url=QDRANT_URL)
|
|
return _qdrant_client
|
|
|
|
|
|
def get_embedding_model() -> TextEmbedding:
|
|
"""Returns a singleton sentence-transformer model."""
|
|
global _embedding_model
|
|
if _embedding_model is None:
|
|
logger.info(f"Loading embedding model: {EMBEDDING_MODEL_NAME}")
|
|
_embedding_model = TextEmbedding(model_name=EMBEDDING_MODEL_NAME)
|
|
return _embedding_model
|
|
|
|
|
|
def log_qdrant_diagnostics(client: QdrantClient) -> None:
|
|
"""Logs Qdrant client capabilities to diagnose API/version mismatch issues."""
|
|
version = getattr(qdrant_pkg, "__version__", "unknown")
|
|
logger.info(
|
|
"Qdrant diagnostics: client_type=%s, qdrant_client_version=%s, has_search=%s, has_query_points=%s, has_scroll=%s",
|
|
type(client).__name__,
|
|
version,
|
|
hasattr(client, "search"),
|
|
hasattr(client, "query_points"),
|
|
hasattr(client, "scroll"),
|
|
)
|
|
|
|
|
|
def normalize_email_address(value: Optional[str]) -> str:
|
|
"""Extracts and normalizes the bare email address from a value."""
|
|
if not value:
|
|
return ""
|
|
_, addr = parseaddr(value)
|
|
return (addr or value).strip().lower()
|
|
|
|
|
|
def payload_matches_participant(payload: dict[str, Any], participant: str) -> bool:
|
|
"""Checks if participant matches either sender or receiver on normalized and raw payload values."""
|
|
participant_norm = normalize_email_address(participant)
|
|
|
|
# Sender checks
|
|
payload_sender = normalize_email_address(payload.get("sender"))
|
|
payload_sender_raw = normalize_email_address(payload.get("sender_raw"))
|
|
payload_sender_text = str(payload.get("sender", "")).strip().lower()
|
|
payload_sender_raw_text = str(payload.get("sender_raw", "")).strip().lower()
|
|
|
|
# Receiver checks
|
|
payload_receiver = normalize_email_address(payload.get("receiver"))
|
|
payload_receiver_raw = normalize_email_address(payload.get("receiver_raw"))
|
|
payload_receiver_text = str(payload.get("receiver", "")).strip().lower()
|
|
payload_receiver_raw_text = str(payload.get("receiver_raw", "")).strip().lower()
|
|
|
|
return participant_norm in {
|
|
payload_sender,
|
|
payload_sender_raw,
|
|
payload_receiver,
|
|
payload_receiver_raw,
|
|
} or participant_norm in payload_sender_text \
|
|
or participant_norm in payload_sender_raw_text \
|
|
or participant_norm in payload_receiver_text \
|
|
or participant_norm in payload_receiver_raw_text
|
|
|
|
|
|
def build_filter(participant: Optional[str], start_date: Optional[str], end_date: Optional[str]) -> Optional[models.Filter]:
|
|
"""Builds Qdrant payload filters from optional parameters."""
|
|
conditions: list[models.FieldCondition] = []
|
|
|
|
if participant:
|
|
# Filter for either sender or receiver matching the participant email
|
|
normalized_participant = normalize_email_address(participant)
|
|
|
|
# In Qdrant, to do an OR condition, we use a Should clause within a Filter
|
|
participant_filter = models.Filter(
|
|
should=[
|
|
models.FieldCondition(
|
|
key="sender",
|
|
match=models.MatchValue(value=normalized_participant),
|
|
),
|
|
models.FieldCondition(
|
|
key="receiver",
|
|
match=models.MatchValue(value=normalized_participant),
|
|
)
|
|
]
|
|
)
|
|
# We append this compound filter as a requirement
|
|
conditions.append(participant_filter)
|
|
|
|
if start_date or end_date:
|
|
gte = date_parser.parse(start_date).isoformat() if start_date else None
|
|
lte = date_parser.parse(end_date).isoformat() if end_date else None
|
|
conditions.append(
|
|
models.FieldCondition(
|
|
key="date",
|
|
range=models.DatetimeRange(gte=gte, lte=lte),
|
|
)
|
|
)
|
|
|
|
return models.Filter(must=conditions) if conditions else None
|
|
|
|
|
|
def format_search_result(point: Any) -> dict[str, Any]:
|
|
"""Formats a Qdrant point into a compact response."""
|
|
payload = point.payload or {}
|
|
return {
|
|
"message_id": payload.get("message_id"),
|
|
"date": payload.get("date"),
|
|
"sender": payload.get("sender"),
|
|
"receiver": payload.get("receiver"),
|
|
"subject": payload.get("subject"),
|
|
"attachments": payload.get("attachments", []),
|
|
"score": getattr(point, "score", None),
|
|
}
|
|
|
|
|
|
def vector_search_points(
|
|
client: QdrantClient,
|
|
*,
|
|
query_vector: list[float],
|
|
query_filter: Optional[models.Filter],
|
|
limit: int,
|
|
) -> list[Any]:
|
|
"""Executes vector search across Qdrant client API variants."""
|
|
if hasattr(client, "query_points"):
|
|
response = client.query_points(
|
|
collection_name=COLLECTION_NAME,
|
|
query=query_vector,
|
|
query_filter=query_filter,
|
|
limit=limit,
|
|
with_payload=True,
|
|
with_vectors=False,
|
|
)
|
|
points = getattr(response, "points", None)
|
|
if points is None:
|
|
if isinstance(response, list):
|
|
return response
|
|
return []
|
|
return points
|
|
|
|
if hasattr(client, "search"):
|
|
return client.search(
|
|
collection_name=COLLECTION_NAME,
|
|
query_vector=query_vector,
|
|
query_filter=query_filter,
|
|
limit=limit,
|
|
with_payload=True,
|
|
with_vectors=False,
|
|
)
|
|
|
|
raise AttributeError("Qdrant client exposes neither 'query_points' nor 'search'.")
|
|
|
|
|
|
@mcp.tool()
|
|
def search_emails(
|
|
query: str,
|
|
participant: str = "",
|
|
start_date: str = "",
|
|
end_date: str = "",
|
|
):
|
|
"""
|
|
Performs a hybrid search (Semantic + Exact filtering on metadata).
|
|
"""
|
|
logger.info(
|
|
"Tool search_emails input diagnostics: query=%r(type=%s,len=%d), participant=%r(type=%s), start_date=%r(type=%s), end_date=%r(type=%s)",
|
|
query,
|
|
type(query).__name__,
|
|
len(query) if isinstance(query, str) else -1,
|
|
participant,
|
|
type(participant).__name__,
|
|
start_date,
|
|
type(start_date).__name__,
|
|
end_date,
|
|
type(end_date).__name__,
|
|
)
|
|
|
|
# Convert empty strings to None for internal logic
|
|
p_val = normalize_email_address(participant) if participant else None
|
|
sd_val = start_date if start_date else None
|
|
ed_val = end_date if end_date else None
|
|
|
|
logger.info(
|
|
"Tool search_emails normalized filters: participant=%r, start_date=%r, end_date=%r",
|
|
p_val,
|
|
sd_val,
|
|
ed_val,
|
|
)
|
|
try:
|
|
model = get_embedding_model()
|
|
qdrant = get_qdrant_client()
|
|
log_qdrant_diagnostics(qdrant)
|
|
|
|
query_vector = list(model.embed([query]))[0].tolist()
|
|
query_filter = build_filter(participant=p_val, start_date=sd_val, end_date=ed_val)
|
|
logger.info("Tool search_emails built filter: %s", query_filter)
|
|
|
|
points = vector_search_points(
|
|
qdrant,
|
|
query_vector=query_vector,
|
|
query_filter=query_filter,
|
|
limit=DEFAULT_LIMIT,
|
|
)
|
|
|
|
# Backward-compatible fallback for legacy indexed payloads where participant may
|
|
# still be stored as full "Display Name <email@domain>".
|
|
if p_val and not points:
|
|
logger.info("No result with exact participant filter, trying fallback matching...")
|
|
fallback_filter = build_filter(participant=None, start_date=sd_val, end_date=ed_val)
|
|
fallback_points = vector_search_points(
|
|
qdrant,
|
|
query_vector=query_vector,
|
|
query_filter=fallback_filter,
|
|
limit=max(DEFAULT_LIMIT * 5, 50),
|
|
)
|
|
points = [
|
|
p for p in fallback_points
|
|
if payload_matches_participant(p.payload or {}, p_val)
|
|
][:DEFAULT_LIMIT]
|
|
|
|
logger.info(f"Found {len(points)} results")
|
|
return {
|
|
"query": query,
|
|
"filters": {
|
|
"participant": participant,
|
|
"start_date": start_date,
|
|
"end_date": end_date,
|
|
},
|
|
"count": len(points),
|
|
"results": [format_search_result(point) for point in points],
|
|
}
|
|
except Exception as exc:
|
|
logger.error(f"Error in search_emails: {exc}", exc_info=True)
|
|
return {"error": f"search_emails failed: {exc}"}
|
|
|
|
|
|
@mcp.tool()
|
|
def read_email(message_id: str):
|
|
"""
|
|
Returns the full text content (cleaned of HTML) of a specific email.
|
|
"""
|
|
logger.info(
|
|
"Tool read_email input diagnostics: message_id=%r(type=%s,len=%d)",
|
|
message_id,
|
|
type(message_id).__name__,
|
|
len(message_id) if isinstance(message_id, str) else -1,
|
|
)
|
|
try:
|
|
qdrant = get_qdrant_client()
|
|
|
|
points, _ = qdrant.scroll(
|
|
collection_name=COLLECTION_NAME,
|
|
scroll_filter=models.Filter(
|
|
must=[
|
|
models.FieldCondition(
|
|
key="message_id",
|
|
match=models.MatchValue(value=message_id),
|
|
)
|
|
]
|
|
),
|
|
limit=1,
|
|
with_payload=True,
|
|
with_vectors=False,
|
|
)
|
|
|
|
if not points:
|
|
logger.warning(f"No email found for message_id={message_id}")
|
|
return {"error": f"No email found for message_id={message_id}"}
|
|
|
|
payload = points[0].payload or {}
|
|
return {
|
|
"message_id": payload.get("message_id"),
|
|
"date": payload.get("date"),
|
|
"sender": payload.get("sender"),
|
|
"receiver": payload.get("receiver"),
|
|
"subject": payload.get("subject"),
|
|
"attachments": payload.get("attachments", []),
|
|
"body_text": payload.get("body_text", ""),
|
|
}
|
|
except Exception as exc:
|
|
logger.error(f"Error in read_email: {exc}", exc_info=True)
|
|
return {"error": f"read_email failed: {exc}"}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
logger.info("Initializing models before starting server...")
|
|
try:
|
|
get_embedding_model()
|
|
logger.info("Models loaded successfully.")
|
|
except Exception as e:
|
|
logger.error(f"Failed to load models: {e}")
|
|
|
|
logger.info("Starting SSE server on 0.0.0.0:8000...")
|
|
# Start the MCP server using SSE (Server-Sent Events) over HTTP
|
|
mcp.run(transport="sse", host="0.0.0.0", port=8000)
|