Source code for convnwb.sorting.process

"""Processing functions related to spike sorting / combinato files."""

import numpy as np

from convnwb.sorting.io import load_spike_data_file, load_sorting_data_file, save_units
from convnwb.sorting.utils import get_sorting_kept_labels, get_group_labels, extract_clusters

###################################################################################################
###################################################################################################

[docs]def collect_all_sorting(spike_data, sort_data): """Collect together all the organized spike sorting information for a channel of data. Parameters ---------- spike_data : dict Loaded data from the spike data file. Should include the keys: `times`, `waveforms`. sort_data : dict Loaded sorting data from the spike sorting data file. Should include the keys: `index`, `classes`, `groups`. Returns ------- outputs : dict Each value is an array of all values for valid events in the channel of data, including: * `times` : spike times for each event * `waveforms` : spike waveform for each event * `classes` : class assignment for each event * `clusters` : cluster (group) assignment for each event Notes ----- Kept information is for all valid spikes - all clusters considered putative single units. This excludes: - spike events detected but excluded from sorting due to being listed as artifact - spike events entered into sorting, but that are unassigned to a group - spike events sorted into a group, but who's group was listed as an artifact """ assert spike_data['channel'] == sort_data['channel'], "Data file channels do not match." assert spike_data['polarity'] == sort_data['polarity'], "Data file polarity does not match." # Get the set of valid class & group labels, and make a mask valid_classes, valid_groups = get_sorting_kept_labels(sort_data['groups']) class_mask = np.isin(sort_data['classes'], valid_classes) # Create a vector reflecting group assignment of each spike group_labels = get_group_labels(sort_data['classes'], sort_data['groups']) outputs = { # collect metadata into output 'channel' : spike_data['channel'], 'polarity' : spike_data['polarity'], # spike data collected as the non-artifact spikes, sub-selected for valid classes 'times' : spike_data['times'][sort_data['index']][class_mask], 'waveforms' : spike_data['waveforms'][sort_data['index'][class_mask], :], # spike sorting information collected as the valid class labels 'classes' : sort_data['classes'][class_mask], 'clusters' : group_labels[class_mask], } return outputs
[docs]def process_combinato_data(channel, input_folder, polarity, user, units_folder, continue_on_fail=False, verbose=True): """Helper function to run the process of going from combinato -> extracted units files. Parameters ---------- channel : int or str The channel number / label of the file to load. input_folder : str or Path The folder location to load the spike data from. polarity : {'neg', 'pos'} Which polarity of detected spikes to load. user : str The 3 character user label to load. output_folder : str or Path The folder destination to save the output units files to. continue_on_fail : bool, optional, default: False Whether to continue when an error is encountered. verbose : bool, optional, default: True Whether to print out updates about the extraction. """ try: # Load spike & sorting data spike_data = load_spike_data_file(channel, input_folder, polarity) sort_data = load_sorting_data_file(channel, input_folder, polarity, user) # Organize and collect extracted data together, and extract unit clusters clusters = collect_all_sorting(spike_data, sort_data) units = extract_clusters(clusters) # Save out extracted unit data save_units(units, units_folder) if verbose: print('Extracted channel {:20s} - found {:2d} clusters\t\t'.format(\ channel, len(units))) except: if not continue_on_fail: raise if verbose: print('Issue extracting channel: {}'.format(channel))