Source code for kineticstoolkit.player

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright 2020-2025 Félix Chénier

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Provides the Player class to visualize points and frames in 3d.

The Player class is accessible directly from the toplevel Kinetics
Toolkit namespace (i.e., ktk.Player).

"""

__author__ = "Félix Chénier"
__copyright__ = "Copyright (C) 2020-2025 Félix Chénier"
__email__ = "chenier.felix@uqam.ca"
__license__ = "Apache 2.0"


from kineticstoolkit.timeseries import TimeSeries
from kineticstoolkit.tools import check_interactive_backend
import kineticstoolkit.geometry as geometry
from kineticstoolkit._repr import _format_dict_entries
from kineticstoolkit.classes import dict_to_monitored_dict
from kineticstoolkit.exceptions import (
    TimeSeriesMergeConflictError,
    raise_ktk_error,
)

from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib import animation
import numpy as np
from numpy import sin, cos
import time
from copy import deepcopy
from typing import Any
from kineticstoolkit.typing_ import check_param
import warnings

# To fit the new viewpoint on selecting a new point
import scipy.optimize as optim

REPR_HTML_MAX_DURATION = 10  # Max duration for _repr_html
PALETTE = {
    "k": (0.0, 0.0, 0.0),
    "r": (1.0, 0.0, 0.0),
    "g": (0.0, 1.0, 0.0),
    "b": (0.3, 0.3, 1.0),
    "y": (1.0, 1.0, 0.0),
    "m": (1.0, 0.0, 1.0),
    "c": (0.0, 1.0, 1.0),
    "w": (1.0, 1.0, 1.0),
}


HELP_TEXT = """
    ktk.Player help
    ----------------------------------------------------
    KEYBOARD COMMANDS
    show/hide this help : h
    previous index      : left
    next index          : right
    previous second     : shift+left
    next second         : shift+right
    play/pause          : space
    2x playback speed   : +
    0.5x playback speed : -
    toggle track        : t
    toggle perspective  : d (depth)
    set back/front view : 1/2
    set left/right view : 3/4
    set top/bottom view : 5/6
    set initial view    : 0
    ----------------------------------------------------
    MOUSE COMMANDS
    select a point      : left-click
    3d rotate           : left-drag
    pan                 : middle-drag or shift+left-drag
    zoom                : right-drag or wheel
"""


def _parse_color(
    value: str | tuple[float, float, float]
) -> tuple[float, float, float]:
    """Convert a color specification into a tuple[float, float, float]."""
    if isinstance(value, str):
        try:
            return PALETTE[value]
        except KeyError:
            raise ValueError(
                f"The specified color '{value}' is not recognized."
            )

    # Here, it's a sequence. Cast to tuple, check and return.
    value = tuple(value)  # type: ignore
    check_param("value", value, tuple, length=3, contents_type=float)
    if (
        (value[0] < 0.0)
        or (value[1] < 0.0)
        or (value[2] < 0.0)
        or (value[0] > 1.0)
        or (value[1] > 1.0)
        or (value[2] > 1.0)
    ):
        raise ValueError(
            f"The specified color '{value}' is invalid because each R, G, B "
            "value must be between 0.0 and 1.0."
        )
    return value


[docs] class Player: """ A class that allows visualizing points and frames in 3D. `player = ktk.Player(parameters)` creates and launches an interactive Player instance. Once the window is open, press `h` to show a help overlay. All of the following parameters are also accessible as read/write properties, except the contents and the interconnections that are accessible using `get_contents`, `set_contents`, `get_interconnections` and `set_interconnections`. Parameters ---------- *ts Contains the points and frames to visualize, where each data key is either a point position expressed as Nx4 array, or a frame expressed as a Nx4x4 array. Multiple TimeSeries can be provided. interconnections Optional. Each key corresponds to a group of interconnections, which is a dictionary with the following keys: - "Links": list of connections where each string is a point name. For example, to create a link that connects Point1 to Point2, and another link that spans Point3, Point4 and Point5:: interconnections["Example"]["Links"] = [ ["Point1", "Point2"], ["Point3", "Point4", "Point5"] ] which internally is converted to:: interconnections["Example"]["Links"] = [ ["Point1", "Point2"], ["Point3", "Point4"], ["Point4", "Point5"] ] Point names can include wildcards (*) either as a prefix or as a suffix. This is useful to apply a single set of interconnections to multiple bodies. For instance, if the Player's contents include these points: [Body1_HipR, Body1_HipL, Body1_L5S1, Body2_HipR, Body2_HipL, Body2_L5S1], we could link L5S1 and both hips at once using:: interconnections["Pelvis"]["Links"] = [ ["*_HipR", "*_HipL", "*_L5S1"] ] - "Color": character or tuple (RGB) that represents the color of the link. These two examples are equivalent:: interconnections["Pelvis"]["Color"] = 'r' interconnections["Pelvis"]["Color"] = (1.0, 0.0, 0.0) Its default value connects the four corners of force platforms in purple:: interconnections = { "ForcePlatforms": { "Links": [['*_Corner1', '*_Corner2'], ['*_Corner2', '*_Corner3'], ['*_Corner3', '*_Corner4'], ['*_Corner1', '*_Corner4']] "Color": (0.5, 0.0, 1.0) } } vectors Optional. A dictionary where each key is the name of a vector and each value contains its origin, scale, and color. For example:: vectors = { "WristForce": { "Origin": "WristCenter", "Scale": 0.001, "Color": (1.0, 1.0, 0.0) }, "ElbowForce": { "Origin": "ElbowCenter", "Scale": 0.001, "Color": (1.0, 1.0, 0.0) }, } will draw lines for the forces WristForce and ElbowForce, with their origin being at WristCenter and ElbowCenter, and with a scale of 0.001 metre per newton. Force and point names can include wildcards (*) either as a prefix or as a suffix. For instance, to draw forces recorded by multiple force plates, we could use:: vectors = { "*Force": { "Origin": "*COP", "Scale": 0.001, "Color": (1.0, 1.0, 0.0) } } which would assign any point ending by "COP" to its counterpart force. This is the default, so that force plate data read by read_c3d_file are shown by default in the Player. current_index Optional. The current index being shown. current_time Optional. The current time being shown. playback_speed Optional. Speed multiplier. Set to 1.0 for normal speed, 1.5 to increase playback speed by 50%, etc. up Optional. Defines the ground plane by setting which axis is up. May be {"x", "y", "z", "-x", "-y", "-z"}. Default is "y". anterior Optional. Defines the anterior direction. May be {"x", "y", "z", "-x", "-y", "-z"}. Default is "x". zoom Optional. Camera zoom multiplier. azimuth Optional. Camera azimuth in radians. If `anterior` is set, then an azimuth of 0 corresponds to the right sagittal plane, pi/2 to the front frontal plane, -pi/2 to the back frontal plane, etc. elevation Optional. Camera elevation in radians. Default is 0.2. If `up` is set, then a value of 0 corresponds to a purely horizontal view, pi/2 to the top transverse plane, -pi/2 to the bottom transverse plane, etc. perspective Optional. True to draw the scene using perspective, False to draw the scene orthogonally. pan Optional. Camera translation (panning). Default is (0.0, 0.0). target Optional. Camera target in metres. Default is (0.0, 0.0, 0.0). track Optional. False to keep the camera static, True to follow the last selected point when changing index. Default is False. default_point_color Optional. Default color for points that do not have a "Color" data_info. Can be a character or tuple (RGB) where each RGB color is between 0.0 and 1.0. Default is (0.8, 0.8, 0.8). default_interconnection_color Optional. Default color for interconnections. Can be a character or tuple (RGB) where each RGB color is between 0.0 and 1.0. Default is (0.8, 0.8, 0.8). default_vector_color Optional. Default color for vectors. Can be a character or tuple (RGB) where each RGB color is between 0.0 and 1.0. Default is (1.0, 1.0, 0.0). point_size Optional. Point size as defined by Matplotlib marker size. Default is 4.0. interconnection_width Optional. Width of the interconnections as defined by Matplotlib line width. Default is 1.5. vector_width Optional. Width of the vectors as defined by Matplotlib line width. Default is 2.0. frame_size Optional. Length of the frame axes in metres. Default is 0.1. frame_width Optional. Width of the frame axes as defined by Matplotlib line width. Default is 3.0. grid_size Optional. Length of one side of the grid in metres. Default is 10.0. grid_subdivision_size Optional. Length of one subdivision of the grid in metres. Default is 1.0. grid_width Optional. Width of the grid lines as defined by Matplotlib line width. Default is 1.0. grid_origin Optional. Origin of the grid in metres. Default is (0.0, 0.0, 0.0). grid_color Optional. Color of the grid. Can be a character or tuple (RGB) where each RGB color is between 0.0 and 1.0. Default is (0.3, 0.3, 0.3). background_color Optional. Background color. Can be a character or tuple (RGB) where each RGB color is between 0.0 and 1.0. Default is (0.0, 0.0, 0.0). Note ---- Matplotlib must be in interactive mode. """ # %% Init and properties getters and setters # Internal variables - for mypy _being_constructed: bool _contents: TimeSeries _processed_points: TimeSeries _processed_frames: TimeSeries _oriented_target: tuple[float, float, float] _interconnections: dict[str, dict[str, Any]] # Wildcard-extended interconnections, and with all fields including Color: _processed_interconnections: dict[str, dict[str, Any]] _vectors: dict[str, dict[str, Any]] _processed_vectors: dict[str, dict[str, Any]] _colors: set[tuple[float, float, float]] # A list of all point colors _selected_points: list[str] # List of point names _last_selected_point: str _current_index: int _current_time: float _playback_speed: float _up: str _anterior: str _zoom: float _azimuth: float _elevation: float _perspective: bool _initial_elevation: float _initial_azimuth: float _initial_perspective: bool _pan: np.ndarray _target: np.ndarray _track: bool _default_point_color: tuple[float, float, float] _default_interconnection_color: tuple[float, float, float] _default_vector_color: tuple[float, float, float] _point_size: float _interconnection_width: float _vector_width: float _frame_size: float _frame_width: float _force_factor: float _force_color: tuple[float, float, float] _force_width: float _grid_size: float _grid_subdivision_size: float _grid_width: float _grid_origin: np.ndarray _grid_color: tuple[float, float, float] _background_color: tuple[float, float, float] _title_text: str def __init__( self, *ts: TimeSeries, interconnections: dict[str, dict[str, Any]] = { "ForcePlatforms": { "Links": [ ["*_Corner1", "*_Corner2"], ["*_Corner2", "*_Corner3"], ["*_Corner3", "*_Corner4"], ["*_Corner1", "*_Corner4"], ], "Color": (0.5, 0.0, 1.0), }, }, vectors: dict[str, dict[str, Any]] = { "*Force": { "Origin": "*COP", "Scale": 0.001, "Color": (1.0, 1.0, 0.0), } }, current_index: int = 0, current_time: float | None = None, playback_speed: float = 1.0, up: str = "y", anterior: str = "x", zoom: float = 1.0, azimuth: float = 0.0, elevation: float = 0.2, pan: tuple[float, float] = (0.0, 0.0), target: tuple[float, float, float] = (0.0, 0.0, 0.0), perspective: bool = True, track: bool = False, default_point_color: str | tuple[float, float, float] = ( 0.8, 0.8, 0.8, ), default_interconnection_color: str | tuple[float, float, float] = ( 0.8, 0.8, 0.8, ), default_vector_color: str | tuple[float, float, float] = ( 1.0, 1.0, 0.0, ), point_size: float = 4.0, interconnection_width: float = 1.5, vector_width: float = 2.0, frame_size: float = 0.1, frame_width: float = 3.0, grid_size: float = 10.0, grid_subdivision_size: float = 1.0, grid_width: float = 1.0, grid_origin: tuple[float, float, float] = (0.0, 0.0, 0.0), grid_color: str | tuple[float, float, float] = ( 0.3, 0.3, 0.3, ), background_color: str | tuple[float, float, float] = ( 0.0, 0.0, 0.0, ), **kwargs, # Can be "inline_player=True", or older parameter names ): # Allow older parameter names if "segments" in kwargs and interconnections == {}: interconnections = kwargs["segments"] if "segment_width" in kwargs: interconnection_width = kwargs["segment_width"] if "current_frame" in kwargs: current_index = kwargs["current_frame"] if "translation" in kwargs: pan = kwargs["translation"] if "marker_radius" in kwargs: point_size = kwargs["marker_radius"] if "axis_length" in kwargs: frame_size = kwargs["axis_length"] if "axis_width" in kwargs: frame_width = kwargs["axis_width"] check_param("ts", ts, tuple, contents_type=TimeSeries) # The other parameters are checked by the property setters. # Merge all input TimeSeries merged_ts = TimeSeries() for one_ts in ts: try: merged_ts.merge(one_ts, in_place=True, on_conflict="error") except TimeSeriesMergeConflictError as e: warnings.warn( "Duplicate keys were found in at least two of the " "Player's input TimeSeries. When merging the inputs, we " f"got this error: {e} " "Please merge your input TimeSeries manually to avoid " "this situation. This will become an error in the future." ) # Warn if Matplotlib is not interactive check_interactive_backend() # Assign properties # Empty content for now. We set the final content after all # initializations. self._being_constructed = True # We ensure that _contents will ALWAYS have at least one sample. self._contents = TimeSeries(time=[0]) self._grid = np.array([]) self._processed_points = TimeSeries(time=self._contents.time) self._processed_frames = TimeSeries(time=self._contents.time) self._oriented_target = (0.0, 0.0, 0.0) self._interconnections = interconnections # Just to put stuff for now self._processed_interconnections = interconnections # idem self._vectors = vectors # idem self._processed_vectors = vectors # idem self._selected_points = [] self._last_selected_point = "" # Assign standard properties self.current_index = current_index self.playback_speed = playback_speed self.up = up self.anterior = anterior self.zoom = zoom self.azimuth = azimuth self.elevation = elevation self.perspective = perspective self._initial_elevation = elevation self._initial_azimuth = azimuth self._initial_perspective = perspective self.pan = pan self.target = target self.track = track self.default_point_color = default_point_color self.default_interconnection_color = default_interconnection_color self.default_vector_color = default_vector_color self.point_size = point_size self.interconnection_width = interconnection_width self.vector_width = vector_width self.frame_size = frame_size self.frame_width = frame_width self.grid_size = grid_size self.grid_width = grid_width self.grid_subdivision_size = grid_subdivision_size self.grid_origin = grid_origin self.grid_color = grid_color self.background_color = background_color self.title_text = "" self._running = False self._colors = set() # Init mouse navigation state self._state = { "ShiftPressed": False, "MouseLeftPressed": False, "MouseMiddlePressed": False, "MouseRightPressed": False, "MousePositionOnPress": (0.0, 0.0), "MousePositionOnMiddlePress": (0.0, 0.0), "MousePositionOnRightPress": (0.0, 0.0), "PanOnMousePress": (0.0, 0.0), "AzimutOnMousePress": 0.0, "ElevationOnMousePress": 0.0, "SystemTimeOnLastUpdate": time.time(), } # Create the figure and prepare its contents (fig, axes, anim) = self._create_empty_figure() self._mpl_objects = { "Figure": fig, "Axes": axes, "Anim": anim, } # Add the true contents using the public interface so that everything # is refreshed automatically self._being_constructed = False self.set_contents(merged_ts) self.vectors = vectors self.interconnections = interconnections self.grid_origin = grid_origin # Refresh grid # Now that everything is loaded, we can set the current time if # needed. if current_time is not None: self.current_time = current_time @property def contents(self): """Use get_contents or set_contents instead.""" raise AttributeError( "Please use Player.get_contents() and Player.set_contents() to " "read and write contents." ) @contents.setter def contents(self, value): """Use get_contents or set_contents instead.""" raise AttributeError( "Please use Player.get_contents() and Player.set_contents() to " "read and write contents." ) @property def interconnections(self): """Read/write interconnections.""" return self._interconnections @interconnections.setter def interconnections(self, value): """Set interconnections value.""" check_param("interconnections", value, dict, key_type=str) # Other checks during processing self._interconnections = value # Cast to MonitoredDict and MonitoredList in _process_interconnections self._process_interconnections() if not self._being_constructed: self._refresh() @property def vectors(self): """Read/write vectors.""" return self._vectors @vectors.setter def vectors(self, value): """Set vectors value.""" check_param("vectors", value, dict, key_type=str) # Other checks during processing self._vectors = value # Cast to MonitoredDict and MonitoredList in _process_vectors self._process_vectors() self._process_contents() if not self._being_constructed: self._refresh() @property def current_index(self) -> int: """Read/write current_index.""" return self._current_index @current_index.setter def current_index(self, value: int): """Set current_index value.""" self._current_index = value % len(self._contents.time) if not self._being_constructed: if self.track: self._recenter() self._fast_refresh() @property def current_time(self) -> float: """Read/write current_time.""" return self._contents.time[self._current_index] @current_time.setter def current_time(self, value: float): """Set current_time value.""" check_param("current_time", value, float) index = int(np.argmin(np.abs(self._contents.time - value))) self.current_index = index @property def playback_speed(self) -> float: """Read/write playback_speed.""" return self._playback_speed @playback_speed.setter def playback_speed(self, value: float): """Set playback_speed value.""" check_param("playback_speed", value, float) self._playback_speed = value @property def up(self) -> str: """Read/write up.""" return self._up @up.setter def up(self, value: str): """Set up value.""" check_param( "up", value, str, expected_values=["x", "y", "z", "-x", "-y", "-z"] ) self._up = value if not self._being_constructed: if self._up[-1] == self._anterior[-1]: # up and anterior cannot be the same axis. self._anterior = "x" if value[-1] != "x" else "y" self._process_contents() self._refresh() @property def anterior(self) -> str: """Read/write anterior.""" return self._anterior @anterior.setter def anterior(self, value: str): """Set anterior value.""" check_param( "anterior", value, str, expected_values=["x", "y", "z", "-x", "-y", "-z"], ) self._anterior = value if not self._being_constructed: if self._anterior[-1] == self._up[-1]: # up and anterior cannot be the same axis. self._up = "y" if value[-1] != "y" else "z" self._process_contents() self._refresh() @property def zoom(self) -> float: """Read/write zoom.""" return self._zoom @zoom.setter def zoom(self, value: float): """Set zoom value.""" check_param("zoom", value, float) self._zoom = value if not self._being_constructed: self._fast_refresh() @property def azimuth(self) -> float: """Read/write azimuth.""" return self._azimuth @azimuth.setter def azimuth(self, value: float): """Set azimuth value.""" check_param("azimuth", value, float) self._azimuth = value if not self._being_constructed: self._fast_refresh() @property def elevation(self) -> float: """Read/write elevation.""" return self._elevation @elevation.setter def elevation(self, value: float): """Set elevation value.""" check_param("elevation", value, float) self._elevation = value if not self._being_constructed: self._fast_refresh() @property def pan(self): """Read/write pan as (x, y).""" return (self._pan[0], self._pan[1]) @pan.setter def pan(self, value): """Set pan value using (x, y) or (x, y, ...).""" value = tuple(value) check_param("pan", value, tuple, contents_type=float) self._pan = np.array(value)[0:2] if not self._being_constructed: self._fast_refresh() @property def target(self): """Read/write target as (x, y, z).""" return tuple(self._target) @target.setter def target(self, value): """Set target value using (x, y, z) or (x, y, z, 1.0).""" value = tuple(value) check_param("target", value, tuple, contents_type=float) self._target = np.array(value)[0:3] if not self._being_constructed: self._process_contents() self._fast_refresh() @property def perspective(self) -> bool: """Read/write perspective.""" return self._perspective @perspective.setter def perspective(self, value: bool): """Set perspective value.""" check_param("perspective", value, bool) self._perspective = value if not self._being_constructed: self._fast_refresh() @property def track(self) -> bool: """Read/write track.""" return self._track @track.setter def track(self, value: bool): """Set perspective value.""" check_param("track", value, bool) self._track = value if not self._being_constructed: self._fast_refresh() @property def default_point_color(self): """Read/write default_point_color.""" return self._default_point_color @default_point_color.setter def default_point_color(self, value): """Set default_point_color value.""" self._default_point_color = _parse_color(value) if not self._being_constructed: self._refresh() @property def default_interconnection_color(self): """Read/write default_interconnection_color.""" return self._default_interconnection_color @default_interconnection_color.setter def default_interconnection_color(self, value): """Set default_interconnection_color value.""" self._default_interconnection_color = _parse_color(value) if not self._being_constructed: self._refresh() @property def default_vector_color(self): """Read/write default_vector_color.""" return self._default_vector_color @default_vector_color.setter def default_vector_color(self, value): """Set default_vector_color value.""" self._default_vector_color = _parse_color(value) if not self._being_constructed: self._refresh() @property def point_size(self) -> float: """Read/write point_size.""" return self._point_size @point_size.setter def point_size(self, value: float): """Set point_size value.""" check_param("point_size", value, float) self._point_size = value if not self._being_constructed: self._refresh() @property def interconnection_width(self) -> float: """Read/write interconnection_width.""" return self._interconnection_width @interconnection_width.setter def interconnection_width(self, value: float): """Set interconnection_width value.""" check_param("interconnection_width", value, float) self._interconnection_width = value if not self._being_constructed: self._refresh() @property def vector_width(self) -> float: """Read/write vector_width.""" return self._vector_width @vector_width.setter def vector_width(self, value: float): """Set vector_width value.""" check_param("vector_width", value, float) self._vector_width = value if not self._being_constructed: self._refresh() @property def frame_size(self) -> float: """Read/write frame_size.""" return self._frame_size @frame_size.setter def frame_size(self, value: float): """Set frame_size value.""" check_param("frame_size", value, float) self._frame_size = value if not self._being_constructed: self._fast_refresh() @property def frame_width(self) -> float: """Read/write frame_width.""" return self._frame_width @frame_width.setter def frame_width(self, value: float): """Set frame_width value.""" check_param("frame_width", value, float) self._frame_width = value if not self._being_constructed: self._refresh() @property def grid_size(self) -> float: """Read/write grid_size.""" return self._grid_size @grid_size.setter def grid_size(self, value: float): """Set grid_size value.""" check_param("grid_size", value, float) self._grid_size = value if not self._being_constructed: self._update_grid() self._refresh() @property def grid_width(self) -> float: """Read/write grid_width.""" return self._grid_width @grid_width.setter def grid_width(self, value: float): """Set grid_width value.""" check_param("grid_width", value, float) self._grid_width = value if not self._being_constructed: self._update_grid() self._refresh() @property def grid_subdivision_size(self) -> float: """Read/write grid_subdivision_size.""" return self._grid_subdivision_size @grid_subdivision_size.setter def grid_subdivision_size(self, value: float): """Set grid_subdivision_size value.""" check_param("grid_subdivision_size", value, float) self._grid_subdivision_size = value if not self._being_constructed: self._update_grid() self._refresh() @property def grid_origin(self): """Read/write grid_origin.""" return tuple(self._grid_origin) @grid_origin.setter def grid_origin(self, value): """Set grid_origin value.""" value = tuple(value) check_param("grid_subdivision_size", value, tuple, contents_type=float) self._grid_origin = np.array(value)[0:3] if not self._being_constructed: self._update_grid() self._refresh() @property def grid_color(self): """Read/write grid_color.""" return self._grid_color @grid_color.setter def grid_color(self, value): """Set grid_color value.""" self._grid_color = _parse_color(value) if not self._being_constructed: self._update_grid() self._refresh() @property def background_color(self): """Read/write background_color.""" return self._background_color @background_color.setter def background_color(self, value): """Set background_color value.""" self._background_color = _parse_color(value) if not self._being_constructed: self._refresh() @property def title_text(self) -> str: """Read/write the text info on top of the figure.""" return self._title_text @title_text.setter def title_text(self, value: str): """Set title_text.""" check_param("title_text", value, str) self._title_text = value if not self._being_constructed: self._mpl_objects["Axes"].set_title(value, pad=-20) def __dir__(self): """Return directory.""" return [ "play", "pause", "set_view", "close", "to_image", "to_video", "get_contents", "set_contents", ] def __str__(self) -> str: """Print a textual description of the Player properties.""" return "ktk.Player with properties:\n" + _format_dict_entries( { "current_index": self.current_index, "current_time": self.current_time, "playback_speed": self.playback_speed, "up": self.up, "anterior": self.anterior, "zoom": self.zoom, "azimuth": self.azimuth, "elevation": self.elevation, "perspective": self.perspective, "pan": self.pan, "target": self.target, "track": self.track, "default_point_color": self.default_point_color, "default_interconnection_color": self.default_interconnection_color, "default_vector_color": self.default_vector_color, "point_size": self.point_size, "interconnection_width": self.interconnection_width, "vector_width": self.vector_width, "frame_size": self.frame_size, "frame_width": self.frame_width, "grid_size": self.grid_size, "grid_width": self.grid_width, "grid_subdivision_size": self.grid_subdivision_size, "grid_origin": self.grid_origin, "grid_color": self.grid_color, "background_color": self.background_color, "title_text": self.title_text, } ) def __repr__(self) -> str: """Generate the class representation.""" return str(self) def _create_empty_figure(self) -> tuple: """Create figure and return Figure, Axes and AnimationTimer.""" # Create the figure and axes (fig, ax) = plt.subplots(num=None, figsize=(12, 9)) fig.set_facecolor("k") # Remove the toolbar try: # Try, setVisible method not always there fig.canvas.toolbar.setVisible(False) # type: ignore except AttributeError: pass plt.tight_layout() # Connect the callback functions fig.canvas.mpl_connect("pick_event", self._on_pick) fig.canvas.mpl_connect("key_press_event", self._on_key) fig.canvas.mpl_connect("key_release_event", self._on_release) fig.canvas.mpl_connect("scroll_event", self._on_scroll) fig.canvas.mpl_connect("button_press_event", self._on_mouse_press) fig.canvas.mpl_connect("button_release_event", self._on_mouse_release) fig.canvas.mpl_connect("motion_notify_event", self._on_mouse_motion) # Create the animation anim = animation.FuncAnimation( fig, self._on_timer, # type: ignore interval=33, cache_frame_data=False, ) # 30 ips return (fig, ax, anim) # %% Contents getters/setters # We use proper setters and getters to be absolutely sure the contents # could not be modified without adapting the Player to this new contents. # (e.g., rebuild the interconnection plots)
[docs] def get_contents(self) -> TimeSeries: """Get contents value.""" return self._contents
[docs] def set_contents(self, value: TimeSeries) -> None: """Set contents value.""" check_param("value", value, TimeSeries) # First reset index to 0 to be sure that we won't end up out of bounds self._current_index = 0 # Ensure that there is at least one sample so that the Player does not # crash and shows nothing instead. if len(value.time) > 0: self._contents = value.copy() else: self._contents = TimeSeries(time=[0]) self._process_vectors() self._process_interconnections() self._process_contents() self._refresh()
# %% Interconnection management def _interconnections_callback(self, *args, **kwargs) -> None: """ Process interconnection modification and refresh. Parameters *args and **kwargs are only there to catch values sent by the MonitoredList and MonitoredDict callbacks. They are not used in this function. """ self._process_interconnections() self._refresh() def _parse_interconnection_group( self, group: dict[str, Any] ) -> dict[str, Any]: """Return a well-formatted group or raise a ValueError.""" output = dict() # type: dict[str, Any] # Add links try: output["Links"] = [] links = group["Links"] for link in links: dest_link = [] for point in link: if not isinstance(point, str): raise TypeError("Should be a string.") dest_link.append(point) output["Links"].append(dest_link) except (TypeError, KeyError, ValueError): raise ValueError( f"Impossible to parse interconnection group {group}." ) # Add parsed color, create it if not existent try: color = group["Color"] except KeyError: color = self._default_interconnection_color output["Color"] = _parse_color(color) return output def _process_interconnections(self) -> None: """ Process interconnections after setting or modifying them. Update _processed_interconnections. Does not refresh. We don't throw errors if interconnections are malformed, simply because maybe they are being built. We try our best to parse what is there. """ # Convert _interconnections to monitored dicts and lists self._interconnections = dict_to_monitored_dict( self._interconnections, callback=self._interconnections_callback ) # Parse _interconnections to a full-defined temporary value parsed_interconnections = dict() for group in self._interconnections: try: parsed_interconnections[group] = ( self._parse_interconnection_group( self._interconnections[group] ) ) except ValueError: pass # Simply don't include this group as it's not valid. # Make a set of all patterns matched by the * in interconnection # point names. patterns = {"__NO_WILD_CARD_DEFAULT_PATTERN__"} keys = list(self._contents.data.keys()) for group in parsed_interconnections: links = parsed_interconnections[group]["Links"] for i_link, link in enumerate(links): for i_point, point in enumerate(link): if point.startswith("*") and point.endswith("*"): warnings.warn( f"Point {point} found in interconnections. " "Only one wildcard can be used, either as a " "prefix or as a suffix." ) continue elif point.startswith("*"): for key in keys: if key.endswith(point[1:]): patterns.add( key[: (len(key) - len(point) + 1)] ) elif point.endswith("*"): for key in keys: if key.startswith(point[:-1]): patterns.add(key[(len(point) - 1) :]) # Extend every * to every pattern # Also parse color or add default color if not present in # _interconnections self._processed_interconnections = dict() for pattern in patterns: for group in parsed_interconnections: extended_group_key = f"{pattern}{group}" self._processed_interconnections[extended_group_key] = dict() self._processed_interconnections[extended_group_key][ "Color" ] = parsed_interconnections[group]["Color"] # Add every link of this segment self._processed_interconnections[extended_group_key][ "Links" ] = [] links = parsed_interconnections[group]["Links"] for i_link, link in enumerate(links): self._processed_interconnections[extended_group_key][ "Links" ].append([s.replace("*", pattern) for s in link]) # %% Vector management def _vectors_callback(self, *args, **kwargs) -> None: """ Process vector modification and refresh. Parameters *args and **kwargs are only there to catch values sent by the MonitoredList and MonitoredDict callbacks. They are not used in this function. """ self._process_vectors() self._process_contents() # Because points are dependents of vectors self._refresh() def _parse_vector(self, value: dict[str, Any]) -> dict[str, Any]: """Return a well-formatted vector dict or raise a ValueError.""" output = {} # Add origin try: origin = value["Origin"] check_param("value['Origin']", origin, str) output["Origin"] = origin except KeyError: raise ValueError("No origin") # Add parsed color, create it if not existent try: color = value["Color"] except KeyError: color = self._default_vector_color output["Color"] = _parse_color(color) # Add scale, create it of not existant try: scale = float(value["Scale"]) except (KeyError, TypeError, ValueError): scale = 1.0 output["Scale"] = scale return output def _process_vectors(self) -> None: """ Process vectors after setting or modifying them. Update _processed_vectors. Does not refresh. We don't throw errors if vectors are malformed, simply because maybe they are being built. We try our best to parse what is there. """ # Convert _vectors to monitored dicts and lists self._vectors = dict_to_monitored_dict( self._vectors, callback=self._vectors_callback ) # Parse _vectors to a full-defined temporary value parsed_vectors = dict() for vector in self._vectors: try: parsed_vectors[vector] = self._parse_vector( self._vectors[vector] ) except ValueError: pass # Simply don't include this vector as it's not valid. # Make a set of all patterns matched by the * in vectors patterns = {"__NO_WILD_CARD_DEFAULT_PATTERN__"} keys = list(self._contents.data.keys()) for vector in parsed_vectors: point = parsed_vectors[vector]["Origin"] if "*" not in vector and "*" not in point: continue if vector.startswith("*") and vector.endswith("*"): warnings.warn( f"Vector {vector} found in vectors. " "Only one wildcard can be used, either as a " "prefix or as a suffix." ) continue if point.startswith("*") and point.endswith("*"): warnings.warn( f"Point {point} found in vectors. " "Only one wildcard can be used, either as a " "prefix or as a suffix." ) continue if ( (vector.startswith("*") and not point.startswith("*")) or (not vector.startswith("*") and point.startswith("*")) or (vector.endswith("*") and not point.endswith("*")) or (not vector.endswith("*") and point.endswith("*")) ): warnings.warn( f"Vector {vector} and Point {point} must have matching " "wildcards." ) continue # Here, everything is correct. We have either starting or ending # wildcards in both points and vectors. if point.startswith("*"): for key in keys: if key.endswith(point[1:]): patterns.add(key[: (len(key) - len(point) + 1)]) elif point.endswith("*"): for key in keys: if key.startswith(point[:-1]): patterns.add(key[(len(point) - 1) :]) # Extend every * to every pattern self._processed_vectors = dict() for pattern in patterns: for vector in parsed_vectors: point = parsed_vectors[vector]["Origin"] new_vector = vector.replace("*", pattern) new_point = point.replace("*", pattern) if ( new_vector in self._contents.data and new_point in self._contents.data ): self._processed_vectors[new_vector] = { "Origin": new_point, "Scale": parsed_vectors[vector]["Scale"], "Color": parsed_vectors[vector]["Color"], } # %% Contents management def _general_rotation(self) -> np.ndarray: """Return a 1x4x4 rotation matrix from up and anterior attributes.""" # Create a frame based on these specs if self.up == "x": up = [[1, 0, 0, 0]] elif self.up == "y": up = [[0, 1, 0, 0]] elif self.up == "z": up = [[0, 0, 1, 0]] elif self.up == "-x": up = [[-1, 0, 0, 0]] elif self.up == "-y": up = [[0, -1, 0, 0]] elif self.up == "-z": up = [[0, 0, -1, 0]] else: raise_ktk_error(ValueError("Bad value for up parameter.")) if self.anterior == "x": anterior = [[1, 0, 0, 0]] elif self.anterior == "y": anterior = [[0, 1, 0, 0]] elif self.anterior == "z": anterior = [[0, 0, 1, 0]] elif self.anterior == "-x": anterior = [[-1, 0, 0, 0]] elif self.anterior == "-y": anterior = [[0, -1, 0, 0]] elif self.anterior == "-z": anterior = [[0, 0, -1, 0]] else: raise_ktk_error(ValueError("Bad value for anterior parameter.")) inverse_transform = geometry.create_frames( origin=[[0, 0, 0, 1]], x=anterior, xy=up ) return geometry.inv(inverse_transform) def _process_contents(self) -> None: """ Update self._processed_points, _processed_frames, and _oriented_target. Rotate everything according to the up input, so that the end result is y up: |y | +---- x / z/ Also adds vectors and the global origin to _processed_frames. Does not refresh. """ self._processed_points = self._contents.copy(copy_data=False) self._processed_frames = self._contents.copy(copy_data=False) contents = self._contents.copy() # Add the global reference frame origin_name = "Origin" while origin_name in contents.data: origin_name += "_" contents.data[origin_name] = np.repeat( np.eye(4)[np.newaxis], len(contents.time), axis=0 ) rotation = self._general_rotation() # Orient points, vectors and frames for key in contents.data: if geometry.is_transform_series(contents.data[key]): # Simply rotate it and add it to processed frames self._processed_frames.data[key] = ( geometry.get_global_coordinates( contents.data[key], rotation ) ) elif geometry.is_point_series(contents.data[key]): # Simply rotate it and add it to processed points self._processed_points.data[key] = ( geometry.get_global_coordinates( contents.data[key], rotation ) ) elif ( geometry.is_vector_series(contents.data[key]) and key in self._processed_vectors ): # This is a vector. We must scale it and add it to an origin. self._processed_points.data[key] = ( geometry.get_global_coordinates( self._processed_vectors[key]["Scale"] * contents.data[key] + contents.data[ self._processed_vectors[key]["Origin"] ], rotation, ) ) # Orient target oriented_target = geometry.get_global_coordinates( np.array( [[self._target[0], self._target[1], self._target[2], 1.0]] ), rotation, )[0, 0:3] self._oriented_target = ( oriented_target[0], oriented_target[1], oriented_target[2], ) # %% Projection and update def _recenter(self) -> None: """Recenter the current view on the last selected point (tracking).""" try: new_target = self._processed_points.data[ self._last_selected_point ][self.current_index] except (KeyError, IndexError): new_target = np.array([np.nan, np.nan, np.nan, np.nan]) if not np.isnan(np.sum(new_target)): self.target = new_target def _project_to_camera(self, points_3d: np.ndarray) -> np.ndarray: """ Get a 3d --> 2d projection of a list of points. The method uses the Player's camera properties to project a list of 3d points onto a 2d canvas. Parameters ---------- points_3d Nx4 array, where the first dimension is the number of points and the second dimension is (x, y, z, 1). Returns ------- Nx2 array, where the first dimension is the number of points and the second dimension is (x, y) to be ploted on a 2d graphic. """ # ------------------------------------------------------------ # Create the rotation matrix to convert the lab's coordinates # (x anterior, y up, z right) to the camera coordinates (x right, # y up, z deep) R = ( np.array( [ [2 * self.zoom, 0, 0, 0], [0, 2 * self.zoom, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1], ] ) @ np.array( [ [1, 0, 0, self.pan[0]], # Pan [0, 1, 0, self.pan[1]], [0, 0, 1, 0], [0, 0, 0, 1], ] ) @ np.array( [ [1, 0, 0, 0], [0, cos(-self.elevation), sin(self.elevation), 0], [0, sin(-self.elevation), cos(-self.elevation), 0], [0, 0, 0, 1], ] ) @ np.array( [ [cos(-self.azimuth), 0, sin(self.azimuth), 0], [0, 1, 0, 0], [sin(-self.azimuth), 0, cos(-self.azimuth), 0], [0, 0, 0, 1], ] ) @ np.array( [ [1, 0, 0, -self._oriented_target[0]], [0, 1, 0, -self._oriented_target[1]], [0, 0, -1, self._oriented_target[2]], [0, 0, 0, 1], ] ) ) # Add a first dimension to R and match first dimension of points_3d R = np.repeat(R[np.newaxis, :], points_3d.shape[0], axis=0) # Rotate points. rotated_points_3d = R @ points_3d[:, :, np.newaxis] rotated_points_3d = rotated_points_3d[:, :, 0] # Apply perspective. if self.perspective is True: # This uses an ugly magical constant but it works fine for now. denom = rotated_points_3d[:, 2] / 10 + 5 rotated_points_3d[:, 0] = rotated_points_3d[:, 0] / denom rotated_points_3d[:, 1] = rotated_points_3d[:, 1] / denom with np.errstate(invalid="ignore"): to_remove = denom < 1e-12 rotated_points_3d[to_remove, 0] = np.nan rotated_points_3d[to_remove, 1] = np.nan else: # Scale to match the point of view rotated_points_3d /= 5 # Return only x and y return rotated_points_3d[:, 0:2] def _update_grid(self) -> None: """ (Re)-create the grid. First create a ground plane matrix in the form: [ [x1, 0, z1], [x2, 0, z2], [nan, nan, nan], [x3, 0, z3], [x4, 0, z4], [nan, nan, nan, nan], ... ] The grid is on the x and z axes, at y=0, as to be shown in Matplotlib. Then translate it according to the 'up' attribute and 'grid_origin'. A full refresh must be called after to recreate the Matplotlib plot using the correct width. """ # Build the grid as an xz plane with y being up. temp_grid = [] # z-to-z lines for x in np.arange( -self._grid_size / 2, self._grid_size / 2 + self._grid_subdivision_size, self._grid_subdivision_size, ): for z in np.arange( -self._grid_size / 2, self._grid_size / 2 + self._grid_subdivision_size, self._grid_subdivision_size, ): temp_grid.append([x, 0.0, z, 1.0]) temp_grid.append([np.nan, np.nan, np.nan, np.nan]) # x-to-x lines for z in np.arange( -self._grid_size / 2, self._grid_size / 2 + self._grid_subdivision_size, self._grid_subdivision_size, ): for x in np.arange( -self._grid_size / 2, self._grid_size / 2 + self._grid_subdivision_size, self._grid_subdivision_size, ): temp_grid.append([x, 0.0, z, 1.0]) temp_grid.append([np.nan, np.nan, np.nan, np.nan]) self._grid = np.array(temp_grid) # Translate the grid translation = geometry.get_global_coordinates( [ [ self._grid_origin[0], self._grid_origin[1], self._grid_origin[2], 1.0, ] ], self._general_rotation(), )[0] translation[3] = 0 # Not a position, but a vector self._grid += translation def _update_points_interconnections_vectors(self) -> None: # Get a Nx4 matrices of every point at the current index points = self._processed_points if points is None: return else: n_points = len(points.data) points_data = dict() # Used to draw the points with different colors interconnection_points = dict() # Used to draw the interconnections for color in self._colors: # Reset unselected points points_data[(color, False)] = np.empty([n_points, 4]) points_data[(color, False)][:] = np.nan # Reset selected points points_data[(color, True)] = np.empty([n_points, 4]) points_data[(color, True)][:] = np.nan for i_point, point in enumerate(points.data): # Get this point's color if ( point in points.data_info and "Color" in points.data_info[point] ): color = _parse_color(points.data_info[point]["Color"]) else: color = self.default_point_color these_coordinates = points.data[point][self.current_index] interconnection_points[point] = these_coordinates # Assign to unselected(False) or selected(True) points_data if point in self._selected_points: points_data[(color, True)][i_point] = these_coordinates else: points_data[(color, False)][i_point] = these_coordinates # Update the points plot for color in self._colors: # Unselected points points_data[(color, False)] = self._project_to_camera( points_data[(color, False)] ) self._mpl_objects["PointPlots"][(color, False)].set_data( points_data[(color, False)][:, 0], points_data[(color, False)][:, 1], ) # Selected points points_data[(color, True)] = self._project_to_camera( points_data[(color, True)] ) self._mpl_objects["PointPlots"][(color, True)].set_data( points_data[(color, True)][:, 0], points_data[(color, True)][:, 1], ) # Draw the interconnections for interconnection in self._processed_interconnections: coordinates = [] links = self._processed_interconnections[interconnection]["Links"] for link in links: for point in link: try: coordinates.append(interconnection_points[point]) except KeyError: coordinates.append(np.repeat(np.nan, 4)) coordinates.append(np.repeat(np.nan, 4)) np_coordinates = np.array(coordinates) np_coordinates = self._project_to_camera(np_coordinates) self._mpl_objects["InterconnectionPlots"][ interconnection ].set_data(np_coordinates[:, 0], np_coordinates[:, 1]) # Draw the vectors for vector in self._processed_vectors: np_coordinates = np.ones((2, 4)) try: np_coordinates[0] = points.data[vector][self.current_index] except KeyError: # Most probably the vector was not a vector series and has not # been added to `points` timeseries. np_coordinates[0] = [np.nan, np.nan, np.nan, np.nan] origin = self._processed_vectors[vector]["Origin"] try: np_coordinates[1] = points.data[origin][self.current_index] except KeyError: # Most probably the origin was not a point series and has not # been added to `points` timeseries. np_coordinates[1] = [np.nan, np.nan, np.nan, np.nan] np_coordinates = self._project_to_camera(np_coordinates) self._mpl_objects["VectorPlots"][vector].set_data( np_coordinates[:, 0], np_coordinates[:, 1] ) def _fast_refresh(self) -> None: """Update plot data, assuming all plots have already been created.""" self._update_points_interconnections_vectors() # Get three (3N)x4 matrices (for x, y and z lines) for the rigid bodies # at the current index frames = self._processed_frames n_frames = len(frames.data) framex_data = np.empty([n_frames * 3, 4]) framey_data = np.empty([n_frames * 3, 4]) framez_data = np.empty([n_frames * 3, 4]) for i_rigid_body, rigid_body in enumerate(frames.data): # Origin framex_data[i_rigid_body * 3] = frames.data[rigid_body][ self.current_index, :, 3 ] framey_data[i_rigid_body * 3] = frames.data[rigid_body][ self.current_index, :, 3 ] framez_data[i_rigid_body * 3] = frames.data[rigid_body][ self.current_index, :, 3 ] # Direction framex_data[i_rigid_body * 3 + 1] = frames.data[rigid_body][ self.current_index ] @ np.array([self.frame_size, 0, 0, 1]) framey_data[i_rigid_body * 3 + 1] = frames.data[rigid_body][ self.current_index ] @ np.array([0, self.frame_size, 0, 1]) framez_data[i_rigid_body * 3 + 1] = frames.data[rigid_body][ self.current_index ] @ np.array([0, 0, self.frame_size, 1]) # NaN to cut the line between the different frames framex_data[i_rigid_body * 3 + 2] = np.repeat(np.nan, 4) framey_data[i_rigid_body * 3 + 2] = np.repeat(np.nan, 4) framez_data[i_rigid_body * 3 + 2] = np.repeat(np.nan, 4) # Update the ground plane if len(self._grid) > 0: gp = self._project_to_camera(self._grid) self._mpl_objects["GridPlot"].set_data(gp[:, 0], gp[:, 1]) # Create or update the frame plot framex_data = self._project_to_camera(framex_data) framey_data = self._project_to_camera(framey_data) framez_data = self._project_to_camera(framez_data) self._mpl_objects["FrameXPlot"].set_data( framex_data[:, 0], framex_data[:, 1] ) self._mpl_objects["FrameYPlot"].set_data( framey_data[:, 0], framey_data[:, 1] ) self._mpl_objects["FrameZPlot"].set_data( framez_data[:, 0], framez_data[:, 1] ) # Update the window title try: self._mpl_objects["Figure"].canvas.manager.set_window_title( f"{self.current_index}/{len(self._contents.time)}: " + "%2.2f s." % self._contents.time[self.current_index] ) except AttributeError: pass self._mpl_objects["Figure"].canvas.draw() def _update_colors(self) -> None: """Update self._colors.""" self._colors = set() # In contents (points and vectors) for key in self._contents.data: try: color = self._contents.data_info[key]["Color"] except KeyError: # Default color color = self._default_point_color self._colors.add(_parse_color(color)) # # In interconnections # for key in self._interconnections: # try: # color = self._interconnections[key]["Color"] # except KeyError: # Default color # color = self._default_interconnection_color # colors.add(_parse_color(color)) # # In vectors # for key in self._vectors: # try: # color = self._vectors[key]["Color"] # except KeyError: # Default color # color = self._default_vector_color # colors.add(_parse_color(color)) def _refresh(self): """ Perform a full refresh of the Player. Normally, this function does not need to be called by the user. Use it if the Player is not refreshed as it should. You may report this need as a bug in the issue tracker: https://github.com/kineticstoolkit/kineticstoolkit/issues """ # Clear and rebuild the mpl plots. self._mpl_objects["Plots"] = dict() self._mpl_objects["HelpText"] = None self._mpl_objects["Axes"].clear() self._mpl_objects["Figure"].set_facecolor(self._background_color) # Reset axes properties self._mpl_objects["Axes"].set_axis_off() # Create the ground plane self._mpl_objects["GridPlot"] = self._mpl_objects["Axes"].plot( np.nan, np.nan, linewidth=self._grid_width, color=self._grid_color, )[0] # Create the frame plots self._mpl_objects["FrameXPlot"] = self._mpl_objects["Axes"].plot( np.nan, np.nan, c="r", linewidth=self.frame_width, )[0] self._mpl_objects["FrameYPlot"] = self._mpl_objects["Axes"].plot( np.nan, np.nan, c="g", linewidth=self.frame_width, )[0] self._mpl_objects["FrameZPlot"] = self._mpl_objects["Axes"].plot( np.nan, np.nan, c="b", linewidth=self.frame_width, )[0] # Create the interconnection plots self._mpl_objects["InterconnectionPlots"] = dict() for interconnection in self._processed_interconnections: self._mpl_objects["InterconnectionPlots"][ interconnection ] = self._mpl_objects["Axes"].plot( np.nan, np.nan, "-", c=self._processed_interconnections[interconnection]["Color"], linewidth=self._interconnection_width, )[ 0 ] # Create the vector plots self._mpl_objects["VectorPlots"] = dict() for vector in self._processed_vectors: self._mpl_objects["VectorPlots"][vector] = self._mpl_objects[ "Axes" ].plot( np.nan, np.nan, "-", c=self._processed_vectors[vector]["Color"], linewidth=self._vector_width, )[ 0 ] # ---------------------- # Create the point plots # ---------------------- # Create all required point plots self._update_colors() self._mpl_objects["PointPlots"] = dict() for color in self._colors: # Unselected points self._mpl_objects["PointPlots"][ (color, False) ] = self._mpl_objects["Axes"].plot( np.nan, np.nan, ".", c=color, markersize=self._point_size, pickradius=1.1 * self._point_size, picker=True, )[ 0 ] # Selected points self._mpl_objects["PointPlots"][(color, True)] = self._mpl_objects[ "Axes" ].plot( np.nan, np.nan, ".", c=color, markersize=3 * self._point_size, )[ 0 ] # Add the title title_obj = plt.title("", fontfamily="monospace") plt.setp(title_obj, color=[0, 1, 0]) # Set a green title self._fast_refresh() # Draw everything once # Set limits once it's drawn self._mpl_objects["Axes"].set_xlim([-1.5, 1.5]) self._mpl_objects["Axes"].set_ylim([-1.0, 1.0]) def _set_new_target(self, target: tuple[float, float, float]) -> None: """Set new target and adapts pan and zoom consequently.""" # Save the current view if np.sum(np.isnan(target)) > 0: return initial_pan = deepcopy(self.pan) initial_zoom = deepcopy(self.zoom) initial_target = deepcopy(self.target) n_points = len(self._processed_points.data) points = np.empty((n_points, 4)) for i_point, point in enumerate(self._processed_points.data): points[i_point] = self._processed_points.data[point][ self.current_index ] initial_projected_points = self._project_to_camera(points) # Do not consider points that are not in the screen initial_projected_points[initial_projected_points[:, 0] < -1.5] = ( np.nan ) initial_projected_points[initial_projected_points[:, 0] > 1.5] = np.nan initial_projected_points[initial_projected_points[:, 1] < -1.0] = ( np.nan ) initial_projected_points[initial_projected_points[:, 1] > 1.0] = np.nan def error_function(input): self._pan = input[0:2] self._zoom = input[2] new_projected_points = self._project_to_camera(points) error = np.nanmean( (initial_projected_points - new_projected_points) ** 2 ) return error # Set the new target self._target = np.array(target) self._process_contents() # Try to find a camera pan/zoom so that the view is similar res = optim.minimize(error_function, np.hstack((self.pan, self.zoom))) if res.success is False: self.pan = initial_pan self.zoom = initial_zoom self.target = initial_target self._fast_refresh() # ------------------------------------ # Callbacks def _on_close(self, _) -> None: # pragma: no cover # Release all references to objects self.close() def _on_timer(self, _) -> None: # pragma: no cover """Implement callback for the animation timer object.""" if self._running is True: # We check self._running because we can enter this callback # even if the animation has been deactivated. This is because the # recommended way to deactivate a timer is to unreference it, # however the garbage collector may take time deleting the timer # and we will end up with still a few timer callbacks. Checking # self._running makes sure that we effectively stop. current_system_time = time.time() current_index = self.current_index self.current_time += self.playback_speed * ( time.time() - self._state["SystemTimeOnLastUpdate"] # type: ignore ) # Type ignored because mypy considers # self._state["SystemTimeOnLastUpdate"] as an "object" if current_index == self.current_index: # The time wasn't enough to advance a frame. Articifically # advance a frame. self.current_index += 1 self._state["SystemTimeOnLastUpdate"] = current_system_time else: self._mpl_objects["Anim"].event_source.stop() def _on_pick(self, event): # pragma: no cover """Implement callback for point selection.""" if event.mouseevent.button == 1: index = event.ind selected_point = list(self._processed_points.data.keys())[index[0]] self.title_text = selected_point # Mark selected self._selected_points = [selected_point] # Set as new target self._last_selected_point = selected_point self._set_new_target( self._contents.data[selected_point][self.current_index] ) self._fast_refresh() def _on_key(self, event): # pragma: no cover """Implement callback for keyboard key pressed.""" if event.key == " ": if self._running is False: self.play() else: self.pause() elif event.key == "left": self.current_index -= 1 elif event.key == "shift+left": self.current_time -= 1 elif event.key == "right": self.current_index += 1 elif event.key == "shift+right": self.current_time += 1 elif event.key == "-": self.playback_speed /= 2 self.title_text = f"Playback set to {self.playback_speed}x" elif event.key == "+": self.playback_speed *= 2 self.title_text = f"Playback set to {self.playback_speed}x" elif event.key == "h": if self._mpl_objects["HelpText"] is None: self._mpl_objects["HelpText"] = self._mpl_objects["Axes"].text( -1.5, -1, HELP_TEXT, color=[0, 1, 0], fontfamily="monospace", ) else: self._mpl_objects["HelpText"].remove() self._mpl_objects["HelpText"] = None elif event.key == "d": self.perspective = not self.perspective if self.perspective is True: self.title_text = "Camera set to perspective" else: self.title_text = "Camera set to orthogonal" elif event.key == "t": self.track = not self.track if self.track is True: self.title_text = "Point tracking activated" else: self.title_text = "Point tracking deactivated" elif event.key == "1": self.set_view("back") self.title_text = "Back view, orthogonal" elif event.key == "2": self.set_view("front") self.title_text = "Front view, orthogonal" elif event.key == "3": self.set_view("left") self.title_text = "Left view, orthogonal" elif event.key == "4": self.set_view("right") self.title_text = "Right view, orthogonal" elif event.key == "5": self.set_view("top") self.title_text = "Top view, orthogonal" elif event.key == "6": self.set_view("bottom") self.title_text = "Bottom view, orthogonal" elif event.key == "0": self.set_view("initial") self.title_text = "Initial view" elif event.key == "shift": self._state["ShiftPressed"] = True self._fast_refresh() def _on_release(self, event): # pragma: no cover if event.key == "shift": self._state["ShiftPressed"] = False def _on_scroll(self, event): # pragma: no cover if event.button == "up": self.zoom *= 1.05 elif event.button == "down": self.zoom /= 1.05 self._fast_refresh() def _on_mouse_press(self, event): # pragma: no cover self._state["PanOnMousePress"] = self.pan self._state["AzimutOnMousePress"] = self.azimuth self._state["ElevationOnMousePress"] = self.elevation self._state["ZoomOnMousePress"] = self.zoom self._state["MousePositionOnPress"] = (event.x, event.y) if event.button == 1: self._state["MouseLeftPressed"] = True elif event.button == 2: self._state["MouseMiddlePressed"] = True elif event.button == 3: self._state["MouseRightPressed"] = True self._fast_refresh() def _on_mouse_release(self, event): # pragma: no cover if event.button == 1: self._state["MouseLeftPressed"] = False elif event.button == 2: self._state["MouseMiddlePressed"] = False elif event.button == 3: self._state["MouseRightPressed"] = False def _on_mouse_motion(self, event): # pragma: no cover # Pan: if ( self._state["MouseLeftPressed"] and self._state["ShiftPressed"] ) or self._state["MouseMiddlePressed"]: self.pan = ( self._state["PanOnMousePress"][0] + (event.x - self._state["MousePositionOnPress"][0]) / (100 * self.zoom), self._state["PanOnMousePress"][1] + (event.y - self._state["MousePositionOnPress"][1]) / (100 * self.zoom), ) self._fast_refresh() # Rotation: elif ( self._state["MouseLeftPressed"] and not self._state["ShiftPressed"] ): self.azimuth = ( self._state["AzimutOnMousePress"] - (event.x - self._state["MousePositionOnPress"][0]) / 250 ) self.elevation = ( self._state["ElevationOnMousePress"] - (event.y - self._state["MousePositionOnPress"][1]) / 250 ) self._fast_refresh() # Zoom: elif self._state["MouseRightPressed"]: self.zoom = ( self._state["ZoomOnMousePress"] + (event.y - self._state["MousePositionOnPress"][1]) / 250 ) self._fast_refresh() def _to_animation(self): """ Create a matplotlib FuncAnimation for displaying in Jupyter notebooks. This also closes the figure so that Jupyter does not show both the animation and the figure. Parameters ---------- No parameter. Returns ------- A FuncAnimation to be displayed by Jupyter notebook. """ try: from IPython.display import Video except ModuleNotFoundError: raise RuntimeError( "This function must be run in an IPython session." ) self._mpl_objects["Figure"].set_size_inches(6, 4.5) # Half size self._mpl_objects["Figure"].tight_layout() self.to_video("temp.mp4", show_progress_bar=False) return Video( "temp.mp4", embed=True, html_attributes="controls loop autoplay" ) # %% Public methods
[docs] def play(self) -> None: """Start the animation.""" self._state["SystemTimeOnLastUpdate"] = time.time() self._running = True self._mpl_objects["Anim"].event_source.start()
[docs] def pause(self) -> None: """Pause the animation.""" self._running = False self._mpl_objects["Anim"].event_source.stop()
[docs] def set_view(self, plane: str) -> None: """ Set the current view to an orthogonal view in a given plane. Ensure that the player's `up` and `anterior` properties are set to the correct axes beforehand. By default, `up` is "y" and `anterior` is "x". Parameters ---------- plane Can be either "front", "back", "right", "left", "top", "bottom" or "initial". In the latter case, the view is reset to the initial view at Player creation. """ check_param("plane", plane, str) if plane.lower() == "initial": self.elevation = self._initial_elevation self.azimuth = self._initial_azimuth self.perspective = self._initial_perspective return # Set a "from rotation" matrix following x anterior, y up and z right if plane.lower() == "front": from_rot = geometry.create_transforms( "YXZ", [[90, 0, 0]], degrees=True ) self.perspective = False elif plane.lower() == "back": from_rot = geometry.create_transforms( "YXZ", [[-90, 0, 0]], degrees=True ) self.perspective = False elif plane.lower() == "top": from_rot = geometry.create_transforms( "YXZ", [[0, 90, 0]], degrees=True ) self.perspective = False elif plane.lower() == "bottom": from_rot = geometry.create_transforms( "YXZ", [[0, -90, 0]], degrees=True ) self.perspective = False elif plane.lower() == "right": from_rot = geometry.create_transforms( "YXZ", [[0, 0, 0]], degrees=True ) self.perspective = False elif plane.lower() == "left": from_rot = geometry.create_transforms( "YXZ", [[180, 0, 0]], degrees=True ) self.perspective = False else: raise ValueError( "Parameter plane can be either " '"front", "back", "right", "left", "top" or "bottom".' ) # Ignore gimbal lock warnings, gimbal locks are ok since SciPy # behaviour is well documented in those case (3rd angle set to 0). with warnings.catch_warnings(): warnings.simplefilter("ignore") angles = geometry.get_angles(from_rot, "YXZ") self.elevation = angles[0, 1] self.azimuth = angles[0, 0]
[docs] def close(self) -> None: """Close the Player and its associated window.""" plt.close(self._mpl_objects["Figure"]) self._mpl_objects = {}
[docs] def to_image(self, filename: str) -> None: """ Save the current view to an image file. Any format supported by Matplotlib can be used. Parameters ---------- filename Name of the image file to save (e.g., "file.png", "file.jpeg", "file.pdf", "file.svg", "file.tiff") Returns ------- None """ check_param("filename", filename, str) self._mpl_objects["Figure"].savefig(filename)
[docs] def to_video( self, filename: str, *, fps: int | None = None, downsample: int = 1, show_progress_bar: bool = True, ) -> None: """ Save the current view to an MP4 video file. Parameters ---------- filename Name of the video file to save. fps Optional. Frames per second of the output video. Default is None, which means that fps matches the current playback speed of the Player. This attribute does not affect the number of images in the output video; it only affects the playback speed of the output video. downsample Optional. Use it to reduce the file size on acquisitions at high sample rates. Default is 1, which means that the video is not downsampled. In this case, each index is exported as one frame of the output video. A value of 2 divides the number of frames by 2, which means that every other index is skipped. A value of 3 divides the number of frames by 3, etc. show_progress_bar Optional. True to show a progress bar while creating the video file. Returns ------- None """ check_param("filename", filename, str) check_param("fps", fps, (int, None)) check_param("downsample", downsample, int) check_param("show_progress_bar", show_progress_bar, bool) if downsample < 1: raise ValueError( "Parameter downsample must be stricly higher than 0." ) n_samples = int(len(self._contents.time) / downsample) # We create a specific animation and callback, since all processing # will be done offline. We set a very long delay between frames but # this is just so that the animation didn't advance by itself by the # time recording has started. def advance(args): self.current_index = args * downsample self.title_text = ( f"{self.current_index}/{(n_samples - 1) * downsample}: " f"{self.current_time:.3f} s." ) anim = animation.FuncAnimation( self._mpl_objects["Figure"], advance, # type: ignore frames=n_samples, interval=1e6, ) # 30 ips if fps is None: fps = int( self._contents.get_sample_rate() * self.playback_speed / downsample ) if np.isnan(fps): fps = 30 writervideo = animation.FFMpegWriter(fps=fps) self.pause() self.current_index = 0 if show_progress_bar: progress_bar = tqdm(n_samples - 1) update_progress_bar = lambda i, n: progress_bar.update(1) else: update_progress_bar = lambda i, n: None anim.save( filename, writer=writervideo, progress_callback=update_progress_bar ) anim.event_source.stop() if show_progress_bar: progress_bar.close() self.title_text = "" self.current_index = 0
# %% To deprecate
[docs] def get_interconnections(self) -> dict[str, dict[str, Any]]: """Get interconnections value (deprecated).""" return self._interconnections
[docs] def set_interconnections(self, value: dict[str, dict[str, Any]]) -> None: """Set interconnections value (deprecated).""" self.interconnections = value