Spaces:
Runtime error
Runtime error
File size: 4,925 Bytes
60b97da | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 | import os
from datetime import datetime
from pathlib import Path
from sqlalchemy import (
JSON,
Column,
DateTime,
Float,
ForeignKey,
Integer,
String,
Text,
create_engine,
inspect,
text,
)
from sqlalchemy.orm import declarative_base, relationship, sessionmaker
Base = declarative_base()
_ENGINE_CACHE = {}
_SESSION_FACTORY_CACHE = {}
SERVER_DIR = Path(__file__).resolve().parents[1]
class Repository(Base):
__tablename__ = "repositories"
id = Column(Integer, primary_key=True)
github_url = Column(String(1024), nullable=False, unique=True)
source_url = Column(String(1024))
session_key = Column(String(255), index=True)
session_expires_at = Column(DateTime)
owner = Column(String(255), nullable=False)
name = Column(String(255), nullable=False)
branch = Column(String(255), nullable=False, default="main")
local_path = Column(String(1024))
status = Column(String(64), nullable=False, default="queued")
error_message = Column(Text)
file_count = Column(Integer, nullable=False, default=0)
chunk_count = Column(Integer, nullable=False, default=0)
indexed_at = Column(DateTime)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
chunks = relationship(
"CodeChunk", back_populates="repository", cascade="all, delete-orphan"
)
chat_turns = relationship(
"ChatTurn", back_populates="repository", cascade="all, delete-orphan"
)
class CodeChunk(Base):
__tablename__ = "code_chunks"
id = Column(Integer, primary_key=True)
repository_id = Column(Integer, ForeignKey("repositories.id"), nullable=False)
file_path = Column(String(1024), nullable=False)
language = Column(String(64), nullable=False)
symbol_name = Column(String(255))
symbol_type = Column(String(128), nullable=False, default="chunk")
line_start = Column(Integer, nullable=False)
line_end = Column(Integer, nullable=False)
signature = Column(Text)
content = Column(Text, nullable=False)
searchable_text = Column(Text, nullable=False)
metadata_json = Column(JSON, nullable=False, default=dict)
embedding_id = Column(Integer)
rerank_score = Column(Float)
created_at = Column(DateTime, default=datetime.utcnow)
repository = relationship("Repository", back_populates="chunks")
class ChatTurn(Base):
__tablename__ = "chat_turns"
id = Column(Integer, primary_key=True)
repository_id = Column(Integer, ForeignKey("repositories.id"), nullable=False)
role = Column(String(32), nullable=False)
content = Column(Text, nullable=False)
answer_json = Column(JSON)
created_at = Column(DateTime, default=datetime.utcnow)
repository = relationship("Repository", back_populates="chat_turns")
def init_db(database_url: str = None):
if database_url is None:
database_url = os.getenv("DATABASE_URL", "sqlite:///./codebase_rag.db")
database_url = resolve_database_url(database_url)
if database_url in _ENGINE_CACHE:
return _ENGINE_CACHE[database_url], _SESSION_FACTORY_CACHE[database_url]
connect_args = {"check_same_thread": False} if database_url.startswith("sqlite") else {}
engine = create_engine(database_url, echo=False, connect_args=connect_args)
Base.metadata.create_all(engine)
_ensure_runtime_columns(engine)
session_local = sessionmaker(bind=engine)
_ENGINE_CACHE[database_url] = engine
_SESSION_FACTORY_CACHE[database_url] = session_local
return engine, session_local
def resolve_database_url(database_url: str) -> str:
if not database_url.startswith("sqlite:///"):
return database_url
sqlite_path = database_url.removeprefix("sqlite:///")
if sqlite_path == ":memory:":
return database_url
path = Path(sqlite_path)
if not path.is_absolute():
path = SERVER_DIR / path
path.parent.mkdir(parents=True, exist_ok=True)
path.touch(exist_ok=True)
return f"sqlite:///{path.resolve()}"
def _ensure_runtime_columns(engine):
inspector = inspect(engine)
if "repositories" not in inspector.get_table_names():
return
existing = {column["name"] for column in inspector.get_columns("repositories")}
alterations = {
"source_url": "ALTER TABLE repositories ADD COLUMN source_url VARCHAR(1024)",
"session_key": "ALTER TABLE repositories ADD COLUMN session_key VARCHAR(255)",
"session_expires_at": "ALTER TABLE repositories ADD COLUMN session_expires_at DATETIME",
}
with engine.begin() as connection:
for column_name, statement in alterations.items():
if column_name not in existing:
connection.execute(text(statement))
def get_db_session(database_url: str = None):
_, session_local = init_db(database_url)
return session_local()
|