Source code for timeflux.helpers.mne

"""MNE helpers"""

import pandas as pd
import numpy as np
import xarray as xr
import logging
from timeflux.core.exceptions import WorkerInterrupt

try:
    import mne
except ModuleNotFoundError:
    raise SystemExit(
        "MNE is not installed. Optional dependencies can be installed with: 'pip install timeflux[opt]'."
    )

logger = logging.getLogger()


def _context_to_id(context, context_key, event_id):
    if context_key is None:
        return context
    else:
        return event_id.get(context.get(context_key))


[docs]def xarray_to_mne( data, meta, context_key, event_id, reporting="warn", ch_types="eeg", **kwargs ): """Convert DataArray and meta into mne Epochs object Args: data (DataArray): Array of dimensions ('epoch', 'time', 'space') meta (dict): Dictionary with keys 'epochs_context', 'rate', 'epochs_onset' context_key (str|None): key to select the context label. If the context is a string, `context_key` should be set to ``None``. event_id (dict): Associates context label to an event_id that should be an int. (eg. dict(auditory=1, visual=3)) reporting ('warn'|'error'| None): How this function handles epochs with invalid context: - 'error' will raise a TimefluxException - 'warn' will print a warning with :py:func:`warnings.warn` and skip the corrupted epochs - ``None`` will skip the corrupted epochs ch_types (list|str): Channel type to Returns: epochs (mne.Epochs): mne object with the converted data. """ if isinstance(ch_types, str): ch_types = [ch_types] * len(data.space) if isinstance(data, xr.DataArray): pass elif isinstance(data, xr.Dataset): # extract data data = data.data else: raise ValueError( f"data should be of type DataArray or Dataset, received {data.type} instead. " ) _dims = data.coords.dims if "target" in _dims: np_data = data.transpose("target", "space", "time").values elif "epoch" in _dims: np_data = data.transpose("epoch", "space", "time").values else: raise ValueError( f"Data should have either `target` or `epoch` in its coordinates. Found {_dims}" ) # create events objects are essentially numpy arrays with three columns: # event_sample | previous_event_id | event_id events = np.array( [ [onset.value, 0, _context_to_id(context, context_key, event_id)] for (context, onset) in zip(meta["epochs_context"], meta["epochs_onset"]) ] ) # List of three arbitrary events events_mask = np.isnan(events.astype(float))[:, 2] if events_mask.any(): if reporting == "error": raise WorkerInterrupt( f"Found {events_mask.sum()} epochs with corrupted context. " ) else: # reporting is either None or warn # be cool, skip those evens events = events[~events_mask, :] np_data = np_data[~events_mask, :, :] if reporting == "warn": logger.warning( f"Found {events_mask.sum()} epochs with corrupted context. " f"Skipping them. " ) # Fill the second column with previous event ids. events[0, 1] = events[0, 2] events[1:, 1] = events[0:-1, 2] # set the info rate = meta["rate"] info = mne.create_info( ch_names=list(data.space.values), sfreq=rate, ch_types=ch_types ) # construct the mne object epochs = mne.EpochsArray( np_data, info=info, events=events.astype(int), event_id=event_id, tmin=data.time.values[0] / np.timedelta64(1, "s"), verbose=False, **kwargs, ) return epochs
[docs]def mne_to_xarray(epochs, context_key, event_id, output="dataarray"): """Convert mne Epochs object into DataArray along with meta. Args: epochs (mne.Epochs): mne object with the converted data. context_key (str|None): key to select the context label. If the context is a string, `context_key` should be set to ``None``. event_id (dict): Associates context label to an event_id that should be an int. (eg. dict(auditory=1, visual=3)) output (str): type of the expected output (DataArray or Dataset) Returns: data (DataArray|Dataset): Array of dimensions ('epoch', 'time', 'space') meta (dict): Dictionary with keys 'epochs_context', 'rate', 'epochs_onset' """ reversed_event_id = {value: key for (key, value) in event_id.items()} np_data = epochs._data ch_names = epochs.ch_names epochs_onset = [pd.Timestamp(event_sample) for event_sample in epochs.events[:, 0]] epochs_context = [ {context_key: reversed_event_id[_id]} for _id in epochs.events[:, 2] ] meta = dict( epochs_onset=epochs_onset, epochs_context=epochs_context, rate=epochs.info["sfreq"], ) n_epochs = len(epochs) times = pd.TimedeltaIndex(data=epochs.times, unit="s") data = xr.DataArray( np_data, dims=("epoch", "space", "time"), coords=(np.arange(n_epochs), ch_names, times), ).transpose("epoch", "time", "space") if output == "dataarray": return data, meta else: # output == 'dataset' data = xr.Dataset( { "data": data, "target": [reversed_event_id[_id] for _id in epochs.events[:, 2]], } ) return data, meta