{ "cells": [ { "cell_type": "markdown", "id": "intro", "metadata": {}, "source": [ "# programming-language-identification-100plus\n", "\n", "Runnable examples for the ModernBERT programming-language identifier.\n", "Covers 107 languages. Input is truncated to the first 512 characters\n", "(matches the training-time `head` strategy).\n", "\n", "Point `MODEL_ID` at the local checkpoint directory or the HF repo id." ] }, { "cell_type": "code", "execution_count": null, "id": "setup", "metadata": {}, "outputs": [], "source": "import torch\nfrom transformers import AutoModelForSequenceClassification, AutoTokenizer\n\nMODEL_ID = \"/home/vijay/llm_models/guardrail_code_models/programming-language-identification-100plus\"\nDEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\ntokenizer = AutoTokenizer.from_pretrained(MODEL_ID)\nmodel = AutoModelForSequenceClassification.from_pretrained(\n MODEL_ID,\n attn_implementation=\"eager\",\n torch_dtype=torch.bfloat16, # weights are published in bf16\n).to(DEVICE).eval()\n\nprint(f\"device={DEVICE} num_labels={model.config.num_labels} dtype={model.dtype}\")\n" }, { "cell_type": "markdown", "id": "helpers", "metadata": {}, "source": [ "## Helpers" ] }, { "cell_type": "code", "execution_count": null, "id": "helpers-code", "metadata": {}, "outputs": [], "source": [ "@torch.no_grad()\n", "def predict(snippets, top_k=1, max_chars=512):\n", " \"\"\"Return the top-k languages + probabilities for each snippet.\"\"\"\n", " if isinstance(snippets, str):\n", " snippets = [snippets]\n", " trimmed = [s[:max_chars] for s in snippets]\n", " encoded = tokenizer(\n", " trimmed, return_tensors=\"pt\", padding=True, truncation=True, max_length=512\n", " ).to(DEVICE)\n", " logits = model(**encoded).logits\n", " probs = logits.softmax(-1)\n", " top_probs, top_ids = probs.topk(top_k, dim=-1)\n", " results = []\n", " for row_probs, row_ids in zip(top_probs.tolist(), top_ids.tolist()):\n", " results.append(\n", " [\n", " (model.config.id2label[label_id], prob)\n", " for label_id, prob in zip(row_ids, row_probs)\n", " ]\n", " )\n", " return results\n", "\n", "\n", "def show(title, snippet, top_k=1):\n", " preds = predict(snippet, top_k=top_k)[0]\n", " head = snippet.strip().splitlines()[0][:60]\n", " print(f\"{title:14s} `{head}`\")\n", " for name, prob in preds:\n", " print(f\" {name:30s} {prob:.3f}\")\n", " print()" ] }, { "cell_type": "markdown", "id": "single", "metadata": {}, "source": [ "## 1. Single-snippet prediction" ] }, { "cell_type": "code", "execution_count": null, "id": "single-code", "metadata": {}, "outputs": [], "source": [ "python_snippet = '''\n", "def greet(name: str) -> None:\n", " print(f\"hello, {name}\")\n", "\n", "for person in [\"ada\", \"alan\", \"grace\"]:\n", " greet(person)\n", "'''.strip()\n", "\n", "show(\"Python\", python_snippet)" ] }, { "cell_type": "markdown", "id": "batch", "metadata": {}, "source": [ "## 2. Batch across many languages" ] }, { "cell_type": "code", "execution_count": null, "id": "batch-code", "metadata": {}, "outputs": [], "source": [ "SAMPLES = {\n", " \"Rust\": '''\n", "fn main() {\n", " let names = vec![\"ada\", \"alan\", \"grace\"];\n", " for n in &names {\n", " println!(\"hello, {}\", n);\n", " }\n", "}\n", "'''.strip(),\n", " \"Go\": '''\n", "package main\n", "\n", "import \"fmt\"\n", "\n", "func main() {\n", " names := []string{\"ada\", \"alan\", \"grace\"}\n", " for _, n := range names {\n", " fmt.Printf(\"hello, %s\\\\n\", n)\n", " }\n", "}\n", "'''.strip(),\n", " \"Ruby\": '''\n", "[\"ada\", \"alan\", \"grace\"].each do |name|\n", " puts \"hello, #{name}\"\n", "end\n", "'''.strip(),\n", " \"Elixir\": '''\n", "defmodule Greeter do\n", " def hello(name), do: IO.puts(\"hello, #{name}\")\n", "end\n", "\n", "Enum.each([\"ada\", \"alan\", \"grace\"], &Greeter.hello/1)\n", "'''.strip(),\n", " \"Haskell\": '''\n", "main :: IO ()\n", "main = mapM_ (\\\\n -> putStrLn (\"hello, \" ++ n)) [\"ada\", \"alan\", \"grace\"]\n", "'''.strip(),\n", " \"Kotlin\": '''\n", "fun main() {\n", " listOf(\"ada\", \"alan\", \"grace\").forEach { println(\"hello, $it\") }\n", "}\n", "'''.strip(),\n", " \"Mathematica/Wolfram Language\": '''\n", "greet[name_String] := Print[\"hello, \" <> name];\n", "greet /@ {\"ada\", \"alan\", \"grace\"};\n", "'''.strip(),\n", " \"ARM Assembly\": '''\n", " .syntax unified\n", " .thumb\n", " .global main\n", "main:\n", " ldr r0, =message\n", " bl puts\n", " mov r0, #0\n", " bx lr\n", "message:\n", " .asciz \"hello\"\n", "'''.strip(),\n", " \"Julia\": '''\n", "for name in [\"ada\", \"alan\", \"grace\"]\n", " println(\"hello, $name\")\n", "end\n", "'''.strip(),\n", "}\n", "\n", "snippets = list(SAMPLES.values())\n", "expected = list(SAMPLES.keys())\n", "predictions = predict(snippets, top_k=1)\n", "\n", "correct = 0\n", "for gold, preds in zip(expected, predictions):\n", " predicted, prob = preds[0]\n", " mark = \"OK \" if predicted == gold else \"! \"\n", " print(f\" {mark} gold={gold:32s} pred={predicted:32s} p={prob:.3f}\")\n", " if predicted == gold:\n", " correct += 1\n", "print(f\"\\n{correct}/{len(snippets)} correct\")" ] }, { "cell_type": "markdown", "id": "topk", "metadata": {}, "source": [ "## 3. Top-k with confidence\n", "\n", "Useful when a snippet is short or ambiguous — inspect the runner-ups\n", "before committing to a label." ] }, { "cell_type": "code", "execution_count": null, "id": "topk-code", "metadata": {}, "outputs": [], "source": [ "# Kotlin/Java syntactic overlap — see how far ahead the winner is\n", "jvm_snippet = '''\n", "class Hello {\n", " fun say(name: String) = println(\"hello, $name\")\n", "}\n", "'''.strip()\n", "\n", "show(\"JVM snippet\", jvm_snippet, top_k=5)" ] }, { "cell_type": "markdown", "id": "ambiguous", "metadata": {}, "source": [ "## 4. Very short / ambiguous input\n", "\n", "Snippets under ~60 characters are often genuinely ambiguous — multiple\n", "languages accept the same syntax. Top-k probabilities will be diffuse." ] }, { "cell_type": "code", "execution_count": null, "id": "ambiguous-code", "metadata": {}, "outputs": [], "source": [ "show(\"short\", \"x = 1\", top_k=5)\n", "show(\"one-liner\", \"print('hi')\", top_k=5)\n", "show(\"empty-ish\", \"{}\", top_k=5)" ] }, { "cell_type": "markdown", "id": "closing", "metadata": {}, "source": [ "## Tips\n", "\n", "* Feed at least ~100 characters for reliable results.\n", "* The model was trained and evaluated with the first 512 characters of each\n", " file. For longer files, that's also what you should pass.\n", "* If you have file extensions available, treat them as a strong prior —\n", " this classifier is purely content-based and will happily misclassify a\n", " polyglot hello-world if you ask it to." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.11" } }, "nbformat": 4, "nbformat_minor": 5 }