Skip to content

API Reference

This page gives an overview of all public pyxations objects, functions and methods.

pre_processing

PreProcessing

Pyxations preprocessing: trial segmentation, quality flags, and saccade direction. All mutating functions are safe (copy-aware) and validate required columns.

Tables (pd.DataFrame) expected: samples: typically contains 'tSample' (ms), gaze columns (e.g., 'LX','LY','RX','RY' or 'X','Y') fixations: typically contains 'tStart','tEnd' and optional 'xAvg','yAvg' saccades: typically contains 'tStart','tEnd','xStart','yStart','xEnd','yEnd' blinks: typically contains 'tStart','tEnd' (optional) user_messages: must contain 'timestamp','message'

New columns created: - All tables after trialing: 'phase', 'trial_number', 'trial_label' (optional) - samples/fixations/saccades: 'bad' (bool) after bad_samples() - saccades: 'deg' (float degrees), 'dir' (str) after saccades_direction()

Source code in pyxations/pre_processing.py
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
class PreProcessing:
    """
    Pyxations preprocessing: trial segmentation, quality flags, and saccade direction.
    All mutating functions are safe (copy-aware) and validate required columns.

    Tables (pd.DataFrame) expected:
        samples:   typically contains 'tSample' (ms), gaze columns (e.g., 'LX','LY','RX','RY' or 'X','Y')
        fixations: typically contains 'tStart','tEnd' and optional 'xAvg','yAvg'
        saccades:  typically contains 'tStart','tEnd','xStart','yStart','xEnd','yEnd'
        blinks:    typically contains 'tStart','tEnd' (optional)
        user_messages: must contain 'timestamp','message'

    New columns created:
        - All tables after trialing: 'phase', 'trial_number', 'trial_label' (optional)
        - samples/fixations/saccades: 'bad' (bool) after bad_samples()
        - saccades: 'deg' (float degrees), 'dir' (str) after saccades_direction()
    """

    VERSION = "0.2.0"

    def __init__(
        self,
        samples: pd.DataFrame,
        fixations: pd.DataFrame,
        saccades: pd.DataFrame,
        blinks: pd.DataFrame,
        user_messages: pd.DataFrame,
        session_path: PathLike,
        metadata: Optional[SessionMetadata] = None,
    ):
        self.samples = samples.copy()
        self.fixations = fixations.copy()
        self.saccades = saccades.copy()
        self.blinks = blinks.copy()
        self.user_messages = user_messages.copy()
        self.session_path = Path(session_path)
        self.metadata = metadata or SessionMetadata()

        # Normalize dtypes where possible (strings for messages)
        if "message" in self.user_messages.columns:
            self.user_messages["message"] = self.user_messages["message"].astype(str)

    # ------------------------------- Utilities ------------------------------- #

    @staticmethod
    def _require_columns(
        df: pd.DataFrame, cols: Sequence[str], context: str
    ) -> None:
        missing = [c for c in cols if c not in df.columns]
        if missing:
            raise ValueError(
                f"[{context}] Missing required columns: {missing}. "
                f"Available: {list(df.columns)}"
            )

    @staticmethod
    def _assert_nonoverlap(starts: Sequence[int], ends: Sequence[int], key: str, session: Path) -> None:
        if len(starts) != len(ends):
            raise ValueError(
                f"[{key}] start_times and end_times must have the same length, "
                f"got {len(starts)} vs {len(ends)} in session: {session}"
            )
        for i, (s, e) in enumerate(zip(starts, ends)):
            if not (s < e):
                raise ValueError(
                    f"[{key}] Non-positive interval at trial {i}: start={s}, end={e} "
                    f"in session: {session}"
                )
            if i < len(starts) - 1:
                if e > starts[i + 1]:
                    raise ValueError(
                        f"[{key}] Overlapping trials {i}{i+1}: end[i]={e} > start[i+1]={starts[i+1]} "
                        f"in session: {session}"
                    )

    @staticmethod
    def _ensure_columns_exist(df: pd.DataFrame, cols: Sequence[str]) -> List[str]:
        """Return the subset of 'cols' that actually exist in df."""
        return [c for c in cols if c in df.columns]

    def _save_json_sidecar(self, obj: dict, filename: str) -> None:
        outdir = self.session_path
        outdir.mkdir(parents=True, exist_ok=True)
        with open(outdir / filename, "w", encoding="utf-8") as f:
            json.dump(obj, f, indent=2, ensure_ascii=False)

    # ---------------------------- Public API: Meta ---------------------------- #

    def set_metadata(
        self,
        coords_unit: Optional[str] = None,
        time_unit: Optional[str] = None,
        pupil_unit: Optional[str] = None,
        screen_width: Optional[int] = None,
        screen_height: Optional[int] = None,
        **extra,
    ) -> None:
        """Update session-level metadata used in bounds checks and documentation."""
        if coords_unit is not None:
            self.metadata.coords_unit = coords_unit
        if time_unit is not None:
            self.metadata.time_unit = time_unit
        if pupil_unit is not None:
            self.metadata.pupil_unit = pupil_unit
        if screen_width is not None:
            self.metadata.screen_width = screen_width
        if screen_height is not None:
            self.metadata.screen_height = screen_height
        self.metadata.extra.update(extra)

    def save_metadata(self, filename: str = "metadata.json") -> None:
        """Persist metadata next to derivatives for reproducibility."""
        self._save_json_sidecar(self.metadata.to_dict(), filename)

    # ----------------------- Public API: Message Parsing ---------------------- #

    def get_timestamps_from_messages(
        self,
        messages_dict: Dict[str, List[str]],
        *,
        case_insensitive: bool = True,
        use_regex: bool = True,
        return_match_token: bool = False,
    ) -> Dict[str, List[int]]:
        """
        Extract ordered timestamps per phase by matching message substrings/patterns.

        Parameters
        ----------
        messages_dict : dict
            e.g., {'trial': ['TRIAL_START', 'BEGIN_TRIAL'], 'stim': ['STIM_ONSET']}
        case_insensitive : bool
            If True, ignore case during matching.
        use_regex : bool
            If True, treat entries as regex patterns joined by '|'; otherwise escape literals.
        return_match_token : bool
            If True, also creates/updates a 'matched_token' column with the first matched pattern.

        Returns
        -------
        Dict[str, List[int]]
            Ordered timestamps in ms for each key.
        """
        df = self.user_messages
        self._require_columns(df, ["timestamp", "message"], "get_timestamps_from_messages")

        timestamps_dict: Dict[str, List[int]] = {}
        flags = re.I if case_insensitive else 0

        # Prepare an optional matched_token column for traceability
        if return_match_token and "matched_token" not in df.columns:
            df = df.copy()
            df["matched_token"] = pd.Series([None] * len(df), index=df.index)

        for key, tokens in messages_dict.items():
            if not tokens:
                raise ValueError(f"[{key}] Empty token list passed to get_timestamps_from_messages.")
            parts = tokens if use_regex else [re.escape(t) for t in tokens]
            pat = re.compile("|".join(parts), flags=flags)

            hits = df[df["message"].str.contains(pat, regex=True, na=False)].copy()
            hits.sort_values(by="timestamp", inplace=True)

            if return_match_token and not hits.empty:
                # record which token matched first for each hit
                def _which(m: str) -> Optional[str]:
                    for t in tokens:
                        if (re.search(t, m, flags=flags) if use_regex else re.search(re.escape(t), m, flags=flags)):
                            return t
                    return None

                hits["matched_token"] = hits["message"].apply(_which)
                # write back those rows (optional traceability)
                df.loc[hits.index, "matched_token"] = hits["matched_token"]

            stamps = hits["timestamp"].astype(int).tolist()
            if len(stamps) == 0:
                raise ValueError(
                    f"[{key}] No timestamps found for messages {tokens} "
                    f"in session: {self.session_path}"
                )
            timestamps_dict[key] = stamps

        # Persist updated matched_token if requested
        if return_match_token:
            self.user_messages = df

        return timestamps_dict

    # ---------------------- Public API: Trial Segmentation -------------------- #

    def split_all_into_trials(
        self,
        start_times: Dict[str, List[int]],
        end_times: Dict[str, List[int]],
        trial_labels: Optional[Dict[str, List[str]]] = None,
        *,
        allow_open_last: bool = True,
        require_nonoverlap: bool = True,
    ) -> None:
        """Segment samples/fixations/saccades/blinks using explicit times."""
        for df in (self.samples, self.fixations, self.saccades, self.blinks):
            self._split_into_trials_df(
                df, start_times, end_times, trial_labels,
                allow_open_last=allow_open_last,
                require_nonoverlap=require_nonoverlap,
            )

    def split_all_into_trials_by_msgs(
        self,
        start_msgs: Dict[str, List[str]],
        end_msgs: Dict[str, List[str]],
        trial_labels: Optional[Dict[str, List[str]]] = None,
        **msg_kwargs,
    ) -> None:
        """Segment tables using start and end message patterns."""
        starts = self.get_timestamps_from_messages(start_msgs, **msg_kwargs)
        ends = self.get_timestamps_from_messages(end_msgs, **msg_kwargs)
        self.split_all_into_trials(starts, ends, trial_labels)

    def split_all_into_trials_by_durations(
        self,
        start_msgs: Dict[str, List[str]],
        durations: Dict[str, List[int]],
        trial_labels: Optional[Dict[str, List[str]]] = None,
        **msg_kwargs,
    ) -> None:
        """Segment using start message patterns and per-trial durations (ms)."""
        starts = self.get_timestamps_from_messages(start_msgs, **msg_kwargs)
        end_times: Dict[str, List[int]] = {}
        for key, durs in durations.items():
            s = starts.get(key, [])
            if len(durs) < len(s):
                raise ValueError(
                    f"[{key}] Provided {len(durs)} durations but found {len(s)} start times "
                    f"in session: {self.session_path}"
                )
            end_times[key] = [st + du for st, du in zip(s, durs)]
        self.split_all_into_trials(starts, end_times, trial_labels)

    def _split_into_trials_df(
        self,
        data: pd.DataFrame,
        start_times: Dict[str, List[int]],
        end_times: Dict[str, List[int]],
        trial_labels: Optional[Dict[str, List[str]]] = None,
        *,
        allow_open_last: bool = True,
        require_nonoverlap: bool = True,
    ) -> None:
        """
        Core segmentation for a single table. Works with 'tSample' OR ('tStart','tEnd').
        Adds 'phase', 'trial_number', 'trial_label'.
        """
        if data is self.samples:
            time_mode = "sample"
            self._require_columns(data, ["tSample"], "split_into_trials(samples)")
        else:
            # events (fixations/saccades/blinks)
            time_mode = "event"
            self._require_columns(data, ["tStart", "tEnd"], "split_into_trials(events)")

        df = data.copy()
        # Initialize columns deterministically
        df["phase"] = ""
        df["trial_number"] = -1
        df["trial_label"] = ""

        for key in start_times.keys():
            start_list = list(start_times[key])
            end_list = list(end_times[key])

            # Discard starts after last end (common partial last-trial artifact)
            if allow_open_last and end_list:
                last_end = end_list[-1]
                start_list = [st for st in start_list if st < last_end]

            # Sanity checks
            if require_nonoverlap:
                self._assert_nonoverlap(start_list, end_list, key, self.session_path)
            elif len(start_list) != len(end_list):
                raise ValueError(
                    f"[{key}] start_times and end_times length mismatch: {len(start_list)} vs {len(end_list)} "
                    f"in session: {self.session_path}"
                )

            labels = trial_labels.get(key) if (trial_labels and key in trial_labels) else None
            if labels is not None and len(labels) != len(start_list):
                raise ValueError(
                    f"[{key}] Computed {len(start_list)} trials but got {len(labels)} trial labels "
                    f"in session: {self.session_path}"
                )

            # Apply segmentation
            if time_mode == "sample":
                t = df["tSample"].values
                for i, (st, en) in enumerate(zip(start_list, end_list)):
                    mask = (t >= st) & (t <= en)
                    if not np.any(mask):
                        continue
                    df.loc[mask, "trial_number"] = i
                    df.loc[mask, "phase"] = str(key)
                    if labels is not None:
                        df.loc[mask, "trial_label"] = labels[i]
            else:
                t0 = df["tStart"].values
                t1 = df["tEnd"].values
                for i, (st, en) in enumerate(zip(start_list, end_list)):
                    mask = (t0 >= st) & (t1 <= en)
                    if not np.any(mask):
                        continue
                    df.loc[mask, "trial_number"] = i
                    df.loc[mask, "phase"] = str(key)
                    if labels is not None:
                        df.loc[mask, "trial_label"] = labels[i]

        # Commit
        if data is self.samples:
            self.samples = df
        elif data is self.fixations:
            self.fixations = df
        elif data is self.saccades:
            self.saccades = df
        elif data is self.blinks:
            self.blinks = df

    # ------------------------- Public API: QC / Flags ------------------------- #

    def bad_samples(
        self,
        screen_height: Optional[int] = None,
        screen_width: Optional[int] = None,
        *,
        mark_nan_as_bad: bool = True,
        inclusive_bounds: bool = True,
    ) -> None:
        """
        Mark rows as 'bad' if any available coordinate falls outside screen bounds.
        Applies to samples, fixations, saccades. (Blinks unaffected.)

        If width/height not provided, will use metadata.screen_* if available.
        """
        H = screen_height if screen_height is not None else self.metadata.screen_height
        W = screen_width if screen_width is not None else self.metadata.screen_width
        if H is None or W is None:
            raise ValueError(
                "bad_samples requires screen_height and screen_width (either passed "
                "or set via set_metadata())."
            )

        def _mark(df: pd.DataFrame) -> pd.DataFrame:
            d = df.copy()

            # Gather candidate coordinate columns if present
            coord_cols = self._ensure_columns_exist(
                d,
                [
                    "LX", "LY", "RX", "RY", "X", "Y",
                    "xStart", "xEnd", "yStart", "yEnd", "xAvg", "yAvg",
                ],
            )
            if not coord_cols:
                # If no coords present, default to 'not bad'
                if "bad" not in d.columns:
                    d["bad"] = False
                return d

            xcols = [c for c in coord_cols if c.lower().startswith("x")]
            ycols = [c for c in coord_cols if c.lower().startswith("y")]

            # Validity masks for each axis
            if inclusive_bounds:
                valid_w = np.logical_and.reduce([d[c].ge(0) & d[c].le(W) for c in xcols]) if xcols else True
                valid_h = np.logical_and.reduce([d[c].ge(0) & d[c].le(H) for c in ycols]) if ycols else True
            else:
                valid_w = np.logical_and.reduce([d[c].gt(0) & d[c].lt(W) for c in xcols]) if xcols else True
                valid_h = np.logical_and.reduce([d[c].gt(0) & d[c].lt(H) for c in ycols]) if ycols else True

            bad = ~(valid_w & valid_h)
            if mark_nan_as_bad:
                bad |= d[coord_cols].isna().any(axis=1)

            d["bad"] = bad.values
            return d

        self.samples = _mark(self.samples)
        self.fixations = _mark(self.fixations)
        self.saccades = _mark(self.saccades)

    # ---------------------- Public API: Saccade Direction --------------------- #

    def saccades_direction(self, tol_deg: float = 15.0) -> None:
        """
        Compute saccade angle (deg) and cardinal direction with tolerance bands.

        Parameters
        ----------
        tol_deg : float
            Half-width of the acceptance band around 0°, ±90°, and ±180°
            for classifying right/left/up/down.
        """
        df = self.saccades.copy()
        self._require_columns(
            df, ["xStart", "xEnd", "yStart", "yEnd"], "saccades_direction"
        )

        x_dif = df["xEnd"].astype(float) - df["xStart"].astype(float)
        y_dif = df["yEnd"].astype(float) - df["yStart"].astype(float)
        deg = np.degrees(np.arctan2(y_dif.to_numpy(), x_dif.to_numpy()))
        df["deg"] = deg.astype(float)

        # Tolerant direction bins
        right = (-tol_deg < df["deg"]) & (df["deg"] < tol_deg)
        left = (df["deg"] > 180 - tol_deg) | (df["deg"] < -180 + tol_deg)
        down = ((90 - tol_deg) < df["deg"]) & (df["deg"] < (90 + tol_deg))
        up = ((-90 - tol_deg) < df["deg"]) & (df["deg"] < (-90 + tol_deg))

        df["dir"] = ""
        df.loc[right, "dir"] = "right"
        df.loc[left, "dir"] = "left"
        df.loc[down, "dir"] = "down"
        df.loc[up, "dir"] = "up"

        self.saccades = df

    # -------------------------- Public API: Orchestrator ---------------------- #

    def process(
        self,
        functions_and_params: Dict[str, Dict],
        *,
        log_recipe: bool = True,
        recipe_filename: str = "preprocessing_recipe.json",
        provenance_filename: str = "preprocessing_provenance.json",
    ) -> None:
        """
        Run a declarative preprocessing recipe, e.g.:
            pp.process({
                "split_all_into_trials_by_msgs": {
                    "start_msgs": {"trial": ["TRIAL_START"]},
                    "end_msgs": {"trial": ["TRIAL_END"]},
                },
                "bad_samples": {"screen_height": 1080, "screen_width": 1920},
                "saccades_direction": {"tol_deg": 15},
            })

        Unknown function names raise a helpful error.
        """
        # Optional: save the declared recipe for exact reproducibility
        if log_recipe:
            recipe_obj = {
                "declared_recipe": functions_and_params,
                "tool_version": self.VERSION,
                "timestamp_utc": datetime.now(timezone.utc).isoformat(),
                "session_path": str(self.session_path),
            }
            self._save_json_sidecar(recipe_obj, recipe_filename)

        for func_name, params in functions_and_params.items():
            if not hasattr(self, func_name):
                raise AttributeError(
                    f"Unknown preprocessing function '{func_name}'. "
                    f"Available: {[m for m in dir(self) if not m.startswith('_')]}"
                )
            fn = getattr(self, func_name)
            if not isinstance(params, dict):
                raise TypeError(
                    f"Parameters for '{func_name}' must be a dict, got {type(params)}"
                )
            fn(**params)

        # Save lightweight provenance after successful run
        if log_recipe:
            prov = {
                "completed_recipe": list(functions_and_params.keys()),
                "tool_version": self.VERSION,
                "timestamp_utc": datetime.now(timezone.utc).isoformat(),
                "metadata": self.metadata.to_dict(),
            }
            self._save_json_sidecar(prov, provenance_filename)

bad_samples(screen_height=None, screen_width=None, *, mark_nan_as_bad=True, inclusive_bounds=True)

Mark rows as 'bad' if any available coordinate falls outside screen bounds. Applies to samples, fixations, saccades. (Blinks unaffected.)

If width/height not provided, will use metadata.screen_* if available.

Source code in pyxations/pre_processing.py
def bad_samples(
    self,
    screen_height: Optional[int] = None,
    screen_width: Optional[int] = None,
    *,
    mark_nan_as_bad: bool = True,
    inclusive_bounds: bool = True,
) -> None:
    """
    Mark rows as 'bad' if any available coordinate falls outside screen bounds.
    Applies to samples, fixations, saccades. (Blinks unaffected.)

    If width/height not provided, will use metadata.screen_* if available.
    """
    H = screen_height if screen_height is not None else self.metadata.screen_height
    W = screen_width if screen_width is not None else self.metadata.screen_width
    if H is None or W is None:
        raise ValueError(
            "bad_samples requires screen_height and screen_width (either passed "
            "or set via set_metadata())."
        )

    def _mark(df: pd.DataFrame) -> pd.DataFrame:
        d = df.copy()

        # Gather candidate coordinate columns if present
        coord_cols = self._ensure_columns_exist(
            d,
            [
                "LX", "LY", "RX", "RY", "X", "Y",
                "xStart", "xEnd", "yStart", "yEnd", "xAvg", "yAvg",
            ],
        )
        if not coord_cols:
            # If no coords present, default to 'not bad'
            if "bad" not in d.columns:
                d["bad"] = False
            return d

        xcols = [c for c in coord_cols if c.lower().startswith("x")]
        ycols = [c for c in coord_cols if c.lower().startswith("y")]

        # Validity masks for each axis
        if inclusive_bounds:
            valid_w = np.logical_and.reduce([d[c].ge(0) & d[c].le(W) for c in xcols]) if xcols else True
            valid_h = np.logical_and.reduce([d[c].ge(0) & d[c].le(H) for c in ycols]) if ycols else True
        else:
            valid_w = np.logical_and.reduce([d[c].gt(0) & d[c].lt(W) for c in xcols]) if xcols else True
            valid_h = np.logical_and.reduce([d[c].gt(0) & d[c].lt(H) for c in ycols]) if ycols else True

        bad = ~(valid_w & valid_h)
        if mark_nan_as_bad:
            bad |= d[coord_cols].isna().any(axis=1)

        d["bad"] = bad.values
        return d

    self.samples = _mark(self.samples)
    self.fixations = _mark(self.fixations)
    self.saccades = _mark(self.saccades)

get_timestamps_from_messages(messages_dict, *, case_insensitive=True, use_regex=True, return_match_token=False)

Extract ordered timestamps per phase by matching message substrings/patterns.

Parameters:

Name Type Description Default
messages_dict dict

e.g., {'trial': ['TRIAL_START', 'BEGIN_TRIAL'], 'stim': ['STIM_ONSET']}

required
case_insensitive bool

If True, ignore case during matching.

True
use_regex bool

If True, treat entries as regex patterns joined by '|'; otherwise escape literals.

True
return_match_token bool

If True, also creates/updates a 'matched_token' column with the first matched pattern.

False

Returns:

Type Description
Dict[str, List[int]]

Ordered timestamps in ms for each key.

Source code in pyxations/pre_processing.py
def get_timestamps_from_messages(
    self,
    messages_dict: Dict[str, List[str]],
    *,
    case_insensitive: bool = True,
    use_regex: bool = True,
    return_match_token: bool = False,
) -> Dict[str, List[int]]:
    """
    Extract ordered timestamps per phase by matching message substrings/patterns.

    Parameters
    ----------
    messages_dict : dict
        e.g., {'trial': ['TRIAL_START', 'BEGIN_TRIAL'], 'stim': ['STIM_ONSET']}
    case_insensitive : bool
        If True, ignore case during matching.
    use_regex : bool
        If True, treat entries as regex patterns joined by '|'; otherwise escape literals.
    return_match_token : bool
        If True, also creates/updates a 'matched_token' column with the first matched pattern.

    Returns
    -------
    Dict[str, List[int]]
        Ordered timestamps in ms for each key.
    """
    df = self.user_messages
    self._require_columns(df, ["timestamp", "message"], "get_timestamps_from_messages")

    timestamps_dict: Dict[str, List[int]] = {}
    flags = re.I if case_insensitive else 0

    # Prepare an optional matched_token column for traceability
    if return_match_token and "matched_token" not in df.columns:
        df = df.copy()
        df["matched_token"] = pd.Series([None] * len(df), index=df.index)

    for key, tokens in messages_dict.items():
        if not tokens:
            raise ValueError(f"[{key}] Empty token list passed to get_timestamps_from_messages.")
        parts = tokens if use_regex else [re.escape(t) for t in tokens]
        pat = re.compile("|".join(parts), flags=flags)

        hits = df[df["message"].str.contains(pat, regex=True, na=False)].copy()
        hits.sort_values(by="timestamp", inplace=True)

        if return_match_token and not hits.empty:
            # record which token matched first for each hit
            def _which(m: str) -> Optional[str]:
                for t in tokens:
                    if (re.search(t, m, flags=flags) if use_regex else re.search(re.escape(t), m, flags=flags)):
                        return t
                return None

            hits["matched_token"] = hits["message"].apply(_which)
            # write back those rows (optional traceability)
            df.loc[hits.index, "matched_token"] = hits["matched_token"]

        stamps = hits["timestamp"].astype(int).tolist()
        if len(stamps) == 0:
            raise ValueError(
                f"[{key}] No timestamps found for messages {tokens} "
                f"in session: {self.session_path}"
            )
        timestamps_dict[key] = stamps

    # Persist updated matched_token if requested
    if return_match_token:
        self.user_messages = df

    return timestamps_dict

process(functions_and_params, *, log_recipe=True, recipe_filename='preprocessing_recipe.json', provenance_filename='preprocessing_provenance.json')

Run a declarative preprocessing recipe, e.g.: pp.process({ "split_all_into_trials_by_msgs": { "start_msgs": {"trial": ["TRIAL_START"]}, "end_msgs": {"trial": ["TRIAL_END"]}, }, "bad_samples": {"screen_height": 1080, "screen_width": 1920}, "saccades_direction": {"tol_deg": 15}, })

Unknown function names raise a helpful error.

Source code in pyxations/pre_processing.py
def process(
    self,
    functions_and_params: Dict[str, Dict],
    *,
    log_recipe: bool = True,
    recipe_filename: str = "preprocessing_recipe.json",
    provenance_filename: str = "preprocessing_provenance.json",
) -> None:
    """
    Run a declarative preprocessing recipe, e.g.:
        pp.process({
            "split_all_into_trials_by_msgs": {
                "start_msgs": {"trial": ["TRIAL_START"]},
                "end_msgs": {"trial": ["TRIAL_END"]},
            },
            "bad_samples": {"screen_height": 1080, "screen_width": 1920},
            "saccades_direction": {"tol_deg": 15},
        })

    Unknown function names raise a helpful error.
    """
    # Optional: save the declared recipe for exact reproducibility
    if log_recipe:
        recipe_obj = {
            "declared_recipe": functions_and_params,
            "tool_version": self.VERSION,
            "timestamp_utc": datetime.now(timezone.utc).isoformat(),
            "session_path": str(self.session_path),
        }
        self._save_json_sidecar(recipe_obj, recipe_filename)

    for func_name, params in functions_and_params.items():
        if not hasattr(self, func_name):
            raise AttributeError(
                f"Unknown preprocessing function '{func_name}'. "
                f"Available: {[m for m in dir(self) if not m.startswith('_')]}"
            )
        fn = getattr(self, func_name)
        if not isinstance(params, dict):
            raise TypeError(
                f"Parameters for '{func_name}' must be a dict, got {type(params)}"
            )
        fn(**params)

    # Save lightweight provenance after successful run
    if log_recipe:
        prov = {
            "completed_recipe": list(functions_and_params.keys()),
            "tool_version": self.VERSION,
            "timestamp_utc": datetime.now(timezone.utc).isoformat(),
            "metadata": self.metadata.to_dict(),
        }
        self._save_json_sidecar(prov, provenance_filename)

saccades_direction(tol_deg=15.0)

Compute saccade angle (deg) and cardinal direction with tolerance bands.

Parameters:

Name Type Description Default
tol_deg float

Half-width of the acceptance band around 0°, ±90°, and ±180° for classifying right/left/up/down.

15.0
Source code in pyxations/pre_processing.py
def saccades_direction(self, tol_deg: float = 15.0) -> None:
    """
    Compute saccade angle (deg) and cardinal direction with tolerance bands.

    Parameters
    ----------
    tol_deg : float
        Half-width of the acceptance band around 0°, ±90°, and ±180°
        for classifying right/left/up/down.
    """
    df = self.saccades.copy()
    self._require_columns(
        df, ["xStart", "xEnd", "yStart", "yEnd"], "saccades_direction"
    )

    x_dif = df["xEnd"].astype(float) - df["xStart"].astype(float)
    y_dif = df["yEnd"].astype(float) - df["yStart"].astype(float)
    deg = np.degrees(np.arctan2(y_dif.to_numpy(), x_dif.to_numpy()))
    df["deg"] = deg.astype(float)

    # Tolerant direction bins
    right = (-tol_deg < df["deg"]) & (df["deg"] < tol_deg)
    left = (df["deg"] > 180 - tol_deg) | (df["deg"] < -180 + tol_deg)
    down = ((90 - tol_deg) < df["deg"]) & (df["deg"] < (90 + tol_deg))
    up = ((-90 - tol_deg) < df["deg"]) & (df["deg"] < (-90 + tol_deg))

    df["dir"] = ""
    df.loc[right, "dir"] = "right"
    df.loc[left, "dir"] = "left"
    df.loc[down, "dir"] = "down"
    df.loc[up, "dir"] = "up"

    self.saccades = df

save_metadata(filename='metadata.json')

Persist metadata next to derivatives for reproducibility.

Source code in pyxations/pre_processing.py
def save_metadata(self, filename: str = "metadata.json") -> None:
    """Persist metadata next to derivatives for reproducibility."""
    self._save_json_sidecar(self.metadata.to_dict(), filename)

set_metadata(coords_unit=None, time_unit=None, pupil_unit=None, screen_width=None, screen_height=None, **extra)

Update session-level metadata used in bounds checks and documentation.

Source code in pyxations/pre_processing.py
def set_metadata(
    self,
    coords_unit: Optional[str] = None,
    time_unit: Optional[str] = None,
    pupil_unit: Optional[str] = None,
    screen_width: Optional[int] = None,
    screen_height: Optional[int] = None,
    **extra,
) -> None:
    """Update session-level metadata used in bounds checks and documentation."""
    if coords_unit is not None:
        self.metadata.coords_unit = coords_unit
    if time_unit is not None:
        self.metadata.time_unit = time_unit
    if pupil_unit is not None:
        self.metadata.pupil_unit = pupil_unit
    if screen_width is not None:
        self.metadata.screen_width = screen_width
    if screen_height is not None:
        self.metadata.screen_height = screen_height
    self.metadata.extra.update(extra)

split_all_into_trials(start_times, end_times, trial_labels=None, *, allow_open_last=True, require_nonoverlap=True)

Segment samples/fixations/saccades/blinks using explicit times.

Source code in pyxations/pre_processing.py
def split_all_into_trials(
    self,
    start_times: Dict[str, List[int]],
    end_times: Dict[str, List[int]],
    trial_labels: Optional[Dict[str, List[str]]] = None,
    *,
    allow_open_last: bool = True,
    require_nonoverlap: bool = True,
) -> None:
    """Segment samples/fixations/saccades/blinks using explicit times."""
    for df in (self.samples, self.fixations, self.saccades, self.blinks):
        self._split_into_trials_df(
            df, start_times, end_times, trial_labels,
            allow_open_last=allow_open_last,
            require_nonoverlap=require_nonoverlap,
        )

split_all_into_trials_by_durations(start_msgs, durations, trial_labels=None, **msg_kwargs)

Segment using start message patterns and per-trial durations (ms).

Source code in pyxations/pre_processing.py
def split_all_into_trials_by_durations(
    self,
    start_msgs: Dict[str, List[str]],
    durations: Dict[str, List[int]],
    trial_labels: Optional[Dict[str, List[str]]] = None,
    **msg_kwargs,
) -> None:
    """Segment using start message patterns and per-trial durations (ms)."""
    starts = self.get_timestamps_from_messages(start_msgs, **msg_kwargs)
    end_times: Dict[str, List[int]] = {}
    for key, durs in durations.items():
        s = starts.get(key, [])
        if len(durs) < len(s):
            raise ValueError(
                f"[{key}] Provided {len(durs)} durations but found {len(s)} start times "
                f"in session: {self.session_path}"
            )
        end_times[key] = [st + du for st, du in zip(s, durs)]
    self.split_all_into_trials(starts, end_times, trial_labels)

split_all_into_trials_by_msgs(start_msgs, end_msgs, trial_labels=None, **msg_kwargs)

Segment tables using start and end message patterns.

Source code in pyxations/pre_processing.py
def split_all_into_trials_by_msgs(
    self,
    start_msgs: Dict[str, List[str]],
    end_msgs: Dict[str, List[str]],
    trial_labels: Optional[Dict[str, List[str]]] = None,
    **msg_kwargs,
) -> None:
    """Segment tables using start and end message patterns."""
    starts = self.get_timestamps_from_messages(start_msgs, **msg_kwargs)
    ends = self.get_timestamps_from_messages(end_msgs, **msg_kwargs)
    self.split_all_into_trials(starts, ends, trial_labels)

SessionMetadata dataclass

Lightweight metadata container saved alongside derivatives.

Source code in pyxations/pre_processing.py
@dataclass
class SessionMetadata:
    """Lightweight metadata container saved alongside derivatives."""
    coords_unit: str = "px"          # 'px' or 'deg'
    time_unit: str = "ms"            # 'ms'
    pupil_unit: str = "arbitrary"
    screen_width: Optional[int] = None
    screen_height: Optional[int] = None
    extra: Dict[str, Union[str, int, float, bool, None]] = field(default_factory=dict)

    def to_dict(self) -> dict:
        return {
            "coords_unit": self.coords_unit,
            "time_unit": self.time_unit,
            "pupil_unit": self.pupil_unit,
            "screen_width": self.screen_width,
            "screen_height": self.screen_height,
            "extra": self.extra,
        }

bids_formatting

dataset_to_bids(target_folder_path, files_folder_path, dataset_name, session_substrings=1, format_name='eyelink')

Convert a dataset to BIDS format.

Args: target_folder_path (str): Path to the folder where the BIDS dataset will be created. files_folder_path (str): Path to the folder containing the EDF files. The EDF files are assumed to have the ID of the subject at the beginning of the file name, separated by an underscore. dataset_name (str): Name of the BIDS dataset. session_substrings (int): Number of substrings to use for the session ID. Default is 1.

Returns: None

Source code in pyxations/bids_formatting.py
def dataset_to_bids(target_folder_path, files_folder_path, dataset_name, session_substrings=1, format_name='eyelink'):
    """
    Convert a dataset to BIDS format.

    Args:
        target_folder_path (str): Path to the folder where the BIDS dataset will be created.
        files_folder_path (str): Path to the folder containing the EDF files.
        The EDF files are assumed to have the ID of the subject at the beginning of the file name, separated by an underscore.
        dataset_name (str): Name of the BIDS dataset.
        session_substrings (int): Number of substrings to use for the session ID. Default is 1.

    Returns:
        None
    """
    converter = get_converter(format_name)

    # Create a metadata tsv file
    metadata = pd.DataFrame(columns=['subject_id', 'old_subject_id'])
    files_folder_path = Path(files_folder_path)
    # List all file paths in the folder
    file_paths = []
    for file_path in files_folder_path.rglob('*'):  # Recursively go through all files
        if file_path.is_file():
            file_paths.append(file_path)

    file_paths = [file for file in file_paths if file.suffix.lower() in converter.relevant_extensions()]

    bids_folder_path = Path(target_folder_path) / dataset_name
    bids_folder_path.mkdir(parents=True, exist_ok=True)

    subj_ids = converter.get_subject_ids(file_paths)

    # If all of the subjects have numerical IDs, sort them numerically, else sort them alphabetically
    if all(subject_id.isdigit() for subject_id in subj_ids):
        subj_ids.sort(key=int)
    else:
        subj_ids.sort()
    new_subj_ids = [str(subject_index).zfill(4) for subject_index in range(1, len(subj_ids) + 1)]

    # Create subfolders for each session for each subject
    for subject_id in new_subj_ids:
        old_subject_id = subj_ids[int(subject_id) - 1]
        for file in file_paths:
            file_name = Path(file).name
            session_id = "_".join("".join(file_name.split(".")[:-1]).split("_")[1:session_substrings + 1])
            converter.move_file_to_bids_folder(file, bids_folder_path, subject_id, old_subject_id, session_id)

        metadata.loc[len(metadata.index)] = [subject_id, old_subject_id]
    # Save metadata to tsv file
    metadata.to_csv(bids_folder_path / "participants.tsv", sep="\t", index=False)
    return bids_folder_path

eye_movement_detection

visualization

utils

Visualization

Source code in pyxations/visualization/visualization.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
class Visualization():
    def __init__(self, derivatives_folder_path,events_detection_algorithm):
        self.derivatives_folder_path = Path(derivatives_folder_path)
        if events_detection_algorithm not in EYE_MOVEMENT_DETECTION_DICT and events_detection_algorithm != 'eyelink':
            raise ValueError(f"Detection algorithm {events_detection_algorithm} not found.")
        self.events_detection_folder = Path(events_detection_algorithm+'_events')

    def scanpath(
        self,
        fixations: pl.DataFrame,
        screen_height: int,
        screen_width: int,
        folder_path: str | Path | None = None,
        tmin: int | None = None,
        tmax: int | None = None,
        saccades: pl.DataFrame | None = None,
        samples: pl.DataFrame | None = None,
        phase_data: dict[str, dict] | None = None,
        display: bool = True,
    ):
        """
        Fast scan‑path visualiser.

        • **Vectorised**: no per‑row Python loops  
        • **Single pass** phase grouping  
        • Uses `BrokenBarHCollection` for fixation spans  
        • Optional asynchronous PNG write via ThreadPoolExecutor (drop‑in‑ready, see comment)

        Parameters
        ----------
        fixations
            Polars DataFrame with at least `tStart`, `duration`, `xAvg`, `yAvg`, `phase`.
        screen_height, screen_width
            Stimulus resolution in pixels.
        folder_path
            Directory where 1 PNG per phase will be stored.  If *None*, nothing is saved.
        tmin, tmax
            Time window in **ms**.  If both `None`, the whole trial is plotted.
        saccades
            Polars DataFrame with `tStart`, `phase`, …  (optional).
        samples
            Polars DataFrame with gaze traces (`tSample`, `LX`, `LY`, `RX`, `RY` or
            `X`, `Y`) (optional).
        phase_data
            Per‑phase extras::

                {
                    "search": {
                        "img_paths": [...],
                        "img_plot_coords": [(x1,y1,x2,y2), ...],
                        "bbox": (x1,y1,x2,y2),
                    },
                    ...
                }

        display
            If *False* the figure canvas is never shown (faster for batch jobs).
        """


        # ------------- small helpers ------------------------------------------------
        def _make_axes(plot_samples: bool):
            if plot_samples:
                fig, (ax_main, ax_gaze) = plt.subplots(
                    2, 1, height_ratios=(4, 1), figsize=(10, 6), sharex=False
                )
            else:
                fig, ax_main = plt.subplots(figsize=(10, 6))
                ax_gaze = None
            ax_main.set_xlim(0, screen_width)
            ax_main.set_ylim(screen_height, 0)
            return fig, ax_main, ax_gaze

        def _maybe_cache_img(path):
            """Load image from disk with a small LRU cache."""

            # Cache hit: move to the end (most recently used)
            if path in _img_cache:
                img = _img_cache.pop(path)
                _img_cache[path] = img
                return img

            # Cache miss: load image
            img = mpimg.imread(path)

            # Optional: reduce memory if image is float64 in [0, 1]
            if isinstance(img, np.ndarray) and img.dtype == np.float64:
                img = (img * 255).clip(0, 255).astype(np.uint8)

            # Insert into cache
            _img_cache[path] = img

            # If cache too big, drop least recently used item
            if len(_img_cache) > _MAX_CACHE_ITEMS:
                _img_cache.popitem(last=False)  # pops the oldest inserted item

            return img

        # ---------------------------------------------------------------------------
        plot_saccades = saccades is not None
        plot_samples = samples is not None
        _img_cache = OrderedDict()
        _MAX_CACHE_ITEMS = 8  # or 5, 10, etc. Tune as you like.

        trial_idx = fixations["trial_number"][0]

        # ---- time filter ----------------------------------------------------------
        if tmin is not None and tmax is not None:
            fixations = fixations.filter(pl.col("tStart").is_between(tmin, tmax))
            if plot_saccades:
                saccades = saccades.filter(pl.col("tStart").is_between(tmin, tmax))
            if plot_samples:
                samples = samples.filter(pl.col("tSample").is_between(tmin, tmax))

        # remove empty phase markings
        fixations = fixations.filter(pl.col("phase") != "")
        if plot_saccades:
            saccades = saccades.filter(pl.col("phase") != "")
        if plot_samples:
            samples = samples.filter(pl.col("phase") != "")

        # ---- split once by phase --------------------------------------------------
        fix_by_phase = fixations.partition_by("phase", as_dict=True)
        sac_by_phase = (
            saccades.partition_by("phase", as_dict=True) if plot_saccades else {}
        )
        samp_by_phase = (
            samples.partition_by("phase", as_dict=True) if plot_samples else {}
        )

        # colour map shared across phases
        cmap = plt.cm.rainbow

        # ---- build & draw ---------------------------------------------------------
        # optional async saver (uncomment if you save hundreds of files)
        from concurrent.futures import ThreadPoolExecutor
        saver = ThreadPoolExecutor(max_workers=4) if folder_path else None

        if not display:
            plt.ioff()

        for phase, phase_fix in fix_by_phase.items():
            if phase_fix.is_empty():
                continue

            # ---------- vectors (zero‑copy) -----------------
            fx, fy, fdur = phase_fix.select(["xAvg", "yAvg", "duration"]).to_numpy().T
            n_fix = fx.size
            fix_idx = np.arange(1, n_fix + 1)

            norm = mplcolors.BoundaryNorm(np.arange(1, n_fix + 2), cmap.N)

            # saccades
            sac_t = (
                sac_by_phase[phase]["tStart"].to_numpy()
                if plot_saccades and phase in sac_by_phase
                else np.empty(0)
            )

            # samples
            if plot_samples and phase in samp_by_phase and samp_by_phase[phase].height:
                samp_phase = samp_by_phase[phase]
                t0 = samp_phase["tSample"][0]
                ts = (samp_phase["tSample"].to_numpy() - t0) 
                get = samp_phase.get_column
                lx = get("LX").to_numpy() if "LX" in samp_phase.columns else None
                ly = get("LY").to_numpy() if "LY" in samp_phase.columns else None
                rx = get("RX").to_numpy() if "RX" in samp_phase.columns else None
                ry = get("RY").to_numpy() if "RY" in samp_phase.columns else None
                gx = get("X").to_numpy() if "X" in samp_phase.columns else None
                gy = get("Y").to_numpy() if "Y" in samp_phase.columns else None
            else:
                t0 = None

            # ---------- figure -----------------------------
            fig, ax_main, ax_gaze = _make_axes(plot_samples and t0 is not None)
            # scatter fixations
            sc = ax_main.scatter(
                fx,
                fy,
                c=fix_idx,
                s=fdur,
                cmap=cmap,
                norm=norm,
                alpha=0.5,
                zorder=2,
            )
            fig.colorbar(
                sc,
                ax=ax_main,
                ticks=[1, n_fix // 2 if n_fix > 2 else n_fix, n_fix],
                fraction=0.046,
                pad=0.04,
            ).set_label("# of fixation")

            # ---------- stimulus imagery / bbox ------------
            if phase_data and phase[0] in phase_data:
                pdict = phase_data[phase[0]]
                coords = pdict.get("img_plot_coords") or []
                bbox = pdict.get('bbox',None) 
                for img_path, box in zip(pdict.get("img_paths", []), coords):

                    ax_main.imshow(_maybe_cache_img(img_path), extent=[box[0], box[2], box[3], box[1]], zorder=0)
                if bbox is not None:
                    x1, y1, x2, y2 = bbox
                    ax_main.plot([x1, x2, x2, x1, x1], [y1, y1, y2, y2, y1], color='red', linewidth=1.5, zorder=3)

            # ---------- gaze traces ------------------------
            if ax_gaze is not None:
                if lx is not None:
                    ax_main.plot(lx, ly, "--", color="C0", zorder=1)
                    ax_gaze.plot(ts, lx, label="Left X")
                    ax_gaze.plot(ts, ly, label="Left Y")
                if rx is not None:
                    ax_main.plot(rx, ry, "--", color="k", zorder=1)
                    ax_gaze.plot(ts, rx, label="Right X")
                    ax_gaze.plot(ts, ry, label="Right Y")
                if gx is not None:
                    ax_main.plot(gx, gy, "--", color="k", zorder=1, alpha=0.6)
                    ax_gaze.plot(ts, gx, label="X")
                    ax_gaze.plot(ts, gy, label="Y")

                # fixation spans
                bars   = np.c_[phase_fix['tStart'].to_numpy() - t0,
                            phase_fix['duration'].to_numpy()]
                height = ax_gaze.get_ylim()[1] - ax_gaze.get_ylim()[0]
                colors = cmap(norm(fix_idx))

                # Draw all bars in one call; no BrokenBarHCollection import needed
                ax_gaze.broken_barh(bars, (0, height), facecolors=colors, alpha=0.4)
                # saccades
                if sac_t.size:
                    ymin, ymax = ax_gaze.get_ylim()
                    ax_gaze.vlines(
                        sac_t - t0,
                        ymin,
                        ymax,
                        colors="red",
                        linestyles="--",
                        linewidth=0.8,
                    )

                # tidy gaze axis
                h, l = ax_gaze.get_legend_handles_labels()
                by_label = {lab: hdl for hdl, lab in zip(h, l)}
                ax_gaze.legend(
                    by_label.values(),
                    by_label.keys(),
                    loc="center left",
                    bbox_to_anchor=(1, 0.5),
                )
                ax_gaze.set_ylabel("Gaze")
                ax_gaze.set_xlabel("Time [s]")

            fig.tight_layout()

            # ---------- save / show ------------------------
            if folder_path:
                scan_name = f"scanpath_{trial_idx}"
                if tmin is not None and tmax is not None:
                    scan_name += f"_{tmin}_{tmax}"
                out = Path(folder_path) / f"{scan_name}_{phase[0]}.png"
                fig.savefig(out, dpi=150)
                if saver:  saver.submit(fig.savefig, out, dpi=150)

            if display:
                plt.show()
            plt.close(fig)

        if not display:
            plt.ion()


    def fix_duration(self,fixations:pl.DataFrame,axs=None):

        ax = axs
        if ax is None:
            fig, ax = plt.subplots()

        ax.hist(fixations.select(pl.col('duration')).to_numpy().ravel(), bins=100, edgecolor='black', linewidth=1.2, density=True)
        ax.set_title('Fixation duration')
        ax.set_xlabel('Time (ms)')
        ax.set_ylabel('Density')


    def sacc_amplitude(self,saccades:pl.DataFrame,axs=None):

        ax = axs
        if ax is None:
            fig, ax = plt.subplots()

        saccades_amp = saccades.select(pl.col('ampDeg')).to_numpy().ravel()
        ax.hist(saccades_amp, bins=100, range=(0, 20), edgecolor='black', linewidth=1.2, density=True)
        ax.set_title('Saccades amplitude')
        ax.set_xlabel('Amplitude (deg)')
        ax.set_ylabel('Density')


    def sacc_direction(self,saccades:pl.DataFrame,axs=None,figs=None):

        ax = axs
        if ax is None:
            fig = plt.figure()
            ax = plt.subplot(polar=True)
        else:
            ax.set_axis_off()
            ax = figs.add_subplot(2, 2, 3, projection='polar')
        if 'deg' not in saccades.columns or 'dir' not in saccades.columns:
            raise ValueError('Compute saccades direction first by using saccades_direction function from the PreProcessing module.')
        # Convert from deg to rad
        saccades_rad = saccades.select(pl.col('deg')).to_numpy().ravel() * np.pi / 180

        n_bins = 24
        ang_hist, bin_edges = np.histogram(saccades_rad, bins=24, density=True)
        bin_centers = [np.mean((bin_edges[i], bin_edges[i+1])) for i in range(len(bin_edges) - 1)]

        bars = ax.bar(bin_centers, ang_hist, width=2*np.pi/n_bins, bottom=0.0, alpha=0.4, edgecolor='black')
        ax.set_title('Saccades direction')
        ax.set_yticklabels([])

        for r, bar in zip(ang_hist, bars):
            bar.set_facecolor(plt.cm.Blues(r / np.max(ang_hist)))


    def sacc_main_sequence(self,saccades:pl.DataFrame,axs=None, hline=None):

        ax = axs
        if ax is None:
            fig, ax = plt.subplots()
        # Logarithmic bins
        XL = np.log10(25)  # Adjusted to fit the xlim
        YL = np.log10(1000)  # Adjusted to fit the ylim

        saccades_peak_vel = saccades.select(pl.col('vPeak')).to_numpy().ravel()
        saccades_amp = saccades.select(pl.col('ampDeg')).to_numpy().ravel()

        # Create a 2D histogram with logarithmic bins
        ax.hist2d(saccades_amp, saccades_peak_vel, bins=[np.logspace(-1, XL, 50), np.logspace(0, YL, 50)])

        if hline:
            ax.hlines(y=hline, xmin=ax.get_xlim()[0], xmax=ax.get_xlim()[1], colors='grey', linestyles='--', label=hline)
            ax.legend()
        ax.set_yscale('log')
        ax.set_xscale('log')
        ax.set_title('Main sequence')
        ax.set_xlabel('Amplitude (deg)')
        ax.set_ylabel('Peak velocity (deg)')
         # Set the limits of the axes
        ax.set_xlim(0.1, 25)
        ax.set_ylim(10, 1000)
        ax.set_aspect('equal')


    def plot_multipanel(
            self,
            fixations: pl.DataFrame,
            saccades: pl.DataFrame,
            display: bool = True
        ) -> None:
        """
        Create a 2×2 multi‑panel diagnostic plot for every non‑empty
        phase label and save it as PNG in
        <derivatives_folder_path>/<events_detection_folder>/plots/.
        """
        # ── paths & matplotlib style ────────────────────────────────
        folder_path: Path = (
            self.derivatives_folder_path
            / self.events_detection_folder
            / "plots"
        )
        folder_path.mkdir(parents=True, exist_ok=True)
        plt.rcParams.update({"font.size": 12})

        # ── drop practice / invalid trials ─────────────────────────
        fixations = fixations.filter(pl.col("trial_number") != -1)
        saccades  = saccades.filter(pl.col("trial_number") != -1)

        # ── collect valid phase labels (skip empty string) ─────────
        phases = (
            fixations
            .select(pl.col("phase").filter(pl.col("phase") != ""))
            .unique()           # unique values in this Series
            .to_series()
            .to_list()          # plain Python list of strings
        )

        # ── one figure per phase ───────────────────────────────────
        for phase in phases:
            fix_phase   = fixations.filter(pl.col("phase") == phase)
            sacc_phase  = saccades.filter(pl.col("phase") == phase)

            fig, axs = plt.subplots(2, 2, figsize=(12, 7))

            self.fix_duration(fix_phase , axs=axs[0, 0])
            self.sacc_main_sequence(sacc_phase, axs=axs[1, 1])
            self.sacc_direction(sacc_phase, axs=axs[1, 0], figs=fig)
            self.sacc_amplitude(sacc_phase, axs=axs[0, 1])

            fig.tight_layout()
            plt.savefig(folder_path / f"multipanel_{phase}.png")
            if display:
                plt.show()
            plt.close()

    def plot_animation(
        self,
        samples: pl.DataFrame,
        screen_height: int,
        screen_width: int,
        video_path: str | Path | None = None,
        background_image_path: str | Path | None = None,
        folder_path: str | Path | None = None,
        tmin: int | None = None,
        tmax: int | None = None,
        seconds_to_show: float | None = None,
        scale_factor: float = 0.5,
        gaze_radius: int = 10,
        gaze_color: tuple = (255, 0, 0),
        fps: float | None = None,
        output_format: str = "matplotlib",
        display: bool = True,
    ):
        """
        Create an animated visualization of eye-tracking data.

        When a video is provided, the animation syncs gaze samples with video frames.
        When no video is provided, gaze points are animated on a grey background or
        a provided background image, using the sample timestamps for timing.

        Parameters
        ----------
        samples
            Polars DataFrame with gaze samples. Must contain 'tSample' and gaze
            position columns ('X', 'Y' or 'LX', 'LY', 'RX', 'RY').
        screen_height, screen_width
            Stimulus resolution in pixels.
        video_path
            Path to a video file. If provided, gaze is overlaid on video frames.
        background_image_path
            Path to a background image. Only used when video_path is None.
            If both are None, a grey background is used.
        folder_path
            Directory where the animation will be saved. If None, nothing is saved.
            The file format depends on `output_format`.
        tmin, tmax
            Time window in **ms**. If both None, the whole trial is plotted.
        seconds_to_show
            Limit the animation to the first N seconds. If None, shows all available data.
        scale_factor
            Resolution scaling factor (1.0 = original, 0.5 = half resolution).
        gaze_radius
            Radius of the gaze point circle in pixels (before scaling).
        gaze_color
            RGB tuple for gaze point color.
        fps
            Frames per second for the animation. If None:
            - With video: uses the video's native FPS
            - Without video: defaults to 60 FPS
        output_format
            Output format for saved animations:
            - "html": Interactive HTML file (default, works in browsers)
            - "mp4": Video file (requires ffmpeg)
            - "gif": Animated GIF file (requires pillow)
            - "matplotlib": Show in matplotlib GUI window (blocking)
        display
            If True and output_format is "html", returns an HTML object for notebooks.
            If output_format is "matplotlib", this is ignored (always shows window).
            If False, only saves to file (if folder_path is provided).

        Returns
        -------
        IPython.display.HTML or None
            Returns HTML animation if display=True and output_format="html", otherwise None.
        """
        try:
            import cv2
            from matplotlib.animation import FuncAnimation
            import matplotlib as mpl
            mpl.rcParams['animation.embed_limit'] = 100
        except ImportError as e:
            raise ImportError(
                f"Missing required dependency for animation: {e}. "
                "Please install cv2 (opencv-python)."
            )

        # Validate output_format
        valid_formats = ["html", "mp4", "gif", "matplotlib"]
        if output_format not in valid_formats:
            raise ValueError(f"output_format must be one of {valid_formats}, got '{output_format}'")

        # ---- Determine gaze columns ----
        if "X" in samples.columns and "Y" in samples.columns:
            x_col, y_col = "X", "Y"
        elif "LX" in samples.columns and "LY" in samples.columns:
            x_col, y_col = "LX", "LY"
        elif "RX" in samples.columns and "RY" in samples.columns:
            x_col, y_col = "RX", "RY"
        else:
            raise ValueError("Samples DataFrame must contain gaze columns (X, Y) or (LX, LY) or (RX, RY)")

        # ---- Time filter ----
        if tmin is not None and tmax is not None:
            samples = samples.filter(pl.col("tSample").is_between(tmin, tmax))

        if samples.is_empty():
            raise ValueError("No samples available after time filtering")

        # ---- Drop NaN gaze values ----
        samples = samples.filter(pl.col(x_col).is_not_null() & pl.col(y_col).is_not_null())

        # ---- Calculate scaled dimensions ----
        scaled_width = int(screen_width * scale_factor)
        scaled_height = int(screen_height * scale_factor)

        trial_idx = samples["trial_number"][0] if "trial_number" in samples.columns else 0

        # ================= WITH VIDEO =================
        if video_path is not None:
            video_path = Path(video_path)
            if not video_path.exists():
                raise FileNotFoundError(f"Video file not found: {video_path}")

            cap = cv2.VideoCapture(str(video_path))
            video_fps = cap.get(cv2.CAP_PROP_FPS)
            total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            video_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
            video_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

            if fps is None:
                fps = video_fps

            # Calculate time to frame mapping
            t_start = samples["tSample"].min()
            t_end = samples["tSample"].max()
            trial_duration = t_end - t_start

            # Create frame-to-time mapping
            frame_edges = np.linspace(t_start, t_end, total_frames + 1)
            frame_times = ((frame_edges[:-1] + frame_edges[1:]) / 2).astype(int)

            # Build a lookup: frame_index -> list of gaze points
            samples_np = samples.select([x_col, y_col, "tSample"]).to_numpy()
            gaze_by_frame = {i: [] for i in range(total_frames)}

            for x, y, t in samples_np:
                # Find the closest frame
                frame_idx = np.searchsorted(frame_times, t, side='right') - 1
                frame_idx = max(0, min(frame_idx, total_frames - 1))
                gaze_by_frame[frame_idx].append((x, y))

            # Limit frames if seconds_to_show is set
            frames_to_show = total_frames
            if seconds_to_show is not None:
                frames_to_show = min(int(fps * seconds_to_show), total_frames)

            # Reset video
            cap.set(cv2.CAP_PROP_POS_FRAMES, 0)

            # Create figure
            fig, ax = plt.subplots(figsize=(10 * scale_factor, 6 * scale_factor))
            ax.axis('off')

            # Initialize with first frame
            ret, frame = cap.read()
            if not ret:
                cap.release()
                raise RuntimeError("Could not read first frame from video")

            frame_resized = cv2.resize(frame, (scaled_width, scaled_height), interpolation=cv2.INTER_AREA)
            frame_rgb = cv2.cvtColor(frame_resized, cv2.COLOR_BGR2RGB)
            im = ax.imshow(frame_rgb)

            def update_frame_video(frame_idx):
                ret, frame = cap.read()
                if not ret:
                    return [im]

                frame_resized = cv2.resize(frame, (scaled_width, scaled_height), interpolation=cv2.INTER_AREA)
                frame_rgb = cv2.cvtColor(frame_resized, cv2.COLOR_BGR2RGB)

                # Draw gaze points for this frame
                for gx, gy in gaze_by_frame.get(frame_idx, []):
                    scaled_x = int(gx * scale_factor)
                    scaled_y = int(gy * scale_factor)
                    if 0 <= scaled_x < scaled_width and 0 <= scaled_y < scaled_height:
                        radius = max(3, int(gaze_radius * scale_factor))
                        cv2.circle(frame_rgb, (scaled_x, scaled_y), radius=radius, color=gaze_color, thickness=-1)

                im.set_array(frame_rgb)
                return [im]

            anim = FuncAnimation(fig, update_frame_video, frames=frames_to_show,
                                 interval=1000/fps, blit=True, repeat=True)

        # ================= WITHOUT VIDEO =================
        else:
            if fps is None:
                fps = 60  # Default FPS for sample-based animation

            # Prepare background
            if background_image_path is not None:
                bg_path = Path(background_image_path)
                if not bg_path.exists():
                    raise FileNotFoundError(f"Background image not found: {bg_path}")
                bg_img = mpimg.imread(str(bg_path))
                if bg_img.dtype == np.float64:
                    bg_img = (bg_img * 255).clip(0, 255).astype(np.uint8)
                # Resize background to match screen dimensions then scale
                bg_img = cv2.resize(bg_img, (scaled_width, scaled_height), interpolation=cv2.INTER_AREA)
            else:
                # Grey background
                bg_img = np.ones((scaled_height, scaled_width, 3), dtype=np.uint8) * 128

            # Get time range
            t_start = samples["tSample"].min()
            t_end = samples["tSample"].max()
            trial_duration = t_end - t_start

            # Limit duration if seconds_to_show is set
            if seconds_to_show is not None:
                t_end = min(t_end, t_start + int(seconds_to_show * 1000))
                samples = samples.filter(pl.col("tSample") <= t_end)
                trial_duration = t_end - t_start

            # Calculate total frames based on duration and fps
            total_frames = int((trial_duration / 1000) * fps)
            if total_frames < 1:
                total_frames = 1

            # Create time bins for each animation frame
            frame_times = np.linspace(t_start, t_end, total_frames + 1)

            # Build gaze lookup by frame
            samples_np = samples.select([x_col, y_col, "tSample"]).to_numpy()
            gaze_by_frame = {i: [] for i in range(total_frames)}

            for x, y, t in samples_np:
                frame_idx = np.searchsorted(frame_times, t, side='right') - 1
                frame_idx = max(0, min(frame_idx, total_frames - 1))
                gaze_by_frame[frame_idx].append((x, y))

            # Create figure
            fig, ax = plt.subplots(figsize=(10 * scale_factor, 6 * scale_factor))
            ax.axis('off')

            # Initialize with background
            im = ax.imshow(bg_img.copy())

            def update_frame_no_video(frame_idx):
                # Start with fresh background copy
                frame_rgb = bg_img.copy()

                # Draw gaze points for this frame
                for gx, gy in gaze_by_frame.get(frame_idx, []):
                    scaled_x = int(gx * scale_factor)
                    scaled_y = int(gy * scale_factor)
                    if 0 <= scaled_x < scaled_width and 0 <= scaled_y < scaled_height:
                        radius = max(3, int(gaze_radius * scale_factor))
                        cv2.circle(frame_rgb, (scaled_x, scaled_y), radius=radius, color=gaze_color, thickness=-1)

                im.set_array(frame_rgb)
                return [im]

            anim = FuncAnimation(fig, update_frame_no_video, frames=total_frames,
                                 interval=1000/fps, blit=True, repeat=True)

        # ================= SAVE / DISPLAY =================
        result = None
        trial_idx_val = trial_idx

        # Build output filename
        anim_name = f"animation_{trial_idx_val}"
        if tmin is not None and tmax is not None:
            anim_name += f"_{tmin}_{tmax}"

        # Handle different output formats
        if output_format == "matplotlib":
            # Show in matplotlib GUI window (blocking)
            plt.show()
            # Cleanup video capture if used
            if video_path is not None:
                cap.release()
            return None

        elif output_format == "mp4":
            if folder_path:
                folder_path = Path(folder_path)
                folder_path.mkdir(parents=True, exist_ok=True)
                out_path = folder_path / f"{anim_name}.mp4"
                try:
                    anim.save(str(out_path), writer='ffmpeg', fps=fps)
                    print(f"Animation saved to: {out_path}")
                except Exception as e:
                    raise RuntimeError(
                        f"Failed to save MP4. Make sure ffmpeg is installed. Error: {e}"
                    )
            plt.close(fig)

        elif output_format == "gif":
            if folder_path:
                folder_path = Path(folder_path)
                folder_path.mkdir(parents=True, exist_ok=True)
                out_path = folder_path / f"{anim_name}.gif"
                try:
                    anim.save(str(out_path), writer='pillow', fps=fps)
                    print(f"Animation saved to: {out_path}")
                except Exception as e:
                    raise RuntimeError(
                        f"Failed to save GIF. Make sure pillow is installed. Error: {e}"
                    )
            plt.close(fig)

        else:  # html (default)
            if folder_path:
                folder_path = Path(folder_path)
                folder_path.mkdir(parents=True, exist_ok=True)
                out_path = folder_path / f"{anim_name}.html"
                with open(out_path, 'w') as f:
                    f.write(anim.to_jshtml())
                print(f"Animation saved to: {out_path}")

            if display:
                try:
                    from IPython.display import HTML
                    plt.close(fig)
                    result = HTML(anim.to_jshtml())
                except ImportError:
                    print("IPython not available. Use output_format='matplotlib' for GUI display.")
                    plt.close(fig)
            else:
                plt.close(fig)

        # Cleanup video capture if used
        if video_path is not None:
            cap.release()

        return result

plot_animation(samples, screen_height, screen_width, video_path=None, background_image_path=None, folder_path=None, tmin=None, tmax=None, seconds_to_show=None, scale_factor=0.5, gaze_radius=10, gaze_color=(255, 0, 0), fps=None, output_format='matplotlib', display=True)

Create an animated visualization of eye-tracking data.

When a video is provided, the animation syncs gaze samples with video frames. When no video is provided, gaze points are animated on a grey background or a provided background image, using the sample timestamps for timing.

Parameters:

Name Type Description Default
samples DataFrame

Polars DataFrame with gaze samples. Must contain 'tSample' and gaze position columns ('X', 'Y' or 'LX', 'LY', 'RX', 'RY').

required
screen_height int

Stimulus resolution in pixels.

required
screen_width int

Stimulus resolution in pixels.

required
video_path str | Path | None

Path to a video file. If provided, gaze is overlaid on video frames.

None
background_image_path str | Path | None

Path to a background image. Only used when video_path is None. If both are None, a grey background is used.

None
folder_path str | Path | None

Directory where the animation will be saved. If None, nothing is saved. The file format depends on output_format.

None
tmin int | None

Time window in ms. If both None, the whole trial is plotted.

None
tmax int | None

Time window in ms. If both None, the whole trial is plotted.

None
seconds_to_show float | None

Limit the animation to the first N seconds. If None, shows all available data.

None
scale_factor float

Resolution scaling factor (1.0 = original, 0.5 = half resolution).

0.5
gaze_radius int

Radius of the gaze point circle in pixels (before scaling).

10
gaze_color tuple

RGB tuple for gaze point color.

(255, 0, 0)
fps float | None

Frames per second for the animation. If None: - With video: uses the video's native FPS - Without video: defaults to 60 FPS

None
output_format str

Output format for saved animations: - "html": Interactive HTML file (default, works in browsers) - "mp4": Video file (requires ffmpeg) - "gif": Animated GIF file (requires pillow) - "matplotlib": Show in matplotlib GUI window (blocking)

'matplotlib'
display bool

If True and output_format is "html", returns an HTML object for notebooks. If output_format is "matplotlib", this is ignored (always shows window). If False, only saves to file (if folder_path is provided).

True

Returns:

Type Description
HTML or None

Returns HTML animation if display=True and output_format="html", otherwise None.

Source code in pyxations/visualization/visualization.py
def plot_animation(
    self,
    samples: pl.DataFrame,
    screen_height: int,
    screen_width: int,
    video_path: str | Path | None = None,
    background_image_path: str | Path | None = None,
    folder_path: str | Path | None = None,
    tmin: int | None = None,
    tmax: int | None = None,
    seconds_to_show: float | None = None,
    scale_factor: float = 0.5,
    gaze_radius: int = 10,
    gaze_color: tuple = (255, 0, 0),
    fps: float | None = None,
    output_format: str = "matplotlib",
    display: bool = True,
):
    """
    Create an animated visualization of eye-tracking data.

    When a video is provided, the animation syncs gaze samples with video frames.
    When no video is provided, gaze points are animated on a grey background or
    a provided background image, using the sample timestamps for timing.

    Parameters
    ----------
    samples
        Polars DataFrame with gaze samples. Must contain 'tSample' and gaze
        position columns ('X', 'Y' or 'LX', 'LY', 'RX', 'RY').
    screen_height, screen_width
        Stimulus resolution in pixels.
    video_path
        Path to a video file. If provided, gaze is overlaid on video frames.
    background_image_path
        Path to a background image. Only used when video_path is None.
        If both are None, a grey background is used.
    folder_path
        Directory where the animation will be saved. If None, nothing is saved.
        The file format depends on `output_format`.
    tmin, tmax
        Time window in **ms**. If both None, the whole trial is plotted.
    seconds_to_show
        Limit the animation to the first N seconds. If None, shows all available data.
    scale_factor
        Resolution scaling factor (1.0 = original, 0.5 = half resolution).
    gaze_radius
        Radius of the gaze point circle in pixels (before scaling).
    gaze_color
        RGB tuple for gaze point color.
    fps
        Frames per second for the animation. If None:
        - With video: uses the video's native FPS
        - Without video: defaults to 60 FPS
    output_format
        Output format for saved animations:
        - "html": Interactive HTML file (default, works in browsers)
        - "mp4": Video file (requires ffmpeg)
        - "gif": Animated GIF file (requires pillow)
        - "matplotlib": Show in matplotlib GUI window (blocking)
    display
        If True and output_format is "html", returns an HTML object for notebooks.
        If output_format is "matplotlib", this is ignored (always shows window).
        If False, only saves to file (if folder_path is provided).

    Returns
    -------
    IPython.display.HTML or None
        Returns HTML animation if display=True and output_format="html", otherwise None.
    """
    try:
        import cv2
        from matplotlib.animation import FuncAnimation
        import matplotlib as mpl
        mpl.rcParams['animation.embed_limit'] = 100
    except ImportError as e:
        raise ImportError(
            f"Missing required dependency for animation: {e}. "
            "Please install cv2 (opencv-python)."
        )

    # Validate output_format
    valid_formats = ["html", "mp4", "gif", "matplotlib"]
    if output_format not in valid_formats:
        raise ValueError(f"output_format must be one of {valid_formats}, got '{output_format}'")

    # ---- Determine gaze columns ----
    if "X" in samples.columns and "Y" in samples.columns:
        x_col, y_col = "X", "Y"
    elif "LX" in samples.columns and "LY" in samples.columns:
        x_col, y_col = "LX", "LY"
    elif "RX" in samples.columns and "RY" in samples.columns:
        x_col, y_col = "RX", "RY"
    else:
        raise ValueError("Samples DataFrame must contain gaze columns (X, Y) or (LX, LY) or (RX, RY)")

    # ---- Time filter ----
    if tmin is not None and tmax is not None:
        samples = samples.filter(pl.col("tSample").is_between(tmin, tmax))

    if samples.is_empty():
        raise ValueError("No samples available after time filtering")

    # ---- Drop NaN gaze values ----
    samples = samples.filter(pl.col(x_col).is_not_null() & pl.col(y_col).is_not_null())

    # ---- Calculate scaled dimensions ----
    scaled_width = int(screen_width * scale_factor)
    scaled_height = int(screen_height * scale_factor)

    trial_idx = samples["trial_number"][0] if "trial_number" in samples.columns else 0

    # ================= WITH VIDEO =================
    if video_path is not None:
        video_path = Path(video_path)
        if not video_path.exists():
            raise FileNotFoundError(f"Video file not found: {video_path}")

        cap = cv2.VideoCapture(str(video_path))
        video_fps = cap.get(cv2.CAP_PROP_FPS)
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        video_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        video_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

        if fps is None:
            fps = video_fps

        # Calculate time to frame mapping
        t_start = samples["tSample"].min()
        t_end = samples["tSample"].max()
        trial_duration = t_end - t_start

        # Create frame-to-time mapping
        frame_edges = np.linspace(t_start, t_end, total_frames + 1)
        frame_times = ((frame_edges[:-1] + frame_edges[1:]) / 2).astype(int)

        # Build a lookup: frame_index -> list of gaze points
        samples_np = samples.select([x_col, y_col, "tSample"]).to_numpy()
        gaze_by_frame = {i: [] for i in range(total_frames)}

        for x, y, t in samples_np:
            # Find the closest frame
            frame_idx = np.searchsorted(frame_times, t, side='right') - 1
            frame_idx = max(0, min(frame_idx, total_frames - 1))
            gaze_by_frame[frame_idx].append((x, y))

        # Limit frames if seconds_to_show is set
        frames_to_show = total_frames
        if seconds_to_show is not None:
            frames_to_show = min(int(fps * seconds_to_show), total_frames)

        # Reset video
        cap.set(cv2.CAP_PROP_POS_FRAMES, 0)

        # Create figure
        fig, ax = plt.subplots(figsize=(10 * scale_factor, 6 * scale_factor))
        ax.axis('off')

        # Initialize with first frame
        ret, frame = cap.read()
        if not ret:
            cap.release()
            raise RuntimeError("Could not read first frame from video")

        frame_resized = cv2.resize(frame, (scaled_width, scaled_height), interpolation=cv2.INTER_AREA)
        frame_rgb = cv2.cvtColor(frame_resized, cv2.COLOR_BGR2RGB)
        im = ax.imshow(frame_rgb)

        def update_frame_video(frame_idx):
            ret, frame = cap.read()
            if not ret:
                return [im]

            frame_resized = cv2.resize(frame, (scaled_width, scaled_height), interpolation=cv2.INTER_AREA)
            frame_rgb = cv2.cvtColor(frame_resized, cv2.COLOR_BGR2RGB)

            # Draw gaze points for this frame
            for gx, gy in gaze_by_frame.get(frame_idx, []):
                scaled_x = int(gx * scale_factor)
                scaled_y = int(gy * scale_factor)
                if 0 <= scaled_x < scaled_width and 0 <= scaled_y < scaled_height:
                    radius = max(3, int(gaze_radius * scale_factor))
                    cv2.circle(frame_rgb, (scaled_x, scaled_y), radius=radius, color=gaze_color, thickness=-1)

            im.set_array(frame_rgb)
            return [im]

        anim = FuncAnimation(fig, update_frame_video, frames=frames_to_show,
                             interval=1000/fps, blit=True, repeat=True)

    # ================= WITHOUT VIDEO =================
    else:
        if fps is None:
            fps = 60  # Default FPS for sample-based animation

        # Prepare background
        if background_image_path is not None:
            bg_path = Path(background_image_path)
            if not bg_path.exists():
                raise FileNotFoundError(f"Background image not found: {bg_path}")
            bg_img = mpimg.imread(str(bg_path))
            if bg_img.dtype == np.float64:
                bg_img = (bg_img * 255).clip(0, 255).astype(np.uint8)
            # Resize background to match screen dimensions then scale
            bg_img = cv2.resize(bg_img, (scaled_width, scaled_height), interpolation=cv2.INTER_AREA)
        else:
            # Grey background
            bg_img = np.ones((scaled_height, scaled_width, 3), dtype=np.uint8) * 128

        # Get time range
        t_start = samples["tSample"].min()
        t_end = samples["tSample"].max()
        trial_duration = t_end - t_start

        # Limit duration if seconds_to_show is set
        if seconds_to_show is not None:
            t_end = min(t_end, t_start + int(seconds_to_show * 1000))
            samples = samples.filter(pl.col("tSample") <= t_end)
            trial_duration = t_end - t_start

        # Calculate total frames based on duration and fps
        total_frames = int((trial_duration / 1000) * fps)
        if total_frames < 1:
            total_frames = 1

        # Create time bins for each animation frame
        frame_times = np.linspace(t_start, t_end, total_frames + 1)

        # Build gaze lookup by frame
        samples_np = samples.select([x_col, y_col, "tSample"]).to_numpy()
        gaze_by_frame = {i: [] for i in range(total_frames)}

        for x, y, t in samples_np:
            frame_idx = np.searchsorted(frame_times, t, side='right') - 1
            frame_idx = max(0, min(frame_idx, total_frames - 1))
            gaze_by_frame[frame_idx].append((x, y))

        # Create figure
        fig, ax = plt.subplots(figsize=(10 * scale_factor, 6 * scale_factor))
        ax.axis('off')

        # Initialize with background
        im = ax.imshow(bg_img.copy())

        def update_frame_no_video(frame_idx):
            # Start with fresh background copy
            frame_rgb = bg_img.copy()

            # Draw gaze points for this frame
            for gx, gy in gaze_by_frame.get(frame_idx, []):
                scaled_x = int(gx * scale_factor)
                scaled_y = int(gy * scale_factor)
                if 0 <= scaled_x < scaled_width and 0 <= scaled_y < scaled_height:
                    radius = max(3, int(gaze_radius * scale_factor))
                    cv2.circle(frame_rgb, (scaled_x, scaled_y), radius=radius, color=gaze_color, thickness=-1)

            im.set_array(frame_rgb)
            return [im]

        anim = FuncAnimation(fig, update_frame_no_video, frames=total_frames,
                             interval=1000/fps, blit=True, repeat=True)

    # ================= SAVE / DISPLAY =================
    result = None
    trial_idx_val = trial_idx

    # Build output filename
    anim_name = f"animation_{trial_idx_val}"
    if tmin is not None and tmax is not None:
        anim_name += f"_{tmin}_{tmax}"

    # Handle different output formats
    if output_format == "matplotlib":
        # Show in matplotlib GUI window (blocking)
        plt.show()
        # Cleanup video capture if used
        if video_path is not None:
            cap.release()
        return None

    elif output_format == "mp4":
        if folder_path:
            folder_path = Path(folder_path)
            folder_path.mkdir(parents=True, exist_ok=True)
            out_path = folder_path / f"{anim_name}.mp4"
            try:
                anim.save(str(out_path), writer='ffmpeg', fps=fps)
                print(f"Animation saved to: {out_path}")
            except Exception as e:
                raise RuntimeError(
                    f"Failed to save MP4. Make sure ffmpeg is installed. Error: {e}"
                )
        plt.close(fig)

    elif output_format == "gif":
        if folder_path:
            folder_path = Path(folder_path)
            folder_path.mkdir(parents=True, exist_ok=True)
            out_path = folder_path / f"{anim_name}.gif"
            try:
                anim.save(str(out_path), writer='pillow', fps=fps)
                print(f"Animation saved to: {out_path}")
            except Exception as e:
                raise RuntimeError(
                    f"Failed to save GIF. Make sure pillow is installed. Error: {e}"
                )
        plt.close(fig)

    else:  # html (default)
        if folder_path:
            folder_path = Path(folder_path)
            folder_path.mkdir(parents=True, exist_ok=True)
            out_path = folder_path / f"{anim_name}.html"
            with open(out_path, 'w') as f:
                f.write(anim.to_jshtml())
            print(f"Animation saved to: {out_path}")

        if display:
            try:
                from IPython.display import HTML
                plt.close(fig)
                result = HTML(anim.to_jshtml())
            except ImportError:
                print("IPython not available. Use output_format='matplotlib' for GUI display.")
                plt.close(fig)
        else:
            plt.close(fig)

    # Cleanup video capture if used
    if video_path is not None:
        cap.release()

    return result

plot_multipanel(fixations, saccades, display=True)

Create a 2×2 multi‑panel diagnostic plot for every non‑empty phase label and save it as PNG in //plots/.

Source code in pyxations/visualization/visualization.py
def plot_multipanel(
        self,
        fixations: pl.DataFrame,
        saccades: pl.DataFrame,
        display: bool = True
    ) -> None:
    """
    Create a 2×2 multi‑panel diagnostic plot for every non‑empty
    phase label and save it as PNG in
    <derivatives_folder_path>/<events_detection_folder>/plots/.
    """
    # ── paths & matplotlib style ────────────────────────────────
    folder_path: Path = (
        self.derivatives_folder_path
        / self.events_detection_folder
        / "plots"
    )
    folder_path.mkdir(parents=True, exist_ok=True)
    plt.rcParams.update({"font.size": 12})

    # ── drop practice / invalid trials ─────────────────────────
    fixations = fixations.filter(pl.col("trial_number") != -1)
    saccades  = saccades.filter(pl.col("trial_number") != -1)

    # ── collect valid phase labels (skip empty string) ─────────
    phases = (
        fixations
        .select(pl.col("phase").filter(pl.col("phase") != ""))
        .unique()           # unique values in this Series
        .to_series()
        .to_list()          # plain Python list of strings
    )

    # ── one figure per phase ───────────────────────────────────
    for phase in phases:
        fix_phase   = fixations.filter(pl.col("phase") == phase)
        sacc_phase  = saccades.filter(pl.col("phase") == phase)

        fig, axs = plt.subplots(2, 2, figsize=(12, 7))

        self.fix_duration(fix_phase , axs=axs[0, 0])
        self.sacc_main_sequence(sacc_phase, axs=axs[1, 1])
        self.sacc_direction(sacc_phase, axs=axs[1, 0], figs=fig)
        self.sacc_amplitude(sacc_phase, axs=axs[0, 1])

        fig.tight_layout()
        plt.savefig(folder_path / f"multipanel_{phase}.png")
        if display:
            plt.show()
        plt.close()

scanpath(fixations, screen_height, screen_width, folder_path=None, tmin=None, tmax=None, saccades=None, samples=None, phase_data=None, display=True)

Fast scan‑path visualiser.

Vectorised: no per‑row Python loops
Single pass phase grouping
• Uses BrokenBarHCollection for fixation spans
• Optional asynchronous PNG write via ThreadPoolExecutor (drop‑in‑ready, see comment)

Parameters:

Name Type Description Default
fixations DataFrame

Polars DataFrame with at least tStart, duration, xAvg, yAvg, phase.

required
screen_height int

Stimulus resolution in pixels.

required
screen_width int

Stimulus resolution in pixels.

required
folder_path str | Path | None

Directory where 1 PNG per phase will be stored. If None, nothing is saved.

None
tmin int | None

Time window in ms. If both None, the whole trial is plotted.

None
tmax int | None

Time window in ms. If both None, the whole trial is plotted.

None
saccades DataFrame | None

Polars DataFrame with tStart, phase, … (optional).

None
samples DataFrame | None

Polars DataFrame with gaze traces (tSample, LX, LY, RX, RY or X, Y) (optional).

None
phase_data dict[str, dict] | None

Per‑phase extras::

{
    "search": {
        "img_paths": [...],
        "img_plot_coords": [(x1,y1,x2,y2), ...],
        "bbox": (x1,y1,x2,y2),
    },
    ...
}
None
display bool

If False the figure canvas is never shown (faster for batch jobs).

True
Source code in pyxations/visualization/visualization.py
def scanpath(
    self,
    fixations: pl.DataFrame,
    screen_height: int,
    screen_width: int,
    folder_path: str | Path | None = None,
    tmin: int | None = None,
    tmax: int | None = None,
    saccades: pl.DataFrame | None = None,
    samples: pl.DataFrame | None = None,
    phase_data: dict[str, dict] | None = None,
    display: bool = True,
):
    """
    Fast scan‑path visualiser.

    • **Vectorised**: no per‑row Python loops  
    • **Single pass** phase grouping  
    • Uses `BrokenBarHCollection` for fixation spans  
    • Optional asynchronous PNG write via ThreadPoolExecutor (drop‑in‑ready, see comment)

    Parameters
    ----------
    fixations
        Polars DataFrame with at least `tStart`, `duration`, `xAvg`, `yAvg`, `phase`.
    screen_height, screen_width
        Stimulus resolution in pixels.
    folder_path
        Directory where 1 PNG per phase will be stored.  If *None*, nothing is saved.
    tmin, tmax
        Time window in **ms**.  If both `None`, the whole trial is plotted.
    saccades
        Polars DataFrame with `tStart`, `phase`, …  (optional).
    samples
        Polars DataFrame with gaze traces (`tSample`, `LX`, `LY`, `RX`, `RY` or
        `X`, `Y`) (optional).
    phase_data
        Per‑phase extras::

            {
                "search": {
                    "img_paths": [...],
                    "img_plot_coords": [(x1,y1,x2,y2), ...],
                    "bbox": (x1,y1,x2,y2),
                },
                ...
            }

    display
        If *False* the figure canvas is never shown (faster for batch jobs).
    """


    # ------------- small helpers ------------------------------------------------
    def _make_axes(plot_samples: bool):
        if plot_samples:
            fig, (ax_main, ax_gaze) = plt.subplots(
                2, 1, height_ratios=(4, 1), figsize=(10, 6), sharex=False
            )
        else:
            fig, ax_main = plt.subplots(figsize=(10, 6))
            ax_gaze = None
        ax_main.set_xlim(0, screen_width)
        ax_main.set_ylim(screen_height, 0)
        return fig, ax_main, ax_gaze

    def _maybe_cache_img(path):
        """Load image from disk with a small LRU cache."""

        # Cache hit: move to the end (most recently used)
        if path in _img_cache:
            img = _img_cache.pop(path)
            _img_cache[path] = img
            return img

        # Cache miss: load image
        img = mpimg.imread(path)

        # Optional: reduce memory if image is float64 in [0, 1]
        if isinstance(img, np.ndarray) and img.dtype == np.float64:
            img = (img * 255).clip(0, 255).astype(np.uint8)

        # Insert into cache
        _img_cache[path] = img

        # If cache too big, drop least recently used item
        if len(_img_cache) > _MAX_CACHE_ITEMS:
            _img_cache.popitem(last=False)  # pops the oldest inserted item

        return img

    # ---------------------------------------------------------------------------
    plot_saccades = saccades is not None
    plot_samples = samples is not None
    _img_cache = OrderedDict()
    _MAX_CACHE_ITEMS = 8  # or 5, 10, etc. Tune as you like.

    trial_idx = fixations["trial_number"][0]

    # ---- time filter ----------------------------------------------------------
    if tmin is not None and tmax is not None:
        fixations = fixations.filter(pl.col("tStart").is_between(tmin, tmax))
        if plot_saccades:
            saccades = saccades.filter(pl.col("tStart").is_between(tmin, tmax))
        if plot_samples:
            samples = samples.filter(pl.col("tSample").is_between(tmin, tmax))

    # remove empty phase markings
    fixations = fixations.filter(pl.col("phase") != "")
    if plot_saccades:
        saccades = saccades.filter(pl.col("phase") != "")
    if plot_samples:
        samples = samples.filter(pl.col("phase") != "")

    # ---- split once by phase --------------------------------------------------
    fix_by_phase = fixations.partition_by("phase", as_dict=True)
    sac_by_phase = (
        saccades.partition_by("phase", as_dict=True) if plot_saccades else {}
    )
    samp_by_phase = (
        samples.partition_by("phase", as_dict=True) if plot_samples else {}
    )

    # colour map shared across phases
    cmap = plt.cm.rainbow

    # ---- build & draw ---------------------------------------------------------
    # optional async saver (uncomment if you save hundreds of files)
    from concurrent.futures import ThreadPoolExecutor
    saver = ThreadPoolExecutor(max_workers=4) if folder_path else None

    if not display:
        plt.ioff()

    for phase, phase_fix in fix_by_phase.items():
        if phase_fix.is_empty():
            continue

        # ---------- vectors (zero‑copy) -----------------
        fx, fy, fdur = phase_fix.select(["xAvg", "yAvg", "duration"]).to_numpy().T
        n_fix = fx.size
        fix_idx = np.arange(1, n_fix + 1)

        norm = mplcolors.BoundaryNorm(np.arange(1, n_fix + 2), cmap.N)

        # saccades
        sac_t = (
            sac_by_phase[phase]["tStart"].to_numpy()
            if plot_saccades and phase in sac_by_phase
            else np.empty(0)
        )

        # samples
        if plot_samples and phase in samp_by_phase and samp_by_phase[phase].height:
            samp_phase = samp_by_phase[phase]
            t0 = samp_phase["tSample"][0]
            ts = (samp_phase["tSample"].to_numpy() - t0) 
            get = samp_phase.get_column
            lx = get("LX").to_numpy() if "LX" in samp_phase.columns else None
            ly = get("LY").to_numpy() if "LY" in samp_phase.columns else None
            rx = get("RX").to_numpy() if "RX" in samp_phase.columns else None
            ry = get("RY").to_numpy() if "RY" in samp_phase.columns else None
            gx = get("X").to_numpy() if "X" in samp_phase.columns else None
            gy = get("Y").to_numpy() if "Y" in samp_phase.columns else None
        else:
            t0 = None

        # ---------- figure -----------------------------
        fig, ax_main, ax_gaze = _make_axes(plot_samples and t0 is not None)
        # scatter fixations
        sc = ax_main.scatter(
            fx,
            fy,
            c=fix_idx,
            s=fdur,
            cmap=cmap,
            norm=norm,
            alpha=0.5,
            zorder=2,
        )
        fig.colorbar(
            sc,
            ax=ax_main,
            ticks=[1, n_fix // 2 if n_fix > 2 else n_fix, n_fix],
            fraction=0.046,
            pad=0.04,
        ).set_label("# of fixation")

        # ---------- stimulus imagery / bbox ------------
        if phase_data and phase[0] in phase_data:
            pdict = phase_data[phase[0]]
            coords = pdict.get("img_plot_coords") or []
            bbox = pdict.get('bbox',None) 
            for img_path, box in zip(pdict.get("img_paths", []), coords):

                ax_main.imshow(_maybe_cache_img(img_path), extent=[box[0], box[2], box[3], box[1]], zorder=0)
            if bbox is not None:
                x1, y1, x2, y2 = bbox
                ax_main.plot([x1, x2, x2, x1, x1], [y1, y1, y2, y2, y1], color='red', linewidth=1.5, zorder=3)

        # ---------- gaze traces ------------------------
        if ax_gaze is not None:
            if lx is not None:
                ax_main.plot(lx, ly, "--", color="C0", zorder=1)
                ax_gaze.plot(ts, lx, label="Left X")
                ax_gaze.plot(ts, ly, label="Left Y")
            if rx is not None:
                ax_main.plot(rx, ry, "--", color="k", zorder=1)
                ax_gaze.plot(ts, rx, label="Right X")
                ax_gaze.plot(ts, ry, label="Right Y")
            if gx is not None:
                ax_main.plot(gx, gy, "--", color="k", zorder=1, alpha=0.6)
                ax_gaze.plot(ts, gx, label="X")
                ax_gaze.plot(ts, gy, label="Y")

            # fixation spans
            bars   = np.c_[phase_fix['tStart'].to_numpy() - t0,
                        phase_fix['duration'].to_numpy()]
            height = ax_gaze.get_ylim()[1] - ax_gaze.get_ylim()[0]
            colors = cmap(norm(fix_idx))

            # Draw all bars in one call; no BrokenBarHCollection import needed
            ax_gaze.broken_barh(bars, (0, height), facecolors=colors, alpha=0.4)
            # saccades
            if sac_t.size:
                ymin, ymax = ax_gaze.get_ylim()
                ax_gaze.vlines(
                    sac_t - t0,
                    ymin,
                    ymax,
                    colors="red",
                    linestyles="--",
                    linewidth=0.8,
                )

            # tidy gaze axis
            h, l = ax_gaze.get_legend_handles_labels()
            by_label = {lab: hdl for hdl, lab in zip(h, l)}
            ax_gaze.legend(
                by_label.values(),
                by_label.keys(),
                loc="center left",
                bbox_to_anchor=(1, 0.5),
            )
            ax_gaze.set_ylabel("Gaze")
            ax_gaze.set_xlabel("Time [s]")

        fig.tight_layout()

        # ---------- save / show ------------------------
        if folder_path:
            scan_name = f"scanpath_{trial_idx}"
            if tmin is not None and tmax is not None:
                scan_name += f"_{tmin}_{tmax}"
            out = Path(folder_path) / f"{scan_name}_{phase[0]}.png"
            fig.savefig(out, dpi=150)
            if saver:  saver.submit(fig.savefig, out, dpi=150)

        if display:
            plt.show()
        plt.close(fig)

    if not display:
        plt.ion()