ryanyen22's picture
feat: add reason_first_program/experiment.py
498fb0a verified
"""
First Experiment: Concept Probing on Program Completions
This experiment validates the central hypothesis:
"Concept structure exists in the program space and can be extracted,
verified, and used for steering."
Steps:
1. Define stub functions with known diverse solution strategies
2. Generate multiple valid implementations (simulated or via LLM)
3. Run concept discovery (behavioral + structural)
4. Build concept-guided embedding space
5. Verify alignment: do concept dimensions separate meaningful program families?
6. Test steering: can we select programs by concept queries?
This script runs WITHOUT GPU and WITHOUT API keys by using hand-crafted
program populations. For the full LLM-powered pipeline, see experiment_llm.py.
"""
from __future__ import annotations
import json
import logging
import sys
from typing import Any
import numpy as np
from reason_first_program.stub import reason_first, Stub, StubConstraints
from reason_first_program.program_space import Program, ProgramSpace, execute_program
from reason_first_program.concepts import (
BehavioralConceptDiscovery,
AbstractionConceptDiscovery,
UnifiedConceptDiscovery,
ConceptSet,
)
from reason_first_program.embeddings import (
ConceptBottleneckAE,
GCAVEmbedding,
ConceptEmbeddingSpace,
)
from reason_first_program.steering import ConceptQuery, SteeringEngine, QueryLanguage
logging.basicConfig(level=logging.INFO, format="%(name)s - %(message)s")
logger = logging.getLogger(__name__)
# ============================================================
# Stub Definitions
# ============================================================
def create_two_sum_stub() -> Stub:
"""The classic 'two sum' problem — rich space of valid implementations."""
constraints = StubConstraints(
decorator_spec="Find two indices in nums that sum to target",
inline_specs=["return pair of indices", "each input has exactly one solution"],
type_hints={"nums": "list[int]", "target": "int"},
return_type="list[int]",
)
return Stub(
name="two_sum",
source=(
"def two_sum(nums: list[int], target: int) -> list[int]:\n"
" #> return pair of indices; each input has exactly one solution\n"
" ..."
),
signature="(nums: list[int], target: int) -> list[int]",
constraints=constraints,
test_inputs=[
{"nums": [2, 7, 11, 15], "target": 9},
{"nums": [3, 2, 4], "target": 6},
{"nums": [3, 3], "target": 6},
{"nums": [1, 5, 3, 7, 2], "target": 8},
],
test_outputs=[[0, 1], [1, 2], [0, 1], [1, 3]],
)
def create_flatten_stub() -> Stub:
"""Flatten a nested list — diverse algorithmic approaches possible."""
constraints = StubConstraints(
decorator_spec="Flatten a nested list of arbitrary depth into a flat list",
inline_specs=["handle arbitrary nesting depth", "preserve element order"],
type_hints={"nested": "list"},
return_type="list",
)
return Stub(
name="flatten",
source=(
"def flatten(nested: list) -> list:\n"
" #> handle arbitrary nesting depth; preserve element order\n"
" ..."
),
signature="(nested: list) -> list",
constraints=constraints,
test_inputs=[
{"nested": [1, [2, 3], [4, [5, 6]]]},
{"nested": [[1, 2], [3, [4, [5]]]]},
{"nested": [1, 2, 3]},
{"nested": []},
{"nested": [[[1]], [[2]], [[3]]]},
],
test_outputs=[
[1, 2, 3, 4, 5, 6],
[1, 2, 3, 4, 5],
[1, 2, 3],
[],
[1, 2, 3],
],
)
def create_group_anagrams_stub() -> Stub:
"""Group anagrams — several data structure choices."""
constraints = StubConstraints(
decorator_spec="Group strings that are anagrams of each other",
inline_specs=["anagrams have same letters in different order"],
type_hints={"strs": "list[str]"},
return_type="list[list[str]]",
)
return Stub(
name="group_anagrams",
source=(
"def group_anagrams(strs: list[str]) -> list[list[str]]:\n"
" #> anagrams have same letters in different order\n"
" ..."
),
signature="(strs: list[str]) -> list[list[str]]",
constraints=constraints,
test_inputs=[
{"strs": ["eat", "tea", "tan", "ate", "nat", "bat"]},
{"strs": [""]},
{"strs": ["a"]},
],
)
# ============================================================
# Hand-Crafted Program Populations (for testing without LLM)
# ============================================================
TWO_SUM_IMPLEMENTATIONS = [
# 1. Brute force - nested loops
{
"source": """ for i in range(len(nums)):
for j in range(i + 1, len(nums)):
if nums[i] + nums[j] == target:
return [i, j]
return []""",
"label": "brute_force",
},
# 2. Hash map - single pass
{
"source": """ seen = {}
for i, num in enumerate(nums):
complement = target - num
if complement in seen:
return [seen[complement], i]
seen[num] = i
return []""",
"label": "hash_single_pass",
},
# 3. Hash map - two pass
{
"source": """ lookup = {}
for i, num in enumerate(nums):
lookup[num] = i
for i, num in enumerate(nums):
complement = target - num
if complement in lookup and lookup[complement] != i:
return [i, lookup[complement]]
return []""",
"label": "hash_two_pass",
},
# 4. Sorting + two pointers
{
"source": """ indexed = sorted(enumerate(nums), key=lambda x: x[1])
left, right = 0, len(indexed) - 1
while left < right:
current_sum = indexed[left][1] + indexed[right][1]
if current_sum == target:
return sorted([indexed[left][0], indexed[right][0]])
elif current_sum < target:
left += 1
else:
right -= 1
return []""",
"label": "sort_two_pointer",
},
# 5. List comprehension + index
{
"source": """ pairs = [(i, j) for i in range(len(nums)) for j in range(i+1, len(nums)) if nums[i] + nums[j] == target]
return list(pairs[0]) if pairs else []""",
"label": "comprehension",
},
# 6. Functional approach with filter
{
"source": """ from itertools import combinations
for i, j in combinations(range(len(nums)), 2):
if nums[i] + nums[j] == target:
return [i, j]
return []""",
"label": "itertools",
},
# 7. Recursive approach
{
"source": """ def helper(start, nums_list, target_val):
if start >= len(nums_list) - 1:
return []
for j in range(start + 1, len(nums_list)):
if nums_list[start] + nums_list[j] == target_val:
return [start, j]
return helper(start + 1, nums_list, target_val)
return helper(0, nums, target)""",
"label": "recursive",
},
# 8. Set-based lookup
{
"source": """ num_set = set(nums)
for i, num in enumerate(nums):
complement = target - num
if complement in num_set and nums.index(complement) != i:
j = nums.index(complement)
return [min(i, j), max(i, j)]
return []""",
"label": "set_based",
},
# 9. Dict with early return + validation
{
"source": """ if not nums or len(nums) < 2:
return []
seen = {}
for idx, val in enumerate(nums):
diff = target - val
if diff in seen:
return [seen[diff], idx]
seen[val] = idx
return []""",
"label": "hash_with_guard",
},
# 10. While loop approach
{
"source": """ i = 0
while i < len(nums):
j = i + 1
while j < len(nums):
if nums[i] + nums[j] == target:
return [i, j]
j += 1
i += 1
return []""",
"label": "while_loop",
},
]
FLATTEN_IMPLEMENTATIONS = [
# 1. Recursive
{
"source": """ result = []
for item in nested:
if isinstance(item, list):
result.extend(flatten(item))
else:
result.append(item)
return result""",
"label": "recursive_extend",
},
# 2. Iterative with stack
{
"source": """ stack = list(nested)
result = []
while stack:
item = stack.pop(0)
if isinstance(item, list):
stack = item + stack
else:
result.append(item)
return result""",
"label": "iterative_stack",
},
# 3. Generator-based
{
"source": """ def _flatten_gen(lst):
for item in lst:
if isinstance(item, list):
yield from _flatten_gen(item)
else:
yield item
return list(_flatten_gen(nested))""",
"label": "generator",
},
# 4. Recursive one-liner
{
"source": """ return [x for item in nested for x in (flatten(item) if isinstance(item, list) else [item])]""",
"label": "recursive_comprehension",
},
# 5. Iterative with explicit stack (LIFO)
{
"source": """ stack = [nested]
result = []
while stack:
current = stack.pop()
if isinstance(current, list):
for item in reversed(current):
stack.append(item)
else:
result.append(current)
return result""",
"label": "iterative_lifo",
},
# 6. Reduce-based
{
"source": """ from functools import reduce
def reducer(acc, item):
if isinstance(item, list):
return reduce(reducer, item, acc)
return acc + [item]
return reduce(reducer, nested, [])""",
"label": "reduce",
},
# 7. While loop mutation
{
"source": """ result = list(nested)
i = 0
while i < len(result):
if isinstance(result[i], list):
items = result.pop(i)
for j, item in enumerate(items):
result.insert(i + j, item)
else:
i += 1
return result""",
"label": "while_mutation",
},
# 8. Nested helper with accumulator
{
"source": """ def helper(lst, acc):
for item in lst:
if isinstance(item, list):
helper(item, acc)
else:
acc.append(item)
return acc
return helper(nested, [])""",
"label": "accumulator",
},
]
def build_program(source: str, label: str, stub: Stub) -> Program:
"""Build a Program from source and execute against test inputs."""
# Build full function source
def_line = stub.source.split("\n")[0]
# Source strings already have 4-space indentation; use as-is
body = source.rstrip()
# Ensure body has proper indentation (4 spaces minimum)
lines = body.split("\n")
# Check if first non-empty line already starts with spaces
first_content = next((l for l in lines if l.strip()), "")
if first_content and not first_content.startswith(" "):
# Add indentation
indented_source = "\n".join(
f" {line}" if line.strip() else line
for line in lines
)
else:
indented_source = body
full_source = f"{def_line}\n{indented_source}"
program = Program(
source=source.strip(),
full_source=full_source,
stub_id=stub.stub_id,
model_id="handcrafted",
metadata={"label": label},
)
if stub.test_inputs:
program = execute_program(program, stub, stub.test_inputs)
return program
def run_experiment() -> dict[str, Any]:
"""
Run the full concept probing experiment.
Returns a comprehensive results dict suitable for JSON serialization.
"""
results: dict[str, Any] = {}
# ============================================================
# 1. Build Program Spaces
# ============================================================
logger.info("=" * 60)
logger.info("STAGE 1: Building Program Spaces")
logger.info("=" * 60)
two_sum_stub = create_two_sum_stub()
flatten_stub = create_flatten_stub()
# Build two_sum space
ts_space = ProgramSpace(two_sum_stub)
for impl in TWO_SUM_IMPLEMENTATIONS:
p = build_program(impl["source"], impl["label"], two_sum_stub)
ts_space.add(p)
# Build flatten space
fl_space = ProgramSpace(flatten_stub)
for impl in FLATTEN_IMPLEMENTATIONS:
p = build_program(impl["source"], impl["label"], flatten_stub)
fl_space.add(p)
logger.info(f"two_sum space: {len(ts_space)} programs, {len(ts_space.valid_programs)} valid")
logger.info(f"flatten space: {len(fl_space)} programs, {len(fl_space.valid_programs)} valid")
results["program_spaces"] = {
"two_sum": ts_space.diversity_report(),
"flatten": fl_space.diversity_report(),
}
# ============================================================
# 2. Concept Discovery
# ============================================================
logger.info("\n" + "=" * 60)
logger.info("STAGE 2: Concept Discovery")
logger.info("=" * 60)
discovery = UnifiedConceptDiscovery()
ts_concepts = discovery.discover(ts_space)
fl_concepts = discovery.discover(fl_space)
logger.info(f"\ntwo_sum concepts ({len(ts_concepts.concepts)}):")
for c in ts_concepts.concepts:
logger.info(
f" {c.name}: {c.description} "
f"[{len(c.programs)}/{len(ts_space.valid_programs)} programs]"
)
logger.info(f"\nflatten concepts ({len(fl_concepts.concepts)}):")
for c in fl_concepts.concepts:
logger.info(
f" {c.name}: {c.description} "
f"[{len(c.programs)}/{len(fl_space.valid_programs)} programs]"
)
results["concepts"] = {
"two_sum": ts_concepts.to_dict(),
"flatten": fl_concepts.to_dict(),
}
# ============================================================
# 3. Concept-Guided Embedding
# ============================================================
logger.info("\n" + "=" * 60)
logger.info("STAGE 3: Concept-Guided Embeddings")
logger.info("=" * 60)
for name, space, concepts in [
("two_sum", ts_space, ts_concepts),
("flatten", fl_space, fl_concepts),
]:
if not concepts.concepts:
logger.warning(f"No concepts found for {name}, skipping embedding")
continue
embedding = ConceptEmbeddingSpace(concepts)
programs = space.valid_programs
if not programs:
logger.warning(f"No valid programs for {name}, skipping")
continue
# Project into concept space
projection = embedding.project(programs)
logger.info(f"\n{name} concept projection shape: {projection.shape}")
# Verify alignment
alignment = embedding.verify_alignment(programs)
logger.info(f"{name} alignment metrics:")
for k, v in alignment.items():
logger.info(f" {k}: {v:.4f}" if isinstance(v, float) else f" {k}: {v}")
results[f"{name}_alignment"] = alignment
# Show concept scores for each program
logger.info(f"\n{name} concept scores per program:")
concept_names = concepts.names
for program in programs:
scores = concepts.score_program(program)
active = [n for n, s in scores.items() if s > 0.5]
label = program.metadata.get("label", "?")
logger.info(f" [{label}]: {', '.join(active) if active else 'none'}")
# Score matrix for GCAV training
score_matrix = concepts.score_matrix(programs)
results[f"{name}_score_matrix"] = {
"shape": list(score_matrix.shape),
"mean_scores": score_matrix.mean(axis=0).tolist(),
"concept_names": concept_names,
}
# Train GCAV if we have enough data
if len(programs) >= 4 and score_matrix.shape[1] >= 2:
logger.info(f"\nTraining GCAV for {name}...")
# Use behavioral vectors as features
features = np.array([p.behavioral_vector() for p in programs])
# Pad to same length
max_len = max(f.shape[0] for f in [features[i] for i in range(len(features))])
padded_features = np.zeros((len(programs), max_len))
for i, f in enumerate(features):
padded_features[i, : len(f)] = f
gcav = GCAVEmbedding()
gcav_results = gcav.train_all(concepts, padded_features, programs)
logger.info(f" Trained {len(gcav_results)} concept vectors:")
for cname, metrics in gcav_results.items():
logger.info(f" {cname}: accuracy={metrics['accuracy']:.2f}")
results[f"{name}_gcav"] = {
k: {mk: float(mv) if isinstance(mv, (int, float)) else str(mv) for mk, mv in v.items()}
for k, v in gcav_results.items()
}
# ============================================================
# 4. Steering / Query Language
# ============================================================
logger.info("\n" + "=" * 60)
logger.info("STAGE 4: Query Language & Steering")
logger.info("=" * 60)
for name, space, concepts in [
("two_sum", ts_space, ts_concepts),
("flatten", fl_space, fl_concepts),
]:
if not concepts.concepts:
continue
engine = SteeringEngine(concepts)
ql = engine.query_language
logger.info(f"\n{name} — Query Examples:")
# Example queries
queries = [
("uses_recursion=1.0", "Find recursive implementations"),
("uses_iteration=1.0, uses_mutation=-1.0", "Iterative but immutable"),
("uses_dict=1.0, uses_early_return=1.0", "Hash-based with guards"),
("uses_list_comprehension=1.0", "Comprehension-based"),
]
query_results = []
for query_str, description in queries:
query = ql.parse(query_str)
selected = engine.select(space, query, top_k=3)
if selected:
logger.info(f"\n Query: {description}")
logger.info(f" Parsed: {query}")
for program, score in selected:
label = program.metadata.get("label", "?")
logger.info(f" -> [{label}] relevance={score:.3f}")
query_results.append({
"query": query_str,
"description": description,
"results": [
{"label": p.metadata.get("label", "?"), "relevance": float(s)}
for p, s in selected
],
})
results[f"{name}_queries"] = query_results
# Concept boundary exploration
logger.info(f"\n{name} — Concept Boundaries:")
for concept in concepts.concepts[:3]:
boundary = engine.concept_boundary_programs(concept.name, space)
inside_labels = [
p.metadata.get("label", "?") for p in boundary["inside"]
]
outside_labels = [
p.metadata.get("label", "?") for p in boundary["outside"]
]
logger.info(
f" {concept.name}: inside={inside_labels}, outside={outside_labels}"
)
# ============================================================
# 5. Concept Lattice (FCA)
# ============================================================
logger.info("\n" + "=" * 60)
logger.info("STAGE 5: Concept Lattice (Formal Concept Analysis)")
logger.info("=" * 60)
for name, space, concepts in [
("two_sum", ts_space, ts_concepts),
("flatten", fl_space, fl_concepts),
]:
if not concepts.concepts:
continue
lattice = concepts.concept_lattice()
logger.info(f"\n{name} concept lattice ({len(lattice)} nodes):")
for extent, intent in lattice[:10]: # Show first 10
n_programs = len(extent)
concept_names = sorted(intent)
# Find labels of programs in this extent
program_labels = []
for pid in extent:
p = space.get(pid)
if p:
program_labels.append(p.metadata.get("label", "?"))
logger.info(
f" Concepts: {concept_names} -> "
f"Programs ({n_programs}): {program_labels[:5]}"
)
results[f"{name}_lattice"] = {
"n_nodes": len(lattice),
"sample_nodes": [
{
"intent": sorted(intent),
"extent_size": len(extent),
}
for extent, intent in lattice[:20]
],
}
# ============================================================
# Summary
# ============================================================
logger.info("\n" + "=" * 60)
logger.info("EXPERIMENT SUMMARY")
logger.info("=" * 60)
for name in ["two_sum", "flatten"]:
dr = results["program_spaces"][name]
logger.info(f"\n{name}:")
logger.info(f" Total programs: {dr['total_programs']}")
logger.info(f" Valid programs: {dr['valid_programs']}")
logger.info(f" Functional clusters: {dr['functionally_unique_clusters']}")
logger.info(f" Entropy diversity: {dr['entropy_diversity']:.2f}")
if f"{name}_alignment" in results:
al = results[f"{name}_alignment"]
logger.info(f" Concept dimensions: {al.get('n_concepts', 0)}")
logger.info(f" Effective dimensionality: {al.get('effective_dimensionality', 0):.2f}")
logger.info(f" Avg concept correlation: {al.get('avg_concept_correlation', 0):.3f}")
return results
if __name__ == "__main__":
results = run_experiment()
# Save results
output_path = "/app/experiment_results.json"
# Make results JSON-serializable
def make_serializable(obj):
if isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
elif isinstance(obj, set):
return list(obj)
raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
with open(output_path, "w") as f:
json.dump(results, f, indent=2, default=make_serializable)
logger.info(f"\nResults saved to {output_path}")