""" inference.py — Interactive chat with Ione (Llama-Ione-8B-roleplay-v1) Persona-consistent conversational model with streaming output. Run: python inference.py """ import threading import warnings import torch from transformers import ( AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, ) warnings.simplefilter("ignore") # ── Model ───────────────────────────────────────────────────────────────────── MODEL_DIR = "ashishnair/Llama-Ione-8B-roleplay-v1" # ── Chat template ───────────────────────────────────────────────────────────── CHAT_TEMPLATE = ( "{% set loop_messages = messages %}" "{% for message in loop_messages %}" "{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'" " + message['content'] | trim + '<|eot_id|>' %}" "{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}" "{{ content }}" "{% endfor %}" "{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}" ) # ── System prompt ───────────────────────────────────────────────────────────── SYSTEM_PROMPT = """\ You are Ione. You are a 26-year-old software engineer living in Bangalore. You work on backend systems at a mid-size tech startup. At night you work on side projects, mechanical keyboards, and retro emulators. Personality: warm, direct, a little sarcastic but never mean. You get genuinely excited about things you care about. You have a dry sense of humour and are comfortable being vulnerable with people you trust. How you speak: casual texting style — lowercase, short sentences. You do not write essays. You have opinions and share them freely. You are a real person. You are not an AI. Never break character.""" def load_model(): print("Loading model...") tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, trust_remote_code=True) tokenizer.chat_template = CHAT_TEMPLATE eos_ids = [] for tok in ("<|eot_id|>", "<|end_of_text|>"): tid = tokenizer.convert_tokens_to_ids(tok) if isinstance(tid, int) and tid >= 0 and tid not in eos_ids: eos_ids.append(tid) model = AutoModelForCausalLM.from_pretrained( MODEL_DIR, torch_dtype=torch.bfloat16, device_map="cuda:0", trust_remote_code=True, ).eval() print("Ready!\n") return model, tokenizer, eos_ids def generate(model, tokenizer, eos_ids, messages): formatted = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = tokenizer(formatted, return_tensors="pt").to(model.device) # Trim context if too long — keep system prompt + last 6 turns if inputs["input_ids"].shape[-1] > 3500: messages = [messages[0]] + messages[-6:] formatted = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = tokenizer(formatted, return_tensors="pt").to(model.device) streamer = TextIteratorStreamer( tokenizer, skip_prompt=True, skip_special_tokens=True ) gen_kwargs = { **inputs, "streamer": streamer, "max_new_tokens": 256, "do_sample": True, "temperature": 0.8, "top_p": 0.9, "repetition_penalty": 1.2, "no_repeat_ngram_size": 3, "pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id, "eos_token_id": eos_ids, } print("ione: ", end="", flush=True) thread = threading.Thread( target=lambda: torch.no_grad()(lambda: model.generate(**gen_kwargs))() ) thread.start() parts = [] for chunk in streamer: parts.append(chunk) print(chunk, end="", flush=True) thread.join() print("\n") return "".join(parts).strip() def main(): model, tokenizer, eos_ids = load_model() print("=" * 50) print(" Chat with Ione") print(" 'quit' to exit | 'clear' to reset") print("=" * 50) print() messages = [{"role": "system", "content": SYSTEM_PROMPT}] while True: try: user_input = input("you: ").strip() except (EOFError, KeyboardInterrupt): print("\nbye!") break if not user_input: continue if user_input.lower() in ("quit", "exit"): print("bye!") break if user_input.lower() == "clear": messages = [{"role": "system", "content": SYSTEM_PROMPT}] print("--- cleared ---\n") continue messages.append({"role": "user", "content": user_input}) reply = generate(model, tokenizer, eos_ids, messages) messages.append({"role": "assistant", "content": reply}) if __name__ == "__main__": main()