Skip to content

References

=== Analysis functions ===

Helper classes and functions to perform analysis on fitted models

PklHandler

Helper class to handle metadata and fit data from pkl file

Source code in pytau/changepoint_analysis.py
class PklHandler():
    """Helper class to handle metadata and fit data from pkl file
    """

    def __init__(self, file_path):
        """Initialize PklHandler class

        Args:
            file_path (str): Path to pkl file
        """
        self.dir_name = os.path.dirname(file_path)
        file_name = os.path.basename(file_path)
        self.file_name_base = file_name.split('.')[0]
        self.pkl_file_path = \
            os.path.join(self.dir_name, self.file_name_base + ".pkl")
        with open(self.pkl_file_path, 'rb') as this_file:
            self.data = pkl.load(this_file)

        model_keys = ['model', 'approx', 'lambda', 'tau', 'data']
        key_savenames = ['_model_structure', '_fit_model',
                         'lambda_array', 'tau_array', 'processed_spikes']
        data_map = dict(zip(model_keys, key_savenames))

        for key, var_name in data_map.items():
            setattr(self, var_name, self.data['model_data'][key])

        self.metadata = self.data['metadata']
        self.pretty_metadata = pd.json_normalize(self.data['metadata']).T

        self.tau = _tau(self.tau_array, self.metadata)
        self.firing = _firing(self.tau, self.processed_spikes, self.metadata)

__init__(self, file_path) special

Initialize PklHandler class

Parameters:

Name Type Description Default
file_path str

Path to pkl file

required
Source code in pytau/changepoint_analysis.py
def __init__(self, file_path):
    """Initialize PklHandler class

    Args:
        file_path (str): Path to pkl file
    """
    self.dir_name = os.path.dirname(file_path)
    file_name = os.path.basename(file_path)
    self.file_name_base = file_name.split('.')[0]
    self.pkl_file_path = \
        os.path.join(self.dir_name, self.file_name_base + ".pkl")
    with open(self.pkl_file_path, 'rb') as this_file:
        self.data = pkl.load(this_file)

    model_keys = ['model', 'approx', 'lambda', 'tau', 'data']
    key_savenames = ['_model_structure', '_fit_model',
                     'lambda_array', 'tau_array', 'processed_spikes']
    data_map = dict(zip(model_keys, key_savenames))

    for key, var_name in data_map.items():
        setattr(self, var_name, self.data['model_data'][key])

    self.metadata = self.data['metadata']
    self.pretty_metadata = pd.json_normalize(self.data['metadata']).T

    self.tau = _tau(self.tau_array, self.metadata)
    self.firing = _firing(self.tau, self.processed_spikes, self.metadata)

get_state_firing(spike_array, tau_array)

Calculate firing rates within states given changepoint positions on data

Parameters:

Name Type Description Default
spike_array 3D Numpy array

trials x nrns x bins

required
tau_array 2D Numpy array

trials x switchpoints

required

Returns:

Type Description
Numpy array

Average firing given state bounds

Source code in pytau/changepoint_analysis.py
def get_state_firing(spike_array, tau_array):
    """Calculate firing rates within states given changepoint positions on data

    Args:
        spike_array (3D Numpy array): trials x nrns x bins
        tau_array (2D Numpy array): trials x switchpoints

    Returns:
        Numpy array: Average firing given state bounds
    """

    states = tau_array.shape[-1] + 1
    # Get mean firing rate for each STATE using model
    state_inds = np.hstack([np.zeros((tau_array.shape[0], 1)),
                            tau_array,
                            np.ones((tau_array.shape[0], 1))*spike_array.shape[-1]])
    state_lims = np.array([state_inds[:, x:x+2] for x in range(states)])
    state_lims = np.vectorize(np.int)(state_lims)
    state_lims = np.swapaxes(state_lims, 0, 1)

    state_firing = \
        np.array([[np.mean(trial_dat[:, start:end], axis=-1)
                   for start, end in trial_lims]
                  for trial_dat, trial_lims in zip(spike_array, state_lims)])

    state_firing = np.nan_to_num(state_firing)
    return state_firing

=== I/O functions ===

Pipeline to handle model fitting from data extraction to saving results

DatabaseHandler

Class to handle transactions with model database

Source code in pytau/changepoint_io.py
class DatabaseHandler():
    """Class to handle transactions with model database
    """

    def __init__(self):
        """Initialize DatabaseHandler class
        """
        self.unique_cols = ['exp.model_id', 'exp.save_path', 'exp.fit_date']
        self.model_database_path = MODEL_DATABASE_PATH
        self.model_save_base_dir = MODEL_SAVE_DIR

        if os.path.exists(self.model_database_path):
            self.fit_database = pd.read_csv(self.model_database_path,
                                            index_col=0)
            all_na = [all(x) for num, x in self.fit_database.isna().iterrows()]
            if all_na:
                print(f'{sum(all_na)} rows found with all NA, removing...')
                self.fit_database = self.fit_database.dropna(how='all')
        else:
            print('Fit database does not exist yet')

    def show_duplicates(self, keep='first'):
        """Find duplicates in database

        Args:
            keep (str, optional): Which duplicate to keep
                    (refer to pandas duplicated). Defaults to 'first'.

        Returns:
            pandas dataframe: Dataframe containing duplicated rows
            pandas series : Indices of duplicated rows
        """
        dup_inds = self.fit_database.drop(self.unique_cols, axis=1)\
            .duplicated(keep=keep)
        return self.fit_database.loc[dup_inds], dup_inds

    def drop_duplicates(self):
        """Remove duplicated rows from database
        """
        _, dup_inds = self.show_duplicates()
        print(f'Removing {sum(dup_inds)} duplicate rows')
        self.fit_database = self.fit_database.loc[~dup_inds]

    def check_mismatched_paths(self):
        """Check if there are any mismatched pkl files between database and directory

        Returns:
            pandas dataframe: Dataframe containing rows for which pkl file not present
            list: pkl files which cannot be matched to model in database
            list: all files in save directory
        """
        mismatch_from_database = [not os.path.exists(x + ".pkl")
                                  for x in self.fit_database['exp.save_path']]
        file_list = glob(os.path.join(self.model_save_base_dir, "*/*.pkl"))
        mismatch_from_file = [not
                              (x.split('.')[0] in list(
                                  self.fit_database['exp.save_path']))
                              for x in file_list]
        print(f"{sum(mismatch_from_database)} mismatches from database" + "\n"
              + f"{sum(mismatch_from_file)} mismatches from files")
        return mismatch_from_database, mismatch_from_file, file_list

    def clear_mismatched_paths(self):
        """Remove mismatched files and rows in database

        i.e. Remove
        1) Files for which no entry can be found in database
        2) Database entries for which no corresponding file can be found
        """
        mismatch_from_database, mismatch_from_file, file_list = \
            self.check_mismatched_paths()
        mismatch_from_file = np.array(mismatch_from_file)
        mismatch_from_database = np.array(mismatch_from_database)
        self.fit_database = self.fit_database.loc[~mismatch_from_database]
        mismatched_files = [x for x, y in zip(file_list, mismatch_from_file) if y]
        for x in mismatched_files:
            os.remove(x)
        print('==== Clearing Completed ====')

    def write_updated_database(self):
        """Can be called following clear_mismatched_entries to update current database
        """
        database_backup_dir = os.path.join(
            self.model_save_base_dir, '.database_backups')
        if not os.path.exists(database_backup_dir):
            os.makedirs(database_backup_dir)
        #current_date = date.today().strftime("%m-%d-%y")
        current_date = str(datetime.now()).replace(" ", "_")
        shutil.copy(self.model_database_path,
                    os.path.join(database_backup_dir,
                                 f"database_backup_{current_date}"))
        self.fit_database.to_csv(self.model_database_path, mode='w')

    def set_run_params(self, data_dir, experiment_name, taste_num, region_name):
        """Store metadata related to inference run

        Args:
            data_dir (str): Path to directory containing HDF5 file
            experiment_name (str): Name given to fitted batch
                    (for metedata). Defaults to None.
            taste_num (int): Index of taste to perform fit on (Corresponds to
                    INDEX of taste in spike array, not actual dig_ins)
            region_name (str): Region on which to perform fit on
                    (must match regions in .info file)
        """
        self.data_dir = data_dir
        self.data_basename = os.path.basename(self.data_dir)
        self.animal_name = self.data_basename.split("_")[0]
        self.session_date = self.data_basename.split("_")[-1]

        self.experiment_name = experiment_name
        self.model_save_dir = os.path.join(self.model_save_base_dir,
                                           experiment_name)

        if not os.path.exists(self.model_save_dir):
            os.makedirs(self.model_save_dir)

        self.model_id = str(uuid.uuid4()).split('-')[0]
        self.model_save_path = os.path.join(self.model_save_dir,
                                            self.experiment_name +
                                            "_" + self.model_id)
        self.fit_date = date.today().strftime("%m-%d-%y")

        self.taste_num = taste_num
        self.region_name = region_name

        self.fit_exists = None

    def ingest_fit_data(self, met_dict):
        """Load external metadata

        Args:
            met_dict (dict): Dictionary of metadata from FitHandler class
        """
        self.external_metadata = met_dict

    def aggregate_metadata(self):
        """Collects information regarding data and current "experiment"

        Raises:
            Exception: If 'external_metadata' has not been ingested, that needs to be done first

        Returns:
            dict: Dictionary of metadata given to FitHandler class
        """
        if 'external_metadata' not in dir(self):
            raise Exception('Fit run metdata needs to be ingested '
                            'into data_handler first')

        data_details = dict(zip(
            ['data_dir',
             'basename',
             'animal_name',
             'session_date',
             'taste_num',
             'region_name'],
            [self.data_dir,
             self.data_basename,
             self.animal_name,
             self.session_date,
             self.taste_num,
             self.region_name]))

        exp_details = dict(zip(
            ['exp_name', 
             'model_id',
             'save_path',
             'fit_date'],
            [self.experiment_name,
             self.model_id,
             self.model_save_path,
             self.fit_date]))

        temp_ext_met = self.external_metadata
        temp_ext_met['data'] = data_details
        temp_ext_met['exp'] = exp_details

        return temp_ext_met

    def write_to_database(self):
        """Write out metadata to database
        """
        agg_metadata = self.aggregate_metadata()
        flat_metadata = pd.json_normalize(agg_metadata)
        if not os.path.isfile(self.model_database_path):
            flat_metadata.to_csv(self.model_database_path, mode='a')
        else:
            flat_metadata.to_csv(self.model_database_path,
                                 mode='a', header=False)

    def check_exists(self):
        """Check if the given fit already exists in database

        Returns:
            bool: Boolean for whether fit already exists or not
        """
        if self.fit_exists is not None:
            return self.fit_exists

__init__(self) special

Initialize DatabaseHandler class

Source code in pytau/changepoint_io.py
def __init__(self):
    """Initialize DatabaseHandler class
    """
    self.unique_cols = ['exp.model_id', 'exp.save_path', 'exp.fit_date']
    self.model_database_path = MODEL_DATABASE_PATH
    self.model_save_base_dir = MODEL_SAVE_DIR

    if os.path.exists(self.model_database_path):
        self.fit_database = pd.read_csv(self.model_database_path,
                                        index_col=0)
        all_na = [all(x) for num, x in self.fit_database.isna().iterrows()]
        if all_na:
            print(f'{sum(all_na)} rows found with all NA, removing...')
            self.fit_database = self.fit_database.dropna(how='all')
    else:
        print('Fit database does not exist yet')

aggregate_metadata(self)

Collects information regarding data and current "experiment"

Exceptions:

Type Description
Exception

If 'external_metadata' has not been ingested, that needs to be done first

Returns:

Type Description
dict

Dictionary of metadata given to FitHandler class

Source code in pytau/changepoint_io.py
def aggregate_metadata(self):
    """Collects information regarding data and current "experiment"

    Raises:
        Exception: If 'external_metadata' has not been ingested, that needs to be done first

    Returns:
        dict: Dictionary of metadata given to FitHandler class
    """
    if 'external_metadata' not in dir(self):
        raise Exception('Fit run metdata needs to be ingested '
                        'into data_handler first')

    data_details = dict(zip(
        ['data_dir',
         'basename',
         'animal_name',
         'session_date',
         'taste_num',
         'region_name'],
        [self.data_dir,
         self.data_basename,
         self.animal_name,
         self.session_date,
         self.taste_num,
         self.region_name]))

    exp_details = dict(zip(
        ['exp_name', 
         'model_id',
         'save_path',
         'fit_date'],
        [self.experiment_name,
         self.model_id,
         self.model_save_path,
         self.fit_date]))

    temp_ext_met = self.external_metadata
    temp_ext_met['data'] = data_details
    temp_ext_met['exp'] = exp_details

    return temp_ext_met

check_exists(self)

Check if the given fit already exists in database

Returns:

Type Description
bool

Boolean for whether fit already exists or not

Source code in pytau/changepoint_io.py
def check_exists(self):
    """Check if the given fit already exists in database

    Returns:
        bool: Boolean for whether fit already exists or not
    """
    if self.fit_exists is not None:
        return self.fit_exists

check_mismatched_paths(self)

Check if there are any mismatched pkl files between database and directory

Returns:

Type Description
pandas dataframe

Dataframe containing rows for which pkl file not present list: pkl files which cannot be matched to model in database list: all files in save directory

Source code in pytau/changepoint_io.py
def check_mismatched_paths(self):
    """Check if there are any mismatched pkl files between database and directory

    Returns:
        pandas dataframe: Dataframe containing rows for which pkl file not present
        list: pkl files which cannot be matched to model in database
        list: all files in save directory
    """
    mismatch_from_database = [not os.path.exists(x + ".pkl")
                              for x in self.fit_database['exp.save_path']]
    file_list = glob(os.path.join(self.model_save_base_dir, "*/*.pkl"))
    mismatch_from_file = [not
                          (x.split('.')[0] in list(
                              self.fit_database['exp.save_path']))
                          for x in file_list]
    print(f"{sum(mismatch_from_database)} mismatches from database" + "\n"
          + f"{sum(mismatch_from_file)} mismatches from files")
    return mismatch_from_database, mismatch_from_file, file_list

clear_mismatched_paths(self)

Remove mismatched files and rows in database

i.e. Remove 1) Files for which no entry can be found in database 2) Database entries for which no corresponding file can be found

Source code in pytau/changepoint_io.py
def clear_mismatched_paths(self):
    """Remove mismatched files and rows in database

    i.e. Remove
    1) Files for which no entry can be found in database
    2) Database entries for which no corresponding file can be found
    """
    mismatch_from_database, mismatch_from_file, file_list = \
        self.check_mismatched_paths()
    mismatch_from_file = np.array(mismatch_from_file)
    mismatch_from_database = np.array(mismatch_from_database)
    self.fit_database = self.fit_database.loc[~mismatch_from_database]
    mismatched_files = [x for x, y in zip(file_list, mismatch_from_file) if y]
    for x in mismatched_files:
        os.remove(x)
    print('==== Clearing Completed ====')

drop_duplicates(self)

Remove duplicated rows from database

Source code in pytau/changepoint_io.py
def drop_duplicates(self):
    """Remove duplicated rows from database
    """
    _, dup_inds = self.show_duplicates()
    print(f'Removing {sum(dup_inds)} duplicate rows')
    self.fit_database = self.fit_database.loc[~dup_inds]

ingest_fit_data(self, met_dict)

Load external metadata

Parameters:

Name Type Description Default
met_dict dict

Dictionary of metadata from FitHandler class

required
Source code in pytau/changepoint_io.py
def ingest_fit_data(self, met_dict):
    """Load external metadata

    Args:
        met_dict (dict): Dictionary of metadata from FitHandler class
    """
    self.external_metadata = met_dict

set_run_params(self, data_dir, experiment_name, taste_num, region_name)

Store metadata related to inference run

Parameters:

Name Type Description Default
data_dir str

Path to directory containing HDF5 file

required
experiment_name str

Name given to fitted batch (for metedata). Defaults to None.

required
taste_num int

Index of taste to perform fit on (Corresponds to INDEX of taste in spike array, not actual dig_ins)

required
region_name str

Region on which to perform fit on (must match regions in .info file)

required
Source code in pytau/changepoint_io.py
def set_run_params(self, data_dir, experiment_name, taste_num, region_name):
    """Store metadata related to inference run

    Args:
        data_dir (str): Path to directory containing HDF5 file
        experiment_name (str): Name given to fitted batch
                (for metedata). Defaults to None.
        taste_num (int): Index of taste to perform fit on (Corresponds to
                INDEX of taste in spike array, not actual dig_ins)
        region_name (str): Region on which to perform fit on
                (must match regions in .info file)
    """
    self.data_dir = data_dir
    self.data_basename = os.path.basename(self.data_dir)
    self.animal_name = self.data_basename.split("_")[0]
    self.session_date = self.data_basename.split("_")[-1]

    self.experiment_name = experiment_name
    self.model_save_dir = os.path.join(self.model_save_base_dir,
                                       experiment_name)

    if not os.path.exists(self.model_save_dir):
        os.makedirs(self.model_save_dir)

    self.model_id = str(uuid.uuid4()).split('-')[0]
    self.model_save_path = os.path.join(self.model_save_dir,
                                        self.experiment_name +
                                        "_" + self.model_id)
    self.fit_date = date.today().strftime("%m-%d-%y")

    self.taste_num = taste_num
    self.region_name = region_name

    self.fit_exists = None

show_duplicates(self, keep='first')

Find duplicates in database

Parameters:

Name Type Description Default
keep str

Which duplicate to keep (refer to pandas duplicated). Defaults to 'first'.

'first'

Returns:

Type Description
pandas dataframe

Dataframe containing duplicated rows pandas series : Indices of duplicated rows

Source code in pytau/changepoint_io.py
def show_duplicates(self, keep='first'):
    """Find duplicates in database

    Args:
        keep (str, optional): Which duplicate to keep
                (refer to pandas duplicated). Defaults to 'first'.

    Returns:
        pandas dataframe: Dataframe containing duplicated rows
        pandas series : Indices of duplicated rows
    """
    dup_inds = self.fit_database.drop(self.unique_cols, axis=1)\
        .duplicated(keep=keep)
    return self.fit_database.loc[dup_inds], dup_inds

write_to_database(self)

Write out metadata to database

Source code in pytau/changepoint_io.py
def write_to_database(self):
    """Write out metadata to database
    """
    agg_metadata = self.aggregate_metadata()
    flat_metadata = pd.json_normalize(agg_metadata)
    if not os.path.isfile(self.model_database_path):
        flat_metadata.to_csv(self.model_database_path, mode='a')
    else:
        flat_metadata.to_csv(self.model_database_path,
                             mode='a', header=False)

write_updated_database(self)

Can be called following clear_mismatched_entries to update current database

Source code in pytau/changepoint_io.py
def write_updated_database(self):
    """Can be called following clear_mismatched_entries to update current database
    """
    database_backup_dir = os.path.join(
        self.model_save_base_dir, '.database_backups')
    if not os.path.exists(database_backup_dir):
        os.makedirs(database_backup_dir)
    #current_date = date.today().strftime("%m-%d-%y")
    current_date = str(datetime.now()).replace(" ", "_")
    shutil.copy(self.model_database_path,
                os.path.join(database_backup_dir,
                             f"database_backup_{current_date}"))
    self.fit_database.to_csv(self.model_database_path, mode='w')

FitHandler

Class to handle pipeline of model fitting including: 1) Loading data 2) Preprocessing loaded arrays 3) Fitting model 4) Writing out fitted parameters to pkl file

Source code in pytau/changepoint_io.py
class FitHandler():
    """Class to handle pipeline of model fitting including:
    1) Loading data
    2) Preprocessing loaded arrays
    3) Fitting model
    4) Writing out fitted parameters to pkl file

    """

    def __init__(self,
                 data_dir,
                 taste_num,
                 region_name,
                 experiment_name=None,
                 model_params_path=None,
                 preprocess_params_path=None,
                 ):
        """Initialize FitHandler class

        Args:
            data_dir (str): Path to directory containing HDF5 file
            taste_num (int): Index of taste to perform fit on
                    (Corresponds to INDEX of taste in spike array, not actual dig_ins)
            region_name (str): Region on which to perform fit on
                    (must match regions in .info file)
            experiment_name (str, optional): Name given to fitted batch
                    (for metedata). Defaults to None.
            model_params_path (str, optional): Path to json file
                    containing model parameters. Defaults to None.
            preprocess_params_path (str, optional): Path to json file
                    containing preprocessing parameters. Defaults to None.

        Raises:
            Exception: If "experiment_name" is None
            Exception: If "taste_num" is not integer or "all"
        """

        # =============== Check for exceptions ===============
        if experiment_name is None:
            raise Exception('Please specify an experiment name')
        if not (isinstance(taste_num,int) or taste_num == 'all'):
            raise Exception('taste_num must be an integer or "all"')

        # =============== Save relevant arguments ===============
        self.data_dir = data_dir
        self.EphysData = EphysData(self.data_dir)
        #self.data = self.EphysData.get_spikes({"bla","gc","all"})

        self.taste_num = taste_num
        self.region_name = region_name
        self.experiment_name = experiment_name

        data_handler_init_kwargs = dict(zip(
            ['data_dir', 'experiment_name', 'taste_num', 'region_name'],
            [data_dir, experiment_name, taste_num, region_name]))
        self.database_handler = DatabaseHandler()
        self.database_handler.set_run_params(**data_handler_init_kwargs)

        if model_params_path is None:
            print('MODEL_PARAMS will have to be set')
        else:
            self.set_model_params(file_path=model_params_path)

        if preprocess_params_path is None:
            print('PREPROCESS_PARAMS will have to be set')
        else:
            self.set_preprocess_params(file_path=preprocess_params_path)

    ########################################
    # SET PARAMS
    ########################################

    def set_preprocess_params(self,
                              time_lims,
                              bin_width,
                              data_transform,
                              file_path=None):

        """Load given params as "preprocess_params" attribute

        Args:
            time_lims (array/tuple/list): Start and end of where to cut
                    spike train array
            bin_width (int): Bin width for binning spikes to counts
            data_transform (str): Indicator for which transformation to
                    use (refer to changepoint_preprocess)
            file_path (str, optional): Path to json file containing preprocess
                    parameters. Defaults to None.
        """

        if file_path is None:
            self.preprocess_params = \
                dict(zip(['time_lims', 'bin_width', 'data_transform'],
                         [time_lims, bin_width, data_transform]))
        else:
            # Load json and save dict
            pass

    def set_model_params(self,
                         states,
                         fit,
                         samples,
                         file_path=None):

        """Load given params as "model_params" attribute

        Args:
            states (int): Number of states to use in model
            fit (int): Iterations to use for model fitting (given ADVI fit)
            samples (int): Number of samples to return from fitten model
            file_path (str, optional): Path to json file containing
                    preprocess parameters. Defaults to None.
        """

        if file_path is None:
            self.model_params = \
                dict(zip(['states', 'fit', 'samples'], [states, fit, samples]))
        else:
            # Load json and save dict
            pass

    ########################################
    # SET PIPELINE FUNCS
    ########################################

    def set_preprocessor(self, preprocessing_func):
        """Manually set preprocessor for data e.g.

        FitHandler.set_preprocessor(
                    changepoint_preprocess.preprocess_single_taste)

        Args:
            preprocessing_func (func):
                    Function to preprocess data (refer to changepoint_preprocess)
        """
        self.preprocessor = preprocessing_func

    def preprocess_selector(self):
        """Function to return preprocess function based off of input flag

        Preprocessing can be set manually but it is preferred to
        go through preprocess selector

        Raises:
            Exception: If self.taste_num is neither int nor str

        """

        if isinstance(self.taste_num, int):
            self.set_preprocessor(
                changepoint_preprocess.preprocess_single_taste)
        elif self.taste_num == 'all':
            self.set_preprocessor(changepoint_preprocess.preprocess_all_taste)
        else:
            raise Exception("Something went wrong")

    def set_model_template(self, model_template):
        """Manually set model_template for data e.g.

        FitHandler.set_model(changepoint_model.single_taste_poisson)

        Args:
            model_template (func): Function to generate model template for data]
        """
        self.model_template = model_template

    def model_template_selector(self):
        """Function to set model based off of input flag

        Models can be set manually but it is preferred to go through model selector

        Raises:
            Exception: If self.taste_num is neither int nor str

        """
        if isinstance(self.taste_num, int):
            self.set_model_template(changepoint_model.single_taste_poisson)
        elif self.taste_num == 'all':
            self.set_model_template(changepoint_model.all_taste_poisson)
        else:
            raise Exception("Something went wrong")

    def set_inference(self, inference_func):
        """Manually set inference function for model fit e.g.

        FitHandler.set_inference(changepoint_model.advi_fit)

        Args:
            inference_func (func): Function to use for fitting model
        """
        self.inference_func = changepoint_model.advi_fit

    def inference_func_selector(self):
        """Function to return model based off of input flag

        Currently hard-coded to use "advi_fit"
        """
        self.set_inference(changepoint_model.advi_fit)

    ########################################
    # PIPELINE FUNCS
    ########################################

    def load_spike_trains(self):
        """Helper function to load spike trains from data_dir using EphysData module
        """
        full_spike_array = self.EphysData.return_region_spikes(self.region_name)
        if isinstance(self.taste_num, int):
            self.data = full_spike_array[self.taste_num]
        if self.taste_num == 'all':
            self.data = full_spike_array
        print(f'Loading spike trains from {self.database_handler.data_basename}, '
              f'dig_in {self.taste_num}')

    def preprocess_data(self):
        """Perform data preprocessing

        Will check for and complete:
        1) Raw data loaded
        2) Preprocessor selected
        """
        if 'data' not in dir(self):
            self.load_spike_trains()
        if 'preprocessor' not in dir(self):
            self.preprocess_selector()
        print('Preprocessing spike trains, '
              f'preprocessing func: <{self.preprocessor.__name__}>')
        self.preprocessed_data = \
            self.preprocessor(self.data, **self.preprocess_params)

    def create_model(self):
        """Create model and save as attribute

        Will check for and complete:
        1) Data preprocessed
        2) Model template selected
        """
        if 'preprocessed_data' not in dir(self):
            self.preprocess_data()
        if 'model_template' not in dir(self):
            self.model_template_selector()

        # In future iterations, before fitting model,
        # check that a similar entry doesn't exist

        changepoint_model.compile_wait()
        print(f'Generating Model, model func: <{self.model_template.__name__}>')
        self.model = self.model_template(self.preprocessed_data,
                                         self.model_params['states'])

    def run_inference(self):
        """Perform inference on data

        Will check for and complete:
        1) Model created
        2) Inference function selected
        """
        if 'model' not in dir(self):
            self.create_model()
        if 'inference_func' not in dir(self):
            self.inference_func_selector()

        changepoint_model.compile_wait()
        print('Running inference, inference func: '
              f'<{self.inference_func.__name__}>')
        temp_outs = self.inference_func(self.model,
                                        self.model_params['fit'],
                                        self.model_params['samples'])
        varnames = ['model', 'approx', 'lambda', 'tau', 'data']
        self.inference_outs = dict(zip(varnames, temp_outs))

    def _gen_fit_metadata(self):
        """Generate metadata for fit

        Generate metadat by compiling:
        1) Preprocess parameters given as input
        2) Model parameters given as input
        3) Functions used in inference pipeline for : preprocessing,
                model generation, fitting

        Returns:
            dict: Dictionary containing compiled metadata for different
                    parts of inference pipeline
        """
        pre_params = self.preprocess_params
        model_params = self.model_params
        pre_params['preprocessor_name'] = self.preprocessor.__name__
        model_params['model_template_name'] = self.model_template.__name__
        model_params['inference_func_name'] = self.inference_func.__name__
        fin_dict = dict(zip(['preprocess', 'model'], [pre_params, model_params]))
        return fin_dict

    def _pass_metadata_to_handler(self):
        """Function to coordinate transfer of metadata to DatabaseHandler
        """
        self.database_handler.ingest_fit_data(self._gen_fit_metadata())

    def _return_fit_output(self):
        """Compile data, model, fit, and metadata to save output

        Returns:
            dict: Dictionary containing fitted model data and metadata
        """
        self._pass_metadata_to_handler()
        agg_metadata = self.database_handler.aggregate_metadata()
        return {'model_data': self.inference_outs, 'metadata': agg_metadata}

    def save_fit_output(self):
        """Save fit output (fitted data + metadata) to pkl file
        """
        if 'inference_outs' not in dir(self):
            self.run_inference()
        out_dict = self._return_fit_output()
        with open(self.database_handler.model_save_path + '.pkl', 'wb') as buff:
            pickle.dump(out_dict, buff)

        json_file_name = os.path.join(
            self.database_handler.model_save_path + '.info')
        with open(json_file_name, 'w') as file:
            json.dump(out_dict['metadata'], file, indent=4)

        self.database_handler.write_to_database()

        print('Saving inference output to : \n'
                f'{self.database_handler.model_save_dir}' 
                "\n" + "================================" + '\n')

__init__(self, data_dir, taste_num, region_name, experiment_name=None, model_params_path=None, preprocess_params_path=None) special

Initialize FitHandler class

Parameters:

Name Type Description Default
data_dir str

Path to directory containing HDF5 file

required
taste_num int

Index of taste to perform fit on (Corresponds to INDEX of taste in spike array, not actual dig_ins)

required
region_name str

Region on which to perform fit on (must match regions in .info file)

required
experiment_name str

Name given to fitted batch (for metedata). Defaults to None.

None
model_params_path str

Path to json file containing model parameters. Defaults to None.

None
preprocess_params_path str

Path to json file containing preprocessing parameters. Defaults to None.

None

Exceptions:

Type Description
Exception

If "experiment_name" is None

Exception

If "taste_num" is not integer or "all"

Source code in pytau/changepoint_io.py
def __init__(self,
             data_dir,
             taste_num,
             region_name,
             experiment_name=None,
             model_params_path=None,
             preprocess_params_path=None,
             ):
    """Initialize FitHandler class

    Args:
        data_dir (str): Path to directory containing HDF5 file
        taste_num (int): Index of taste to perform fit on
                (Corresponds to INDEX of taste in spike array, not actual dig_ins)
        region_name (str): Region on which to perform fit on
                (must match regions in .info file)
        experiment_name (str, optional): Name given to fitted batch
                (for metedata). Defaults to None.
        model_params_path (str, optional): Path to json file
                containing model parameters. Defaults to None.
        preprocess_params_path (str, optional): Path to json file
                containing preprocessing parameters. Defaults to None.

    Raises:
        Exception: If "experiment_name" is None
        Exception: If "taste_num" is not integer or "all"
    """

    # =============== Check for exceptions ===============
    if experiment_name is None:
        raise Exception('Please specify an experiment name')
    if not (isinstance(taste_num,int) or taste_num == 'all'):
        raise Exception('taste_num must be an integer or "all"')

    # =============== Save relevant arguments ===============
    self.data_dir = data_dir
    self.EphysData = EphysData(self.data_dir)
    #self.data = self.EphysData.get_spikes({"bla","gc","all"})

    self.taste_num = taste_num
    self.region_name = region_name
    self.experiment_name = experiment_name

    data_handler_init_kwargs = dict(zip(
        ['data_dir', 'experiment_name', 'taste_num', 'region_name'],
        [data_dir, experiment_name, taste_num, region_name]))
    self.database_handler = DatabaseHandler()
    self.database_handler.set_run_params(**data_handler_init_kwargs)

    if model_params_path is None:
        print('MODEL_PARAMS will have to be set')
    else:
        self.set_model_params(file_path=model_params_path)

    if preprocess_params_path is None:
        print('PREPROCESS_PARAMS will have to be set')
    else:
        self.set_preprocess_params(file_path=preprocess_params_path)

create_model(self)

Create model and save as attribute

Will check for and complete: 1) Data preprocessed 2) Model template selected

Source code in pytau/changepoint_io.py
def create_model(self):
    """Create model and save as attribute

    Will check for and complete:
    1) Data preprocessed
    2) Model template selected
    """
    if 'preprocessed_data' not in dir(self):
        self.preprocess_data()
    if 'model_template' not in dir(self):
        self.model_template_selector()

    # In future iterations, before fitting model,
    # check that a similar entry doesn't exist

    changepoint_model.compile_wait()
    print(f'Generating Model, model func: <{self.model_template.__name__}>')
    self.model = self.model_template(self.preprocessed_data,
                                     self.model_params['states'])

inference_func_selector(self)

Function to return model based off of input flag

Currently hard-coded to use "advi_fit"

Source code in pytau/changepoint_io.py
def inference_func_selector(self):
    """Function to return model based off of input flag

    Currently hard-coded to use "advi_fit"
    """
    self.set_inference(changepoint_model.advi_fit)

load_spike_trains(self)

Helper function to load spike trains from data_dir using EphysData module

Source code in pytau/changepoint_io.py
def load_spike_trains(self):
    """Helper function to load spike trains from data_dir using EphysData module
    """
    full_spike_array = self.EphysData.return_region_spikes(self.region_name)
    if isinstance(self.taste_num, int):
        self.data = full_spike_array[self.taste_num]
    if self.taste_num == 'all':
        self.data = full_spike_array
    print(f'Loading spike trains from {self.database_handler.data_basename}, '
          f'dig_in {self.taste_num}')

model_template_selector(self)

Function to set model based off of input flag

Models can be set manually but it is preferred to go through model selector

Exceptions:

Type Description
Exception

If self.taste_num is neither int nor str

Source code in pytau/changepoint_io.py
def model_template_selector(self):
    """Function to set model based off of input flag

    Models can be set manually but it is preferred to go through model selector

    Raises:
        Exception: If self.taste_num is neither int nor str

    """
    if isinstance(self.taste_num, int):
        self.set_model_template(changepoint_model.single_taste_poisson)
    elif self.taste_num == 'all':
        self.set_model_template(changepoint_model.all_taste_poisson)
    else:
        raise Exception("Something went wrong")

preprocess_data(self)

Perform data preprocessing

Will check for and complete: 1) Raw data loaded 2) Preprocessor selected

Source code in pytau/changepoint_io.py
def preprocess_data(self):
    """Perform data preprocessing

    Will check for and complete:
    1) Raw data loaded
    2) Preprocessor selected
    """
    if 'data' not in dir(self):
        self.load_spike_trains()
    if 'preprocessor' not in dir(self):
        self.preprocess_selector()
    print('Preprocessing spike trains, '
          f'preprocessing func: <{self.preprocessor.__name__}>')
    self.preprocessed_data = \
        self.preprocessor(self.data, **self.preprocess_params)

preprocess_selector(self)

Function to return preprocess function based off of input flag

Preprocessing can be set manually but it is preferred to go through preprocess selector

Exceptions:

Type Description
Exception

If self.taste_num is neither int nor str

Source code in pytau/changepoint_io.py
def preprocess_selector(self):
    """Function to return preprocess function based off of input flag

    Preprocessing can be set manually but it is preferred to
    go through preprocess selector

    Raises:
        Exception: If self.taste_num is neither int nor str

    """

    if isinstance(self.taste_num, int):
        self.set_preprocessor(
            changepoint_preprocess.preprocess_single_taste)
    elif self.taste_num == 'all':
        self.set_preprocessor(changepoint_preprocess.preprocess_all_taste)
    else:
        raise Exception("Something went wrong")

run_inference(self)

Perform inference on data

Will check for and complete: 1) Model created 2) Inference function selected

Source code in pytau/changepoint_io.py
def run_inference(self):
    """Perform inference on data

    Will check for and complete:
    1) Model created
    2) Inference function selected
    """
    if 'model' not in dir(self):
        self.create_model()
    if 'inference_func' not in dir(self):
        self.inference_func_selector()

    changepoint_model.compile_wait()
    print('Running inference, inference func: '
          f'<{self.inference_func.__name__}>')
    temp_outs = self.inference_func(self.model,
                                    self.model_params['fit'],
                                    self.model_params['samples'])
    varnames = ['model', 'approx', 'lambda', 'tau', 'data']
    self.inference_outs = dict(zip(varnames, temp_outs))

save_fit_output(self)

Save fit output (fitted data + metadata) to pkl file

Source code in pytau/changepoint_io.py
def save_fit_output(self):
    """Save fit output (fitted data + metadata) to pkl file
    """
    if 'inference_outs' not in dir(self):
        self.run_inference()
    out_dict = self._return_fit_output()
    with open(self.database_handler.model_save_path + '.pkl', 'wb') as buff:
        pickle.dump(out_dict, buff)

    json_file_name = os.path.join(
        self.database_handler.model_save_path + '.info')
    with open(json_file_name, 'w') as file:
        json.dump(out_dict['metadata'], file, indent=4)

    self.database_handler.write_to_database()

    print('Saving inference output to : \n'
            f'{self.database_handler.model_save_dir}' 
            "\n" + "================================" + '\n')

set_inference(self, inference_func)

Manually set inference function for model fit e.g.

FitHandler.set_inference(changepoint_model.advi_fit)

Parameters:

Name Type Description Default
inference_func func

Function to use for fitting model

required
Source code in pytau/changepoint_io.py
def set_inference(self, inference_func):
    """Manually set inference function for model fit e.g.

    FitHandler.set_inference(changepoint_model.advi_fit)

    Args:
        inference_func (func): Function to use for fitting model
    """
    self.inference_func = changepoint_model.advi_fit

set_model_params(self, states, fit, samples, file_path=None)

Load given params as "model_params" attribute

Parameters:

Name Type Description Default
states int

Number of states to use in model

required
fit int

Iterations to use for model fitting (given ADVI fit)

required
samples int

Number of samples to return from fitten model

required
file_path str

Path to json file containing preprocess parameters. Defaults to None.

None
Source code in pytau/changepoint_io.py
def set_model_params(self,
                     states,
                     fit,
                     samples,
                     file_path=None):

    """Load given params as "model_params" attribute

    Args:
        states (int): Number of states to use in model
        fit (int): Iterations to use for model fitting (given ADVI fit)
        samples (int): Number of samples to return from fitten model
        file_path (str, optional): Path to json file containing
                preprocess parameters. Defaults to None.
    """

    if file_path is None:
        self.model_params = \
            dict(zip(['states', 'fit', 'samples'], [states, fit, samples]))
    else:
        # Load json and save dict
        pass

set_model_template(self, model_template)

Manually set model_template for data e.g.

FitHandler.set_model(changepoint_model.single_taste_poisson)

Parameters:

Name Type Description Default
model_template func

Function to generate model template for data]

required
Source code in pytau/changepoint_io.py
def set_model_template(self, model_template):
    """Manually set model_template for data e.g.

    FitHandler.set_model(changepoint_model.single_taste_poisson)

    Args:
        model_template (func): Function to generate model template for data]
    """
    self.model_template = model_template

set_preprocess_params(self, time_lims, bin_width, data_transform, file_path=None)

Load given params as "preprocess_params" attribute

Parameters:

Name Type Description Default
time_lims array/tuple/list

Start and end of where to cut spike train array

required
bin_width int

Bin width for binning spikes to counts

required
data_transform str

Indicator for which transformation to use (refer to changepoint_preprocess)

required
file_path str

Path to json file containing preprocess parameters. Defaults to None.

None
Source code in pytau/changepoint_io.py
def set_preprocess_params(self,
                          time_lims,
                          bin_width,
                          data_transform,
                          file_path=None):

    """Load given params as "preprocess_params" attribute

    Args:
        time_lims (array/tuple/list): Start and end of where to cut
                spike train array
        bin_width (int): Bin width for binning spikes to counts
        data_transform (str): Indicator for which transformation to
                use (refer to changepoint_preprocess)
        file_path (str, optional): Path to json file containing preprocess
                parameters. Defaults to None.
    """

    if file_path is None:
        self.preprocess_params = \
            dict(zip(['time_lims', 'bin_width', 'data_transform'],
                     [time_lims, bin_width, data_transform]))
    else:
        # Load json and save dict
        pass

set_preprocessor(self, preprocessing_func)

Manually set preprocessor for data e.g.

FitHandler.set_preprocessor( changepoint_preprocess.preprocess_single_taste)

Parameters:

Name Type Description Default
preprocessing_func func

Function to preprocess data (refer to changepoint_preprocess)

required
Source code in pytau/changepoint_io.py
def set_preprocessor(self, preprocessing_func):
    """Manually set preprocessor for data e.g.

    FitHandler.set_preprocessor(
                changepoint_preprocess.preprocess_single_taste)

    Args:
        preprocessing_func (func):
                Function to preprocess data (refer to changepoint_preprocess)
    """
    self.preprocessor = preprocessing_func

=== Model building functions ===

PyMC3 Blackbox Variational Inference implementation of Poisson Likelihood Changepoint for spike trains.

advi_fit(model, fit, samples)

Convenience function to perform ADVI fit on model

Parameters:

Name Type Description Default
model pymc3 model

model object to run inference on

required
fit int

Number of iterationst to fit the model for

required
samples int

Number of samples to draw from fitted model

required

Returns:

Type Description
model

original model on which inference was run, approx: fitted model, lambda_stack: array containing lambda (emission) values, tau_samples,: array containing samples from changepoint distribution model.obs.observations: processed array on which fit was run

Source code in pytau/changepoint_model.py
def advi_fit(model, fit, samples):
    """Convenience function to perform ADVI fit on model

    Args:
        model (pymc3 model): model object to run inference on
        fit (int): Number of iterationst to fit the model for
        samples (int): Number of samples to draw from fitted model

    Returns:
        model: original model on which inference was run,
        approx: fitted model,
        lambda_stack: array containing lambda (emission) values,
        tau_samples,: array containing samples from changepoint distribution
        model.obs.observations: processed array on which fit was run
    """

    with model:
        inference = pm.ADVI('full-rank')
        approx = pm.fit(n=fit, method=inference)
        trace = approx.sample(draws=samples)

    # Extract relevant variables from trace
    lambda_stack = trace['lambda'].swapaxes(0, 1)
    tau_samples = trace['tau']

    return model, approx, lambda_stack, tau_samples, model.obs.observations

all_taste_poisson(spike_array, states)

Model to fit changepoint to single taste ** Largely taken from "_v1/poisson_all_tastes_changepoint_model.py"

spike_array :: Shape : tastes, trials, neurons, time_bins states :: number of states to include in the model

Source code in pytau/changepoint_model.py
def all_taste_poisson(
        spike_array,
        states):

    """
    ** Model to fit changepoint to single taste **
    ** Largely taken from "_v1/poisson_all_tastes_changepoint_model.py"

    spike_array :: Shape : tastes, trials, neurons, time_bins
    states :: number of states to include in the model 
    """

    # If model already doesn't exist, then create new one
    #spike_array = this_dat_binned
    # Unroll arrays along taste axis
    #spike_array_long = np.reshape(spike_array,(-1,*spike_array.shape[-2:]))
    spike_array_long = np.concatenate(spike_array, axis=0) 

    # Find mean firing for initial values
    tastes = spike_array.shape[0]
    length = spike_array.shape[-1]
    nrns = spike_array.shape[2]
    trials = spike_array.shape[1]

    split_list = np.array_split(spike_array,states,axis=-1)
    # Cut all to the same size
    min_val = min([x.shape[-1] for x in split_list])
    split_array = np.array([x[...,:min_val] for x in split_list])
    mean_vals = np.mean(split_array,axis=(2,-1)).swapaxes(0,1)
    mean_vals += 0.01 # To avoid zero starting prob
    mean_nrn_vals = np.mean(mean_vals,axis=(0,1))

    # Find evenly spaces switchpoints for initial values
    idx = np.arange(spike_array.shape[-1]) # Index
    array_idx = np.broadcast_to(idx, spike_array_long.shape)
    even_switches = np.linspace(0,idx.max(),states+1)
    even_switches_normal = even_switches/np.max(even_switches)

    taste_label = np.repeat(np.arange(spike_array.shape[0]), spike_array.shape[1])
    trial_num = array_idx.shape[0]

    # Being constructing model
    with pm.Model() as model:

        # Hierarchical firing rates
        # Refer to model diagram
        # Mean firing rate of neuron AT ALL TIMES
        lambda_nrn = pm.Exponential('lambda_nrn',
                                    1/mean_nrn_vals,
                                    shape=(mean_vals.shape[-1]))
        # Priors for each state, derived from each neuron
        # Mean firing rate of neuron IN EACH STATE (averaged across tastes)
        lambda_state = pm.Exponential('lambda_state',
                                      lambda_nrn,
                                      shape=(mean_vals.shape[1:]))
        # Mean firing rate of neuron PER STATE PER TASTE
        lambda_latent = pm.Exponential('lambda',
                                       lambda_state[np.newaxis, :, :],
                                       testval=mean_vals,
                                       shape=(mean_vals.shape))

        # Changepoint time variable
        # INDEPENDENT TAU FOR EVERY TRIAL
        a = pm.HalfNormal('a_tau', 3., shape=states - 1)
        b = pm.HalfNormal('b_tau', 3., shape=states - 1)

        # Stack produces states x trials --> That gets transposed
        # to trials x states and gets sorted along states (axis=-1)
        # Sort should work the same way as the Ordered transform -->
        # see rv_sort_test.ipynb
        tau_latent = pm.Beta('tau_latent', a, b,
                             shape=(trial_num, states-1),
                             testval=tt.tile(even_switches_normal[1:(states)],
                                             (array_idx.shape[0], 1))).sort(axis=-1)

        tau = pm.Deterministic('tau',
                               idx.min() + (idx.max() - idx.min()) * tau_latent)

        weight_stack = tt.nnet.sigmoid(
            idx[np.newaxis, :]-tau[:, :, np.newaxis])
        weight_stack = tt.concatenate(
            [np.ones((tastes*trials, 1, length)), weight_stack], axis=1)
        inverse_stack = 1 - weight_stack[:, 1:]
        inverse_stack = tt.concatenate([inverse_stack, np.ones(
            (tastes*trials, 1, length))], axis=1)
        weight_stack = weight_stack*inverse_stack
        weight_stack = tt.tile(weight_stack[:, :, None, :], (1, 1, nrns, 1))

        lambda_latent = lambda_latent.dimshuffle(2, 0, 1)
        lambda_latent = tt.repeat(lambda_latent, trials, axis=1)
        lambda_latent = tt.tile(
            lambda_latent[..., None], (1, 1, 1, length))
        lambda_latent = lambda_latent.dimshuffle(1, 2, 0, 3)
        lambda_ = tt.sum(lambda_latent * weight_stack, axis=1)

        observation = pm.Poisson("obs", lambda_, observed=spike_array_long)

    return model

compile_wait()

Function to allow waiting while a model is already fitting Wait twice because lock blips out between steps 10 secs of waiting shouldn't be a problem for long fits (~mins) And wait a random time in the beginning to stagger fits

Source code in pytau/changepoint_model.py
def compile_wait():
    """
    Function to allow waiting while a model is already fitting
    Wait twice because lock blips out between steps
    10 secs of waiting shouldn't be a problem for long fits (~mins)
    And wait a random time in the beginning to stagger fits
    """
    time.sleep(np.random.random()*10)
    while theano_lock_present():
        print('Lock present...waiting')
        time.sleep(10)
    while theano_lock_present():
        print('Lock present...waiting')
        time.sleep(10)

single_taste_poisson(spike_array, states)

Model for changepoint on single taste

Largely taken from "non_hardcoded_changepoint_test_3d.ipynb" Note : This model does not have hierarchical structure for emissions

Parameters:

Name Type Description Default
spike_array 3D Numpy array

trials x neurons x time

required
states int

Number of states to model

required

Returns:

Type Description
pymc3 model

Model class containing graph to run inference on

Source code in pytau/changepoint_model.py
def single_taste_poisson(
        spike_array,
        states):
    """Model for changepoint on single taste

    ** Largely taken from "non_hardcoded_changepoint_test_3d.ipynb"
    ** Note : This model does not have hierarchical structure for emissions

    Args:
        spike_array (3D Numpy array): trials x neurons x time
        states (int): Number of states to model

    Returns:
        pymc3 model: Model class containing graph to run inference on
    """

    mean_vals = np.array([np.mean(x, axis=-1)
                          for x in np.array_split(spike_array, states, axis=-1)]).T
    mean_vals = np.mean(mean_vals, axis=1)
    mean_vals += 0.01  # To avoid zero starting prob

    nrns = spike_array.shape[1]
    trials = spike_array.shape[0]
    idx = np.arange(spike_array.shape[-1])
    length = idx.max() + 1

    with pm.Model() as model:
        lambda_latent = pm.Exponential('lambda',
                                       1/mean_vals,
                                       shape=(nrns, states))

        a_tau = pm.HalfCauchy('a_tau', 3., shape=states - 1)
        b_tau = pm.HalfCauchy('b_tau', 3., shape=states - 1)

        even_switches = np.linspace(0, 1, states+1)[1:-1]
        tau_latent = pm.Beta('tau_latent', a_tau, b_tau,
                             testval=even_switches,
                             shape=(trials, states-1)).sort(axis=-1)

        tau = pm.Deterministic('tau',
                               idx.min() + (idx.max() - idx.min()) * tau_latent)

        weight_stack = tt.nnet.sigmoid(
                        idx[np.newaxis, :]-tau[:, :, np.newaxis])
        weight_stack = tt.concatenate(
                        [np.ones((trials, 1, length)), weight_stack], axis=1)
        inverse_stack = 1 - weight_stack[:, 1:]
        inverse_stack = tt.concatenate(
                        [inverse_stack, np.ones((trials, 1, length))], axis=1)
        weight_stack = np.multiply(weight_stack, inverse_stack)

        lambda_ = tt.tensordot(weight_stack, lambda_latent, [1, 1]).swapaxes(1, 2)
        observation = pm.Poisson("obs", lambda_, observed=spike_array)

    return model

theano_lock_present()

Check if theano compilation lock is present

Source code in pytau/changepoint_model.py
def theano_lock_present():
    """
    Check if theano compilation lock is present
    """
    return os.path.exists(os.path.join(theano.config.compiledir, 'lock_dir'))

=== Preprocessing functions ===

Code to preprocess spike trains before feeding into model

preprocess_all_taste(spike_array, time_lims, bin_width, data_transform)

Preprocess array containing trials for all tastes (in blocks) concatenated

Parameters:

Name Type Description Default
spike_array 4D Numpy Array

Taste x Trials x Neurons x Time

required
time_lims List/Tuple/Numpy Array

2-element object indicating limits of array

required
bin_width int

Width to use for binning

required
data_transform str

Data-type to return {actual, shuffled, simulated}

required

Exceptions:

Type Description
Exception

If transforms do not belong to ['shuffled','simulated','None',None]

Returns:

Type Description
(4D Numpy Array)

Of processed data

Source code in pytau/changepoint_preprocess.py
def preprocess_all_taste(spike_array,
                         time_lims,
                         bin_width,
                         data_transform):
    """Preprocess array containing trials for all tastes (in blocks) concatenated

    Args:
        spike_array (4D Numpy Array): Taste x Trials x Neurons x Time
        time_lims (List/Tuple/Numpy Array): 2-element object indicating limits of array
        bin_width (int): Width to use for binning
        data_transform (str): Data-type to return {actual, shuffled, simulated}

    Raises:
        Exception: If transforms do not belong to ['shuffled','simulated','None',None]

    Returns:
        (4D Numpy Array): Of processed data
    """

    accepted_transforms = ['trial_shuffled','spike_shuffled','simulated','None',None]
    if data_transform not in accepted_transforms:
        raise Exception(
            f'data_transform must be of type {accepted_transforms}')

    ##################################################
    # Create shuffled data
    ##################################################
    # Shuffle neurons across trials FOR SAME TASTE

    if data_transform == 'trial_shuffled':
        transformed_dat = np.array([np.random.permutation(neuron) \
                    for neuron in np.swapaxes(spike_array,2,0)])
        transformed_dat = np.swapaxes(transformed_dat,0,2)

    if data_transform == 'spike_shuffled':
        transformed_dat = spike_array.swapaxes(-1,0) 
        transformed_dat = np.stack([np.random.permutation(x) \
                for x in transformed_dat])
        transformed_dat = transformed_dat.swapaxes(0,-1) 

    ##################################################
    # Create simulated data
    ##################################################
    # Inhomogeneous poisson process using mean firing rates

    elif data_transform == 'simulated':
        mean_firing = np.mean(spike_array,axis=1)
        mean_firing = np.broadcast_to(mean_firing[:,None], spike_array.shape)

        # Simulate spikes
        transformed_dat = (np.random.random(spike_array.shape) < mean_firing )*1

    ##################################################
    # Null Transform Case
    ##################################################
    elif data_transform == None or data_transform == 'None':
        transformed_dat = spike_array

    ##################################################
    # Bin Data
    ##################################################
    spike_binned = \
            np.sum(transformed_dat[...,time_lims[0]:time_lims[1]].\
            reshape(*transformed_dat.shape[:-1],-1,bin_width),axis=-1)
    spike_binned = np.vectorize(np.int)(spike_binned)

    return spike_binned

preprocess_single_taste(spike_array, time_lims, bin_width, data_transform)

Preprocess array containing trials for all tastes (in blocks) concatenated

** Note, it may be useful to use x-arrays here to keep track of coordinates

Parameters:

Name Type Description Default
spike_array 3D Numpy array

trials x neurons x time

required
time_lims List/Tuple/Numpy Array

2-element object indicating limits of array

required
bin_width int

Width to use for binning

required
data_transform str

Data-type to return {actual, shuffled, simulated}

required

Exceptions:

Type Description
Exception

If transforms do not belong to ['shuffled','simulated','None',None]

Returns:

Type Description
(3D Numpy Array)

Of processed data

Source code in pytau/changepoint_preprocess.py
def preprocess_single_taste(spike_array,
                            time_lims,
                            bin_width,
                            data_transform):
    """Preprocess array containing trials for all tastes (in blocks) concatenated

    ** Note, it may be useful to use x-arrays here to keep track of coordinates

    Args:
        spike_array (3D Numpy array): trials x neurons x time
        time_lims (List/Tuple/Numpy Array): 2-element object indicating limits of array
        bin_width (int): Width to use for binning
        data_transform (str): Data-type to return {actual, shuffled, simulated}

    Raises:
        Exception: If transforms do not belong to ['shuffled','simulated','None',None]

    Returns:
        (3D Numpy Array): Of processed data
    """

    accepted_transforms = ['shuffled', 'simulated', 'None', None]
    if data_transform not in accepted_transforms:
        raise Exception(
            f'data_transform must be of type {accepted_transforms}')

    ##################################################
    # Create shuffled data
    ##################################################
    # Shuffle neurons across trials FOR SAME TASTE

    if data_transform == 'shuffled':
        transformed_dat = np.array([np.random.permutation(neuron) \
                                    for neuron in np.swapaxes(spike_array, 1, 0)])
        transformed_dat = np.swapaxes(transformed_dat, 0, 1)

    ##################################################
    # Create simulated data
    ##################################################
    # Inhomogeneous poisson process using mean firing rates

    elif data_transform == 'simulated':
        mean_firing = np.mean(spike_array, axis=0)

        # Simulate spikes
        transformed_dat = np.array(
            [np.random.random(mean_firing.shape) < mean_firing
             for trial in range(spike_array.shape[0])])*1

    ##################################################
    # Null Transform Case
    ##################################################
    elif data_transform in (None, 'None'):
        transformed_dat = spike_array

    ##################################################
    # Bin Data
    ##################################################
    spike_binned = \
        np.sum(transformed_dat[..., time_lims[0]:time_lims[1]].
               reshape(*spike_array.shape[:-1], -1, bin_width), axis=-1)
    spike_binned = np.vectorize(np.int)(spike_binned)

    return spike_binned