| """Synchronise champion MLflow models from the remote registry to the local filesystem.""" |
|
|
| import logging |
| import os |
| from pathlib import Path |
| import shutil |
|
|
| import mlflow |
| from mlflow.tracking import MlflowClient |
|
|
| logger = logging.getLogger(__name__) |
| LANGUAGES = ("python", "java", "pharo") |
|
|
|
|
| def _get_mlflow_client() -> MlflowClient: |
| """Return an MLflow client configured from environment variables. |
| |
| If ``MLFLOW_TRACKING_URI`` is defined, it is passed to |
| :func:`mlflow.set_tracking_uri`. Authentication (for example on DagsHub) |
| is handled by MLflow itself via the standard environment variables |
| ``MLFLOW_TRACKING_USERNAME`` and ``MLFLOW_TRACKING_PASSWORD``. |
| """ |
| tracking_uri = os.getenv("MLFLOW_TRACKING_URI") |
| if tracking_uri: |
| mlflow.set_tracking_uri(tracking_uri) |
| return MlflowClient() |
|
|
|
|
| def _find_champion_version_for_language( |
| client: MlflowClient, |
| lang: str, |
| ): |
| """Return the champion model version for the given language, if any. |
| |
| The function searches all registered models and looks for models whose name |
| starts with ``"<lang>-"`` (for example ``"python-transformer"``). For each |
| matching model it tries to resolve the alias ``"<lang>-champion"`` using |
| :meth:`MlflowClient.get_model_version_by_alias`. |
| |
| Args: |
| client: Initialised MLflow client. |
| lang: Language identifier, such as ``"python"``, ``"java"`` or |
| ``"pharo"``. |
| |
| Returns: |
| The matching :class:`mlflow.entities.model_registry.ModelVersion` if a |
| champion is found, otherwise ``None``. |
| |
| """ |
| alias_name = f"{lang}-champion" |
| prefix = f"{lang}-" |
|
|
| |
| for rm in client.search_registered_models(): |
| model_name = rm.name |
| if not model_name.startswith(prefix): |
| continue |
|
|
| try: |
| mv = client.get_model_version_by_alias( |
| name=model_name, |
| alias=alias_name, |
| ) |
| logger.info( |
| "Found champion model for %s: %s (version %s)", |
| lang, |
| model_name, |
| mv.version, |
| ) |
| return mv |
| except Exception: |
| logger.info("Alias not defined for model %s, trying next one.", model_name) |
| continue |
|
|
| logger.warning("No champion model found for %s.", lang) |
| return None |
|
|
|
|
| def sync_best_models_to_disk( |
| models_root: str | Path = "models", |
| api_subdir: str = "api", |
| ) -> None: |
| """Download champion models from MLflow and write them to disk. |
| |
| For each language in :data:`LANGUAGES`, this function looks up the model |
| version with alias ``"<lang>-champion"`` and downloads its artifacts. After |
| download, the directory structure is normalised so that the final layout is: |
| |
| .. code-block:: text |
| |
| models/ |
| <api_subdir>/ |
| python/ |
| <model_type>/ |
| ... |
| java/ |
| <model_type>/ |
| ... |
| pharo/ |
| <model_type>/ |
| ... |
| |
| For transformer models logged via ``mlflow.transformers``, the inner |
| ``model/`` directory is flattened so that the Hugging Face files |
| (``config.json``, ``model.safetensors``, ``tokenizer.json``, and so on) |
| live directly under ``<model_type>/``. |
| |
| Args: |
| models_root: Base directory under which models are written. Can be a |
| string or :class:`pathlib.Path`. Defaults to ``"models"``. |
| api_subdir: Optional subdirectory appended under ``models_root`` (for |
| example ``"api"``). If empty, models are stored directly under |
| ``models_root``. |
| |
| Raises: |
| OSError: If creating directories, moving files, or removing directories |
| fails at the OS level. |
| |
| """ |
| client = _get_mlflow_client() |
|
|
| root = Path(models_root) |
| if api_subdir: |
| root = root / api_subdir |
| root.mkdir(parents=True, exist_ok=True) |
| logger.info("Syncing best models to: %s", root.resolve()) |
|
|
| for lang in LANGUAGES: |
| mv = _find_champion_version_for_language(client, lang) |
| if mv is None: |
| continue |
|
|
| model_name = mv.name |
| try: |
| lang_from_name, model_type = model_name.split("-", 1) |
| except ValueError: |
| logger.error("Unexpected model name format: %s", model_name) |
| continue |
|
|
| if lang_from_name != lang: |
| logger.warning( |
| "Language mismatch for model %s: expected %s, got %s", |
| model_name, |
| lang, |
| lang_from_name, |
| ) |
|
|
| dest_dir = root / lang / model_type |
| if dest_dir.exists(): |
| shutil.rmtree(dest_dir) |
| dest_dir.mkdir(parents=True, exist_ok=True) |
|
|
| logger.info( |
| "Downloading model '%s' version %s to %s...", |
| model_name, |
| mv.version, |
| dest_dir.resolve(), |
| ) |
|
|
| try: |
| |
| downloaded_path = Path( |
| mlflow.artifacts.download_artifacts( |
| artifact_uri=mv.source, |
| dst_path=str(dest_dir), |
| ), |
| ) |
|
|
| |
| |
| model_subdir = downloaded_path / "model" |
| if model_subdir.is_dir(): |
| |
| for item in model_subdir.iterdir(): |
| shutil.move(str(item), dest_dir / item.name) |
|
|
| |
| if downloaded_path != dest_dir: |
| shutil.rmtree(downloaded_path) |
|
|
| except Exception as e: |
| logger.error( |
| "Failed to download/reshape model '%s' version %s: %s", |
| model_name, |
| mv.version, |
| e, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| logging.basicConfig(level=logging.INFO) |
| sync_best_models_to_disk() |
|
|