Skip to main content

Retriever

The Retriever class is the primary interface for loading your conversation data into Fair Forge. Every evaluation requires a custom retriever implementation.

Basic Structure

from fair_forge.core.retriever import Retriever
from fair_forge.schemas.common import Dataset, Batch

class MyRetriever(Retriever):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        # Store any configuration
        self.data_path = kwargs.get('data_path', './data')

    def load_dataset(self) -> list[Dataset]:
        """Load and return your conversation data."""
        datasets = []
        # Load your data and convert to Dataset objects
        return datasets

Retriever Interface

The Retriever abstract base class requires one method:
class Retriever(ABC):
    def __init__(self, **kwargs):
        """Initialize with optional configuration."""
        pass

    @abstractmethod
    def load_dataset(self) -> list[Dataset]:
        """Load and return datasets for evaluation.

        Returns:
            list[Dataset]: List of conversation datasets
        """
        pass

Examples

Loading from JSON File

import json
from pathlib import Path
from fair_forge.core.retriever import Retriever
from fair_forge.schemas.common import Dataset

class JSONRetriever(Retriever):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.file_path = kwargs.get('file_path', 'data.json')

    def load_dataset(self) -> list[Dataset]:
        with open(self.file_path) as f:
            data = json.load(f)

        datasets = []
        for item in data:
            datasets.append(Dataset.model_validate(item))
        return datasets

# Usage
metrics = Toxicity.run(
    JSONRetriever,
    file_path='conversations.json',
    group_prototypes={...},
)

Loading from Database

import sqlite3
from fair_forge.core.retriever import Retriever
from fair_forge.schemas.common import Dataset, Batch

class DatabaseRetriever(Retriever):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.db_path = kwargs.get('db_path', 'conversations.db')
        self.assistant_id = kwargs.get('assistant_id', 'default')

    def load_dataset(self) -> list[Dataset]:
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()

        # Fetch sessions
        cursor.execute("""
            SELECT DISTINCT session_id, context, language
            FROM conversations
            WHERE assistant_id = ?
        """, (self.assistant_id,))

        datasets = []
        for session_id, context, language in cursor.fetchall():
            # Fetch conversation batches for this session
            cursor.execute("""
                SELECT qa_id, query, assistant, ground_truth
                FROM conversations
                WHERE session_id = ?
                ORDER BY created_at
            """, (session_id,))

            batches = [
                Batch(
                    qa_id=row[0],
                    query=row[1],
                    assistant=row[2],
                    ground_truth_assistant=row[3],
                )
                for row in cursor.fetchall()
            ]

            datasets.append(Dataset(
                session_id=session_id,
                assistant_id=self.assistant_id,
                language=language,
                context=context,
                conversation=batches,
            ))

        conn.close()
        return datasets

# Usage
metrics = Context.run(
    DatabaseRetriever,
    db_path='my_chatbot.db',
    assistant_id='gpt-4-assistant',
    model=judge_model,
)

Loading from API

import httpx
from fair_forge.core.retriever import Retriever
from fair_forge.schemas.common import Dataset, Batch

class APIRetriever(Retriever):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.api_url = kwargs.get('api_url')
        self.api_key = kwargs.get('api_key')

    def load_dataset(self) -> list[Dataset]:
        headers = {'Authorization': f'Bearer {self.api_key}'}

        with httpx.Client() as client:
            response = client.get(
                f'{self.api_url}/conversations',
                headers=headers,
            )
            data = response.json()

        return [Dataset.model_validate(d) for d in data['datasets']]

# Usage
metrics = Bias.run(
    APIRetriever,
    api_url='https://api.example.com',
    api_key='your-api-key',
    guardian=LLamaGuard,
    config=guardian_config,
)

Loading from CSV

import csv
from collections import defaultdict
from fair_forge.core.retriever import Retriever
from fair_forge.schemas.common import Dataset, Batch

class CSVRetriever(Retriever):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.file_path = kwargs.get('file_path', 'conversations.csv')

    def load_dataset(self) -> list[Dataset]:
        # Group conversations by session
        sessions = defaultdict(list)

        with open(self.file_path, newline='') as f:
            reader = csv.DictReader(f)
            for row in reader:
                sessions[row['session_id']].append(row)

        datasets = []
        for session_id, rows in sessions.items():
            batches = [
                Batch(
                    qa_id=row['qa_id'],
                    query=row['query'],
                    assistant=row['assistant'],
                    ground_truth_assistant=row.get('ground_truth', ''),
                )
                for row in rows
            ]

            datasets.append(Dataset(
                session_id=session_id,
                assistant_id=rows[0].get('assistant_id', 'unknown'),
                language=rows[0].get('language', 'english'),
                context=rows[0].get('context', ''),
                conversation=batches,
            ))

        return datasets

Multi-Assistant Retriever (for BestOf)

from fair_forge.core.retriever import Retriever
from fair_forge.schemas.common import Dataset, Batch

class MultiAssistantRetriever(Retriever):
    """Load datasets from multiple assistants for comparison."""

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.assistants = kwargs.get('assistants', [])

    def load_dataset(self) -> list[Dataset]:
        datasets = []

        # Same questions, different assistants
        questions = [
            "What are the benefits of renewable energy?",
            "Explain machine learning in simple terms.",
        ]

        for assistant_id in self.assistants:
            # Get responses from each assistant
            responses = self._get_responses(assistant_id, questions)

            batches = [
                Batch(qa_id=f"q{i}", query=q, assistant=r)
                for i, (q, r) in enumerate(zip(questions, responses))
            ]

            datasets.append(Dataset(
                session_id=f"eval-{assistant_id}",
                assistant_id=assistant_id,
                language="english",
                context="",
                conversation=batches,
            ))

        return datasets

    def _get_responses(self, assistant_id: str, questions: list) -> list:
        # Implement based on your assistant API
        pass

# Usage with BestOf
metrics = BestOf.run(
    MultiAssistantRetriever,
    assistants=['gpt-4', 'claude-3', 'llama-3'],
    model=judge_model,
)

Best Practices

Pass configuration through **kwargs to keep retrievers flexible:
class MyRetriever(Retriever):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.file_path = kwargs.get('file_path', 'default.json')
        self.filter_by = kwargs.get('filter_by', None)
Set defaults for optional fields:
Batch(
    qa_id=row.get('qa_id', f'q-{idx}'),
    query=row['query'],  # Required
    assistant=row.get('assistant', ''),
    ground_truth_assistant=row.get('ground_truth'),  # None if missing
)
Use the logging utilities for debugging:
from fair_forge.utils.logging import logger

class MyRetriever(Retriever):
    def load_dataset(self) -> list[Dataset]:
        logger.info(f"Loading datasets from {self.file_path}")
        datasets = self._load()
        logger.info(f"Loaded {len(datasets)} datasets")
        return datasets
Use Pydantic validation to catch issues early:
def load_dataset(self) -> list[Dataset]:
    datasets = []
    for item in raw_data:
        try:
            datasets.append(Dataset.model_validate(item))
        except ValidationError as e:
            logger.warning(f"Skipping invalid dataset: {e}")
    return datasets

Next Steps