Spaces:
Runtime error
Runtime error
| import json | |
| from pathlib import Path | |
| from attr import dataclass | |
| import numpy as np | |
| import torch | |
| from tensor_utils import ( | |
| calculate_angle_from_points, | |
| read_mask_from_file, | |
| save_mask_to_file, | |
| ) | |
| class Trajectory: | |
| original_trajectory: dict[str, bool | list[torch.Tensor]] = None | |
| """ | |
| trajectory is dict, keys include 'is_rotation', 'points', if translation also has 'control_points' | |
| """ | |
| block_trajectories: list[dict[str, bool | list[torch.Tensor]]] = [] | |
| """block_num x trajectory | |
| trajectory has keys 'is_rotation' 'deltas' 'start_point' | |
| if is_rotation: trajectory also has 'rotation_center' | |
| """ | |
| mask: np.ndarray = None | |
| """ | |
| target mask for the trajectory | |
| """ | |
| def __init__( | |
| self, | |
| original_trajectory: dict[str, bool | list[torch.Tensor]] = None, | |
| mask: np.ndarray = None, | |
| ): | |
| self.original_trajectory = original_trajectory | |
| self.mask = mask | |
| if original_trajectory is not None: | |
| self.block_trajectories = self.original_to_block_trajectories(original_trajectory) | |
| else: | |
| self.block_trajectories = [] | |
| def original_to_block_trajectories( | |
| original_trajectory: dict[str, bool | list[torch.Tensor]], | |
| block_length: int = 3, | |
| ) -> list[dict[str, bool | list[torch.Tensor]]]: | |
| """Convert an original trajectory (with 'points') into per-block trajectories (with 'deltas'). | |
| For translation: | |
| deltas[i] = points[i+1] - points[0] (displacement from start) | |
| Each block gets `block_length` consecutive deltas. | |
| For rotation: | |
| points[0] is the rotation center. | |
| deltas[i] = angle(center, points[1], points[i+2]) | |
| Each block gets `block_length` consecutive deltas, | |
| plus 'rotation_center' and 'start_point'. | |
| """ | |
| is_rotation = original_trajectory.get("is_rotation", False) | |
| points = original_trajectory.get("points", []) | |
| if is_rotation: | |
| # points[0] = rotation center, points[1] = start arm, points[2:] = subsequent arms | |
| if len(points) < 2: | |
| return [] | |
| rotation_center = points[0] | |
| start_point = points[1] | |
| deltas = [ | |
| calculate_angle_from_points( | |
| rotation_center, | |
| start_point, | |
| point, | |
| ) | |
| for point in points[2:] | |
| ] | |
| else: | |
| # Translation: points[0] = start, points[1:] = subsequent positions | |
| if len(points) < 1: | |
| return [] | |
| start_point = points[0] | |
| deltas = [torch.Tensor(point) - torch.Tensor(start_point) for point in points[1:]] | |
| block_trajectories = [] | |
| for i in range(0, len(deltas), block_length): | |
| block_traj = { | |
| "is_rotation": is_rotation, | |
| "deltas": deltas[i : i + block_length], | |
| "start_point": start_point, | |
| } | |
| if is_rotation: | |
| block_traj["rotation_center"] = rotation_center | |
| block_trajectories.append(block_traj) | |
| return block_trajectories | |
| def set_original_trajectory( | |
| self, | |
| original_trajectory: dict[str, bool | list[torch.Tensor]] = None, | |
| ): | |
| self.original_trajectory = original_trajectory | |
| if original_trajectory is not None: | |
| self.block_trajectories = self.original_to_block_trajectories(original_trajectory) | |
| else: | |
| self.block_trajectories = [] | |
| def _serialize_value( | |
| v, | |
| ): | |
| """Recursively serialize a value to JSON-compatible types.""" | |
| if isinstance(v, torch.Tensor): | |
| return v.tolist() | |
| elif isinstance(v, np.ndarray): | |
| return v.tolist() | |
| elif isinstance(v, dict): | |
| return {k: Trajectory._serialize_value(val) for k, val in v.items()} | |
| elif isinstance(v, list): | |
| return [Trajectory._serialize_value(item) for item in v] | |
| else: | |
| return v | |
| def to_dict( | |
| self, | |
| mask_filename: str = None, | |
| ) -> dict: | |
| """Convert the Trajectory to a JSON-serializable dictionary. | |
| Args: | |
| mask_filename: If provided, store this filename instead of the mask array. | |
| """ | |
| result = {} | |
| if self.original_trajectory is not None: | |
| result["original_trajectory"] = self._serialize_value(self.original_trajectory) | |
| else: | |
| result["original_trajectory"] = None | |
| result["block_trajectories"] = self._serialize_value(self.block_trajectories) | |
| if mask_filename is not None: | |
| result["mask_file"] = mask_filename | |
| return result | |
| def save_mask( | |
| self, | |
| save_path: Path, | |
| ) -> None: | |
| """Save the mask as a PNG image.""" | |
| if self.mask is not None: | |
| save_mask_to_file(self.mask, save_path) | |
| def load( | |
| data: dict, | |
| save_dir: Path, | |
| ) -> "Trajectory": | |
| """Load a Trajectory from a dictionary and directory.""" | |
| traj = Trajectory() | |
| traj.original_trajectory = data.get("original_trajectory", None) | |
| traj.block_trajectories = data.get("block_trajectories", []) | |
| mask_file = data.get("mask_file", None) | |
| if mask_file is not None: | |
| traj.mask = read_mask_from_file(save_dir / mask_file) | |
| return traj | |
| class MultiTrajectory: | |
| block_number: int = 1 | |
| prompt: str = "" | |
| drag_or_animation_select: str = "Drag" | |
| trajectories: list[Trajectory] = [] | |
| """ | |
| multiple trajectories for a single prompt, each trajectory has its own mask | |
| """ | |
| movable_mask: np.ndarray = None | |
| """ | |
| the movable area mask for the whole image | |
| """ | |
| def save( | |
| self, | |
| save_dir: str | Path, | |
| prefix: str = "multi_traj", | |
| ) -> Path: | |
| """Save the MultiTrajectory to a directory. | |
| Masks are saved as PNG images, and metadata is saved as a JSON file. | |
| Args: | |
| save_dir: Directory to save files into. | |
| prefix: Filename prefix for all saved files. | |
| Returns: | |
| Path to the saved JSON file. | |
| """ | |
| save_dir = Path(save_dir) | |
| save_dir.mkdir(parents=True, exist_ok=True) | |
| result = { | |
| "block_number": self.block_number, | |
| "prompt": self.prompt, | |
| "drag_or_animation_select": self.drag_or_animation_select, | |
| } | |
| # Save movable_mask | |
| if self.movable_mask is not None: | |
| movable_mask_filename = f"{prefix}_movable_mask.png" | |
| save_mask_to_file(self.movable_mask, save_dir / movable_mask_filename) | |
| result["movable_area_mask_file"] = movable_mask_filename | |
| else: | |
| result["movable_area_mask_file"] = None | |
| # Save each trajectory and its mask | |
| traj_dicts = [] | |
| if self.trajectories is not None: | |
| for i, traj in enumerate(self.trajectories): | |
| mask_filename = None | |
| if traj.mask is not None: | |
| mask_filename = f"{prefix}_traj_{i}_mask.png" | |
| traj.save_mask(save_dir / mask_filename) | |
| traj_dicts.append(traj.to_dict(mask_filename=mask_filename)) | |
| result["trajectories"] = traj_dicts | |
| # Write JSON | |
| json_path = save_dir / f"{prefix}_trajectory.json" | |
| with open(json_path, "w") as f: | |
| json.dump(result, f, indent=2) | |
| return json_path | |
| def load( | |
| save_dir: str | Path, | |
| prefix: str = "multi_traj", | |
| ) -> "MultiTrajectory": | |
| """Load a MultiTrajectory from a directory. | |
| Args: | |
| save_dir: Directory containing the saved files. | |
| prefix: Filename prefix used when saving. | |
| Returns: | |
| The loaded MultiTrajectory instance. | |
| """ | |
| save_dir = Path(save_dir) | |
| json_path = save_dir / f"{prefix}_trajectory.json" | |
| with open(json_path, "r") as f: | |
| data = json.load(f) | |
| mt = MultiTrajectory() | |
| mt.block_number = data.get("block_number", 1) | |
| mt.prompt = data.get("prompt", "") | |
| mt.drag_or_animation_select = data.get("drag_or_animation_select", "Drag") | |
| # Load movable_mask | |
| movable_file = data.get("movable_area_mask_file", None) | |
| if movable_file is not None: | |
| mt.movable_mask = read_mask_from_file(save_dir / movable_file) | |
| # Load trajectories | |
| mt.trajectories = [] | |
| for traj_data in data.get("trajectories", []): | |
| mt.trajectories.append(Trajectory.load(traj_data, save_dir)) | |
| return mt | |
| def transpose_dict_2d(d): | |
| """Transpose a 2D dict: dict[key1][key2] -> dict[key2][key1].""" | |
| result = {} | |
| for key1, inner in d.items(): | |
| for key2, item in inner.items(): | |
| result.setdefault(key2, {})[key1] = item | |
| return result | |