ObsiGate/backend/search.py

826 lines
28 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
import math
import re
import unicodedata
from collections import defaultdict
from typing import List, Dict, Any, Optional, Tuple
from backend.indexer import index
logger = logging.getLogger("obsigate.search")
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
DEFAULT_SEARCH_LIMIT = 200
ADVANCED_SEARCH_DEFAULT_LIMIT = 50
SNIPPET_CONTEXT_CHARS = 120
MAX_SNIPPET_HIGHLIGHTS = 5
TITLE_BOOST = 3.0 # TF-IDF multiplier for title matches
PATH_BOOST = 1.5 # TF-IDF multiplier for path matches
TAG_BOOST = 2.0 # TF-IDF multiplier for tag matches
MIN_PREFIX_LENGTH = 2 # Minimum chars for prefix matching
SUGGEST_LIMIT = 10 # Default max suggestions returned
# Regex to tokenize text into alphanumeric words (Unicode-aware)
_WORD_RE = re.compile(r"[\w]+", re.UNICODE)
# ---------------------------------------------------------------------------
# Accent / Unicode normalization helpers
# ---------------------------------------------------------------------------
def normalize_text(text: str) -> str:
"""Normalize text for accent-insensitive comparison.
Decomposes Unicode characters (NFD), strips combining diacritical marks,
then lowercases the result. For example ``"Éléphant"`` → ``"elephant"``.
Args:
text: Raw input string.
Returns:
Lowercased, accent-stripped string.
"""
if not text:
return ""
# NFD decomposition splits base char + combining mark
nfkd = unicodedata.normalize("NFKD", text)
# Strip combining marks (category "Mn" = Mark, Nonspacing)
stripped = "".join(ch for ch in nfkd if unicodedata.category(ch) != "Mn")
return stripped.lower()
def tokenize(text: str) -> List[str]:
"""Split text into normalized tokens (accent-stripped, lowercased words).
Args:
text: Raw text to tokenize.
Returns:
List of normalized word tokens.
"""
return _WORD_RE.findall(normalize_text(text))
# ---------------------------------------------------------------------------
# Tag filter helper (unchanged for backward compat)
# ---------------------------------------------------------------------------
def _normalize_tag_filter(tag_filter: Optional[str]) -> List[str]:
"""Parse a comma-separated tag filter string into a clean list.
Strips whitespace and leading ``#`` from each tag.
Args:
tag_filter: Raw tag filter string (e.g. ``"docker,linux"``).
Returns:
List of normalised tag strings, empty list if input is falsy.
"""
if not tag_filter:
return []
return [tag.strip().lstrip("#") for tag in tag_filter.split(",") if tag.strip()]
# ---------------------------------------------------------------------------
# Snippet extraction helpers
# ---------------------------------------------------------------------------
def _extract_snippet(content: str, query: str, context_chars: int = SNIPPET_CONTEXT_CHARS) -> str:
"""Extract a text snippet around the first occurrence of *query*.
Returns up to ``context_chars`` characters before and after the match.
Falls back to the first 200 characters when the query is not found.
Args:
content: Full text to search within.
query: The search term.
context_chars: Number of context characters on each side.
Returns:
Snippet string, optionally prefixed/suffixed with ``...``.
"""
lower_content = content.lower()
lower_query = query.lower()
pos = lower_content.find(lower_query)
if pos == -1:
return content[:200].strip()
start = max(0, pos - context_chars)
end = min(len(content), pos + len(query) + context_chars)
snippet = content[start:end].strip()
if start > 0:
snippet = "..." + snippet
if end < len(content):
snippet = snippet + "..."
return snippet
def _extract_highlighted_snippet(
content: str,
query_terms: List[str],
context_chars: int = SNIPPET_CONTEXT_CHARS,
max_highlights: int = MAX_SNIPPET_HIGHLIGHTS,
) -> str:
"""Extract a snippet and wrap matching terms in ``<mark>`` tags.
Performs accent-normalized matching so ``"resume"`` highlights ``"résumé"``.
Returns at most *max_highlights* highlighted regions to keep snippets concise.
Args:
content: Full text to search within.
query_terms: Normalized search terms.
context_chars: Number of context characters on each side.
max_highlights: Maximum highlighted regions.
Returns:
HTML snippet string with ``<mark>`` highlights.
"""
if not content or not query_terms:
return content[:200].strip() if content else ""
norm_content = normalize_text(content)
# Find best position — first occurrence of any query term
best_pos = len(content)
for term in query_terms:
pos = norm_content.find(term)
if pos != -1 and pos < best_pos:
best_pos = pos
if best_pos == len(content):
# No match found — return beginning of content
return _escape_html(content[:200].strip())
start = max(0, best_pos - context_chars)
end = min(len(content), best_pos + context_chars + 40)
raw_snippet = content[start:end].strip()
prefix = "..." if start > 0 else ""
suffix = "..." if end < len(content) else ""
# Highlight all term occurrences in the snippet
highlighted = _highlight_terms(raw_snippet, query_terms, max_highlights)
return prefix + highlighted + suffix
def _highlight_terms(text: str, terms: List[str], max_highlights: int) -> str:
"""Wrap occurrences of *terms* in *text* with ``<mark>`` tags.
Uses accent-normalized comparison so diacritical variants are matched.
Escapes HTML in non-highlighted portions to prevent XSS.
Args:
text: Raw text snippet.
terms: Normalized search terms.
max_highlights: Cap on highlighted regions.
Returns:
HTML-safe string with ``<mark>`` wrapped matches.
"""
if not terms or not text:
return _escape_html(text)
norm = normalize_text(text)
# Collect (start, end) spans for all term matches
spans: List[Tuple[int, int]] = []
for term in terms:
idx = 0
while idx < len(norm):
pos = norm.find(term, idx)
if pos == -1:
break
spans.append((pos, pos + len(term)))
idx = pos + 1
if not spans:
return _escape_html(text)
# Merge overlapping spans and limit count
spans.sort()
merged: List[Tuple[int, int]] = [spans[0]]
for s, e in spans[1:]:
if s <= merged[-1][1]:
merged[-1] = (merged[-1][0], max(merged[-1][1], e))
else:
merged.append((s, e))
merged = merged[:max_highlights]
# Build result with highlights
parts: List[str] = []
prev = 0
for s, e in merged:
if s > prev:
parts.append(_escape_html(text[prev:s]))
parts.append(f"<mark>{_escape_html(text[s:e])}</mark>")
prev = e
if prev < len(text):
parts.append(_escape_html(text[prev:]))
return "".join(parts)
def _escape_html(text: str) -> str:
"""Escape HTML special characters."""
return (
text.replace("&", "&amp;")
.replace("<", "&lt;")
.replace(">", "&gt;")
.replace('"', "&quot;")
)
# ---------------------------------------------------------------------------
# Inverted Index for TF-IDF
# ---------------------------------------------------------------------------
class InvertedIndex:
"""In-memory inverted index supporting TF-IDF scoring.
Built lazily from the global ``index`` dict whenever a search or
suggestion request detects that the underlying vault index has changed.
The class is designed to be a singleton — use ``get_inverted_index()``.
Attributes:
word_index: ``{token: {doc_key: term_frequency}}``
title_index: ``{token: [doc_key, ...]}``
tag_norm_map: ``{normalized_tag: original_tag}``
tag_prefix_index: ``{prefix: [original_tag, ...]}``
doc_count: Total number of indexed documents.
_source_id: Fingerprint of the source index to detect staleness.
"""
def __init__(self) -> None:
self.word_index: Dict[str, Dict[str, int]] = defaultdict(dict)
self.title_index: Dict[str, List[str]] = defaultdict(list)
self.tag_norm_map: Dict[str, str] = {}
self.tag_prefix_index: Dict[str, List[str]] = defaultdict(list)
self.title_norm_map: Dict[str, List[Dict[str, str]]] = defaultdict(list)
self.doc_count: int = 0
self._source_id: Optional[int] = None
def is_stale(self) -> bool:
"""Check if the inverted index needs rebuilding."""
current_id = id(index)
return current_id != self._source_id
def rebuild(self) -> None:
"""Rebuild inverted index from the global ``index`` dict.
Tokenizes titles and content of every file, computes term frequencies,
and builds auxiliary indexes for tag and title prefix suggestions.
"""
logger.info("Rebuilding inverted index...")
self.word_index = defaultdict(dict)
self.title_index = defaultdict(list)
self.tag_norm_map = {}
self.tag_prefix_index = defaultdict(list)
self.title_norm_map = defaultdict(list)
self.doc_count = 0
for vault_name, vault_data in index.items():
for file_info in vault_data.get("files", []):
doc_key = f"{vault_name}::{file_info['path']}"
self.doc_count += 1
# --- Title tokens ---
title_tokens = tokenize(file_info.get("title", ""))
for token in set(title_tokens):
self.title_index[token].append(doc_key)
# --- Normalized title for prefix suggestions ---
norm_title = normalize_text(file_info.get("title", ""))
if norm_title:
self.title_norm_map[norm_title].append({
"vault": vault_name,
"path": file_info["path"],
"title": file_info["title"],
})
# --- Content tokens (including title for combined scoring) ---
content = file_info.get("content", "")
full_text = (file_info.get("title", "") + " " + content)
tokens = tokenize(full_text)
tf: Dict[str, int] = defaultdict(int)
for token in tokens:
tf[token] += 1
for token, freq in tf.items():
self.word_index[token][doc_key] = freq
# --- Tag indexes ---
for tag in vault_data.get("tags", {}):
norm_tag = normalize_text(tag)
self.tag_norm_map[norm_tag] = tag
# Build prefix entries for each prefix length ≥ MIN_PREFIX_LENGTH
for plen in range(MIN_PREFIX_LENGTH, len(norm_tag) + 1):
prefix = norm_tag[:plen]
if tag not in self.tag_prefix_index[prefix]:
self.tag_prefix_index[prefix].append(tag)
self._source_id = id(index)
logger.info(
"Inverted index built: %d documents, %d unique tokens, %d tags",
self.doc_count,
len(self.word_index),
len(self.tag_norm_map),
)
def idf(self, term: str) -> float:
"""Inverse Document Frequency for a term.
``idf(t) = log(N / (1 + df(t)))`` where *df(t)* is the number
of documents containing term *t*.
Args:
term: Normalized term.
Returns:
IDF score (≥ 0).
"""
df = len(self.word_index.get(term, {}))
if df == 0:
return 0.0
return math.log((self.doc_count + 1) / (1 + df))
def tf_idf(self, term: str, doc_key: str) -> float:
"""TF-IDF score for a term in a document.
Uses raw term frequency (no log normalization) × IDF.
Args:
term: Normalized term.
doc_key: ``"vault::path"`` document key.
Returns:
TF-IDF score.
"""
tf = self.word_index.get(term, {}).get(doc_key, 0)
if tf == 0:
return 0.0
return tf * self.idf(term)
# Singleton inverted index
_inverted_index = InvertedIndex()
def get_inverted_index() -> InvertedIndex:
"""Return the singleton inverted index, rebuilding if stale."""
if _inverted_index.is_stale():
_inverted_index.rebuild()
return _inverted_index
# ---------------------------------------------------------------------------
# Backward-compatible search (unchanged API)
# ---------------------------------------------------------------------------
def search(
query: str,
vault_filter: str = "all",
tag_filter: Optional[str] = None,
limit: int = DEFAULT_SEARCH_LIMIT,
) -> List[Dict[str, Any]]:
"""Full-text search across indexed vaults with relevance scoring.
Scoring heuristics (when a text query is provided):
- **+20** exact title match (case-insensitive)
- **+10** partial title match
- **+5** query found in file path
- **+3** query matches a tag name
- **+1 per occurrence** in content (capped at 10)
When only tag filters are active, all matching files receive score 1.
Results are sorted descending by score and capped at *limit*.
Uses the in-memory cached content from the index — **no disk I/O**.
Args:
query: Free-text search string.
vault_filter: Vault name or ``"all"``.
tag_filter: Comma-separated tag names to require.
limit: Maximum number of results to return.
Returns:
List of result dicts sorted by descending relevance score.
"""
query = query.strip() if query else ""
has_query = len(query) > 0
selected_tags = _normalize_tag_filter(tag_filter)
if not has_query and not selected_tags:
return []
query_lower = query.lower()
results: List[Dict[str, Any]] = []
for vault_name, vault_data in index.items():
if vault_filter != "all" and vault_name != vault_filter:
continue
for file_info in vault_data["files"]:
# Tag filter: all selected tags must be present
if selected_tags and not all(tag in file_info["tags"] for tag in selected_tags):
continue
score = 0
snippet = file_info.get("content_preview", "")
if has_query:
title_lower = file_info["title"].lower()
# Exact title match (highest weight)
if query_lower == title_lower:
score += 20
# Partial title match
elif query_lower in title_lower:
score += 10
# Path match (folder/filename relevance)
if query_lower in file_info["path"].lower():
score += 5
# Tag name match
for tag in file_info.get("tags", []):
if query_lower in tag.lower():
score += 3
break # count once per file
# Content match — use cached content (no disk I/O)
content = file_info.get("content", "")
content_lower = content.lower()
if query_lower in content_lower:
# Frequency-based scoring, capped to avoid over-weighting
occurrences = content_lower.count(query_lower)
score += min(occurrences, 10)
snippet = _extract_snippet(content, query)
else:
# Tag-only filter: all matching files get score 1
score = 1
if score > 0:
results.append({
"vault": vault_name,
"path": file_info["path"],
"title": file_info["title"],
"tags": file_info["tags"],
"score": score,
"snippet": snippet,
"modified": file_info["modified"],
})
results.sort(key=lambda x: -x["score"])
return results[:limit]
# ---------------------------------------------------------------------------
# Advanced search with TF-IDF scoring
# ---------------------------------------------------------------------------
def _parse_advanced_query(raw_query: str) -> Dict[str, Any]:
"""Parse an advanced query string into structured filters and free text.
Supported operators:
- ``tag:<name>`` or ``#<name>`` — tag filter
- ``vault:<name>`` — vault filter
- ``title:<text>`` — title filter
- ``path:<text>`` — path filter
- Remaining tokens are treated as free-text search terms.
Args:
raw_query: Raw query string from the user.
Returns:
Dict with keys ``tags``, ``vault``, ``title``, ``path``, ``terms``.
"""
parsed: Dict[str, Any] = {
"tags": [],
"vault": None,
"title": None,
"path": None,
"terms": [],
}
if not raw_query:
return parsed
# Use shlex-like tokenizing but handle quotes manually
tokens = _split_query_tokens(raw_query)
for token in tokens:
lower = token.lower()
if lower.startswith("tag:"):
tag_val = token[4:].strip().lstrip("#")
if tag_val:
parsed["tags"].append(tag_val)
elif lower.startswith("#") and len(token) > 1:
parsed["tags"].append(token[1:])
elif lower.startswith("vault:"):
parsed["vault"] = token[6:].strip()
elif lower.startswith("title:"):
parsed["title"] = token[6:].strip()
elif lower.startswith("path:"):
parsed["path"] = token[5:].strip()
else:
parsed["terms"].append(token)
return parsed
def _split_query_tokens(raw: str) -> List[str]:
"""Split a query string respecting quoted phrases.
``tag:"my tag" hello world`` → ``['tag:my tag', 'hello', 'world']``
Args:
raw: Raw query string.
Returns:
List of token strings.
"""
tokens: List[str] = []
i = 0
n = len(raw)
while i < n:
# Skip whitespace
while i < n and raw[i] == " ":
i += 1
if i >= n:
break
# Check for operator with quoted value, e.g., tag:"foo bar"
if i < n and raw[i] != '"':
# Read until space or quote
j = i
while j < n and raw[j] != " ":
if raw[j] == '"':
# Read quoted portion
j += 1
while j < n and raw[j] != '"':
j += 1
if j < n:
j += 1 # skip closing quote
else:
j += 1
token = raw[i:j].replace('"', "")
tokens.append(token)
i = j
else:
# Quoted token
i += 1 # skip opening quote
j = i
while j < n and raw[j] != '"':
j += 1
tokens.append(raw[i:j])
i = j + 1 # skip closing quote
return tokens
def advanced_search(
query: str,
vault_filter: str = "all",
tag_filter: Optional[str] = None,
limit: int = ADVANCED_SEARCH_DEFAULT_LIMIT,
offset: int = 0,
sort_by: str = "relevance",
) -> Dict[str, Any]:
"""Advanced full-text search with TF-IDF scoring, facets, and pagination.
Parses the query for operators (``tag:``, ``vault:``, ``title:``,
``path:``), falls back remaining tokens to TF-IDF scored free-text
search using the inverted index. Results include highlighted snippets
with ``<mark>`` tags and faceted counts for tags and vaults.
Args:
query: Raw query string (may include operators).
vault_filter: Vault name or ``"all"`` (overridden by ``vault:`` op).
tag_filter: Comma-separated tag names (merged with ``tag:`` ops).
limit: Max results per page.
offset: Pagination offset.
sort_by: ``"relevance"`` or ``"modified"``.
Returns:
Dict with ``results``, ``total``, ``offset``, ``limit``, ``facets``.
"""
query = query.strip() if query else ""
parsed = _parse_advanced_query(query)
# Merge explicit tag_filter with parsed tag: operators
all_tags = list(parsed["tags"])
extra_tags = _normalize_tag_filter(tag_filter)
for t in extra_tags:
if t not in all_tags:
all_tags.append(t)
# Vault filter — parsed vault: overrides parameter
effective_vault = parsed["vault"] or vault_filter
# Normalize free-text terms
query_terms = [normalize_text(t) for t in parsed["terms"] if t.strip()]
has_terms = len(query_terms) > 0
if not has_terms and not all_tags and not parsed["title"] and not parsed["path"]:
return {"results": [], "total": 0, "offset": offset, "limit": limit, "facets": {"tags": {}, "vaults": {}}}
inv = get_inverted_index()
scored_results: List[Tuple[float, Dict[str, Any]]] = []
facet_tags: Dict[str, int] = defaultdict(int)
facet_vaults: Dict[str, int] = defaultdict(int)
for vault_name, vault_data in index.items():
if effective_vault != "all" and vault_name != effective_vault:
continue
for file_info in vault_data.get("files", []):
doc_key = f"{vault_name}::{file_info['path']}"
# --- Tag filter ---
if all_tags:
file_tags_lower = [t.lower() for t in file_info.get("tags", [])]
if not all(t.lower() in file_tags_lower for t in all_tags):
continue
# --- Title filter ---
if parsed["title"]:
norm_title_filter = normalize_text(parsed["title"])
norm_file_title = normalize_text(file_info.get("title", ""))
if norm_title_filter not in norm_file_title:
continue
# --- Path filter ---
if parsed["path"]:
norm_path_filter = normalize_text(parsed["path"])
norm_file_path = normalize_text(file_info.get("path", ""))
if norm_path_filter not in norm_file_path:
continue
# --- Scoring ---
score = 0.0
if has_terms:
# TF-IDF scoring for each term
for term in query_terms:
tfidf = inv.tf_idf(term, doc_key)
score += tfidf
# Title boost — check if term appears in title tokens
norm_title = normalize_text(file_info.get("title", ""))
if term in norm_title:
score += tfidf * TITLE_BOOST
# Path boost
norm_path = normalize_text(file_info.get("path", ""))
if term in norm_path:
score += tfidf * PATH_BOOST
# Tag boost
for tag in file_info.get("tags", []):
if term in normalize_text(tag):
score += tfidf * TAG_BOOST
break
# Also add prefix matching bonus for partial words
for term in query_terms:
if len(term) >= MIN_PREFIX_LENGTH:
for indexed_term, docs in inv.word_index.items():
if indexed_term.startswith(term) and indexed_term != term:
if doc_key in docs:
score += inv.tf_idf(indexed_term, doc_key) * 0.5
else:
# Filter-only search (tag/title/path): score = 1
score = 1.0
if score > 0:
# Build highlighted snippet
content = file_info.get("content", "")
if has_terms:
snippet = _extract_highlighted_snippet(content, query_terms)
else:
snippet = _escape_html(content[:200].strip()) if content else ""
result = {
"vault": vault_name,
"path": file_info["path"],
"title": file_info["title"],
"tags": file_info.get("tags", []),
"score": round(score, 4),
"snippet": snippet,
"modified": file_info.get("modified", ""),
}
scored_results.append((score, result))
# Facets
facet_vaults[vault_name] = facet_vaults.get(vault_name, 0) + 1
for tag in file_info.get("tags", []):
facet_tags[tag] = facet_tags.get(tag, 0) + 1
# Sort
if sort_by == "modified":
scored_results.sort(key=lambda x: x[1].get("modified", ""), reverse=True)
else:
scored_results.sort(key=lambda x: -x[0])
total = len(scored_results)
page = scored_results[offset: offset + limit]
return {
"results": [r for _, r in page],
"total": total,
"offset": offset,
"limit": limit,
"facets": {
"tags": dict(sorted(facet_tags.items(), key=lambda x: -x[1])[:20]),
"vaults": dict(sorted(facet_vaults.items(), key=lambda x: -x[1])),
},
}
# ---------------------------------------------------------------------------
# Suggestion helpers
# ---------------------------------------------------------------------------
def suggest_titles(
prefix: str,
vault_filter: str = "all",
limit: int = SUGGEST_LIMIT,
) -> List[Dict[str, str]]:
"""Suggest file titles matching a prefix (accent-insensitive).
Args:
prefix: User-typed prefix string.
vault_filter: Vault name or ``"all"``.
limit: Maximum suggestions.
Returns:
List of ``{"vault", "path", "title"}`` dicts.
"""
if not prefix or len(prefix) < MIN_PREFIX_LENGTH:
return []
inv = get_inverted_index()
norm_prefix = normalize_text(prefix)
results: List[Dict[str, str]] = []
seen: set = set()
for norm_title, entries in inv.title_norm_map.items():
if norm_prefix in norm_title:
for entry in entries:
if vault_filter != "all" and entry["vault"] != vault_filter:
continue
key = f"{entry['vault']}::{entry['path']}"
if key not in seen:
seen.add(key)
results.append(entry)
if len(results) >= limit:
return results
return results
def suggest_tags(
prefix: str,
vault_filter: str = "all",
limit: int = SUGGEST_LIMIT,
) -> List[Dict[str, Any]]:
"""Suggest tags matching a prefix (accent-insensitive).
Args:
prefix: User-typed prefix (with or without leading ``#``).
vault_filter: Vault name or ``"all"``.
limit: Maximum suggestions.
Returns:
List of ``{"tag", "count"}`` dicts sorted by descending count.
"""
prefix = prefix.lstrip("#").strip()
if not prefix or len(prefix) < MIN_PREFIX_LENGTH:
return []
norm_prefix = normalize_text(prefix)
all_tag_counts = get_all_tags(vault_filter)
matches: List[Dict[str, Any]] = []
for tag, count in all_tag_counts.items():
norm_tag = normalize_text(tag)
if norm_prefix in norm_tag:
matches.append({"tag": tag, "count": count})
if len(matches) >= limit:
break
return matches
# ---------------------------------------------------------------------------
# Backward-compatible tag aggregation (unchanged API)
# ---------------------------------------------------------------------------
def get_all_tags(vault_filter: Optional[str] = None) -> Dict[str, int]:
"""Aggregate tag counts across vaults, sorted by descending count.
Args:
vault_filter: Optional vault name to restrict to a single vault.
Returns:
Dict mapping tag names to their total occurrence count.
"""
merged: Dict[str, int] = {}
for vault_name, vault_data in index.items():
if vault_filter and vault_name != vault_filter:
continue
for tag, count in vault_data.get("tags", {}).items():
merged[tag] = merged.get(tag, 0) + count
return dict(sorted(merged.items(), key=lambda x: -x[1]))