ryanyen22's picture
feat: add reason_first_program/sampling.py
76a5bfd verified
"""
Stage 1: Program Space Sampling
Generate diverse valid implementations of a stub using multiple strategies:
- Direct sampling from LLMs at various temperatures
- SFS-inspired scattering (2411.05010): diversify via textual gradient directions
- Multi-model heterogeneous sampling (AlgoDiv finding: diversity requires multiple models)
- Concept-guided sampling: steer toward specific concept regions
Supports both API-based models (OpenAI, Anthropic, HF Inference) and local models.
"""
from __future__ import annotations
import re
import time
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Optional
from reason_first_program.stub import Stub
from reason_first_program.program_space import Program, ProgramSpace, execute_program
logger = logging.getLogger(__name__)
@dataclass
class SamplingConfig:
"""Configuration for program sampling."""
n_samples: int = 100
temperatures: list[float] = field(
default_factory=lambda: [0.2, 0.6, 0.8, 1.0, 1.2]
)
models: list[str] = field(
default_factory=lambda: ["deepseek-coder"]
)
prompt_styles: list[str] = field(
default_factory=lambda: ["direct", "diverse"]
)
max_tokens: int = 1024
timeout_per_execution: float = 5.0
deduplicate: bool = True
filter_valid: bool = True
class ModelBackend(ABC):
"""Abstract backend for code generation."""
@abstractmethod
def generate(
self,
prompt: str,
temperature: float = 0.8,
max_tokens: int = 1024,
n: int = 1,
) -> list[str]:
"""Generate n completions for the given prompt."""
...
@property
@abstractmethod
def model_id(self) -> str:
...
class HFInferenceBackend(ModelBackend):
"""HuggingFace Inference API backend."""
def __init__(self, model_name: str = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct", token: Optional[str] = None):
self._model_name = model_name
self._token = token
@property
def model_id(self) -> str:
return self._model_name
def generate(
self,
prompt: str,
temperature: float = 0.8,
max_tokens: int = 1024,
n: int = 1,
) -> list[str]:
try:
from huggingface_hub import InferenceClient
except ImportError:
raise ImportError("pip install huggingface_hub")
client = InferenceClient(model=self._model_name, token=self._token)
results = []
for _ in range(n):
try:
response = client.text_generation(
prompt,
max_new_tokens=max_tokens,
temperature=max(temperature, 0.01),
do_sample=True,
)
results.append(response)
except Exception as e:
logger.warning(f"Generation failed: {e}")
continue
return results
class OpenAIBackend(ModelBackend):
"""OpenAI API backend."""
def __init__(self, model_name: str = "gpt-4o", api_key: Optional[str] = None):
self._model_name = model_name
self._api_key = api_key
@property
def model_id(self) -> str:
return self._model_name
def generate(
self,
prompt: str,
temperature: float = 0.8,
max_tokens: int = 1024,
n: int = 1,
) -> list[str]:
try:
import openai
except ImportError:
raise ImportError("pip install openai")
client = openai.OpenAI(api_key=self._api_key)
try:
response = client.chat.completions.create(
model=self._model_name,
messages=[
{"role": "system", "content": "You are an expert Python programmer. Output only the function body, no explanation."},
{"role": "user", "content": prompt},
],
temperature=temperature,
max_tokens=max_tokens,
n=n,
)
return [choice.message.content for choice in response.choices]
except Exception as e:
logger.warning(f"OpenAI generation failed: {e}")
return []
class LocalModelBackend(ModelBackend):
"""Local model backend using transformers."""
def __init__(self, model_name: str = "deepseek-ai/deepseek-coder-1.3b-instruct", device: str = "auto"):
self._model_name = model_name
self._device = device
self._pipeline = None
@property
def model_id(self) -> str:
return self._model_name
def _load(self):
if self._pipeline is None:
try:
from transformers import pipeline
except ImportError:
raise ImportError("pip install transformers torch")
self._pipeline = pipeline(
"text-generation",
model=self._model_name,
device_map=self._device,
trust_remote_code=True,
)
def generate(
self,
prompt: str,
temperature: float = 0.8,
max_tokens: int = 1024,
n: int = 1,
) -> list[str]:
self._load()
results = []
for _ in range(n):
try:
out = self._pipeline(
prompt,
max_new_tokens=max_tokens,
temperature=max(temperature, 0.01),
do_sample=True,
return_full_text=False,
)
results.append(out[0]["generated_text"])
except Exception as e:
logger.warning(f"Local generation failed: {e}")
return results
def _extract_function_body(raw_output: str, stub: Stub) -> Optional[str]:
"""
Extract a clean function body from LLM output.
Handles markdown code blocks, extra commentary, etc.
"""
text = raw_output.strip()
# Remove markdown code fences
code_block = re.search(r"```(?:python)?\s*\n(.*?)```", text, re.DOTALL)
if code_block:
text = code_block.group(1).strip()
# If the output contains a full function def, extract it
func_match = re.search(
rf"def\s+{re.escape(stub.name)}\s*\(.*?\).*?:\s*\n(.*)",
text,
re.DOTALL,
)
if func_match:
text = func_match.group(1)
# Remove any leading/trailing non-code lines
lines = text.split("\n")
code_lines = []
in_code = False
for line in lines:
stripped = line.strip()
if stripped and not stripped.startswith("#") and not in_code:
in_code = True
if in_code or stripped.startswith("#"):
code_lines.append(line)
if not code_lines:
return None
return "\n".join(code_lines)
def _build_full_source(body: str, stub: Stub) -> str:
"""Reconstruct full function source from body and stub signature."""
# Extract just the def line from the stub source
for line in stub.source.split("\n"):
if line.strip().startswith("def "):
def_line = line
break
else:
def_line = f"def {stub.name}{stub.signature}:"
# Ensure proper indentation of body
indented_body = "\n".join(
f" {line}" if line.strip() else line for line in body.split("\n")
)
return f"{def_line}\n{indented_body}"
class ProgramSampler:
"""
Basic program sampler: generates completions from a single backend.
"""
def __init__(self, backend: ModelBackend, config: Optional[SamplingConfig] = None):
self.backend = backend
self.config = config or SamplingConfig()
def sample(self, stub: Stub) -> ProgramSpace:
"""Sample programs for a stub and return a ProgramSpace."""
space = ProgramSpace(stub)
samples_per_config = max(
1,
self.config.n_samples
// (len(self.config.temperatures) * len(self.config.prompt_styles)),
)
for temp in self.config.temperatures:
for style in self.config.prompt_styles:
prompt = stub.to_completion_prompt(style=style)
logger.info(
f"Sampling {samples_per_config} programs "
f"(temp={temp}, style={style}, model={self.backend.model_id})"
)
raw_outputs = self.backend.generate(
prompt=prompt,
temperature=temp,
max_tokens=self.config.max_tokens,
n=samples_per_config,
)
for raw in raw_outputs:
body = _extract_function_body(raw, stub)
if body is None:
continue
full_source = _build_full_source(body, stub)
program = Program(
source=body,
full_source=full_source,
stub_id=stub.stub_id,
model_id=self.backend.model_id,
metadata={
"temperature": temp,
"prompt_style": style,
},
)
# Execute and validate
if stub.test_inputs:
program = execute_program(
program, stub, stub.test_inputs,
timeout_seconds=self.config.timeout_per_execution,
)
space.add(program)
# Post-processing
if self.config.deduplicate:
space = space.deduplicate_syntactic()
if self.config.filter_valid and stub.test_inputs:
space = space.filter_valid()
return space
class DiverseSampler:
"""
Diverse program sampler using multiple backends and SFS-inspired scattering.
Key insight from AlgoDiv (2503.00691): combining solutions from heterogeneous
models increases algorithmic diversity more than any single-model technique.
"""
def __init__(
self,
backends: list[ModelBackend],
config: Optional[SamplingConfig] = None,
):
self.backends = backends
self.config = config or SamplingConfig()
def sample(self, stub: Stub) -> ProgramSpace:
"""Sample from all backends and merge into a single ProgramSpace."""
space = ProgramSpace(stub)
samples_per_backend = max(1, self.config.n_samples // len(self.backends))
for backend in self.backends:
backend_config = SamplingConfig(
n_samples=samples_per_backend,
temperatures=self.config.temperatures,
models=[backend.model_id],
prompt_styles=self.config.prompt_styles,
max_tokens=self.config.max_tokens,
timeout_per_execution=self.config.timeout_per_execution,
deduplicate=False, # We'll deduplicate at the end
filter_valid=False,
)
sampler = ProgramSampler(backend, backend_config)
backend_space = sampler.sample(stub)
for program in backend_space.programs:
space.add(program)
logger.info(
f"Backend {backend.model_id}: generated {len(backend_space)} programs"
)
# Post-processing across all backends
if self.config.deduplicate:
space = space.deduplicate_syntactic()
if self.config.filter_valid and stub.test_inputs:
space = space.filter_valid()
logger.info(
f"DiverseSampler: {len(space)} total programs "
f"({len(space.valid_programs)} valid)"
)
return space
def sample_with_scattering(
self, stub: Stub, n_directions: int = 5
) -> ProgramSpace:
"""
SFS-inspired scattering (2411.05010): first discover diverse algorithmic
directions, then sample implementations along each direction.
"""
# Phase 1: Discover algorithmic directions
scout_backend = self.backends[0]
direction_prompt = (
f"Consider this Python function stub:\n\n"
f"```python\n{stub.source}\n```\n\n"
f"{stub.constraints.to_prompt_context()}\n\n"
f"List {n_directions} fundamentally different algorithmic approaches "
f"to implement this function. For each, give a short name and 1-sentence "
f"description. Format: '1. NAME: description'"
)
direction_outputs = scout_backend.generate(
direction_prompt, temperature=0.7, n=1
)
directions = []
if direction_outputs:
for line in direction_outputs[0].split("\n"):
line = line.strip()
if line and line[0].isdigit():
# Extract direction name
match = re.match(r"\d+\.\s*(.+?)(?::|$)", line)
if match:
directions.append(match.group(1).strip())
if not directions:
directions = [
"iterative approach",
"recursive approach",
"functional/map-reduce approach",
"optimized in-place approach",
"library-heavy approach",
]
logger.info(f"Discovered {len(directions)} algorithmic directions: {directions}")
# Phase 2: Sample along each direction
space = ProgramSpace(stub)
samples_per_direction = max(
1, self.config.n_samples // (len(directions) * len(self.backends))
)
for direction in directions:
directed_prompt = (
f"Complete this Python function using the following approach: "
f"**{direction}**\n\n"
f"```python\n{stub.source}\n```\n\n"
f"{stub.constraints.to_prompt_context()}\n\n"
f"Only output the function body. Use the {direction} approach."
)
for backend in self.backends:
for temp in self.config.temperatures:
raw_outputs = backend.generate(
directed_prompt,
temperature=temp,
max_tokens=self.config.max_tokens,
n=samples_per_direction,
)
for raw in raw_outputs:
body = _extract_function_body(raw, stub)
if body is None:
continue
full_source = _build_full_source(body, stub)
program = Program(
source=body,
full_source=full_source,
stub_id=stub.stub_id,
model_id=backend.model_id,
metadata={
"temperature": temp,
"direction": direction,
"prompt_style": "scattered",
},
)
if stub.test_inputs:
program = execute_program(
program, stub, stub.test_inputs,
timeout_seconds=self.config.timeout_per_execution,
)
space.add(program)
# Post-processing
if self.config.deduplicate:
space = space.deduplicate_syntactic()
if self.config.filter_valid and stub.test_inputs:
space = space.filter_valid()
return space