diff --git a/scripts/llm/analysis_support.py b/scripts/llm/analysis_support.py index 6a5752056b4ec658c2b9aaa126ed0c11ab5ecb81..e2c7a8d86b84d083e2b2fd72e3bbdb53c30312d7 100755 --- a/scripts/llm/analysis_support.py +++ b/scripts/llm/analysis_support.py @@ -9,13 +9,13 @@ Prerequisites: 1. Optionally, set up a `config.py` file to provide configuration - copy `config.py.tmpl` to `config.py` and fill in the values 2. Install dependencies: - - `torqueclient`, `jinja2`, `requests` Python libraries + - `torqueclient`, `jinja2`, `requests`, `nltk` Python libraries - `pandoc` Installation: $ python -m venv venv $ source ./venv/bin/activate - $ pip3 install torqueclient jinja2 requests + $ pip3 install torqueclient jinja2 requests nltk $ sudo apt-get install pandoc Usage: @@ -63,6 +63,7 @@ except ImportError: from dataclasses import asdict, dataclass, field import re +import nltk import requests from requests.adapters import HTTPAdapter, Retry import textwrap @@ -81,6 +82,7 @@ logging.getLogger("requests").setLevel(logging.WARNING) logging.getLogger("urllib3").setLevel(logging.WARNING) logging.getLogger("mwclient").setLevel(logging.ERROR) +nltk.download("punkt_tab", quiet=True) parser = argparse.ArgumentParser( prog="analysis_support.py", @@ -93,6 +95,21 @@ parser.add_argument( action="store_true", help="Output the response from the LLM rather than committing back to Torque", ) +parser.add_argument( + "-e", + "--evaluate", + action="store_true", + help="Evaluate the similarity between the LLM and LFC analyses", +) +parser.add_argument( + "-s", + "--similarity", + type=float, + default=os.getenv( + "SEMANTIC_SIMILARITY", getattr(config, "SEMANTIC_SIMILARITY", 0.77) + ), + help="Parts of the analysis are considered similar if the score is greater than this value", +) parser.add_argument( "-l", "--log-level", @@ -469,6 +486,50 @@ class SearchResult: extra_snippets: list[str] = field(default_factory=list) +@dataclass +class SearchResults: + results: list[SearchResult] + + def __str__(self): + if len(self.results) > 0: + return "The following results were returned when searching:" + "\n".join( + [ + f" * {result.title}\n" + f" - {result.description}\n" + + "\n".join( + [f" - {snippet}" for snippet in result.extra_snippets] + ) + for result in self.results + ] + ) + else: + return "No results found." + + +@dataclass +class EmbeddingRequest: + input: str | list[str] + + +@dataclass +class EvaluationStats: + total: int = 0 + similar: int = 0 + + def __add__(self, other): + return EvaluationStats( + total=self.total + other.total, + similar=self.similar + other.similar, + ) + + @property + def percent(self): + return round((self.similar / self.total) * 100) + + def __str__(self): + return f"{self.similar} / {self.total} = {self.percent}%" + + class APIClient: endpoint: str api_key: str @@ -517,6 +578,8 @@ class Brave(APIClient): ) def search(self, query): + logging.info(f" Searching for {query}...") + response = self.make_request( "web/search", params={"q": query}, @@ -526,15 +589,18 @@ class Brave(APIClient): raise ValueError("No web pages in response") results = response["web"]["results"] - return [ - SearchResult( - title=result["title"], - url=result["url"], - description=result["description"], - extra_snippets=result.get("extra_snippets", []), - ) - for result in results - ] + + return SearchResults( + results=[ + SearchResult( + title=result["title"], + url=result["url"], + description=result["description"], + extra_snippets=result.get("extra_snippets", []), + ) + for result in results + ] + ) class LLM(APIClient): @@ -570,6 +636,13 @@ class LLM(APIClient): summary = response["result"]["summary"] return AnalysisResponse(value=summary["content"], id=summary["id"]) + def get_embeddings(self, text): + response = self.make_request( + "embeddings", + EmbeddingRequest(input=text), + ) + return [embedding["embedding"] for embedding in response["data"]] + class MarkdownRenderer: template_path: str @@ -664,8 +737,77 @@ def clean_text(text): return text -def generate_analysis_support(llm, proposal, search_engine): +def split_text(text): + return nltk.sent_tokenize(text) + + +def remove_backslashes(text): + # Replace escaped characters + text = re.sub(r"\\([,.\-()$:+?])", r"\1", text) + return text + + +def dot_product(a, b): + result = 0 + for i in range(len(a)): + result += a[i] * b[i] + + return result + + +def calculate_evaluation_stats(llm, llm_key, llm_value, proposal, **kwargs): + if "LFC Analysis" not in proposal.keys(): + logging.error("LFC Analysis not found, not evaluating") + return EvaluationStats() + + key_mapping = { + "Project Overview": "Overview", + "Diversity, Equity, Inclusion, and Accessibility": "Diversity, Equity and Inclusion", + } + + lfc_key = key_mapping.get(llm_key, llm_key) + lfc_value = remove_backslashes(proposal["LFC Analysis"].get(lfc_key, "")) + + llm_parts = split_text(llm_value) + lfc_parts = split_text(lfc_value) + + stats = EvaluationStats(total=len(lfc_parts)) + + if not llm_parts or not lfc_parts: + return stats + + llm_embeddings = llm.get_embeddings(llm_parts) + lfc_embeddings = llm.get_embeddings(lfc_parts) + + for lfc_part, lfc_embedding in zip(lfc_parts, lfc_embeddings): + max_score, llm_part = max( + ( + ( + dot_product(llm_embedding, lfc_embedding), + llm_part, + ) + for llm_part, llm_embedding in zip(llm_parts, llm_embeddings) + ), + key=lambda item: item[0], + ) + + if max_score >= kwargs.get("similarity", 0.77): + logging.debug( + f" * Similar sentence found ({max_score}):\n" + f" - LLM: {llm_part}\n" + f" - LFC: {lfc_part}\n" + ) + + stats.similar += 1 + + logging.info(f" - Similarity: {stats}") + + return stats + + +def generate_analysis_support(llm, proposal, search_engine, **kwargs): llm_analysis = {} + summary_stats = {} for name, section in sections.items(): logging.info(f" * {name}...") @@ -673,32 +815,16 @@ def generate_analysis_support(llm, proposal, search_engine): prompts = [] for prompt in section.prompts.values(): - text = clean_text( - LLMProposal(proposal).render_markdown(prompt.template_blocks) - ) - if name == "Reputational Risks": - logging.info(" Searching for controversies...") - query = proposal["Organization Name"] + " controversy" - results = search_engine.search(query) - - if len(results) > 0: - text = f""" - The following results were returned when searching for "{query}": - """ - - for result in results: - text += f""" - - {result.title} - {result.description} - {" ".join([snippet for snippet in result.extra_snippets])} - """ - - text = clean_text(text) + text = str( + search_engine.search(f"{proposal['Organization Name']} controversy") + ) + else: + text = LLMProposal(proposal).render_markdown(prompt.template_blocks) prompts.append( AnalysisRequest( - text=text, + text=clean_text(text), considerations=prompt.sent, ) ) @@ -717,13 +843,18 @@ def generate_analysis_support(llm, proposal, search_engine): logging.info(f"{wrap_text(value, indent=' ')}") + if kwargs.get("evaluate"): + summary_stats[name] = calculate_evaluation_stats( + llm, name, value, proposal, **kwargs + ) + logging.debug("") logging.debug("*** Prompt and Analysis Support: ***") logging.debug("") logging.debug(pprint.pformat(llm_analysis)) logging.debug("") - return llm_analysis + return llm_analysis, summary_stats def cli(): @@ -764,6 +895,8 @@ def cli(): logging.info("Generating Analysis Support for:") logging.info("") + summary_stats = {} + for proposal_id in args.proposals: logging.info(f" * #{proposal_id}") @@ -774,16 +907,31 @@ def cli(): continue try: - llm_analysis = generate_analysis_support(llm, proposal, brave) + llm_analysis, proposal_stats = generate_analysis_support( + llm, proposal, brave, **vars(args) + ) except Exception as e: logging.error(f"Error generating analysis support: {e}") continue + for section, section_stats in proposal_stats.items(): + if section not in summary_stats: + summary_stats[section] = EvaluationStats() + + summary_stats[section] += section_stats + if not args.dry_run: # Setting this variable on a torqueclient proposal saves the data back # out to the server proposal["LLM LFC Analysis"] = llm_analysis + if args.evaluate: + logging.info("") + logging.info("Evaluation Summary (similar / total):") + + for section, stats in summary_stats.items(): + logging.info(f" * {section}: {stats}") + logging.info("")