Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
32f7aed
update kb's random selection (without replacement) and fix test
JiaenLiu Jul 24, 2024
22a88d3
Merge branch 'main' into GSK-3609-Avoid-redundant-questions-in-data-g…
henchaves Jul 29, 2024
cae6e80
Remove gpt4 dependency disclaimer from api docs
henchaves Aug 2, 2024
6fb07c8
update testcase for knowledge base
JiaenLiu Aug 14, 2024
b7848cd
pdm update
JiaenLiu Aug 14, 2024
ee762c1
update lock file
JiaenLiu Aug 14, 2024
ea1e29e
Regenerating pdm.lock
Aug 14, 2024
88433e2
Merge branch 'main' into GSK-3609-Avoid-redundant-questions-in-data-g…
henchaves Aug 19, 2024
a0915df
Regenerating pdm.lock
Aug 19, 2024
9213f90
Try to fix last dependencies version
henchaves Aug 19, 2024
c27318a
Regenerating pdm.lock
Aug 19, 2024
ca2feb5
Merge branch 'main' into GSK-3609-Avoid-redundant-questions-in-data-g…
kevinmessiaen Aug 21, 2024
4b37c65
pdm.lock
kevinmessiaen Aug 21, 2024
cf425dc
Merge branch 'main' into GSK-3609-Avoid-redundant-questions-in-data-g…
henchaves Aug 23, 2024
1f266e7
Merge branch 'main' into GSK-3609-Avoid-redundant-questions-in-data-g…
kevinmessiaen Aug 30, 2024
0ae67b6
Merge branch 'main' into GSK-3609-Avoid-redundant-questions-in-data-g…
henchaves Sep 13, 2024
7810628
Merge branch 'main' into GSK-3609-Avoid-redundant-questions-in-data-g…
henchaves Sep 13, 2024
ef1f1f0
Regenerating pdm.lock
Sep 13, 2024
064b6e6
Merge branch 'main' into GSK-3609-Avoid-redundant-questions-in-data-g…
henchaves Oct 1, 2024
00c784f
Regenerating pdm.lock
Oct 1, 2024
bd94913
Merge branch 'main' into GSK-3609-Avoid-redundant-questions-in-data-g…
henchaves Oct 30, 2024
8caa128
Merge remote-tracking branch 'origin/main' into GSK-3609-Avoid-redund…
henchaves Oct 30, 2024
e9012ed
Format file
henchaves Oct 30, 2024
6498980
Update knowledge base test
mattbit Oct 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions giskard/rag/knowledge_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,17 @@ def get_failure_plot(self, question_evaluation: Sequence[dict] = None):
def get_random_document(self):
return self._rng.choice(self._documents)

def get_random_documents(self, n: int, with_replacement=False):
if with_replacement:
return list(self._rng.choice(self._documents, n, replace=True))

docs = list(self._rng.choice(self._documents, min(n, len(self._documents)), replace=False))

if len(docs) <= n:
docs.extend(self._rng.choice(self._documents, n - len(docs), replace=True))

return docs

def get_neighbors(self, seed_document: Document, n_neighbors: int = 4, similarity_threshold: float = 0.2):
seed_embedding = seed_document.embeddings

Expand Down
6 changes: 4 additions & 2 deletions giskard/rag/question_generators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,11 @@ class GenerateFromSingleQuestionMixin:
_question_type: str

def generate_questions(self, knowledge_base: KnowledgeBase, num_questions: int, *args, **kwargs) -> Iterator[Dict]:
for _ in range(num_questions):
docs = knowledge_base.get_random_documents(num_questions)

for doc in docs:
try:
yield self.generate_single_question(knowledge_base, *args, **kwargs)
yield self.generate_single_question(knowledge_base, *args, **kwargs, seed_document=doc)
except Exception as e: # @TODO: specify exceptions
logger.error(f"Encountered error in question generation: {e}. Skipping.")
logger.exception(e)
Expand Down
4 changes: 2 additions & 2 deletions giskard/rag/question_generators/double_questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@ class DoubleQuestionsGenerator(GenerateFromSingleQuestionMixin, _LLMBasedQuestio
_question_type = "double"

def generate_single_question(
self, knowledge_base: KnowledgeBase, agent_description: str, language: str
self, knowledge_base: KnowledgeBase, agent_description: str, language: str, seed_document=None
) -> QuestionSample:
seed_document = knowledge_base.get_random_document()
seed_document = seed_document or knowledge_base.get_random_document()
context_documents = knowledge_base.get_neighbors(
seed_document, self._context_neighbors, self._context_similarity_threshold
)
Expand Down
4 changes: 2 additions & 2 deletions giskard/rag/question_generators/oos_questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class OutOfScopeGenerator(GenerateFromSingleQuestionMixin, _LLMBasedQuestionGene
_question_type = "out of scope"

def generate_single_question(
self, knowledge_base: KnowledgeBase, agent_description: str, language: str
self, knowledge_base: KnowledgeBase, agent_description: str, language: str, seed_document=None
) -> QuestionSample:
"""
Generate a question from a list of context documents.
Expand All @@ -87,7 +87,7 @@ def generate_single_question(
Tuple[dict, dict]
The generated question and the metadata of the question.
"""
seed_document = knowledge_base.get_random_document()
seed_document = seed_document or knowledge_base.get_random_document()

context_documents = knowledge_base.get_neighbors(
seed_document, self._context_neighbors, self._context_similarity_threshold
Expand Down
11 changes: 9 additions & 2 deletions giskard/rag/question_generators/simple_questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,13 @@ class SimpleQuestionsGenerator(GenerateFromSingleQuestionMixin, _LLMBasedQuestio

_question_type = "simple"

def generate_single_question(self, knowledge_base: KnowledgeBase, agent_description: str, language: str) -> dict:
def generate_single_question(
self,
knowledge_base: KnowledgeBase,
agent_description: str,
language: str,
seed_document=None,
) -> dict:
"""
Generate a question from a list of context documents.

Expand All @@ -80,7 +86,8 @@ def generate_single_question(self, knowledge_base: KnowledgeBase, agent_descript
QuestionSample
The generated question and the metadata of the question.
"""
seed_document = knowledge_base.get_random_document()
seed_document = seed_document or knowledge_base.get_random_document()

context_documents = knowledge_base.get_neighbors(
seed_document, self._context_neighbors, self._context_similarity_threshold
)
Expand Down
26 changes: 26 additions & 0 deletions tests/rag/test_knowledge_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,32 @@
from giskard.rag.knowledge_base import KnowledgeBase


def test_knowledge_base_get_random_documents():
llm_client = Mock()
embeddings = Mock()
embeddings.embed.side_effect = [np.random.rand(5, 10), np.random.rand(3, 10)]

kb = KnowledgeBase.from_pandas(
df=pd.DataFrame({"text": ["This is a test string"] * 5}), llm_client=llm_client, embedding_model=embeddings
)

# Test when k is smaller than the number of documents
docs = kb.get_random_documents(3)
assert len(docs) == 3
# Check that all document IDs are unique
assert len(set(doc.id for doc in docs)) == len(docs)

# Test when k is equal to the number of documents
docs = kb.get_random_documents(5)
assert len(docs) == 5
assert all([doc == kb[doc.id] for doc in docs])

# Test when k is larger than the number of documents
docs = kb.get_random_documents(10)
assert len(docs) == 10
assert all([doc == kb[doc.id] for doc in docs])


def test_knowledge_base_creation_from_df():
dimension = 8
df = pd.DataFrame(["This is a test string"] * 5)
Expand Down
3 changes: 3 additions & 0 deletions tests/rag/test_question_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def test_simple_question_generation():
Document(dict(content="Milk is produced by cows, goats or sheep.")),
]
knowledge_base.get_random_document = Mock(return_value=documents[0])
knowledge_base.get_random_documents = Mock(return_value=documents)
knowledge_base.get_neighbors = Mock(return_value=documents)

question_generator = SimpleQuestionsGenerator(llm_client=llm_client)
Expand Down Expand Up @@ -212,6 +213,7 @@ def test_double_question_generation():
Document(dict(content="Milk is produced by cows, goats or sheep.")),
]
knowledge_base.get_random_document = Mock(return_value=documents[0])
knowledge_base.get_random_documents = Mock(return_value=documents)
knowledge_base.get_neighbors = Mock(return_value=documents)

question_generator = DoubleQuestionsGenerator(llm_client=llm_client)
Expand Down Expand Up @@ -304,6 +306,7 @@ def test_oos_question_generation():
dict(content="Paul Graham liked to buy a baguette every day at the local market."), doc_id="1"
)
)
knowledge_base.get_random_documents = Mock(return_value=documents)
knowledge_base.get_neighbors = Mock(return_value=documents)

question_generator = OutOfScopeGenerator(llm_client=llm_client)
Expand Down
2 changes: 2 additions & 0 deletions tests/rag/test_testset_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ def test_question_generation_fail(caplog):
knowledge_base.__getitem__ = lambda obj, idx: documents[0]
knowledge_base.topics = ["Cheese", "Ski"]

knowledge_base.get_random_documents = Mock(return_value=documents)

simple_gen = Mock()
simple_gen.generate_questions.return_value = [q1, q2]
failing_gen = SimpleQuestionsGenerator(llm_client=Mock())
Expand Down