"""Epoching nodes"""
import numpy as np
import pandas as pd
import json
import xarray as xr
from timeflux.core.node import Node
from timeflux.core.exceptions import WorkerInterrupt
from timeflux.helpers.port import match_events
[docs]class Samples(Node):
"""Fixed-size epoching.
This node produces equal-length epochs from the default input stream. These epochs are triggered from the `events` stream.
Each epoch contains contextual metadata, making this node ideal in front of the `ml` node to train a model.
Non-monotonic data, late data, late events, jittered data and jumbled events are all handled reasonably well.
Multiple epochs are automatically assigned to dynamic outputs ports. For convenience, the first epoch is bound to the default output, so you can avoid enumerating all output ports if you expects only one epoch.
Attributes:
i (Port): Default data input, expects DataFrame.
i_events (Port): Event input, expects DataFrame.
o (Port): Default output, provides DataFrame and meta.
o_* (Port): Dynamic outputs, provide DataFrame and meta.
Args:
trigger (string): The marker name.
length (float): The length of the epoch, in seconds.
rate (float): The rate of the input stream. If None (the default), it will be taken from the meta data.
buffer (float): The length of the buffer, in seconds (default: 5).
"""
def __init__(self, trigger, length=0.6, rate=None, buffer=5):
self._trigger = trigger
self._duration_epoch = length
self._duration_buffer = buffer
self._rate = rate
self._length_epoch = None
self._length_buffer = None
self._buffer = None
self._epochs = []
[docs] def update(self):
if self.i.ready():
# We need a rate, either as an argument or from the input meta
if not self._rate:
if not "rate" in self.i.meta:
self.logger.error("Rate is not specified")
raise WorkerInterrupt()
self._rate = self.i.meta["rate"]
if not self._length_buffer:
self._length_buffer = round(self._duration_buffer * self._rate)
if not self._length_epoch:
self._length_epoch = round(self._duration_epoch * self._rate)
# Append to main buffer
if self._buffer is None:
self._buffer = self.i.data
else:
self._buffer = pd.concat([self._buffer, self.i.data])
# Detect onsets
matches = match_events(self.i_events, self._trigger)
if matches is not None:
for index, row in matches.iterrows():
# Start a new epoch
try:
context = json.loads(row["data"])
except json.JSONDecodeError:
context = row["data"]
except TypeError:
context = {}
self._epochs.append(
{
"data": None,
"meta": {"onset": index, "context": context},
}
)
# Update epochs
if self._epochs and self.i.ready():
indices = []
for index, epoch in enumerate(self._epochs):
if epoch["data"] is None:
# Discard if the event is outdated
if epoch["meta"]["onset"] < self._buffer.index[0]:
self.logger.warning("Oudated event")
indices.append(index)
# Find the first sample and initialize the epoch
mask = self._buffer.index >= epoch["meta"]["onset"]
data = self._buffer[mask][: self._length_epoch]
if len(data) > 0:
epoch["data"] = data
else:
# Append
mask = self._buffer.index > epoch["data"].index[-1]
remaining = self._length_epoch - len(epoch["data"])
epoch["data"] = pd.concat(
[epoch["data"], self._buffer[mask][:remaining]]
)
# Send if the epoch is complete
if (
epoch["data"] is not None
and len(epoch["data"]) == self._length_epoch
):
o = getattr(self, "o_" + str(len(indices)))
o.data = epoch["data"]
o.meta = {"rate": self._rate, "epoch": epoch["meta"]}
indices.append(index)
if len(indices) > 0:
# Remove complete epochs
for index in sorted(set(indices), reverse=True):
del self._epochs[index]
self.o = self.o_0 # Bind default output to the first epoch
# Trim main buffer
if self._buffer is not None:
if len(self._buffer) > self._length_buffer:
low = len(self._buffer) - self._length_buffer
self._buffer = self._buffer[low:]
[docs]class Epoch(Node):
"""Event-triggered epoching.
This node continuously buffers a small amount of data (of a duration of ``before`` seconds) from the default input stream.
When it detects a marker matching the ``event_trigger`` in the ``label`` column of the event input stream, it starts accumulating data for ``after`` seconds.
It then sends the epoched data to an output stream, and sets the metadata to a dictionary containing the triggering marker and optional event data.
Multiple, overlapping epochs are authorized. Each concurrent epoch is assigned its own `Port`. For convenience, the first epoch is bound to the default output, so you can avoid enumerating all output ports if you expects only one epoch.
Attributes:
i (Port): Default data input, expects DataFrame.
i_events (Port): Event input, expects DataFrame.
o (Port): Default output, provides DataFrame and meta.
o_* (Port): Dynamic outputs, provide DataFrame and meta.
Args:
event_trigger (string): The marker name.
before (float): Length before onset, in seconds.
after (float): Length after onset, in seconds.
Example:
.. literalinclude:: /../examples/epoch.yaml
:language: yaml
"""
def __init__(self, event_trigger, before=0.2, after=0.6):
self._event_trigger = event_trigger
self._before = pd.Timedelta(seconds=before)
self._after = pd.Timedelta(seconds=after)
self._buffer = None
self._epochs = []
[docs] def update(self):
# Append to main buffer
if self.i.data is not None:
if not self.i.data.empty:
if self._buffer is None:
self._buffer = self.i.data
else:
self._buffer = pd.concat([self._buffer, self.i.data])
# Detect onset
matches = match_events(self.i_events, self._event_trigger)
if matches is not None:
for index, row in matches.iterrows():
# Start a new epoch
low = index - self._before
high = index + self._after
if self._buffer is not None:
if not self._buffer.index.is_monotonic_increasing:
self.logger.warning("Index must be monotonic. Skipping epoch.")
return
try:
context = json.loads(row["data"])
except json.JSONDecodeError:
context = row["data"]
except TypeError:
context = {}
self._epochs.append(
{
"data": self._buffer[low:high],
"meta": {
"onset": index,
"context": context,
"before": self._before.total_seconds(),
"after": self._after.total_seconds(),
},
}
)
# Trim main buffer
if self._buffer is not None:
low = self._buffer.index[-1] - self._before
self._buffer = self._buffer[low:]
# Update epochs
if self._epochs and self.i.ready():
complete = 0
for epoch in self._epochs:
high = epoch["meta"]["onset"] + self._after
last = self.i.data.index[-1]
if epoch["data"].empty:
low = epoch["meta"]["onset"] - self._before
mask = (self.i.data.index >= low) & (self.i.data.index <= high)
else:
low = epoch["data"].index[-1]
mask = (self.i.data.index > low) & (self.i.data.index <= high)
# Append
epoch["data"] = pd.concat([epoch["data"], self.i.data[mask]])
# Send if we have enough data
if last >= high:
o = getattr(self, "o_" + str(complete))
o.data = epoch["data"]
o.meta = {"epoch": epoch["meta"]}
complete += 1
if complete > 0:
del self._epochs[:complete] # Unqueue
self.o = self.o_0 # Bind default output to the first epoch
[docs]class Trim(Node):
"""Trim data so epochs are of equal length.
Because real-time data is often jittered, the `Epoch` node is not always able to
provide dataframes of equal dimensions. This can be problematic if the data is
further processed by the `Pipeline` node, for example. This simple node takes care
of trimming the extra samples. It should be placed just after an `Epoch` node.
Attributes:
i_* (Port): Epoched data input, expects DataFrame.
o_* (Port): Trimmed epochs, provides DataFrame and meta.
Args:
samples (int): The maximum number of samples per epoch.
If `0`, the size of the first epoch is used.
"""
def __init__(self, samples=0):
self.samples = samples
[docs] def update(self):
ports = []
for _, _, port in self.iterate("i_*"):
if port.ready():
if self.samples == 0:
self.samples = len(port.data)
if len(port.data) < self.samples:
self.logger.warn(
f"Epoch rejected: not enough sample ({len(port.data)}<{self.samples})"
)
else:
port.data = port.data.head(self.samples)
ports.append(port)
for i, port in enumerate(ports):
o = getattr(self, f"o_{i}")
o.data = port.data
o.meta = port.meta
[docs]class ToXArray(Node):
"""Convert multiple epochs to DataArray
This node iterates over input ports with valid epochs, concatenates them on the
first axis, and creates a XArray with dimensions ('epoch', 'time', 'space') where
epoch corresponds to th input ports, time to the ports data index and space to the
ports data columns.
A port is considered to be valid if it has meta with key 'epoch' and data with
expected number of samples.
If some epoch have an invalid length (which happens when the data has jitter), the
node either raises a warning, an error or pass.
Attributes:
i_* (Port): Dynamic inputs, expects DataFrame and meta.
o (Port): Default output, provides DataArray and meta.
Args:
reporting (string| None): How this function handles epochs with
invalid length: `warn` will issue a warning with :py:func:`warnings.warn`,
`error` will raise an exception, `None` will ignore it.
output (`DataArray`|`Dataset`): Type of output to return
context_key (string|None): If output type is `Dataset`, key to define the
target of the event. If `None`, the whole context is considered.
"""
def __init__(self, reporting="warn", output="DataArray", context_key=None):
self._reporting = reporting
self._output = output
self._context_key = context_key
self._columns = self._before = self._after = None
self._ready = False
[docs] def update(self):
if not self._ready:
ports_ready = [port for _, _, port in self.iterate("i*") if port.ready()]
if len(ports_ready) < 1:
return
# initialize attributes on first ready port
port = ports_ready[0]
if port.ready():
self._columns = port.data.columns
self._before = port.meta["epoch"]["before"]
self._after = port.meta["epoch"]["after"]
self._num_times = len(port.data)
self._times = pd.TimedeltaIndex(
data=np.linspace(-self._before, self._after, self._num_times),
unit="s",
)
self._rate = 1 / (self._times[1] - self._times[0]).total_seconds()
self._ready = True
ports_ready = [
port for _, _, port in self.iterate(name="i*") if self._valid_port(port)
]
if not ports_ready:
return
list_onset = [port.meta["epoch"].get("onset") for port in ports_ready]
list_context = [port.meta["epoch"].get("context") for port in ports_ready]
list_epochs = [port.data for port in ports_ready]
data = np.stack([epoch.values for epoch in list_epochs], axis=0)
meta = {
"epochs_context": list_context,
"epochs_onset": list_onset,
"rate": self._rate,
}
if self._output == "DataArray":
if self._context_key is not None:
data_array = xr.DataArray(
data,
dims=("target", "time", "space"),
coords=(
[self._extract_target(context) for context in list_context],
self._times,
self._columns,
),
)
else:
data_array = xr.DataArray(
data,
dims=("epoch", "time", "space"),
coords=(np.arange(data.shape[0]), self._times, self._columns),
)
self.o.data = data_array
self.o.meta = meta
else: # Dataset
data_array = xr.DataArray(
data,
dims=("epoch", "time", "space"),
coords=(np.arange(data.shape[0]), self._times, self._columns),
)
self.o.data = xr.Dataset(
{
"data": data_array,
"target": [
self._extract_target(context) for context in list_context
],
}
)
self.o.meta = meta
def _extract_target(self, context):
if self._context_key is None:
return context
else:
if isinstance(context, str):
context = json.loads(context)
return context.get(self._context_key)
def _valid_port(self, port):
"""Checks that the port has valid meta and data."""
if port.data is None or port.data.empty:
return False
if "epoch" not in port.meta:
return False
if port.data.shape[0] != self._num_times:
if self._reporting == "error":
raise WorkerInterrupt(
f"Received an epoch with {port.data.shape[0]} "
f"samples instead of {self._num_times}."
)
elif self._reporting == "warn":
self.logger.warning(
f"Received an epoch with {port.data.shape[0]} "
f"samples instead of {self._num_times}. "
f"Skipping."
)
return False
else: # reporting is None
# be cool
return False
return True