You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

540 lines
16 KiB

#!/usr/bin/env python3
"""
Graph Core - vzug-e-hinge
==========================
Pure plotting functions for telemetry data visualization.
Database-agnostic - works with any data source via table_query adapters.
Features:
- Matplotlib plotting functions
- Overlay, subplots, drift comparison, multi-series
- XY scatter plots
- Export PNG and CSV
- No database dependency (uses TelemetryData from table_query)
Author: Kynsight
Version: 2.0.0
"""
from __future__ import annotations
from typing import List, Dict, Tuple, Optional, Any
from dataclasses import dataclass
import matplotlib
matplotlib.use('QtAgg') # For PyQt6 integration
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
import numpy as np
# Import data structures from table_query
from graph_table_query import TelemetryData, get_column_label
# =============================================================================
# Plot Configuration
# =============================================================================
@dataclass
class PlotConfig:
"""Plot configuration."""
title: str = "Telemetry Data"
xlabel: str = "Time (s)"
figsize: Tuple[int, int] = (12, 8)
dpi: int = 100
grid: bool = True
legend: bool = True
style: str = "default" # matplotlib style
linestyle: str = "-" # Line style
marker: Optional[str] = None # Marker style
markersize: int = 3 # Marker size
# =============================================================================
# Plotting Functions
# =============================================================================
def plot_overlay(
data_list: List[TelemetryData],
x_column: str,
y_column: str,
xlabel: str,
ylabel: str,
config: Optional[PlotConfig] = None
) -> Figure:
"""
Create overlay plot (multiple runs on same axes).
Args:
data_list: List of TelemetryData objects
x_column: Column name for X-axis (e.g., 'time_ms', 't_ns')
y_column: Column name for Y-axis (e.g., 'pwm', 'motor_current')
xlabel: X-axis label
ylabel: Y-axis label
config: Plot configuration
Returns:
Matplotlib Figure object
"""
if config is None:
config = PlotConfig()
fig, ax = plt.subplots(figsize=config.figsize, dpi=config.dpi)
# Set color cycle
ax.set_prop_cycle(color=plt.cm.tab10.colors)
# Plot each run
for data in data_list:
x_data = getattr(data, x_column, None)
y_data = getattr(data, y_column, None)
if x_data is None or y_data is None:
continue
# Updated label format: "session_name run_no (name)"
label = f"{data.session_id} {data.run_no} ({data.session_name})"
ax.plot(x_data, y_data, label=label, alpha=0.8,
linestyle=config.linestyle, marker=config.marker,
markersize=config.markersize)
# Formatting
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.set_title(config.title)
if config.grid:
ax.grid(True, alpha=0.3)
if config.legend:
# Legend at bottom, outside plot
ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.12), ncol=2)
fig.tight_layout(rect=[0, 0.08, 1, 1]) # Leave space for legend
return fig
def plot_subplots(
data_list: List[TelemetryData],
x_column: str,
y_column: str,
xlabel: str,
ylabel: str,
config: Optional[PlotConfig] = None
) -> Figure:
"""
Create subplot grid (one subplot per run).
Args:
data_list: List of TelemetryData objects
x_column: Column name for X-axis
y_column: Column name for Y-axis
xlabel: X-axis label
ylabel: Y-axis label
config: Plot configuration
Returns:
Matplotlib Figure object
"""
if config is None:
config = PlotConfig()
n_runs = len(data_list)
# Calculate grid dimensions
n_cols = min(2, n_runs)
n_rows = (n_runs + n_cols - 1) // n_cols
fig, axes = plt.subplots(
n_rows, n_cols,
figsize=config.figsize,
dpi=config.dpi,
squeeze=False
)
# Flatten axes for easy iteration
axes_flat = axes.flatten()
# Plot each run
for idx, data in enumerate(data_list):
ax = axes_flat[idx]
x_data = getattr(data, x_column, None)
y_data = getattr(data, y_column, None)
if x_data is None or y_data is None:
ax.text(0.5, 0.5, 'No data', ha='center', va='center')
ax.set_title(f"{data.session_id} - Run {data.run_no} ({data.session_name})")
continue
ax.plot(x_data, y_data, alpha=0.8,
linestyle=config.linestyle, marker=config.marker,
markersize=config.markersize)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
# Updated title format
ax.set_title(f"{data.session_id} {data.run_no} ({data.session_name})")
if config.grid:
ax.grid(True, alpha=0.3)
# Hide unused subplots
for idx in range(len(data_list), len(axes_flat)):
axes_flat[idx].set_visible(False)
fig.tight_layout()
return fig
def plot_comparison(
data_list: List[TelemetryData],
x_column: str,
y_column: str,
xlabel: str,
ylabel: str,
reference_index: int = 0,
config: Optional[PlotConfig] = None
) -> Figure:
"""
Create drift comparison plot (deviation from reference run).
Args:
data_list: List of TelemetryData objects
x_column: Column name for X-axis
y_column: Column name for Y-axis
xlabel: X-axis label
ylabel: Y-axis label
reference_index: Index of reference run (default: 0 = first run)
config: Plot configuration
Returns:
Matplotlib Figure object
"""
if config is None:
config = PlotConfig()
if reference_index >= len(data_list):
reference_index = 0
fig, ax = plt.subplots(figsize=config.figsize, dpi=config.dpi)
reference = data_list[reference_index]
ref_x = getattr(reference, x_column, None)
ref_y = getattr(reference, y_column, None)
if ref_x is None or ref_y is None:
ax.text(0.5, 0.5, 'No reference data', ha='center', va='center')
return fig
# Plot reference as baseline (zero)
ax.axhline(y=0, color='black', linestyle='--',
# Updated label format
label=f'Reference: {reference.session_id} {reference.run_no} ({reference.session_name})')
# Set color cycle
ax.set_prop_cycle(color=plt.cm.tab10.colors)
# Plot deviations
for idx, data in enumerate(data_list):
if idx == reference_index:
continue
x_data = getattr(data, x_column, None)
y_data = getattr(data, y_column, None)
if x_data is None or y_data is None:
continue
# Interpolate to match reference x points (for comparison)
if len(x_data) != len(ref_x) or not np.array_equal(x_data, ref_x):
y_interp = np.interp(ref_x, x_data, y_data)
else:
y_interp = y_data
# Calculate deviation
deviation = y_interp - ref_y
# Updated label format
label = f"{data.session_id} {data.run_no} ({data.session_name})"
ax.plot(ref_x, deviation, label=label, alpha=0.8,
linestyle=config.linestyle, marker=config.marker,
markersize=config.markersize)
# Formatting
ax.set_xlabel(xlabel)
ax.set_ylabel(f"Deviation in {ylabel}")
ax.set_title(f"{config.title} - Drift Analysis")
if config.grid:
ax.grid(True, alpha=0.3)
if config.legend:
# Legend at bottom
ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.12), ncol=2)
fig.tight_layout(rect=[0, 0.08, 1, 1])
return fig
def plot_multi_series(
data_list: List[TelemetryData],
x_column: str,
y_columns: List[str],
xlabel: str,
ylabels: List[str],
config: Optional[PlotConfig] = None
) -> Figure:
"""
Create multi-series plot (multiple data columns, multiple runs).
Args:
data_list: List of TelemetryData objects
x_column: Column name for X-axis
y_columns: List of column names to plot
xlabel: X-axis label
ylabels: List of Y-axis labels
config: Plot configuration
Returns:
Matplotlib Figure object
"""
if config is None:
config = PlotConfig()
n_series = len(y_columns)
fig, axes = plt.subplots(
n_series, 1,
figsize=config.figsize,
dpi=config.dpi,
sharex=True
)
# Handle single subplot case
if n_series == 1:
axes = [axes]
# Plot each series
for idx, (y_col, ylabel) in enumerate(zip(y_columns, ylabels)):
ax = axes[idx]
# Set color cycle
ax.set_prop_cycle(color=plt.cm.tab10.colors)
for data in data_list:
x_data = getattr(data, x_column, None)
y_data = getattr(data, y_col, None)
if x_data is None or y_data is None:
continue
# Updated label format
label = f"{data.session_id} {data.run_no} ({data.session_name})"
ax.plot(x_data, y_data, label=label, alpha=0.8,
linestyle=config.linestyle, marker=config.marker,
markersize=config.markersize)
ax.set_ylabel(ylabel)
if config.grid:
ax.grid(True, alpha=0.3)
if config.legend and idx == 0: # Legend only on first subplot
ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.15), ncol=2)
# X-label only on bottom subplot
axes[-1].set_xlabel(xlabel)
fig.suptitle(config.title)
fig.tight_layout(rect=[0, 0.08, 1, 0.96])
return fig
def plot_xy_scatter(
data_list: List[TelemetryData],
x_column: str,
y_column: str,
xlabel: str,
ylabel: str,
config: Optional[PlotConfig] = None
) -> Figure:
"""
Create XY scatter/line plot (any column vs any column).
Useful for phase plots, correlation analysis, etc.
Example: motor_current vs pwm, angle vs encoder_value
Args:
data_list: List of TelemetryData objects
x_column: Column name for X-axis
y_column: Column name for Y-axis
xlabel: X-axis label
ylabel: Y-axis label
config: Plot configuration
Returns:
Matplotlib Figure object
"""
if config is None:
config = PlotConfig()
fig, ax = plt.subplots(figsize=config.figsize, dpi=config.dpi)
# Set color cycle
ax.set_prop_cycle(color=plt.cm.tab10.colors)
# Plot each run
for data in data_list:
x_data = getattr(data, x_column, None)
y_data = getattr(data, y_column, None)
if x_data is None or y_data is None:
continue
# Updated label format
label = f"{data.session_id} {data.run_no} ({data.session_name})"
ax.plot(x_data, y_data, label=label, alpha=0.8,
linestyle=config.linestyle, marker=config.marker or 'o',
markersize=config.markersize if config.marker else 2,
linewidth=1)
# Formatting
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.set_title(config.title)
if config.grid:
ax.grid(True, alpha=0.3)
if config.legend:
# Legend at bottom
ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.12), ncol=2)
fig.tight_layout(rect=[0, 0.08, 1, 1])
return fig
# =============================================================================
# Export Functions
# =============================================================================
def export_png(fig: Figure, filepath: str, dpi: int = 300) -> bool:
"""
Export figure to PNG.
Args:
fig: Matplotlib figure
filepath: Output file path
dpi: Resolution
Returns:
True on success
"""
try:
fig.savefig(filepath, dpi=dpi, bbox_inches='tight')
return True
except Exception as e:
print(f"[EXPORT ERROR] Failed to save PNG: {e}")
return False
def export_csv(
data_list: List[TelemetryData],
filepath: str,
x_column: str,
y_columns: List[str]
) -> bool:
"""
Export telemetry data to CSV (only selected columns).
Args:
data_list: List of TelemetryData objects
filepath: Output file path
x_column: X-axis column name (e.g., 'time_ms')
y_columns: List of Y-axis column names (e.g., ['motor_current', 'pwm'])
Returns:
True on success
"""
try:
import csv
with open(filepath, 'w', newline='') as f:
writer = csv.writer(f)
# Header: metadata + X column + selected Y columns
header = ['session_id', 'session_name', 'run_no', x_column] + y_columns
writer.writerow(header)
# Data rows
for data in data_list:
# Get X data
x_data = getattr(data, x_column, None)
if x_data is None:
continue
# Get length from X column
length = len(x_data)
# Write each data point
for i in range(length):
row = [
data.session_id,
data.session_name,
data.run_no,
x_data[i]
]
# Add selected Y columns
for y_col in y_columns:
y_data = getattr(data, y_col, None)
if y_data is not None and i < len(y_data):
row.append(y_data[i])
else:
row.append('') # Empty if column doesn't exist
writer.writerow(row)
return True
except Exception as e:
print(f"[EXPORT ERROR] Failed to save CSV: {e}")
return False
# =============================================================================
# Demo
# =============================================================================
if __name__ == "__main__":
print("Graph Core - Pure Plotting Functions")
print("=" * 50)
print()
print("Database-agnostic plotting library!")
print()
print("Usage:")
print(" from table_query import SQLiteAdapter, CSVAdapter")
print(" from graph_core import plot_overlay, plot_xy_scatter")
print()
print(" # Load data via adapter")
print(" adapter = SQLiteAdapter('./database/ehinge.db')")
print(" adapter.connect()")
print(" data = adapter.load_run_data('Session_A', 1)")
print()
print(" # Plot (pure function - no database!)")
print(" fig = plot_overlay([data], 'time_ms', 'pwm', 'Time', 'PWM')")
print()
print("Available functions:")
print(" • plot_overlay() - Multiple runs on same axes")
print(" • plot_subplots() - One subplot per run")
print(" • plot_comparison() - Drift analysis")
print(" • plot_multi_series() - Stacked subplots")
print(" • plot_xy_scatter() - XY scatter/line plot")
print(" • export_png() - Save as PNG")
print(" • export_csv() - Export data as CSV")