Instructions to use nvidia/C-RADIOv2-B with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use nvidia/C-RADIOv2-B with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("image-feature-extraction", model="nvidia/C-RADIOv2-B", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("nvidia/C-RADIOv2-B", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| # Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # NVIDIA CORPORATION and its licensors retain all intellectual property | |
| # and proprietary rights in and to this software, related documentation | |
| # and any modifications thereto. Any use, reproduction, disclosure or | |
| # distribution of this software and related documentation without an express | |
| # license agreement from NVIDIA CORPORATION is strictly prohibited. | |
| from typing import Callable, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union | |
| import torch | |
| from torch import nn | |
| from timm.models import create_model, VisionTransformer | |
| from .enable_cpe_support import enable_cpe | |
| from .input_conditioner import InputConditioner | |
| from .adaptor_base import AdaptorBase, RadioOutput, AdaptorInput | |
| from . import eradio_model | |
| from .enable_spectral_reparam import configure_spectral_reparam_from_args | |
| from .feature_normalizer import FeatureNormalizer, IntermediateFeatureNormalizer | |
| from . import dual_hybrid_vit | |
| class Resolution(NamedTuple): | |
| height: int | |
| width: int | |
| class RADIOModel(nn.Module): | |
| def __init__( | |
| self, | |
| model: nn.Module, | |
| input_conditioner: InputConditioner, | |
| patch_size: int, | |
| max_resolution: int, | |
| preferred_resolution: Resolution, | |
| summary_idxs: Optional[torch.Tensor] = None, | |
| window_size: int = None, | |
| adaptors: Dict[str, AdaptorBase] = None, | |
| feature_normalizer: Optional[FeatureNormalizer] = None, | |
| inter_feature_normalizer: Optional[IntermediateFeatureNormalizer] = None, | |
| ): | |
| super().__init__() | |
| self.model = model | |
| self.input_conditioner = input_conditioner | |
| if summary_idxs is not None: | |
| self.register_buffer('summary_idxs', summary_idxs) | |
| else: | |
| self.summary_idxs = None | |
| self._preferred_resolution = preferred_resolution | |
| self._patch_size = patch_size | |
| self._max_resolution = max_resolution | |
| self._window_size = window_size | |
| adaptors = adaptors or dict() | |
| self.adaptors = nn.ModuleDict(adaptors) | |
| if feature_normalizer is None: | |
| feature_normalizer = nn.Identity() | |
| self.feature_normalizer = feature_normalizer | |
| self.inter_feature_normalizer = inter_feature_normalizer | |
| def num_summary_tokens(self) -> int: | |
| if hasattr(self.model, 'num_summary_tokens'): | |
| return self.model.num_summary_tokens | |
| patch_gen = getattr(self.model, "patch_generator", None) | |
| if patch_gen is not None: | |
| return patch_gen.num_skip | |
| elif getattr(self.model, 'global_pool', None) == 'avg': | |
| return 0 | |
| return 1 | |
| def num_cls_tokens(self) -> int: | |
| if hasattr(self.model, 'num_cls_tokens'): | |
| return self.model.num_cls_tokens | |
| patch_gen = getattr(self.model, 'patch_generator', None) | |
| if patch_gen is not None: | |
| return patch_gen.num_cls_tokens | |
| elif getattr(self.model, 'global_pool', None) == 'avg': | |
| return 0 | |
| return 1 | |
| def patch_size(self) -> int: | |
| if self._patch_size is not None: | |
| return self._patch_size | |
| if hasattr(self.model, "patch_size"): | |
| return self.model.patch_size | |
| patch_gen = getattr(self.model, "patch_generator", None) | |
| if patch_gen is not None: | |
| return patch_gen.patch_size | |
| return None | |
| def max_resolution(self) -> int: | |
| return self._max_resolution | |
| def preferred_resolution(self) -> Resolution: | |
| return self._preferred_resolution | |
| def window_size(self) -> int: | |
| return self._window_size | |
| def min_resolution_step(self) -> int: | |
| res = self.patch_size | |
| if self.window_size is not None: | |
| res *= self.window_size | |
| return res | |
| def blocks(self) -> Iterable[nn.Module]: | |
| blocks = getattr(self.model, 'blocks', None) | |
| if blocks is not None: | |
| return blocks | |
| return None | |
| def embed_dim(self) -> int: | |
| return self.model.embed_dim | |
| def make_preprocessor_external(self) -> Callable[[torch.Tensor], torch.Tensor]: | |
| ret = self.input_conditioner | |
| self.input_conditioner = nn.Identity() | |
| return ret | |
| def get_nearest_supported_resolution(self, height: int, width: int) -> Resolution: | |
| height = int(round(height / self.min_resolution_step) * self.min_resolution_step) | |
| width = int(round(width / self.min_resolution_step) * self.min_resolution_step) | |
| height = max(height, self.min_resolution_step) | |
| width = max(width, self.min_resolution_step) | |
| return Resolution(height=height, width=width) | |
| def switch_to_deploy(self): | |
| fn = getattr(self.model, 'switch_to_deploy', None) | |
| if fn is not None: | |
| fn() | |
| def forward(self, x: torch.Tensor, feature_fmt: str = 'NLC') -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: | |
| ''' | |
| Forward process for model. | |
| Args: | |
| x: Input tensor. Unless `make_preprocessor_external` has been called, then the dynamic range of `x` is expected to be `[0, 1]`, | |
| otherwise `x` is expected to be mean centered with unit standard deviation. | |
| feature_format: ['NLC', 'NCHW'] - The output format for the features. | |
| ''' | |
| res_step = self.min_resolution_step | |
| if res_step is not None and (x.shape[-2] % res_step != 0 or x.shape[-1] % res_step != 0): | |
| raise ValueError('The input resolution must be a multiple of `self.min_resolution_step`. ' | |
| '`self.get_nearest_supported_resolution(<height>, <width>) is provided as a convenience API. ' | |
| f'Input: {x.shape[-2:]}, Nearest: {self.get_nearest_supported_resolution(*x.shape[-2:])}') | |
| x = self.input_conditioner(x) | |
| y = self.model.forward_features(x) | |
| ret = self._extract_final(x, y, feature_fmt=feature_fmt) | |
| return ret | |
| def _extract_final(self, x: torch.Tensor, y: torch.Tensor, feature_fmt: str = 'NLC'): | |
| if isinstance(self.model, VisionTransformer): | |
| patch_gen = getattr(self.model, "patch_generator", None) | |
| if patch_gen is not None: | |
| all_summary = y[:, : patch_gen.num_cls_tokens] | |
| if self.summary_idxs is not None: | |
| bb_summary = all_summary[:, self.summary_idxs] | |
| else: | |
| bb_summary = all_summary | |
| all_feat = y[:, patch_gen.num_skip :] | |
| elif self.model.global_pool == "avg": | |
| all_summary = y[:, self.model.num_prefix_tokens :].mean(dim=1) | |
| bb_summary = all_summary | |
| all_feat = y | |
| else: | |
| all_summary = y[:, 0] | |
| bb_summary = all_summary | |
| all_feat = y[:, 1:] | |
| elif isinstance(self.model, eradio_model.ERADIO): | |
| _, f = y | |
| all_feat = f.flatten(2).transpose(1, 2) | |
| all_summary = all_feat.mean(dim=1) | |
| bb_summary = all_summary | |
| elif isinstance(y, (list, tuple)): | |
| all_summary, all_feat = y | |
| bb_summary = all_summary | |
| else: | |
| all_summary = y[:, :self.num_cls_tokens] | |
| if self.summary_idxs is not None and all_summary.shape[1] > 1: | |
| if all_summary.shape[1] == 1: | |
| # Create dummy duplicates | |
| all_summary = all_summary.expand(-1, 128, -1) | |
| bb_summary = all_summary[:, self.summary_idxs] | |
| else: | |
| bb_summary = all_summary | |
| all_feat = y[:, self.num_summary_tokens:] | |
| all_feat = self.feature_normalizer(all_feat) | |
| if feature_fmt == 'NCHW': | |
| fmt_feat = (all_feat.reshape(all_feat.shape[0], x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size, all_feat.shape[2]) | |
| .permute(0, 3, 1, 2) | |
| ) | |
| elif feature_fmt == 'NLC': | |
| fmt_feat = all_feat | |
| else: | |
| raise ValueError(f'Unsupported feature_fmt: {feature_fmt}. Must be one of ["NLC", "NCHW"]') | |
| ret = RadioOutput(bb_summary.flatten(1), fmt_feat) | |
| if self.adaptors: | |
| ret = dict(backbone=ret) | |
| for name, adaptor in self.adaptors.items(): | |
| if all_summary.ndim == 3: | |
| if all_summary.shape[1] == 1: | |
| summary = all_summary[:, 0] | |
| else: | |
| summary = all_summary[:, adaptor.head_idx] | |
| else: | |
| summary = all_summary | |
| ada_input = AdaptorInput(images=x, summary=summary.float(), features=all_feat, feature_fmt=feature_fmt, patch_size=self.patch_size) | |
| v = adaptor(ada_input).to(torch.float32) | |
| ret[name] = v | |
| return ret | |
| def forward_intermediates( | |
| self, | |
| x: torch.Tensor, | |
| indices: Optional[Union[int, List[int], Tuple[int]]] = None, | |
| return_prefix_tokens: bool = False, | |
| norm: bool = False, | |
| stop_early: bool = False, | |
| output_fmt: str = 'NCHW', | |
| intermediates_only: bool = False, | |
| aggregation: Optional[str] = "sparse", | |
| norm_alpha_scheme: Optional[str] = "post-alpha", | |
| ) -> List[RadioOutput]: | |
| """ Forward features that returns intermediates. | |
| Args: | |
| x: Input image tensor | |
| indices: Take last n blocks if int, select matching indices if sequence | |
| return_prefix_tokens: Return both prefix and spatial intermediate tokens | |
| norm: Apply norm layer to all intermediates | |
| stop_early: Stop iterating over blocks when last desired intermediate hit | |
| output_fmt: Shape of intermediate feature outputs. Options: NCHW, NLC | |
| intermediates_only: Only return intermediate features | |
| aggregation: intermediate layer aggregation method (sparse or dense). | |
| Dense accumulation is done by averaging the features in each group. | |
| norm_alpha_scheme: apply alpha before ("pre-alpha") or after accumulation ("post-alpha"), or don't normalize ("none") | |
| Only affects dense aggregation | |
| Returns: | |
| List of RadioOutput objects. | |
| """ | |
| x = self.input_conditioner(x) | |
| intermediates = self.model.forward_intermediates( | |
| x, | |
| indices=indices, | |
| return_prefix_tokens=return_prefix_tokens, | |
| norm=norm, | |
| stop_early=stop_early, | |
| output_fmt=output_fmt, | |
| intermediates_only=intermediates_only, | |
| aggregation=aggregation, | |
| inter_feature_normalizer=self.inter_feature_normalizer, | |
| norm_alpha_scheme=norm_alpha_scheme, | |
| ) | |
| if not intermediates_only: | |
| final, intermediates = intermediates | |
| def prepare_summary(summ: Optional[torch.Tensor]): | |
| if summ is None: | |
| return summ | |
| if self.summary_idxs is not None and summ.shape[1] > 1: | |
| summ = summ[:, self.summary_idxs] | |
| return summ.flatten(1) | |
| if return_prefix_tokens: | |
| radio_outputs = [ | |
| RadioOutput(prepare_summary(summary), features) | |
| for summary, features in intermediates | |
| ] | |
| else: | |
| radio_outputs = intermediates | |
| if intermediates_only: | |
| return radio_outputs | |
| else: | |
| final = self._extract_final(x, final, feature_fmt=output_fmt) | |
| return final, radio_outputs | |
| def create_model_from_args(args) -> nn.Module: | |
| in_chans = 3 | |
| if args.in_chans is not None: | |
| in_chans = args.in_chans | |
| elif args.input_size is not None: | |
| in_chans = args.input_size[0] | |
| # Skip weight initialization unless it's explicitly requested. | |
| weight_init = args.model_kwargs.pop("weight_init", "skip") | |
| model = create_model( | |
| args.model, | |
| pretrained=args.pretrained, | |
| in_chans=in_chans, | |
| num_classes=args.num_classes, | |
| drop_rate=args.drop, | |
| drop_path_rate=args.drop_path, | |
| drop_block_rate=args.drop_block, | |
| global_pool=args.gp, | |
| bn_momentum=args.bn_momentum, | |
| bn_eps=args.bn_eps, | |
| scriptable=args.torchscript, | |
| checkpoint_path=args.initial_checkpoint, | |
| weight_init=weight_init, | |
| **args.model_kwargs, | |
| ) | |
| if hasattr(model, 'norm') and not getattr(args, 'model_norm', False): | |
| model.norm = nn.Identity() | |
| model.head = nn.Identity() | |
| if args.cpe_max_size is not None: | |
| uq_teachers = set(t['name'] for t in args.teachers) | |
| enable_cpe( | |
| model, | |
| args.cpe_max_size, | |
| num_cls_tokens=len(uq_teachers) if args.cls_token_per_teacher else 1, | |
| register_multiple=getattr(args, 'register_multiple', None), | |
| num_registers=getattr(args, 'cpe_num_registers', None), | |
| ) | |
| return model | |