| from typing import Optional |
| import torch |
|
|
| from src.models.score import so3, r3 |
| from src.common.rigid_utils import Rigid, Rotation, quat_multiply |
| from src.common import rotation3d |
|
|
|
|
| def assemble_rigid(rotvec: torch.Tensor, trans: torch.Tensor): |
| rotvec_shape = rotvec.shape |
| rotmat = rotation3d.axis_angle_to_matrix(rotvec).view(rotvec_shape[:-1] + (3, 3)) |
| return Rigid( |
| rots=Rotation(rot_mats=rotmat), |
| trans=trans, |
| ) |
|
|
| def apply_mask(x_tgt, x_src, tgt_mask): |
| return tgt_mask * x_tgt + (1 - tgt_mask) * x_src |
|
|
|
|
| class FrameDiffuser: |
| """ |
| Wrapper class for diffusion of rigid body transformations, |
| including rotations and translations. |
| """ |
| def __init__(self, |
| trans_diffuser: Optional[r3.R3Diffuser] = None, |
| rot_diffuser: Optional[so3.SO3Diffuser] = None, |
| min_t: float = 0.001, |
| ): |
| |
| self.trans_diffuser = trans_diffuser |
| self.rot_diffuser = rot_diffuser |
| self.min_t = min_t |
|
|
| def forward_marginal( |
| self, |
| rigids_0: Rigid, |
| t: torch.Tensor, |
| diffuse_mask: torch.Tensor = None, |
| as_tensor_7: bool = True, |
| ): |
| """ |
| Args: |
| rigids_0: [..., N] openfold Rigid objects |
| t: continuous time in [0, 1]. |
| |
| Returns: |
| Dict contains: |
| rigids_t: [..., N] noised rigid. [..., N, 7] if as_tensor_7 is true. |
| trans_score: [..., N, 3] translation score |
| rot_score: [..., N, 3] rotation score |
| trans_score_norm: [...] translation score norm |
| rot_score_norm: [...] rotation score norm |
| """ |
| output = {} |
| rot_0 = rotation3d.matrix_to_axis_angle(rigids_0.get_rots().get_rot_mats()) |
| trans_0 = rigids_0.get_trans() |
| |
| if self.rot_diffuser is None: |
| rot_t = rot_0 |
| rot_score, rot_score_scaling = torch.zeros_like(rot_0), t |
| else: |
| rot_t, rot_score = self.rot_diffuser.forward_marginal(rot_0, t) |
| rot_score_scaling = self.rot_diffuser.score_scaling(t) |
|
|
| if self.trans_diffuser is None: |
| trans_t, trans_score, trans_score_scaling = ( |
| trans_0, |
| torch.zeros_like(trans_0), |
| torch.ones_like(t) |
| ) |
| else: |
| trans_t, trans_score = self.trans_diffuser.forward_marginal(trans_0, t) |
| trans_score_scaling = self.trans_diffuser.score_scaling(t) |
|
|
| |
| if diffuse_mask is not None: |
| diffuse_mask = torch.as_tensor(diffuse_mask, device=trans_t.device, dtype=trans_t.dtype)[..., None] |
| |
| rot_t = apply_mask(rot_t, rot_0, diffuse_mask) |
| trans_t = apply_mask(trans_t, trans_0, diffuse_mask) |
|
|
| trans_score = apply_mask( |
| trans_score, |
| torch.zeros_like(trans_score), |
| diffuse_mask |
| ) |
| rot_score = apply_mask( |
| rot_score, |
| torch.zeros_like(rot_score), |
| diffuse_mask |
| ) |
| |
| rigids_t = assemble_rigid(rot_t, trans_t) |
| |
| if as_tensor_7: |
| rigids_t = rigids_t.to_tensor_7() |
| |
| output = { |
| 'rigids_t': rigids_t, |
| 'trans_score': trans_score, |
| 'rot_score': rot_score, |
| 'trans_score_scaling': trans_score_scaling, |
| 'rot_score_scaling': rot_score_scaling, |
| } |
| return output |
|
|
| def score( |
| self, |
| rigids_0: Rigid, |
| rigids_t: Rigid, |
| t: torch.Tensor, |
| mask: torch.Tensor = None, |
| ): |
| rot_0, trans_0 = rigids_0.get_rots(), rigids_0.get_trans() |
| rot_t, trans_t = rigids_t.get_rots(), rigids_t.get_trans() |
|
|
| if self.rot_diffuser is None: |
| rot_score = torch.zeros_like(rot_0) |
| else: |
| rot_0_inv = rot_0.invert() |
| quat_0_inv = rotation3d.matrix_to_quaternion(rot_0_inv.get_rot_mats()) |
| quat_t = rotation3d.matrix_to_quaternion(rot_t.get_rot_mats()) |
| |
| quat_0t = quat_multiply(quat_0_inv, quat_t) |
| rotvec_0t = rotation3d.quaternion_to_axis_angle(quat_0t) |
| |
| rot_score = self.rot_diffuser.score(rotvec_0t, t) |
|
|
| if self.trans_diffuser is None: |
| trans_score = torch.zeros_like(trans_0) |
| else: |
| trans_score = self.trans_diffuser.score(trans_t, trans_0, t, scale=True) |
| |
| if mask is not None: |
| trans_score = trans_score * mask[..., None] |
| rot_score = rot_score * mask[..., None] |
| |
| return { |
| 'trans_score': trans_score, |
| 'rot_score': rot_score |
| } |
|
|
| def score_scaling(self, t): |
| rot_score_scaling = self.rot_diffuser.score_scaling(t) |
| trans_score_scaling = self.trans_diffuser.score_scaling(t) |
| return { |
| 'trans_score_scaling': trans_score_scaling, |
| 'rot_score_scaling': rot_score_scaling, |
| } |
| |
| def reverse( |
| self, |
| rigids_t: Rigid, |
| rot_score: torch.Tensor, |
| trans_score: torch.Tensor, |
| t: torch.Tensor, |
| dt: float, |
| diffuse_mask: torch.Tensor = None, |
| center_trans: bool = True, |
| noise_scale: float = 1.0, |
| probability_flow: bool = True, |
| ): |
| """Reverse sampling function from (t) to (t-1). |
| |
| Args: |
| rigids_t: [..., N] protein rigid objects at time t. |
| rot_score: [..., N, 3] rotation score. |
| trans_score: [..., N, 3] translation score. |
| t: continuous time in [0, 1]. |
| dt: continuous step size in [0, 1]. |
| mask: [..., N] which residues to update. |
| center_trans: true to set center of mass to zero after step |
| probability_flow: whether to use probability flow ODE. |
| |
| Returns: |
| rigids_t_1: [..., N] protein rigid objects at time t-1. |
| """ |
| |
| rot_t = rotation3d.matrix_to_axis_angle(rigids_t.get_rots().get_rot_mats()) |
| trans_t = rigids_t.get_trans() |
| |
| |
| rot_t_1 = self.rot_diffuser.reverse( |
| rot_t=rot_t, |
| score_t=rot_score, |
| t=t, |
| dt=dt, |
| noise_scale=noise_scale, |
| probability_flow=probability_flow, |
| ) if self.rot_diffuser is not None else rot_t |
|
|
| |
| trans_t_1 = self.trans_diffuser.reverse( |
| x_t=trans_t, |
| score_t=trans_score, |
| t=t, |
| dt=dt, |
| center=center_trans, |
| noise_scale=noise_scale, |
| probability_flow=probability_flow, |
| ) if self.trans_diffuser is not None else trans_t |
|
|
| |
| if diffuse_mask is not None: |
| trans_t_1 = apply_mask(trans_t_1, trans_t, diffuse_mask[..., None]) |
| rot_t_1 = apply_mask(rot_t_1, rot_t, diffuse_mask[..., None]) |
|
|
| return assemble_rigid(rot_t_1, trans_t_1) |
|
|
| def sample_prior( |
| self, |
| shape: torch.Size, |
| device: torch.device, |
| reference_rigids: Rigid = None, |
| diffuse_mask: torch.Tensor = None, |
| as_tensor_7: bool = False |
| ): |
| """Samples rigids from reference distribution. |
| |
| """ |
| if reference_rigids is not None: |
| assert reference_rigids.shape[:-1] == shape, f"reference_rigids.shape[:-1] = {reference_rigids.shape[:-1]}, shape = {shape}" |
| assert diffuse_mask is not None, "diffuse_mask must be provided if reference_rigids is given" |
| rot_ref = rotation3d.matrix_to_axis_angle(reference_rigids.get_rots().get_rot_mats()) |
| trans_ref = reference_rigids.get_trans() |
| |
| trans_ref = self.trans_diffuser.scale(trans_ref) |
| else: |
| |
| assert diffuse_mask is None, "diffuse_mask must be None if reference_rigids is None" |
| assert self.rot_diffuser is not None and self.trans_diffuser is not None |
| |
| |
| trans_shape, rot_shape = shape + (3, ), shape + (3, ) |
| rot_sample = self.rot_diffuser.sample_prior(shape=rot_shape, device=device) \ |
| if self.rot_diffuser is not None else rot_ref |
| trans_sample = self.trans_diffuser.sample_prior(shape=trans_shape, device=device) \ |
| if self.trans_diffuser is not None else trans_ref |
| |
| |
| if diffuse_mask is not None: |
| rot_sample = apply_mask(rot_sample, rot_ref, diffuse_mask[..., None]) |
| trans_sample = apply_mask(trans_sample, trans_ref, diffuse_mask[..., None]) |
| |
| trans_sample = self.trans_diffuser.unscale(trans_sample) |
| |
| |
| rigids_t = assemble_rigid(rot_sample, trans_sample) |
| |
| if as_tensor_7: |
| rigids_t = rigids_t.to_tensor_7() |
| |
| return {'rigids_t': rigids_t} |