Instructions to use Synthyra/Boltz2 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Synthyra/Boltz2 with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="Synthyra/Boltz2", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Synthyra/Boltz2", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import importlib | |
| import torch | |
| from torch import Tensor, nn | |
| from . import vb_layers_initialize as init | |
| def kernel_triangular_mult( | |
| x, | |
| direction, | |
| mask, | |
| norm_in_weight, | |
| norm_in_bias, | |
| p_in_weight, | |
| g_in_weight, | |
| norm_out_weight, | |
| norm_out_bias, | |
| p_out_weight, | |
| g_out_weight, | |
| eps, | |
| ): | |
| triangle_module = importlib.import_module("cuequivariance_torch.primitives.triangle") | |
| triangle_multiplicative_update = triangle_module.triangle_multiplicative_update | |
| return triangle_multiplicative_update( | |
| x, | |
| direction=direction, | |
| mask=mask, | |
| norm_in_weight=norm_in_weight, | |
| norm_in_bias=norm_in_bias, | |
| p_in_weight=p_in_weight, | |
| g_in_weight=g_in_weight, | |
| norm_out_weight=norm_out_weight, | |
| norm_out_bias=norm_out_bias, | |
| p_out_weight=p_out_weight, | |
| g_out_weight=g_out_weight, | |
| eps=eps, | |
| ) | |
| class TriangleMultiplicationOutgoing(nn.Module): | |
| """TriangleMultiplicationOutgoing.""" | |
| def __init__(self, dim: int = 128) -> None: | |
| """Initialize the TriangularUpdate module. | |
| Parameters | |
| ---------- | |
| dim: int | |
| The dimension of the input, default 128 | |
| """ | |
| super().__init__() | |
| self.norm_in = nn.LayerNorm(dim, eps=1e-5) | |
| self.p_in = nn.Linear(dim, 2 * dim, bias=False) | |
| self.g_in = nn.Linear(dim, 2 * dim, bias=False) | |
| self.norm_out = nn.LayerNorm(dim) | |
| self.p_out = nn.Linear(dim, dim, bias=False) | |
| self.g_out = nn.Linear(dim, dim, bias=False) | |
| init.bias_init_one_(self.norm_in.weight) | |
| init.bias_init_zero_(self.norm_in.bias) | |
| init.lecun_normal_init_(self.p_in.weight) | |
| init.gating_init_(self.g_in.weight) | |
| init.bias_init_one_(self.norm_out.weight) | |
| init.bias_init_zero_(self.norm_out.bias) | |
| init.final_init_(self.p_out.weight) | |
| init.gating_init_(self.g_out.weight) | |
| def forward(self, x: Tensor, mask: Tensor, use_kernels: bool = False) -> Tensor: | |
| """Perform a forward pass. | |
| Parameters | |
| ---------- | |
| x: torch.Tensor | |
| The input data of shape (B, N, N, D) | |
| mask: torch.Tensor | |
| The input mask of shape (B, N, N) | |
| use_kernels: bool | |
| Whether to use the kernel | |
| Returns | |
| ------- | |
| x: torch.Tensor | |
| The output data of shape (B, N, N, D) | |
| """ | |
| if use_kernels: | |
| return kernel_triangular_mult( | |
| x, | |
| direction="outgoing", | |
| mask=mask, | |
| norm_in_weight=self.norm_in.weight, | |
| norm_in_bias=self.norm_in.bias, | |
| p_in_weight=self.p_in.weight, | |
| g_in_weight=self.g_in.weight, | |
| norm_out_weight=self.norm_out.weight, | |
| norm_out_bias=self.norm_out.bias, | |
| p_out_weight=self.p_out.weight, | |
| g_out_weight=self.g_out.weight, | |
| eps=1e-5, | |
| ) | |
| # Input gating: D -> D | |
| x = self.norm_in(x) | |
| x_in = x | |
| x = self.p_in(x) * self.g_in(x).sigmoid() | |
| # Apply mask | |
| x = x * mask.unsqueeze(-1) | |
| # Split input and cast to float | |
| a, b = torch.chunk(x.float(), 2, dim=-1) | |
| # Triangular projection | |
| x = torch.einsum("bikd,bjkd->bijd", a, b) | |
| # Output gating | |
| x = self.p_out(self.norm_out(x)) * self.g_out(x_in).sigmoid() | |
| return x | |
| class TriangleMultiplicationIncoming(nn.Module): | |
| """TriangleMultiplicationIncoming.""" | |
| def __init__(self, dim: int = 128) -> None: | |
| """Initialize the TriangularUpdate module. | |
| Parameters | |
| ---------- | |
| dim: int | |
| The dimension of the input, default 128 | |
| """ | |
| super().__init__() | |
| self.norm_in = nn.LayerNorm(dim, eps=1e-5) | |
| self.p_in = nn.Linear(dim, 2 * dim, bias=False) | |
| self.g_in = nn.Linear(dim, 2 * dim, bias=False) | |
| self.norm_out = nn.LayerNorm(dim) | |
| self.p_out = nn.Linear(dim, dim, bias=False) | |
| self.g_out = nn.Linear(dim, dim, bias=False) | |
| init.bias_init_one_(self.norm_in.weight) | |
| init.bias_init_zero_(self.norm_in.bias) | |
| init.lecun_normal_init_(self.p_in.weight) | |
| init.gating_init_(self.g_in.weight) | |
| init.bias_init_one_(self.norm_out.weight) | |
| init.bias_init_zero_(self.norm_out.bias) | |
| init.final_init_(self.p_out.weight) | |
| init.gating_init_(self.g_out.weight) | |
| def forward(self, x: Tensor, mask: Tensor, use_kernels: bool = False) -> Tensor: | |
| """Perform a forward pass. | |
| Parameters | |
| ---------- | |
| x: torch.Tensor | |
| The input data of shape (B, N, N, D) | |
| mask: torch.Tensor | |
| The input mask of shape (B, N, N) | |
| use_kernels: bool | |
| Whether to use the kernel | |
| Returns | |
| ------- | |
| x: torch.Tensor | |
| The output data of shape (B, N, N, D) | |
| """ | |
| if use_kernels: | |
| return kernel_triangular_mult( | |
| x, | |
| direction="incoming", | |
| mask=mask, | |
| norm_in_weight=self.norm_in.weight, | |
| norm_in_bias=self.norm_in.bias, | |
| p_in_weight=self.p_in.weight, | |
| g_in_weight=self.g_in.weight, | |
| norm_out_weight=self.norm_out.weight, | |
| norm_out_bias=self.norm_out.bias, | |
| p_out_weight=self.p_out.weight, | |
| g_out_weight=self.g_out.weight, | |
| eps=1e-5, | |
| ) | |
| # Input gating: D -> D | |
| x = self.norm_in(x) | |
| x_in = x | |
| x = self.p_in(x) * self.g_in(x).sigmoid() | |
| # Apply mask | |
| x = x * mask.unsqueeze(-1) | |
| # Split input and cast to float | |
| a, b = torch.chunk(x.float(), 2, dim=-1) | |
| # Triangular projection | |
| x = torch.einsum("bkid,bkjd->bijd", a, b) | |
| # Output gating | |
| x = self.p_out(self.norm_out(x)) * self.g_out(x_in).sigmoid() | |
| return x | |