DragStream / optimize_utils.py
bowmanchow's picture
add code
0328207
import json
from pathlib import Path
from attr import dataclass
import numpy as np
import torch
from tensor_utils import (
calculate_angle_from_points,
read_mask_from_file,
save_mask_to_file,
)
@dataclass
class Trajectory:
original_trajectory: dict[str, bool | list[torch.Tensor]] = None
"""
trajectory is dict, keys include 'is_rotation', 'points', if translation also has 'control_points'
"""
block_trajectories: list[dict[str, bool | list[torch.Tensor]]] = []
"""block_num x trajectory
trajectory has keys 'is_rotation' 'deltas' 'start_point'
if is_rotation: trajectory also has 'rotation_center'
"""
mask: np.ndarray = None
"""
target mask for the trajectory
"""
def __init__(
self,
original_trajectory: dict[str, bool | list[torch.Tensor]] = None,
mask: np.ndarray = None,
):
self.original_trajectory = original_trajectory
self.mask = mask
if original_trajectory is not None:
self.block_trajectories = self.original_to_block_trajectories(original_trajectory)
else:
self.block_trajectories = []
@staticmethod
def original_to_block_trajectories(
original_trajectory: dict[str, bool | list[torch.Tensor]],
block_length: int = 3,
) -> list[dict[str, bool | list[torch.Tensor]]]:
"""Convert an original trajectory (with 'points') into per-block trajectories (with 'deltas').
For translation:
deltas[i] = points[i+1] - points[0] (displacement from start)
Each block gets `block_length` consecutive deltas.
For rotation:
points[0] is the rotation center.
deltas[i] = angle(center, points[1], points[i+2])
Each block gets `block_length` consecutive deltas,
plus 'rotation_center' and 'start_point'.
"""
is_rotation = original_trajectory.get("is_rotation", False)
points = original_trajectory.get("points", [])
if is_rotation:
# points[0] = rotation center, points[1] = start arm, points[2:] = subsequent arms
if len(points) < 2:
return []
rotation_center = points[0]
start_point = points[1]
deltas = [
calculate_angle_from_points(
rotation_center,
start_point,
point,
)
for point in points[2:]
]
else:
# Translation: points[0] = start, points[1:] = subsequent positions
if len(points) < 1:
return []
start_point = points[0]
deltas = [torch.Tensor(point) - torch.Tensor(start_point) for point in points[1:]]
block_trajectories = []
for i in range(0, len(deltas), block_length):
block_traj = {
"is_rotation": is_rotation,
"deltas": deltas[i : i + block_length],
"start_point": start_point,
}
if is_rotation:
block_traj["rotation_center"] = rotation_center
block_trajectories.append(block_traj)
return block_trajectories
def set_original_trajectory(
self,
original_trajectory: dict[str, bool | list[torch.Tensor]] = None,
):
self.original_trajectory = original_trajectory
if original_trajectory is not None:
self.block_trajectories = self.original_to_block_trajectories(original_trajectory)
else:
self.block_trajectories = []
@staticmethod
def _serialize_value(
v,
):
"""Recursively serialize a value to JSON-compatible types."""
if isinstance(v, torch.Tensor):
return v.tolist()
elif isinstance(v, np.ndarray):
return v.tolist()
elif isinstance(v, dict):
return {k: Trajectory._serialize_value(val) for k, val in v.items()}
elif isinstance(v, list):
return [Trajectory._serialize_value(item) for item in v]
else:
return v
def to_dict(
self,
mask_filename: str = None,
) -> dict:
"""Convert the Trajectory to a JSON-serializable dictionary.
Args:
mask_filename: If provided, store this filename instead of the mask array.
"""
result = {}
if self.original_trajectory is not None:
result["original_trajectory"] = self._serialize_value(self.original_trajectory)
else:
result["original_trajectory"] = None
result["block_trajectories"] = self._serialize_value(self.block_trajectories)
if mask_filename is not None:
result["mask_file"] = mask_filename
return result
def save_mask(
self,
save_path: Path,
) -> None:
"""Save the mask as a PNG image."""
if self.mask is not None:
save_mask_to_file(self.mask, save_path)
@staticmethod
def load(
data: dict,
save_dir: Path,
) -> "Trajectory":
"""Load a Trajectory from a dictionary and directory."""
traj = Trajectory()
traj.original_trajectory = data.get("original_trajectory", None)
traj.block_trajectories = data.get("block_trajectories", [])
mask_file = data.get("mask_file", None)
if mask_file is not None:
traj.mask = read_mask_from_file(save_dir / mask_file)
return traj
@dataclass
class MultiTrajectory:
block_number: int = 1
prompt: str = ""
drag_or_animation_select: str = "Drag"
trajectories: list[Trajectory] = []
"""
multiple trajectories for a single prompt, each trajectory has its own mask
"""
movable_mask: np.ndarray = None
"""
the movable area mask for the whole image
"""
def save(
self,
save_dir: str | Path,
prefix: str = "multi_traj",
) -> Path:
"""Save the MultiTrajectory to a directory.
Masks are saved as PNG images, and metadata is saved as a JSON file.
Args:
save_dir: Directory to save files into.
prefix: Filename prefix for all saved files.
Returns:
Path to the saved JSON file.
"""
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
result = {
"block_number": self.block_number,
"prompt": self.prompt,
"drag_or_animation_select": self.drag_or_animation_select,
}
# Save movable_mask
if self.movable_mask is not None:
movable_mask_filename = f"{prefix}_movable_mask.png"
save_mask_to_file(self.movable_mask, save_dir / movable_mask_filename)
result["movable_area_mask_file"] = movable_mask_filename
else:
result["movable_area_mask_file"] = None
# Save each trajectory and its mask
traj_dicts = []
if self.trajectories is not None:
for i, traj in enumerate(self.trajectories):
mask_filename = None
if traj.mask is not None:
mask_filename = f"{prefix}_traj_{i}_mask.png"
traj.save_mask(save_dir / mask_filename)
traj_dicts.append(traj.to_dict(mask_filename=mask_filename))
result["trajectories"] = traj_dicts
# Write JSON
json_path = save_dir / f"{prefix}_trajectory.json"
with open(json_path, "w") as f:
json.dump(result, f, indent=2)
return json_path
@staticmethod
def load(
save_dir: str | Path,
prefix: str = "multi_traj",
) -> "MultiTrajectory":
"""Load a MultiTrajectory from a directory.
Args:
save_dir: Directory containing the saved files.
prefix: Filename prefix used when saving.
Returns:
The loaded MultiTrajectory instance.
"""
save_dir = Path(save_dir)
json_path = save_dir / f"{prefix}_trajectory.json"
with open(json_path, "r") as f:
data = json.load(f)
mt = MultiTrajectory()
mt.block_number = data.get("block_number", 1)
mt.prompt = data.get("prompt", "")
mt.drag_or_animation_select = data.get("drag_or_animation_select", "Drag")
# Load movable_mask
movable_file = data.get("movable_area_mask_file", None)
if movable_file is not None:
mt.movable_mask = read_mask_from_file(save_dir / movable_file)
# Load trajectories
mt.trajectories = []
for traj_data in data.get("trajectories", []):
mt.trajectories.append(Trajectory.load(traj_data, save_dir))
return mt
def transpose_dict_2d(d):
"""Transpose a 2D dict: dict[key1][key2] -> dict[key2][key1]."""
result = {}
for key1, inner in d.items():
for key2, item in inner.items():
result.setdefault(key2, {})[key1] = item
return result