Retriever
TheRetriever class is the primary interface for loading your conversation data into Fair Forge. Every evaluation requires a custom retriever implementation.
Basic Structure
Copy
from fair_forge.core.retriever import Retriever
from fair_forge.schemas.common import Dataset, Batch
class MyRetriever(Retriever):
def __init__(self, **kwargs):
super().__init__(**kwargs)
# Store any configuration
self.data_path = kwargs.get('data_path', './data')
def load_dataset(self) -> list[Dataset]:
"""Load and return your conversation data."""
datasets = []
# Load your data and convert to Dataset objects
return datasets
Retriever Interface
TheRetriever abstract base class requires one method:
Copy
class Retriever(ABC):
def __init__(self, **kwargs):
"""Initialize with optional configuration."""
pass
@abstractmethod
def load_dataset(self) -> list[Dataset]:
"""Load and return datasets for evaluation.
Returns:
list[Dataset]: List of conversation datasets
"""
pass
Examples
Loading from JSON File
Copy
import json
from pathlib import Path
from fair_forge.core.retriever import Retriever
from fair_forge.schemas.common import Dataset
class JSONRetriever(Retriever):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.file_path = kwargs.get('file_path', 'data.json')
def load_dataset(self) -> list[Dataset]:
with open(self.file_path) as f:
data = json.load(f)
datasets = []
for item in data:
datasets.append(Dataset.model_validate(item))
return datasets
# Usage
metrics = Toxicity.run(
JSONRetriever,
file_path='conversations.json',
group_prototypes={...},
)
Loading from Database
Copy
import sqlite3
from fair_forge.core.retriever import Retriever
from fair_forge.schemas.common import Dataset, Batch
class DatabaseRetriever(Retriever):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.db_path = kwargs.get('db_path', 'conversations.db')
self.assistant_id = kwargs.get('assistant_id', 'default')
def load_dataset(self) -> list[Dataset]:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# Fetch sessions
cursor.execute("""
SELECT DISTINCT session_id, context, language
FROM conversations
WHERE assistant_id = ?
""", (self.assistant_id,))
datasets = []
for session_id, context, language in cursor.fetchall():
# Fetch conversation batches for this session
cursor.execute("""
SELECT qa_id, query, assistant, ground_truth
FROM conversations
WHERE session_id = ?
ORDER BY created_at
""", (session_id,))
batches = [
Batch(
qa_id=row[0],
query=row[1],
assistant=row[2],
ground_truth_assistant=row[3],
)
for row in cursor.fetchall()
]
datasets.append(Dataset(
session_id=session_id,
assistant_id=self.assistant_id,
language=language,
context=context,
conversation=batches,
))
conn.close()
return datasets
# Usage
metrics = Context.run(
DatabaseRetriever,
db_path='my_chatbot.db',
assistant_id='gpt-4-assistant',
model=judge_model,
)
Loading from API
Copy
import httpx
from fair_forge.core.retriever import Retriever
from fair_forge.schemas.common import Dataset, Batch
class APIRetriever(Retriever):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.api_url = kwargs.get('api_url')
self.api_key = kwargs.get('api_key')
def load_dataset(self) -> list[Dataset]:
headers = {'Authorization': f'Bearer {self.api_key}'}
with httpx.Client() as client:
response = client.get(
f'{self.api_url}/conversations',
headers=headers,
)
data = response.json()
return [Dataset.model_validate(d) for d in data['datasets']]
# Usage
metrics = Bias.run(
APIRetriever,
api_url='https://api.example.com',
api_key='your-api-key',
guardian=LLamaGuard,
config=guardian_config,
)
Loading from CSV
Copy
import csv
from collections import defaultdict
from fair_forge.core.retriever import Retriever
from fair_forge.schemas.common import Dataset, Batch
class CSVRetriever(Retriever):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.file_path = kwargs.get('file_path', 'conversations.csv')
def load_dataset(self) -> list[Dataset]:
# Group conversations by session
sessions = defaultdict(list)
with open(self.file_path, newline='') as f:
reader = csv.DictReader(f)
for row in reader:
sessions[row['session_id']].append(row)
datasets = []
for session_id, rows in sessions.items():
batches = [
Batch(
qa_id=row['qa_id'],
query=row['query'],
assistant=row['assistant'],
ground_truth_assistant=row.get('ground_truth', ''),
)
for row in rows
]
datasets.append(Dataset(
session_id=session_id,
assistant_id=rows[0].get('assistant_id', 'unknown'),
language=rows[0].get('language', 'english'),
context=rows[0].get('context', ''),
conversation=batches,
))
return datasets
Multi-Assistant Retriever (for BestOf)
Copy
from fair_forge.core.retriever import Retriever
from fair_forge.schemas.common import Dataset, Batch
class MultiAssistantRetriever(Retriever):
"""Load datasets from multiple assistants for comparison."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.assistants = kwargs.get('assistants', [])
def load_dataset(self) -> list[Dataset]:
datasets = []
# Same questions, different assistants
questions = [
"What are the benefits of renewable energy?",
"Explain machine learning in simple terms.",
]
for assistant_id in self.assistants:
# Get responses from each assistant
responses = self._get_responses(assistant_id, questions)
batches = [
Batch(qa_id=f"q{i}", query=q, assistant=r)
for i, (q, r) in enumerate(zip(questions, responses))
]
datasets.append(Dataset(
session_id=f"eval-{assistant_id}",
assistant_id=assistant_id,
language="english",
context="",
conversation=batches,
))
return datasets
def _get_responses(self, assistant_id: str, questions: list) -> list:
# Implement based on your assistant API
pass
# Usage with BestOf
metrics = BestOf.run(
MultiAssistantRetriever,
assistants=['gpt-4', 'claude-3', 'llama-3'],
model=judge_model,
)
Best Practices
Use kwargs for Configuration
Use kwargs for Configuration
Pass configuration through
**kwargs to keep retrievers flexible:Copy
class MyRetriever(Retriever):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.file_path = kwargs.get('file_path', 'default.json')
self.filter_by = kwargs.get('filter_by', None)
Handle Missing Data Gracefully
Handle Missing Data Gracefully
Set defaults for optional fields:
Copy
Batch(
qa_id=row.get('qa_id', f'q-{idx}'),
query=row['query'], # Required
assistant=row.get('assistant', ''),
ground_truth_assistant=row.get('ground_truth'), # None if missing
)
Log Data Loading
Log Data Loading
Use the logging utilities for debugging:
Copy
from fair_forge.utils.logging import logger
class MyRetriever(Retriever):
def load_dataset(self) -> list[Dataset]:
logger.info(f"Loading datasets from {self.file_path}")
datasets = self._load()
logger.info(f"Loaded {len(datasets)} datasets")
return datasets
Validate Data
Validate Data
Use Pydantic validation to catch issues early:
Copy
def load_dataset(self) -> list[Dataset]:
datasets = []
for item in raw_data:
try:
datasets.append(Dataset.model_validate(item))
except ValidationError as e:
logger.warning(f"Skipping invalid dataset: {e}")
return datasets