References
=== Analysis functions ===
Helper classes and functions to perform analysis on fitted models
PklHandler
Helper class to handle metadata and fit data from pkl file
Source code in pytau/changepoint_analysis.py
class PklHandler():
"""Helper class to handle metadata and fit data from pkl file
"""
def __init__(self, file_path):
"""Initialize PklHandler class
Args:
file_path (str): Path to pkl file
"""
self.dir_name = os.path.dirname(file_path)
file_name = os.path.basename(file_path)
self.file_name_base = file_name.split('.')[0]
self.pkl_file_path = \
os.path.join(self.dir_name, self.file_name_base + ".pkl")
with open(self.pkl_file_path, 'rb') as this_file:
self.data = pkl.load(this_file)
model_keys = ['model', 'approx', 'lambda', 'tau', 'data']
key_savenames = ['_model_structure', '_fit_model',
'lambda_array', 'tau_array', 'processed_spikes']
data_map = dict(zip(model_keys, key_savenames))
for key, var_name in data_map.items():
setattr(self, var_name, self.data['model_data'][key])
self.metadata = self.data['metadata']
self.pretty_metadata = pd.json_normalize(self.data['metadata']).T
self.tau = _tau(self.tau_array, self.metadata)
self.firing = _firing(self.tau, self.processed_spikes, self.metadata)
__init__(self, file_path)
special
Initialize PklHandler class
Parameters:
Name | Type | Description | Default |
---|---|---|---|
file_path |
str |
Path to pkl file |
required |
Source code in pytau/changepoint_analysis.py
def __init__(self, file_path):
"""Initialize PklHandler class
Args:
file_path (str): Path to pkl file
"""
self.dir_name = os.path.dirname(file_path)
file_name = os.path.basename(file_path)
self.file_name_base = file_name.split('.')[0]
self.pkl_file_path = \
os.path.join(self.dir_name, self.file_name_base + ".pkl")
with open(self.pkl_file_path, 'rb') as this_file:
self.data = pkl.load(this_file)
model_keys = ['model', 'approx', 'lambda', 'tau', 'data']
key_savenames = ['_model_structure', '_fit_model',
'lambda_array', 'tau_array', 'processed_spikes']
data_map = dict(zip(model_keys, key_savenames))
for key, var_name in data_map.items():
setattr(self, var_name, self.data['model_data'][key])
self.metadata = self.data['metadata']
self.pretty_metadata = pd.json_normalize(self.data['metadata']).T
self.tau = _tau(self.tau_array, self.metadata)
self.firing = _firing(self.tau, self.processed_spikes, self.metadata)
get_state_firing(spike_array, tau_array)
Calculate firing rates within states given changepoint positions on data
Parameters:
Name | Type | Description | Default |
---|---|---|---|
spike_array |
3D Numpy array |
trials x nrns x bins |
required |
tau_array |
2D Numpy array |
trials x switchpoints |
required |
Returns:
Type | Description |
---|---|
Numpy array |
Average firing given state bounds |
Source code in pytau/changepoint_analysis.py
def get_state_firing(spike_array, tau_array):
"""Calculate firing rates within states given changepoint positions on data
Args:
spike_array (3D Numpy array): trials x nrns x bins
tau_array (2D Numpy array): trials x switchpoints
Returns:
Numpy array: Average firing given state bounds
"""
states = tau_array.shape[-1] + 1
# Get mean firing rate for each STATE using model
state_inds = np.hstack([np.zeros((tau_array.shape[0], 1)),
tau_array,
np.ones((tau_array.shape[0], 1))*spike_array.shape[-1]])
state_lims = np.array([state_inds[:, x:x+2] for x in range(states)])
state_lims = np.vectorize(np.int)(state_lims)
state_lims = np.swapaxes(state_lims, 0, 1)
state_firing = \
np.array([[np.mean(trial_dat[:, start:end], axis=-1)
for start, end in trial_lims]
for trial_dat, trial_lims in zip(spike_array, state_lims)])
state_firing = np.nan_to_num(state_firing)
return state_firing
=== I/O functions ===
Pipeline to handle model fitting from data extraction to saving results
DatabaseHandler
Class to handle transactions with model database
Source code in pytau/changepoint_io.py
class DatabaseHandler():
"""Class to handle transactions with model database
"""
def __init__(self):
"""Initialize DatabaseHandler class
"""
self.unique_cols = ['exp.model_id', 'exp.save_path', 'exp.fit_date']
self.model_database_path = MODEL_DATABASE_PATH
self.model_save_base_dir = MODEL_SAVE_DIR
if os.path.exists(self.model_database_path):
self.fit_database = pd.read_csv(self.model_database_path,
index_col=0)
all_na = [all(x) for num, x in self.fit_database.isna().iterrows()]
if all_na:
print(f'{sum(all_na)} rows found with all NA, removing...')
self.fit_database = self.fit_database.dropna(how='all')
else:
print('Fit database does not exist yet')
def show_duplicates(self, keep='first'):
"""Find duplicates in database
Args:
keep (str, optional): Which duplicate to keep
(refer to pandas duplicated). Defaults to 'first'.
Returns:
pandas dataframe: Dataframe containing duplicated rows
pandas series : Indices of duplicated rows
"""
dup_inds = self.fit_database.drop(self.unique_cols, axis=1)\
.duplicated(keep=keep)
return self.fit_database.loc[dup_inds], dup_inds
def drop_duplicates(self):
"""Remove duplicated rows from database
"""
_, dup_inds = self.show_duplicates()
print(f'Removing {sum(dup_inds)} duplicate rows')
self.fit_database = self.fit_database.loc[~dup_inds]
def check_mismatched_paths(self):
"""Check if there are any mismatched pkl files between database and directory
Returns:
pandas dataframe: Dataframe containing rows for which pkl file not present
list: pkl files which cannot be matched to model in database
list: all files in save directory
"""
mismatch_from_database = [not os.path.exists(x + ".pkl")
for x in self.fit_database['exp.save_path']]
file_list = glob(os.path.join(self.model_save_base_dir, "*/*.pkl"))
mismatch_from_file = [not
(x.split('.')[0] in list(
self.fit_database['exp.save_path']))
for x in file_list]
print(f"{sum(mismatch_from_database)} mismatches from database" + "\n"
+ f"{sum(mismatch_from_file)} mismatches from files")
return mismatch_from_database, mismatch_from_file, file_list
def clear_mismatched_paths(self):
"""Remove mismatched files and rows in database
i.e. Remove
1) Files for which no entry can be found in database
2) Database entries for which no corresponding file can be found
"""
mismatch_from_database, mismatch_from_file, file_list = \
self.check_mismatched_paths()
mismatch_from_file = np.array(mismatch_from_file)
mismatch_from_database = np.array(mismatch_from_database)
self.fit_database = self.fit_database.loc[~mismatch_from_database]
mismatched_files = [x for x, y in zip(file_list, mismatch_from_file) if y]
for x in mismatched_files:
os.remove(x)
print('==== Clearing Completed ====')
def write_updated_database(self):
"""Can be called following clear_mismatched_entries to update current database
"""
database_backup_dir = os.path.join(
self.model_save_base_dir, '.database_backups')
if not os.path.exists(database_backup_dir):
os.makedirs(database_backup_dir)
#current_date = date.today().strftime("%m-%d-%y")
current_date = str(datetime.now()).replace(" ", "_")
shutil.copy(self.model_database_path,
os.path.join(database_backup_dir,
f"database_backup_{current_date}"))
self.fit_database.to_csv(self.model_database_path, mode='w')
def set_run_params(self, data_dir, experiment_name, taste_num, region_name):
"""Store metadata related to inference run
Args:
data_dir (str): Path to directory containing HDF5 file
experiment_name (str): Name given to fitted batch
(for metedata). Defaults to None.
taste_num (int): Index of taste to perform fit on (Corresponds to
INDEX of taste in spike array, not actual dig_ins)
region_name (str): Region on which to perform fit on
(must match regions in .info file)
"""
self.data_dir = data_dir
self.data_basename = os.path.basename(self.data_dir)
self.animal_name = self.data_basename.split("_")[0]
self.session_date = self.data_basename.split("_")[-1]
self.experiment_name = experiment_name
self.model_save_dir = os.path.join(self.model_save_base_dir,
experiment_name)
if not os.path.exists(self.model_save_dir):
os.makedirs(self.model_save_dir)
self.model_id = str(uuid.uuid4()).split('-')[0]
self.model_save_path = os.path.join(self.model_save_dir,
self.experiment_name +
"_" + self.model_id)
self.fit_date = date.today().strftime("%m-%d-%y")
self.taste_num = taste_num
self.region_name = region_name
self.fit_exists = None
def ingest_fit_data(self, met_dict):
"""Load external metadata
Args:
met_dict (dict): Dictionary of metadata from FitHandler class
"""
self.external_metadata = met_dict
def aggregate_metadata(self):
"""Collects information regarding data and current "experiment"
Raises:
Exception: If 'external_metadata' has not been ingested, that needs to be done first
Returns:
dict: Dictionary of metadata given to FitHandler class
"""
if 'external_metadata' not in dir(self):
raise Exception('Fit run metdata needs to be ingested '
'into data_handler first')
data_details = dict(zip(
['data_dir',
'basename',
'animal_name',
'session_date',
'taste_num',
'region_name'],
[self.data_dir,
self.data_basename,
self.animal_name,
self.session_date,
self.taste_num,
self.region_name]))
exp_details = dict(zip(
['exp_name',
'model_id',
'save_path',
'fit_date'],
[self.experiment_name,
self.model_id,
self.model_save_path,
self.fit_date]))
temp_ext_met = self.external_metadata
temp_ext_met['data'] = data_details
temp_ext_met['exp'] = exp_details
return temp_ext_met
def write_to_database(self):
"""Write out metadata to database
"""
agg_metadata = self.aggregate_metadata()
flat_metadata = pd.json_normalize(agg_metadata)
if not os.path.isfile(self.model_database_path):
flat_metadata.to_csv(self.model_database_path, mode='a')
else:
flat_metadata.to_csv(self.model_database_path,
mode='a', header=False)
def check_exists(self):
"""Check if the given fit already exists in database
Returns:
bool: Boolean for whether fit already exists or not
"""
if self.fit_exists is not None:
return self.fit_exists
__init__(self)
special
Initialize DatabaseHandler class
Source code in pytau/changepoint_io.py
def __init__(self):
"""Initialize DatabaseHandler class
"""
self.unique_cols = ['exp.model_id', 'exp.save_path', 'exp.fit_date']
self.model_database_path = MODEL_DATABASE_PATH
self.model_save_base_dir = MODEL_SAVE_DIR
if os.path.exists(self.model_database_path):
self.fit_database = pd.read_csv(self.model_database_path,
index_col=0)
all_na = [all(x) for num, x in self.fit_database.isna().iterrows()]
if all_na:
print(f'{sum(all_na)} rows found with all NA, removing...')
self.fit_database = self.fit_database.dropna(how='all')
else:
print('Fit database does not exist yet')
aggregate_metadata(self)
Collects information regarding data and current "experiment"
Exceptions:
Type | Description |
---|---|
Exception |
If 'external_metadata' has not been ingested, that needs to be done first |
Returns:
Type | Description |
---|---|
dict |
Dictionary of metadata given to FitHandler class |
Source code in pytau/changepoint_io.py
def aggregate_metadata(self):
"""Collects information regarding data and current "experiment"
Raises:
Exception: If 'external_metadata' has not been ingested, that needs to be done first
Returns:
dict: Dictionary of metadata given to FitHandler class
"""
if 'external_metadata' not in dir(self):
raise Exception('Fit run metdata needs to be ingested '
'into data_handler first')
data_details = dict(zip(
['data_dir',
'basename',
'animal_name',
'session_date',
'taste_num',
'region_name'],
[self.data_dir,
self.data_basename,
self.animal_name,
self.session_date,
self.taste_num,
self.region_name]))
exp_details = dict(zip(
['exp_name',
'model_id',
'save_path',
'fit_date'],
[self.experiment_name,
self.model_id,
self.model_save_path,
self.fit_date]))
temp_ext_met = self.external_metadata
temp_ext_met['data'] = data_details
temp_ext_met['exp'] = exp_details
return temp_ext_met
check_exists(self)
Check if the given fit already exists in database
Returns:
Type | Description |
---|---|
bool |
Boolean for whether fit already exists or not |
Source code in pytau/changepoint_io.py
def check_exists(self):
"""Check if the given fit already exists in database
Returns:
bool: Boolean for whether fit already exists or not
"""
if self.fit_exists is not None:
return self.fit_exists
check_mismatched_paths(self)
Check if there are any mismatched pkl files between database and directory
Returns:
Type | Description |
---|---|
pandas dataframe |
Dataframe containing rows for which pkl file not present list: pkl files which cannot be matched to model in database list: all files in save directory |
Source code in pytau/changepoint_io.py
def check_mismatched_paths(self):
"""Check if there are any mismatched pkl files between database and directory
Returns:
pandas dataframe: Dataframe containing rows for which pkl file not present
list: pkl files which cannot be matched to model in database
list: all files in save directory
"""
mismatch_from_database = [not os.path.exists(x + ".pkl")
for x in self.fit_database['exp.save_path']]
file_list = glob(os.path.join(self.model_save_base_dir, "*/*.pkl"))
mismatch_from_file = [not
(x.split('.')[0] in list(
self.fit_database['exp.save_path']))
for x in file_list]
print(f"{sum(mismatch_from_database)} mismatches from database" + "\n"
+ f"{sum(mismatch_from_file)} mismatches from files")
return mismatch_from_database, mismatch_from_file, file_list
clear_mismatched_paths(self)
Remove mismatched files and rows in database
i.e. Remove 1) Files for which no entry can be found in database 2) Database entries for which no corresponding file can be found
Source code in pytau/changepoint_io.py
def clear_mismatched_paths(self):
"""Remove mismatched files and rows in database
i.e. Remove
1) Files for which no entry can be found in database
2) Database entries for which no corresponding file can be found
"""
mismatch_from_database, mismatch_from_file, file_list = \
self.check_mismatched_paths()
mismatch_from_file = np.array(mismatch_from_file)
mismatch_from_database = np.array(mismatch_from_database)
self.fit_database = self.fit_database.loc[~mismatch_from_database]
mismatched_files = [x for x, y in zip(file_list, mismatch_from_file) if y]
for x in mismatched_files:
os.remove(x)
print('==== Clearing Completed ====')
drop_duplicates(self)
Remove duplicated rows from database
Source code in pytau/changepoint_io.py
def drop_duplicates(self):
"""Remove duplicated rows from database
"""
_, dup_inds = self.show_duplicates()
print(f'Removing {sum(dup_inds)} duplicate rows')
self.fit_database = self.fit_database.loc[~dup_inds]
ingest_fit_data(self, met_dict)
Load external metadata
Parameters:
Name | Type | Description | Default |
---|---|---|---|
met_dict |
dict |
Dictionary of metadata from FitHandler class |
required |
Source code in pytau/changepoint_io.py
def ingest_fit_data(self, met_dict):
"""Load external metadata
Args:
met_dict (dict): Dictionary of metadata from FitHandler class
"""
self.external_metadata = met_dict
set_run_params(self, data_dir, experiment_name, taste_num, region_name)
Store metadata related to inference run
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_dir |
str |
Path to directory containing HDF5 file |
required |
experiment_name |
str |
Name given to fitted batch (for metedata). Defaults to None. |
required |
taste_num |
int |
Index of taste to perform fit on (Corresponds to INDEX of taste in spike array, not actual dig_ins) |
required |
region_name |
str |
Region on which to perform fit on (must match regions in .info file) |
required |
Source code in pytau/changepoint_io.py
def set_run_params(self, data_dir, experiment_name, taste_num, region_name):
"""Store metadata related to inference run
Args:
data_dir (str): Path to directory containing HDF5 file
experiment_name (str): Name given to fitted batch
(for metedata). Defaults to None.
taste_num (int): Index of taste to perform fit on (Corresponds to
INDEX of taste in spike array, not actual dig_ins)
region_name (str): Region on which to perform fit on
(must match regions in .info file)
"""
self.data_dir = data_dir
self.data_basename = os.path.basename(self.data_dir)
self.animal_name = self.data_basename.split("_")[0]
self.session_date = self.data_basename.split("_")[-1]
self.experiment_name = experiment_name
self.model_save_dir = os.path.join(self.model_save_base_dir,
experiment_name)
if not os.path.exists(self.model_save_dir):
os.makedirs(self.model_save_dir)
self.model_id = str(uuid.uuid4()).split('-')[0]
self.model_save_path = os.path.join(self.model_save_dir,
self.experiment_name +
"_" + self.model_id)
self.fit_date = date.today().strftime("%m-%d-%y")
self.taste_num = taste_num
self.region_name = region_name
self.fit_exists = None
show_duplicates(self, keep='first')
Find duplicates in database
Parameters:
Name | Type | Description | Default |
---|---|---|---|
keep |
str |
Which duplicate to keep (refer to pandas duplicated). Defaults to 'first'. |
'first' |
Returns:
Type | Description |
---|---|
pandas dataframe |
Dataframe containing duplicated rows pandas series : Indices of duplicated rows |
Source code in pytau/changepoint_io.py
def show_duplicates(self, keep='first'):
"""Find duplicates in database
Args:
keep (str, optional): Which duplicate to keep
(refer to pandas duplicated). Defaults to 'first'.
Returns:
pandas dataframe: Dataframe containing duplicated rows
pandas series : Indices of duplicated rows
"""
dup_inds = self.fit_database.drop(self.unique_cols, axis=1)\
.duplicated(keep=keep)
return self.fit_database.loc[dup_inds], dup_inds
write_to_database(self)
Write out metadata to database
Source code in pytau/changepoint_io.py
def write_to_database(self):
"""Write out metadata to database
"""
agg_metadata = self.aggregate_metadata()
flat_metadata = pd.json_normalize(agg_metadata)
if not os.path.isfile(self.model_database_path):
flat_metadata.to_csv(self.model_database_path, mode='a')
else:
flat_metadata.to_csv(self.model_database_path,
mode='a', header=False)
write_updated_database(self)
Can be called following clear_mismatched_entries to update current database
Source code in pytau/changepoint_io.py
def write_updated_database(self):
"""Can be called following clear_mismatched_entries to update current database
"""
database_backup_dir = os.path.join(
self.model_save_base_dir, '.database_backups')
if not os.path.exists(database_backup_dir):
os.makedirs(database_backup_dir)
#current_date = date.today().strftime("%m-%d-%y")
current_date = str(datetime.now()).replace(" ", "_")
shutil.copy(self.model_database_path,
os.path.join(database_backup_dir,
f"database_backup_{current_date}"))
self.fit_database.to_csv(self.model_database_path, mode='w')
FitHandler
Class to handle pipeline of model fitting including: 1) Loading data 2) Preprocessing loaded arrays 3) Fitting model 4) Writing out fitted parameters to pkl file
Source code in pytau/changepoint_io.py
class FitHandler():
"""Class to handle pipeline of model fitting including:
1) Loading data
2) Preprocessing loaded arrays
3) Fitting model
4) Writing out fitted parameters to pkl file
"""
def __init__(self,
data_dir,
taste_num,
region_name,
experiment_name=None,
model_params_path=None,
preprocess_params_path=None,
):
"""Initialize FitHandler class
Args:
data_dir (str): Path to directory containing HDF5 file
taste_num (int): Index of taste to perform fit on
(Corresponds to INDEX of taste in spike array, not actual dig_ins)
region_name (str): Region on which to perform fit on
(must match regions in .info file)
experiment_name (str, optional): Name given to fitted batch
(for metedata). Defaults to None.
model_params_path (str, optional): Path to json file
containing model parameters. Defaults to None.
preprocess_params_path (str, optional): Path to json file
containing preprocessing parameters. Defaults to None.
Raises:
Exception: If "experiment_name" is None
Exception: If "taste_num" is not integer or "all"
"""
# =============== Check for exceptions ===============
if experiment_name is None:
raise Exception('Please specify an experiment name')
if not (isinstance(taste_num,int) or taste_num == 'all'):
raise Exception('taste_num must be an integer or "all"')
# =============== Save relevant arguments ===============
self.data_dir = data_dir
self.EphysData = EphysData(self.data_dir)
#self.data = self.EphysData.get_spikes({"bla","gc","all"})
self.taste_num = taste_num
self.region_name = region_name
self.experiment_name = experiment_name
data_handler_init_kwargs = dict(zip(
['data_dir', 'experiment_name', 'taste_num', 'region_name'],
[data_dir, experiment_name, taste_num, region_name]))
self.database_handler = DatabaseHandler()
self.database_handler.set_run_params(**data_handler_init_kwargs)
if model_params_path is None:
print('MODEL_PARAMS will have to be set')
else:
self.set_model_params(file_path=model_params_path)
if preprocess_params_path is None:
print('PREPROCESS_PARAMS will have to be set')
else:
self.set_preprocess_params(file_path=preprocess_params_path)
########################################
# SET PARAMS
########################################
def set_preprocess_params(self,
time_lims,
bin_width,
data_transform,
file_path=None):
"""Load given params as "preprocess_params" attribute
Args:
time_lims (array/tuple/list): Start and end of where to cut
spike train array
bin_width (int): Bin width for binning spikes to counts
data_transform (str): Indicator for which transformation to
use (refer to changepoint_preprocess)
file_path (str, optional): Path to json file containing preprocess
parameters. Defaults to None.
"""
if file_path is None:
self.preprocess_params = \
dict(zip(['time_lims', 'bin_width', 'data_transform'],
[time_lims, bin_width, data_transform]))
else:
# Load json and save dict
pass
def set_model_params(self,
states,
fit,
samples,
file_path=None):
"""Load given params as "model_params" attribute
Args:
states (int): Number of states to use in model
fit (int): Iterations to use for model fitting (given ADVI fit)
samples (int): Number of samples to return from fitten model
file_path (str, optional): Path to json file containing
preprocess parameters. Defaults to None.
"""
if file_path is None:
self.model_params = \
dict(zip(['states', 'fit', 'samples'], [states, fit, samples]))
else:
# Load json and save dict
pass
########################################
# SET PIPELINE FUNCS
########################################
def set_preprocessor(self, preprocessing_func):
"""Manually set preprocessor for data e.g.
FitHandler.set_preprocessor(
changepoint_preprocess.preprocess_single_taste)
Args:
preprocessing_func (func):
Function to preprocess data (refer to changepoint_preprocess)
"""
self.preprocessor = preprocessing_func
def preprocess_selector(self):
"""Function to return preprocess function based off of input flag
Preprocessing can be set manually but it is preferred to
go through preprocess selector
Raises:
Exception: If self.taste_num is neither int nor str
"""
if isinstance(self.taste_num, int):
self.set_preprocessor(
changepoint_preprocess.preprocess_single_taste)
elif self.taste_num == 'all':
self.set_preprocessor(changepoint_preprocess.preprocess_all_taste)
else:
raise Exception("Something went wrong")
def set_model_template(self, model_template):
"""Manually set model_template for data e.g.
FitHandler.set_model(changepoint_model.single_taste_poisson)
Args:
model_template (func): Function to generate model template for data]
"""
self.model_template = model_template
def model_template_selector(self):
"""Function to set model based off of input flag
Models can be set manually but it is preferred to go through model selector
Raises:
Exception: If self.taste_num is neither int nor str
"""
if isinstance(self.taste_num, int):
self.set_model_template(changepoint_model.single_taste_poisson)
elif self.taste_num == 'all':
self.set_model_template(changepoint_model.all_taste_poisson)
else:
raise Exception("Something went wrong")
def set_inference(self, inference_func):
"""Manually set inference function for model fit e.g.
FitHandler.set_inference(changepoint_model.advi_fit)
Args:
inference_func (func): Function to use for fitting model
"""
self.inference_func = changepoint_model.advi_fit
def inference_func_selector(self):
"""Function to return model based off of input flag
Currently hard-coded to use "advi_fit"
"""
self.set_inference(changepoint_model.advi_fit)
########################################
# PIPELINE FUNCS
########################################
def load_spike_trains(self):
"""Helper function to load spike trains from data_dir using EphysData module
"""
full_spike_array = self.EphysData.return_region_spikes(self.region_name)
if isinstance(self.taste_num, int):
self.data = full_spike_array[self.taste_num]
if self.taste_num == 'all':
self.data = full_spike_array
print(f'Loading spike trains from {self.database_handler.data_basename}, '
f'dig_in {self.taste_num}')
def preprocess_data(self):
"""Perform data preprocessing
Will check for and complete:
1) Raw data loaded
2) Preprocessor selected
"""
if 'data' not in dir(self):
self.load_spike_trains()
if 'preprocessor' not in dir(self):
self.preprocess_selector()
print('Preprocessing spike trains, '
f'preprocessing func: <{self.preprocessor.__name__}>')
self.preprocessed_data = \
self.preprocessor(self.data, **self.preprocess_params)
def create_model(self):
"""Create model and save as attribute
Will check for and complete:
1) Data preprocessed
2) Model template selected
"""
if 'preprocessed_data' not in dir(self):
self.preprocess_data()
if 'model_template' not in dir(self):
self.model_template_selector()
# In future iterations, before fitting model,
# check that a similar entry doesn't exist
changepoint_model.compile_wait()
print(f'Generating Model, model func: <{self.model_template.__name__}>')
self.model = self.model_template(self.preprocessed_data,
self.model_params['states'])
def run_inference(self):
"""Perform inference on data
Will check for and complete:
1) Model created
2) Inference function selected
"""
if 'model' not in dir(self):
self.create_model()
if 'inference_func' not in dir(self):
self.inference_func_selector()
changepoint_model.compile_wait()
print('Running inference, inference func: '
f'<{self.inference_func.__name__}>')
temp_outs = self.inference_func(self.model,
self.model_params['fit'],
self.model_params['samples'])
varnames = ['model', 'approx', 'lambda', 'tau', 'data']
self.inference_outs = dict(zip(varnames, temp_outs))
def _gen_fit_metadata(self):
"""Generate metadata for fit
Generate metadat by compiling:
1) Preprocess parameters given as input
2) Model parameters given as input
3) Functions used in inference pipeline for : preprocessing,
model generation, fitting
Returns:
dict: Dictionary containing compiled metadata for different
parts of inference pipeline
"""
pre_params = self.preprocess_params
model_params = self.model_params
pre_params['preprocessor_name'] = self.preprocessor.__name__
model_params['model_template_name'] = self.model_template.__name__
model_params['inference_func_name'] = self.inference_func.__name__
fin_dict = dict(zip(['preprocess', 'model'], [pre_params, model_params]))
return fin_dict
def _pass_metadata_to_handler(self):
"""Function to coordinate transfer of metadata to DatabaseHandler
"""
self.database_handler.ingest_fit_data(self._gen_fit_metadata())
def _return_fit_output(self):
"""Compile data, model, fit, and metadata to save output
Returns:
dict: Dictionary containing fitted model data and metadata
"""
self._pass_metadata_to_handler()
agg_metadata = self.database_handler.aggregate_metadata()
return {'model_data': self.inference_outs, 'metadata': agg_metadata}
def save_fit_output(self):
"""Save fit output (fitted data + metadata) to pkl file
"""
if 'inference_outs' not in dir(self):
self.run_inference()
out_dict = self._return_fit_output()
with open(self.database_handler.model_save_path + '.pkl', 'wb') as buff:
pickle.dump(out_dict, buff)
json_file_name = os.path.join(
self.database_handler.model_save_path + '.info')
with open(json_file_name, 'w') as file:
json.dump(out_dict['metadata'], file, indent=4)
self.database_handler.write_to_database()
print('Saving inference output to : \n'
f'{self.database_handler.model_save_dir}'
"\n" + "================================" + '\n')
__init__(self, data_dir, taste_num, region_name, experiment_name=None, model_params_path=None, preprocess_params_path=None)
special
Initialize FitHandler class
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_dir |
str |
Path to directory containing HDF5 file |
required |
taste_num |
int |
Index of taste to perform fit on (Corresponds to INDEX of taste in spike array, not actual dig_ins) |
required |
region_name |
str |
Region on which to perform fit on (must match regions in .info file) |
required |
experiment_name |
str |
Name given to fitted batch (for metedata). Defaults to None. |
None |
model_params_path |
str |
Path to json file containing model parameters. Defaults to None. |
None |
preprocess_params_path |
str |
Path to json file containing preprocessing parameters. Defaults to None. |
None |
Exceptions:
Type | Description |
---|---|
Exception |
If "experiment_name" is None |
Exception |
If "taste_num" is not integer or "all" |
Source code in pytau/changepoint_io.py
def __init__(self,
data_dir,
taste_num,
region_name,
experiment_name=None,
model_params_path=None,
preprocess_params_path=None,
):
"""Initialize FitHandler class
Args:
data_dir (str): Path to directory containing HDF5 file
taste_num (int): Index of taste to perform fit on
(Corresponds to INDEX of taste in spike array, not actual dig_ins)
region_name (str): Region on which to perform fit on
(must match regions in .info file)
experiment_name (str, optional): Name given to fitted batch
(for metedata). Defaults to None.
model_params_path (str, optional): Path to json file
containing model parameters. Defaults to None.
preprocess_params_path (str, optional): Path to json file
containing preprocessing parameters. Defaults to None.
Raises:
Exception: If "experiment_name" is None
Exception: If "taste_num" is not integer or "all"
"""
# =============== Check for exceptions ===============
if experiment_name is None:
raise Exception('Please specify an experiment name')
if not (isinstance(taste_num,int) or taste_num == 'all'):
raise Exception('taste_num must be an integer or "all"')
# =============== Save relevant arguments ===============
self.data_dir = data_dir
self.EphysData = EphysData(self.data_dir)
#self.data = self.EphysData.get_spikes({"bla","gc","all"})
self.taste_num = taste_num
self.region_name = region_name
self.experiment_name = experiment_name
data_handler_init_kwargs = dict(zip(
['data_dir', 'experiment_name', 'taste_num', 'region_name'],
[data_dir, experiment_name, taste_num, region_name]))
self.database_handler = DatabaseHandler()
self.database_handler.set_run_params(**data_handler_init_kwargs)
if model_params_path is None:
print('MODEL_PARAMS will have to be set')
else:
self.set_model_params(file_path=model_params_path)
if preprocess_params_path is None:
print('PREPROCESS_PARAMS will have to be set')
else:
self.set_preprocess_params(file_path=preprocess_params_path)
create_model(self)
Create model and save as attribute
Will check for and complete: 1) Data preprocessed 2) Model template selected
Source code in pytau/changepoint_io.py
def create_model(self):
"""Create model and save as attribute
Will check for and complete:
1) Data preprocessed
2) Model template selected
"""
if 'preprocessed_data' not in dir(self):
self.preprocess_data()
if 'model_template' not in dir(self):
self.model_template_selector()
# In future iterations, before fitting model,
# check that a similar entry doesn't exist
changepoint_model.compile_wait()
print(f'Generating Model, model func: <{self.model_template.__name__}>')
self.model = self.model_template(self.preprocessed_data,
self.model_params['states'])
inference_func_selector(self)
Function to return model based off of input flag
Currently hard-coded to use "advi_fit"
Source code in pytau/changepoint_io.py
def inference_func_selector(self):
"""Function to return model based off of input flag
Currently hard-coded to use "advi_fit"
"""
self.set_inference(changepoint_model.advi_fit)
load_spike_trains(self)
Helper function to load spike trains from data_dir using EphysData module
Source code in pytau/changepoint_io.py
def load_spike_trains(self):
"""Helper function to load spike trains from data_dir using EphysData module
"""
full_spike_array = self.EphysData.return_region_spikes(self.region_name)
if isinstance(self.taste_num, int):
self.data = full_spike_array[self.taste_num]
if self.taste_num == 'all':
self.data = full_spike_array
print(f'Loading spike trains from {self.database_handler.data_basename}, '
f'dig_in {self.taste_num}')
model_template_selector(self)
Function to set model based off of input flag
Models can be set manually but it is preferred to go through model selector
Exceptions:
Type | Description |
---|---|
Exception |
If self.taste_num is neither int nor str |
Source code in pytau/changepoint_io.py
def model_template_selector(self):
"""Function to set model based off of input flag
Models can be set manually but it is preferred to go through model selector
Raises:
Exception: If self.taste_num is neither int nor str
"""
if isinstance(self.taste_num, int):
self.set_model_template(changepoint_model.single_taste_poisson)
elif self.taste_num == 'all':
self.set_model_template(changepoint_model.all_taste_poisson)
else:
raise Exception("Something went wrong")
preprocess_data(self)
Perform data preprocessing
Will check for and complete: 1) Raw data loaded 2) Preprocessor selected
Source code in pytau/changepoint_io.py
def preprocess_data(self):
"""Perform data preprocessing
Will check for and complete:
1) Raw data loaded
2) Preprocessor selected
"""
if 'data' not in dir(self):
self.load_spike_trains()
if 'preprocessor' not in dir(self):
self.preprocess_selector()
print('Preprocessing spike trains, '
f'preprocessing func: <{self.preprocessor.__name__}>')
self.preprocessed_data = \
self.preprocessor(self.data, **self.preprocess_params)
preprocess_selector(self)
Function to return preprocess function based off of input flag
Preprocessing can be set manually but it is preferred to go through preprocess selector
Exceptions:
Type | Description |
---|---|
Exception |
If self.taste_num is neither int nor str |
Source code in pytau/changepoint_io.py
def preprocess_selector(self):
"""Function to return preprocess function based off of input flag
Preprocessing can be set manually but it is preferred to
go through preprocess selector
Raises:
Exception: If self.taste_num is neither int nor str
"""
if isinstance(self.taste_num, int):
self.set_preprocessor(
changepoint_preprocess.preprocess_single_taste)
elif self.taste_num == 'all':
self.set_preprocessor(changepoint_preprocess.preprocess_all_taste)
else:
raise Exception("Something went wrong")
run_inference(self)
Perform inference on data
Will check for and complete: 1) Model created 2) Inference function selected
Source code in pytau/changepoint_io.py
def run_inference(self):
"""Perform inference on data
Will check for and complete:
1) Model created
2) Inference function selected
"""
if 'model' not in dir(self):
self.create_model()
if 'inference_func' not in dir(self):
self.inference_func_selector()
changepoint_model.compile_wait()
print('Running inference, inference func: '
f'<{self.inference_func.__name__}>')
temp_outs = self.inference_func(self.model,
self.model_params['fit'],
self.model_params['samples'])
varnames = ['model', 'approx', 'lambda', 'tau', 'data']
self.inference_outs = dict(zip(varnames, temp_outs))
save_fit_output(self)
Save fit output (fitted data + metadata) to pkl file
Source code in pytau/changepoint_io.py
def save_fit_output(self):
"""Save fit output (fitted data + metadata) to pkl file
"""
if 'inference_outs' not in dir(self):
self.run_inference()
out_dict = self._return_fit_output()
with open(self.database_handler.model_save_path + '.pkl', 'wb') as buff:
pickle.dump(out_dict, buff)
json_file_name = os.path.join(
self.database_handler.model_save_path + '.info')
with open(json_file_name, 'w') as file:
json.dump(out_dict['metadata'], file, indent=4)
self.database_handler.write_to_database()
print('Saving inference output to : \n'
f'{self.database_handler.model_save_dir}'
"\n" + "================================" + '\n')
set_inference(self, inference_func)
Manually set inference function for model fit e.g.
FitHandler.set_inference(changepoint_model.advi_fit)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
inference_func |
func |
Function to use for fitting model |
required |
Source code in pytau/changepoint_io.py
def set_inference(self, inference_func):
"""Manually set inference function for model fit e.g.
FitHandler.set_inference(changepoint_model.advi_fit)
Args:
inference_func (func): Function to use for fitting model
"""
self.inference_func = changepoint_model.advi_fit
set_model_params(self, states, fit, samples, file_path=None)
Load given params as "model_params" attribute
Parameters:
Name | Type | Description | Default |
---|---|---|---|
states |
int |
Number of states to use in model |
required |
fit |
int |
Iterations to use for model fitting (given ADVI fit) |
required |
samples |
int |
Number of samples to return from fitten model |
required |
file_path |
str |
Path to json file containing preprocess parameters. Defaults to None. |
None |
Source code in pytau/changepoint_io.py
def set_model_params(self,
states,
fit,
samples,
file_path=None):
"""Load given params as "model_params" attribute
Args:
states (int): Number of states to use in model
fit (int): Iterations to use for model fitting (given ADVI fit)
samples (int): Number of samples to return from fitten model
file_path (str, optional): Path to json file containing
preprocess parameters. Defaults to None.
"""
if file_path is None:
self.model_params = \
dict(zip(['states', 'fit', 'samples'], [states, fit, samples]))
else:
# Load json and save dict
pass
set_model_template(self, model_template)
Manually set model_template for data e.g.
FitHandler.set_model(changepoint_model.single_taste_poisson)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_template |
func |
Function to generate model template for data] |
required |
Source code in pytau/changepoint_io.py
def set_model_template(self, model_template):
"""Manually set model_template for data e.g.
FitHandler.set_model(changepoint_model.single_taste_poisson)
Args:
model_template (func): Function to generate model template for data]
"""
self.model_template = model_template
set_preprocess_params(self, time_lims, bin_width, data_transform, file_path=None)
Load given params as "preprocess_params" attribute
Parameters:
Name | Type | Description | Default |
---|---|---|---|
time_lims |
array/tuple/list |
Start and end of where to cut spike train array |
required |
bin_width |
int |
Bin width for binning spikes to counts |
required |
data_transform |
str |
Indicator for which transformation to use (refer to changepoint_preprocess) |
required |
file_path |
str |
Path to json file containing preprocess parameters. Defaults to None. |
None |
Source code in pytau/changepoint_io.py
def set_preprocess_params(self,
time_lims,
bin_width,
data_transform,
file_path=None):
"""Load given params as "preprocess_params" attribute
Args:
time_lims (array/tuple/list): Start and end of where to cut
spike train array
bin_width (int): Bin width for binning spikes to counts
data_transform (str): Indicator for which transformation to
use (refer to changepoint_preprocess)
file_path (str, optional): Path to json file containing preprocess
parameters. Defaults to None.
"""
if file_path is None:
self.preprocess_params = \
dict(zip(['time_lims', 'bin_width', 'data_transform'],
[time_lims, bin_width, data_transform]))
else:
# Load json and save dict
pass
set_preprocessor(self, preprocessing_func)
Manually set preprocessor for data e.g.
FitHandler.set_preprocessor( changepoint_preprocess.preprocess_single_taste)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
preprocessing_func |
func |
Function to preprocess data (refer to changepoint_preprocess) |
required |
Source code in pytau/changepoint_io.py
def set_preprocessor(self, preprocessing_func):
"""Manually set preprocessor for data e.g.
FitHandler.set_preprocessor(
changepoint_preprocess.preprocess_single_taste)
Args:
preprocessing_func (func):
Function to preprocess data (refer to changepoint_preprocess)
"""
self.preprocessor = preprocessing_func
=== Model building functions ===
PyMC3 Blackbox Variational Inference implementation of Poisson Likelihood Changepoint for spike trains.
advi_fit(model, fit, samples)
Convenience function to perform ADVI fit on model
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
pymc3 model |
model object to run inference on |
required |
fit |
int |
Number of iterationst to fit the model for |
required |
samples |
int |
Number of samples to draw from fitted model |
required |
Returns:
Type | Description |
---|---|
model |
original model on which inference was run, approx: fitted model, lambda_stack: array containing lambda (emission) values, tau_samples,: array containing samples from changepoint distribution model.obs.observations: processed array on which fit was run |
Source code in pytau/changepoint_model.py
def advi_fit(model, fit, samples):
"""Convenience function to perform ADVI fit on model
Args:
model (pymc3 model): model object to run inference on
fit (int): Number of iterationst to fit the model for
samples (int): Number of samples to draw from fitted model
Returns:
model: original model on which inference was run,
approx: fitted model,
lambda_stack: array containing lambda (emission) values,
tau_samples,: array containing samples from changepoint distribution
model.obs.observations: processed array on which fit was run
"""
with model:
inference = pm.ADVI('full-rank')
approx = pm.fit(n=fit, method=inference)
trace = approx.sample(draws=samples)
# Extract relevant variables from trace
lambda_stack = trace['lambda'].swapaxes(0, 1)
tau_samples = trace['tau']
return model, approx, lambda_stack, tau_samples, model.obs.observations
all_taste_poisson(spike_array, states)
Model to fit changepoint to single taste ** Largely taken from "_v1/poisson_all_tastes_changepoint_model.py"
spike_array :: Shape : tastes, trials, neurons, time_bins states :: number of states to include in the model
Source code in pytau/changepoint_model.py
def all_taste_poisson(
spike_array,
states):
"""
** Model to fit changepoint to single taste **
** Largely taken from "_v1/poisson_all_tastes_changepoint_model.py"
spike_array :: Shape : tastes, trials, neurons, time_bins
states :: number of states to include in the model
"""
# If model already doesn't exist, then create new one
#spike_array = this_dat_binned
# Unroll arrays along taste axis
#spike_array_long = np.reshape(spike_array,(-1,*spike_array.shape[-2:]))
spike_array_long = np.concatenate(spike_array, axis=0)
# Find mean firing for initial values
tastes = spike_array.shape[0]
length = spike_array.shape[-1]
nrns = spike_array.shape[2]
trials = spike_array.shape[1]
split_list = np.array_split(spike_array,states,axis=-1)
# Cut all to the same size
min_val = min([x.shape[-1] for x in split_list])
split_array = np.array([x[...,:min_val] for x in split_list])
mean_vals = np.mean(split_array,axis=(2,-1)).swapaxes(0,1)
mean_vals += 0.01 # To avoid zero starting prob
mean_nrn_vals = np.mean(mean_vals,axis=(0,1))
# Find evenly spaces switchpoints for initial values
idx = np.arange(spike_array.shape[-1]) # Index
array_idx = np.broadcast_to(idx, spike_array_long.shape)
even_switches = np.linspace(0,idx.max(),states+1)
even_switches_normal = even_switches/np.max(even_switches)
taste_label = np.repeat(np.arange(spike_array.shape[0]), spike_array.shape[1])
trial_num = array_idx.shape[0]
# Being constructing model
with pm.Model() as model:
# Hierarchical firing rates
# Refer to model diagram
# Mean firing rate of neuron AT ALL TIMES
lambda_nrn = pm.Exponential('lambda_nrn',
1/mean_nrn_vals,
shape=(mean_vals.shape[-1]))
# Priors for each state, derived from each neuron
# Mean firing rate of neuron IN EACH STATE (averaged across tastes)
lambda_state = pm.Exponential('lambda_state',
lambda_nrn,
shape=(mean_vals.shape[1:]))
# Mean firing rate of neuron PER STATE PER TASTE
lambda_latent = pm.Exponential('lambda',
lambda_state[np.newaxis, :, :],
testval=mean_vals,
shape=(mean_vals.shape))
# Changepoint time variable
# INDEPENDENT TAU FOR EVERY TRIAL
a = pm.HalfNormal('a_tau', 3., shape=states - 1)
b = pm.HalfNormal('b_tau', 3., shape=states - 1)
# Stack produces states x trials --> That gets transposed
# to trials x states and gets sorted along states (axis=-1)
# Sort should work the same way as the Ordered transform -->
# see rv_sort_test.ipynb
tau_latent = pm.Beta('tau_latent', a, b,
shape=(trial_num, states-1),
testval=tt.tile(even_switches_normal[1:(states)],
(array_idx.shape[0], 1))).sort(axis=-1)
tau = pm.Deterministic('tau',
idx.min() + (idx.max() - idx.min()) * tau_latent)
weight_stack = tt.nnet.sigmoid(
idx[np.newaxis, :]-tau[:, :, np.newaxis])
weight_stack = tt.concatenate(
[np.ones((tastes*trials, 1, length)), weight_stack], axis=1)
inverse_stack = 1 - weight_stack[:, 1:]
inverse_stack = tt.concatenate([inverse_stack, np.ones(
(tastes*trials, 1, length))], axis=1)
weight_stack = weight_stack*inverse_stack
weight_stack = tt.tile(weight_stack[:, :, None, :], (1, 1, nrns, 1))
lambda_latent = lambda_latent.dimshuffle(2, 0, 1)
lambda_latent = tt.repeat(lambda_latent, trials, axis=1)
lambda_latent = tt.tile(
lambda_latent[..., None], (1, 1, 1, length))
lambda_latent = lambda_latent.dimshuffle(1, 2, 0, 3)
lambda_ = tt.sum(lambda_latent * weight_stack, axis=1)
observation = pm.Poisson("obs", lambda_, observed=spike_array_long)
return model
compile_wait()
Function to allow waiting while a model is already fitting Wait twice because lock blips out between steps 10 secs of waiting shouldn't be a problem for long fits (~mins) And wait a random time in the beginning to stagger fits
Source code in pytau/changepoint_model.py
def compile_wait():
"""
Function to allow waiting while a model is already fitting
Wait twice because lock blips out between steps
10 secs of waiting shouldn't be a problem for long fits (~mins)
And wait a random time in the beginning to stagger fits
"""
time.sleep(np.random.random()*10)
while theano_lock_present():
print('Lock present...waiting')
time.sleep(10)
while theano_lock_present():
print('Lock present...waiting')
time.sleep(10)
single_taste_poisson(spike_array, states)
Model for changepoint on single taste
Largely taken from "non_hardcoded_changepoint_test_3d.ipynb" Note : This model does not have hierarchical structure for emissions
Parameters:
Name | Type | Description | Default |
---|---|---|---|
spike_array |
3D Numpy array |
trials x neurons x time |
required |
states |
int |
Number of states to model |
required |
Returns:
Type | Description |
---|---|
pymc3 model |
Model class containing graph to run inference on |
Source code in pytau/changepoint_model.py
def single_taste_poisson(
spike_array,
states):
"""Model for changepoint on single taste
** Largely taken from "non_hardcoded_changepoint_test_3d.ipynb"
** Note : This model does not have hierarchical structure for emissions
Args:
spike_array (3D Numpy array): trials x neurons x time
states (int): Number of states to model
Returns:
pymc3 model: Model class containing graph to run inference on
"""
mean_vals = np.array([np.mean(x, axis=-1)
for x in np.array_split(spike_array, states, axis=-1)]).T
mean_vals = np.mean(mean_vals, axis=1)
mean_vals += 0.01 # To avoid zero starting prob
nrns = spike_array.shape[1]
trials = spike_array.shape[0]
idx = np.arange(spike_array.shape[-1])
length = idx.max() + 1
with pm.Model() as model:
lambda_latent = pm.Exponential('lambda',
1/mean_vals,
shape=(nrns, states))
a_tau = pm.HalfCauchy('a_tau', 3., shape=states - 1)
b_tau = pm.HalfCauchy('b_tau', 3., shape=states - 1)
even_switches = np.linspace(0, 1, states+1)[1:-1]
tau_latent = pm.Beta('tau_latent', a_tau, b_tau,
testval=even_switches,
shape=(trials, states-1)).sort(axis=-1)
tau = pm.Deterministic('tau',
idx.min() + (idx.max() - idx.min()) * tau_latent)
weight_stack = tt.nnet.sigmoid(
idx[np.newaxis, :]-tau[:, :, np.newaxis])
weight_stack = tt.concatenate(
[np.ones((trials, 1, length)), weight_stack], axis=1)
inverse_stack = 1 - weight_stack[:, 1:]
inverse_stack = tt.concatenate(
[inverse_stack, np.ones((trials, 1, length))], axis=1)
weight_stack = np.multiply(weight_stack, inverse_stack)
lambda_ = tt.tensordot(weight_stack, lambda_latent, [1, 1]).swapaxes(1, 2)
observation = pm.Poisson("obs", lambda_, observed=spike_array)
return model
theano_lock_present()
Check if theano compilation lock is present
Source code in pytau/changepoint_model.py
def theano_lock_present():
"""
Check if theano compilation lock is present
"""
return os.path.exists(os.path.join(theano.config.compiledir, 'lock_dir'))
=== Preprocessing functions ===
Code to preprocess spike trains before feeding into model
preprocess_all_taste(spike_array, time_lims, bin_width, data_transform)
Preprocess array containing trials for all tastes (in blocks) concatenated
Parameters:
Name | Type | Description | Default |
---|---|---|---|
spike_array |
4D Numpy Array |
Taste x Trials x Neurons x Time |
required |
time_lims |
List/Tuple/Numpy Array |
2-element object indicating limits of array |
required |
bin_width |
int |
Width to use for binning |
required |
data_transform |
str |
Data-type to return {actual, shuffled, simulated} |
required |
Exceptions:
Type | Description |
---|---|
Exception |
If transforms do not belong to ['shuffled','simulated','None',None] |
Returns:
Type | Description |
---|---|
(4D Numpy Array) |
Of processed data |
Source code in pytau/changepoint_preprocess.py
def preprocess_all_taste(spike_array,
time_lims,
bin_width,
data_transform):
"""Preprocess array containing trials for all tastes (in blocks) concatenated
Args:
spike_array (4D Numpy Array): Taste x Trials x Neurons x Time
time_lims (List/Tuple/Numpy Array): 2-element object indicating limits of array
bin_width (int): Width to use for binning
data_transform (str): Data-type to return {actual, shuffled, simulated}
Raises:
Exception: If transforms do not belong to ['shuffled','simulated','None',None]
Returns:
(4D Numpy Array): Of processed data
"""
accepted_transforms = ['trial_shuffled','spike_shuffled','simulated','None',None]
if data_transform not in accepted_transforms:
raise Exception(
f'data_transform must be of type {accepted_transforms}')
##################################################
# Create shuffled data
##################################################
# Shuffle neurons across trials FOR SAME TASTE
if data_transform == 'trial_shuffled':
transformed_dat = np.array([np.random.permutation(neuron) \
for neuron in np.swapaxes(spike_array,2,0)])
transformed_dat = np.swapaxes(transformed_dat,0,2)
if data_transform == 'spike_shuffled':
transformed_dat = spike_array.swapaxes(-1,0)
transformed_dat = np.stack([np.random.permutation(x) \
for x in transformed_dat])
transformed_dat = transformed_dat.swapaxes(0,-1)
##################################################
# Create simulated data
##################################################
# Inhomogeneous poisson process using mean firing rates
elif data_transform == 'simulated':
mean_firing = np.mean(spike_array,axis=1)
mean_firing = np.broadcast_to(mean_firing[:,None], spike_array.shape)
# Simulate spikes
transformed_dat = (np.random.random(spike_array.shape) < mean_firing )*1
##################################################
# Null Transform Case
##################################################
elif data_transform == None or data_transform == 'None':
transformed_dat = spike_array
##################################################
# Bin Data
##################################################
spike_binned = \
np.sum(transformed_dat[...,time_lims[0]:time_lims[1]].\
reshape(*transformed_dat.shape[:-1],-1,bin_width),axis=-1)
spike_binned = np.vectorize(np.int)(spike_binned)
return spike_binned
preprocess_single_taste(spike_array, time_lims, bin_width, data_transform)
Preprocess array containing trials for all tastes (in blocks) concatenated
** Note, it may be useful to use x-arrays here to keep track of coordinates
Parameters:
Name | Type | Description | Default |
---|---|---|---|
spike_array |
3D Numpy array |
trials x neurons x time |
required |
time_lims |
List/Tuple/Numpy Array |
2-element object indicating limits of array |
required |
bin_width |
int |
Width to use for binning |
required |
data_transform |
str |
Data-type to return {actual, shuffled, simulated} |
required |
Exceptions:
Type | Description |
---|---|
Exception |
If transforms do not belong to ['shuffled','simulated','None',None] |
Returns:
Type | Description |
---|---|
(3D Numpy Array) |
Of processed data |
Source code in pytau/changepoint_preprocess.py
def preprocess_single_taste(spike_array,
time_lims,
bin_width,
data_transform):
"""Preprocess array containing trials for all tastes (in blocks) concatenated
** Note, it may be useful to use x-arrays here to keep track of coordinates
Args:
spike_array (3D Numpy array): trials x neurons x time
time_lims (List/Tuple/Numpy Array): 2-element object indicating limits of array
bin_width (int): Width to use for binning
data_transform (str): Data-type to return {actual, shuffled, simulated}
Raises:
Exception: If transforms do not belong to ['shuffled','simulated','None',None]
Returns:
(3D Numpy Array): Of processed data
"""
accepted_transforms = ['shuffled', 'simulated', 'None', None]
if data_transform not in accepted_transforms:
raise Exception(
f'data_transform must be of type {accepted_transforms}')
##################################################
# Create shuffled data
##################################################
# Shuffle neurons across trials FOR SAME TASTE
if data_transform == 'shuffled':
transformed_dat = np.array([np.random.permutation(neuron) \
for neuron in np.swapaxes(spike_array, 1, 0)])
transformed_dat = np.swapaxes(transformed_dat, 0, 1)
##################################################
# Create simulated data
##################################################
# Inhomogeneous poisson process using mean firing rates
elif data_transform == 'simulated':
mean_firing = np.mean(spike_array, axis=0)
# Simulate spikes
transformed_dat = np.array(
[np.random.random(mean_firing.shape) < mean_firing
for trial in range(spike_array.shape[0])])*1
##################################################
# Null Transform Case
##################################################
elif data_transform in (None, 'None'):
transformed_dat = spike_array
##################################################
# Bin Data
##################################################
spike_binned = \
np.sum(transformed_dat[..., time_lims[0]:time_lims[1]].
reshape(*spike_array.shape[:-1], -1, bin_width), axis=-1)
spike_binned = np.vectorize(np.int)(spike_binned)
return spike_binned