Circular-Linear Correlation¶
This tutorial demonstrates how to use the WaveSpace toolbox to compute circular-linear (phase-distance) correlations on simulated data, including grid setup, distance matrix calculation, and visualization of results.
Setup¶
Before running the example, ensure the project root is on the Python path:
import sys
import os
path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
sys.path.insert(0, path )
print(path)
from WaveSpace.PlottingHelpers import Plotting
from WaveSpace.Utils import HelperFuns as hf
from WaveSpace.Utils import ImportHelpers
from WaveSpace.WaveAnalysis import DistanceCorrelation
from WaveSpace.SpatialArrangement import SensorLayout as sensors
import time
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.gridspec as gridspec
from matplotlib import colormaps
Loading Simulated Data¶
# Load some simulated data
dataPath = os.path.join(path, "Examples/ExampleData/Output")
waveData = ImportHelpers.load_wavedata_object(dataPath + "/complexData")
Grid and Distance Matrix Setup¶
# We already know that our data is on a regular grid because we generated it that way
# so we can simply use the channel positions to create a distance matrix
sensors.regularGrid(waveData)
Generalized Phase Distance Correlation¶
DistanceCorrelation.calculate_distance_correlation_GP(waveData, dataBucketName = "complexData", evaluationAngle=np.pi, tolerance=0.2)
dataFrame = waveData.get_data("PhaseDistanceCorrelation")
Distance Correlation for Selected Source Points¶
pointRange = range(0,20,2)
sourcePoints = []
for i in pointRange:
sourcePoints.append((i,i))
DistanceCorrelation.calculate_distance_correlation(waveData, dataBucketName = "complexData", sourcePoints=sourcePoints, pixelSpacing=1)
Plotting Phase-Distance Correlation Over Time¶
phaseDistCorr= waveData.get_data("PhaseDistanceCorrelation")
shape = waveData.get_data("complexData").shape
dimord = waveData.DataBuckets["complexData"].get_dimord()
splitDimord = dimord.split("_")
spatialIndexStart = splitDimord.index("posx")
selectedTrial = 0
fig, ax = plt.subplots(figsize=(8,6))
for i, point in enumerate(sourcePoints):
rho = phaseDistCorr.loc[(phaseDistCorr["trialInd"] == selectedTrial) & (phaseDistCorr["sourcePointX"] == point[0]) & (phaseDistCorr["sourcePointY"] == point[1])]
color = Plotting.getProbeColor(i, len(sourcePoints))
ax.plot(rho["rho"].tolist(), label =str(point), color=color)
ax.legend()
color_grid = Plotting.get_color_grid_from_probes((shape[spatialIndexStart],shape[spatialIndexStart+1]), sourcePoints)
Plotting.add_color_grid_legend(ax, color_grid, position=[0.2, 0.2, 1.5, 1.5])
plt.show()
Full Grid Correlation and Visualization¶
# Only do if you have too much time on your hands:
# Calculate and plot average phase-distance correlation for 600 to 1000 ms for all points
pointRange = (20,20)
sourcePoints = []
for x in range(pointRange[0]):
for y in range(pointRange[1]):
sourcePoints.append((x,y))
DistanceCorrelation.calculate_distance_correlation(waveData, dataBucketName = "complexData", sourcePoints=sourcePoints, pixelSpacing=1)
output_path = os.path.join(path, "Examples/ExampleData/Output")
waveData.save_to_file(os.path.join(output_path, "DistanceCorrelation"))
Loading and Plotting Saved Correlation Data¶
waveData = ImportHelpers.load_wavedata_object("ExampleData/Output/DistanceCorrelation")
pointRange = (20,20)
sourcePoints = []
for x in range(pointRange[0]):
for y in range(pointRange[1]):
sourcePoints.append((x,y))
phaseDistCorr= waveData.get_data("PhaseDistanceCorrelation")
conditions = waveData.get_trialInfo()[::2]
shape = waveData.get_data("complexData").shape
selectedTrial = 4
rho = np.zeros((8,20,20))
for condInd, condition in enumerate(conditions):
for i, (x,y) in enumerate(sourcePoints):
phaseDistCorrOverTime = phaseDistCorr.loc[(phaseDistCorr["trialInd"] == condInd*2) &
(phaseDistCorr["sourcePointX"] == x) &
(phaseDistCorr["sourcePointY"] == y)]
rho[condInd,x,y] = np.mean(phaseDistCorrOverTime["rho"][300:500])
fig, ax = plt.subplots(figsize=(8,6))
im = ax.imshow(rho[condInd], origin="lower", )
ax.set_title(condition)
fig.colorbar(im, ax=ax)
plt.show()