refactor: improve documentation, db5 preprocessing, and removed unused downstream tasks
d9b6f8e unverified | import os | |
| from typing import Tuple | |
| import h5py | |
| import numpy as np | |
| import scipy.io | |
| import scipy.signal as signal | |
| from scipy.signal import iirnotch | |
| from tqdm import tqdm | |
| def sequence_to_seconds(seq_len: int, fs: float) -> float: | |
| """Converts a sequence length in samples to time in seconds. | |
| Args: | |
| seq_len (int): The number of samples in the sequence. | |
| fs (float): The sampling frequency in Hz. | |
| Returns: | |
| float: The duration of the sequence in seconds. | |
| """ | |
| return seq_len / fs | |
| def random_amplitude_scale( | |
| sig: np.ndarray, scale_range: Tuple[float, float] = (0.9, 1.1) | |
| ) -> np.ndarray: | |
| """Applies random amplitude scaling to the input signal. | |
| Args: | |
| sig (np.ndarray): The input signal array of shape (T, D). | |
| scale_range (Tuple[float, float], optional): The range [min, max] for the scaling factor. | |
| Defaults to (0.9, 1.1). | |
| Returns: | |
| np.ndarray: The scaled signal array. | |
| """ | |
| scale = np.random.uniform(*scale_range) | |
| return sig * scale | |
| def random_time_jitter(sig: np.ndarray, jitter_ratio: float = 0.01) -> np.ndarray: | |
| """Adds random Gaussian noise (jitter) to the input signal. | |
| Args: | |
| sig (np.ndarray): The input signal array of shape (T, D). | |
| jitter_ratio (float, optional): The ratio to scale the noise relative to | |
| each channel's standard deviation. Defaults to 0.01. | |
| Returns: | |
| np.ndarray: The signal with added jitter. | |
| """ | |
| T, D = sig.shape | |
| std_ch = np.std(sig, axis=0) | |
| noise = np.random.randn(T, D) * (jitter_ratio * std_ch) | |
| return sig + noise | |
| def random_channel_dropout(sig: np.ndarray, dropout_prob: float = 0.05) -> np.ndarray: | |
| """Randomly zeros out channels in the signal based on a probability. | |
| Args: | |
| sig (np.ndarray): The input signal array of shape (T, D). | |
| dropout_prob (float, optional): Probability of dropping each channel. | |
| Defaults to 0.05. | |
| Returns: | |
| np.ndarray: The signal with dropped channels. | |
| """ | |
| T, D = sig.shape | |
| mask = np.random.rand(D) < dropout_prob | |
| sig[:, mask] = 0.0 | |
| return sig | |
| def augment_one_sample(seg: np.ndarray) -> np.ndarray: | |
| """Applies a sequence of random augmentations to a single signal segment. | |
| Args: | |
| seg (np.ndarray): Single signal segment of shape (window_size, n_ch). | |
| Returns: | |
| np.ndarray: The augmented signal segment. | |
| """ | |
| out = seg.copy() | |
| out = random_amplitude_scale(out, (0.9, 1.1)) | |
| out = random_time_jitter(out, 0.01) | |
| out = random_channel_dropout(out, 0.05) | |
| return out | |
| def augment_train_data( | |
| data: np.ndarray, labels: np.ndarray, factor: int = 3 | |
| ) -> Tuple[np.ndarray, np.ndarray]: | |
| """Augments the training dataset by creating multiple versions of each sample. | |
| Args: | |
| data (np.ndarray): The input dataset of shape (N, window_size, n_ch). | |
| labels (np.ndarray): The corresponding labels of shape (N,). | |
| factor (int, optional): The number of augmented versions to create for each sample. | |
| Defaults to 3. | |
| Returns: | |
| Tuple[np.ndarray, np.ndarray]: A tuple containing: | |
| - The augmented dataset. | |
| - The augmented labels. | |
| """ | |
| if factor <= 0 or data.shape[0] == 0: | |
| return data, labels | |
| aug_segs = [data] | |
| aug_lbls = [labels] | |
| N = data.shape[0] | |
| for i in tqdm(range(N), desc="Augmenting training data"): | |
| seg = data[i] # [window_size, n_ch] | |
| lab = labels[i] | |
| for _ in range(factor): | |
| aug_segs.append(augment_one_sample(seg)[None, ...]) | |
| aug_lbls.append([lab]) | |
| new_data = np.concatenate(aug_segs, axis=0) | |
| new_labels = np.concatenate(aug_lbls, axis=0).ravel() | |
| return new_data, new_labels | |
| def notch_filter( | |
| data: np.ndarray, notch_freq: float = 50.0, Q: float = 30.0, fs: float = 200.0 | |
| ) -> np.ndarray: | |
| """Applies a notch filter to remove power line interference. | |
| Args: | |
| data (np.ndarray): The input signal array of shape (T, D). | |
| notch_freq (float, optional): The frequency to be removed (e.g., 50Hz or 60Hz). | |
| Defaults to 50.0. | |
| Q (float, optional): The quality factor. Defaults to 30.0. | |
| fs (float, optional): The sampling frequency of the signal. Defaults to 200.0. | |
| Returns: | |
| np.ndarray: The filtered signal array. | |
| """ | |
| b, a = iirnotch(notch_freq, Q, fs) | |
| out = np.zeros_like(data) | |
| for ch in range(data.shape[1]): | |
| out[:, ch] = signal.filtfilt(b, a, data[:, ch]) | |
| return out | |
| def bandpass_filter_emg( | |
| emg: np.ndarray, | |
| lowcut: float = 20.0, | |
| highcut: float = 90.0, | |
| fs: float = 200.0, | |
| order: int = 4, | |
| ) -> np.ndarray: | |
| """Applies a Butterworth bandpass filter to the EMG signal. | |
| Args: | |
| emg (np.ndarray): The input signal array of shape (T, D). | |
| lowcut (float, optional): Lower bound of the passband in Hz. Defaults to 20.0. | |
| highcut (float, optional): Upper bound of the passband in Hz. Defaults to 90.0. | |
| fs (float, optional): The sampling frequency of the signal. Defaults to 200.0. | |
| order (int, optional): The order of the filter. Defaults to 4. | |
| Returns: | |
| np.ndarray: The bandpass filtered signal array. | |
| """ | |
| nyq = 0.5 * fs | |
| low = lowcut / nyq | |
| high = highcut / nyq | |
| b, a = signal.butter(order, [low, high], btype="bandpass") | |
| out = np.zeros_like(emg) | |
| for c in range(emg.shape[1]): | |
| out[:, c] = signal.filtfilt(b, a, emg[:, c]) | |
| return out | |
| def process_emg_features( | |
| emg: np.ndarray, | |
| label: np.ndarray, | |
| rerep: np.ndarray, | |
| window_size: int = 1024, | |
| stride: int = 512, | |
| ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: | |
| """Segments raw EMG signals into overlapping windows. | |
| Args: | |
| emg (np.ndarray): Raw EMG data of shape (T, n_ch). | |
| label (np.ndarray): Gesture labels of shape (T,). | |
| rerep (np.ndarray): Repetition indices of shape (T,). | |
| window_size (int, optional): Number of samples per window. Defaults to 1024. | |
| stride (int, optional): Number of samples to shift between windows. Defaults to 512. | |
| Returns: | |
| Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing: | |
| - windowed segments (N, window_size, n_ch). | |
| - labels for each window (N,). | |
| - repetition indices for each window (N,). | |
| """ | |
| segs, lbls, reps = [], [], [] | |
| N = len(label) | |
| for start in range(0, N, stride): | |
| end = start + window_size | |
| if end > N: | |
| cut = emg[start:N] | |
| pad = np.zeros((end - N, emg.shape[1])) | |
| win = np.vstack([cut, pad]) | |
| else: | |
| win = emg[start:end] | |
| segs.append(win) | |
| lbls.append(label[start]) | |
| reps.append(rerep[start]) | |
| return np.array(segs), np.array(lbls), np.array(reps) | |
| def main(): | |
| import argparse | |
| args = argparse.ArgumentParser(description="Process EMG data from DB5.") | |
| args.add_argument("--download_data", action="store_true") | |
| args.add_argument("--data_dir", type=str) | |
| args.add_argument("--save_dir", type=str) | |
| args.add_argument( | |
| "--seq_len", type=int, help="Size of the window in samples for segmentation." | |
| ) | |
| args.add_argument( | |
| "--stride", | |
| type=int, | |
| help="Step size between windows in samples for segmentation.", | |
| ) | |
| args = args.parse_args() | |
| data_dir = args.data_dir | |
| save_dir = args.save_dir | |
| os.makedirs(save_dir, exist_ok=True) | |
| # download data if requested | |
| if args.download_data: | |
| # https://ninapro.hevs.ch/instructions/DB5.html | |
| len_data = range(1, 11) # 1–10 | |
| base_url = "https://ninapro.hevs.ch/files/DB5_Preproc/" | |
| # download and unzip | |
| for i in len_data: | |
| url = f"{base_url}s{i}.zip" | |
| os.system(f"wget -P {data_dir} {url}") | |
| os.system(f"unzip -o {data_dir}/s{i}.zip -d {data_dir}") | |
| os.system(f"rm {data_dir}/s{i}.zip") | |
| print(f"Downloaded and unzipped subject {i}\n{data_dir}/s{i}.zip") | |
| fs = 200.0 # original sampling rate | |
| window_size, stride = args.seq_len, args.stride | |
| window_seconds = sequence_to_seconds(window_size, fs) | |
| print(f"Window size: {window_size} samples ({window_seconds:.2f} seconds)") | |
| train_reps = [1, 3, 4, 6] | |
| val_reps = [2] | |
| test_reps = [5] | |
| all_data = {"train": [], "val": [], "test": []} | |
| all_lbls = {"train": [], "val": [], "test": []} | |
| for subj in sorted(os.listdir(data_dir)): | |
| if subj.startswith("h5"): | |
| # Skip folders created by this script when run multiple times | |
| continue | |
| subj_path = os.path.join(data_dir, subj) | |
| if not os.path.isdir(subj_path): | |
| continue | |
| print(f"Processing subject {subj}...") | |
| for mat in sorted(os.listdir(subj_path)): | |
| if not mat.endswith(".mat"): | |
| continue | |
| dd = scipy.io.loadmat(os.path.join(subj_path, mat)) | |
| emg = dd["emg"] # [N,16] | |
| label = dd["restimulus"].ravel().astype(int) | |
| rerep = dd["rerepetition"].ravel().astype(int) | |
| # label shift by exercise | |
| if "E2" in mat: | |
| label = np.where(label != 0, label + 12, 0) | |
| elif "E3" in mat: | |
| label = np.where(label != 0, label + 29, 0) | |
| # filtering at original 200 Hz | |
| emg_filt = bandpass_filter_emg(emg, 20, 90, fs=fs) | |
| emg_filt = notch_filter(emg_filt, 50, 30, fs=fs) | |
| # z-score | |
| emg_z = (emg_filt - emg_filt.mean(axis=0)) / emg_filt.std(axis=0, ddof=1) | |
| # segment | |
| segs, lbls, reps = process_emg_features( | |
| emg_z, label, rerep, window_size, stride | |
| ) | |
| # split by repetition index | |
| for seg, lab, rp in zip(segs, lbls, reps): | |
| if rp in train_reps: | |
| all_data["train"].append(seg) | |
| all_lbls["train"].append(lab) | |
| elif rp in val_reps: | |
| all_data["val"].append(seg) | |
| all_lbls["val"].append(lab) | |
| elif rp in test_reps: | |
| all_data["test"].append(seg) | |
| all_lbls["test"].append(lab) | |
| # stack, augment train, transpose, save, and print stats | |
| stats = {} | |
| for split in ["train", "val", "test"]: | |
| X = np.stack(all_data[split], axis=0) # [N, window_size, ch] | |
| y = np.array(all_lbls[split], dtype=int) | |
| if split == "train": | |
| X, y = augment_train_data(X, y, factor=3) | |
| # transpose to [N, ch, window_size] | |
| X = X.transpose(0, 2, 1) | |
| # save | |
| with h5py.File(os.path.join(save_dir, f"{split}.h5"), "w") as hf: | |
| hf.create_dataset("data", data=X) | |
| hf.create_dataset("label", data=y) | |
| # compute stats | |
| uniq, cnt = np.unique(y, return_counts=True) | |
| stats[split] = (X.shape, dict(zip(uniq.tolist(), cnt.tolist()))) | |
| # print stats | |
| for split, (shape, dist) in stats.items(): | |
| print(f"\n{split} → X={shape} [N, C, T]\nlabel distribution:") | |
| for lab, count in dist.items(): | |
| print(f" label {lab}: {count} samples") | |
| if __name__ == "__main__": | |
| main() | |