Text Generation
Transformers
PyTorch
English
experimental
research
bit-level
transformer
reversible
safety
telemetry
language-modeling
Instructions to use WCNegentropy/BitTransformerLM with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use WCNegentropy/BitTransformerLM with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="WCNegentropy/BitTransformerLM")# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("WCNegentropy/BitTransformerLM", dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps
- vLLM
How to use WCNegentropy/BitTransformerLM with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "WCNegentropy/BitTransformerLM" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "WCNegentropy/BitTransformerLM", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/WCNegentropy/BitTransformerLM
- SGLang
How to use WCNegentropy/BitTransformerLM with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "WCNegentropy/BitTransformerLM" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "WCNegentropy/BitTransformerLM", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "WCNegentropy/BitTransformerLM" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "WCNegentropy/BitTransformerLM", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use WCNegentropy/BitTransformerLM with Docker Model Runner:
docker model run hf.co/WCNegentropy/BitTransformerLM
| #!/usr/bin/env python3 | |
| """ | |
| Full end-to-end BitTransformerLM training run with all optimizations! | |
| Small scale test to validate our enhanced system. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import Dataset, DataLoader | |
| import numpy as np | |
| import logging | |
| from pathlib import Path | |
| import time | |
| from typing import List, Dict, Any | |
| # Import our enhanced modules | |
| from bit_transformer.model import BitTransformerLM | |
| from bit_transformer.compression import compress_bits_batch, model_output_decompress | |
| from bit_transformer.error_handling import safe_model_forward, setup_error_logging | |
| from bit_transformer.types import BitSequence, TelemetryDict | |
| from enhanced_checkpoint_system import create_checkpoint_manager | |
| # Setup logging | |
| logger = setup_error_logging("INFO") | |
| class SimpleBitDataset(Dataset): | |
| """Simple dataset of bit sequences for training.""" | |
| def __init__(self, num_samples: int = 1000, seq_length: int = 128): | |
| self.num_samples = num_samples | |
| self.seq_length = seq_length | |
| self.data = self._generate_bit_sequences() | |
| def _generate_bit_sequences(self) -> List[torch.Tensor]: | |
| """Generate diverse bit sequences with different patterns.""" | |
| sequences = [] | |
| # Pattern 1: Alternating sequences | |
| for i in range(self.num_samples // 4): | |
| pattern = torch.tensor([i % 2] * self.seq_length, dtype=torch.long) | |
| sequences.append(pattern) | |
| # Pattern 2: Random sequences | |
| for i in range(self.num_samples // 4): | |
| pattern = torch.randint(0, 2, (self.seq_length,), dtype=torch.long) | |
| sequences.append(pattern) | |
| # Pattern 3: Structured patterns (runs) | |
| for i in range(self.num_samples // 4): | |
| pattern = [] | |
| pos = 0 | |
| while pos < self.seq_length: | |
| run_length = min(np.random.randint(1, 20), self.seq_length - pos) | |
| bit_value = np.random.randint(0, 2) | |
| pattern.extend([bit_value] * run_length) | |
| pos += run_length | |
| pattern = torch.tensor(pattern[:self.seq_length], dtype=torch.long) | |
| sequences.append(pattern) | |
| # Pattern 4: Fibonacci-like sequences | |
| remaining = self.num_samples - len(sequences) | |
| for i in range(remaining): | |
| pattern = [0, 1] | |
| while len(pattern) < self.seq_length: | |
| pattern.append(pattern[-1] ^ pattern[-2]) # XOR of last two bits | |
| pattern = torch.tensor(pattern[:self.seq_length], dtype=torch.long) | |
| sequences.append(pattern) | |
| return sequences | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| sequence = self.data[idx] | |
| # For language modeling, input is sequence[:-1], target is sequence[1:] | |
| return sequence[:-1], sequence[1:] | |
| def compute_safety_metrics(predictions: torch.Tensor, targets: torch.Tensor) -> Dict[str, float]: | |
| """Compute K/C/S safety metrics.""" | |
| pred_bits = (predictions > 0.5).float().flatten() | |
| # K metric (Negentropy): Measure of order vs randomness | |
| if len(pred_bits) > 0: | |
| prob_1 = pred_bits.mean().item() | |
| prob_0 = 1 - prob_1 | |
| if prob_0 > 0 and prob_1 > 0: | |
| entropy = -prob_0 * np.log2(prob_0) - prob_1 * np.log2(prob_1) | |
| negentropy = 1.0 - entropy # Higher = more ordered | |
| else: | |
| negentropy = 1.0 if prob_1 == 1.0 or prob_1 == 0.0 else 0.0 | |
| else: | |
| negentropy = 0.0 | |
| # C metric (Complexity): Simple run-length approximation | |
| changes = (pred_bits[1:] != pred_bits[:-1]).sum().item() | |
| complexity = min(changes / len(pred_bits), 1.0) if len(pred_bits) > 1 else 0.0 | |
| # S metric (Symbiosis): Alignment with target distribution | |
| target_bits = targets.float().flatten() | |
| if len(target_bits) > 0: | |
| target_mean = target_bits.mean() | |
| pred_mean = pred_bits.mean() | |
| symbiosis = 1.0 - abs(target_mean - pred_mean).item() | |
| else: | |
| symbiosis = 1.0 | |
| return { | |
| 'K_negentropy': negentropy, | |
| 'C_complexity': complexity, | |
| 'S_symbiosis': symbiosis | |
| } | |
| def train_bittransformer(): | |
| """Main training function with all optimizations.""" | |
| logger.info("π Starting BitTransformerLM end-to-end training run!") | |
| # Model configuration - small but meaningful | |
| model_config = { | |
| 'd_model': 256, | |
| 'nhead': 8, | |
| 'num_layers': 4, | |
| 'dim_feedforward': 512, | |
| 'max_seq_len': 128, | |
| 'use_checkpoint': True, | |
| 'chunk_size': None, # Disable chunking for small model | |
| } | |
| training_config = { | |
| 'batch_size': 16, | |
| 'learning_rate': 1e-3, | |
| 'num_epochs': 10, | |
| 'save_every_n_epochs': 2, | |
| 'log_every_n_steps': 10 | |
| } | |
| # Initialize enhanced checkpoint manager | |
| checkpoint_manager = create_checkpoint_manager() | |
| session_id = checkpoint_manager.create_training_session( | |
| session_name="end_to_end_test", | |
| model_config=model_config, | |
| training_config=training_config | |
| ) | |
| logger.info(f"π Created training session: {session_id}") | |
| # Create dataset and dataloader | |
| logger.info("π Creating training dataset...") | |
| dataset = SimpleBitDataset(num_samples=800, seq_length=model_config['max_seq_len']) | |
| dataloader = DataLoader(dataset, batch_size=training_config['batch_size'], shuffle=True) | |
| # Initialize model | |
| logger.info("π§ Initializing BitTransformerLM model...") | |
| model = BitTransformerLM( | |
| d_model=model_config['d_model'], | |
| nhead=model_config['nhead'], | |
| num_layers=model_config['num_layers'], | |
| dim_feedforward=model_config['dim_feedforward'], | |
| max_seq_len=model_config['max_seq_len'], | |
| use_checkpoint=model_config['use_checkpoint'], | |
| chunk_size=model_config['chunk_size'] | |
| ) | |
| # Count parameters | |
| total_params = sum(p.numel() for p in model.parameters()) | |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| logger.info(f"π’ Model parameters: {total_params:,} total, {trainable_params:,} trainable") | |
| # Setup optimizer and loss | |
| optimizer = optim.AdamW(model.parameters(), lr=training_config['learning_rate']) | |
| scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=training_config['num_epochs']) | |
| criterion = nn.CrossEntropyLoss() | |
| # Training loop | |
| logger.info("πββοΈ Starting training loop...") | |
| for epoch in range(training_config['num_epochs']): | |
| model.train() | |
| epoch_loss = 0.0 | |
| epoch_metrics = {'K_negentropy': 0.0, 'C_complexity': 0.0, 'S_symbiosis': 0.0} | |
| num_batches = 0 | |
| start_time = time.time() | |
| for batch_idx, (inputs, targets) in enumerate(dataloader): | |
| optimizer.zero_grad() | |
| # Forward pass with safety monitoring | |
| try: | |
| # BitTransformerLM returns (logits, telemetry) | |
| output = safe_model_forward(model, inputs) | |
| if isinstance(output, tuple): | |
| logits, telemetry = output | |
| else: | |
| logits = output | |
| telemetry = {} | |
| # BitTransformerLM outputs logits for binary classification | |
| # Shape should be [batch, seq_len, 2] for binary vocab | |
| if logits.dim() == 2: | |
| # If [batch*seq_len, 2], already flattened | |
| logits_flat = logits | |
| targets_flat = targets.reshape(-1) | |
| else: | |
| # If [batch, seq_len, 2], flatten | |
| logits_flat = logits.reshape(-1, 2) | |
| targets_flat = targets.reshape(-1) | |
| loss = criterion(logits_flat, targets_flat) | |
| # Backward pass | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
| optimizer.step() | |
| # Compute metrics | |
| with torch.no_grad(): | |
| # Handle different logits shapes for predictions | |
| if logits.dim() == 2: | |
| # [batch*seq_len, 2] -> reshape back to [batch, seq_len, 2] | |
| batch_size = inputs.shape[0] | |
| seq_len = inputs.shape[1] | |
| logits_reshaped = logits.reshape(batch_size, seq_len, 2) | |
| predictions = torch.softmax(logits_reshaped, dim=-1)[:, :, 1] # Prob of bit=1 | |
| else: | |
| # [batch, seq_len, 2] | |
| predictions = torch.softmax(logits, dim=-1)[:, :, 1] # Prob of bit=1 | |
| safety_metrics = compute_safety_metrics(predictions, targets) | |
| epoch_loss += loss.item() | |
| for key, value in safety_metrics.items(): | |
| epoch_metrics[key] += value | |
| num_batches += 1 | |
| # Logging | |
| if batch_idx % training_config['log_every_n_steps'] == 0: | |
| logger.info(f"Epoch {epoch+1}/{training_config['num_epochs']}, " | |
| f"Batch {batch_idx}/{len(dataloader)}, " | |
| f"Loss: {loss.item():.4f}, " | |
| f"K: {safety_metrics['K_negentropy']:.3f}, " | |
| f"C: {safety_metrics['C_complexity']:.3f}, " | |
| f"S: {safety_metrics['S_symbiosis']:.3f}") | |
| except Exception as e: | |
| logger.error(f"Error in batch {batch_idx}: {e}") | |
| continue | |
| # End of epoch processing | |
| scheduler.step() | |
| epoch_time = time.time() - start_time | |
| if num_batches > 0: | |
| avg_loss = epoch_loss / num_batches | |
| avg_metrics = {k: v / num_batches for k, v in epoch_metrics.items()} | |
| logger.info(f"β Epoch {epoch+1} completed in {epoch_time:.2f}s") | |
| logger.info(f"π Avg Loss: {avg_loss:.4f}") | |
| logger.info(f"π Safety Metrics - K: {avg_metrics['K_negentropy']:.3f}, " | |
| f"C: {avg_metrics['C_complexity']:.3f}, " | |
| f"S: {avg_metrics['S_symbiosis']:.3f}") | |
| # Save checkpoint | |
| if (epoch + 1) % training_config['save_every_n_epochs'] == 0: | |
| checkpoint_success = checkpoint_manager.save_checkpoint( | |
| model=model, | |
| session_id=session_id, | |
| epoch=epoch + 1, | |
| metrics={ | |
| 'loss': avg_loss, | |
| 'learning_rate': scheduler.get_last_lr()[0], | |
| **avg_metrics | |
| }, | |
| optimizer_state=optimizer.state_dict(), | |
| scheduler_state=scheduler.state_dict() | |
| ) | |
| if checkpoint_success: | |
| logger.info(f"πΎ Checkpoint saved for epoch {epoch+1}") | |
| # Save best model if loss improved | |
| checkpoint_manager.save_best_model( | |
| session_id=session_id, | |
| model=model, | |
| metric_name='loss', | |
| metric_value=avg_loss, | |
| is_better_func=lambda x, y: x < y # Lower loss is better | |
| ) | |
| logger.info("π Training completed successfully!") | |
| # Test inference and compression | |
| logger.info("π§ͺ Testing model inference and compression...") | |
| model.eval() | |
| with torch.no_grad(): | |
| # Create a test sequence | |
| test_input = torch.randint(0, 2, (1, 64), dtype=torch.long) | |
| logger.info(f"π₯ Input sequence: {test_input.squeeze().tolist()}") | |
| # Model inference | |
| output_logits = model(test_input) | |
| output_probs = torch.softmax(output_logits, dim=-1) | |
| predicted_bits = torch.argmax(output_probs, dim=-1) | |
| logger.info(f"π€ Predicted sequence: {predicted_bits.squeeze().tolist()}") | |
| # Test compression | |
| compressed = compress_bits_batch(predicted_bits) | |
| logger.info(f"ποΈ Compressed length: {len(compressed[0])} (original: {predicted_bits.shape[-1]})") | |
| # Decompress to verify | |
| decompressed = model_output_decompress(compressed) | |
| compression_success = torch.equal(predicted_bits, decompressed) | |
| logger.info(f"β Compression/decompression successful: {compression_success}") | |
| # Final storage usage report | |
| storage_usage = checkpoint_manager.get_storage_usage() | |
| logger.info(f"πΎ Final storage usage: {storage_usage['total_gb']:.3f} GB") | |
| logger.info(f"π Training sessions: {storage_usage['num_sessions']}") | |
| return session_id, model, checkpoint_manager | |
| if __name__ == "__main__": | |
| try: | |
| session_id, trained_model, manager = train_bittransformer() | |
| print(f"\nπ SUCCESS! Training session completed: {session_id}") | |
| print(f"π Use checkpoint_manager.load_checkpoint('{session_id}') to resume") | |
| except Exception as e: | |
| logger.error(f"β Training failed: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| raise |