Skip to content
Snippets Groups Projects
Commit 5e7fb1e6 authored by Chris Zubak-Skees's avatar Chris Zubak-Skees
Browse files

Add an evaluation metric for Analysis Support

parent 053a2cc1
No related branches found
No related tags found
1 merge request!192Add evaluation metric(s) for Analysis Support
...@@ -9,13 +9,13 @@ Prerequisites: ...@@ -9,13 +9,13 @@ Prerequisites:
1. Optionally, set up a `config.py` file to provide configuration 1. Optionally, set up a `config.py` file to provide configuration
- copy `config.py.tmpl` to `config.py` and fill in the values - copy `config.py.tmpl` to `config.py` and fill in the values
2. Install dependencies: 2. Install dependencies:
- `torqueclient`, `jinja2`, `requests` Python libraries - `torqueclient`, `jinja2`, `requests`, `nltk` Python libraries
- `pandoc` - `pandoc`
Installation: Installation:
$ python -m venv venv $ python -m venv venv
$ source ./venv/bin/activate $ source ./venv/bin/activate
$ pip3 install torqueclient jinja2 requests $ pip3 install torqueclient jinja2 requests nltk
$ sudo apt-get install pandoc $ sudo apt-get install pandoc
Usage: Usage:
...@@ -63,6 +63,7 @@ except ImportError: ...@@ -63,6 +63,7 @@ except ImportError:
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
import re import re
import nltk
import requests import requests
from requests.adapters import HTTPAdapter, Retry from requests.adapters import HTTPAdapter, Retry
import textwrap import textwrap
...@@ -81,6 +82,7 @@ logging.getLogger("requests").setLevel(logging.WARNING) ...@@ -81,6 +82,7 @@ logging.getLogger("requests").setLevel(logging.WARNING)
logging.getLogger("urllib3").setLevel(logging.WARNING) logging.getLogger("urllib3").setLevel(logging.WARNING)
logging.getLogger("mwclient").setLevel(logging.ERROR) logging.getLogger("mwclient").setLevel(logging.ERROR)
nltk.download("punkt_tab", quiet=True)
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
prog="analysis_support.py", prog="analysis_support.py",
...@@ -93,6 +95,21 @@ parser.add_argument( ...@@ -93,6 +95,21 @@ parser.add_argument(
action="store_true", action="store_true",
help="Output the response from the LLM rather than committing back to Torque", help="Output the response from the LLM rather than committing back to Torque",
) )
parser.add_argument(
"-e",
"--evaluate",
action="store_true",
help="Evaluate the similarity between the LLM and LFC analyses",
)
parser.add_argument(
"-s",
"--similarity",
type=float,
default=os.getenv(
"SEMANTIC_SIMILARITY", getattr(config, "SEMANTIC_SIMILARITY", 0.77)
),
help="Parts of the analysis are considered similar if the score is greater than this value",
)
parser.add_argument( parser.add_argument(
"-l", "-l",
"--log-level", "--log-level",
...@@ -469,6 +486,50 @@ class SearchResult: ...@@ -469,6 +486,50 @@ class SearchResult:
extra_snippets: list[str] = field(default_factory=list) extra_snippets: list[str] = field(default_factory=list)
@dataclass
class SearchResults:
results: list[SearchResult]
def __str__(self):
if len(self.results) > 0:
return "The following results were returned when searching:" + "\n".join(
[
f" * {result.title}\n"
f" - {result.description}\n"
+ "\n".join(
[f" - {snippet}" for snippet in result.extra_snippets]
)
for result in self.results
]
)
else:
return "No results found."
@dataclass
class EmbeddingRequest:
input: str | list[str]
@dataclass
class EvaluationStats:
total: int = 0
similar: int = 0
def __add__(self, other):
return EvaluationStats(
total=self.total + other.total,
similar=self.similar + other.similar,
)
@property
def percent(self):
return round((self.similar / self.total) * 100)
def __str__(self):
return f"{self.similar} / {self.total} = {self.percent}%"
class APIClient: class APIClient:
endpoint: str endpoint: str
api_key: str api_key: str
...@@ -517,6 +578,8 @@ class Brave(APIClient): ...@@ -517,6 +578,8 @@ class Brave(APIClient):
) )
def search(self, query): def search(self, query):
logging.info(f" Searching for {query}...")
response = self.make_request( response = self.make_request(
"web/search", "web/search",
params={"q": query}, params={"q": query},
...@@ -526,15 +589,18 @@ class Brave(APIClient): ...@@ -526,15 +589,18 @@ class Brave(APIClient):
raise ValueError("No web pages in response") raise ValueError("No web pages in response")
results = response["web"]["results"] results = response["web"]["results"]
return [
SearchResult( return SearchResults(
title=result["title"], results=[
url=result["url"], SearchResult(
description=result["description"], title=result["title"],
extra_snippets=result.get("extra_snippets", []), url=result["url"],
) description=result["description"],
for result in results extra_snippets=result.get("extra_snippets", []),
] )
for result in results
]
)
class LLM(APIClient): class LLM(APIClient):
...@@ -570,6 +636,13 @@ class LLM(APIClient): ...@@ -570,6 +636,13 @@ class LLM(APIClient):
summary = response["result"]["summary"] summary = response["result"]["summary"]
return AnalysisResponse(value=summary["content"], id=summary["id"]) return AnalysisResponse(value=summary["content"], id=summary["id"])
def get_embeddings(self, text):
response = self.make_request(
"embeddings",
EmbeddingRequest(input=text),
)
return [embedding["embedding"] for embedding in response["data"]]
class MarkdownRenderer: class MarkdownRenderer:
template_path: str template_path: str
...@@ -664,8 +737,77 @@ def clean_text(text): ...@@ -664,8 +737,77 @@ def clean_text(text):
return text return text
def generate_analysis_support(llm, proposal, search_engine): def split_text(text):
return nltk.sent_tokenize(text)
def remove_backslashes(text):
# Replace escaped characters
text = re.sub(r"\\([,.\-()$:+?])", r"\1", text)
return text
def dot_product(a, b):
result = 0
for i in range(len(a)):
result += a[i] * b[i]
return result
def calculate_evaluation_stats(llm, llm_key, llm_value, proposal, **kwargs):
if "LFC Analysis" not in proposal.keys():
logging.error("LFC Analysis not found, not evaluating")
return EvaluationStats()
key_mapping = {
"Project Overview": "Overview",
"Diversity, Equity, Inclusion, and Accessibility": "Diversity, Equity and Inclusion",
}
lfc_key = key_mapping.get(llm_key, llm_key)
lfc_value = remove_backslashes(proposal["LFC Analysis"].get(lfc_key, ""))
llm_parts = split_text(llm_value)
lfc_parts = split_text(lfc_value)
stats = EvaluationStats(total=len(lfc_parts))
if not llm_parts or not lfc_parts:
return stats
llm_embeddings = llm.get_embeddings(llm_parts)
lfc_embeddings = llm.get_embeddings(lfc_parts)
for lfc_part, lfc_embedding in zip(lfc_parts, lfc_embeddings):
max_score, llm_part = max(
(
(
dot_product(llm_embedding, lfc_embedding),
llm_part,
)
for llm_part, llm_embedding in zip(llm_parts, llm_embeddings)
),
key=lambda item: item[0],
)
if max_score >= kwargs.get("similarity", 0.77):
logging.debug(
f" * Similar sentence found ({max_score}):\n"
f" - LLM: {llm_part}\n"
f" - LFC: {lfc_part}\n"
)
stats.similar += 1
logging.info(f" - Similarity: {stats}")
return stats
def generate_analysis_support(llm, proposal, search_engine, **kwargs):
llm_analysis = {} llm_analysis = {}
summary_stats = {}
for name, section in sections.items(): for name, section in sections.items():
logging.info(f" * {name}...") logging.info(f" * {name}...")
...@@ -673,32 +815,16 @@ def generate_analysis_support(llm, proposal, search_engine): ...@@ -673,32 +815,16 @@ def generate_analysis_support(llm, proposal, search_engine):
prompts = [] prompts = []
for prompt in section.prompts.values(): for prompt in section.prompts.values():
text = clean_text(
LLMProposal(proposal).render_markdown(prompt.template_blocks)
)
if name == "Reputational Risks": if name == "Reputational Risks":
logging.info(" Searching for controversies...") text = str(
query = proposal["Organization Name"] + " controversy" search_engine.search(f"{proposal['Organization Name']} controversy")
results = search_engine.search(query) )
else:
if len(results) > 0: text = LLMProposal(proposal).render_markdown(prompt.template_blocks)
text = f"""
The following results were returned when searching for "{query}":
"""
for result in results:
text += f"""
- {result.title}
{result.description}
{" ".join([snippet for snippet in result.extra_snippets])}
"""
text = clean_text(text)
prompts.append( prompts.append(
AnalysisRequest( AnalysisRequest(
text=text, text=clean_text(text),
considerations=prompt.sent, considerations=prompt.sent,
) )
) )
...@@ -717,13 +843,18 @@ def generate_analysis_support(llm, proposal, search_engine): ...@@ -717,13 +843,18 @@ def generate_analysis_support(llm, proposal, search_engine):
logging.info(f"{wrap_text(value, indent=' ')}") logging.info(f"{wrap_text(value, indent=' ')}")
if kwargs.get("evaluate"):
summary_stats[name] = calculate_evaluation_stats(
llm, name, value, proposal, **kwargs
)
logging.debug("") logging.debug("")
logging.debug("*** Prompt and Analysis Support: ***") logging.debug("*** Prompt and Analysis Support: ***")
logging.debug("") logging.debug("")
logging.debug(pprint.pformat(llm_analysis)) logging.debug(pprint.pformat(llm_analysis))
logging.debug("") logging.debug("")
return llm_analysis return llm_analysis, summary_stats
def cli(): def cli():
...@@ -764,6 +895,8 @@ def cli(): ...@@ -764,6 +895,8 @@ def cli():
logging.info("Generating Analysis Support for:") logging.info("Generating Analysis Support for:")
logging.info("") logging.info("")
summary_stats = {}
for proposal_id in args.proposals: for proposal_id in args.proposals:
logging.info(f" * #{proposal_id}") logging.info(f" * #{proposal_id}")
...@@ -774,16 +907,31 @@ def cli(): ...@@ -774,16 +907,31 @@ def cli():
continue continue
try: try:
llm_analysis = generate_analysis_support(llm, proposal, brave) llm_analysis, proposal_stats = generate_analysis_support(
llm, proposal, brave, **vars(args)
)
except Exception as e: except Exception as e:
logging.error(f"Error generating analysis support: {e}") logging.error(f"Error generating analysis support: {e}")
continue continue
for section, section_stats in proposal_stats.items():
if section not in summary_stats:
summary_stats[section] = EvaluationStats()
summary_stats[section] += section_stats
if not args.dry_run: if not args.dry_run:
# Setting this variable on a torqueclient proposal saves the data back # Setting this variable on a torqueclient proposal saves the data back
# out to the server # out to the server
proposal["LLM LFC Analysis"] = llm_analysis proposal["LLM LFC Analysis"] = llm_analysis
if args.evaluate:
logging.info("")
logging.info("Evaluation Summary (similar / total):")
for section, stats in summary_stats.items():
logging.info(f" * {section}: {stats}")
logging.info("") logging.info("")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment