Skip to content

API Reference

This page gives an overview of all public pyxations objects, functions and methods.

pre_processing

PreProcessing

Pyxations preprocessing: trial segmentation, quality flags, and saccade direction. All mutating functions are safe (copy-aware) and validate required columns.

Tables (pd.DataFrame) expected: samples: typically contains 'tSample' (ms), gaze columns (e.g., 'LX','LY','RX','RY' or 'X','Y') fixations: typically contains 'tStart','tEnd' and optional 'xAvg','yAvg' saccades: typically contains 'tStart','tEnd','xStart','yStart','xEnd','yEnd' blinks: typically contains 'tStart','tEnd' (optional) user_messages: must contain 'timestamp','message'

New columns created: - All tables after trialing: 'phase', 'trial_number', 'trial_label' (optional) - samples/fixations/saccades: 'bad' (bool) after bad_samples() - saccades: 'deg' (float degrees), 'dir' (str) after saccades_direction()

Source code in pyxations/pre_processing.py
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
class PreProcessing:
    """
    Pyxations preprocessing: trial segmentation, quality flags, and saccade direction.
    All mutating functions are safe (copy-aware) and validate required columns.

    Tables (pd.DataFrame) expected:
        samples:   typically contains 'tSample' (ms), gaze columns (e.g., 'LX','LY','RX','RY' or 'X','Y')
        fixations: typically contains 'tStart','tEnd' and optional 'xAvg','yAvg'
        saccades:  typically contains 'tStart','tEnd','xStart','yStart','xEnd','yEnd'
        blinks:    typically contains 'tStart','tEnd' (optional)
        user_messages: must contain 'timestamp','message'

    New columns created:
        - All tables after trialing: 'phase', 'trial_number', 'trial_label' (optional)
        - samples/fixations/saccades: 'bad' (bool) after bad_samples()
        - saccades: 'deg' (float degrees), 'dir' (str) after saccades_direction()
    """

    VERSION = "0.2.0"

    def __init__(
        self,
        samples: pd.DataFrame,
        fixations: pd.DataFrame,
        saccades: pd.DataFrame,
        blinks: pd.DataFrame,
        user_messages: pd.DataFrame,
        session_path: PathLike,
        metadata: Optional[SessionMetadata] = None,
    ):
        self.samples = samples.copy()
        self.fixations = fixations.copy()
        self.saccades = saccades.copy()
        self.blinks = blinks.copy()
        self.user_messages = user_messages.copy()
        self.session_path = Path(session_path)
        self.metadata = metadata or SessionMetadata()

        # Normalize dtypes where possible (strings for messages)
        if "message" in self.user_messages.columns:
            self.user_messages["message"] = self.user_messages["message"].astype(str)

    # ------------------------------- Utilities ------------------------------- #

    @staticmethod
    def _require_columns(
        df: pd.DataFrame, cols: Sequence[str], context: str
    ) -> None:
        missing = [c for c in cols if c not in df.columns]
        if missing:
            raise ValueError(
                f"[{context}] Missing required columns: {missing}. "
                f"Available: {list(df.columns)}"
            )

    @staticmethod
    def _assert_nonoverlap(starts: Sequence[int], ends: Sequence[int], key: str, session: Path) -> None:
        if len(starts) != len(ends):
            raise ValueError(
                f"[{key}] start_times and end_times must have the same length, "
                f"got {len(starts)} vs {len(ends)} in session: {session}"
            )
        for i, (s, e) in enumerate(zip(starts, ends)):
            if not (s < e):
                raise ValueError(
                    f"[{key}] Non-positive interval at trial {i}: start={s}, end={e} "
                    f"in session: {session}"
                )
            if i < len(starts) - 1:
                if e > starts[i + 1]:
                    raise ValueError(
                        f"[{key}] Overlapping trials {i}{i+1}: end[i]={e} > start[i+1]={starts[i+1]} "
                        f"in session: {session}"
                    )

    @staticmethod
    def _ensure_columns_exist(df: pd.DataFrame, cols: Sequence[str]) -> List[str]:
        """Return the subset of 'cols' that actually exist in df."""
        return [c for c in cols if c in df.columns]

    def _save_json_sidecar(self, obj: dict, filename: str) -> None:
        outdir = self.session_path
        outdir.mkdir(parents=True, exist_ok=True)
        with open(outdir / filename, "w", encoding="utf-8") as f:
            json.dump(obj, f, indent=2, ensure_ascii=False)

    # ---------------------------- Public API: Meta ---------------------------- #

    def set_metadata(
        self,
        coords_unit: Optional[str] = None,
        time_unit: Optional[str] = None,
        pupil_unit: Optional[str] = None,
        screen_width: Optional[int] = None,
        screen_height: Optional[int] = None,
        **extra,
    ) -> None:
        """Update session-level metadata used in bounds checks and documentation."""
        if coords_unit is not None:
            self.metadata.coords_unit = coords_unit
        if time_unit is not None:
            self.metadata.time_unit = time_unit
        if pupil_unit is not None:
            self.metadata.pupil_unit = pupil_unit
        if screen_width is not None:
            self.metadata.screen_width = screen_width
        if screen_height is not None:
            self.metadata.screen_height = screen_height
        self.metadata.extra.update(extra)

    def save_metadata(self, filename: str = "metadata.json") -> None:
        """Persist metadata next to derivatives for reproducibility."""
        self._save_json_sidecar(self.metadata.to_dict(), filename)

    # ----------------------- Public API: Message Parsing ---------------------- #

    def get_timestamps_from_messages(
        self,
        messages_dict: Dict[str, List[str]],
        *,
        case_insensitive: bool = True,
        use_regex: bool = True,
        return_match_token: bool = False,
    ) -> Dict[str, List[int]]:
        """
        Extract ordered timestamps per phase by matching message substrings/patterns.

        Parameters
        ----------
        messages_dict : dict
            e.g., {'trial': ['TRIAL_START', 'BEGIN_TRIAL'], 'stim': ['STIM_ONSET']}
        case_insensitive : bool
            If True, ignore case during matching.
        use_regex : bool
            If True, treat entries as regex patterns joined by '|'; otherwise escape literals.
        return_match_token : bool
            If True, also creates/updates a 'matched_token' column with the first matched pattern.

        Returns
        -------
        Dict[str, List[int]]
            Ordered timestamps in ms for each key.
        """
        df = self.user_messages
        self._require_columns(df, ["timestamp", "message"], "get_timestamps_from_messages")

        timestamps_dict: Dict[str, List[int]] = {}
        flags = re.I if case_insensitive else 0

        # Prepare an optional matched_token column for traceability
        if return_match_token and "matched_token" not in df.columns:
            df = df.copy()
            df["matched_token"] = pd.Series([None] * len(df), index=df.index)

        for key, tokens in messages_dict.items():
            if not tokens:
                raise ValueError(f"[{key}] Empty token list passed to get_timestamps_from_messages.")
            parts = tokens if use_regex else [re.escape(t) for t in tokens]
            pat = re.compile("|".join(parts), flags=flags)

            hits = df[df["message"].str.contains(pat, regex=True, na=False)].copy()
            hits.sort_values(by="timestamp", inplace=True)

            if return_match_token and not hits.empty:
                # record which token matched first for each hit
                def _which(m: str) -> Optional[str]:
                    for t in tokens:
                        if (re.search(t, m, flags=flags) if use_regex else re.search(re.escape(t), m, flags=flags)):
                            return t
                    return None

                hits["matched_token"] = hits["message"].apply(_which)
                # write back those rows (optional traceability)
                df.loc[hits.index, "matched_token"] = hits["matched_token"]

            stamps = hits["timestamp"].astype(int).tolist()
            if len(stamps) == 0:
                raise ValueError(
                    f"[{key}] No timestamps found for messages {tokens} "
                    f"in session: {self.session_path}"
                )
            timestamps_dict[key] = stamps

        # Persist updated matched_token if requested
        if return_match_token:
            self.user_messages = df

        return timestamps_dict

    # ---------------------- Public API: Trial Segmentation -------------------- #

    def split_all_into_trials(
        self,
        start_times: Dict[str, List[int]],
        end_times: Dict[str, List[int]],
        trial_labels: Optional[Dict[str, List[str]]] = None,
        *,
        allow_open_last: bool = True,
        require_nonoverlap: bool = True,
    ) -> None:
        """Segment samples/fixations/saccades/blinks using explicit times."""
        for df in (self.samples, self.fixations, self.saccades, self.blinks):
            self._split_into_trials_df(
                df, start_times, end_times, trial_labels,
                allow_open_last=allow_open_last,
                require_nonoverlap=require_nonoverlap,
            )

    def split_all_into_trials_by_msgs(
        self,
        start_msgs: Dict[str, List[str]],
        end_msgs: Dict[str, List[str]],
        trial_labels: Optional[Dict[str, List[str]]] = None,
        **msg_kwargs,
    ) -> None:
        """Segment tables using start and end message patterns."""
        starts = self.get_timestamps_from_messages(start_msgs, **msg_kwargs)
        ends = self.get_timestamps_from_messages(end_msgs, **msg_kwargs)
        self.split_all_into_trials(starts, ends, trial_labels)

    def split_all_into_trials_by_durations(
        self,
        start_msgs: Dict[str, List[str]],
        durations: Dict[str, List[int]],
        trial_labels: Optional[Dict[str, List[str]]] = None,
        **msg_kwargs,
    ) -> None:
        """Segment using start message patterns and per-trial durations (ms)."""
        starts = self.get_timestamps_from_messages(start_msgs, **msg_kwargs)
        end_times: Dict[str, List[int]] = {}
        for key, durs in durations.items():
            s = starts.get(key, [])
            if len(durs) < len(s):
                raise ValueError(
                    f"[{key}] Provided {len(durs)} durations but found {len(s)} start times "
                    f"in session: {self.session_path}"
                )
            end_times[key] = [st + du for st, du in zip(s, durs)]
        self.split_all_into_trials(starts, end_times, trial_labels)

    def _split_into_trials_df(
        self,
        data: pd.DataFrame,
        start_times: Dict[str, List[int]],
        end_times: Dict[str, List[int]],
        trial_labels: Optional[Dict[str, List[str]]] = None,
        *,
        allow_open_last: bool = True,
        require_nonoverlap: bool = True,
    ) -> None:
        """
        Core segmentation for a single table. Works with 'tSample' OR ('tStart','tEnd').
        Adds 'phase', 'trial_number', 'trial_label'.
        """
        if data is self.samples:
            time_mode = "sample"
            self._require_columns(data, ["tSample"], "split_into_trials(samples)")
        else:
            # events (fixations/saccades/blinks)
            time_mode = "event"
            self._require_columns(data, ["tStart", "tEnd"], "split_into_trials(events)")

        df = data.copy()
        # Initialize columns deterministically
        df["phase"] = ""
        df["trial_number"] = -1
        df["trial_label"] = ""

        for key in start_times.keys():
            start_list = list(start_times[key])
            end_list = list(end_times[key])

            # Discard starts after last end (common partial last-trial artifact)
            if allow_open_last and end_list:
                last_end = end_list[-1]
                start_list = [st for st in start_list if st < last_end]

            # Sanity checks
            if require_nonoverlap:
                self._assert_nonoverlap(start_list, end_list, key, self.session_path)
            elif len(start_list) != len(end_list):
                raise ValueError(
                    f"[{key}] start_times and end_times length mismatch: {len(start_list)} vs {len(end_list)} "
                    f"in session: {self.session_path}"
                )

            labels = trial_labels.get(key) if (trial_labels and key in trial_labels) else None
            if labels is not None and len(labels) != len(start_list):
                raise ValueError(
                    f"[{key}] Computed {len(start_list)} trials but got {len(labels)} trial labels "
                    f"in session: {self.session_path}"
                )

            # Apply segmentation
            if time_mode == "sample":
                t = df["tSample"].values
                for i, (st, en) in enumerate(zip(start_list, end_list)):
                    mask = (t >= st) & (t <= en)
                    if not np.any(mask):
                        continue
                    df.loc[mask, "trial_number"] = i
                    df.loc[mask, "phase"] = str(key)
                    if labels is not None:
                        df.loc[mask, "trial_label"] = labels[i]
            else:
                t0 = df["tStart"].values
                t1 = df["tEnd"].values
                for i, (st, en) in enumerate(zip(start_list, end_list)):
                    mask = (t0 >= st) & (t1 <= en)
                    if not np.any(mask):
                        continue
                    df.loc[mask, "trial_number"] = i
                    df.loc[mask, "phase"] = str(key)
                    if labels is not None:
                        df.loc[mask, "trial_label"] = labels[i]

        # Commit
        if data is self.samples:
            self.samples = df
        elif data is self.fixations:
            self.fixations = df
        elif data is self.saccades:
            self.saccades = df
        elif data is self.blinks:
            self.blinks = df

    # ------------------------- Public API: QC / Flags ------------------------- #

    def bad_samples(
        self,
        screen_height: Optional[int] = None,
        screen_width: Optional[int] = None,
        *,
        mark_nan_as_bad: bool = True,
        inclusive_bounds: bool = True,
    ) -> None:
        """
        Mark rows as 'bad' if any available coordinate falls outside screen bounds.
        Applies to samples, fixations, saccades. (Blinks unaffected.)

        If width/height not provided, will use metadata.screen_* if available.
        """
        H = screen_height if screen_height is not None else self.metadata.screen_height
        W = screen_width if screen_width is not None else self.metadata.screen_width
        if H is None or W is None:
            raise ValueError(
                "bad_samples requires screen_height and screen_width (either passed "
                "or set via set_metadata())."
            )

        def _mark(df: pd.DataFrame) -> pd.DataFrame:
            d = df.copy()

            # Gather candidate coordinate columns if present
            coord_cols = self._ensure_columns_exist(
                d,
                [
                    "LX", "LY", "RX", "RY", "X", "Y",
                    "xStart", "xEnd", "yStart", "yEnd", "xAvg", "yAvg",
                ],
            )
            if not coord_cols:
                # If no coords present, default to 'not bad'
                if "bad" not in d.columns:
                    d["bad"] = False
                return d

            xcols = [c for c in coord_cols if c.lower().startswith("x")]
            ycols = [c for c in coord_cols if c.lower().startswith("y")]

            # Validity masks for each axis
            if inclusive_bounds:
                valid_w = np.logical_and.reduce([d[c].ge(0) & d[c].le(W) for c in xcols]) if xcols else True
                valid_h = np.logical_and.reduce([d[c].ge(0) & d[c].le(H) for c in ycols]) if ycols else True
            else:
                valid_w = np.logical_and.reduce([d[c].gt(0) & d[c].lt(W) for c in xcols]) if xcols else True
                valid_h = np.logical_and.reduce([d[c].gt(0) & d[c].lt(H) for c in ycols]) if ycols else True

            bad = ~(valid_w & valid_h)
            if mark_nan_as_bad:
                bad |= d[coord_cols].isna().any(axis=1)

            d["bad"] = bad.values
            return d

        self.samples = _mark(self.samples)
        self.fixations = _mark(self.fixations)
        self.saccades = _mark(self.saccades)

    # ---------------------- Public API: Saccade Direction --------------------- #

    def saccades_direction(self, tol_deg: float = 15.0) -> None:
        """
        Compute saccade angle (deg) and cardinal direction with tolerance bands.

        Parameters
        ----------
        tol_deg : float
            Half-width of the acceptance band around 0°, ±90°, and ±180°
            for classifying right/left/up/down.
        """
        df = self.saccades.copy()
        self._require_columns(
            df, ["xStart", "xEnd", "yStart", "yEnd"], "saccades_direction"
        )

        x_dif = df["xEnd"].astype(float) - df["xStart"].astype(float)
        y_dif = df["yEnd"].astype(float) - df["yStart"].astype(float)
        deg = np.degrees(np.arctan2(y_dif.to_numpy(), x_dif.to_numpy()))
        df["deg"] = deg.astype(float)

        # Tolerant direction bins
        right = (-tol_deg < df["deg"]) & (df["deg"] < tol_deg)
        left = (df["deg"] > 180 - tol_deg) | (df["deg"] < -180 + tol_deg)
        down = ((90 - tol_deg) < df["deg"]) & (df["deg"] < (90 + tol_deg))
        up = ((-90 - tol_deg) < df["deg"]) & (df["deg"] < (-90 + tol_deg))

        df["dir"] = ""
        df.loc[right, "dir"] = "right"
        df.loc[left, "dir"] = "left"
        df.loc[down, "dir"] = "down"
        df.loc[up, "dir"] = "up"

        self.saccades = df

    # -------------------------- Public API: Orchestrator ---------------------- #

    def process(
        self,
        functions_and_params: Dict[str, Dict],
        *,
        log_recipe: bool = True,
        recipe_filename: str = "preprocessing_recipe.json",
        provenance_filename: str = "preprocessing_provenance.json",
    ) -> None:
        """
        Run a declarative preprocessing recipe, e.g.:
            pp.process({
                "split_all_into_trials_by_msgs": {
                    "start_msgs": {"trial": ["TRIAL_START"]},
                    "end_msgs": {"trial": ["TRIAL_END"]},
                },
                "bad_samples": {"screen_height": 1080, "screen_width": 1920},
                "saccades_direction": {"tol_deg": 15},
            })

        Unknown function names raise a helpful error.
        """
        # Optional: save the declared recipe for exact reproducibility
        if log_recipe:
            recipe_obj = {
                "declared_recipe": functions_and_params,
                "tool_version": self.VERSION,
                "timestamp_utc": datetime.now(timezone.utc).isoformat(),
                "session_path": str(self.session_path),
            }
            self._save_json_sidecar(recipe_obj, recipe_filename)

        for func_name, params in functions_and_params.items():
            if not hasattr(self, func_name):
                raise AttributeError(
                    f"Unknown preprocessing function '{func_name}'. "
                    f"Available: {[m for m in dir(self) if not m.startswith('_')]}"
                )
            fn = getattr(self, func_name)
            if not isinstance(params, dict):
                raise TypeError(
                    f"Parameters for '{func_name}' must be a dict, got {type(params)}"
                )
            fn(**params)

        # Save lightweight provenance after successful run
        if log_recipe:
            prov = {
                "completed_recipe": list(functions_and_params.keys()),
                "tool_version": self.VERSION,
                "timestamp_utc": datetime.now(timezone.utc).isoformat(),
                "metadata": self.metadata.to_dict(),
            }
            self._save_json_sidecar(prov, provenance_filename)

bad_samples(screen_height=None, screen_width=None, *, mark_nan_as_bad=True, inclusive_bounds=True)

Mark rows as 'bad' if any available coordinate falls outside screen bounds. Applies to samples, fixations, saccades. (Blinks unaffected.)

If width/height not provided, will use metadata.screen_* if available.

Source code in pyxations/pre_processing.py
def bad_samples(
    self,
    screen_height: Optional[int] = None,
    screen_width: Optional[int] = None,
    *,
    mark_nan_as_bad: bool = True,
    inclusive_bounds: bool = True,
) -> None:
    """
    Mark rows as 'bad' if any available coordinate falls outside screen bounds.
    Applies to samples, fixations, saccades. (Blinks unaffected.)

    If width/height not provided, will use metadata.screen_* if available.
    """
    H = screen_height if screen_height is not None else self.metadata.screen_height
    W = screen_width if screen_width is not None else self.metadata.screen_width
    if H is None or W is None:
        raise ValueError(
            "bad_samples requires screen_height and screen_width (either passed "
            "or set via set_metadata())."
        )

    def _mark(df: pd.DataFrame) -> pd.DataFrame:
        d = df.copy()

        # Gather candidate coordinate columns if present
        coord_cols = self._ensure_columns_exist(
            d,
            [
                "LX", "LY", "RX", "RY", "X", "Y",
                "xStart", "xEnd", "yStart", "yEnd", "xAvg", "yAvg",
            ],
        )
        if not coord_cols:
            # If no coords present, default to 'not bad'
            if "bad" not in d.columns:
                d["bad"] = False
            return d

        xcols = [c for c in coord_cols if c.lower().startswith("x")]
        ycols = [c for c in coord_cols if c.lower().startswith("y")]

        # Validity masks for each axis
        if inclusive_bounds:
            valid_w = np.logical_and.reduce([d[c].ge(0) & d[c].le(W) for c in xcols]) if xcols else True
            valid_h = np.logical_and.reduce([d[c].ge(0) & d[c].le(H) for c in ycols]) if ycols else True
        else:
            valid_w = np.logical_and.reduce([d[c].gt(0) & d[c].lt(W) for c in xcols]) if xcols else True
            valid_h = np.logical_and.reduce([d[c].gt(0) & d[c].lt(H) for c in ycols]) if ycols else True

        bad = ~(valid_w & valid_h)
        if mark_nan_as_bad:
            bad |= d[coord_cols].isna().any(axis=1)

        d["bad"] = bad.values
        return d

    self.samples = _mark(self.samples)
    self.fixations = _mark(self.fixations)
    self.saccades = _mark(self.saccades)

get_timestamps_from_messages(messages_dict, *, case_insensitive=True, use_regex=True, return_match_token=False)

Extract ordered timestamps per phase by matching message substrings/patterns.

Parameters:

Name Type Description Default
messages_dict dict

e.g., {'trial': ['TRIAL_START', 'BEGIN_TRIAL'], 'stim': ['STIM_ONSET']}

required
case_insensitive bool

If True, ignore case during matching.

True
use_regex bool

If True, treat entries as regex patterns joined by '|'; otherwise escape literals.

True
return_match_token bool

If True, also creates/updates a 'matched_token' column with the first matched pattern.

False

Returns:

Type Description
Dict[str, List[int]]

Ordered timestamps in ms for each key.

Source code in pyxations/pre_processing.py
def get_timestamps_from_messages(
    self,
    messages_dict: Dict[str, List[str]],
    *,
    case_insensitive: bool = True,
    use_regex: bool = True,
    return_match_token: bool = False,
) -> Dict[str, List[int]]:
    """
    Extract ordered timestamps per phase by matching message substrings/patterns.

    Parameters
    ----------
    messages_dict : dict
        e.g., {'trial': ['TRIAL_START', 'BEGIN_TRIAL'], 'stim': ['STIM_ONSET']}
    case_insensitive : bool
        If True, ignore case during matching.
    use_regex : bool
        If True, treat entries as regex patterns joined by '|'; otherwise escape literals.
    return_match_token : bool
        If True, also creates/updates a 'matched_token' column with the first matched pattern.

    Returns
    -------
    Dict[str, List[int]]
        Ordered timestamps in ms for each key.
    """
    df = self.user_messages
    self._require_columns(df, ["timestamp", "message"], "get_timestamps_from_messages")

    timestamps_dict: Dict[str, List[int]] = {}
    flags = re.I if case_insensitive else 0

    # Prepare an optional matched_token column for traceability
    if return_match_token and "matched_token" not in df.columns:
        df = df.copy()
        df["matched_token"] = pd.Series([None] * len(df), index=df.index)

    for key, tokens in messages_dict.items():
        if not tokens:
            raise ValueError(f"[{key}] Empty token list passed to get_timestamps_from_messages.")
        parts = tokens if use_regex else [re.escape(t) for t in tokens]
        pat = re.compile("|".join(parts), flags=flags)

        hits = df[df["message"].str.contains(pat, regex=True, na=False)].copy()
        hits.sort_values(by="timestamp", inplace=True)

        if return_match_token and not hits.empty:
            # record which token matched first for each hit
            def _which(m: str) -> Optional[str]:
                for t in tokens:
                    if (re.search(t, m, flags=flags) if use_regex else re.search(re.escape(t), m, flags=flags)):
                        return t
                return None

            hits["matched_token"] = hits["message"].apply(_which)
            # write back those rows (optional traceability)
            df.loc[hits.index, "matched_token"] = hits["matched_token"]

        stamps = hits["timestamp"].astype(int).tolist()
        if len(stamps) == 0:
            raise ValueError(
                f"[{key}] No timestamps found for messages {tokens} "
                f"in session: {self.session_path}"
            )
        timestamps_dict[key] = stamps

    # Persist updated matched_token if requested
    if return_match_token:
        self.user_messages = df

    return timestamps_dict

process(functions_and_params, *, log_recipe=True, recipe_filename='preprocessing_recipe.json', provenance_filename='preprocessing_provenance.json')

Run a declarative preprocessing recipe, e.g.: pp.process({ "split_all_into_trials_by_msgs": { "start_msgs": {"trial": ["TRIAL_START"]}, "end_msgs": {"trial": ["TRIAL_END"]}, }, "bad_samples": {"screen_height": 1080, "screen_width": 1920}, "saccades_direction": {"tol_deg": 15}, })

Unknown function names raise a helpful error.

Source code in pyxations/pre_processing.py
def process(
    self,
    functions_and_params: Dict[str, Dict],
    *,
    log_recipe: bool = True,
    recipe_filename: str = "preprocessing_recipe.json",
    provenance_filename: str = "preprocessing_provenance.json",
) -> None:
    """
    Run a declarative preprocessing recipe, e.g.:
        pp.process({
            "split_all_into_trials_by_msgs": {
                "start_msgs": {"trial": ["TRIAL_START"]},
                "end_msgs": {"trial": ["TRIAL_END"]},
            },
            "bad_samples": {"screen_height": 1080, "screen_width": 1920},
            "saccades_direction": {"tol_deg": 15},
        })

    Unknown function names raise a helpful error.
    """
    # Optional: save the declared recipe for exact reproducibility
    if log_recipe:
        recipe_obj = {
            "declared_recipe": functions_and_params,
            "tool_version": self.VERSION,
            "timestamp_utc": datetime.now(timezone.utc).isoformat(),
            "session_path": str(self.session_path),
        }
        self._save_json_sidecar(recipe_obj, recipe_filename)

    for func_name, params in functions_and_params.items():
        if not hasattr(self, func_name):
            raise AttributeError(
                f"Unknown preprocessing function '{func_name}'. "
                f"Available: {[m for m in dir(self) if not m.startswith('_')]}"
            )
        fn = getattr(self, func_name)
        if not isinstance(params, dict):
            raise TypeError(
                f"Parameters for '{func_name}' must be a dict, got {type(params)}"
            )
        fn(**params)

    # Save lightweight provenance after successful run
    if log_recipe:
        prov = {
            "completed_recipe": list(functions_and_params.keys()),
            "tool_version": self.VERSION,
            "timestamp_utc": datetime.now(timezone.utc).isoformat(),
            "metadata": self.metadata.to_dict(),
        }
        self._save_json_sidecar(prov, provenance_filename)

saccades_direction(tol_deg=15.0)

Compute saccade angle (deg) and cardinal direction with tolerance bands.

Parameters:

Name Type Description Default
tol_deg float

Half-width of the acceptance band around 0°, ±90°, and ±180° for classifying right/left/up/down.

15.0
Source code in pyxations/pre_processing.py
def saccades_direction(self, tol_deg: float = 15.0) -> None:
    """
    Compute saccade angle (deg) and cardinal direction with tolerance bands.

    Parameters
    ----------
    tol_deg : float
        Half-width of the acceptance band around 0°, ±90°, and ±180°
        for classifying right/left/up/down.
    """
    df = self.saccades.copy()
    self._require_columns(
        df, ["xStart", "xEnd", "yStart", "yEnd"], "saccades_direction"
    )

    x_dif = df["xEnd"].astype(float) - df["xStart"].astype(float)
    y_dif = df["yEnd"].astype(float) - df["yStart"].astype(float)
    deg = np.degrees(np.arctan2(y_dif.to_numpy(), x_dif.to_numpy()))
    df["deg"] = deg.astype(float)

    # Tolerant direction bins
    right = (-tol_deg < df["deg"]) & (df["deg"] < tol_deg)
    left = (df["deg"] > 180 - tol_deg) | (df["deg"] < -180 + tol_deg)
    down = ((90 - tol_deg) < df["deg"]) & (df["deg"] < (90 + tol_deg))
    up = ((-90 - tol_deg) < df["deg"]) & (df["deg"] < (-90 + tol_deg))

    df["dir"] = ""
    df.loc[right, "dir"] = "right"
    df.loc[left, "dir"] = "left"
    df.loc[down, "dir"] = "down"
    df.loc[up, "dir"] = "up"

    self.saccades = df

save_metadata(filename='metadata.json')

Persist metadata next to derivatives for reproducibility.

Source code in pyxations/pre_processing.py
def save_metadata(self, filename: str = "metadata.json") -> None:
    """Persist metadata next to derivatives for reproducibility."""
    self._save_json_sidecar(self.metadata.to_dict(), filename)

set_metadata(coords_unit=None, time_unit=None, pupil_unit=None, screen_width=None, screen_height=None, **extra)

Update session-level metadata used in bounds checks and documentation.

Source code in pyxations/pre_processing.py
def set_metadata(
    self,
    coords_unit: Optional[str] = None,
    time_unit: Optional[str] = None,
    pupil_unit: Optional[str] = None,
    screen_width: Optional[int] = None,
    screen_height: Optional[int] = None,
    **extra,
) -> None:
    """Update session-level metadata used in bounds checks and documentation."""
    if coords_unit is not None:
        self.metadata.coords_unit = coords_unit
    if time_unit is not None:
        self.metadata.time_unit = time_unit
    if pupil_unit is not None:
        self.metadata.pupil_unit = pupil_unit
    if screen_width is not None:
        self.metadata.screen_width = screen_width
    if screen_height is not None:
        self.metadata.screen_height = screen_height
    self.metadata.extra.update(extra)

split_all_into_trials(start_times, end_times, trial_labels=None, *, allow_open_last=True, require_nonoverlap=True)

Segment samples/fixations/saccades/blinks using explicit times.

Source code in pyxations/pre_processing.py
def split_all_into_trials(
    self,
    start_times: Dict[str, List[int]],
    end_times: Dict[str, List[int]],
    trial_labels: Optional[Dict[str, List[str]]] = None,
    *,
    allow_open_last: bool = True,
    require_nonoverlap: bool = True,
) -> None:
    """Segment samples/fixations/saccades/blinks using explicit times."""
    for df in (self.samples, self.fixations, self.saccades, self.blinks):
        self._split_into_trials_df(
            df, start_times, end_times, trial_labels,
            allow_open_last=allow_open_last,
            require_nonoverlap=require_nonoverlap,
        )

split_all_into_trials_by_durations(start_msgs, durations, trial_labels=None, **msg_kwargs)

Segment using start message patterns and per-trial durations (ms).

Source code in pyxations/pre_processing.py
def split_all_into_trials_by_durations(
    self,
    start_msgs: Dict[str, List[str]],
    durations: Dict[str, List[int]],
    trial_labels: Optional[Dict[str, List[str]]] = None,
    **msg_kwargs,
) -> None:
    """Segment using start message patterns and per-trial durations (ms)."""
    starts = self.get_timestamps_from_messages(start_msgs, **msg_kwargs)
    end_times: Dict[str, List[int]] = {}
    for key, durs in durations.items():
        s = starts.get(key, [])
        if len(durs) < len(s):
            raise ValueError(
                f"[{key}] Provided {len(durs)} durations but found {len(s)} start times "
                f"in session: {self.session_path}"
            )
        end_times[key] = [st + du for st, du in zip(s, durs)]
    self.split_all_into_trials(starts, end_times, trial_labels)

split_all_into_trials_by_msgs(start_msgs, end_msgs, trial_labels=None, **msg_kwargs)

Segment tables using start and end message patterns.

Source code in pyxations/pre_processing.py
def split_all_into_trials_by_msgs(
    self,
    start_msgs: Dict[str, List[str]],
    end_msgs: Dict[str, List[str]],
    trial_labels: Optional[Dict[str, List[str]]] = None,
    **msg_kwargs,
) -> None:
    """Segment tables using start and end message patterns."""
    starts = self.get_timestamps_from_messages(start_msgs, **msg_kwargs)
    ends = self.get_timestamps_from_messages(end_msgs, **msg_kwargs)
    self.split_all_into_trials(starts, ends, trial_labels)

SessionMetadata dataclass

Lightweight metadata container saved alongside derivatives.

Source code in pyxations/pre_processing.py
@dataclass
class SessionMetadata:
    """Lightweight metadata container saved alongside derivatives."""
    coords_unit: str = "px"          # 'px' or 'deg'
    time_unit: str = "ms"            # 'ms'
    pupil_unit: str = "arbitrary"
    screen_width: Optional[int] = None
    screen_height: Optional[int] = None
    extra: Dict[str, Union[str, int, float, bool, None]] = field(default_factory=dict)

    def to_dict(self) -> dict:
        return {
            "coords_unit": self.coords_unit,
            "time_unit": self.time_unit,
            "pupil_unit": self.pupil_unit,
            "screen_width": self.screen_width,
            "screen_height": self.screen_height,
            "extra": self.extra,
        }

bids_formatting

dataset_to_bids(target_folder_path, files_folder_path, dataset_name, session_substrings=1, format_name='eyelink')

Convert a dataset to BIDS format.

Args: target_folder_path (str): Path to the folder where the BIDS dataset will be created. files_folder_path (str): Path to the folder containing the EDF files. The EDF files are assumed to have the ID of the subject at the beginning of the file name, separated by an underscore. dataset_name (str): Name of the BIDS dataset. session_substrings (int): Number of substrings to use for the session ID. Default is 1.

Returns: None

Source code in pyxations/bids_formatting.py
def dataset_to_bids(target_folder_path, files_folder_path, dataset_name, session_substrings=1, format_name='eyelink'):
    """
    Convert a dataset to BIDS format.

    Args:
        target_folder_path (str): Path to the folder where the BIDS dataset will be created.
        files_folder_path (str): Path to the folder containing the EDF files.
        The EDF files are assumed to have the ID of the subject at the beginning of the file name, separated by an underscore.
        dataset_name (str): Name of the BIDS dataset.
        session_substrings (int): Number of substrings to use for the session ID. Default is 1.

    Returns:
        None
    """
    converter = get_converter(format_name)

    # Create a metadata tsv file
    metadata = pd.DataFrame(columns=['subject_id', 'old_subject_id'])
    files_folder_path = Path(files_folder_path)
    # List all file paths in the folder
    file_paths = []
    for file_path in files_folder_path.rglob('*'):  # Recursively go through all files
        if file_path.is_file():
            file_paths.append(file_path)

    file_paths = [file for file in file_paths if file.suffix.lower() in converter.relevant_extensions()]

    bids_folder_path = Path(target_folder_path) / dataset_name
    bids_folder_path.mkdir(parents=True, exist_ok=True)

    subj_ids = converter.get_subject_ids(file_paths)

    # If all of the subjects have numerical IDs, sort them numerically, else sort them alphabetically
    if all(subject_id.isdigit() for subject_id in subj_ids):
        subj_ids.sort(key=int)
    else:
        subj_ids.sort()
    new_subj_ids = [str(subject_index).zfill(4) for subject_index in range(1, len(subj_ids) + 1)]

    # Create subfolders for each session for each subject
    for subject_id in new_subj_ids:
        old_subject_id = subj_ids[int(subject_id) - 1]
        for file in file_paths:
            file_name = Path(file).name
            session_id = "_".join("".join(file_name.split(".")[:-1]).split("_")[1:session_substrings + 1])
            converter.move_file_to_bids_folder(file, bids_folder_path, subject_id, old_subject_id, session_id)

        metadata.loc[len(metadata.index)] = [subject_id, old_subject_id]
    # Save metadata to tsv file
    metadata.to_csv(bids_folder_path / "participants.tsv", sep="\t", index=False)
    return bids_folder_path

eye_movement_detection

visualization

utils

Visualization

Source code in pyxations/visualization/visualization.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
class Visualization():
    def __init__(self, derivatives_folder_path,events_detection_algorithm):
        self.derivatives_folder_path = Path(derivatives_folder_path)
        if events_detection_algorithm not in EYE_MOVEMENT_DETECTION_DICT and events_detection_algorithm != 'eyelink':
            raise ValueError(f"Detection algorithm {events_detection_algorithm} not found.")
        self.events_detection_folder = Path(events_detection_algorithm+'_events')

    def scanpath(
        self,
        fixations: pl.DataFrame,
        screen_height: int,
        screen_width: int,
        folder_path: str | Path | None = None,
        tmin: int | None = None,
        tmax: int | None = None,
        saccades: pl.DataFrame | None = None,
        samples: pl.DataFrame | None = None,
        phase_data: dict[str, dict] | None = None,
        display: bool = True,
    ):
        """
        Fast scan‑path visualiser.

        • **Vectorised**: no per‑row Python loops  
        • **Single pass** phase grouping  
        • Uses `BrokenBarHCollection` for fixation spans  
        • Optional asynchronous PNG write via ThreadPoolExecutor (drop‑in‑ready, see comment)

        Parameters
        ----------
        fixations
            Polars DataFrame with at least `tStart`, `duration`, `xAvg`, `yAvg`, `phase`.
        screen_height, screen_width
            Stimulus resolution in pixels.
        folder_path
            Directory where 1 PNG per phase will be stored.  If *None*, nothing is saved.
        tmin, tmax
            Time window in **ms**.  If both `None`, the whole trial is plotted.
        saccades
            Polars DataFrame with `tStart`, `phase`, …  (optional).
        samples
            Polars DataFrame with gaze traces (`tSample`, `LX`, `LY`, `RX`, `RY` or
            `X`, `Y`) (optional).
        phase_data
            Per‑phase extras::

                {
                    "search": {
                        "img_paths": [...],
                        "img_plot_coords": [(x1,y1,x2,y2), ...],
                        "bbox": (x1,y1,x2,y2),
                    },
                    ...
                }

        display
            If *False* the figure canvas is never shown (faster for batch jobs).
        """


        # ------------- small helpers ------------------------------------------------
        def _make_axes(plot_samples: bool):
            if plot_samples:
                fig, (ax_main, ax_gaze) = plt.subplots(
                    2, 1, height_ratios=(4, 1), figsize=(10, 6), sharex=False
                )
            else:
                fig, ax_main = plt.subplots(figsize=(10, 6))
                ax_gaze = None
            ax_main.set_xlim(0, screen_width)
            ax_main.set_ylim(screen_height, 0)
            return fig, ax_main, ax_gaze

        def _maybe_cache_img(path):
            """Load image from disk with a small LRU cache."""

            # Cache hit: move to the end (most recently used)
            if path in _img_cache:
                img = _img_cache.pop(path)
                _img_cache[path] = img
                return img

            # Cache miss: load image
            img = mpimg.imread(path)

            # Optional: reduce memory if image is float64 in [0, 1]
            if isinstance(img, np.ndarray) and img.dtype == np.float64:
                img = (img * 255).clip(0, 255).astype(np.uint8)

            # Insert into cache
            _img_cache[path] = img

            # If cache too big, drop least recently used item
            if len(_img_cache) > _MAX_CACHE_ITEMS:
                _img_cache.popitem(last=False)  # pops the oldest inserted item

            return img

        # ---------------------------------------------------------------------------
        plot_saccades = saccades is not None
        plot_samples = samples is not None
        _img_cache = OrderedDict()
        _MAX_CACHE_ITEMS = 8  # or 5, 10, etc. Tune as you like.

        trial_idx = fixations["trial_number"][0]

        # ---- time filter ----------------------------------------------------------
        if tmin is not None and tmax is not None:
            fixations = fixations.filter(pl.col("tStart").is_between(tmin, tmax))
            if plot_saccades:
                saccades = saccades.filter(pl.col("tStart").is_between(tmin, tmax))
            if plot_samples:
                samples = samples.filter(pl.col("tSample").is_between(tmin, tmax))

        # remove empty phase markings
        fixations = fixations.filter(pl.col("phase") != "")
        if plot_saccades:
            saccades = saccades.filter(pl.col("phase") != "")
        if plot_samples:
            samples = samples.filter(pl.col("phase") != "")

        # ---- split once by phase --------------------------------------------------
        fix_by_phase = fixations.partition_by("phase", as_dict=True)
        sac_by_phase = (
            saccades.partition_by("phase", as_dict=True) if plot_saccades else {}
        )
        samp_by_phase = (
            samples.partition_by("phase", as_dict=True) if plot_samples else {}
        )

        # colour map shared across phases
        cmap = plt.cm.rainbow

        # ---- build & draw ---------------------------------------------------------
        # optional async saver (uncomment if you save hundreds of files)
        from concurrent.futures import ThreadPoolExecutor
        saver = ThreadPoolExecutor(max_workers=4) if folder_path else None

        if not display:
            plt.ioff()

        for phase, phase_fix in fix_by_phase.items():
            if phase_fix.is_empty():
                continue

            # ---------- vectors (zero‑copy) -----------------
            fx, fy, fdur = phase_fix.select(["xAvg", "yAvg", "duration"]).to_numpy().T
            n_fix = fx.size
            fix_idx = np.arange(1, n_fix + 1)

            norm = mplcolors.BoundaryNorm(np.arange(1, n_fix + 2), cmap.N)

            # saccades
            sac_t = (
                sac_by_phase[phase]["tStart"].to_numpy()
                if plot_saccades and phase in sac_by_phase
                else np.empty(0)
            )

            # samples
            if plot_samples and phase in samp_by_phase and samp_by_phase[phase].height:
                samp_phase = samp_by_phase[phase]
                t0 = samp_phase["tSample"][0]
                ts = (samp_phase["tSample"].to_numpy() - t0) 
                get = samp_phase.get_column
                lx = get("LX").to_numpy() if "LX" in samp_phase.columns else None
                ly = get("LY").to_numpy() if "LY" in samp_phase.columns else None
                rx = get("RX").to_numpy() if "RX" in samp_phase.columns else None
                ry = get("RY").to_numpy() if "RY" in samp_phase.columns else None
                gx = get("X").to_numpy() if "X" in samp_phase.columns else None
                gy = get("Y").to_numpy() if "Y" in samp_phase.columns else None
            else:
                t0 = None

            # ---------- figure -----------------------------
            fig, ax_main, ax_gaze = _make_axes(plot_samples and t0 is not None)
            # scatter fixations
            sc = ax_main.scatter(
                fx,
                fy,
                c=fix_idx,
                s=fdur,
                cmap=cmap,
                norm=norm,
                alpha=0.5,
                zorder=2,
            )
            fig.colorbar(
                sc,
                ax=ax_main,
                ticks=[1, n_fix // 2 if n_fix > 2 else n_fix, n_fix],
                fraction=0.046,
                pad=0.04,
            ).set_label("# of fixation")

            # ---------- stimulus imagery / bbox ------------
            if phase_data and phase[0] in phase_data:
                pdict = phase_data[phase[0]]
                coords = pdict.get("img_plot_coords") or []
                bbox = pdict.get('bbox',None) 
                for img_path, box in zip(pdict.get("img_paths", []), coords):

                    ax_main.imshow(_maybe_cache_img(img_path), extent=[box[0], box[2], box[3], box[1]], zorder=0)
                if bbox is not None:
                    x1, y1, x2, y2 = bbox
                    ax_main.plot([x1, x2, x2, x1, x1], [y1, y1, y2, y2, y1], color='red', linewidth=1.5, zorder=3)

            # ---------- gaze traces ------------------------
            if ax_gaze is not None:
                if lx is not None:
                    ax_main.plot(lx, ly, "--", color="C0", zorder=1)
                    ax_gaze.plot(ts, lx, label="Left X")
                    ax_gaze.plot(ts, ly, label="Left Y")
                if rx is not None:
                    ax_main.plot(rx, ry, "--", color="k", zorder=1)
                    ax_gaze.plot(ts, rx, label="Right X")
                    ax_gaze.plot(ts, ry, label="Right Y")
                if gx is not None:
                    ax_main.plot(gx, gy, "--", color="k", zorder=1, alpha=0.6)
                    ax_gaze.plot(ts, gx, label="X")
                    ax_gaze.plot(ts, gy, label="Y")

                # fixation spans
                bars   = np.c_[phase_fix['tStart'].to_numpy() - t0,
                            phase_fix['duration'].to_numpy()]
                height = ax_gaze.get_ylim()[1] - ax_gaze.get_ylim()[0]
                colors = cmap(norm(fix_idx))

                # Draw all bars in one call; no BrokenBarHCollection import needed
                ax_gaze.broken_barh(bars, (0, height), facecolors=colors, alpha=0.4)
                # saccades
                if sac_t.size:
                    ymin, ymax = ax_gaze.get_ylim()
                    ax_gaze.vlines(
                        sac_t - t0,
                        ymin,
                        ymax,
                        colors="red",
                        linestyles="--",
                        linewidth=0.8,
                    )

                # tidy gaze axis
                h, l = ax_gaze.get_legend_handles_labels()
                by_label = {lab: hdl for hdl, lab in zip(h, l)}
                ax_gaze.legend(
                    by_label.values(),
                    by_label.keys(),
                    loc="center left",
                    bbox_to_anchor=(1, 0.5),
                )
                ax_gaze.set_ylabel("Gaze")
                ax_gaze.set_xlabel("Time [s]")

            fig.tight_layout()

            # ---------- save / show ------------------------
            if folder_path:
                scan_name = f"scanpath_{trial_idx}"
                if tmin is not None and tmax is not None:
                    scan_name += f"_{tmin}_{tmax}"
                out = Path(folder_path) / f"{scan_name}_{phase[0]}.png"
                fig.savefig(out, dpi=150)
                if saver:  saver.submit(fig.savefig, out, dpi=150)

            if display:
                plt.show()
            plt.close(fig)

        if not display:
            plt.ion()


    def fix_duration(self,fixations:pl.DataFrame,axs=None):

        ax = axs
        if ax is None:
            fig, ax = plt.subplots()

        ax.hist(fixations.select(pl.col('duration')).to_numpy().ravel(), bins=100, edgecolor='black', linewidth=1.2, density=True)
        ax.set_title('Fixation duration')
        ax.set_xlabel('Time (ms)')
        ax.set_ylabel('Density')


    def sacc_amplitude(self,saccades:pl.DataFrame,axs=None):

        ax = axs
        if ax is None:
            fig, ax = plt.subplots()

        saccades_amp = saccades.select(pl.col('ampDeg')).to_numpy().ravel()
        ax.hist(saccades_amp, bins=100, range=(0, 20), edgecolor='black', linewidth=1.2, density=True)
        ax.set_title('Saccades amplitude')
        ax.set_xlabel('Amplitude (deg)')
        ax.set_ylabel('Density')


    def sacc_direction(self,saccades:pl.DataFrame,axs=None,figs=None):

        ax = axs
        if ax is None:
            fig = plt.figure()
            ax = plt.subplot(polar=True)
        else:
            ax.set_axis_off()
            ax = figs.add_subplot(2, 2, 3, projection='polar')
        if 'deg' not in saccades.columns or 'dir' not in saccades.columns:
            raise ValueError('Compute saccades direction first by using saccades_direction function from the PreProcessing module.')
        # Convert from deg to rad
        saccades_rad = saccades.select(pl.col('deg')).to_numpy().ravel() * np.pi / 180

        n_bins = 24
        ang_hist, bin_edges = np.histogram(saccades_rad, bins=24, density=True)
        bin_centers = [np.mean((bin_edges[i], bin_edges[i+1])) for i in range(len(bin_edges) - 1)]

        bars = ax.bar(bin_centers, ang_hist, width=2*np.pi/n_bins, bottom=0.0, alpha=0.4, edgecolor='black')
        ax.set_title('Saccades direction')
        ax.set_yticklabels([])

        for r, bar in zip(ang_hist, bars):
            bar.set_facecolor(plt.cm.Blues(r / np.max(ang_hist)))


    def sacc_main_sequence(self,saccades:pl.DataFrame,axs=None, hline=None):

        ax = axs
        if ax is None:
            fig, ax = plt.subplots()
        # Logarithmic bins
        XL = np.log10(25)  # Adjusted to fit the xlim
        YL = np.log10(1000)  # Adjusted to fit the ylim

        saccades_peak_vel = saccades.select(pl.col('vPeak')).to_numpy().ravel()
        saccades_amp = saccades.select(pl.col('ampDeg')).to_numpy().ravel()

        # Create a 2D histogram with logarithmic bins
        ax.hist2d(saccades_amp, saccades_peak_vel, bins=[np.logspace(-1, XL, 50), np.logspace(0, YL, 50)])

        if hline:
            ax.hlines(y=hline, xmin=ax.get_xlim()[0], xmax=ax.get_xlim()[1], colors='grey', linestyles='--', label=hline)
            ax.legend()
        ax.set_yscale('log')
        ax.set_xscale('log')
        ax.set_title('Main sequence')
        ax.set_xlabel('Amplitude (deg)')
        ax.set_ylabel('Peak velocity (deg)')
         # Set the limits of the axes
        ax.set_xlim(0.1, 25)
        ax.set_ylim(10, 1000)
        ax.set_aspect('equal')


    def plot_multipanel(
            self,
            fixations: pl.DataFrame,
            saccades: pl.DataFrame,
            display: bool = True
        ) -> None:
        """
        Create a 2×2 multi‑panel diagnostic plot for every non‑empty
        phase label and save it as PNG in
        <derivatives_folder_path>/<events_detection_folder>/plots/.
        """
        # ── paths & matplotlib style ────────────────────────────────
        folder_path: Path = (
            self.derivatives_folder_path
            / self.events_detection_folder
            / "plots"
        )
        folder_path.mkdir(parents=True, exist_ok=True)
        plt.rcParams.update({"font.size": 12})

        # ── drop practice / invalid trials ─────────────────────────
        fixations = fixations.filter(pl.col("trial_number") != -1)
        saccades  = saccades.filter(pl.col("trial_number") != -1)

        # ── collect valid phase labels (skip empty string) ─────────
        phases = (
            fixations
            .select(pl.col("phase").filter(pl.col("phase") != ""))
            .unique()           # unique values in this Series
            .to_series()
            .to_list()          # plain Python list of strings
        )

        # ── one figure per phase ───────────────────────────────────
        for phase in phases:
            fix_phase   = fixations.filter(pl.col("phase") == phase)
            sacc_phase  = saccades.filter(pl.col("phase") == phase)

            fig, axs = plt.subplots(2, 2, figsize=(12, 7))

            self.fix_duration(fix_phase , axs=axs[0, 0])
            self.sacc_main_sequence(sacc_phase, axs=axs[1, 1])
            self.sacc_direction(sacc_phase, axs=axs[1, 0], figs=fig)
            self.sacc_amplitude(sacc_phase, axs=axs[0, 1])

            fig.tight_layout()
            plt.savefig(folder_path / f"multipanel_{phase}.png")
            if display:
                plt.show()
            plt.close()

plot_multipanel(fixations, saccades, display=True)

Create a 2×2 multi‑panel diagnostic plot for every non‑empty phase label and save it as PNG in //plots/.

Source code in pyxations/visualization/visualization.py
def plot_multipanel(
        self,
        fixations: pl.DataFrame,
        saccades: pl.DataFrame,
        display: bool = True
    ) -> None:
    """
    Create a 2×2 multi‑panel diagnostic plot for every non‑empty
    phase label and save it as PNG in
    <derivatives_folder_path>/<events_detection_folder>/plots/.
    """
    # ── paths & matplotlib style ────────────────────────────────
    folder_path: Path = (
        self.derivatives_folder_path
        / self.events_detection_folder
        / "plots"
    )
    folder_path.mkdir(parents=True, exist_ok=True)
    plt.rcParams.update({"font.size": 12})

    # ── drop practice / invalid trials ─────────────────────────
    fixations = fixations.filter(pl.col("trial_number") != -1)
    saccades  = saccades.filter(pl.col("trial_number") != -1)

    # ── collect valid phase labels (skip empty string) ─────────
    phases = (
        fixations
        .select(pl.col("phase").filter(pl.col("phase") != ""))
        .unique()           # unique values in this Series
        .to_series()
        .to_list()          # plain Python list of strings
    )

    # ── one figure per phase ───────────────────────────────────
    for phase in phases:
        fix_phase   = fixations.filter(pl.col("phase") == phase)
        sacc_phase  = saccades.filter(pl.col("phase") == phase)

        fig, axs = plt.subplots(2, 2, figsize=(12, 7))

        self.fix_duration(fix_phase , axs=axs[0, 0])
        self.sacc_main_sequence(sacc_phase, axs=axs[1, 1])
        self.sacc_direction(sacc_phase, axs=axs[1, 0], figs=fig)
        self.sacc_amplitude(sacc_phase, axs=axs[0, 1])

        fig.tight_layout()
        plt.savefig(folder_path / f"multipanel_{phase}.png")
        if display:
            plt.show()
        plt.close()

scanpath(fixations, screen_height, screen_width, folder_path=None, tmin=None, tmax=None, saccades=None, samples=None, phase_data=None, display=True)

Fast scan‑path visualiser.

Vectorised: no per‑row Python loops
Single pass phase grouping
• Uses BrokenBarHCollection for fixation spans
• Optional asynchronous PNG write via ThreadPoolExecutor (drop‑in‑ready, see comment)

Parameters:

Name Type Description Default
fixations DataFrame

Polars DataFrame with at least tStart, duration, xAvg, yAvg, phase.

required
screen_height int

Stimulus resolution in pixels.

required
screen_width int

Stimulus resolution in pixels.

required
folder_path str | Path | None

Directory where 1 PNG per phase will be stored. If None, nothing is saved.

None
tmin int | None

Time window in ms. If both None, the whole trial is plotted.

None
tmax int | None

Time window in ms. If both None, the whole trial is plotted.

None
saccades DataFrame | None

Polars DataFrame with tStart, phase, … (optional).

None
samples DataFrame | None

Polars DataFrame with gaze traces (tSample, LX, LY, RX, RY or X, Y) (optional).

None
phase_data dict[str, dict] | None

Per‑phase extras::

{
    "search": {
        "img_paths": [...],
        "img_plot_coords": [(x1,y1,x2,y2), ...],
        "bbox": (x1,y1,x2,y2),
    },
    ...
}
None
display bool

If False the figure canvas is never shown (faster for batch jobs).

True
Source code in pyxations/visualization/visualization.py
def scanpath(
    self,
    fixations: pl.DataFrame,
    screen_height: int,
    screen_width: int,
    folder_path: str | Path | None = None,
    tmin: int | None = None,
    tmax: int | None = None,
    saccades: pl.DataFrame | None = None,
    samples: pl.DataFrame | None = None,
    phase_data: dict[str, dict] | None = None,
    display: bool = True,
):
    """
    Fast scan‑path visualiser.

    • **Vectorised**: no per‑row Python loops  
    • **Single pass** phase grouping  
    • Uses `BrokenBarHCollection` for fixation spans  
    • Optional asynchronous PNG write via ThreadPoolExecutor (drop‑in‑ready, see comment)

    Parameters
    ----------
    fixations
        Polars DataFrame with at least `tStart`, `duration`, `xAvg`, `yAvg`, `phase`.
    screen_height, screen_width
        Stimulus resolution in pixels.
    folder_path
        Directory where 1 PNG per phase will be stored.  If *None*, nothing is saved.
    tmin, tmax
        Time window in **ms**.  If both `None`, the whole trial is plotted.
    saccades
        Polars DataFrame with `tStart`, `phase`, …  (optional).
    samples
        Polars DataFrame with gaze traces (`tSample`, `LX`, `LY`, `RX`, `RY` or
        `X`, `Y`) (optional).
    phase_data
        Per‑phase extras::

            {
                "search": {
                    "img_paths": [...],
                    "img_plot_coords": [(x1,y1,x2,y2), ...],
                    "bbox": (x1,y1,x2,y2),
                },
                ...
            }

    display
        If *False* the figure canvas is never shown (faster for batch jobs).
    """


    # ------------- small helpers ------------------------------------------------
    def _make_axes(plot_samples: bool):
        if plot_samples:
            fig, (ax_main, ax_gaze) = plt.subplots(
                2, 1, height_ratios=(4, 1), figsize=(10, 6), sharex=False
            )
        else:
            fig, ax_main = plt.subplots(figsize=(10, 6))
            ax_gaze = None
        ax_main.set_xlim(0, screen_width)
        ax_main.set_ylim(screen_height, 0)
        return fig, ax_main, ax_gaze

    def _maybe_cache_img(path):
        """Load image from disk with a small LRU cache."""

        # Cache hit: move to the end (most recently used)
        if path in _img_cache:
            img = _img_cache.pop(path)
            _img_cache[path] = img
            return img

        # Cache miss: load image
        img = mpimg.imread(path)

        # Optional: reduce memory if image is float64 in [0, 1]
        if isinstance(img, np.ndarray) and img.dtype == np.float64:
            img = (img * 255).clip(0, 255).astype(np.uint8)

        # Insert into cache
        _img_cache[path] = img

        # If cache too big, drop least recently used item
        if len(_img_cache) > _MAX_CACHE_ITEMS:
            _img_cache.popitem(last=False)  # pops the oldest inserted item

        return img

    # ---------------------------------------------------------------------------
    plot_saccades = saccades is not None
    plot_samples = samples is not None
    _img_cache = OrderedDict()
    _MAX_CACHE_ITEMS = 8  # or 5, 10, etc. Tune as you like.

    trial_idx = fixations["trial_number"][0]

    # ---- time filter ----------------------------------------------------------
    if tmin is not None and tmax is not None:
        fixations = fixations.filter(pl.col("tStart").is_between(tmin, tmax))
        if plot_saccades:
            saccades = saccades.filter(pl.col("tStart").is_between(tmin, tmax))
        if plot_samples:
            samples = samples.filter(pl.col("tSample").is_between(tmin, tmax))

    # remove empty phase markings
    fixations = fixations.filter(pl.col("phase") != "")
    if plot_saccades:
        saccades = saccades.filter(pl.col("phase") != "")
    if plot_samples:
        samples = samples.filter(pl.col("phase") != "")

    # ---- split once by phase --------------------------------------------------
    fix_by_phase = fixations.partition_by("phase", as_dict=True)
    sac_by_phase = (
        saccades.partition_by("phase", as_dict=True) if plot_saccades else {}
    )
    samp_by_phase = (
        samples.partition_by("phase", as_dict=True) if plot_samples else {}
    )

    # colour map shared across phases
    cmap = plt.cm.rainbow

    # ---- build & draw ---------------------------------------------------------
    # optional async saver (uncomment if you save hundreds of files)
    from concurrent.futures import ThreadPoolExecutor
    saver = ThreadPoolExecutor(max_workers=4) if folder_path else None

    if not display:
        plt.ioff()

    for phase, phase_fix in fix_by_phase.items():
        if phase_fix.is_empty():
            continue

        # ---------- vectors (zero‑copy) -----------------
        fx, fy, fdur = phase_fix.select(["xAvg", "yAvg", "duration"]).to_numpy().T
        n_fix = fx.size
        fix_idx = np.arange(1, n_fix + 1)

        norm = mplcolors.BoundaryNorm(np.arange(1, n_fix + 2), cmap.N)

        # saccades
        sac_t = (
            sac_by_phase[phase]["tStart"].to_numpy()
            if plot_saccades and phase in sac_by_phase
            else np.empty(0)
        )

        # samples
        if plot_samples and phase in samp_by_phase and samp_by_phase[phase].height:
            samp_phase = samp_by_phase[phase]
            t0 = samp_phase["tSample"][0]
            ts = (samp_phase["tSample"].to_numpy() - t0) 
            get = samp_phase.get_column
            lx = get("LX").to_numpy() if "LX" in samp_phase.columns else None
            ly = get("LY").to_numpy() if "LY" in samp_phase.columns else None
            rx = get("RX").to_numpy() if "RX" in samp_phase.columns else None
            ry = get("RY").to_numpy() if "RY" in samp_phase.columns else None
            gx = get("X").to_numpy() if "X" in samp_phase.columns else None
            gy = get("Y").to_numpy() if "Y" in samp_phase.columns else None
        else:
            t0 = None

        # ---------- figure -----------------------------
        fig, ax_main, ax_gaze = _make_axes(plot_samples and t0 is not None)
        # scatter fixations
        sc = ax_main.scatter(
            fx,
            fy,
            c=fix_idx,
            s=fdur,
            cmap=cmap,
            norm=norm,
            alpha=0.5,
            zorder=2,
        )
        fig.colorbar(
            sc,
            ax=ax_main,
            ticks=[1, n_fix // 2 if n_fix > 2 else n_fix, n_fix],
            fraction=0.046,
            pad=0.04,
        ).set_label("# of fixation")

        # ---------- stimulus imagery / bbox ------------
        if phase_data and phase[0] in phase_data:
            pdict = phase_data[phase[0]]
            coords = pdict.get("img_plot_coords") or []
            bbox = pdict.get('bbox',None) 
            for img_path, box in zip(pdict.get("img_paths", []), coords):

                ax_main.imshow(_maybe_cache_img(img_path), extent=[box[0], box[2], box[3], box[1]], zorder=0)
            if bbox is not None:
                x1, y1, x2, y2 = bbox
                ax_main.plot([x1, x2, x2, x1, x1], [y1, y1, y2, y2, y1], color='red', linewidth=1.5, zorder=3)

        # ---------- gaze traces ------------------------
        if ax_gaze is not None:
            if lx is not None:
                ax_main.plot(lx, ly, "--", color="C0", zorder=1)
                ax_gaze.plot(ts, lx, label="Left X")
                ax_gaze.plot(ts, ly, label="Left Y")
            if rx is not None:
                ax_main.plot(rx, ry, "--", color="k", zorder=1)
                ax_gaze.plot(ts, rx, label="Right X")
                ax_gaze.plot(ts, ry, label="Right Y")
            if gx is not None:
                ax_main.plot(gx, gy, "--", color="k", zorder=1, alpha=0.6)
                ax_gaze.plot(ts, gx, label="X")
                ax_gaze.plot(ts, gy, label="Y")

            # fixation spans
            bars   = np.c_[phase_fix['tStart'].to_numpy() - t0,
                        phase_fix['duration'].to_numpy()]
            height = ax_gaze.get_ylim()[1] - ax_gaze.get_ylim()[0]
            colors = cmap(norm(fix_idx))

            # Draw all bars in one call; no BrokenBarHCollection import needed
            ax_gaze.broken_barh(bars, (0, height), facecolors=colors, alpha=0.4)
            # saccades
            if sac_t.size:
                ymin, ymax = ax_gaze.get_ylim()
                ax_gaze.vlines(
                    sac_t - t0,
                    ymin,
                    ymax,
                    colors="red",
                    linestyles="--",
                    linewidth=0.8,
                )

            # tidy gaze axis
            h, l = ax_gaze.get_legend_handles_labels()
            by_label = {lab: hdl for hdl, lab in zip(h, l)}
            ax_gaze.legend(
                by_label.values(),
                by_label.keys(),
                loc="center left",
                bbox_to_anchor=(1, 0.5),
            )
            ax_gaze.set_ylabel("Gaze")
            ax_gaze.set_xlabel("Time [s]")

        fig.tight_layout()

        # ---------- save / show ------------------------
        if folder_path:
            scan_name = f"scanpath_{trial_idx}"
            if tmin is not None and tmax is not None:
                scan_name += f"_{tmin}_{tmax}"
            out = Path(folder_path) / f"{scan_name}_{phase[0]}.png"
            fig.savefig(out, dpi=150)
            if saver:  saver.submit(fig.savefig, out, dpi=150)

        if display:
            plt.show()
        plt.close(fig)

    if not display:
        plt.ion()