"""This module contains nodes for spectral analysis with Timeflux."""
import numpy as np
import pandas as pd
import xarray as xr
from scipy.signal import welch
from scipy.fft import fftfreq, rfftfreq, fft, rfft
from timeflux.core.node import Node
[docs]class FFT(Node):
"""Compute the one-dimensional discrete Fourier Transform for each column using the Fast Fourier Tranform algorithm.
Attributes:
i (Port): default input, expects DataFrame.
o (Port): default output, provides DataArray.
Example:
In this exemple, we simulate a white noise and we apply FFT:
* ``fs`` = `10.0`
* ``nfft`` = `5`
* ``return_onesided`` = `False`
self.i.data::
A B C
2017-12-31 23:59:59.998745401 0.185133 0.541901 0.872946
2018-01-01 00:00:00.104507143 0.732225 0.806561 0.658783
2018-01-01 00:00:00.202319939 0.692277 0.849196 0.249668
2018-01-01 00:00:00.300986584 0.489425 0.221209 0.987668
2018-01-01 00:00:00.396560186 0.944059 0.039427 0.705575
self.o.data::
xarray.DataArray (times: 1, freqs: 5, space: 3)
array([[[ 3.043119+0.j , 2.458294+0.j , 3.47464 +0.j ],
[-0.252884+0.082233j, -0.06265 -1.098709j, 0.29353 +0.478287j],
[-0.805843+0.317437j, 0.188256+0.146341j, 0.151515-0.674376j],
[-0.805843-0.317437j, 0.188256-0.146341j, 0.151515+0.674376j],
[-0.252884-0.082233j, -0.06265 +1.098709j, 0.29353 -0.478287j]]])
Coordinates:
* times (times) datetime64[ns] 2018-01-01T00:00:00.396560186
* freqs (freqs) float64 0.0 2.0 4.0 -4.0 -2.0
* space (space) object 'A' 'B' 'C'
Notes:
This node should be used after a buffer.
References:
* `scipy.fft <https://docs.scipy.org/doc/scipy/reference/fft.html>`_
"""
def __init__(self, fs=1.0, nfft=None, return_onesided=True):
"""
Args:
fs (float): Nominal sampling rate of the input data.
nfft (int|None): Length of the Fourier transform. Default: length of the chunk.
return_onesided (bool): If `True`, return a one-sided spectrum for real data.
If `False` return a two-sided spectrum.
(Note that for complex data, a two-sided spectrum is always returned.)
Default: `True`.
"""
self._fs = fs
self._nfft = nfft
if return_onesided:
self._sides = "onesided"
else:
self._sides = "twosided"
if self._nfft is not None:
self._set_freqs()
def _check_nfft(self):
# Check validity of nfft at first chunk
if self._nfft is None:
self.logger.debug("nfft := length of the chunk ")
self._nfft = self.i.data.shape[0]
self._set_freqs()
elif self._nfft < self.i.data.shape[0]:
raise ValueError("nfft must be greater than or equal to length of chunk.")
else:
self._nfft = int(self._nfft)
def _set_freqs(self):
# Set freqs indexes
if self._sides == "onesided":
self._freqs = rfftfreq(self._nfft, 1 / self._fs)
else:
self._freqs = fftfreq(self._nfft, 1 / self._fs)
[docs] def update(self):
# copy the meta
self.o = self.i
# When we have not received data, there is nothing to do
if not self.i.ready():
return
# At this point, we are sure that we have some data to process
self._check_nfft()
self.o.data = self.i.data
if self._sides == "twosided":
func = fft
else:
self.o.data = self.o.data.apply(lambda x: x.real)
func = rfft
values = func(self.o.data.values.T, n=self._nfft).T
self.o.data = xr.DataArray(
np.stack([values], 0),
coords=[[self.o.data.index[-1]], self._freqs, self.o.data.columns],
dims=["time", "freq", "space"],
)
[docs]class Welch(Node):
"""Estimate power spectral density using Welch’s method.
Attributes:
i (Port): default input, expects DataFrame.
o (Port): default output, provides DataArray with dimensions (time, freq, space).
Example:
In this exemple, we simulate data with noisy sinus on three sensors (columns `a`, `b`, `c`):
* ``fs`` = `100.0`
* ``nfft`` = `24`
node.i.data::
\s a b c
1970-01-01 00:00:00.000 -0.233920 -0.343296 0.157988
1970-01-01 00:00:00.010 0.460353 0.777296 0.957201
1970-01-01 00:00:00.020 0.768459 1.234923 1.942190
1970-01-01 00:00:00.030 1.255393 1.782445 2.326175
... ... ... ...
1970-01-01 00:00:01.190 1.185759 2.603828 3.315607
node.o.data::
<xarray.DataArray (time: 1, freq: 13, space: 3)>
array([[[2.823924e-02, 1.087382e-01, 1.153163e-01],
[1.703466e-01, 6.048703e-01, 6.310628e-01],
... ... ...
[9.989429e-04, 8.519226e-04, 7.769918e-04],
[1.239551e-03, 7.412518e-04, 9.863335e-04],
[5.382880e-04, 4.999334e-04, 4.702757e-04]]])
Coordinates:
* time (time) datetime64[ns] 1970-01-01T00:00:01.190000
* freq (freq) float64 0.0 4.167 8.333 12.5 16.67 ... 37.5 41.67 45.83 50.0
* space (space) object 'a' 'b' 'c'
Notes:
This node should be used after a Window with the appropriate length, with regard to the parameters
`noverlap`, `nperseg` and `nfft`.
It should be noted that a pipeline such as {LargeWindow-Welch} is in fact equivalent to a pipeline
{SmallWindow-FFT-LargeWindow-Average} with SmallWindow 's parameters `length` and `step` respectively
equivalent to `nperseg` and `step` and with FFT node with same kwargs.
"""
def __init__(self, rate=None, closed="right", **kwargs):
"""
Args:
rate (float|None): Nominal sampling rate of the input data. If `None`, the rate will be taken from the input meta/
closed (str): Make the index closed on the `right`, `left` or `center`.
kwargs: Keyword arguments to pass to scipy.signal.welch function.
You can specify: window, nperseg, noverlap, nfft, detrend, return_onesided and scaling.
"""
self._rate = rate
self._closed = closed
self._kwargs = kwargs
self._set_default()
def _set_default(self):
# We set the default params if they are not specifies in kwargs in order to check that they are valid, in respect of the length and sampling of the input data.
if "nperseg" not in self._kwargs.keys():
self._kwargs["nperseg"] = 256
self.logger.debug("nperseg := 256")
if "nfft" not in self._kwargs.keys():
self._kwargs["nfft"] = self._kwargs["nperseg"]
self.logger.debug(
"nfft := nperseg := {nperseg}".format(nperseg=self._kwargs["nperseg"])
)
if "noverlap" not in self._kwargs.keys():
self._kwargs["noverlap"] = self._kwargs["nperseg"] // 2
self.logger.debug(
"noverlap := nperseg/2 := {noverlap}".format(
noverlap=self._kwargs["noverlap"]
)
)
def _check_nfft(self):
# Check validity of nfft at first chun
if not all(
i <= len(self.i.data)
for i in [self._kwargs[k] for k in ["nfft", "nperseg", "noverlap"]]
):
raise ValueError(
"nfft, noverlap and nperseg must be greater than or equal to length of chunk."
)
else:
self._kwargs.update(
{
keyword: int(self._kwargs[keyword])
for keyword in ["nfft", "nperseg", "noverlap"]
}
)
[docs] def update(self):
# copy the meta
self.o = self.i
# When we have not received data, there is nothing to do
if not self.i.ready():
return
# Check rate
if self._rate:
rate = self._rate
elif "rate" in self.i.meta:
rate = self.i.meta["rate"]
else:
raise ValueError(
"The rate was neither explicitely defined nor found in the stream meta."
)
# At this point, we are sure that we have some data to process
# apply welch on the data:
self._check_nfft()
f, Pxx = welch(x=self.i.data, fs=rate, **self._kwargs, axis=0)
if self._closed == "left":
time = self.i.data.index[-1]
elif self._closed == "center":
def middle(a):
return int(np.ceil(len(a) / 2)) - 1
time = self.i.data.index[middle(self.i.data)]
else: # right
time = self.i.data.index[-1]
# f is the frequency axis and Pxx the average power of shape (Nfreqs x Nchanels)
# we reshape Pxx to fit the ('time' x 'freq' x 'space') dimensions
self.o.data = xr.DataArray(
np.stack([Pxx], 0),
coords=[[time], f, self.i.data.columns],
dims=["time", "frequency", "space"],
)
[docs]class Bands(Node):
"""Averages the XArray values over freq dimension according to the frequencies bands given in arguments.
This node selects a subset of values over the chosen dimensions, averages them along this axis and convert the result into a flat dataframe.
This node will output as many ports bands as given bands, with their respective name as suffix.
Attributes:
i (Port): default output, provides DataArray with 3 dimensions (time, freq, space).
o (Port): Default output, provides DataFrame.
o_* (Port): Dynamic outputs, provide DataFrame.
"""
def __init__(self, bands=None, relative=False):
"""
Args:
bands (dict): Define the band to extract given its name and its range.
An output port will be created with the given names as suffix.
"""
bands = bands or {
"delta": [1, 4],
"theta": [4, 8],
"alpha": [8, 12],
"beta": [12, 30],
}
self._relative = relative
self._bands = []
for band_name, band_range in bands.items():
self._bands.append(
dict(
port=getattr(self, "o_" + band_name),
slice=slice(band_range[0], band_range[1]),
meta={"bands": {"range": band_range, "relative": relative}},
)
)
[docs] def update(self):
# When we have not received data, there is nothing to do
if not self.i.ready():
return
# At this point, we are sure that we have some data to process
for band in self._bands:
# 1. select the Xarray on freq axis in the range, 2. average along freq axis
band_power = (
self.i.data.loc[{"frequency": band["slice"]}].sum("frequency").values
) # todo: sum
if self._relative:
tot_power = self.i.data.sum("frequency").values
tot_power[tot_power == 0.0] = 1
band_power /= tot_power
band["port"].data = pd.DataFrame(
columns=self.i.data.space.values,
index=self.i.data.time.values,
data=band_power,
)
band["port"].meta = {**(self.i.meta or {}), **band["meta"]}