vijaym commited on
Commit
eec3a6f
·
verified ·
1 Parent(s): bcd4b22

Initial release: weights + ONNX + notebook

Browse files
Files changed (1) hide show
  1. inference_examples.ipynb +272 -0
inference_examples.ipynb ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "intro",
6
+ "metadata": {},
7
+ "source": [
8
+ "# programming-language-identification-100plus\n",
9
+ "\n",
10
+ "Runnable examples for the ModernBERT programming-language identifier.\n",
11
+ "Covers 107 languages. Input is truncated to the first 512 characters\n",
12
+ "(matches the training-time `head` strategy).\n",
13
+ "\n",
14
+ "Point `MODEL_ID` at the local checkpoint directory or the HF repo id."
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": null,
20
+ "id": "setup",
21
+ "metadata": {},
22
+ "outputs": [],
23
+ "source": "import torch\nfrom transformers import AutoModelForSequenceClassification, AutoTokenizer\n\nMODEL_ID = \"/home/vijay/llm_models/guardrail_code_models/programming-language-identification-100plus\"\nDEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\ntokenizer = AutoTokenizer.from_pretrained(MODEL_ID)\nmodel = AutoModelForSequenceClassification.from_pretrained(\n MODEL_ID,\n attn_implementation=\"eager\",\n torch_dtype=torch.bfloat16, # weights are published in bf16\n).to(DEVICE).eval()\n\nprint(f\"device={DEVICE} num_labels={model.config.num_labels} dtype={model.dtype}\")\n"
24
+ },
25
+ {
26
+ "cell_type": "markdown",
27
+ "id": "helpers",
28
+ "metadata": {},
29
+ "source": [
30
+ "## Helpers"
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "code",
35
+ "execution_count": null,
36
+ "id": "helpers-code",
37
+ "metadata": {},
38
+ "outputs": [],
39
+ "source": [
40
+ "@torch.no_grad()\n",
41
+ "def predict(snippets, top_k=1, max_chars=512):\n",
42
+ " \"\"\"Return the top-k languages + probabilities for each snippet.\"\"\"\n",
43
+ " if isinstance(snippets, str):\n",
44
+ " snippets = [snippets]\n",
45
+ " trimmed = [s[:max_chars] for s in snippets]\n",
46
+ " encoded = tokenizer(\n",
47
+ " trimmed, return_tensors=\"pt\", padding=True, truncation=True, max_length=512\n",
48
+ " ).to(DEVICE)\n",
49
+ " logits = model(**encoded).logits\n",
50
+ " probs = logits.softmax(-1)\n",
51
+ " top_probs, top_ids = probs.topk(top_k, dim=-1)\n",
52
+ " results = []\n",
53
+ " for row_probs, row_ids in zip(top_probs.tolist(), top_ids.tolist()):\n",
54
+ " results.append(\n",
55
+ " [\n",
56
+ " (model.config.id2label[label_id], prob)\n",
57
+ " for label_id, prob in zip(row_ids, row_probs)\n",
58
+ " ]\n",
59
+ " )\n",
60
+ " return results\n",
61
+ "\n",
62
+ "\n",
63
+ "def show(title, snippet, top_k=1):\n",
64
+ " preds = predict(snippet, top_k=top_k)[0]\n",
65
+ " head = snippet.strip().splitlines()[0][:60]\n",
66
+ " print(f\"{title:14s} `{head}`\")\n",
67
+ " for name, prob in preds:\n",
68
+ " print(f\" {name:30s} {prob:.3f}\")\n",
69
+ " print()"
70
+ ]
71
+ },
72
+ {
73
+ "cell_type": "markdown",
74
+ "id": "single",
75
+ "metadata": {},
76
+ "source": [
77
+ "## 1. Single-snippet prediction"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": null,
83
+ "id": "single-code",
84
+ "metadata": {},
85
+ "outputs": [],
86
+ "source": [
87
+ "python_snippet = '''\n",
88
+ "def greet(name: str) -> None:\n",
89
+ " print(f\"hello, {name}\")\n",
90
+ "\n",
91
+ "for person in [\"ada\", \"alan\", \"grace\"]:\n",
92
+ " greet(person)\n",
93
+ "'''.strip()\n",
94
+ "\n",
95
+ "show(\"Python\", python_snippet)"
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "markdown",
100
+ "id": "batch",
101
+ "metadata": {},
102
+ "source": [
103
+ "## 2. Batch across many languages"
104
+ ]
105
+ },
106
+ {
107
+ "cell_type": "code",
108
+ "execution_count": null,
109
+ "id": "batch-code",
110
+ "metadata": {},
111
+ "outputs": [],
112
+ "source": [
113
+ "SAMPLES = {\n",
114
+ " \"Rust\": '''\n",
115
+ "fn main() {\n",
116
+ " let names = vec![\"ada\", \"alan\", \"grace\"];\n",
117
+ " for n in &names {\n",
118
+ " println!(\"hello, {}\", n);\n",
119
+ " }\n",
120
+ "}\n",
121
+ "'''.strip(),\n",
122
+ " \"Go\": '''\n",
123
+ "package main\n",
124
+ "\n",
125
+ "import \"fmt\"\n",
126
+ "\n",
127
+ "func main() {\n",
128
+ " names := []string{\"ada\", \"alan\", \"grace\"}\n",
129
+ " for _, n := range names {\n",
130
+ " fmt.Printf(\"hello, %s\\\\n\", n)\n",
131
+ " }\n",
132
+ "}\n",
133
+ "'''.strip(),\n",
134
+ " \"Ruby\": '''\n",
135
+ "[\"ada\", \"alan\", \"grace\"].each do |name|\n",
136
+ " puts \"hello, #{name}\"\n",
137
+ "end\n",
138
+ "'''.strip(),\n",
139
+ " \"Elixir\": '''\n",
140
+ "defmodule Greeter do\n",
141
+ " def hello(name), do: IO.puts(\"hello, #{name}\")\n",
142
+ "end\n",
143
+ "\n",
144
+ "Enum.each([\"ada\", \"alan\", \"grace\"], &Greeter.hello/1)\n",
145
+ "'''.strip(),\n",
146
+ " \"Haskell\": '''\n",
147
+ "main :: IO ()\n",
148
+ "main = mapM_ (\\\\n -> putStrLn (\"hello, \" ++ n)) [\"ada\", \"alan\", \"grace\"]\n",
149
+ "'''.strip(),\n",
150
+ " \"Kotlin\": '''\n",
151
+ "fun main() {\n",
152
+ " listOf(\"ada\", \"alan\", \"grace\").forEach { println(\"hello, $it\") }\n",
153
+ "}\n",
154
+ "'''.strip(),\n",
155
+ " \"Mathematica/Wolfram Language\": '''\n",
156
+ "greet[name_String] := Print[\"hello, \" <> name];\n",
157
+ "greet /@ {\"ada\", \"alan\", \"grace\"};\n",
158
+ "'''.strip(),\n",
159
+ " \"ARM Assembly\": '''\n",
160
+ " .syntax unified\n",
161
+ " .thumb\n",
162
+ " .global main\n",
163
+ "main:\n",
164
+ " ldr r0, =message\n",
165
+ " bl puts\n",
166
+ " mov r0, #0\n",
167
+ " bx lr\n",
168
+ "message:\n",
169
+ " .asciz \"hello\"\n",
170
+ "'''.strip(),\n",
171
+ " \"Julia\": '''\n",
172
+ "for name in [\"ada\", \"alan\", \"grace\"]\n",
173
+ " println(\"hello, $name\")\n",
174
+ "end\n",
175
+ "'''.strip(),\n",
176
+ "}\n",
177
+ "\n",
178
+ "snippets = list(SAMPLES.values())\n",
179
+ "expected = list(SAMPLES.keys())\n",
180
+ "predictions = predict(snippets, top_k=1)\n",
181
+ "\n",
182
+ "correct = 0\n",
183
+ "for gold, preds in zip(expected, predictions):\n",
184
+ " predicted, prob = preds[0]\n",
185
+ " mark = \"OK \" if predicted == gold else \"! \"\n",
186
+ " print(f\" {mark} gold={gold:32s} pred={predicted:32s} p={prob:.3f}\")\n",
187
+ " if predicted == gold:\n",
188
+ " correct += 1\n",
189
+ "print(f\"\\n{correct}/{len(snippets)} correct\")"
190
+ ]
191
+ },
192
+ {
193
+ "cell_type": "markdown",
194
+ "id": "topk",
195
+ "metadata": {},
196
+ "source": [
197
+ "## 3. Top-k with confidence\n",
198
+ "\n",
199
+ "Useful when a snippet is short or ambiguous — inspect the runner-ups\n",
200
+ "before committing to a label."
201
+ ]
202
+ },
203
+ {
204
+ "cell_type": "code",
205
+ "execution_count": null,
206
+ "id": "topk-code",
207
+ "metadata": {},
208
+ "outputs": [],
209
+ "source": [
210
+ "# Kotlin/Java syntactic overlap — see how far ahead the winner is\n",
211
+ "jvm_snippet = '''\n",
212
+ "class Hello {\n",
213
+ " fun say(name: String) = println(\"hello, $name\")\n",
214
+ "}\n",
215
+ "'''.strip()\n",
216
+ "\n",
217
+ "show(\"JVM snippet\", jvm_snippet, top_k=5)"
218
+ ]
219
+ },
220
+ {
221
+ "cell_type": "markdown",
222
+ "id": "ambiguous",
223
+ "metadata": {},
224
+ "source": [
225
+ "## 4. Very short / ambiguous input\n",
226
+ "\n",
227
+ "Snippets under ~60 characters are often genuinely ambiguous — multiple\n",
228
+ "languages accept the same syntax. Top-k probabilities will be diffuse."
229
+ ]
230
+ },
231
+ {
232
+ "cell_type": "code",
233
+ "execution_count": null,
234
+ "id": "ambiguous-code",
235
+ "metadata": {},
236
+ "outputs": [],
237
+ "source": [
238
+ "show(\"short\", \"x = 1\", top_k=5)\n",
239
+ "show(\"one-liner\", \"print('hi')\", top_k=5)\n",
240
+ "show(\"empty-ish\", \"{}\", top_k=5)"
241
+ ]
242
+ },
243
+ {
244
+ "cell_type": "markdown",
245
+ "id": "closing",
246
+ "metadata": {},
247
+ "source": [
248
+ "## Tips\n",
249
+ "\n",
250
+ "* Feed at least ~100 characters for reliable results.\n",
251
+ "* The model was trained and evaluated with the first 512 characters of each\n",
252
+ " file. For longer files, that's also what you should pass.\n",
253
+ "* If you have file extensions available, treat them as a strong prior —\n",
254
+ " this classifier is purely content-based and will happily misclassify a\n",
255
+ " polyglot hello-world if you ask it to."
256
+ ]
257
+ }
258
+ ],
259
+ "metadata": {
260
+ "kernelspec": {
261
+ "display_name": "Python 3",
262
+ "language": "python",
263
+ "name": "python3"
264
+ },
265
+ "language_info": {
266
+ "name": "python",
267
+ "version": "3.11"
268
+ }
269
+ },
270
+ "nbformat": 4,
271
+ "nbformat_minor": 5
272
+ }