Optical Flow Analysis¶
This tutorial demonstrates how to use the WaveSpace toolbox to perform optical flow analysis on simulated wave data, identify motifs, and visualize the results.
Setup¶
Before running the example, ensure the project root is on the Python path:
import sys
import os
import time
path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
sys.path.insert(0, path )
print(path)
from WaveSpace.PlottingHelpers import Plotting
from WaveSpace.Utils import HelperFuns as hf
from WaveSpace.Utils import ImportHelpers
from WaveSpace.WaveAnalysis import OpticalFlow
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.gridspec as gridspec
from matplotlib import colormaps
Loading Simulated Data¶
# Load some simulated data
dataPath = os.path.join(path, "Examples/ExampleData/Output")
waveData = ImportHelpers.load_wavedata_object(dataPath + "/complexData")
Optical Flow Computation¶
tStart = time.time()
print("OpticalFlow started")
OpticalFlow.create_uv(waveData,
applyGaussianBlur=False,
type = "angle",
Sigma=1,
alpha = 0.1,
maxIter = 200,
dataBucketName="complexData",
is_phase = False)
print('optical flow took: ', time.time()-tStart)
Visualization¶
trialToPlot = 4
waveData.set_active_dataBucket('UV')
ani = Plotting.plot_optical_flow(waveData,
UVBucketName = 'UV',
PlottingDataBucketName = 'complexData',
dataInds = (0, trialToPlot, slice(None), slice(None), slice(None)),
plotangle=True,
normVectorLength = True)
output_path = os.path.join(path, "Examples/ExampleData/Output/")
ani.save( output_path + 'OpticalFlowAfterFilter_Hilbert.gif')
Motif Detection¶
foi = 10
cycleLength = waveData.get_sample_rate()/ foi
freqInd = 0
motifs = hf.find_wave_motifs(waveData,
dataBucketName="UV",
threshold = 0.8,
nTimepointsEdge=cycleLength,
mergeThreshold = 0.8,
minFrames=cycleLength,
pixelThreshold = 0.6,
magnitudeThreshold=.1,
dataInds = (freqInd, slice(None), slice(None), slice(None), slice(None)),
Mask = False)
Motif Mapping and Plotting¶
conds = waveData.get_trialInfo()
uniques = np.unique(conds)
trial_dict = {}
for trial_idx, condition in enumerate(conds):
if condition not in trial_dict:
trial_dict[condition] = []
trial_dict[condition].append(trial_idx)
motifMap = np.full((len(conds),len(waveData.get_time())), -1)
for ind, motif in enumerate(motifs):
trial_frames_list = motif['trial_frames']
for trial_frame in trial_frames_list:
(trial, (start_timepoint, end_timepoint)) = trial_frame
motifMap[trial, start_timepoint:end_timepoint] = ind
cmap = mcolors.ListedColormap(['grey', "#8F43D1", "#c50069",'#d67258', '#416ae4', '#378b8c', "#0f3200" ,'#a05195', "#4e2f13", "#3900AB","#b3ff00", "#ff0015", "#0d15ff"])
bounds = [-1, 0, 1, 2, 3, 4, 5, 7, 8, 9, 10, 11, 12]
norm = mcolors.BoundaryNorm(bounds, cmap.N)
im = plt.pcolormesh(motifMap, cmap=cmap, norm=norm)
plt.colorbar()
# One column per condition
fig, axs = plt.subplots(1, len(motifs[0:6]), figsize=(12, 6), gridspec_kw={'wspace': 0.3})
for motifInd, motif in enumerate(motifs[0:6]):
# Quiver plot
axs[motifInd].quiver(-np.real(motif['average']), -np.imag(motif['average']), color='black')
axs[motifInd].set_facecolor('white')
axs[motifInd].set_aspect('equal')
for spine in axs[motifInd].spines.values():
spine.set_edgecolor(cmap(motifInd+1))
spine.set_linewidth(2)