#!/usr/bin/env python3 """ Graph Core - vzug-e-hinge ========================== Pure plotting functions for telemetry data visualization. Database-agnostic - works with any data source via table_query adapters. Features: - Matplotlib plotting functions - Overlay, subplots, drift comparison, multi-series - XY scatter plots - Export PNG and CSV - No database dependency (uses TelemetryData from table_query) Author: Kynsight Version: 2.0.0 """ from __future__ import annotations from typing import List, Dict, Tuple, Optional, Any from dataclasses import dataclass import matplotlib matplotlib.use('QtAgg') # For PyQt6 integration import matplotlib.pyplot as plt from matplotlib.figure import Figure import numpy as np # Import data structures from table_query from graph_table_query import TelemetryData, get_column_label # ============================================================================= # Plot Configuration # ============================================================================= @dataclass class PlotConfig: """Plot configuration.""" title: str = "Telemetry Data" xlabel: str = "Time (s)" figsize: Tuple[int, int] = (12, 8) dpi: int = 100 grid: bool = True legend: bool = True style: str = "default" # matplotlib style linestyle: str = "-" # Line style marker: Optional[str] = None # Marker style markersize: int = 3 # Marker size # ============================================================================= # Plotting Functions # ============================================================================= def plot_overlay( data_list: List[TelemetryData], x_column: str, y_column: str, xlabel: str, ylabel: str, config: Optional[PlotConfig] = None ) -> Figure: """ Create overlay plot (multiple runs on same axes). Args: data_list: List of TelemetryData objects x_column: Column name for X-axis (e.g., 'time_ms', 't_ns') y_column: Column name for Y-axis (e.g., 'pwm', 'motor_current') xlabel: X-axis label ylabel: Y-axis label config: Plot configuration Returns: Matplotlib Figure object """ if config is None: config = PlotConfig() fig, ax = plt.subplots(figsize=config.figsize, dpi=config.dpi) # Set color cycle ax.set_prop_cycle(color=plt.cm.tab10.colors) # Plot each run for data in data_list: x_data = getattr(data, x_column, None) y_data = getattr(data, y_column, None) if x_data is None or y_data is None: continue # Updated label format: "session_name run_no (name)" label = f"{data.session_id} {data.run_no} ({data.session_name})" ax.plot(x_data, y_data, label=label, alpha=0.8, linestyle=config.linestyle, marker=config.marker, markersize=config.markersize) # Formatting ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.set_title(config.title) if config.grid: ax.grid(True, alpha=0.3) if config.legend: # Legend at bottom, outside plot ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.12), ncol=2) fig.tight_layout(rect=[0, 0.08, 1, 1]) # Leave space for legend return fig def plot_subplots( data_list: List[TelemetryData], x_column: str, y_column: str, xlabel: str, ylabel: str, config: Optional[PlotConfig] = None ) -> Figure: """ Create subplot grid (one subplot per run). Args: data_list: List of TelemetryData objects x_column: Column name for X-axis y_column: Column name for Y-axis xlabel: X-axis label ylabel: Y-axis label config: Plot configuration Returns: Matplotlib Figure object """ if config is None: config = PlotConfig() n_runs = len(data_list) # Calculate grid dimensions n_cols = min(2, n_runs) n_rows = (n_runs + n_cols - 1) // n_cols fig, axes = plt.subplots( n_rows, n_cols, figsize=config.figsize, dpi=config.dpi, squeeze=False ) # Flatten axes for easy iteration axes_flat = axes.flatten() # Plot each run for idx, data in enumerate(data_list): ax = axes_flat[idx] x_data = getattr(data, x_column, None) y_data = getattr(data, y_column, None) if x_data is None or y_data is None: ax.text(0.5, 0.5, 'No data', ha='center', va='center') ax.set_title(f"{data.session_id} - Run {data.run_no} ({data.session_name})") continue ax.plot(x_data, y_data, alpha=0.8, linestyle=config.linestyle, marker=config.marker, markersize=config.markersize) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) # Updated title format ax.set_title(f"{data.session_id} {data.run_no} ({data.session_name})") if config.grid: ax.grid(True, alpha=0.3) # Hide unused subplots for idx in range(len(data_list), len(axes_flat)): axes_flat[idx].set_visible(False) fig.tight_layout() return fig def plot_comparison( data_list: List[TelemetryData], x_column: str, y_column: str, xlabel: str, ylabel: str, reference_index: int = 0, config: Optional[PlotConfig] = None ) -> Figure: """ Create drift comparison plot (deviation from reference run). Args: data_list: List of TelemetryData objects x_column: Column name for X-axis y_column: Column name for Y-axis xlabel: X-axis label ylabel: Y-axis label reference_index: Index of reference run (default: 0 = first run) config: Plot configuration Returns: Matplotlib Figure object """ if config is None: config = PlotConfig() if reference_index >= len(data_list): reference_index = 0 fig, ax = plt.subplots(figsize=config.figsize, dpi=config.dpi) reference = data_list[reference_index] ref_x = getattr(reference, x_column, None) ref_y = getattr(reference, y_column, None) if ref_x is None or ref_y is None: ax.text(0.5, 0.5, 'No reference data', ha='center', va='center') return fig # Plot reference as baseline (zero) ax.axhline(y=0, color='black', linestyle='--', # Updated label format label=f'Reference: {reference.session_id} {reference.run_no} ({reference.session_name})') # Set color cycle ax.set_prop_cycle(color=plt.cm.tab10.colors) # Plot deviations for idx, data in enumerate(data_list): if idx == reference_index: continue x_data = getattr(data, x_column, None) y_data = getattr(data, y_column, None) if x_data is None or y_data is None: continue # Interpolate to match reference x points (for comparison) if len(x_data) != len(ref_x) or not np.array_equal(x_data, ref_x): y_interp = np.interp(ref_x, x_data, y_data) else: y_interp = y_data # Calculate deviation deviation = y_interp - ref_y # Updated label format label = f"{data.session_id} {data.run_no} ({data.session_name})" ax.plot(ref_x, deviation, label=label, alpha=0.8, linestyle=config.linestyle, marker=config.marker, markersize=config.markersize) # Formatting ax.set_xlabel(xlabel) ax.set_ylabel(f"Deviation in {ylabel}") ax.set_title(f"{config.title} - Drift Analysis") if config.grid: ax.grid(True, alpha=0.3) if config.legend: # Legend at bottom ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.12), ncol=2) fig.tight_layout(rect=[0, 0.08, 1, 1]) return fig def plot_multi_series( data_list: List[TelemetryData], x_column: str, y_columns: List[str], xlabel: str, ylabels: List[str], config: Optional[PlotConfig] = None ) -> Figure: """ Create multi-series plot (multiple data columns, multiple runs). Args: data_list: List of TelemetryData objects x_column: Column name for X-axis y_columns: List of column names to plot xlabel: X-axis label ylabels: List of Y-axis labels config: Plot configuration Returns: Matplotlib Figure object """ if config is None: config = PlotConfig() n_series = len(y_columns) fig, axes = plt.subplots( n_series, 1, figsize=config.figsize, dpi=config.dpi, sharex=True ) # Handle single subplot case if n_series == 1: axes = [axes] # Plot each series for idx, (y_col, ylabel) in enumerate(zip(y_columns, ylabels)): ax = axes[idx] # Set color cycle ax.set_prop_cycle(color=plt.cm.tab10.colors) for data in data_list: x_data = getattr(data, x_column, None) y_data = getattr(data, y_col, None) if x_data is None or y_data is None: continue # Updated label format label = f"{data.session_id} {data.run_no} ({data.session_name})" ax.plot(x_data, y_data, label=label, alpha=0.8, linestyle=config.linestyle, marker=config.marker, markersize=config.markersize) ax.set_ylabel(ylabel) if config.grid: ax.grid(True, alpha=0.3) if config.legend and idx == 0: # Legend only on first subplot ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.15), ncol=2) # X-label only on bottom subplot axes[-1].set_xlabel(xlabel) fig.suptitle(config.title) fig.tight_layout(rect=[0, 0.08, 1, 0.96]) return fig def plot_xy_scatter( data_list: List[TelemetryData], x_column: str, y_column: str, xlabel: str, ylabel: str, config: Optional[PlotConfig] = None ) -> Figure: """ Create XY scatter/line plot (any column vs any column). Useful for phase plots, correlation analysis, etc. Example: motor_current vs pwm, angle vs encoder_value Args: data_list: List of TelemetryData objects x_column: Column name for X-axis y_column: Column name for Y-axis xlabel: X-axis label ylabel: Y-axis label config: Plot configuration Returns: Matplotlib Figure object """ if config is None: config = PlotConfig() fig, ax = plt.subplots(figsize=config.figsize, dpi=config.dpi) # Set color cycle ax.set_prop_cycle(color=plt.cm.tab10.colors) # Plot each run for data in data_list: x_data = getattr(data, x_column, None) y_data = getattr(data, y_column, None) if x_data is None or y_data is None: continue # Updated label format label = f"{data.session_id} {data.run_no} ({data.session_name})" ax.plot(x_data, y_data, label=label, alpha=0.8, linestyle=config.linestyle, marker=config.marker or 'o', markersize=config.markersize if config.marker else 2, linewidth=1) # Formatting ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.set_title(config.title) if config.grid: ax.grid(True, alpha=0.3) if config.legend: # Legend at bottom ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.12), ncol=2) fig.tight_layout(rect=[0, 0.08, 1, 1]) return fig # ============================================================================= # Export Functions # ============================================================================= def export_png(fig: Figure, filepath: str, dpi: int = 300) -> bool: """ Export figure to PNG. Args: fig: Matplotlib figure filepath: Output file path dpi: Resolution Returns: True on success """ try: fig.savefig(filepath, dpi=dpi, bbox_inches='tight') return True except Exception as e: print(f"[EXPORT ERROR] Failed to save PNG: {e}") return False def export_csv( data_list: List[TelemetryData], filepath: str, x_column: str, y_columns: List[str] ) -> bool: """ Export telemetry data to CSV (only selected columns). Args: data_list: List of TelemetryData objects filepath: Output file path x_column: X-axis column name (e.g., 'time_ms') y_columns: List of Y-axis column names (e.g., ['motor_current', 'pwm']) Returns: True on success """ try: import csv with open(filepath, 'w', newline='') as f: writer = csv.writer(f) # Header: metadata + X column + selected Y columns header = ['session_id', 'session_name', 'run_no', x_column] + y_columns writer.writerow(header) # Data rows for data in data_list: # Get X data x_data = getattr(data, x_column, None) if x_data is None: continue # Get length from X column length = len(x_data) # Write each data point for i in range(length): row = [ data.session_id, data.session_name, data.run_no, x_data[i] ] # Add selected Y columns for y_col in y_columns: y_data = getattr(data, y_col, None) if y_data is not None and i < len(y_data): row.append(y_data[i]) else: row.append('') # Empty if column doesn't exist writer.writerow(row) return True except Exception as e: print(f"[EXPORT ERROR] Failed to save CSV: {e}") return False # ============================================================================= # Demo # ============================================================================= if __name__ == "__main__": print("Graph Core - Pure Plotting Functions") print("=" * 50) print() print("Database-agnostic plotting library!") print() print("Usage:") print(" from table_query import SQLiteAdapter, CSVAdapter") print(" from graph_core import plot_overlay, plot_xy_scatter") print() print(" # Load data via adapter") print(" adapter = SQLiteAdapter('./database/ehinge.db')") print(" adapter.connect()") print(" data = adapter.load_run_data('Session_A', 1)") print() print(" # Plot (pure function - no database!)") print(" fig = plot_overlay([data], 'time_ms', 'pwm', 'Time', 'PWM')") print() print("Available functions:") print(" • plot_overlay() - Multiple runs on same axes") print(" • plot_subplots() - One subplot per run") print(" • plot_comparison() - Drift analysis") print(" • plot_multi_series() - Stacked subplots") print(" • plot_xy_scatter() - XY scatter/line plot") print(" • export_png() - Save as PNG") print(" • export_csv() - Export data as CSV")