| """ |
| First Experiment: Concept Probing on Program Completions |
| |
| This experiment validates the central hypothesis: |
| "Concept structure exists in the program space and can be extracted, |
| verified, and used for steering." |
| |
| Steps: |
| 1. Define stub functions with known diverse solution strategies |
| 2. Generate multiple valid implementations (simulated or via LLM) |
| 3. Run concept discovery (behavioral + structural) |
| 4. Build concept-guided embedding space |
| 5. Verify alignment: do concept dimensions separate meaningful program families? |
| 6. Test steering: can we select programs by concept queries? |
| |
| This script runs WITHOUT GPU and WITHOUT API keys by using hand-crafted |
| program populations. For the full LLM-powered pipeline, see experiment_llm.py. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| import logging |
| import sys |
| from typing import Any |
|
|
| import numpy as np |
|
|
| from reason_first_program.stub import reason_first, Stub, StubConstraints |
| from reason_first_program.program_space import Program, ProgramSpace, execute_program |
| from reason_first_program.concepts import ( |
| BehavioralConceptDiscovery, |
| AbstractionConceptDiscovery, |
| UnifiedConceptDiscovery, |
| ConceptSet, |
| ) |
| from reason_first_program.embeddings import ( |
| ConceptBottleneckAE, |
| GCAVEmbedding, |
| ConceptEmbeddingSpace, |
| ) |
| from reason_first_program.steering import ConceptQuery, SteeringEngine, QueryLanguage |
|
|
| logging.basicConfig(level=logging.INFO, format="%(name)s - %(message)s") |
| logger = logging.getLogger(__name__) |
|
|
|
|
| |
| |
| |
|
|
| def create_two_sum_stub() -> Stub: |
| """The classic 'two sum' problem — rich space of valid implementations.""" |
| constraints = StubConstraints( |
| decorator_spec="Find two indices in nums that sum to target", |
| inline_specs=["return pair of indices", "each input has exactly one solution"], |
| type_hints={"nums": "list[int]", "target": "int"}, |
| return_type="list[int]", |
| ) |
| return Stub( |
| name="two_sum", |
| source=( |
| "def two_sum(nums: list[int], target: int) -> list[int]:\n" |
| " #> return pair of indices; each input has exactly one solution\n" |
| " ..." |
| ), |
| signature="(nums: list[int], target: int) -> list[int]", |
| constraints=constraints, |
| test_inputs=[ |
| {"nums": [2, 7, 11, 15], "target": 9}, |
| {"nums": [3, 2, 4], "target": 6}, |
| {"nums": [3, 3], "target": 6}, |
| {"nums": [1, 5, 3, 7, 2], "target": 8}, |
| ], |
| test_outputs=[[0, 1], [1, 2], [0, 1], [1, 3]], |
| ) |
|
|
|
|
| def create_flatten_stub() -> Stub: |
| """Flatten a nested list — diverse algorithmic approaches possible.""" |
| constraints = StubConstraints( |
| decorator_spec="Flatten a nested list of arbitrary depth into a flat list", |
| inline_specs=["handle arbitrary nesting depth", "preserve element order"], |
| type_hints={"nested": "list"}, |
| return_type="list", |
| ) |
| return Stub( |
| name="flatten", |
| source=( |
| "def flatten(nested: list) -> list:\n" |
| " #> handle arbitrary nesting depth; preserve element order\n" |
| " ..." |
| ), |
| signature="(nested: list) -> list", |
| constraints=constraints, |
| test_inputs=[ |
| {"nested": [1, [2, 3], [4, [5, 6]]]}, |
| {"nested": [[1, 2], [3, [4, [5]]]]}, |
| {"nested": [1, 2, 3]}, |
| {"nested": []}, |
| {"nested": [[[1]], [[2]], [[3]]]}, |
| ], |
| test_outputs=[ |
| [1, 2, 3, 4, 5, 6], |
| [1, 2, 3, 4, 5], |
| [1, 2, 3], |
| [], |
| [1, 2, 3], |
| ], |
| ) |
|
|
|
|
| def create_group_anagrams_stub() -> Stub: |
| """Group anagrams — several data structure choices.""" |
| constraints = StubConstraints( |
| decorator_spec="Group strings that are anagrams of each other", |
| inline_specs=["anagrams have same letters in different order"], |
| type_hints={"strs": "list[str]"}, |
| return_type="list[list[str]]", |
| ) |
| return Stub( |
| name="group_anagrams", |
| source=( |
| "def group_anagrams(strs: list[str]) -> list[list[str]]:\n" |
| " #> anagrams have same letters in different order\n" |
| " ..." |
| ), |
| signature="(strs: list[str]) -> list[list[str]]", |
| constraints=constraints, |
| test_inputs=[ |
| {"strs": ["eat", "tea", "tan", "ate", "nat", "bat"]}, |
| {"strs": [""]}, |
| {"strs": ["a"]}, |
| ], |
| ) |
|
|
|
|
| |
| |
| |
|
|
| TWO_SUM_IMPLEMENTATIONS = [ |
| |
| { |
| "source": """ for i in range(len(nums)): |
| for j in range(i + 1, len(nums)): |
| if nums[i] + nums[j] == target: |
| return [i, j] |
| return []""", |
| "label": "brute_force", |
| }, |
| |
| { |
| "source": """ seen = {} |
| for i, num in enumerate(nums): |
| complement = target - num |
| if complement in seen: |
| return [seen[complement], i] |
| seen[num] = i |
| return []""", |
| "label": "hash_single_pass", |
| }, |
| |
| { |
| "source": """ lookup = {} |
| for i, num in enumerate(nums): |
| lookup[num] = i |
| for i, num in enumerate(nums): |
| complement = target - num |
| if complement in lookup and lookup[complement] != i: |
| return [i, lookup[complement]] |
| return []""", |
| "label": "hash_two_pass", |
| }, |
| |
| { |
| "source": """ indexed = sorted(enumerate(nums), key=lambda x: x[1]) |
| left, right = 0, len(indexed) - 1 |
| while left < right: |
| current_sum = indexed[left][1] + indexed[right][1] |
| if current_sum == target: |
| return sorted([indexed[left][0], indexed[right][0]]) |
| elif current_sum < target: |
| left += 1 |
| else: |
| right -= 1 |
| return []""", |
| "label": "sort_two_pointer", |
| }, |
| |
| { |
| "source": """ pairs = [(i, j) for i in range(len(nums)) for j in range(i+1, len(nums)) if nums[i] + nums[j] == target] |
| return list(pairs[0]) if pairs else []""", |
| "label": "comprehension", |
| }, |
| |
| { |
| "source": """ from itertools import combinations |
| for i, j in combinations(range(len(nums)), 2): |
| if nums[i] + nums[j] == target: |
| return [i, j] |
| return []""", |
| "label": "itertools", |
| }, |
| |
| { |
| "source": """ def helper(start, nums_list, target_val): |
| if start >= len(nums_list) - 1: |
| return [] |
| for j in range(start + 1, len(nums_list)): |
| if nums_list[start] + nums_list[j] == target_val: |
| return [start, j] |
| return helper(start + 1, nums_list, target_val) |
| return helper(0, nums, target)""", |
| "label": "recursive", |
| }, |
| |
| { |
| "source": """ num_set = set(nums) |
| for i, num in enumerate(nums): |
| complement = target - num |
| if complement in num_set and nums.index(complement) != i: |
| j = nums.index(complement) |
| return [min(i, j), max(i, j)] |
| return []""", |
| "label": "set_based", |
| }, |
| |
| { |
| "source": """ if not nums or len(nums) < 2: |
| return [] |
| seen = {} |
| for idx, val in enumerate(nums): |
| diff = target - val |
| if diff in seen: |
| return [seen[diff], idx] |
| seen[val] = idx |
| return []""", |
| "label": "hash_with_guard", |
| }, |
| |
| { |
| "source": """ i = 0 |
| while i < len(nums): |
| j = i + 1 |
| while j < len(nums): |
| if nums[i] + nums[j] == target: |
| return [i, j] |
| j += 1 |
| i += 1 |
| return []""", |
| "label": "while_loop", |
| }, |
| ] |
|
|
| FLATTEN_IMPLEMENTATIONS = [ |
| |
| { |
| "source": """ result = [] |
| for item in nested: |
| if isinstance(item, list): |
| result.extend(flatten(item)) |
| else: |
| result.append(item) |
| return result""", |
| "label": "recursive_extend", |
| }, |
| |
| { |
| "source": """ stack = list(nested) |
| result = [] |
| while stack: |
| item = stack.pop(0) |
| if isinstance(item, list): |
| stack = item + stack |
| else: |
| result.append(item) |
| return result""", |
| "label": "iterative_stack", |
| }, |
| |
| { |
| "source": """ def _flatten_gen(lst): |
| for item in lst: |
| if isinstance(item, list): |
| yield from _flatten_gen(item) |
| else: |
| yield item |
| return list(_flatten_gen(nested))""", |
| "label": "generator", |
| }, |
| |
| { |
| "source": """ return [x for item in nested for x in (flatten(item) if isinstance(item, list) else [item])]""", |
| "label": "recursive_comprehension", |
| }, |
| |
| { |
| "source": """ stack = [nested] |
| result = [] |
| while stack: |
| current = stack.pop() |
| if isinstance(current, list): |
| for item in reversed(current): |
| stack.append(item) |
| else: |
| result.append(current) |
| return result""", |
| "label": "iterative_lifo", |
| }, |
| |
| { |
| "source": """ from functools import reduce |
| def reducer(acc, item): |
| if isinstance(item, list): |
| return reduce(reducer, item, acc) |
| return acc + [item] |
| return reduce(reducer, nested, [])""", |
| "label": "reduce", |
| }, |
| |
| { |
| "source": """ result = list(nested) |
| i = 0 |
| while i < len(result): |
| if isinstance(result[i], list): |
| items = result.pop(i) |
| for j, item in enumerate(items): |
| result.insert(i + j, item) |
| else: |
| i += 1 |
| return result""", |
| "label": "while_mutation", |
| }, |
| |
| { |
| "source": """ def helper(lst, acc): |
| for item in lst: |
| if isinstance(item, list): |
| helper(item, acc) |
| else: |
| acc.append(item) |
| return acc |
| return helper(nested, [])""", |
| "label": "accumulator", |
| }, |
| ] |
|
|
|
|
| def build_program(source: str, label: str, stub: Stub) -> Program: |
| """Build a Program from source and execute against test inputs.""" |
| |
| def_line = stub.source.split("\n")[0] |
| |
| body = source.rstrip() |
| |
| lines = body.split("\n") |
| |
| first_content = next((l for l in lines if l.strip()), "") |
| if first_content and not first_content.startswith(" "): |
| |
| indented_source = "\n".join( |
| f" {line}" if line.strip() else line |
| for line in lines |
| ) |
| else: |
| indented_source = body |
| full_source = f"{def_line}\n{indented_source}" |
|
|
| program = Program( |
| source=source.strip(), |
| full_source=full_source, |
| stub_id=stub.stub_id, |
| model_id="handcrafted", |
| metadata={"label": label}, |
| ) |
|
|
| if stub.test_inputs: |
| program = execute_program(program, stub, stub.test_inputs) |
|
|
| return program |
|
|
|
|
| def run_experiment() -> dict[str, Any]: |
| """ |
| Run the full concept probing experiment. |
| |
| Returns a comprehensive results dict suitable for JSON serialization. |
| """ |
| results: dict[str, Any] = {} |
|
|
| |
| |
| |
| logger.info("=" * 60) |
| logger.info("STAGE 1: Building Program Spaces") |
| logger.info("=" * 60) |
|
|
| two_sum_stub = create_two_sum_stub() |
| flatten_stub = create_flatten_stub() |
|
|
| |
| ts_space = ProgramSpace(two_sum_stub) |
| for impl in TWO_SUM_IMPLEMENTATIONS: |
| p = build_program(impl["source"], impl["label"], two_sum_stub) |
| ts_space.add(p) |
|
|
| |
| fl_space = ProgramSpace(flatten_stub) |
| for impl in FLATTEN_IMPLEMENTATIONS: |
| p = build_program(impl["source"], impl["label"], flatten_stub) |
| fl_space.add(p) |
|
|
| logger.info(f"two_sum space: {len(ts_space)} programs, {len(ts_space.valid_programs)} valid") |
| logger.info(f"flatten space: {len(fl_space)} programs, {len(fl_space.valid_programs)} valid") |
|
|
| results["program_spaces"] = { |
| "two_sum": ts_space.diversity_report(), |
| "flatten": fl_space.diversity_report(), |
| } |
|
|
| |
| |
| |
| logger.info("\n" + "=" * 60) |
| logger.info("STAGE 2: Concept Discovery") |
| logger.info("=" * 60) |
|
|
| discovery = UnifiedConceptDiscovery() |
|
|
| ts_concepts = discovery.discover(ts_space) |
| fl_concepts = discovery.discover(fl_space) |
|
|
| logger.info(f"\ntwo_sum concepts ({len(ts_concepts.concepts)}):") |
| for c in ts_concepts.concepts: |
| logger.info( |
| f" {c.name}: {c.description} " |
| f"[{len(c.programs)}/{len(ts_space.valid_programs)} programs]" |
| ) |
|
|
| logger.info(f"\nflatten concepts ({len(fl_concepts.concepts)}):") |
| for c in fl_concepts.concepts: |
| logger.info( |
| f" {c.name}: {c.description} " |
| f"[{len(c.programs)}/{len(fl_space.valid_programs)} programs]" |
| ) |
|
|
| results["concepts"] = { |
| "two_sum": ts_concepts.to_dict(), |
| "flatten": fl_concepts.to_dict(), |
| } |
|
|
| |
| |
| |
| logger.info("\n" + "=" * 60) |
| logger.info("STAGE 3: Concept-Guided Embeddings") |
| logger.info("=" * 60) |
|
|
| for name, space, concepts in [ |
| ("two_sum", ts_space, ts_concepts), |
| ("flatten", fl_space, fl_concepts), |
| ]: |
| if not concepts.concepts: |
| logger.warning(f"No concepts found for {name}, skipping embedding") |
| continue |
|
|
| embedding = ConceptEmbeddingSpace(concepts) |
| programs = space.valid_programs |
|
|
| if not programs: |
| logger.warning(f"No valid programs for {name}, skipping") |
| continue |
|
|
| |
| projection = embedding.project(programs) |
| logger.info(f"\n{name} concept projection shape: {projection.shape}") |
|
|
| |
| alignment = embedding.verify_alignment(programs) |
| logger.info(f"{name} alignment metrics:") |
| for k, v in alignment.items(): |
| logger.info(f" {k}: {v:.4f}" if isinstance(v, float) else f" {k}: {v}") |
|
|
| results[f"{name}_alignment"] = alignment |
|
|
| |
| logger.info(f"\n{name} concept scores per program:") |
| concept_names = concepts.names |
| for program in programs: |
| scores = concepts.score_program(program) |
| active = [n for n, s in scores.items() if s > 0.5] |
| label = program.metadata.get("label", "?") |
| logger.info(f" [{label}]: {', '.join(active) if active else 'none'}") |
|
|
| |
| score_matrix = concepts.score_matrix(programs) |
| results[f"{name}_score_matrix"] = { |
| "shape": list(score_matrix.shape), |
| "mean_scores": score_matrix.mean(axis=0).tolist(), |
| "concept_names": concept_names, |
| } |
|
|
| |
| if len(programs) >= 4 and score_matrix.shape[1] >= 2: |
| logger.info(f"\nTraining GCAV for {name}...") |
|
|
| |
| features = np.array([p.behavioral_vector() for p in programs]) |
| |
| max_len = max(f.shape[0] for f in [features[i] for i in range(len(features))]) |
| padded_features = np.zeros((len(programs), max_len)) |
| for i, f in enumerate(features): |
| padded_features[i, : len(f)] = f |
|
|
| gcav = GCAVEmbedding() |
| gcav_results = gcav.train_all(concepts, padded_features, programs) |
| logger.info(f" Trained {len(gcav_results)} concept vectors:") |
| for cname, metrics in gcav_results.items(): |
| logger.info(f" {cname}: accuracy={metrics['accuracy']:.2f}") |
|
|
| results[f"{name}_gcav"] = { |
| k: {mk: float(mv) if isinstance(mv, (int, float)) else str(mv) for mk, mv in v.items()} |
| for k, v in gcav_results.items() |
| } |
|
|
| |
| |
| |
| logger.info("\n" + "=" * 60) |
| logger.info("STAGE 4: Query Language & Steering") |
| logger.info("=" * 60) |
|
|
| for name, space, concepts in [ |
| ("two_sum", ts_space, ts_concepts), |
| ("flatten", fl_space, fl_concepts), |
| ]: |
| if not concepts.concepts: |
| continue |
|
|
| engine = SteeringEngine(concepts) |
| ql = engine.query_language |
|
|
| logger.info(f"\n{name} — Query Examples:") |
|
|
| |
| queries = [ |
| ("uses_recursion=1.0", "Find recursive implementations"), |
| ("uses_iteration=1.0, uses_mutation=-1.0", "Iterative but immutable"), |
| ("uses_dict=1.0, uses_early_return=1.0", "Hash-based with guards"), |
| ("uses_list_comprehension=1.0", "Comprehension-based"), |
| ] |
|
|
| query_results = [] |
| for query_str, description in queries: |
| query = ql.parse(query_str) |
| selected = engine.select(space, query, top_k=3) |
|
|
| if selected: |
| logger.info(f"\n Query: {description}") |
| logger.info(f" Parsed: {query}") |
| for program, score in selected: |
| label = program.metadata.get("label", "?") |
| logger.info(f" -> [{label}] relevance={score:.3f}") |
|
|
| query_results.append({ |
| "query": query_str, |
| "description": description, |
| "results": [ |
| {"label": p.metadata.get("label", "?"), "relevance": float(s)} |
| for p, s in selected |
| ], |
| }) |
|
|
| results[f"{name}_queries"] = query_results |
|
|
| |
| logger.info(f"\n{name} — Concept Boundaries:") |
| for concept in concepts.concepts[:3]: |
| boundary = engine.concept_boundary_programs(concept.name, space) |
| inside_labels = [ |
| p.metadata.get("label", "?") for p in boundary["inside"] |
| ] |
| outside_labels = [ |
| p.metadata.get("label", "?") for p in boundary["outside"] |
| ] |
| logger.info( |
| f" {concept.name}: inside={inside_labels}, outside={outside_labels}" |
| ) |
|
|
| |
| |
| |
| logger.info("\n" + "=" * 60) |
| logger.info("STAGE 5: Concept Lattice (Formal Concept Analysis)") |
| logger.info("=" * 60) |
|
|
| for name, space, concepts in [ |
| ("two_sum", ts_space, ts_concepts), |
| ("flatten", fl_space, fl_concepts), |
| ]: |
| if not concepts.concepts: |
| continue |
|
|
| lattice = concepts.concept_lattice() |
| logger.info(f"\n{name} concept lattice ({len(lattice)} nodes):") |
|
|
| for extent, intent in lattice[:10]: |
| n_programs = len(extent) |
| concept_names = sorted(intent) |
| |
| program_labels = [] |
| for pid in extent: |
| p = space.get(pid) |
| if p: |
| program_labels.append(p.metadata.get("label", "?")) |
| logger.info( |
| f" Concepts: {concept_names} -> " |
| f"Programs ({n_programs}): {program_labels[:5]}" |
| ) |
|
|
| results[f"{name}_lattice"] = { |
| "n_nodes": len(lattice), |
| "sample_nodes": [ |
| { |
| "intent": sorted(intent), |
| "extent_size": len(extent), |
| } |
| for extent, intent in lattice[:20] |
| ], |
| } |
|
|
| |
| |
| |
| logger.info("\n" + "=" * 60) |
| logger.info("EXPERIMENT SUMMARY") |
| logger.info("=" * 60) |
|
|
| for name in ["two_sum", "flatten"]: |
| dr = results["program_spaces"][name] |
| logger.info(f"\n{name}:") |
| logger.info(f" Total programs: {dr['total_programs']}") |
| logger.info(f" Valid programs: {dr['valid_programs']}") |
| logger.info(f" Functional clusters: {dr['functionally_unique_clusters']}") |
| logger.info(f" Entropy diversity: {dr['entropy_diversity']:.2f}") |
|
|
| if f"{name}_alignment" in results: |
| al = results[f"{name}_alignment"] |
| logger.info(f" Concept dimensions: {al.get('n_concepts', 0)}") |
| logger.info(f" Effective dimensionality: {al.get('effective_dimensionality', 0):.2f}") |
| logger.info(f" Avg concept correlation: {al.get('avg_concept_correlation', 0):.3f}") |
|
|
| return results |
|
|
|
|
| if __name__ == "__main__": |
| results = run_experiment() |
|
|
| |
| output_path = "/app/experiment_results.json" |
| |
| |
| def make_serializable(obj): |
| if isinstance(obj, np.integer): |
| return int(obj) |
| elif isinstance(obj, np.floating): |
| return float(obj) |
| elif isinstance(obj, np.ndarray): |
| return obj.tolist() |
| elif isinstance(obj, set): |
| return list(obj) |
| raise TypeError(f"Object of type {type(obj)} is not JSON serializable") |
|
|
| with open(output_path, "w") as f: |
| json.dump(results, f, indent=2, default=make_serializable) |
|
|
| logger.info(f"\nResults saved to {output_path}") |
|
|