Optical Flow Analysis ===================== This tutorial demonstrates how to use the **WaveSpace** toolbox to perform optical flow analysis on simulated wave data, identify motifs, and visualize the results. .. contents:: Table of Contents Setup ----- Before running the example, ensure the project root is on the Python path: .. code-block:: python import sys import os import time path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) sys.path.insert(0, path ) print(path) from WaveSpace.PlottingHelpers import Plotting from WaveSpace.Utils import HelperFuns as hf from WaveSpace.Utils import ImportHelpers from WaveSpace.WaveAnalysis import OpticalFlow import numpy as np import matplotlib.pyplot as plt import matplotlib.colors as mcolors import matplotlib.gridspec as gridspec from matplotlib import colormaps Loading Simulated Data ---------------------- .. code-block:: python # Load some simulated data dataPath = os.path.join(path, "Examples/ExampleData/Output") waveData = ImportHelpers.load_wavedata_object(dataPath + "/complexData") Optical Flow Computation ------------------------ .. code-block:: python tStart = time.time() print("OpticalFlow started") OpticalFlow.create_uv(waveData, applyGaussianBlur=False, type = "angle", Sigma=1, alpha = 0.1, maxIter = 200, dataBucketName="complexData", is_phase = False) print('optical flow took: ', time.time()-tStart) Visualization ------------- .. code-block:: python trialToPlot = 4 waveData.set_active_dataBucket('UV') ani = Plotting.plot_optical_flow(waveData, UVBucketName = 'UV', PlottingDataBucketName = 'complexData', dataInds = (0, trialToPlot, slice(None), slice(None), slice(None)), plotangle=True, normVectorLength = True) output_path = os.path.join(path, "Examples/ExampleData/Output/") ani.save( output_path + 'OpticalFlowAfterFilter_Hilbert.gif') Motif Detection --------------- .. code-block:: python foi = 10 cycleLength = waveData.get_sample_rate()/ foi freqInd = 0 motifs = hf.find_wave_motifs(waveData, dataBucketName="UV", threshold = 0.8, nTimepointsEdge=cycleLength, mergeThreshold = 0.8, minFrames=cycleLength, pixelThreshold = 0.6, magnitudeThreshold=.1, dataInds = (freqInd, slice(None), slice(None), slice(None), slice(None)), Mask = False) Motif Mapping and Plotting -------------------------- .. code-block:: python conds = waveData.get_trialInfo() uniques = np.unique(conds) trial_dict = {} for trial_idx, condition in enumerate(conds): if condition not in trial_dict: trial_dict[condition] = [] trial_dict[condition].append(trial_idx) motifMap = np.full((len(conds),len(waveData.get_time())), -1) for ind, motif in enumerate(motifs): trial_frames_list = motif['trial_frames'] for trial_frame in trial_frames_list: (trial, (start_timepoint, end_timepoint)) = trial_frame motifMap[trial, start_timepoint:end_timepoint] = ind cmap = mcolors.ListedColormap(['grey', "#8F43D1", "#c50069",'#d67258', '#416ae4', '#378b8c', "#0f3200" ,'#a05195', "#4e2f13", "#3900AB","#b3ff00", "#ff0015", "#0d15ff"]) bounds = [-1, 0, 1, 2, 3, 4, 5, 7, 8, 9, 10, 11, 12] norm = mcolors.BoundaryNorm(bounds, cmap.N) im = plt.pcolormesh(motifMap, cmap=cmap, norm=norm) plt.colorbar() # One column per condition fig, axs = plt.subplots(1, len(motifs[0:6]), figsize=(12, 6), gridspec_kw={'wspace': 0.3}) for motifInd, motif in enumerate(motifs[0:6]): # Quiver plot axs[motifInd].quiver(-np.real(motif['average']), -np.imag(motif['average']), color='black') axs[motifInd].set_facecolor('white') axs[motifInd].set_aspect('equal') for spine in axs[motifInd].spines.values(): spine.set_edgecolor(cmap(motifInd+1)) spine.set_linewidth(2)