Skip to content
Snippets Groups Projects
receivers.py 2.69 KiB
Newer Older
Chris Zubak-Skees's avatar
Chris Zubak-Skees committed
from django.conf import settings
from django.db import transaction
from django.db.models import ExpressionWrapper, F, FloatField, Q, Sum, Value, Window
from django.db.models.functions import DenseRank
from django.dispatch import receiver
Chris Zubak-Skees's avatar
Chris Zubak-Skees committed
from pgvector.django import CosineDistance
from torque.signals import search_filter, search_index_rebuilt, update_cache_document

from semantic_search.llm import llm
from semantic_search.models import SemanticSearchCacheDocument


@receiver(update_cache_document)
def update_semantic_cache_document(sender, **kwargs):
Chris Zubak-Skees's avatar
Chris Zubak-Skees committed
    cache_document = kwargs["cache_document"]
    filtered_data = kwargs["filtered_data"]
    document_dict = kwargs["document_dict"]

    with transaction.atomic():
        SemanticSearchCacheDocument.objects.filter(
            search_cache_document=cache_document
        ).delete()

        embedding_data = {}

        for filter in getattr(settings, "SEMANTIC_SEARCH_ADDITIONAL_FILTERS", []):
            embedding_data[filter.name()] = filter.document_value(document_dict)

        embedding_data.update(filtered_data)

        data_text = ""
        for name, value in embedding_data.items():
            name = name.replace("_", " ")
            if isinstance(value, list):
                for v in value:
                    data_text += f"{name} is {v}. "
            elif value:
                data_text += f"{name} is {value}. "

        embeddings = llm.get_embeddings(data_text)

        semantic_search_cache_documents = [
            SemanticSearchCacheDocument(
                search_cache_document=cache_document,
                data=data_text,
                data_embedding=embedding,
            )
            for embedding in embeddings
        ]

        SemanticSearchCacheDocument.objects.bulk_create(semantic_search_cache_documents)


@receiver(search_index_rebuilt)
def rebuild_semantic_search_index(sender, **kwargs):
    pass

@receiver(search_filter)
Chris Zubak-Skees's avatar
Chris Zubak-Skees committed
def semantic_filter(sender, **kwargs):
    similarity = getattr(settings, "SEMANTIC_SEARCH_SIMILARITY", 0.7)

    cache_documents = kwargs["cache_documents"]
    qs = kwargs.get("qs")

    if qs:
        embeddings = llm.get_embeddings(qs, prompt_name="query")

        distances = {
            f"distance_{i}": CosineDistance(
                "semantic_documents__data_embedding", embedding
            )
            for i, embedding in enumerate(embeddings)
        }

        filter_q = Q()
        for i in range(len(embeddings)):
            filter_q |= Q(**{f"distance_{i}__lte": similarity})

        results = (
            cache_documents.annotate(**distances)
            .filter(filter_q)
            .order_by("distance_0")  # sorted by the first query's distance
        )

        return results