References

=== Analysis functions ===

Helper classes and functions to perform analysis on fitted models

PklHandler

Helper class to handle metadata and fit data from pkl file

Source code in pytau/changepoint_analysis.py
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
class PklHandler():
    """Helper class to handle metadata and fit data from pkl file
    """

    def __init__(self, file_path):
        """Initialize PklHandler class

        Args:
            file_path (str): Path to pkl file
        """
        self.dir_name = os.path.dirname(file_path)
        file_name = os.path.basename(file_path)
        self.file_name_base = file_name.split('.')[0]
        self.pkl_file_path = \
            os.path.join(self.dir_name, self.file_name_base + ".pkl")
        with open(self.pkl_file_path, 'rb') as this_file:
            self.data = pkl.load(this_file)

        model_keys = ['model', 'approx', 'lambda', 'tau', 'data']
        key_savenames = ['_model_structure', '_fit_model',
                         'lambda_array', 'tau_array', 'processed_spikes']
        data_map = dict(zip(model_keys, key_savenames))

        for key, var_name in data_map.items():
            setattr(self, var_name, self.data['model_data'][key])

        self.metadata = self.data['metadata']
        self.pretty_metadata = pd.json_normalize(self.data['metadata']).T

        self.tau = _tau(self.tau_array, self.metadata)
        self.firing = _firing(self.tau, self.processed_spikes, self.metadata)

__init__(file_path)

Initialize PklHandler class

Parameters:

Name Type Description Default
file_path str

Path to pkl file

required
Source code in pytau/changepoint_analysis.py
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
def __init__(self, file_path):
    """Initialize PklHandler class

    Args:
        file_path (str): Path to pkl file
    """
    self.dir_name = os.path.dirname(file_path)
    file_name = os.path.basename(file_path)
    self.file_name_base = file_name.split('.')[0]
    self.pkl_file_path = \
        os.path.join(self.dir_name, self.file_name_base + ".pkl")
    with open(self.pkl_file_path, 'rb') as this_file:
        self.data = pkl.load(this_file)

    model_keys = ['model', 'approx', 'lambda', 'tau', 'data']
    key_savenames = ['_model_structure', '_fit_model',
                     'lambda_array', 'tau_array', 'processed_spikes']
    data_map = dict(zip(model_keys, key_savenames))

    for key, var_name in data_map.items():
        setattr(self, var_name, self.data['model_data'][key])

    self.metadata = self.data['metadata']
    self.pretty_metadata = pd.json_normalize(self.data['metadata']).T

    self.tau = _tau(self.tau_array, self.metadata)
    self.firing = _firing(self.tau, self.processed_spikes, self.metadata)

calc_significant_neurons_firing(state_firing, p_val=0.05)

Calculate significant changes in firing rate between states Iterate ANOVA over neurons for all states With Bonferroni correction

Args state_firing (3D Numpy array): trials x states x nrns p_val (float, optional): p-value to use for significance. Defaults to 0.05.

Returns:

Name Type Description
anova_p_val_array 1D Numpy array

p-values for each neuron

anova_sig_neurons 1D Numpy array

indices of significant neurons

Source code in pytau/changepoint_analysis.py
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
def calc_significant_neurons_firing(state_firing, p_val=0.05):
    """Calculate significant changes in firing rate between states
    Iterate ANOVA over neurons for all states
    With Bonferroni correction

    Args
        state_firing (3D Numpy array): trials x states x nrns
        p_val (float, optional): p-value to use for significance. Defaults to 0.05.

    Returns:
        anova_p_val_array (1D Numpy array): p-values for each neuron
        anova_sig_neurons (1D Numpy array): indices of significant neurons
    """
    n_neurons = state_firing.shape[-1]
    # Calculate ANOVA p-values for each neuron
    anova_p_val_array = np.zeros(state_firing.shape[-1])
    for neuron in range(state_firing.shape[-1]):
        anova_p_val_array[neuron] = f_oneway(*state_firing[:, :, neuron].T)[1]
    anova_sig_neurons = np.where(anova_p_val_array < p_val/n_neurons)[0]

    return anova_p_val_array, anova_sig_neurons

calc_significant_neurons_snippets(transition_snips, p_val=0.05)

Calculate pairwise t-tests to detect differences between each transition With Bonferroni correction

Args transition_snips (4D Numpy array): trials x nrns x bins x transitions p_val (float, optional): p-value to use for significance. Defaults to 0.05.

Returns:

Name Type Description
anova_p_val_array (neurons, transition)

p-values for each neuron

anova_sig_neurons (neurons, transition)

indices of significant neurons

Source code in pytau/changepoint_analysis.py
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
def calc_significant_neurons_snippets(transition_snips, p_val=0.05):
    """Calculate pairwise t-tests to detect differences between each transition 
    With Bonferroni correction

    Args
        transition_snips (4D Numpy array): trials x nrns x bins x transitions
        p_val (float, optional): p-value to use for significance. Defaults to 0.05.

    Returns:
        anova_p_val_array (neurons, transition): p-values for each neuron 
        anova_sig_neurons (neurons, transition): indices of significant neurons
    """
    # Calculate pairwise t-tests for each transition
    # shape : [before, after] x trials x neurons x transitions
    mean_transition_snips = np.stack(np.array_split(
        transition_snips, 2, axis=2)).mean(axis=3)
    pairwise_p_val_array = np.zeros(mean_transition_snips.shape[2:])
    n_neuron, n_transitions = pairwise_p_val_array.shape
    for neuron in range(n_neuron):
        for transition in range(n_transitions):
            pairwise_p_val_array[neuron, transition] = \
                ttest_rel(*mean_transition_snips[:, :, neuron, transition])[1]
    pairwise_sig_neurons = pairwise_p_val_array < p_val  # /n_neuron
    return pairwise_p_val_array, pairwise_sig_neurons

get_state_firing(spike_array, tau_array)

Calculate firing rates within states given changepoint positions on data

Parameters:

Name Type Description Default
spike_array 3D Numpy array

trials x nrns x bins

required
tau_array 2D Numpy array

trials x switchpoints

required

Returns:

Name Type Description
state_firing 3D Numpy array

trials x states x nrns

Source code in pytau/changepoint_analysis.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def get_state_firing(spike_array, tau_array):
    """Calculate firing rates within states given changepoint positions on data

    Args:
        spike_array (3D Numpy array): trials x nrns x bins
        tau_array (2D Numpy array): trials x switchpoints

    Returns:
        state_firing (3D Numpy array): trials x states x nrns
    """

    states = tau_array.shape[-1] + 1
    # Get mean firing rate for each STATE using model
    state_inds = np.hstack([np.zeros((tau_array.shape[0], 1)),
                            tau_array,
                            np.ones((tau_array.shape[0], 1))*spike_array.shape[-1]])
    state_lims = np.array([state_inds[:, x:x+2] for x in range(states)])
    state_lims = np.vectorize(np.int)(state_lims)
    state_lims = np.swapaxes(state_lims, 0, 1)

    state_firing = \
        np.array([[np.mean(trial_dat[:, start:end], axis=-1)
                   for start, end in trial_lims]
                  for trial_dat, trial_lims in zip(spike_array, state_lims)])

    state_firing = np.nan_to_num(state_firing)
    return state_firing

get_transition_snips(spike_array, tau_array, window_radius=300)

Get snippets of activty around changepoints for each trial

Parameters:

Name Type Description Default
spike_array 3D Numpy array

trials x nrns x bins

required
tau_array 2D Numpy array

trials x switchpoints

required

Returns:

Type Description

Numpy array: Transition snippets : trials x nrns x bins x transitions

Make sure none of the snippets are outside the bounds of the data

Source code in pytau/changepoint_analysis.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def get_transition_snips(spike_array, tau_array, window_radius=300):
    """Get snippets of activty around changepoints for each trial

    Args:
        spike_array (3D Numpy array): trials x nrns x bins
        tau_array (2D Numpy array): trials x switchpoints

    Returns:
        Numpy array: Transition snippets : trials x nrns x bins x transitions

    Make sure none of the snippets are outside the bounds of the data
    """
    # Get snippets of activity around changepoints for each trial
    n_trials, n_neurons, n_bins = spike_array.shape
    n_transitions = tau_array.shape[1]
    transition_snips = np.zeros(
        (n_trials, n_neurons, 2*window_radius, n_transitions))
    window_lims = np.stack(
        [tau_array - window_radius, tau_array + window_radius], axis=-1)

    # Make sure no lims are outside the bounds of the data
    if (window_lims < 0).sum(axis=None) or (window_lims > n_bins).sum(axis=None):
        raise ValueError('Transition window extends outside data bounds')

    # Pull out snippets
    for trial in range(n_trials):
        for transition in range(n_transitions):
            transition_snips[trial, :, :, transition] = \
                spike_array[trial, :, window_lims[trial, transition, 0]                            :window_lims[trial, transition, 1]]
    return transition_snips

=== I/O functions ===

Pipeline to handle model fitting from data extraction to saving results

DatabaseHandler

Class to handle transactions with model database

Source code in pytau/changepoint_io.py
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
class DatabaseHandler():
    """Class to handle transactions with model database
    """

    def __init__(self):
        """Initialize DatabaseHandler class
        """
        self.unique_cols = ['exp.model_id', 'exp.save_path', 'exp.fit_date']
        self.model_database_path = MODEL_DATABASE_PATH
        self.model_save_base_dir = MODEL_SAVE_DIR

        if os.path.exists(self.model_database_path):
            self.fit_database = pd.read_csv(self.model_database_path,
                                            index_col=0)
            all_na = [all(x) for num, x in self.fit_database.isna().iterrows()]
            if all_na:
                print(f'{sum(all_na)} rows found with all NA, removing...')
                self.fit_database = self.fit_database.dropna(how='all')
        else:
            print('Fit database does not exist yet')

    def show_duplicates(self, keep='first'):
        """Find duplicates in database

        Args:
            keep (str, optional): Which duplicate to keep
                    (refer to pandas duplicated). Defaults to 'first'.

        Returns:
            pandas dataframe: Dataframe containing duplicated rows
            pandas series : Indices of duplicated rows
        """
        dup_inds = self.fit_database.drop(self.unique_cols, axis=1)\
            .duplicated(keep=keep)
        return self.fit_database.loc[dup_inds], dup_inds

    def drop_duplicates(self):
        """Remove duplicated rows from database
        """
        _, dup_inds = self.show_duplicates()
        print(f'Removing {sum(dup_inds)} duplicate rows')
        self.fit_database = self.fit_database.loc[~dup_inds]

    def check_mismatched_paths(self):
        """Check if there are any mismatched pkl files between database and directory

        Returns:
            pandas dataframe: Dataframe containing rows for which pkl file not present
            list: pkl files which cannot be matched to model in database
            list: all files in save directory
        """
        mismatch_from_database = [not os.path.exists(x + ".pkl")
                                  for x in self.fit_database['exp.save_path']]
        file_list = glob(os.path.join(self.model_save_base_dir, "*/*.pkl"))
        mismatch_from_file = [not
                              (x.split('.')[0] in list(
                                  self.fit_database['exp.save_path']))
                              for x in file_list]
        print(f"{sum(mismatch_from_database)} mismatches from database" + "\n"
              + f"{sum(mismatch_from_file)} mismatches from files")
        return mismatch_from_database, mismatch_from_file, file_list

    def clear_mismatched_paths(self):
        """Remove mismatched files and rows in database

        i.e. Remove
        1) Files for which no entry can be found in database
        2) Database entries for which no corresponding file can be found
        """
        mismatch_from_database, mismatch_from_file, file_list = \
            self.check_mismatched_paths()
        mismatch_from_file = np.array(mismatch_from_file)
        mismatch_from_database = np.array(mismatch_from_database)
        self.fit_database = self.fit_database.loc[~mismatch_from_database]
        mismatched_files = [x for x, y in zip(file_list, mismatch_from_file) if y]
        for x in mismatched_files:
            os.remove(x)
        print('==== Clearing Completed ====')

    def write_updated_database(self):
        """Can be called following clear_mismatched_entries to update current database
        """
        database_backup_dir = os.path.join(
            self.model_save_base_dir, '.database_backups')
        if not os.path.exists(database_backup_dir):
            os.makedirs(database_backup_dir)
        #current_date = date.today().strftime("%m-%d-%y")
        current_date = str(datetime.now()).replace(" ", "_")
        shutil.copy(self.model_database_path,
                    os.path.join(database_backup_dir,
                                 f"database_backup_{current_date}"))
        self.fit_database.to_csv(self.model_database_path, mode='w')

    def set_run_params(
            self, 
            data_dir, 
            experiment_name, 
            taste_num, 
            laser_type,
            region_name):
        """Store metadata related to inference run

        Args:
            data_dir (str): Path to directory containing HDF5 file
            experiment_name (str): Name given to fitted batch
                    (for metedata). Defaults to None.
            taste_num (int): Index of taste to perform fit on (Corresponds to
                    INDEX of taste in spike array, not actual dig_ins)
            laser_type (None or str): None, 'on', or 'off' (For a laser session, 
                    which set of trials are wanted, None indicated return all trials)
            region_name (str): Region on which to perform fit on
                    (must match regions in .info file)
        """
        self.data_dir = data_dir
        self.data_basename = os.path.basename(self.data_dir)
        self.animal_name = self.data_basename.split("_")[0]
        self.session_date = self.data_basename.split("_")[-1]

        self.experiment_name = experiment_name
        self.model_save_dir = os.path.join(self.model_save_base_dir,
                                           experiment_name)

        if not os.path.exists(self.model_save_dir):
            os.makedirs(self.model_save_dir)

        self.model_id = str(uuid.uuid4()).split('-')[0]
        self.model_save_path = os.path.join(self.model_save_dir,
                                            self.experiment_name +
                                            "_" + self.model_id)
        self.fit_date = date.today().strftime("%m-%d-%y")

        self.taste_num = taste_num
        self.laser_type = laser_type
        self.region_name = region_name

        self.fit_exists = None

    def ingest_fit_data(self, met_dict):
        """Load external metadata

        Args:
            met_dict (dict): Dictionary of metadata from FitHandler class
        """
        self.external_metadata = met_dict

    def aggregate_metadata(self):
        """Collects information regarding data and current "experiment"

        Raises:
            Exception: If 'external_metadata' has not been ingested, that needs to be done first

        Returns:
            dict: Dictionary of metadata given to FitHandler class
        """
        if 'external_metadata' not in dir(self):
            raise Exception('Fit run metdata needs to be ingested '
                            'into data_handler first')

        data_details = dict(zip(
            ['data_dir',
             'basename',
             'animal_name',
             'session_date',
             'taste_num',
             'laser_type',
             'region_name'],
            [self.data_dir,
             self.data_basename,
             self.animal_name,
             self.session_date,
             self.taste_num,
             self.laser_type,
             self.region_name]))

        exp_details = dict(zip(
            ['exp_name', 
             'model_id',
             'save_path',
             'fit_date'],
            [self.experiment_name,
             self.model_id,
             self.model_save_path,
             self.fit_date]))

        module_details = dict(zip(
            ['pymc3_version',
             'theano_version'],
            [pymc3.__version__,
                theano.__version__]
            ))

        temp_ext_met = self.external_metadata
        temp_ext_met['data'] = data_details
        temp_ext_met['exp'] = exp_details
        temp_ext_met['module'] = module_details

        return temp_ext_met

    def write_to_database(self):
        """Write out metadata to database
        """
        agg_metadata = self.aggregate_metadata()
        # Convert model_kwargs to str so that they are save appropriately
        agg_metadata['model']['model_kwargs'] = str(agg_metadata['model']['model_kwargs'])
        flat_metadata = pd.json_normalize(agg_metadata)
        if not os.path.isfile(self.model_database_path):
            flat_metadata.to_csv(self.model_database_path, mode='a')
        else:
            flat_metadata.to_csv(self.model_database_path,
                                 mode='a', header=False)

    def check_exists(self):
        """Check if the given fit already exists in database

        Returns:
            bool: Boolean for whether fit already exists or not
        """
        if self.fit_exists is not None:
            return self.fit_exists

__init__()

Initialize DatabaseHandler class

Source code in pytau/changepoint_io.py
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
def __init__(self):
    """Initialize DatabaseHandler class
    """
    self.unique_cols = ['exp.model_id', 'exp.save_path', 'exp.fit_date']
    self.model_database_path = MODEL_DATABASE_PATH
    self.model_save_base_dir = MODEL_SAVE_DIR

    if os.path.exists(self.model_database_path):
        self.fit_database = pd.read_csv(self.model_database_path,
                                        index_col=0)
        all_na = [all(x) for num, x in self.fit_database.isna().iterrows()]
        if all_na:
            print(f'{sum(all_na)} rows found with all NA, removing...')
            self.fit_database = self.fit_database.dropna(how='all')
    else:
        print('Fit database does not exist yet')

aggregate_metadata()

Collects information regarding data and current "experiment"

Raises:

Type Description
Exception

If 'external_metadata' has not been ingested, that needs to be done first

Returns:

Name Type Description
dict

Dictionary of metadata given to FitHandler class

Source code in pytau/changepoint_io.py
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
def aggregate_metadata(self):
    """Collects information regarding data and current "experiment"

    Raises:
        Exception: If 'external_metadata' has not been ingested, that needs to be done first

    Returns:
        dict: Dictionary of metadata given to FitHandler class
    """
    if 'external_metadata' not in dir(self):
        raise Exception('Fit run metdata needs to be ingested '
                        'into data_handler first')

    data_details = dict(zip(
        ['data_dir',
         'basename',
         'animal_name',
         'session_date',
         'taste_num',
         'laser_type',
         'region_name'],
        [self.data_dir,
         self.data_basename,
         self.animal_name,
         self.session_date,
         self.taste_num,
         self.laser_type,
         self.region_name]))

    exp_details = dict(zip(
        ['exp_name', 
         'model_id',
         'save_path',
         'fit_date'],
        [self.experiment_name,
         self.model_id,
         self.model_save_path,
         self.fit_date]))

    module_details = dict(zip(
        ['pymc3_version',
         'theano_version'],
        [pymc3.__version__,
            theano.__version__]
        ))

    temp_ext_met = self.external_metadata
    temp_ext_met['data'] = data_details
    temp_ext_met['exp'] = exp_details
    temp_ext_met['module'] = module_details

    return temp_ext_met

check_exists()

Check if the given fit already exists in database

Returns:

Name Type Description
bool

Boolean for whether fit already exists or not

Source code in pytau/changepoint_io.py
601
602
603
604
605
606
607
608
def check_exists(self):
    """Check if the given fit already exists in database

    Returns:
        bool: Boolean for whether fit already exists or not
    """
    if self.fit_exists is not None:
        return self.fit_exists

check_mismatched_paths()

Check if there are any mismatched pkl files between database and directory

Returns:

Name Type Description

pandas dataframe: Dataframe containing rows for which pkl file not present

list

pkl files which cannot be matched to model in database

list

all files in save directory

Source code in pytau/changepoint_io.py
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
def check_mismatched_paths(self):
    """Check if there are any mismatched pkl files between database and directory

    Returns:
        pandas dataframe: Dataframe containing rows for which pkl file not present
        list: pkl files which cannot be matched to model in database
        list: all files in save directory
    """
    mismatch_from_database = [not os.path.exists(x + ".pkl")
                              for x in self.fit_database['exp.save_path']]
    file_list = glob(os.path.join(self.model_save_base_dir, "*/*.pkl"))
    mismatch_from_file = [not
                          (x.split('.')[0] in list(
                              self.fit_database['exp.save_path']))
                          for x in file_list]
    print(f"{sum(mismatch_from_database)} mismatches from database" + "\n"
          + f"{sum(mismatch_from_file)} mismatches from files")
    return mismatch_from_database, mismatch_from_file, file_list

clear_mismatched_paths()

Remove mismatched files and rows in database

i.e. Remove 1) Files for which no entry can be found in database 2) Database entries for which no corresponding file can be found

Source code in pytau/changepoint_io.py
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
def clear_mismatched_paths(self):
    """Remove mismatched files and rows in database

    i.e. Remove
    1) Files for which no entry can be found in database
    2) Database entries for which no corresponding file can be found
    """
    mismatch_from_database, mismatch_from_file, file_list = \
        self.check_mismatched_paths()
    mismatch_from_file = np.array(mismatch_from_file)
    mismatch_from_database = np.array(mismatch_from_database)
    self.fit_database = self.fit_database.loc[~mismatch_from_database]
    mismatched_files = [x for x, y in zip(file_list, mismatch_from_file) if y]
    for x in mismatched_files:
        os.remove(x)
    print('==== Clearing Completed ====')

drop_duplicates()

Remove duplicated rows from database

Source code in pytau/changepoint_io.py
426
427
428
429
430
431
def drop_duplicates(self):
    """Remove duplicated rows from database
    """
    _, dup_inds = self.show_duplicates()
    print(f'Removing {sum(dup_inds)} duplicate rows')
    self.fit_database = self.fit_database.loc[~dup_inds]

ingest_fit_data(met_dict)

Load external metadata

Parameters:

Name Type Description Default
met_dict dict

Dictionary of metadata from FitHandler class

required
Source code in pytau/changepoint_io.py
527
528
529
530
531
532
533
def ingest_fit_data(self, met_dict):
    """Load external metadata

    Args:
        met_dict (dict): Dictionary of metadata from FitHandler class
    """
    self.external_metadata = met_dict

set_run_params(data_dir, experiment_name, taste_num, laser_type, region_name)

Store metadata related to inference run

Parameters:

Name Type Description Default
data_dir str

Path to directory containing HDF5 file

required
experiment_name str

Name given to fitted batch (for metedata). Defaults to None.

required
taste_num int

Index of taste to perform fit on (Corresponds to INDEX of taste in spike array, not actual dig_ins)

required
laser_type None or str

None, 'on', or 'off' (For a laser session, which set of trials are wanted, None indicated return all trials)

required
region_name str

Region on which to perform fit on (must match regions in .info file)

required
Source code in pytau/changepoint_io.py
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
def set_run_params(
        self, 
        data_dir, 
        experiment_name, 
        taste_num, 
        laser_type,
        region_name):
    """Store metadata related to inference run

    Args:
        data_dir (str): Path to directory containing HDF5 file
        experiment_name (str): Name given to fitted batch
                (for metedata). Defaults to None.
        taste_num (int): Index of taste to perform fit on (Corresponds to
                INDEX of taste in spike array, not actual dig_ins)
        laser_type (None or str): None, 'on', or 'off' (For a laser session, 
                which set of trials are wanted, None indicated return all trials)
        region_name (str): Region on which to perform fit on
                (must match regions in .info file)
    """
    self.data_dir = data_dir
    self.data_basename = os.path.basename(self.data_dir)
    self.animal_name = self.data_basename.split("_")[0]
    self.session_date = self.data_basename.split("_")[-1]

    self.experiment_name = experiment_name
    self.model_save_dir = os.path.join(self.model_save_base_dir,
                                       experiment_name)

    if not os.path.exists(self.model_save_dir):
        os.makedirs(self.model_save_dir)

    self.model_id = str(uuid.uuid4()).split('-')[0]
    self.model_save_path = os.path.join(self.model_save_dir,
                                        self.experiment_name +
                                        "_" + self.model_id)
    self.fit_date = date.today().strftime("%m-%d-%y")

    self.taste_num = taste_num
    self.laser_type = laser_type
    self.region_name = region_name

    self.fit_exists = None

show_duplicates(keep='first')

Find duplicates in database

Parameters:

Name Type Description Default
keep str

Which duplicate to keep (refer to pandas duplicated). Defaults to 'first'.

'first'

Returns:

Type Description

pandas dataframe: Dataframe containing duplicated rows

pandas series : Indices of duplicated rows

Source code in pytau/changepoint_io.py
411
412
413
414
415
416
417
418
419
420
421
422
423
424
def show_duplicates(self, keep='first'):
    """Find duplicates in database

    Args:
        keep (str, optional): Which duplicate to keep
                (refer to pandas duplicated). Defaults to 'first'.

    Returns:
        pandas dataframe: Dataframe containing duplicated rows
        pandas series : Indices of duplicated rows
    """
    dup_inds = self.fit_database.drop(self.unique_cols, axis=1)\
        .duplicated(keep=keep)
    return self.fit_database.loc[dup_inds], dup_inds

write_to_database()

Write out metadata to database

Source code in pytau/changepoint_io.py
588
589
590
591
592
593
594
595
596
597
598
599
def write_to_database(self):
    """Write out metadata to database
    """
    agg_metadata = self.aggregate_metadata()
    # Convert model_kwargs to str so that they are save appropriately
    agg_metadata['model']['model_kwargs'] = str(agg_metadata['model']['model_kwargs'])
    flat_metadata = pd.json_normalize(agg_metadata)
    if not os.path.isfile(self.model_database_path):
        flat_metadata.to_csv(self.model_database_path, mode='a')
    else:
        flat_metadata.to_csv(self.model_database_path,
                             mode='a', header=False)

write_updated_database()

Can be called following clear_mismatched_entries to update current database

Source code in pytau/changepoint_io.py
469
470
471
472
473
474
475
476
477
478
479
480
481
def write_updated_database(self):
    """Can be called following clear_mismatched_entries to update current database
    """
    database_backup_dir = os.path.join(
        self.model_save_base_dir, '.database_backups')
    if not os.path.exists(database_backup_dir):
        os.makedirs(database_backup_dir)
    #current_date = date.today().strftime("%m-%d-%y")
    current_date = str(datetime.now()).replace(" ", "_")
    shutil.copy(self.model_database_path,
                os.path.join(database_backup_dir,
                             f"database_backup_{current_date}"))
    self.fit_database.to_csv(self.model_database_path, mode='w')

FitHandler

Class to handle pipeline of model fitting including: 1) Loading data 2) Preprocessing loaded arrays 3) Fitting model 4) Writing out fitted parameters to pkl file

Source code in pytau/changepoint_io.py
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
class FitHandler():
    """Class to handle pipeline of model fitting including:
    1) Loading data
    2) Preprocessing loaded arrays
    3) Fitting model
    4) Writing out fitted parameters to pkl file

    """

    def __init__(self,
                 data_dir,
                 taste_num,
                 region_name,
                 laser_type = None,
                 experiment_name=None,
                 model_params_path=None,
                 preprocess_params_path=None,
                 ):
        """Initialize FitHandler class

        Args:
            data_dir (str): Path to directory containing HDF5 file
            taste_num (int): Index of taste to perform fit on
                    (Corresponds to INDEX of taste in spike array, not actual dig_ins)
            region_name (str): Region on which to perform fit on
                    (must match regions in .info file)
            experiment_name (str, optional): Name given to fitted batch
                    (for metedata). Defaults to None.
            model_params_path (str, optional): Path to json file
                    containing model parameters. Defaults to None.
            preprocess_params_path (str, optional): Path to json file
                    containing preprocessing parameters. Defaults to None.

        Raises:
            Exception: If "experiment_name" is None
            Exception: If "laser_type" is not in [None, 'on', 'off']
            Exception: If "taste_num" is not integer or "all"
        """

        # =============== Check for exceptions ===============
        if experiment_name is None:
            raise Exception('Please specify an experiment name')
        if laser_type not in [None, 'on','off']:
            raise Exception('laser_type must be from [None, "on","off"]')
        if not (isinstance(taste_num,int) or taste_num == 'all'):
            raise Exception('taste_num must be an integer or "all"')

        # =============== Save relevant arguments ===============
        self.data_dir = data_dir
        self.EphysData = EphysData(self.data_dir)
        #self.data = self.EphysData.get_spikes({"bla","gc","all"})

        self.taste_num = taste_num
        self.laser_type = laser_type
        self.region_name = region_name
        self.experiment_name = experiment_name

        data_handler_init_kwargs = dict(zip(
            [   'data_dir', 
                'experiment_name', 
                'taste_num', 
                'laser_type',
                'region_name'],
            [   data_dir, 
                experiment_name, 
                taste_num, 
                laser_type,
                region_name]))
        self.database_handler = DatabaseHandler()
        self.database_handler.set_run_params(**data_handler_init_kwargs)

        if model_params_path is None:
            print('MODEL_PARAMS will have to be set')
        else:
            self.set_model_params(file_path=model_params_path)

        if preprocess_params_path is None:
            print('PREPROCESS_PARAMS will have to be set')
        else:
            self.set_preprocess_params(file_path=preprocess_params_path)

    ########################################
    # SET PARAMS
    ########################################

    def set_preprocess_params(self,
                              time_lims,
                              bin_width,
                              data_transform,
                              file_path=None):

        """Load given params as "preprocess_params" attribute

        Args:
            time_lims (array/tuple/list): Start and end of where to cut
                    spike train array
            bin_width (int): Bin width for binning spikes to counts
            data_transform (str): Indicator for which transformation to
                    use (refer to changepoint_preprocess)
            file_path (str, optional): Path to json file containing preprocess
                    parameters. Defaults to None.
        """

        if file_path is None:
            preprocess_params_dict = dict(zip(['time_lims', 'bin_width', 'data_transform'],
                         [time_lims, bin_width, data_transform]))
            self.preprocess_params = preprocess_params_dict
            print('Set preprocess params to: {}'.format(preprocess_params_dict))

        else:
            # Load json and save dict
            pass

    def set_model_params(self,
                         states,
                         fit,
                         samples,
                         model_kwargs = None,
                         file_path=None):

        """Load given params as "model_params" attribute

        Args:
            states (int): Number of states to use in model
            fit (int): Iterations to use for model fitting (given ADVI fit)
            samples (int): Number of samples to return from fitten model
            model_kwargs (dict) : Additional paramters for model
            file_path (str, optional): Path to json file containing
                    preprocess parameters. Defaults to None.
        """

        if file_path is None:
            model_params_dict = dict(zip(['states', 'fit', 'samples', 'model_kwargs'], 
                    [states, fit, samples, model_kwargs]))
            self.model_params = model_params_dict
            print('Set model params to: {}'.format(model_params_dict))

        else:
            # Load json and save dict
            pass

    ########################################
    # SET PIPELINE FUNCS
    ########################################

    def set_preprocessor(self, preprocessing_func):
        """Manually set preprocessor for data e.g.

        FitHandler.set_preprocessor(
                    changepoint_preprocess.preprocess_single_taste)

        Args:
            preprocessing_func (func):
                    Function to preprocess data (refer to changepoint_preprocess)
        """
        self.preprocessor = preprocessing_func

    def preprocess_selector(self):
        """Function to return preprocess function based off of input flag

        Preprocessing can be set manually but it is preferred to
        go through preprocess selector

        Raises:
            Exception: If self.taste_num is neither int nor str

        """

        if isinstance(self.taste_num, int):
            self.set_preprocessor(
                changepoint_preprocess.preprocess_single_taste)
        elif self.taste_num == 'all':
            self.set_preprocessor(changepoint_preprocess.preprocess_all_taste)
        else:
            raise Exception("Something went wrong")

    def set_model_template(self, model_template):
        """Manually set model_template for data e.g.

        FitHandler.set_model(changepoint_model.single_taste_poisson)

        Args:
            model_template (func): Function to generate model template for data]
        """
        self.model_template = model_template

    def model_template_selector(self):
        """Function to set model based off of input flag

        Models can be set manually but it is preferred to go through model selector

        Raises:
            Exception: If self.taste_num is neither int nor str

        """
        if isinstance(self.taste_num, int):
            #self.set_model_template(changepoint_model.single_taste_poisson_varsig)
            self.set_model_template(
                    changepoint_model.single_taste_poisson)
        elif self.taste_num == 'all':
            self.set_model_template(changepoint_model.all_taste_poisson)
        else:
            raise Exception("Something went wrong")

    def set_inference(self, inference_func):
        """Manually set inference function for model fit e.g.

        FitHandler.set_inference(changepoint_model.advi_fit)

        Args:
            inference_func (func): Function to use for fitting model
        """
        self.inference_func = changepoint_model.advi_fit

    def inference_func_selector(self):
        """Function to return model based off of input flag

        Currently hard-coded to use "advi_fit"
        """
        self.set_inference(changepoint_model.advi_fit)

    ########################################
    # PIPELINE FUNCS
    ########################################

    def load_spike_trains(self):
        """Helper function to load spike trains from data_dir using EphysData module
        """
        full_spike_array = self.EphysData.return_region_spikes(
                region_name = self.region_name,
                laser = self.laser_type)
        if isinstance(self.taste_num, int):
            self.data = full_spike_array[self.taste_num]
        if self.taste_num == 'all':
            self.data = full_spike_array
        print(f'Loading spike trains from {self.database_handler.data_basename}, '
              f'dig_in {self.taste_num}, laser {str(self.laser_type)}')

    def preprocess_data(self):
        """Perform data preprocessing

        Will check for and complete:
        1) Raw data loaded
        2) Preprocessor selected
        """
        if 'data' not in dir(self):
            self.load_spike_trains()
        if 'preprocessor' not in dir(self):
            self.preprocess_selector()
        print('Preprocessing spike trains, '
              f'preprocessing func: <{self.preprocessor.__name__}>')
        self.preprocessed_data = \
            self.preprocessor(self.data, **self.preprocess_params)

    def create_model(self):
        """Create model and save as attribute

        Will check for and complete:
        1) Data preprocessed
        2) Model template selected
        """
        if 'preprocessed_data' not in dir(self):
            self.preprocess_data()
        if 'model_template' not in dir(self):
            self.model_template_selector()

        # In future iterations, before fitting model,
        # check that a similar entry doesn't exist

        changepoint_model.compile_wait()
        print(f'Generating Model, model func: <{self.model_template.__name__}>')
        self.model = self.model_template(self.preprocessed_data,
                                         self.model_params['states'],
                                         **self.model_params['model_kwargs'])

    def run_inference(self):
        """Perform inference on data

        Will check for and complete:
        1) Model created
        2) Inference function selected
        """
        if 'model' not in dir(self):
            self.create_model()
        if 'inference_func' not in dir(self):
            self.inference_func_selector()

        changepoint_model.compile_wait()
        print('Running inference, inference func: '
              f'<{self.inference_func.__name__}>')
        temp_outs = self.inference_func(self.model,
                                        self.model_params['fit'],
                                        self.model_params['samples'])
        varnames = ['model', 'approx', 'lambda', 'tau', 'data']
        self.inference_outs = dict(zip(varnames, temp_outs))

    def _gen_fit_metadata(self):
        """Generate metadata for fit

        Generate metadat by compiling:
        1) Preprocess parameters given as input
        2) Model parameters given as input
        3) Functions used in inference pipeline for : preprocessing,
                model generation, fitting

        Returns:
            dict: Dictionary containing compiled metadata for different
                    parts of inference pipeline
        """
        pre_params = self.preprocess_params
        model_params = self.model_params
        pre_params['preprocessor_name'] = self.preprocessor.__name__
        model_params['model_template_name'] = self.model_template.__name__
        model_params['inference_func_name'] = self.inference_func.__name__
        fin_dict = dict(zip(['preprocess', 'model'], [pre_params, model_params]))
        return fin_dict

    def _pass_metadata_to_handler(self):
        """Function to coordinate transfer of metadata to DatabaseHandler
        """
        self.database_handler.ingest_fit_data(self._gen_fit_metadata())

    def _return_fit_output(self):
        """Compile data, model, fit, and metadata to save output

        Returns:
            dict: Dictionary containing fitted model data and metadata
        """
        self._pass_metadata_to_handler()
        agg_metadata = self.database_handler.aggregate_metadata()
        return {'model_data': self.inference_outs, 'metadata': agg_metadata}

    def save_fit_output(self):
        """Save fit output (fitted data + metadata) to pkl file
        """
        if 'inference_outs' not in dir(self):
            self.run_inference()
        out_dict = self._return_fit_output()
        with open(self.database_handler.model_save_path + '.pkl', 'wb') as buff:
            pickle.dump(out_dict, buff)

        json_file_name = os.path.join(
            self.database_handler.model_save_path + '.info')
        with open(json_file_name, 'w') as file:
            json.dump(out_dict['metadata'], file, indent=4)

        self.database_handler.write_to_database()

        print('Saving inference output to : \n'
                f'{self.database_handler.model_save_dir}' 
                "\n" + "================================" + '\n')

__init__(data_dir, taste_num, region_name, laser_type=None, experiment_name=None, model_params_path=None, preprocess_params_path=None)

Initialize FitHandler class

Parameters:

Name Type Description Default
data_dir str

Path to directory containing HDF5 file

required
taste_num int

Index of taste to perform fit on (Corresponds to INDEX of taste in spike array, not actual dig_ins)

required
region_name str

Region on which to perform fit on (must match regions in .info file)

required
experiment_name str

Name given to fitted batch (for metedata). Defaults to None.

None
model_params_path str

Path to json file containing model parameters. Defaults to None.

None
preprocess_params_path str

Path to json file containing preprocessing parameters. Defaults to None.

None

Raises:

Type Description
Exception

If "experiment_name" is None

Exception

If "laser_type" is not in [None, 'on', 'off']

Exception

If "taste_num" is not integer or "all"

Source code in pytau/changepoint_io.py
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def __init__(self,
             data_dir,
             taste_num,
             region_name,
             laser_type = None,
             experiment_name=None,
             model_params_path=None,
             preprocess_params_path=None,
             ):
    """Initialize FitHandler class

    Args:
        data_dir (str): Path to directory containing HDF5 file
        taste_num (int): Index of taste to perform fit on
                (Corresponds to INDEX of taste in spike array, not actual dig_ins)
        region_name (str): Region on which to perform fit on
                (must match regions in .info file)
        experiment_name (str, optional): Name given to fitted batch
                (for metedata). Defaults to None.
        model_params_path (str, optional): Path to json file
                containing model parameters. Defaults to None.
        preprocess_params_path (str, optional): Path to json file
                containing preprocessing parameters. Defaults to None.

    Raises:
        Exception: If "experiment_name" is None
        Exception: If "laser_type" is not in [None, 'on', 'off']
        Exception: If "taste_num" is not integer or "all"
    """

    # =============== Check for exceptions ===============
    if experiment_name is None:
        raise Exception('Please specify an experiment name')
    if laser_type not in [None, 'on','off']:
        raise Exception('laser_type must be from [None, "on","off"]')
    if not (isinstance(taste_num,int) or taste_num == 'all'):
        raise Exception('taste_num must be an integer or "all"')

    # =============== Save relevant arguments ===============
    self.data_dir = data_dir
    self.EphysData = EphysData(self.data_dir)
    #self.data = self.EphysData.get_spikes({"bla","gc","all"})

    self.taste_num = taste_num
    self.laser_type = laser_type
    self.region_name = region_name
    self.experiment_name = experiment_name

    data_handler_init_kwargs = dict(zip(
        [   'data_dir', 
            'experiment_name', 
            'taste_num', 
            'laser_type',
            'region_name'],
        [   data_dir, 
            experiment_name, 
            taste_num, 
            laser_type,
            region_name]))
    self.database_handler = DatabaseHandler()
    self.database_handler.set_run_params(**data_handler_init_kwargs)

    if model_params_path is None:
        print('MODEL_PARAMS will have to be set')
    else:
        self.set_model_params(file_path=model_params_path)

    if preprocess_params_path is None:
        print('PREPROCESS_PARAMS will have to be set')
    else:
        self.set_preprocess_params(file_path=preprocess_params_path)

create_model()

Create model and save as attribute

Will check for and complete: 1) Data preprocessed 2) Model template selected

Source code in pytau/changepoint_io.py
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
def create_model(self):
    """Create model and save as attribute

    Will check for and complete:
    1) Data preprocessed
    2) Model template selected
    """
    if 'preprocessed_data' not in dir(self):
        self.preprocess_data()
    if 'model_template' not in dir(self):
        self.model_template_selector()

    # In future iterations, before fitting model,
    # check that a similar entry doesn't exist

    changepoint_model.compile_wait()
    print(f'Generating Model, model func: <{self.model_template.__name__}>')
    self.model = self.model_template(self.preprocessed_data,
                                     self.model_params['states'],
                                     **self.model_params['model_kwargs'])

inference_func_selector()

Function to return model based off of input flag

Currently hard-coded to use "advi_fit"

Source code in pytau/changepoint_io.py
251
252
253
254
255
256
def inference_func_selector(self):
    """Function to return model based off of input flag

    Currently hard-coded to use "advi_fit"
    """
    self.set_inference(changepoint_model.advi_fit)

load_spike_trains()

Helper function to load spike trains from data_dir using EphysData module

Source code in pytau/changepoint_io.py
262
263
264
265
266
267
268
269
270
271
272
273
def load_spike_trains(self):
    """Helper function to load spike trains from data_dir using EphysData module
    """
    full_spike_array = self.EphysData.return_region_spikes(
            region_name = self.region_name,
            laser = self.laser_type)
    if isinstance(self.taste_num, int):
        self.data = full_spike_array[self.taste_num]
    if self.taste_num == 'all':
        self.data = full_spike_array
    print(f'Loading spike trains from {self.database_handler.data_basename}, '
          f'dig_in {self.taste_num}, laser {str(self.laser_type)}')

model_template_selector()

Function to set model based off of input flag

Models can be set manually but it is preferred to go through model selector

Raises:

Type Description
Exception

If self.taste_num is neither int nor str

Source code in pytau/changepoint_io.py
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
def model_template_selector(self):
    """Function to set model based off of input flag

    Models can be set manually but it is preferred to go through model selector

    Raises:
        Exception: If self.taste_num is neither int nor str

    """
    if isinstance(self.taste_num, int):
        #self.set_model_template(changepoint_model.single_taste_poisson_varsig)
        self.set_model_template(
                changepoint_model.single_taste_poisson)
    elif self.taste_num == 'all':
        self.set_model_template(changepoint_model.all_taste_poisson)
    else:
        raise Exception("Something went wrong")

preprocess_data()

Perform data preprocessing

Will check for and complete: 1) Raw data loaded 2) Preprocessor selected

Source code in pytau/changepoint_io.py
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
def preprocess_data(self):
    """Perform data preprocessing

    Will check for and complete:
    1) Raw data loaded
    2) Preprocessor selected
    """
    if 'data' not in dir(self):
        self.load_spike_trains()
    if 'preprocessor' not in dir(self):
        self.preprocess_selector()
    print('Preprocessing spike trains, '
          f'preprocessing func: <{self.preprocessor.__name__}>')
    self.preprocessed_data = \
        self.preprocessor(self.data, **self.preprocess_params)

preprocess_selector()

Function to return preprocess function based off of input flag

Preprocessing can be set manually but it is preferred to go through preprocess selector

Raises:

Type Description
Exception

If self.taste_num is neither int nor str

Source code in pytau/changepoint_io.py
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
def preprocess_selector(self):
    """Function to return preprocess function based off of input flag

    Preprocessing can be set manually but it is preferred to
    go through preprocess selector

    Raises:
        Exception: If self.taste_num is neither int nor str

    """

    if isinstance(self.taste_num, int):
        self.set_preprocessor(
            changepoint_preprocess.preprocess_single_taste)
    elif self.taste_num == 'all':
        self.set_preprocessor(changepoint_preprocess.preprocess_all_taste)
    else:
        raise Exception("Something went wrong")

run_inference()

Perform inference on data

Will check for and complete: 1) Model created 2) Inference function selected

Source code in pytau/changepoint_io.py
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
def run_inference(self):
    """Perform inference on data

    Will check for and complete:
    1) Model created
    2) Inference function selected
    """
    if 'model' not in dir(self):
        self.create_model()
    if 'inference_func' not in dir(self):
        self.inference_func_selector()

    changepoint_model.compile_wait()
    print('Running inference, inference func: '
          f'<{self.inference_func.__name__}>')
    temp_outs = self.inference_func(self.model,
                                    self.model_params['fit'],
                                    self.model_params['samples'])
    varnames = ['model', 'approx', 'lambda', 'tau', 'data']
    self.inference_outs = dict(zip(varnames, temp_outs))

save_fit_output()

Save fit output (fitted data + metadata) to pkl file

Source code in pytau/changepoint_io.py
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
def save_fit_output(self):
    """Save fit output (fitted data + metadata) to pkl file
    """
    if 'inference_outs' not in dir(self):
        self.run_inference()
    out_dict = self._return_fit_output()
    with open(self.database_handler.model_save_path + '.pkl', 'wb') as buff:
        pickle.dump(out_dict, buff)

    json_file_name = os.path.join(
        self.database_handler.model_save_path + '.info')
    with open(json_file_name, 'w') as file:
        json.dump(out_dict['metadata'], file, indent=4)

    self.database_handler.write_to_database()

    print('Saving inference output to : \n'
            f'{self.database_handler.model_save_dir}' 
            "\n" + "================================" + '\n')

set_inference(inference_func)

Manually set inference function for model fit e.g.

FitHandler.set_inference(changepoint_model.advi_fit)

Parameters:

Name Type Description Default
inference_func func

Function to use for fitting model

required
Source code in pytau/changepoint_io.py
241
242
243
244
245
246
247
248
249
def set_inference(self, inference_func):
    """Manually set inference function for model fit e.g.

    FitHandler.set_inference(changepoint_model.advi_fit)

    Args:
        inference_func (func): Function to use for fitting model
    """
    self.inference_func = changepoint_model.advi_fit

set_model_params(states, fit, samples, model_kwargs=None, file_path=None)

Load given params as "model_params" attribute

Parameters:

Name Type Description Default
states int

Number of states to use in model

required
fit int

Iterations to use for model fitting (given ADVI fit)

required
samples int

Number of samples to return from fitten model

required
model_kwargs dict)

Additional paramters for model

None
file_path str

Path to json file containing preprocess parameters. Defaults to None.

None
Source code in pytau/changepoint_io.py
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
def set_model_params(self,
                     states,
                     fit,
                     samples,
                     model_kwargs = None,
                     file_path=None):

    """Load given params as "model_params" attribute

    Args:
        states (int): Number of states to use in model
        fit (int): Iterations to use for model fitting (given ADVI fit)
        samples (int): Number of samples to return from fitten model
        model_kwargs (dict) : Additional paramters for model
        file_path (str, optional): Path to json file containing
                preprocess parameters. Defaults to None.
    """

    if file_path is None:
        model_params_dict = dict(zip(['states', 'fit', 'samples', 'model_kwargs'], 
                [states, fit, samples, model_kwargs]))
        self.model_params = model_params_dict
        print('Set model params to: {}'.format(model_params_dict))

    else:
        # Load json and save dict
        pass

set_model_template(model_template)

Manually set model_template for data e.g.

FitHandler.set_model(changepoint_model.single_taste_poisson)

Parameters:

Name Type Description Default
model_template func

Function to generate model template for data]

required
Source code in pytau/changepoint_io.py
213
214
215
216
217
218
219
220
221
def set_model_template(self, model_template):
    """Manually set model_template for data e.g.

    FitHandler.set_model(changepoint_model.single_taste_poisson)

    Args:
        model_template (func): Function to generate model template for data]
    """
    self.model_template = model_template

set_preprocess_params(time_lims, bin_width, data_transform, file_path=None)

Load given params as "preprocess_params" attribute

Parameters:

Name Type Description Default
time_lims array / tuple / list

Start and end of where to cut spike train array

required
bin_width int

Bin width for binning spikes to counts

required
data_transform str

Indicator for which transformation to use (refer to changepoint_preprocess)

required
file_path str

Path to json file containing preprocess parameters. Defaults to None.

None
Source code in pytau/changepoint_io.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
def set_preprocess_params(self,
                          time_lims,
                          bin_width,
                          data_transform,
                          file_path=None):

    """Load given params as "preprocess_params" attribute

    Args:
        time_lims (array/tuple/list): Start and end of where to cut
                spike train array
        bin_width (int): Bin width for binning spikes to counts
        data_transform (str): Indicator for which transformation to
                use (refer to changepoint_preprocess)
        file_path (str, optional): Path to json file containing preprocess
                parameters. Defaults to None.
    """

    if file_path is None:
        preprocess_params_dict = dict(zip(['time_lims', 'bin_width', 'data_transform'],
                     [time_lims, bin_width, data_transform]))
        self.preprocess_params = preprocess_params_dict
        print('Set preprocess params to: {}'.format(preprocess_params_dict))

    else:
        # Load json and save dict
        pass

set_preprocessor(preprocessing_func)

Manually set preprocessor for data e.g.

FitHandler.set_preprocessor( changepoint_preprocess.preprocess_single_taste)

Parameters:

Name Type Description Default
preprocessing_func func
Function to preprocess data (refer to changepoint_preprocess)
required
Source code in pytau/changepoint_io.py
182
183
184
185
186
187
188
189
190
191
192
def set_preprocessor(self, preprocessing_func):
    """Manually set preprocessor for data e.g.

    FitHandler.set_preprocessor(
                changepoint_preprocess.preprocess_single_taste)

    Args:
        preprocessing_func (func):
                Function to preprocess data (refer to changepoint_preprocess)
    """
    self.preprocessor = preprocessing_func

=== Model building functions ===

PyMC3 Blackbox Variational Inference implementation of Poisson Likelihood Changepoint for spike trains.

advi_fit(model, fit, samples)

Convenience function to perform ADVI fit on model

Parameters:

Name Type Description Default
model pymc3 model

model object to run inference on

required
fit int

Number of iterationst to fit the model for

required
samples int

Number of samples to draw from fitted model

required

Returns:

Name Type Description
model

original model on which inference was run,

approx

fitted model,

lambda_stack

array containing lambda (emission) values,

tau_samples,: array containing samples from changepoint distribution

model.obs.observations: processed array on which fit was run

Source code in pytau/changepoint_model.py
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
def advi_fit(model, fit, samples):
    """Convenience function to perform ADVI fit on model

    Args:
        model (pymc3 model): model object to run inference on
        fit (int): Number of iterationst to fit the model for
        samples (int): Number of samples to draw from fitted model

    Returns:
        model: original model on which inference was run,
        approx: fitted model,
        lambda_stack: array containing lambda (emission) values,
        tau_samples,: array containing samples from changepoint distribution
        model.obs.observations: processed array on which fit was run
    """

    with model:
        inference = pm.ADVI('full-rank')
        approx = pm.fit(n=fit, method=inference)
        trace = approx.sample(draws=samples)

    # Extract relevant variables from trace
    tau_samples = trace['tau']
    if 'lambda' in trace.varnames:
        lambda_stack = trace['lambda'].swapaxes(0, 1)
        return model, approx, lambda_stack, tau_samples, model.obs.observations
    if 'mu' in trace.varnames:
        mu_stack = trace['mu'].swapaxes(0, 1)
        sigma_stack = trace['sigma'].swapaxes(0, 1)
        return model, approx, mu_stack, sigma_stack, tau_samples, model.obs.observations

all_taste_poisson(spike_array, states, **kwargs)

** Model to fit changepoint to single taste ** ** Largely taken from "_v1/poisson_all_tastes_changepoint_model.py"

spike_array :: Shape : tastes, trials, neurons, time_bins states :: number of states to include in the model

Source code in pytau/changepoint_model.py
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
def all_taste_poisson(
        spike_array,
        states,
        **kwargs):
    """
    ** Model to fit changepoint to single taste **
    ** Largely taken from "_v1/poisson_all_tastes_changepoint_model.py"

    spike_array :: Shape : tastes, trials, neurons, time_bins
    states :: number of states to include in the model 
    """

    # If model already doesn't exist, then create new one
    #spike_array = this_dat_binned
    # Unroll arrays along taste axis
    #spike_array_long = np.reshape(spike_array,(-1,*spike_array.shape[-2:]))
    spike_array_long = np.concatenate(spike_array, axis=0)

    # Find mean firing for initial values
    tastes = spike_array.shape[0]
    length = spike_array.shape[-1]
    nrns = spike_array.shape[2]
    trials = spike_array.shape[1]

    split_list = np.array_split(spike_array, states, axis=-1)
    # Cut all to the same size
    min_val = min([x.shape[-1] for x in split_list])
    split_array = np.array([x[..., :min_val] for x in split_list])
    mean_vals = np.mean(split_array, axis=(2, -1)).swapaxes(0, 1)
    mean_vals += 0.01  # To avoid zero starting prob
    mean_nrn_vals = np.mean(mean_vals, axis=(0, 1))

    # Find evenly spaces switchpoints for initial values
    idx = np.arange(spike_array.shape[-1])  # Index
    array_idx = np.broadcast_to(idx, spike_array_long.shape)
    even_switches = np.linspace(0, idx.max(), states+1)
    even_switches_normal = even_switches/np.max(even_switches)

    taste_label = np.repeat(
        np.arange(spike_array.shape[0]), spike_array.shape[1])
    trial_num = array_idx.shape[0]

    # Being constructing model
    with pm.Model() as model:

        # Hierarchical firing rates
        # Refer to model diagram
        # Mean firing rate of neuron AT ALL TIMES
        lambda_nrn = pm.Exponential('lambda_nrn',
                                    1/mean_nrn_vals,
                                    shape=(mean_vals.shape[-1]))
        # Priors for each state, derived from each neuron
        # Mean firing rate of neuron IN EACH STATE (averaged across tastes)
        lambda_state = pm.Exponential('lambda_state',
                                      lambda_nrn,
                                      shape=(mean_vals.shape[1:]))
        # Mean firing rate of neuron PER STATE PER TASTE
        lambda_latent = pm.Exponential('lambda',
                                       lambda_state[np.newaxis, :, :],
                                       testval=mean_vals,
                                       shape=(mean_vals.shape))

        # Changepoint time variable
        # INDEPENDENT TAU FOR EVERY TRIAL
        a = pm.HalfNormal('a_tau', 3., shape=states - 1)
        b = pm.HalfNormal('b_tau', 3., shape=states - 1)

        # Stack produces states x trials --> That gets transposed
        # to trials x states and gets sorted along states (axis=-1)
        # Sort should work the same way as the Ordered transform -->
        # see rv_sort_test.ipynb
        tau_latent = pm.Beta('tau_latent', a, b,
                             shape=(trial_num, states-1),
                             testval=tt.tile(even_switches_normal[1:(states)],
                                             (array_idx.shape[0], 1))).sort(axis=-1)

        tau = pm.Deterministic('tau',
                               idx.min() + (idx.max() - idx.min()) * tau_latent)

        weight_stack = tt.nnet.sigmoid(
            idx[np.newaxis, :]-tau[:, :, np.newaxis])
        weight_stack = tt.concatenate(
            [np.ones((tastes*trials, 1, length)), weight_stack], axis=1)
        inverse_stack = 1 - weight_stack[:, 1:]
        inverse_stack = tt.concatenate([inverse_stack, np.ones(
            (tastes*trials, 1, length))], axis=1)
        weight_stack = weight_stack*inverse_stack
        weight_stack = tt.tile(weight_stack[:, :, None, :], (1, 1, nrns, 1))

        lambda_latent = lambda_latent.dimshuffle(2, 0, 1)
        lambda_latent = tt.repeat(lambda_latent, trials, axis=1)
        lambda_latent = tt.tile(
            lambda_latent[..., None], (1, 1, 1, length))
        lambda_latent = lambda_latent.dimshuffle(1, 2, 0, 3)
        lambda_ = tt.sum(lambda_latent * weight_stack, axis=1)

        observation = pm.Poisson("obs", lambda_, observed=spike_array_long)

    return model

all_taste_poisson_trial_switch(spike_array, switch_components, states)

Assuming only emissions change across trials Changepoint distribution remains constant

spike_array :: Tastes x trials x nrns x time_bins states :: number of states to include in the model

Source code in pytau/changepoint_model.py
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
def all_taste_poisson_trial_switch(
        spike_array,
        switch_components,
        states):
    """
    Assuming only emissions change across trials
    Changepoint distribution remains constant

    spike_array :: Tastes x trials x nrns x time_bins
    states :: number of states to include in the model 
    """

    tastes, trial_num, nrn_num, time_bins = spike_array.shape

    with pm.Model() as model:

        # Define Emissions
        # =================================================

        # nrns
        nrn_lambda = pm.Exponential('nrn_lambda', 10, shape=(nrn_num))

        # tastes x nrns
        taste_lambda = pm.Exponential('taste_lambda',
                                      nrn_lambda.dimshuffle('x', 0),
                                      shape=(tastes, nrn_num))

        # tastes x nrns x switch_comps
        trial_lambda = pm.Exponential('trial_lambda',
                                      taste_lambda.dimshuffle(0, 1, 'x'),
                                      shape=(tastes, nrn_num, switch_components))

        # tastes x nrns x switch_comps x states
        state_lambda = pm.Exponential('state_lambda',
                                      trial_lambda.dimshuffle(0, 1, 2, 'x'),
                                      shape=(tastes, nrn_num, switch_components, states))

        # Define Changepoints
        # =================================================
        # Assuming distribution of changepoints remains
        # the same across all trials

        a = pm.HalfCauchy('a_tau', 3., shape=states - 1)
        b = pm.HalfCauchy('b_tau', 3., shape=states - 1)

        even_switches = np.linspace(0, 1, states+1)[1:-1]
        tau_latent = pm.Beta('tau_latent', a, b,
                             testval=even_switches,
                             shape=(tastes, trial_num, states-1)).sort(axis=-1)

        # Tasets x Trials x Changepoints
        tau = pm.Deterministic('tau', time_bins * tau_latent)

        # Define trial switches
        # Will have same structure as regular changepoints

        #a_trial = pm.HalfCauchy('a_trial', 3., shape = switch_components - 1)
        #b_trial = pm.HalfCauchy('b_trial', 3., shape = switch_components - 1)

        even_trial_switches = np.linspace(0, 1, switch_components+1)[1:-1]
        tau_trial_latent = pm.Beta('tau_trial_latent', 1, 1,
                                   testval=even_trial_switches,
                                   shape=(switch_components-1)).sort(axis=-1)

        # Trial_changepoints
        # =================================================
        tau_trial = pm.Deterministic('tau_trial', trial_num * tau_trial_latent)

        trial_idx = np.arange(trial_num)
        trial_selector = tt.nnet.sigmoid(
            trial_idx[np.newaxis, :] - tau_trial.dimshuffle(0, 'x'))

        trial_selector = tt.concatenate(
            [np.ones((1, trial_num)), trial_selector], axis=0)
        inverse_trial_selector = 1 - trial_selector[1:, :]
        inverse_trial_selector = tt.concatenate([inverse_trial_selector,
                                                np.ones((1, trial_num))], axis=0)

        # switch_comps x trials
        trial_selector = np.multiply(trial_selector, inverse_trial_selector)

        # state_lambda: tastes x nrns x switch_comps x states

        # selected_trial_lambda : tastes x nrns x states x trials
        selected_trial_lambda = pm.Deterministic('selected_trial_lambda',
                                                 tt.sum(
                                                     # "tastes" x "nrns" x switch_comps x "states" x trials
                                                     trial_selector.dimshuffle(
                                                         'x', 'x', 0, 'x', 1) * state_lambda.dimshuffle(0, 1, 2, 3, 'x'),
                                                     axis=2)
                                                 )

        # First, we can "select" sets of emissions depending on trial_changepoints
        # =================================================
        trial_idx = np.arange(trial_num)
        trial_selector = tt.nnet.sigmoid(
            trial_idx[np.newaxis, :] - tau_trial.dimshuffle(0, 'x'))

        trial_selector = tt.concatenate(
            [np.ones((1, trial_num)), trial_selector], axis=0)
        inverse_trial_selector = 1 - trial_selector[1:, :]
        inverse_trial_selector = tt.concatenate([inverse_trial_selector,
                                                np.ones((1, trial_num))], axis=0)

        # switch_comps x trials
        trial_selector = np.multiply(trial_selector, inverse_trial_selector)

        # Then, we can select state_emissions for every trial
        # =================================================

        idx = np.arange(time_bins)

        # tau : Tastes x Trials x Changepoints
        weight_stack = tt.nnet.sigmoid(
            idx[np.newaxis, :]-tau[:, :, :, np.newaxis])
        weight_stack = tt.concatenate(
            [np.ones((tastes, trial_num, 1, time_bins)), weight_stack], axis=2)
        inverse_stack = 1 - weight_stack[:, :, 1:]
        inverse_stack = tt.concatenate(
            [inverse_stack, np.ones((tastes, trial_num, 1, time_bins))], axis=2)

        # Tastes x Trials x states x Time
        weight_stack = np.multiply(weight_stack, inverse_stack)

        # Putting everything together
        # =================================================

        # selected_trial_lambda :           tastes x nrns x states x trials
        # Convert selected_trial_lambda --> tastes x trials x nrns x states x "time"

        # weight_stack :           tastes x trials x states x time
        # Convert weight_stack --> tastes x trials x "nrns" x states x time

        # tastes x trials x nrns x time
        lambda_ = tt.sum(selected_trial_lambda.dimshuffle(0, 3, 1, 2, 'x') * weight_stack.dimshuffle(0, 1, 'x', 2, 3),
                         axis=3)

        # Add observations
        observation = pm.Poisson("obs", lambda_, observed=spike_array)

    return model

all_taste_poisson_varsig_fixed(spike_array, states, inds_span=1)

** Model to fit changepoint to single taste ** ** Largely taken from "_v1/poisson_all_tastes_changepoint_model.py"

spike_array :: Shape : tastes, trials, neurons, time_bins states :: number of states to include in the model

Source code in pytau/changepoint_model.py
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
def all_taste_poisson_varsig_fixed(
        spike_array,
        states,
        inds_span=1):
    """
    ** Model to fit changepoint to single taste **
    ** Largely taken from "_v1/poisson_all_tastes_changepoint_model.py"

    spike_array :: Shape : tastes, trials, neurons, time_bins
    states :: number of states to include in the model 
    """

    # If model already doesn't exist, then create new one
    #spike_array = this_dat_binned
    # Unroll arrays along taste axis
    #spike_array_long = np.reshape(spike_array,(-1,*spike_array.shape[-2:]))
    spike_array_long = np.concatenate(spike_array, axis=0)

    # Find mean firing for initial values
    tastes = spike_array.shape[0]
    length = spike_array.shape[-1]
    nrns = spike_array.shape[2]
    trials = spike_array.shape[1]

    split_list = np.array_split(spike_array, states, axis=-1)
    # Cut all to the same size
    min_val = min([x.shape[-1] for x in split_list])
    split_array = np.array([x[..., :min_val] for x in split_list])
    mean_vals = np.mean(split_array, axis=(2, -1)).swapaxes(0, 1)
    mean_vals += 0.01  # To avoid zero starting prob
    mean_nrn_vals = np.mean(mean_vals, axis=(0, 1))

    # Find evenly spaces switchpoints for initial values
    idx = np.arange(spike_array.shape[-1])  # Index
    array_idx = np.broadcast_to(idx, spike_array_long.shape)
    even_switches = np.linspace(0, idx.max(), states+1)
    even_switches_normal = even_switches/np.max(even_switches)

    taste_label = np.repeat(
        np.arange(spike_array.shape[0]), spike_array.shape[1])
    trial_num = array_idx.shape[0]

    # Define sigmoid with given sharpness
    sig_b = inds_to_b(inds_span)

    def sigmoid(x):
        b_temp = tt.tile(
            np.array(sig_b)[None, None, None], x.tag.test_value.shape)
        return 1/(1+tt.exp(-b_temp*x))

    # Being constructing model
    with pm.Model() as model:

        # Hierarchical firing rates
        # Refer to model diagram
        # Mean firing rate of neuron AT ALL TIMES
        lambda_nrn = pm.Exponential('lambda_nrn',
                                    1/mean_nrn_vals,
                                    shape=(mean_vals.shape[-1]))
        # Priors for each state, derived from each neuron
        # Mean firing rate of neuron IN EACH STATE (averaged across tastes)
        lambda_state = pm.Exponential('lambda_state',
                                      lambda_nrn,
                                      shape=(mean_vals.shape[1:]))
        # Mean firing rate of neuron PER STATE PER TASTE
        lambda_latent = pm.Exponential('lambda',
                                       lambda_state[np.newaxis, :, :],
                                       testval=mean_vals,
                                       shape=(mean_vals.shape))

        # Changepoint time variable
        # INDEPENDENT TAU FOR EVERY TRIAL
        a = pm.HalfNormal('a_tau', 3., shape=states - 1)
        b = pm.HalfNormal('b_tau', 3., shape=states - 1)

        # Stack produces states x trials --> That gets transposed
        # to trials x states and gets sorted along states (axis=-1)
        # Sort should work the same way as the Ordered transform -->
        # see rv_sort_test.ipynb
        tau_latent = pm.Beta('tau_latent', a, b,
                             shape=(trial_num, states-1),
                             testval=tt.tile(even_switches_normal[1:(states)],
                                             (array_idx.shape[0], 1))).sort(axis=-1)

        tau = pm.Deterministic('tau',
                               idx.min() + (idx.max() - idx.min()) * tau_latent)

        weight_stack = sigmoid(
            idx[np.newaxis, :]-tau[:, :, np.newaxis])
        weight_stack = tt.concatenate(
            [np.ones((tastes*trials, 1, length)), weight_stack], axis=1)
        inverse_stack = 1 - weight_stack[:, 1:]
        inverse_stack = tt.concatenate([inverse_stack, np.ones(
            (tastes*trials, 1, length))], axis=1)
        weight_stack = weight_stack*inverse_stack
        weight_stack = tt.tile(weight_stack[:, :, None, :], (1, 1, nrns, 1))

        lambda_latent = lambda_latent.dimshuffle(2, 0, 1)
        lambda_latent = tt.repeat(lambda_latent, trials, axis=1)
        lambda_latent = tt.tile(
            lambda_latent[..., None], (1, 1, 1, length))
        lambda_latent = lambda_latent.dimshuffle(1, 2, 0, 3)
        lambda_ = tt.sum(lambda_latent * weight_stack, axis=1)

        observation = pm.Poisson("obs", lambda_, observed=spike_array_long)

    return model

compile_wait()

Function to allow waiting while a model is already fitting Wait twice because lock blips out between steps 10 secs of waiting shouldn't be a problem for long fits (~mins) And wait a random time in the beginning to stagger fits

Source code in pytau/changepoint_model.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def compile_wait():
    """
    Function to allow waiting while a model is already fitting
    Wait twice because lock blips out between steps
    10 secs of waiting shouldn't be a problem for long fits (~mins)
    And wait a random time in the beginning to stagger fits
    """
    time.sleep(np.random.random()*10)
    while theano_lock_present():
        print('Lock present...waiting')
        time.sleep(10)
    while theano_lock_present():
        print('Lock present...waiting')
        time.sleep(10)

dpp_fit(model, n_chains=24, n_cores=1, tune=500, draws=500)

Convenience function to fit DPP model

Source code in pytau/changepoint_model.py
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
def dpp_fit(model, n_chains = 24, n_cores = 1, tune = 500, draws = 500):
    """Convenience function to fit DPP model
    """
    with model:
        dpp_trace = pm.sample(
                            tune = tune,
                            draws = draws, 
                              target_accept = 0.95,
                             chains = n_chains,
                             cores = n_cores,
                            return_inferencedata=False)
    return dpp_trace

extract_inferred_values(trace)

Convenience function to extract inferred values from ADVI fit

Parameters:

Name Type Description Default
trace dict

trace

required

Returns:

Name Type Description
dict

dictionary of inferred values

Source code in pytau/changepoint_model.py
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
def extract_inferred_values(trace):
    """Convenience function to extract inferred values from ADVI fit

    Args:
        trace (dict): trace

    Returns:
        dict: dictionary of inferred values
    """
    # Extract relevant variables from trace
    out_dict = dict(
        tau_samples=trace['tau'])
    if 'lambda' in trace.varnames:
        out_dict['lambda_stack'] = trace['lambda'].swapaxes(0, 1)
    if 'mu' in trace.varnames:
        out_dict['mu_stack'] = trace['mu'].swapaxes(0, 1)
        out_dict['sigma_stack'] = trace['sigma'].swapaxes(0, 1)
    return out_dict

find_best_states(data, model_generator, n_fit, n_samples, min_states=2, max_states=10)

Convenience function to find best number of states for model

Parameters:

Name Type Description Default
data array

array on which to run inference

required
model_generator function

function that generates model

required
n_fit int

Number of iterationst to fit the model for

required
n_samples int

Number of samples to draw from fitted model

required
min_states int

Minimum number of states to test

2
max_states int

Maximum number of states to test

10

Returns:

Name Type Description
best_model

model with best number of states,

model_list

list of models with different number of states,

elbo_values

list of elbo values for different number of states

Source code in pytau/changepoint_model.py
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
def find_best_states(data, model_generator, n_fit, n_samples, min_states=2, max_states=10):
    """Convenience function to find best number of states for model

    Args:
        data (array): array on which to run inference
        model_generator (function): function that generates model
        n_fit (int): Number of iterationst to fit the model for
        n_samples (int): Number of samples to draw from fitted model
        min_states (int): Minimum number of states to test
        max_states (int): Maximum number of states to test

    Returns:
        best_model: model with best number of states,
        model_list: list of models with different number of states,
        elbo_values: list of elbo values for different number of states
    """
    n_state_array = np.arange(min_states, max_states+1)
    elbo_values = []
    model_list = []
    for n_states in tqdm(n_state_array):
        print(f'Fitting model with {n_states} states')
        model = model_generator(data, n_states)
        model, approx = advi_fit(model, n_fit, n_samples)[:2]
        elbo_values.append(approx.hist[-1])
        model_list.append(model)
    best_model = model_list[np.argmin(elbo_values)]
    return best_model, model_list, elbo_values

gaussian_changepoint_mean_2d(data_array, n_states, **kwargs)

Model for gaussian data on 2D array detecting changes only in the mean.

Parameters:

Name Type Description Default
data_array 2D Numpy array

x time

required
n_states int

Number of states to model

required

Returns:

Type Description

pymc3 model: Model class containing graph to run inference on

Source code in pytau/changepoint_model.py
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
def gaussian_changepoint_mean_2d(data_array, n_states, **kwargs):
    """Model for gaussian data on 2D array detecting changes only in 
    the mean.

    Args:
        data_array (2D Numpy array): <dimension> x time
        n_states (int): Number of states to model

    Returns:
        pymc3 model: Model class containing graph to run inference on
    """
    mean_vals = np.array([np.mean(x, axis=-1)
                          for x in np.array_split(data_array, n_states, axis=-1)]).T
    mean_vals += 0.01  # To avoid zero starting prob

    y_dim = data_array.shape[0]
    idx = np.arange(data_array.shape[-1])
    length = idx.max() + 1

    with pm.Model() as model:
        mu = pm.Normal('mu', mu=mean_vals, sd=1, shape=(y_dim, n_states))
        # One variance for each dimension
        sigma = pm.HalfCauchy('sigma', 3., shape=(y_dim))

        a_tau = pm.HalfCauchy('a_tau', 3., shape=n_states - 1)
        b_tau = pm.HalfCauchy('b_tau', 3., shape=n_states - 1)

        even_switches = np.linspace(0, 1, n_states+1)[1:-1]
        tau_latent = pm.Beta('tau_latent', a_tau, b_tau,
                             testval=even_switches,
                             shape=(n_states-1)).sort(axis=-1)

        tau = pm.Deterministic('tau',
                               idx.min() + (idx.max() - idx.min()) * tau_latent)

        weight_stack = tt.nnet.sigmoid(idx[np.newaxis, :]-tau[:, np.newaxis])
        weight_stack = tt.concatenate(
            [np.ones((1, length)), weight_stack], axis=0)
        inverse_stack = 1 - weight_stack[1:]
        inverse_stack = tt.concatenate(
            [inverse_stack, np.ones((1, length))], axis=0)
        weight_stack = np.multiply(weight_stack, inverse_stack)

        mu_latent = mu.dot(weight_stack)
        sigma_latent = sigma.dimshuffle(0, 'x')
        observation = pm.Normal("obs", mu=mu_latent, sd=sigma_latent,
                                observed=data_array)

    return model

gaussian_changepoint_mean_dirichlet(data_array, max_states=15)

Model for gaussian data on 2D array detecting changes only in the mean. Number of states determined using dirichlet process prior.

Parameters:

Name Type Description Default
data_array 2D Numpy array

x time

required
max_states int

Max number of states to include in truncated dirichlet process

15

Returns:

Type Description

pymc3 model: Model class containing graph to run inference on

Source code in pytau/changepoint_model.py
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
def gaussian_changepoint_mean_dirichlet(data_array, max_states=15):
    """Model for gaussian data on 2D array detecting changes only in 
    the mean. Number of states determined using dirichlet process prior.

    Args:
        data_array (2D Numpy array): <dimension> x time
        max_states (int): Max number of states to include in truncated dirichlet process 

    Returns:
        pymc3 model: Model class containing graph to run inference on
    """

    y_dim = data_array.shape[0]
    idx = np.arange(data_array.shape[-1])
    length = idx.max() + 1

    mean_vals = np.array([np.mean(x, axis=-1)
                          for x in np.array_split(data_array, max_states, axis=-1)]).T
    mean_vals += 0.01  # To avoid zero starting prob
    test_std = np.std(data_array, axis=-1)

    with pm.Model() as model:
        # ===================
        # Emissions Variables
        # ===================
        lambda_latent = pm.Normal('lambda',
                                  mu=mean_vals, sigma=10,
                                  shape=(y_dim, max_states))
        # One variance for each dimension
        sigma = pm.HalfCauchy('sigma', test_std, shape=(y_dim))

        # =====================
        # Changepoint Variables
        # =====================

        # Hyperpriors on alpha
        a_gamma = pm.Gamma('a_gamma', 10, 1)
        b_gamma = pm.Gamma('b_gamma', 1.5, 1)

        # Concentration parameter for beta
        alpha = pm.Gamma('alpha', a_gamma, b_gamma)

        # Draw beta's to calculate stick lengths
        beta = pm.Beta('beta', 1, alpha, shape=max_states)

        # Calculate stick lengths using stick_breaking process
        w_raw = pm.Deterministic('w_raw', stick_breaking(beta))

        # Make sure lengths add to 1, and scale to length of data
        w_latent = pm.Deterministic('w_latent', w_raw / w_raw.sum())
        tau = pm.Deterministic('tau', tt.cumsum(w_latent * length)[:-1])

        # Weight stack to assign lambda's to point in time
        weight_stack = tt.nnet.sigmoid(idx[np.newaxis, :]-tau[:, np.newaxis])
        weight_stack = tt.concatenate(
            [np.ones((1, length)), weight_stack], axis=0)
        inverse_stack = 1 - weight_stack[1:]
        inverse_stack = tt.concatenate(
            [inverse_stack, np.ones((1, length))], axis=0)
        weight_stack = np.multiply(weight_stack, inverse_stack)

        # Create timeseries for latent variable (mean emission)
        lambda_ = pm.Deterministic('lambda_',
                                   tt.tensordot(
                                       lambda_latent,
                                       weight_stack,
                                       axes=(1, 0)
                                   )
                                   )
        sigma_latent = sigma.dimshuffle(0, 'x')

        # Likelihood for observations
        observation = pm.Normal(
            "obs", mu=lambda_, sigma=sigma_latent, observed=data_array)
    return model

gaussian_changepoint_mean_var_2d(data_array, n_states, **kwargs)

Model for gaussian data on 2D array detecting changes in both mean and variance.

Parameters:

Name Type Description Default
data_array 2D Numpy array

x time

required
n_states int

Number of states to model

required

Returns:

Type Description

pymc3 model: Model class containing graph to run inference on

Source code in pytau/changepoint_model.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
def gaussian_changepoint_mean_var_2d(data_array, n_states, **kwargs):
    """Model for gaussian data on 2D array detecting changes in both
    mean and variance.

    Args:
        data_array (2D Numpy array): <dimension> x time
        n_states (int): Number of states to model

    Returns:
        pymc3 model: Model class containing graph to run inference on
    """
    mean_vals = np.array([np.mean(x, axis=-1)
                          for x in np.array_split(data_array, n_states, axis=-1)]).T
    mean_vals += 0.01  # To avoid zero starting prob

    y_dim = data_array.shape[0]
    idx = np.arange(data_array.shape[-1])
    length = idx.max() + 1

    with pm.Model() as model:
        mu = pm.Normal('mu', mu=mean_vals, sd=1, shape=(y_dim, n_states))
        sigma = pm.HalfCauchy('sigma', 3., shape=(y_dim, n_states))

        a_tau = pm.HalfCauchy('a_tau', 3., shape=n_states - 1)
        b_tau = pm.HalfCauchy('b_tau', 3., shape=n_states - 1)

        even_switches = np.linspace(0, 1, n_states+1)[1:-1]
        tau_latent = pm.Beta('tau_latent', a_tau, b_tau,
                             testval=even_switches,
                             shape=(n_states-1)).sort(axis=-1)

        tau = pm.Deterministic('tau',
                               idx.min() + (idx.max() - idx.min()) * tau_latent)

        weight_stack = tt.nnet.sigmoid(idx[np.newaxis, :]-tau[:, np.newaxis])
        weight_stack = tt.concatenate(
            [np.ones((1, length)), weight_stack], axis=0)
        inverse_stack = 1 - weight_stack[1:]
        inverse_stack = tt.concatenate(
            [inverse_stack, np.ones((1, length))], axis=0)
        weight_stack = np.multiply(weight_stack, inverse_stack)

        mu_latent = mu.dot(weight_stack)
        sigma_latent = sigma.dot(weight_stack)
        observation = pm.Normal("obs", mu=mu_latent, sd=sigma_latent,
                                observed=data_array)

    return model

gen_test_array(array_size, n_states, type='poisson')

Generate test array for model fitting Last 2 dimensions consist of a single trial Time will always be last dimension

Parameters:

Name Type Description Default
array_size tuple

Size of array to generate

required
n_states int

Number of states to generate

required
type str

Type of data to generate - normal - poisson

'poisson'
Source code in pytau/changepoint_model.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def gen_test_array(array_size, n_states, type='poisson'):
    """
    Generate test array for model fitting
    Last 2 dimensions consist of a single trial
    Time will always be last dimension

    Args:
        array_size (tuple): Size of array to generate
        n_states (int): Number of states to generate
        type (str): Type of data to generate
            - normal
            - poisson
    """
    assert array_size[-1] > n_states, 'Array too small for states'
    assert type in [
        'normal', 'poisson'], 'Invalid type, please use normal or poisson'

    # Generate transition times
    transition_times = np.random.random((*array_size[:-2], n_states))
    transition_times = np.cumsum(transition_times, axis=-1)
    transition_times = transition_times / \
        transition_times.max(axis=-1, keepdims=True)
    transition_times *= array_size[-1]
    transition_times = np.vectorize(int)(transition_times)

    # Generate state bounds
    state_bounds = np.zeros((*array_size[:-2], n_states+1), dtype=int)
    state_bounds[..., 1:] = transition_times

    # Generate state rates
    lambda_vals = np.random.random((*array_size[:-1], n_states))

    # Generate array
    rate_array = np.zeros(array_size)
    inds = list(np.ndindex(lambda_vals.shape))
    for this_ind in inds:
        this_lambda = lambda_vals[this_ind[:-2]][:, this_ind[-1]]
        this_state_bounds = [
            state_bounds[(*this_ind[:-2], this_ind[-1])],
            state_bounds[(*this_ind[:-2], this_ind[-1]+1)]]
        rate_array[this_ind[:-2]][:,
                                  slice(*this_state_bounds)] = this_lambda[:, None]

    if type == 'poisson':
        return np.random.poisson(rate_array)
    else:
        return np.random.normal(loc=rate_array, scale=0.1)

mcmc_fit(model, samples)

Convenience function to perform ADVI fit on model

Parameters:

Name Type Description Default
model pymc3 model

model object to run inference on

required
samples int

Number of samples to draw using MCMC

required

Returns:

Name Type Description
model

original model on which inference was run,

trace

samples drawn from MCMC,

lambda_stack

array containing lambda (emission) values,

tau_samples,: array containing samples from changepoint distribution

model.obs.observations: processed array on which fit was run

Source code in pytau/changepoint_model.py
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
def mcmc_fit(model, samples):
    """Convenience function to perform ADVI fit on model

    Args:
        model (pymc3 model): model object to run inference on
        samples (int): Number of samples to draw using MCMC 

    Returns:
        model: original model on which inference was run,
        trace:  samples drawn from MCMC,
        lambda_stack: array containing lambda (emission) values,
        tau_samples,: array containing samples from changepoint distribution
        model.obs.observations: processed array on which fit was run
    """

    with model:
        sampler_kwargs = {'cores': 1, 'chains': 4}
        trace = pm.sample(draws=samples, **sampler_kwargs)
        trace = trace[::10]

    # Extract relevant variables from trace
    tau_samples = trace['tau']
    if 'lambda' in trace.varnames:
        lambda_stack = trace['lambda'].swapaxes(0, 1)
        return model, approx, lambda_stack, tau_samples, model.obs.observations
    if 'mu' in trace.varnames:
        mu_stack = trace['mu'].swapaxes(0, 1)
        sigma_stack = trace['sigma'].swapaxes(0, 1)
        return model, approx, mu_stack, sigma_stack, tau_samples, model.obs.observations

single_taste_poisson(spike_array, states, **kwargs)

Model for changepoint on single taste

** Largely taken from "non_hardcoded_changepoint_test_3d.ipynb" ** Note : This model does not have hierarchical structure for emissions

Parameters:

Name Type Description Default
spike_array 3D Numpy array

trials x neurons x time

required
states int

Number of states to model

required

Returns:

Type Description

pymc3 model: Model class containing graph to run inference on

Source code in pytau/changepoint_model.py
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
def single_taste_poisson(
        spike_array,
        states,
        **kwargs):
    """Model for changepoint on single taste

    ** Largely taken from "non_hardcoded_changepoint_test_3d.ipynb"
    ** Note : This model does not have hierarchical structure for emissions

    Args:
        spike_array (3D Numpy array): trials x neurons x time
        states (int): Number of states to model

    Returns:
        pymc3 model: Model class containing graph to run inference on
    """

    mean_vals = np.array([np.mean(x, axis=-1)
                          for x in np.array_split(spike_array, states, axis=-1)]).T
    mean_vals = np.mean(mean_vals, axis=1)
    mean_vals += 0.01  # To avoid zero starting prob

    nrns = spike_array.shape[1]
    trials = spike_array.shape[0]
    idx = np.arange(spike_array.shape[-1])
    length = idx.max() + 1

    with pm.Model() as model:
        lambda_latent = pm.Exponential('lambda',
                                       1/mean_vals,
                                       shape=(nrns, states))

        a_tau = pm.HalfCauchy('a_tau', 3., shape=states - 1)
        b_tau = pm.HalfCauchy('b_tau', 3., shape=states - 1)

        even_switches = np.linspace(0, 1, states+1)[1:-1]
        tau_latent = pm.Beta('tau_latent', a_tau, b_tau,
                             testval=even_switches,
                             shape=(trials, states-1)).sort(axis=-1)

        tau = pm.Deterministic('tau',
                               idx.min() + (idx.max() - idx.min()) * tau_latent)

        weight_stack = tt.nnet.sigmoid(
            idx[np.newaxis, :]-tau[:, :, np.newaxis])
        weight_stack = tt.concatenate(
            [np.ones((trials, 1, length)), weight_stack], axis=1)
        inverse_stack = 1 - weight_stack[:, 1:]
        inverse_stack = tt.concatenate(
            [inverse_stack, np.ones((trials, 1, length))], axis=1)
        weight_stack = np.multiply(weight_stack, inverse_stack)

        lambda_ = tt.tensordot(weight_stack, lambda_latent, [
                               1, 1]).swapaxes(1, 2)
        observation = pm.Poisson("obs", lambda_, observed=spike_array)

    return model

single_taste_poisson_dirichlet(spike_array, max_states=10, **kwargs)

Model for changepoint on single taste using dirichlet process prior

Parameters:

Name Type Description Default
spike_array 3D Numpy array

trials x neurons x time

required
max_states int

Maximum number of states to model

10

Returns:

Type Description

pymc3 model: Model class containing graph to run inference on

Source code in pytau/changepoint_model.py
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
def single_taste_poisson_dirichlet(
        spike_array,
        max_states=10,
        **kwargs):
    """
    Model for changepoint on single taste using dirichlet process prior

    Args:
        spike_array (3D Numpy array): trials x neurons x time
        max_states (int): Maximum number of states to model

    Returns:
        pymc3 model: Model class containing graph to run inference on
    """
    mean_vals = np.array([np.mean(x, axis=-1)
                          for x in np.array_split(spike_array, max_states, axis=-1)]).T
    mean_vals = np.mean(mean_vals, axis=1)
    mean_vals += 0.01  # To avoid zero starting prob

    nrns = spike_array.shape[1]
    trials = spike_array.shape[0]
    idx = np.arange(spike_array.shape[-1])
    length = idx.max() + 1

    with pm.Model() as model:
        # ===================
        # Emissions Variables
        # ===================
        lambda_latent = pm.Exponential('lambda',
                                       1/mean_vals,
                                       shape=(nrns, max_states))

        # =====================
        # Changepoint Variables
        # =====================

        # Hyperpriors on alpha
        a_gamma = pm.Gamma('a_gamma', 10, 1)
        b_gamma = pm.Gamma('b_gamma', 1.5, 1)

        # Concentration parameter for beta
        alpha = pm.Gamma('alpha', a_gamma, b_gamma)

        # Draw beta's to calculate stick lengths
        beta = pm.Beta('beta', 1, alpha, shape=(trials, max_states))

        # Calculate stick lengths using stick_breaking process
        w_raw = pm.Deterministic('w_raw', stick_breaking_trial(beta, trials))

        # Make sure lengths add to 1, and scale to length of data
        w_latent = pm.Deterministic(
            'w_latent', w_raw / w_raw.sum(axis=-1)[:, None])
        tau = pm.Deterministic('tau', tt.cumsum(
            w_latent * length, axis=-1)[:, :-1])

        # =====================
        # Rate over time
        # =====================

        # Weight stack to assign lambda's to point in time
        weight_stack = tt.nnet.sigmoid(
            idx[np.newaxis, :]-tau[:, :, np.newaxis])
        weight_stack = tt.concatenate(
            [np.ones((trials, 1, length)), weight_stack], axis=1)
        inverse_stack = 1 - weight_stack[:, 1:]
        inverse_stack = tt.concatenate(
            [inverse_stack, np.ones((trials, 1, length))], axis=1)
        # Trials x States x Time
        weight_stack = np.multiply(weight_stack, inverse_stack)

        lambda_ = pm.Deterministic('lambda_',
            tt.tensordot(weight_stack, lambda_latent, [1, 1]).swapaxes(1, 2))

        # =====================
        # Likelihood
        # =====================
        observation = pm.Poisson("obs", lambda_, observed=spike_array)

    return model

single_taste_poisson_trial_switch(spike_array, switch_components, states)

Assuming only emissions change across trials Changepoint distribution remains constant

spike_array :: trials x nrns x time states :: number of states to include in the model

Source code in pytau/changepoint_model.py
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
def single_taste_poisson_trial_switch(
        spike_array,
        switch_components,
        states):
    """
    Assuming only emissions change across trials
    Changepoint distribution remains constant

    spike_array :: trials x nrns x time
    states :: number of states to include in the model 
    """

    trial_num, nrn_num, time_bins = spike_array.shape

    with pm.Model() as model:

        # Define Emissions

        # nrns
        nrn_lambda = pm.Exponential('nrn_lambda', 10, shape=(nrn_num))

        # nrns x switch_comps
        trial_lambda = pm.Exponential('trial_lambda',
                                      nrn_lambda.dimshuffle(0, 'x'),
                                      shape=(nrn_num, switch_components))

        # nrns x switch_comps x states
        state_lambda = pm.Exponential('state_lambda',
                                      trial_lambda.dimshuffle(0, 1, 'x'),
                                      shape=(nrn_num, switch_components, states))

        # Define Changepoints
        # Assuming distribution of changepoints remains
        # the same across all trials

        a = pm.HalfCauchy('a_tau', 3., shape=states - 1)
        b = pm.HalfCauchy('b_tau', 3., shape=states - 1)

        even_switches = np.linspace(0, 1, states+1)[1:-1]
        tau_latent = pm.Beta('tau_latent', a, b,
                             testval=even_switches,
                             shape=(trial_num, states-1)).sort(axis=-1)

        # Trials x Changepoints
        tau = pm.Deterministic('tau', time_bins * tau_latent)

        # Define trial switches
        # Will have same structure as regular changepoints

        even_trial_switches = np.linspace(0, 1, switch_components+1)[1:-1]
        tau_trial_latent = pm.Beta('tau_trial_latent', 1, 1,
                                   testval=even_trial_switches,
                                   shape=(switch_components-1)).sort(axis=-1)

        # Trial_changepoints
        tau_trial = pm.Deterministic('tau_trial', trial_num * tau_trial_latent)

        trial_idx = np.arange(trial_num)
        trial_selector = tt.nnet.sigmoid(
            trial_idx[np.newaxis, :] - tau_trial.dimshuffle(0, 'x'))

        trial_selector = tt.concatenate(
            [np.ones((1, trial_num)), trial_selector], axis=0)
        inverse_trial_selector = 1 - trial_selector[1:, :]
        inverse_trial_selector = tt.concatenate([inverse_trial_selector,
                                                np.ones((1, trial_num))], axis=0)

        # First, we can "select" sets of emissions depending on trial_changepoints
        # switch_comps x trials
        trial_selector = np.multiply(trial_selector, inverse_trial_selector)

        # state_lambda: nrns x switch_comps x states

        # selected_trial_lambda : nrns x states x trials
        selected_trial_lambda = pm.Deterministic('selected_trial_lambda',
                                                 tt.sum(
                                                     # "nrns" x switch_comps x "states" x trials
                                                     trial_selector.dimshuffle(
                                                         'x', 0, 'x', 1) * state_lambda.dimshuffle(0, 1, 2, 'x'),
                                                     axis=1)
                                                 )

        # Then, we can select state_emissions for every trial
        idx = np.arange(time_bins)

        # tau : Trials x Changepoints
        weight_stack = tt.nnet.sigmoid(
            idx[np.newaxis, :]-tau[:, :, np.newaxis])
        weight_stack = tt.concatenate(
            [np.ones((trial_num, 1, time_bins)), weight_stack], axis=1)
        inverse_stack = 1 - weight_stack[:, 1:]
        inverse_stack = tt.concatenate(
            [inverse_stack, np.ones((trial_num, 1, time_bins))], axis=1)

        # Trials x states x Time
        weight_stack = np.multiply(weight_stack, inverse_stack)

        # Convert selected_trial_lambda : nrns x trials x states x "time"

        # nrns x trials x time
        lambda_ = tt.sum(selected_trial_lambda.dimshuffle(0, 2, 1, 'x') * weight_stack.dimshuffle('x', 0, 1, 2),
                         axis=2)

        # Convert to : trials x nrns x time
        lambda_ = lambda_.dimshuffle(1, 0, 2)

        # Add observations
        observation = pm.Poisson("obs", lambda_, observed=spike_array)

    return model

single_taste_poisson_varsig(spike_array, states, **kwargs)

Model for changepoint on single taste **Uses variables sigmoid slope inferred from data

** Largely taken from "non_hardcoded_changepoint_test_3d.ipynb" ** Note : This model does not have hierarchical structure for emissions

Parameters:

Name Type Description Default
spike_array 3D Numpy array

trials x neurons x time

required
states int

Number of states to model

required

Returns:

Type Description

pymc3 model: Model class containing graph to run inference on

Source code in pytau/changepoint_model.py
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
def single_taste_poisson_varsig(
        spike_array,
        states,
        **kwargs):
    """Model for changepoint on single taste
        **Uses variables sigmoid slope inferred from data

    ** Largely taken from "non_hardcoded_changepoint_test_3d.ipynb"
    ** Note : This model does not have hierarchical structure for emissions

    Args:
        spike_array (3D Numpy array): trials x neurons x time
        states (int): Number of states to model

    Returns:
        pymc3 model: Model class containing graph to run inference on
    """

    mean_vals = np.array([np.mean(x, axis=-1)
                          for x in np.array_split(spike_array, states, axis=-1)]).T
    mean_vals = np.mean(mean_vals, axis=1)
    mean_vals += 0.01  # To avoid zero starting prob

    lambda_test_vals = np.diff(mean_vals, axis=-1)
    even_switches = np.linspace(0, 1, states+1)[1:-1]

    nrns = spike_array.shape[1]
    trials = spike_array.shape[0]
    idx = np.arange(spike_array.shape[-1])
    length = idx.max() + 1

    with pm.Model() as model:

        # Sigmoid slope
        sig_b = pm.Normal('sig_b', -1, 2, shape=states-1)

        # Initial value
        s0 = pm.Exponential('state0',
                            1/(np.mean(mean_vals)),
                            shape=nrns,
                            testval=mean_vals[:, 0])

        # Changes to lambda
        lambda_diff = pm.Normal('lambda_diff',
                                mu=0, sigma=10,
                                shape=(nrns, states-1),
                                testval=lambda_test_vals)

        # This is only here to be extracted at the end of sampling
        # NOT USED DIRECTLY IN MODEL
        lambda_fin = pm.Deterministic('lambda',
                                      tt.concatenate(
                                          [s0[:, np.newaxis], lambda_diff], axis=-1)
                                      )

        # Changepoint positions
        a = pm.HalfCauchy('a_tau', 10, shape=states - 1)
        b = pm.HalfCauchy('b_tau', 10, shape=states - 1)

        tau_latent = pm.Beta('tau_latent', a, b,
                             testval=even_switches,
                             shape=(trials, states-1)).sort(axis=-1)
        tau = pm.Deterministic('tau',
                               idx.min() + (idx.max() - idx.min()) * tau_latent)

        # Mechanical manipulations to generate firing rates
        idx_temp = np.tile(idx[np.newaxis, np.newaxis, :],
                           (trials, states-1, 1))
        tau_temp = tt.tile(tau[:, :, np.newaxis], (1, 1, len(idx)))
        sig_b_temp = tt.tile(
            sig_b[np.newaxis, :, np.newaxis], (trials, 1, len(idx)))

        weight_stack = var_sig_exp_tt(idx_temp-tau_temp, sig_b_temp)
        weight_stack_temp = tt.tile(
            weight_stack[:, np.newaxis, :, :], (1, nrns, 1, 1))

        s0_temp = tt.tile(s0[np.newaxis, :, np.newaxis,
                          np.newaxis], (trials, 1, states-1, len(idx)))
        lambda_diff_temp = tt.tile(
            lambda_diff[np.newaxis, :, :, np.newaxis], (trials, 1, 1, len(idx)))

        # Calculate lambda
        lambda_ = pm.Deterministic('lambda_',
                                   tt.sum(s0_temp + (weight_stack_temp*lambda_diff_temp), axis=2))
        # Bound lambda to prevent the diffs from making it negative
        # Don't let it go down to zero otherwise we have trouble with probabilities
        lambda_bounded = pm.Deterministic("lambda_bounded",
                                          tt.switch(lambda_ >= 0.01, lambda_, 0.01))

        # Add observations
        observation = pm.Poisson("obs", lambda_bounded, observed=spike_array)

    return model

single_taste_poisson_varsig_fixed(spike_array, states, inds_span=1)

Model for changepoint on single taste **Uses sigmoid with given slope

** Largely taken from "non_hardcoded_changepoint_test_3d.ipynb" ** Note : This model does not have hierarchical structure for emissions

Parameters:

Name Type Description Default
spike_array 3D Numpy array

trials x neurons x time

required
states int

Number of states to model

required
inds_span(float)

Number of indices to cover 5-95% change in sigmoid

required

Returns:

Type Description

pymc3 model: Model class containing graph to run inference on

Source code in pytau/changepoint_model.py
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
def single_taste_poisson_varsig_fixed(
        spike_array,
        states,
        inds_span=1):
    """Model for changepoint on single taste
        **Uses sigmoid with given slope

    ** Largely taken from "non_hardcoded_changepoint_test_3d.ipynb"
    ** Note : This model does not have hierarchical structure for emissions

    Args:
        spike_array (3D Numpy array): trials x neurons x time
        states (int): Number of states to model
        inds_span(float) : Number of indices to cover 5-95% change in sigmoid

    Returns:
        pymc3 model: Model class containing graph to run inference on
    """

    mean_vals = np.array([np.mean(x, axis=-1)
                          for x in np.array_split(spike_array, states, axis=-1)]).T
    mean_vals = np.mean(mean_vals, axis=1)
    mean_vals += 0.01  # To avoid zero starting prob

    lambda_test_vals = np.diff(mean_vals, axis=-1)
    even_switches = np.linspace(0, 1, states+1)[1:-1]

    nrns = spike_array.shape[1]
    trials = spike_array.shape[0]
    idx = np.arange(spike_array.shape[-1])
    length = idx.max() + 1

    # Define sigmoid with given sharpness
    sig_b = inds_to_b(inds_span)

    def sigmoid(x):
        b_temp = tt.tile(
            np.array(sig_b)[None, None, None], x.tag.test_value.shape)
        return 1/(1+tt.exp(-b_temp*x))

    with pm.Model() as model:

        # Initial value
        s0 = pm.Exponential('state0',
                            1/(np.mean(mean_vals)),
                            shape=nrns,
                            testval=mean_vals[:, 0])

        # Changes to lambda
        lambda_diff = pm.Normal('lambda_diff',
                                mu=0, sigma=10,
                                shape=(nrns, states-1),
                                testval=lambda_test_vals)

        # This is only here to be extracted at the end of sampling
        # NOT USED DIRECTLY IN MODEL
        lambda_fin = pm.Deterministic('lambda',
                                      tt.concatenate(
                                          [s0[:, np.newaxis], lambda_diff], axis=-1)
                                      )

        # Changepoint positions
        a = pm.HalfCauchy('a_tau', 10, shape=states - 1)
        b = pm.HalfCauchy('b_tau', 10, shape=states - 1)

        tau_latent = pm.Beta('tau_latent', a, b,
                             testval=even_switches,
                             shape=(trials, states-1)).sort(axis=-1)
        tau = pm.Deterministic('tau',
                               idx.min() + (idx.max() - idx.min()) * tau_latent)

        # Mechanical manipulations to generate firing rates
        idx_temp = np.tile(idx[np.newaxis, np.newaxis, :],
                           (trials, states-1, 1))
        tau_temp = tt.tile(tau[:, :, np.newaxis], (1, 1, len(idx)))

        weight_stack = sigmoid(idx_temp-tau_temp)
        weight_stack_temp = tt.tile(
            weight_stack[:, np.newaxis, :, :], (1, nrns, 1, 1))

        s0_temp = tt.tile(s0[np.newaxis, :, np.newaxis,
                          np.newaxis], (trials, 1, states-1, len(idx)))
        lambda_diff_temp = tt.tile(
            lambda_diff[np.newaxis, :, :, np.newaxis], (trials, 1, 1, len(idx)))

        # Calculate lambda
        lambda_ = pm.Deterministic('lambda_',
                                   tt.sum(s0_temp + (weight_stack_temp*lambda_diff_temp), axis=2))
        # Bound lambda to prevent the diffs from making it negative
        # Don't let it go down to zero otherwise we have trouble with probabilities
        lambda_bounded = pm.Deterministic("lambda_bounded",
                                          tt.switch(lambda_ >= 0.01, lambda_, 0.01))

        # Add observations
        observation = pm.Poisson("obs", lambda_bounded, observed=spike_array)

    return model

theano_lock_present()

Check if theano compilation lock is present

Source code in pytau/changepoint_model.py
22
23
24
25
26
def theano_lock_present():
    """
    Check if theano compilation lock is present
    """
    return os.path.exists(os.path.join(theano.config.compiledir, 'lock_dir'))

var_sig_exp_tt(x, b)

x --> b -->

Source code in pytau/changepoint_model.py
430
431
432
433
434
435
def var_sig_exp_tt(x, b):
    """
    x -->
    b -->
    """
    return 1/(1+tt.exp(-tt.exp(b)*x))

var_sig_tt(x, b)

x --> b -->

Source code in pytau/changepoint_model.py
438
439
440
441
442
443
def var_sig_tt(x, b):
    """
    x -->
    b -->
    """
    return 1/(1+tt.exp(-b*x))

=== Preprocessing functions ===

Code to preprocess spike trains before feeding into model

preprocess_all_taste(spike_array, time_lims, bin_width, data_transform)

Preprocess array containing trials for all tastes (in blocks) concatenated

Parameters:

Name Type Description Default
spike_array 4D Numpy Array

Taste x Trials x Neurons x Time

required
time_lims List/Tuple/Numpy Array

2-element object indicating limits of array

required
bin_width int

Width to use for binning

required
data_transform str

Data-type to return {actual, shuffled, simulated}

required

Raises:

Type Description
Exception

If transforms do not belong to ['shuffled','simulated','None',None]

Returns:

Type Description
4D Numpy Array

Of processed data

Source code in pytau/changepoint_preprocess.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
def preprocess_all_taste(spike_array,
                         time_lims,
                         bin_width,
                         data_transform):
    """Preprocess array containing trials for all tastes (in blocks) concatenated

    Args:
        spike_array (4D Numpy Array): Taste x Trials x Neurons x Time
        time_lims (List/Tuple/Numpy Array): 2-element object indicating limits of array
        bin_width (int): Width to use for binning
        data_transform (str): Data-type to return {actual, shuffled, simulated}

    Raises:
        Exception: If transforms do not belong to ['shuffled','simulated','None',None]

    Returns:
        (4D Numpy Array): Of processed data
    """

    accepted_transforms = ['trial_shuffled','spike_shuffled','simulated','None',None]
    if data_transform not in accepted_transforms:
        raise Exception(
            f'data_transform must be of type {accepted_transforms}')

    ##################################################
    # Create shuffled data
    ##################################################
    # Shuffle neurons across trials FOR SAME TASTE

    if data_transform == 'trial_shuffled':
        transformed_dat = np.array([np.random.permutation(neuron) \
                    for neuron in np.swapaxes(spike_array,2,0)])
        transformed_dat = np.swapaxes(transformed_dat,0,2)

    if data_transform == 'spike_shuffled':
        transformed_dat = spike_array.swapaxes(-1,0) 
        transformed_dat = np.stack([np.random.permutation(x) \
                for x in transformed_dat])
        transformed_dat = transformed_dat.swapaxes(0,-1) 

    ##################################################
    # Create simulated data
    ##################################################
    # Inhomogeneous poisson process using mean firing rates

    elif data_transform == 'simulated':
        mean_firing = np.mean(spike_array,axis=1)
        mean_firing = np.broadcast_to(mean_firing[:,None], spike_array.shape)

        # Simulate spikes
        transformed_dat = (np.random.random(spike_array.shape) < mean_firing )*1

    ##################################################
    # Null Transform Case
    ##################################################
    elif data_transform == None or data_transform == 'None':
        transformed_dat = spike_array

    ##################################################
    # Bin Data
    ##################################################
    spike_binned = \
            np.sum(transformed_dat[...,time_lims[0]:time_lims[1]].\
            reshape(*transformed_dat.shape[:-1],-1,bin_width),axis=-1)
    spike_binned = np.vectorize(np.int)(spike_binned)

    return spike_binned

preprocess_single_taste(spike_array, time_lims, bin_width, data_transform)

Preprocess array containing trials for all tastes (in blocks) concatenated

** Note, it may be useful to use x-arrays here to keep track of coordinates

Parameters:

Name Type Description Default
spike_array 3D Numpy array

trials x neurons x time

required
time_lims List/Tuple/Numpy Array

2-element object indicating limits of array

required
bin_width int

Width to use for binning

required
data_transform str

Data-type to return {actual, trial_shuffled, spike_shuffled, simulated}

required

Raises:

Type Description
Exception

If transforms do not belong to

Returns:

Type Description
3D Numpy Array

Of processed data

Source code in pytau/changepoint_preprocess.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def preprocess_single_taste(spike_array,
                            time_lims,
                            bin_width,
                            data_transform):
    """Preprocess array containing trials for all tastes (in blocks) concatenated

    ** Note, it may be useful to use x-arrays here to keep track of coordinates

    Args:
        spike_array (3D Numpy array): trials x neurons x time
        time_lims (List/Tuple/Numpy Array): 2-element object indicating limits of array
        bin_width (int): Width to use for binning
        data_transform (str): Data-type to return 
                            {actual, trial_shuffled, spike_shuffled, simulated}

    Raises:
        Exception: If transforms do not belong to 
        ['trial_shuffled','spike_shuffled','simulated','None',None]

    Returns:
        (3D Numpy Array): Of processed data
    """

    accepted_transforms = ['trial_shuffled','spike_shuffled',
            'simulated','None',None]
    if data_transform not in accepted_transforms:
        raise Exception(
            f'data_transform must be of type {accepted_transforms}')

    ##################################################
    # Create shuffled data
    ##################################################
    # Shuffle neurons across trials FOR SAME TASTE

    if data_transform == 'trial_shuffled':
        transformed_dat = np.array([np.random.permutation(neuron) \
                                    for neuron in np.swapaxes(spike_array, 1, 0)])
        transformed_dat = np.swapaxes(transformed_dat, 0, 1)

    if data_transform == 'spike_shuffled':
        transformed_dat = np.moveaxis(spike_array,-1,0) 
        transformed_dat = np.stack([np.random.permutation(x) \
                for x in transformed_dat])
        transformed_dat = np.moveaxis(transformed_dat,0,-1) 
    ##################################################
    # Create simulated data
    ##################################################
    # Inhomogeneous poisson process using mean firing rates

    elif data_transform == 'simulated':
        mean_firing = np.mean(spike_array, axis=0)

        # Simulate spikes
        transformed_dat = np.array(
            [np.random.random(mean_firing.shape) < mean_firing
             for trial in range(spike_array.shape[0])])*1

    ##################################################
    # Null Transform Case
    ##################################################
    elif data_transform in (None, 'None'):
        transformed_dat = spike_array

    ##################################################
    # Bin Data
    ##################################################
    spike_binned = \
        np.sum(transformed_dat[..., time_lims[0]:time_lims[1]].
               reshape(*spike_array.shape[:-1], -1, bin_width), axis=-1)
    spike_binned = np.vectorize(np.int)(spike_binned)

    return spike_binned