References

=== Analysis functions ===

Helper classes and functions to perform analysis on fitted models

PklHandler

Helper class to handle metadata and fit data from pkl file

Source code in pytau/changepoint_analysis.py
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
class PklHandler:
    """Helper class to handle metadata and fit data from pkl file"""

    def __init__(self, file_path):
        """Initialize PklHandler class

        Args:
            file_path (str): Path to pkl file
        """
        self.dir_name = os.path.dirname(file_path)
        file_name = os.path.basename(file_path)
        self.file_name_base = file_name.split(".")[0]
        self.pkl_file_path = os.path.join(
            self.dir_name, self.file_name_base + ".pkl")
        with open(self.pkl_file_path, "rb") as this_file:
            self.data = pkl.load(this_file)

        model_keys = ["model", "approx", "lambda", "tau", "data"]
        key_savenames = [
            "_model_structure",
            "_fit_model",
            "lambda_array",
            "tau_array",
            "processed_spikes",
        ]
        data_map = dict(zip(model_keys, key_savenames))

        for key, var_name in data_map.items():
            if key in self.data["model_data"]:
                setattr(self, var_name, self.data["model_data"][key])
            else:
                # Set to None if key is missing (e.g., due to pickling fallback)
                setattr(self, var_name, None)

        self.metadata = self.data["metadata"]
        self.pretty_metadata = pd.json_normalize(self.data["metadata"]).T

        # Get number of trials from processed_spikes for proper tau formatting
        n_trials = self.processed_spikes.shape[0] if hasattr(
            self.processed_spikes, 'shape') else None
        self.tau = _tau(self.tau_array, self.metadata, n_trials)
        self.firing = _firing(self.tau, self.processed_spikes, self.metadata)

__init__(file_path)

Initialize PklHandler class

Parameters:

Name Type Description Default
file_path str

Path to pkl file

required
Source code in pytau/changepoint_analysis.py
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
def __init__(self, file_path):
    """Initialize PklHandler class

    Args:
        file_path (str): Path to pkl file
    """
    self.dir_name = os.path.dirname(file_path)
    file_name = os.path.basename(file_path)
    self.file_name_base = file_name.split(".")[0]
    self.pkl_file_path = os.path.join(
        self.dir_name, self.file_name_base + ".pkl")
    with open(self.pkl_file_path, "rb") as this_file:
        self.data = pkl.load(this_file)

    model_keys = ["model", "approx", "lambda", "tau", "data"]
    key_savenames = [
        "_model_structure",
        "_fit_model",
        "lambda_array",
        "tau_array",
        "processed_spikes",
    ]
    data_map = dict(zip(model_keys, key_savenames))

    for key, var_name in data_map.items():
        if key in self.data["model_data"]:
            setattr(self, var_name, self.data["model_data"][key])
        else:
            # Set to None if key is missing (e.g., due to pickling fallback)
            setattr(self, var_name, None)

    self.metadata = self.data["metadata"]
    self.pretty_metadata = pd.json_normalize(self.data["metadata"]).T

    # Get number of trials from processed_spikes for proper tau formatting
    n_trials = self.processed_spikes.shape[0] if hasattr(
        self.processed_spikes, 'shape') else None
    self.tau = _tau(self.tau_array, self.metadata, n_trials)
    self.firing = _firing(self.tau, self.processed_spikes, self.metadata)

calc_significant_neurons_firing(state_firing, p_val=0.05)

Calculate significant changes in firing rate between states Iterate ANOVA over neurons for all states With Bonferroni correction

Args state_firing (3D Numpy array): trials x states x nrns p_val (float, optional): p-value to use for significance. Defaults to 0.05.

Returns:

Name Type Description
anova_p_val_array 1D Numpy array

p-values for each neuron

anova_sig_neurons 1D Numpy array

indices of significant neurons

Source code in pytau/changepoint_analysis.py
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
def calc_significant_neurons_firing(state_firing, p_val=0.05):
    """Calculate significant changes in firing rate between states
    Iterate ANOVA over neurons for all states
    With Bonferroni correction

    Args
        state_firing (3D Numpy array): trials x states x nrns
        p_val (float, optional): p-value to use for significance. Defaults to 0.05.

    Returns:
        anova_p_val_array (1D Numpy array): p-values for each neuron
        anova_sig_neurons (1D Numpy array): indices of significant neurons
    """
    n_neurons = state_firing.shape[-1]
    # Calculate ANOVA p-values for each neuron
    anova_p_val_array = np.zeros(state_firing.shape[-1])
    for neuron in range(state_firing.shape[-1]):
        anova_p_val_array[neuron] = f_oneway(*state_firing[:, :, neuron].T)[1]
    anova_sig_neurons = np.where(anova_p_val_array < p_val / n_neurons)[0]

    return anova_p_val_array, anova_sig_neurons

calc_significant_neurons_snippets(transition_snips, p_val=0.05)

Calculate pairwise t-tests to detect differences between each transition With Bonferroni correction

Args transition_snips (4D Numpy array): trials x nrns x bins x transitions p_val (float, optional): p-value to use for significance. Defaults to 0.05.

Returns:

Name Type Description
anova_p_val_array (neurons, transition)

p-values for each neuron

anova_sig_neurons (neurons, transition)

indices of significant neurons

Source code in pytau/changepoint_analysis.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
def calc_significant_neurons_snippets(transition_snips, p_val=0.05):
    """Calculate pairwise t-tests to detect differences between each transition
    With Bonferroni correction

    Args
        transition_snips (4D Numpy array): trials x nrns x bins x transitions
        p_val (float, optional): p-value to use for significance. Defaults to 0.05.

    Returns:
        anova_p_val_array (neurons, transition): p-values for each neuron
        anova_sig_neurons (neurons, transition): indices of significant neurons
    """
    # Calculate pairwise t-tests for each transition
    # shape : [before, after] x trials x neurons x transitions
    mean_transition_snips = np.stack(np.array_split(
        transition_snips, 2, axis=2)).mean(axis=3)
    pairwise_p_val_array = np.zeros(mean_transition_snips.shape[2:])
    n_neuron, n_transitions = pairwise_p_val_array.shape
    for neuron in range(n_neuron):
        for transition in range(n_transitions):
            pairwise_p_val_array[neuron, transition] = ttest_rel(
                *mean_transition_snips[:, :, neuron, transition]
            )[1]
    pairwise_sig_neurons = pairwise_p_val_array < p_val  # /n_neuron
    return pairwise_p_val_array, pairwise_sig_neurons

get_state_firing(spike_array, tau_array)

Calculate firing rates within states given changepoint positions on data

Parameters:

Name Type Description Default
spike_array 3D Numpy array

trials x nrns x bins

required
tau_array 2D Numpy array

trials x switchpoints

required

Returns:

Name Type Description
state_firing 3D Numpy array

trials x states x nrns

Source code in pytau/changepoint_analysis.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def get_state_firing(spike_array, tau_array):
    """Calculate firing rates within states given changepoint positions on data

    Args:
        spike_array (3D Numpy array): trials x nrns x bins
        tau_array (2D Numpy array): trials x switchpoints

    Returns:
        state_firing (3D Numpy array): trials x states x nrns
    """

    states = tau_array.shape[-1] + 1
    # Get mean firing rate for each STATE using model
    state_inds = np.hstack(
        [
            np.zeros((tau_array.shape[0], 1)),
            tau_array,
            np.ones((tau_array.shape[0], 1)) * spike_array.shape[-1],
        ]
    )
    state_lims = np.array([state_inds[:, x: x + 2] for x in range(states)])
    state_lims = np.vectorize(int)(state_lims)
    state_lims = np.swapaxes(state_lims, 0, 1)

    state_firing = np.array(
        [
            [np.mean(trial_dat[:, start:end], axis=-1)
             for start, end in trial_lims]
            for trial_dat, trial_lims in zip(spike_array, state_lims)
        ]
    )

    state_firing = np.nan_to_num(state_firing)
    return state_firing

get_transition_snips(spike_array, tau_array, window_radius=300)

Get snippets of activty around changepoints for each trial

Parameters:

Name Type Description Default
spike_array 3D Numpy array

trials x nrns x bins

required
tau_array 2D Numpy array

trials x switchpoints

required

Returns:

Type Description

Numpy array: Transition snippets : trials x nrns x bins x transitions

Make sure none of the snippets are outside the bounds of the data

Source code in pytau/changepoint_analysis.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def get_transition_snips(spike_array, tau_array, window_radius=300):
    """Get snippets of activty around changepoints for each trial

    Args:
        spike_array (3D Numpy array): trials x nrns x bins
        tau_array (2D Numpy array): trials x switchpoints

    Returns:
        Numpy array: Transition snippets : trials x nrns x bins x transitions

    Make sure none of the snippets are outside the bounds of the data
    """
    # Get snippets of activity around changepoints for each trial
    n_trials, n_neurons, n_bins = spike_array.shape
    n_transitions = tau_array.shape[1]
    transition_snips = np.zeros(
        (n_trials, n_neurons, 2 * window_radius, n_transitions))
    window_lims = np.stack(
        [tau_array - window_radius, tau_array + window_radius], axis=-1)

    # Make sure no lims are outside the bounds of the data
    if (window_lims < 0).sum(axis=None) or (window_lims > n_bins).sum(axis=None):
        raise ValueError("Transition window extends outside data bounds")

    # Pull out snippets
    for trial in range(n_trials):
        for transition in range(n_transitions):
            transition_snips[trial, :, :, transition] = spike_array[
                trial,
                :,
                window_lims[trial, transition, 0]: window_lims[trial, transition, 1],
            ]
    return transition_snips

=== I/O functions ===

Pipeline to handle model fitting from data extraction to saving results

DatabaseHandler

Class to handle transactions with model database

Source code in pytau/changepoint_io.py
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
class DatabaseHandler:
    """Class to handle transactions with model database"""

    def __init__(self):
        """Initialize DatabaseHandler class"""
        self.unique_cols = ["exp.model_id", "exp.save_path", "exp.fit_date"]
        self.model_database_path = MODEL_DATABASE_PATH
        self.model_save_base_dir = MODEL_SAVE_DIR

        if os.path.exists(self.model_database_path):
            self.fit_database = pd.read_csv(
                self.model_database_path, index_col=0)
            all_na = [all(x) for num, x in self.fit_database.isna().iterrows()]
            if all_na:
                print(f"{sum(all_na)} rows found with all NA, removing...")
                self.fit_database = self.fit_database.dropna(how="all")
        else:
            print("Fit database does not exist yet")

    def show_duplicates(self, keep="first"):
        """Find duplicates in database

        Args:
            keep (str, optional): Which duplicate to keep
                    (refer to pandas duplicated). Defaults to 'first'.

        Returns:
            pandas dataframe: Dataframe containing duplicated rows
            pandas series : Indices of duplicated rows
        """
        dup_inds = self.fit_database.drop(
            self.unique_cols, axis=1).duplicated(keep=keep)
        return self.fit_database.loc[dup_inds], dup_inds

    def drop_duplicates(self):
        """Remove duplicated rows from database"""
        _, dup_inds = self.show_duplicates()
        print(f"Removing {sum(dup_inds)} duplicate rows")
        self.fit_database = self.fit_database.loc[~dup_inds]

    def check_mismatched_paths(self):
        """Check if there are any mismatched pkl files between database and directory

        Returns:
            pandas dataframe: Dataframe containing rows for which pkl file not present
            list: pkl files which cannot be matched to model in database
            list: all files in save directory
        """
        mismatch_from_database = [
            not os.path.exists(x + ".pkl") for x in self.fit_database["exp.save_path"]
        ]
        file_list = glob(os.path.join(self.model_save_base_dir, "*/*.pkl"))
        # Only split basename by '.' in case there are multiple '.' in filenpath
        mismatch_from_file = [
            not (
                os.path.join(
                    os.path.dirname(x),
                    os.path.basename(x).split(".")[0])
                in list(self.fit_database["exp.save_path"]))
            for x in file_list
        ]
        print(
            f"{sum(mismatch_from_database)} mismatches from database"
            + "\n"
            + f"{sum(mismatch_from_file)} mismatches from files"
        )
        return mismatch_from_database, mismatch_from_file, file_list

    def clear_mismatched_paths(self):
        """Remove mismatched files and rows in database

        i.e. Remove
        1) Files for which no entry can be found in database
        2) Database entries for which no corresponding file can be found
        """
        (
            mismatch_from_database,
            mismatch_from_file,
            file_list,
        ) = self.check_mismatched_paths()
        mismatch_from_file = np.array(mismatch_from_file)
        mismatch_from_database = np.array(mismatch_from_database)
        self.fit_database = self.fit_database.loc[~mismatch_from_database]
        mismatched_files = [x for x, y in zip(
            file_list, mismatch_from_file) if y]
        for x in mismatched_files:
            os.remove(x)
        print("==== Clearing Completed ====")

    def write_updated_database(self):
        """Can be called following clear_mismatched_entries to update current database"""
        database_backup_dir = os.path.join(
            self.model_save_base_dir, ".database_backups")
        if not os.path.exists(database_backup_dir):
            os.makedirs(database_backup_dir)
        # current_date = date.today().strftime("%m-%d-%y")
        current_date = str(datetime.now()).replace(" ", "_")
        shutil.copy(
            self.model_database_path,
            os.path.join(database_backup_dir,
                         f"database_backup_{current_date}"),
        )
        self.fit_database.to_csv(self.model_database_path, mode="w")

    def set_run_params(self, data_dir, experiment_name, taste_num, laser_type, region_name):
        """Store metadata related to inference run

        Args:
            data_dir (str): Path to directory containing HDF5 file
            experiment_name (str): Name given to fitted batch
                    (for metedata). Defaults to None.
            taste_num (int): Index of taste to perform fit on (Corresponds to
                    INDEX of taste in spike array, not actual dig_ins)
            laser_type (None or str): None, 'on', or 'off' (For a laser session,
                    which set of trials are wanted, None indicated return all trials)
            region_name (str): Region on which to perform fit on
                    (must match regions in .info file)
        """
        self.data_dir = data_dir
        self.data_basename = os.path.basename(self.data_dir)
        self.animal_name = self.data_basename.split("_")[0]
        self.session_date = self.data_basename.split("_")[-1]

        self.experiment_name = experiment_name
        self.model_save_dir = os.path.join(
            self.model_save_base_dir, experiment_name)

        if not os.path.exists(self.model_save_dir):
            os.makedirs(self.model_save_dir)

        self.model_id = str(uuid.uuid4()).split("-")[0]
        self.model_save_path = os.path.join(
            self.model_save_dir, self.experiment_name + "_" + self.model_id
        )
        self.fit_date = date.today().strftime("%m-%d-%y")

        self.taste_num = taste_num
        self.laser_type = laser_type
        self.region_name = region_name

        self.fit_exists = None

    def ingest_fit_data(self, met_dict):
        """Load external metadata

        Args:
            met_dict (dict): Dictionary of metadata from FitHandler class
        """
        self.external_metadata = met_dict

    def aggregate_metadata(self):
        """Collects information regarding data and current "experiment"

        Raises:
            Exception: If 'external_metadata' has not been ingested, that needs to be done first

        Returns:
            dict: Dictionary of metadata given to FitHandler class
        """
        if "external_metadata" not in dir(self):
            raise Exception(
                "Fit run metdata needs to be ingested " "into data_handler first")

        data_details = dict(
            zip(
                [
                    "data_dir",
                    "basename",
                    "animal_name",
                    "session_date",
                    "taste_num",
                    "laser_type",
                    "region_name",
                ],
                [
                    self.data_dir,
                    self.data_basename,
                    self.animal_name,
                    self.session_date,
                    self.taste_num,
                    self.laser_type,
                    self.region_name,
                ],
            )
        )

        exp_details = dict(
            zip(
                ["exp_name", "model_id", "save_path", "fit_date"],
                [
                    self.experiment_name,
                    self.model_id,
                    self.model_save_path,
                    self.fit_date,
                ],
            )
        )

        module_details = dict(
            zip(
                ["pymc_version", "theano_version"],
                [pymc.__version__, theano.__version__],
            )
        )

        temp_ext_met = self.external_metadata
        temp_ext_met["data"] = data_details
        temp_ext_met["exp"] = exp_details
        temp_ext_met["module"] = module_details

        return temp_ext_met

    def write_to_database(self):
        """Write out metadata to database"""
        agg_metadata = self.aggregate_metadata()
        # Convert model_kwargs to str so that they are save appropriately
        agg_metadata["model"]["model_kwargs"] = str(
            agg_metadata["model"]["model_kwargs"])
        flat_metadata = pd.json_normalize(agg_metadata)
        if not os.path.isfile(self.model_database_path):
            flat_metadata.to_csv(self.model_database_path, mode="a")
        else:
            flat_metadata.to_csv(self.model_database_path,
                                 mode="a", header=False)
        print(f"Updated model database @ {self.model_database_path}")

    def check_exists(self):
        """Check if the given fit already exists in database

        Returns:
            bool: Boolean for whether fit already exists or not
        """
        if self.fit_exists is not None:
            return self.fit_exists

__init__()

Initialize DatabaseHandler class

Source code in pytau/changepoint_io.py
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
def __init__(self):
    """Initialize DatabaseHandler class"""
    self.unique_cols = ["exp.model_id", "exp.save_path", "exp.fit_date"]
    self.model_database_path = MODEL_DATABASE_PATH
    self.model_save_base_dir = MODEL_SAVE_DIR

    if os.path.exists(self.model_database_path):
        self.fit_database = pd.read_csv(
            self.model_database_path, index_col=0)
        all_na = [all(x) for num, x in self.fit_database.isna().iterrows()]
        if all_na:
            print(f"{sum(all_na)} rows found with all NA, removing...")
            self.fit_database = self.fit_database.dropna(how="all")
    else:
        print("Fit database does not exist yet")

aggregate_metadata()

Collects information regarding data and current "experiment"

Raises:

Type Description
Exception

If 'external_metadata' has not been ingested, that needs to be done first

Returns:

Name Type Description
dict

Dictionary of metadata given to FitHandler class

Source code in pytau/changepoint_io.py
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
def aggregate_metadata(self):
    """Collects information regarding data and current "experiment"

    Raises:
        Exception: If 'external_metadata' has not been ingested, that needs to be done first

    Returns:
        dict: Dictionary of metadata given to FitHandler class
    """
    if "external_metadata" not in dir(self):
        raise Exception(
            "Fit run metdata needs to be ingested " "into data_handler first")

    data_details = dict(
        zip(
            [
                "data_dir",
                "basename",
                "animal_name",
                "session_date",
                "taste_num",
                "laser_type",
                "region_name",
            ],
            [
                self.data_dir,
                self.data_basename,
                self.animal_name,
                self.session_date,
                self.taste_num,
                self.laser_type,
                self.region_name,
            ],
        )
    )

    exp_details = dict(
        zip(
            ["exp_name", "model_id", "save_path", "fit_date"],
            [
                self.experiment_name,
                self.model_id,
                self.model_save_path,
                self.fit_date,
            ],
        )
    )

    module_details = dict(
        zip(
            ["pymc_version", "theano_version"],
            [pymc.__version__, theano.__version__],
        )
    )

    temp_ext_met = self.external_metadata
    temp_ext_met["data"] = data_details
    temp_ext_met["exp"] = exp_details
    temp_ext_met["module"] = module_details

    return temp_ext_met

check_exists()

Check if the given fit already exists in database

Returns:

Name Type Description
bool

Boolean for whether fit already exists or not

Source code in pytau/changepoint_io.py
675
676
677
678
679
680
681
682
def check_exists(self):
    """Check if the given fit already exists in database

    Returns:
        bool: Boolean for whether fit already exists or not
    """
    if self.fit_exists is not None:
        return self.fit_exists

check_mismatched_paths()

Check if there are any mismatched pkl files between database and directory

Returns:

Name Type Description

pandas dataframe: Dataframe containing rows for which pkl file not present

list

pkl files which cannot be matched to model in database

list

all files in save directory

Source code in pytau/changepoint_io.py
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
def check_mismatched_paths(self):
    """Check if there are any mismatched pkl files between database and directory

    Returns:
        pandas dataframe: Dataframe containing rows for which pkl file not present
        list: pkl files which cannot be matched to model in database
        list: all files in save directory
    """
    mismatch_from_database = [
        not os.path.exists(x + ".pkl") for x in self.fit_database["exp.save_path"]
    ]
    file_list = glob(os.path.join(self.model_save_base_dir, "*/*.pkl"))
    # Only split basename by '.' in case there are multiple '.' in filenpath
    mismatch_from_file = [
        not (
            os.path.join(
                os.path.dirname(x),
                os.path.basename(x).split(".")[0])
            in list(self.fit_database["exp.save_path"]))
        for x in file_list
    ]
    print(
        f"{sum(mismatch_from_database)} mismatches from database"
        + "\n"
        + f"{sum(mismatch_from_file)} mismatches from files"
    )
    return mismatch_from_database, mismatch_from_file, file_list

clear_mismatched_paths()

Remove mismatched files and rows in database

i.e. Remove 1) Files for which no entry can be found in database 2) Database entries for which no corresponding file can be found

Source code in pytau/changepoint_io.py
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
def clear_mismatched_paths(self):
    """Remove mismatched files and rows in database

    i.e. Remove
    1) Files for which no entry can be found in database
    2) Database entries for which no corresponding file can be found
    """
    (
        mismatch_from_database,
        mismatch_from_file,
        file_list,
    ) = self.check_mismatched_paths()
    mismatch_from_file = np.array(mismatch_from_file)
    mismatch_from_database = np.array(mismatch_from_database)
    self.fit_database = self.fit_database.loc[~mismatch_from_database]
    mismatched_files = [x for x, y in zip(
        file_list, mismatch_from_file) if y]
    for x in mismatched_files:
        os.remove(x)
    print("==== Clearing Completed ====")

drop_duplicates()

Remove duplicated rows from database

Source code in pytau/changepoint_io.py
483
484
485
486
487
def drop_duplicates(self):
    """Remove duplicated rows from database"""
    _, dup_inds = self.show_duplicates()
    print(f"Removing {sum(dup_inds)} duplicate rows")
    self.fit_database = self.fit_database.loc[~dup_inds]

ingest_fit_data(met_dict)

Load external metadata

Parameters:

Name Type Description Default
met_dict dict

Dictionary of metadata from FitHandler class

required
Source code in pytau/changepoint_io.py
591
592
593
594
595
596
597
def ingest_fit_data(self, met_dict):
    """Load external metadata

    Args:
        met_dict (dict): Dictionary of metadata from FitHandler class
    """
    self.external_metadata = met_dict

set_run_params(data_dir, experiment_name, taste_num, laser_type, region_name)

Store metadata related to inference run

Parameters:

Name Type Description Default
data_dir str

Path to directory containing HDF5 file

required
experiment_name str

Name given to fitted batch (for metedata). Defaults to None.

required
taste_num int

Index of taste to perform fit on (Corresponds to INDEX of taste in spike array, not actual dig_ins)

required
laser_type None or str

None, 'on', or 'off' (For a laser session, which set of trials are wanted, None indicated return all trials)

required
region_name str

Region on which to perform fit on (must match regions in .info file)

required
Source code in pytau/changepoint_io.py
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
def set_run_params(self, data_dir, experiment_name, taste_num, laser_type, region_name):
    """Store metadata related to inference run

    Args:
        data_dir (str): Path to directory containing HDF5 file
        experiment_name (str): Name given to fitted batch
                (for metedata). Defaults to None.
        taste_num (int): Index of taste to perform fit on (Corresponds to
                INDEX of taste in spike array, not actual dig_ins)
        laser_type (None or str): None, 'on', or 'off' (For a laser session,
                which set of trials are wanted, None indicated return all trials)
        region_name (str): Region on which to perform fit on
                (must match regions in .info file)
    """
    self.data_dir = data_dir
    self.data_basename = os.path.basename(self.data_dir)
    self.animal_name = self.data_basename.split("_")[0]
    self.session_date = self.data_basename.split("_")[-1]

    self.experiment_name = experiment_name
    self.model_save_dir = os.path.join(
        self.model_save_base_dir, experiment_name)

    if not os.path.exists(self.model_save_dir):
        os.makedirs(self.model_save_dir)

    self.model_id = str(uuid.uuid4()).split("-")[0]
    self.model_save_path = os.path.join(
        self.model_save_dir, self.experiment_name + "_" + self.model_id
    )
    self.fit_date = date.today().strftime("%m-%d-%y")

    self.taste_num = taste_num
    self.laser_type = laser_type
    self.region_name = region_name

    self.fit_exists = None

show_duplicates(keep='first')

Find duplicates in database

Parameters:

Name Type Description Default
keep str

Which duplicate to keep (refer to pandas duplicated). Defaults to 'first'.

'first'

Returns:

Type Description

pandas dataframe: Dataframe containing duplicated rows

pandas series : Indices of duplicated rows

Source code in pytau/changepoint_io.py
468
469
470
471
472
473
474
475
476
477
478
479
480
481
def show_duplicates(self, keep="first"):
    """Find duplicates in database

    Args:
        keep (str, optional): Which duplicate to keep
                (refer to pandas duplicated). Defaults to 'first'.

    Returns:
        pandas dataframe: Dataframe containing duplicated rows
        pandas series : Indices of duplicated rows
    """
    dup_inds = self.fit_database.drop(
        self.unique_cols, axis=1).duplicated(keep=keep)
    return self.fit_database.loc[dup_inds], dup_inds

write_to_database()

Write out metadata to database

Source code in pytau/changepoint_io.py
661
662
663
664
665
666
667
668
669
670
671
672
673
def write_to_database(self):
    """Write out metadata to database"""
    agg_metadata = self.aggregate_metadata()
    # Convert model_kwargs to str so that they are save appropriately
    agg_metadata["model"]["model_kwargs"] = str(
        agg_metadata["model"]["model_kwargs"])
    flat_metadata = pd.json_normalize(agg_metadata)
    if not os.path.isfile(self.model_database_path):
        flat_metadata.to_csv(self.model_database_path, mode="a")
    else:
        flat_metadata.to_csv(self.model_database_path,
                             mode="a", header=False)
    print(f"Updated model database @ {self.model_database_path}")

write_updated_database()

Can be called following clear_mismatched_entries to update current database

Source code in pytau/changepoint_io.py
538
539
540
541
542
543
544
545
546
547
548
549
550
551
def write_updated_database(self):
    """Can be called following clear_mismatched_entries to update current database"""
    database_backup_dir = os.path.join(
        self.model_save_base_dir, ".database_backups")
    if not os.path.exists(database_backup_dir):
        os.makedirs(database_backup_dir)
    # current_date = date.today().strftime("%m-%d-%y")
    current_date = str(datetime.now()).replace(" ", "_")
    shutil.copy(
        self.model_database_path,
        os.path.join(database_backup_dir,
                     f"database_backup_{current_date}"),
    )
    self.fit_database.to_csv(self.model_database_path, mode="w")

FitHandler

Class to handle pipeline of model fitting including: 1) Loading data 2) Preprocessing loaded arrays 3) Fitting model 4) Writing out fitted parameters to pkl file

Source code in pytau/changepoint_io.py
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
class FitHandler:
    """Class to handle pipeline of model fitting including:
    1) Loading data
    2) Preprocessing loaded arrays
    3) Fitting model
    4) Writing out fitted parameters to pkl file

    """

    def __init__(
        self,
        data_dir,
        taste_num,
        region_name,
        laser_type=None,
        experiment_name=None,
        model_params_path=None,
        preprocess_params_path=None,
    ):
        """Initialize FitHandler class

        Args:
            data_dir (str): Path to directory containing HDF5 file
            taste_num (int): Index of taste to perform fit on
                    (Corresponds to INDEX of taste in spike array, not actual dig_ins)
            region_name (str): Region on which to perform fit on
                    (must match regions in .info file)
            experiment_name (str, optional): Name given to fitted batch
                    (for metedata). Defaults to None.
            model_params_path (str, optional): Path to json file
                    containing model parameters. Defaults to None.
            preprocess_params_path (str, optional): Path to json file
                    containing preprocessing parameters. Defaults to None.

        Raises:
            Exception: If "experiment_name" is None
            Exception: If "laser_type" is not in [None, 'on', 'off']
            Exception: If "taste_num" is not integer or "all"
        """

        # =============== Check for exceptions ===============
        if experiment_name is None:
            raise Exception("Please specify an experiment name")
        if laser_type not in [None, "on", "off"]:
            raise Exception('laser_type must be from [None, "on","off"]')
        if not (isinstance(taste_num, int) or taste_num == "all"):
            raise Exception('taste_num must be an integer or "all"')

        # =============== Save relevant arguments ===============
        self.data_dir = data_dir
        self.EphysData = EphysData(self.data_dir)
        # self.data = self.EphysData.get_spikes({"bla","gc","all"})

        self.taste_num = taste_num
        self.laser_type = laser_type
        self.region_name = region_name
        self.experiment_name = experiment_name

        data_handler_init_kwargs = dict(
            zip(
                [
                    "data_dir",
                    "experiment_name",
                    "taste_num",
                    "laser_type",
                    "region_name",
                ],
                [data_dir, experiment_name, taste_num, laser_type, region_name],
            )
        )
        self.database_handler = DatabaseHandler()
        self.database_handler.set_run_params(**data_handler_init_kwargs)

        if model_params_path is None:
            print("MODEL_PARAMS will have to be set")
        else:
            self.set_model_params(file_path=model_params_path)

        if preprocess_params_path is None:
            print("PREPROCESS_PARAMS will have to be set")
        else:
            self.set_preprocess_params(file_path=preprocess_params_path)

    ########################################
    # SET PARAMS
    ########################################

    def set_preprocess_params(self, time_lims, bin_width, data_transform, file_path=None):
        """Load given params as "preprocess_params" attribute

        Args:
            time_lims (array/tuple/list): Start and end of where to cut
                    spike train array
            bin_width (int): Bin width for binning spikes to counts
            data_transform (str): Indicator for which transformation to
                    use (refer to changepoint_preprocess)
            file_path (str, optional): Path to json file containing preprocess
                    parameters. Defaults to None.
        """

        if file_path is None:
            preprocess_params_dict = dict(
                zip(
                    ["time_lims", "bin_width", "data_transform"],
                    [time_lims, bin_width, data_transform],
                )
            )
            self.preprocess_params = preprocess_params_dict
            print("Set preprocess params to: {}".format(preprocess_params_dict))

        else:
            # Load json and save dict
            pass

    def set_model_params(self, states, fit, samples, model_kwargs=None, file_path=None):
        """Load given params as "model_params" attribute

        Args:
            states (int): Number of states to use in model
            fit (int): Iterations to use for model fitting (given ADVI fit)
            samples (int): Number of samples to return from fitten model
            model_kwargs (dict) : Additional paramters for model
            file_path (str, optional): Path to json file containing
                    preprocess parameters. Defaults to None.
        """

        if file_path is None:
            model_params_dict = dict(
                zip(
                    ["states", "fit", "samples", "model_kwargs"],
                    [states, fit, samples, model_kwargs],
                )
            )
            self.model_params = model_params_dict
            print("Set model params to: {}".format(model_params_dict))

        else:
            # Load json and save dict
            pass

    ########################################
    # SET PIPELINE FUNCS
    ########################################

    def set_preprocessor(self, preprocessing_func):
        """Manually set preprocessor for data e.g.

        FitHandler.set_preprocessor(
                    changepoint_preprocess.preprocess_single_taste)

        Args:
            preprocessing_func (func):
                    Function to preprocess data (refer to changepoint_preprocess)
        """
        self.preprocessor = preprocessing_func

    def preprocess_selector(self):
        """Function to return preprocess function based off of input flag

        Preprocessing can be set manually but it is preferred to
        go through preprocess selector

        Raises:
            Exception: If self.taste_num is neither int nor str

        """

        if isinstance(self.taste_num, int):
            self.set_preprocessor(
                changepoint_preprocess.preprocess_single_taste)
        elif self.taste_num == "all":
            self.set_preprocessor(changepoint_preprocess.preprocess_all_taste)
        else:
            raise Exception("Something went wrong")

    def set_model_template(self, model_template):
        """Manually set model_template for data e.g.

        FitHandler.set_model(changepoint_model.single_taste_poisson)

        Args:
            model_template (func): Function to generate model template for data]
        """
        self.model_template = model_template

    def model_template_selector(self):
        """Function to set model based off of input flag

        Models can be set manually but it is preferred to go through model selector

        Raises:
            Exception: If self.taste_num is neither int nor str

        """
        if isinstance(self.taste_num, int):
            # self.set_model_template(changepoint_model.single_taste_poisson_varsig)
            self.set_model_template(changepoint_model.single_taste_poisson)
        elif self.taste_num == "all":
            self.set_model_template(changepoint_model.all_taste_poisson)
        else:
            raise Exception("Something went wrong")

    def set_inference(self, inference_func):
        """Manually set inference function for model fit e.g.

        FitHandler.set_inference(changepoint_model.advi_fit)

        Args:
            inference_func (func): Function to use for fitting model
        """
        self.inference_func = changepoint_model.advi_fit

    def inference_func_selector(self):
        """Function to return model based off of input flag

        Currently hard-coded to use "advi_fit"
        """
        self.set_inference(changepoint_model.advi_fit)

    ########################################
    # PIPELINE FUNCS
    ########################################

    def load_spike_trains(self):
        """Helper function to load spike trains from data_dir using EphysData module"""
        full_spike_array = self.EphysData.return_region_spikes(
            region_name=self.region_name, laser=self.laser_type
        )
        if isinstance(self.taste_num, int):
            self.data = full_spike_array[self.taste_num]
        if self.taste_num == "all":
            self.data = full_spike_array
        print(
            f"Loading spike trains from {self.database_handler.data_basename}, "
            f"dig_in {self.taste_num}, laser {str(self.laser_type)}"
        )

    def preprocess_data(self):
        """Perform data preprocessing

        Will check for and complete:
        1) Raw data loaded
        2) Preprocessor selected
        """
        if "data" not in dir(self):
            self.load_spike_trains()
        if "preprocessor" not in dir(self):
            self.preprocess_selector()
        print(
            "Preprocessing spike trains, " f"preprocessing func: <{self.preprocessor.__name__}>")
        self.preprocessed_data = self.preprocessor(
            self.data, **self.preprocess_params)

    def create_model(self):
        """Create model and save as attribute

        Will check for and complete:
        1) Data preprocessed
        2) Model template selected
        """
        if "preprocessed_data" not in dir(self):
            self.preprocess_data()
        if "model_template" not in dir(self):
            self.model_template_selector()

        # In future iterations, before fitting model,
        # check that a similar entry doesn't exist

        print(
            f"Generating Model, model func: <{self.model_template.__name__}>")
        self.model = self.model_template(
            self.preprocessed_data,
            self.model_params["states"],
            **self.model_params["model_kwargs"],
        )

    def run_inference(self):
        """Perform inference on data

        Will check for and complete:
        1) Model created
        2) Inference function selected
        """
        if "model" not in dir(self):
            self.create_model()
        if "inference_func" not in dir(self):
            self.inference_func_selector()

        print(
            "Running inference, inference func: " f"<{self.inference_func.__name__}>")
        temp_outs = self.inference_func(
            self.model, self.model_params["fit"], self.model_params["samples"]
        )
        varnames = ["model", "approx", "lambda", "tau", "data"]
        self.inference_outs = dict(zip(varnames, temp_outs))

        # If data is None (e.g., from advi_fit to avoid PyMC5 compatibility issues),
        # use the preprocessed_data that was used for inference
        if self.inference_outs.get("data") is None and hasattr(self, "preprocessed_data"):
            self.inference_outs["data"] = self.preprocessed_data

    def _gen_fit_metadata(self):
        """Generate metadata for fit

        Generate metadat by compiling:
        1) Preprocess parameters given as input
        2) Model parameters given as input
        3) Functions used in inference pipeline for : preprocessing,
                model generation, fitting

        Returns:
            dict: Dictionary containing compiled metadata for different
                    parts of inference pipeline
        """
        pre_params = self.preprocess_params
        model_params = self.model_params
        pre_params["preprocessor_name"] = self.preprocessor.__name__
        model_params["model_template_name"] = self.model_template.__name__
        model_params["inference_func_name"] = self.inference_func.__name__
        fin_dict = dict(zip(["preprocess", "model"],
                        [pre_params, model_params]))
        return fin_dict

    def _pass_metadata_to_handler(self):
        """Function to coordinate transfer of metadata to DatabaseHandler"""
        self.database_handler.ingest_fit_data(self._gen_fit_metadata())

    def _return_fit_output(self):
        """Compile data, model, fit, and metadata to save output

        Returns:
            dict: Dictionary containing fitted model data and metadata
        """
        self._pass_metadata_to_handler()
        agg_metadata = self.database_handler.aggregate_metadata()
        return {"model_data": self.inference_outs, "metadata": agg_metadata}

    def save_fit_output(self):
        """Save fit output (fitted data + metadata) to pkl file"""
        if "inference_outs" not in dir(self):
            self.run_inference()
        out_dict = self._return_fit_output()

        # Save output to pkl file
        with open(self.database_handler.model_save_path + ".pkl", "wb") as buff:
            pickle.dump(out_dict, buff)
        print(
            f"Saved full output to {self.database_handler.model_save_path}.pkl")

        # # Create a copy without the model to avoid pickling issues with PyMC5
        # picklable_dict = out_dict.copy()
        # if "model_data" in picklable_dict and "model" in picklable_dict["model_data"]:
        #     picklable_model_data = picklable_dict["model_data"].copy()
        #     # Remove the model object as it contains unpicklable local functions in PyMC5
        #     picklable_model_data.pop("model", None)
        #     picklable_dict["model_data"] = picklable_model_data
        #
        # with open(self.database_handler.model_save_path + ".pkl", "wb") as buff:
        #     try:
        #         pickle.dump(picklable_dict, buff)
        #     except (TypeError, AttributeError) as e:
        #         print(
        #             f"Warning: Full pickling failed ({e}). Saving metadata-only version.")
        #         # If pickling fails, save only metadata and basic info
        #         model_data_fallback = {
        #             "tau_array": picklable_dict.get("model_data", {}).get("tau_array"),
        #             "processed_spikes": picklable_dict.get("model_data", {}).get("processed_spikes"),
        #         }
        #
        #         # Try to save approx.hist for ELBO plotting if available
        #         approx_obj = picklable_dict.get("model_data", {}).get("approx")
        #         if approx_obj and hasattr(approx_obj, 'hist'):
        #             try:
        #                 # Create a simple object with just the hist attribute
        #                 model_data_fallback["approx"] = SimpleApprox(
        #                     approx_obj.hist)
        #             except Exception:
        #                 # If even hist fails to pickle, skip it
        #                 pass
        #
        #         metadata_only_dict = {
        #             "metadata": picklable_dict.get("metadata", {}),
        #             "model_data": model_data_fallback
        #         }
        #         pickle.dump(metadata_only_dict, buff)

        json_file_name = os.path.join(
            self.database_handler.model_save_path + ".info")
        with open(json_file_name, "w") as file:
            json.dump(out_dict["metadata"], file, indent=4)

        self.database_handler.write_to_database()

        print(
            "Saving inference output to : \n"
            f"{self.database_handler.model_save_dir}"
            "\n" + "================================" + "\n"
        )

__init__(data_dir, taste_num, region_name, laser_type=None, experiment_name=None, model_params_path=None, preprocess_params_path=None)

Initialize FitHandler class

Parameters:

Name Type Description Default
data_dir str

Path to directory containing HDF5 file

required
taste_num int

Index of taste to perform fit on (Corresponds to INDEX of taste in spike array, not actual dig_ins)

required
region_name str

Region on which to perform fit on (must match regions in .info file)

required
experiment_name str

Name given to fitted batch (for metedata). Defaults to None.

None
model_params_path str

Path to json file containing model parameters. Defaults to None.

None
preprocess_params_path str

Path to json file containing preprocessing parameters. Defaults to None.

None

Raises:

Type Description
Exception

If "experiment_name" is None

Exception

If "laser_type" is not in [None, 'on', 'off']

Exception

If "taste_num" is not integer or "all"

Source code in pytau/changepoint_io.py
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def __init__(
    self,
    data_dir,
    taste_num,
    region_name,
    laser_type=None,
    experiment_name=None,
    model_params_path=None,
    preprocess_params_path=None,
):
    """Initialize FitHandler class

    Args:
        data_dir (str): Path to directory containing HDF5 file
        taste_num (int): Index of taste to perform fit on
                (Corresponds to INDEX of taste in spike array, not actual dig_ins)
        region_name (str): Region on which to perform fit on
                (must match regions in .info file)
        experiment_name (str, optional): Name given to fitted batch
                (for metedata). Defaults to None.
        model_params_path (str, optional): Path to json file
                containing model parameters. Defaults to None.
        preprocess_params_path (str, optional): Path to json file
                containing preprocessing parameters. Defaults to None.

    Raises:
        Exception: If "experiment_name" is None
        Exception: If "laser_type" is not in [None, 'on', 'off']
        Exception: If "taste_num" is not integer or "all"
    """

    # =============== Check for exceptions ===============
    if experiment_name is None:
        raise Exception("Please specify an experiment name")
    if laser_type not in [None, "on", "off"]:
        raise Exception('laser_type must be from [None, "on","off"]')
    if not (isinstance(taste_num, int) or taste_num == "all"):
        raise Exception('taste_num must be an integer or "all"')

    # =============== Save relevant arguments ===============
    self.data_dir = data_dir
    self.EphysData = EphysData(self.data_dir)
    # self.data = self.EphysData.get_spikes({"bla","gc","all"})

    self.taste_num = taste_num
    self.laser_type = laser_type
    self.region_name = region_name
    self.experiment_name = experiment_name

    data_handler_init_kwargs = dict(
        zip(
            [
                "data_dir",
                "experiment_name",
                "taste_num",
                "laser_type",
                "region_name",
            ],
            [data_dir, experiment_name, taste_num, laser_type, region_name],
        )
    )
    self.database_handler = DatabaseHandler()
    self.database_handler.set_run_params(**data_handler_init_kwargs)

    if model_params_path is None:
        print("MODEL_PARAMS will have to be set")
    else:
        self.set_model_params(file_path=model_params_path)

    if preprocess_params_path is None:
        print("PREPROCESS_PARAMS will have to be set")
    else:
        self.set_preprocess_params(file_path=preprocess_params_path)

create_model()

Create model and save as attribute

Will check for and complete: 1) Data preprocessed 2) Model template selected

Source code in pytau/changepoint_io.py
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
def create_model(self):
    """Create model and save as attribute

    Will check for and complete:
    1) Data preprocessed
    2) Model template selected
    """
    if "preprocessed_data" not in dir(self):
        self.preprocess_data()
    if "model_template" not in dir(self):
        self.model_template_selector()

    # In future iterations, before fitting model,
    # check that a similar entry doesn't exist

    print(
        f"Generating Model, model func: <{self.model_template.__name__}>")
    self.model = self.model_template(
        self.preprocessed_data,
        self.model_params["states"],
        **self.model_params["model_kwargs"],
    )

inference_func_selector()

Function to return model based off of input flag

Currently hard-coded to use "advi_fit"

Source code in pytau/changepoint_io.py
261
262
263
264
265
266
def inference_func_selector(self):
    """Function to return model based off of input flag

    Currently hard-coded to use "advi_fit"
    """
    self.set_inference(changepoint_model.advi_fit)

load_spike_trains()

Helper function to load spike trains from data_dir using EphysData module

Source code in pytau/changepoint_io.py
272
273
274
275
276
277
278
279
280
281
282
283
284
def load_spike_trains(self):
    """Helper function to load spike trains from data_dir using EphysData module"""
    full_spike_array = self.EphysData.return_region_spikes(
        region_name=self.region_name, laser=self.laser_type
    )
    if isinstance(self.taste_num, int):
        self.data = full_spike_array[self.taste_num]
    if self.taste_num == "all":
        self.data = full_spike_array
    print(
        f"Loading spike trains from {self.database_handler.data_basename}, "
        f"dig_in {self.taste_num}, laser {str(self.laser_type)}"
    )

model_template_selector()

Function to set model based off of input flag

Models can be set manually but it is preferred to go through model selector

Raises:

Type Description
Exception

If self.taste_num is neither int nor str

Source code in pytau/changepoint_io.py
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
def model_template_selector(self):
    """Function to set model based off of input flag

    Models can be set manually but it is preferred to go through model selector

    Raises:
        Exception: If self.taste_num is neither int nor str

    """
    if isinstance(self.taste_num, int):
        # self.set_model_template(changepoint_model.single_taste_poisson_varsig)
        self.set_model_template(changepoint_model.single_taste_poisson)
    elif self.taste_num == "all":
        self.set_model_template(changepoint_model.all_taste_poisson)
    else:
        raise Exception("Something went wrong")

preprocess_data()

Perform data preprocessing

Will check for and complete: 1) Raw data loaded 2) Preprocessor selected

Source code in pytau/changepoint_io.py
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
def preprocess_data(self):
    """Perform data preprocessing

    Will check for and complete:
    1) Raw data loaded
    2) Preprocessor selected
    """
    if "data" not in dir(self):
        self.load_spike_trains()
    if "preprocessor" not in dir(self):
        self.preprocess_selector()
    print(
        "Preprocessing spike trains, " f"preprocessing func: <{self.preprocessor.__name__}>")
    self.preprocessed_data = self.preprocessor(
        self.data, **self.preprocess_params)

preprocess_selector()

Function to return preprocess function based off of input flag

Preprocessing can be set manually but it is preferred to go through preprocess selector

Raises:

Type Description
Exception

If self.taste_num is neither int nor str

Source code in pytau/changepoint_io.py
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
def preprocess_selector(self):
    """Function to return preprocess function based off of input flag

    Preprocessing can be set manually but it is preferred to
    go through preprocess selector

    Raises:
        Exception: If self.taste_num is neither int nor str

    """

    if isinstance(self.taste_num, int):
        self.set_preprocessor(
            changepoint_preprocess.preprocess_single_taste)
    elif self.taste_num == "all":
        self.set_preprocessor(changepoint_preprocess.preprocess_all_taste)
    else:
        raise Exception("Something went wrong")

run_inference()

Perform inference on data

Will check for and complete: 1) Model created 2) Inference function selected

Source code in pytau/changepoint_io.py
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
def run_inference(self):
    """Perform inference on data

    Will check for and complete:
    1) Model created
    2) Inference function selected
    """
    if "model" not in dir(self):
        self.create_model()
    if "inference_func" not in dir(self):
        self.inference_func_selector()

    print(
        "Running inference, inference func: " f"<{self.inference_func.__name__}>")
    temp_outs = self.inference_func(
        self.model, self.model_params["fit"], self.model_params["samples"]
    )
    varnames = ["model", "approx", "lambda", "tau", "data"]
    self.inference_outs = dict(zip(varnames, temp_outs))

    # If data is None (e.g., from advi_fit to avoid PyMC5 compatibility issues),
    # use the preprocessed_data that was used for inference
    if self.inference_outs.get("data") is None and hasattr(self, "preprocessed_data"):
        self.inference_outs["data"] = self.preprocessed_data

save_fit_output()

Save fit output (fitted data + metadata) to pkl file

Source code in pytau/changepoint_io.py
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
def save_fit_output(self):
    """Save fit output (fitted data + metadata) to pkl file"""
    if "inference_outs" not in dir(self):
        self.run_inference()
    out_dict = self._return_fit_output()

    # Save output to pkl file
    with open(self.database_handler.model_save_path + ".pkl", "wb") as buff:
        pickle.dump(out_dict, buff)
    print(
        f"Saved full output to {self.database_handler.model_save_path}.pkl")

    # # Create a copy without the model to avoid pickling issues with PyMC5
    # picklable_dict = out_dict.copy()
    # if "model_data" in picklable_dict and "model" in picklable_dict["model_data"]:
    #     picklable_model_data = picklable_dict["model_data"].copy()
    #     # Remove the model object as it contains unpicklable local functions in PyMC5
    #     picklable_model_data.pop("model", None)
    #     picklable_dict["model_data"] = picklable_model_data
    #
    # with open(self.database_handler.model_save_path + ".pkl", "wb") as buff:
    #     try:
    #         pickle.dump(picklable_dict, buff)
    #     except (TypeError, AttributeError) as e:
    #         print(
    #             f"Warning: Full pickling failed ({e}). Saving metadata-only version.")
    #         # If pickling fails, save only metadata and basic info
    #         model_data_fallback = {
    #             "tau_array": picklable_dict.get("model_data", {}).get("tau_array"),
    #             "processed_spikes": picklable_dict.get("model_data", {}).get("processed_spikes"),
    #         }
    #
    #         # Try to save approx.hist for ELBO plotting if available
    #         approx_obj = picklable_dict.get("model_data", {}).get("approx")
    #         if approx_obj and hasattr(approx_obj, 'hist'):
    #             try:
    #                 # Create a simple object with just the hist attribute
    #                 model_data_fallback["approx"] = SimpleApprox(
    #                     approx_obj.hist)
    #             except Exception:
    #                 # If even hist fails to pickle, skip it
    #                 pass
    #
    #         metadata_only_dict = {
    #             "metadata": picklable_dict.get("metadata", {}),
    #             "model_data": model_data_fallback
    #         }
    #         pickle.dump(metadata_only_dict, buff)

    json_file_name = os.path.join(
        self.database_handler.model_save_path + ".info")
    with open(json_file_name, "w") as file:
        json.dump(out_dict["metadata"], file, indent=4)

    self.database_handler.write_to_database()

    print(
        "Saving inference output to : \n"
        f"{self.database_handler.model_save_dir}"
        "\n" + "================================" + "\n"
    )

set_inference(inference_func)

Manually set inference function for model fit e.g.

FitHandler.set_inference(changepoint_model.advi_fit)

Parameters:

Name Type Description Default
inference_func func

Function to use for fitting model

required
Source code in pytau/changepoint_io.py
251
252
253
254
255
256
257
258
259
def set_inference(self, inference_func):
    """Manually set inference function for model fit e.g.

    FitHandler.set_inference(changepoint_model.advi_fit)

    Args:
        inference_func (func): Function to use for fitting model
    """
    self.inference_func = changepoint_model.advi_fit

set_model_params(states, fit, samples, model_kwargs=None, file_path=None)

Load given params as "model_params" attribute

Parameters:

Name Type Description Default
states int

Number of states to use in model

required
fit int

Iterations to use for model fitting (given ADVI fit)

required
samples int

Number of samples to return from fitten model

required
model_kwargs (dict)

Additional paramters for model

required
file_path str

Path to json file containing preprocess parameters. Defaults to None.

None
Source code in pytau/changepoint_io.py
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
def set_model_params(self, states, fit, samples, model_kwargs=None, file_path=None):
    """Load given params as "model_params" attribute

    Args:
        states (int): Number of states to use in model
        fit (int): Iterations to use for model fitting (given ADVI fit)
        samples (int): Number of samples to return from fitten model
        model_kwargs (dict) : Additional paramters for model
        file_path (str, optional): Path to json file containing
                preprocess parameters. Defaults to None.
    """

    if file_path is None:
        model_params_dict = dict(
            zip(
                ["states", "fit", "samples", "model_kwargs"],
                [states, fit, samples, model_kwargs],
            )
        )
        self.model_params = model_params_dict
        print("Set model params to: {}".format(model_params_dict))

    else:
        # Load json and save dict
        pass

set_model_template(model_template)

Manually set model_template for data e.g.

FitHandler.set_model(changepoint_model.single_taste_poisson)

Parameters:

Name Type Description Default
model_template func

Function to generate model template for data]

required
Source code in pytau/changepoint_io.py
224
225
226
227
228
229
230
231
232
def set_model_template(self, model_template):
    """Manually set model_template for data e.g.

    FitHandler.set_model(changepoint_model.single_taste_poisson)

    Args:
        model_template (func): Function to generate model template for data]
    """
    self.model_template = model_template

set_preprocess_params(time_lims, bin_width, data_transform, file_path=None)

Load given params as "preprocess_params" attribute

Parameters:

Name Type Description Default
time_lims array / tuple / list

Start and end of where to cut spike train array

required
bin_width int

Bin width for binning spikes to counts

required
data_transform str

Indicator for which transformation to use (refer to changepoint_preprocess)

required
file_path str

Path to json file containing preprocess parameters. Defaults to None.

None
Source code in pytau/changepoint_io.py
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
def set_preprocess_params(self, time_lims, bin_width, data_transform, file_path=None):
    """Load given params as "preprocess_params" attribute

    Args:
        time_lims (array/tuple/list): Start and end of where to cut
                spike train array
        bin_width (int): Bin width for binning spikes to counts
        data_transform (str): Indicator for which transformation to
                use (refer to changepoint_preprocess)
        file_path (str, optional): Path to json file containing preprocess
                parameters. Defaults to None.
    """

    if file_path is None:
        preprocess_params_dict = dict(
            zip(
                ["time_lims", "bin_width", "data_transform"],
                [time_lims, bin_width, data_transform],
            )
        )
        self.preprocess_params = preprocess_params_dict
        print("Set preprocess params to: {}".format(preprocess_params_dict))

    else:
        # Load json and save dict
        pass

set_preprocessor(preprocessing_func)

Manually set preprocessor for data e.g.

FitHandler.set_preprocessor( changepoint_preprocess.preprocess_single_taste)

Parameters:

Name Type Description Default
preprocessing_func func
Function to preprocess data (refer to changepoint_preprocess)
required
Source code in pytau/changepoint_io.py
193
194
195
196
197
198
199
200
201
202
203
def set_preprocessor(self, preprocessing_func):
    """Manually set preprocessor for data e.g.

    FitHandler.set_preprocessor(
                changepoint_preprocess.preprocess_single_taste)

    Args:
        preprocessing_func (func):
                Function to preprocess data (refer to changepoint_preprocess)
    """
    self.preprocessor = preprocessing_func

SimpleApprox

Simple approximation object that only stores the hist attribute for ELBO plotting

Source code in pytau/changepoint_io.py
23
24
25
26
27
class SimpleApprox:
    """Simple approximation object that only stores the hist attribute for ELBO plotting"""

    def __init__(self, hist):
        self.hist = hist

=== Model building functions ===

pymc Blackbox Variational Inference implementation of Poisson Likelihood Changepoint for spike trains.

AllTastePoisson

Bases: ChangepointModel

** Model to fit changepoint to all tastes ** ** Largely taken from "_v1/poisson_all_tastes_changepoint_model.py"

Source code in pytau/changepoint_model.py
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
class AllTastePoisson(ChangepointModel):
    """
    ** Model to fit changepoint to all tastes **
    ** Largely taken from "_v1/poisson_all_tastes_changepoint_model.py"
    """

    def __init__(self, data_array, n_states, **kwargs):
        """
        Args:
            data_array (4D Numpy array): tastes, trials, neurons, time_bins
            n_states (int): Number of states to model
            **kwargs: Additional arguments
        """
        super().__init__(**kwargs)
        self.data_array = data_array
        self.n_states = n_states

    def generate_model(self):
        """
        Returns:
            pymc model: Model class containing graph to run inference on
        """
        data_array = self.data_array
        n_states = self.n_states

        # Unroll arrays along taste axis
        data_array_long = np.concatenate(data_array, axis=0)

        # Find mean firing for initial values
        tastes = data_array.shape[0]
        length = data_array.shape[-1]
        nrns = data_array.shape[2]
        trials = data_array.shape[1]

        split_list = np.array_split(data_array, n_states, axis=-1)
        # Cut all to the same size
        min_val = min([x.shape[-1] for x in split_list])
        split_array = np.array([x[..., :min_val] for x in split_list])
        mean_vals = np.mean(split_array, axis=(2, -1)).swapaxes(0, 1)
        mean_vals += 0.01  # To avoid zero starting prob
        mean_nrn_vals = np.mean(mean_vals, axis=(0, 1))

        # Find evenly spaces switchpoints for initial values
        idx = np.arange(data_array.shape[-1])  # Index
        array_idx = np.broadcast_to(idx, data_array_long.shape)
        even_switches = np.linspace(0, idx.max(), n_states + 1)
        even_switches_normal = even_switches / np.max(even_switches)

        taste_label = np.repeat(
            np.arange(data_array.shape[0]), data_array.shape[1])
        trial_num = array_idx.shape[0]

        # Being constructing model
        with pm.Model() as model:
            # Hierarchical firing rates
            # Refer to model diagram
            # Mean firing rate of neuron AT ALL TIMES
            lambda_nrn = pm.Exponential(
                "lambda_nrn", 1 / mean_nrn_vals, shape=(mean_vals.shape[-1])
            )
            # Priors for each state, derived from each neuron
            # Mean firing rate of neuron IN EACH STATE (averaged across tastes)
            lambda_state = pm.Exponential(
                "lambda_state", lambda_nrn, shape=(mean_vals.shape[1:]))
            # Mean firing rate of neuron PER STATE PER TASTE
            lambda_latent = pm.Exponential(
                "lambda",
                lambda_state[np.newaxis, :, :],
                initval=mean_vals,
                shape=(mean_vals.shape),
            )

            # Changepoint time variable
            # INDEPENDENT TAU FOR EVERY TRIAL
            a = pm.HalfNormal("a_tau", 3.0, shape=n_states - 1)
            b = pm.HalfNormal("b_tau", 3.0, shape=n_states - 1)

            # Stack produces n_states x trials --> That gets transposed
            # to trials x n_states and gets sorted along n_states (axis=-1)
            # Sort should work the same way as the Ordered transform -->
            # see rv_sort_test.ipynb
            tau_latent = pm.Beta(
                "tau_latent",
                a,
                b,
                shape=(trial_num, n_states - 1),
                initval=tt.tile(even_switches_normal[1:(
                    n_states)], (array_idx.shape[0], 1)),
            ).sort(axis=-1)

            tau = pm.Deterministic(
                "tau", idx.min() + (idx.max() - idx.min()) * tau_latent)

            weight_stack = tt.math.sigmoid(
                idx[np.newaxis, :] - tau[:, :, np.newaxis])
            weight_stack = tt.concatenate(
                [np.ones((tastes * trials, 1, length)), weight_stack], axis=1
            )
            inverse_stack = 1 - weight_stack[:, 1:]
            inverse_stack = tt.concatenate(
                [inverse_stack, np.ones((tastes * trials, 1, length))], axis=1
            )
            weight_stack = weight_stack * inverse_stack
            weight_stack = tt.tile(
                weight_stack[:, :, None, :], (1, 1, nrns, 1))

            lambda_latent = lambda_latent.dimshuffle(2, 0, 1)
            lambda_latent = tt.repeat(lambda_latent, trials, axis=1)
            lambda_latent = tt.tile(
                lambda_latent[..., None], (1, 1, 1, length))
            lambda_latent = lambda_latent.dimshuffle(1, 2, 0, 3)
            lambda_ = tt.sum(lambda_latent * weight_stack, axis=1)

            observation = pm.Poisson("obs", lambda_, observed=data_array_long)

        return model

    def test(self):
        """Test the model with synthetic data"""
        # Generate test data
        test_data = gen_test_array(
            (2, 5, 10, 100), n_states=self.n_states, type="poisson")

        # Create model with test data
        test_model = AllTastePoisson(test_data, self.n_states)
        model = test_model.generate_model()

        # Run a minimal inference to verify model works
        with model:
            # Just do a few iterations to test functionality
            inference = pm.ADVI()
            approx = pm.fit(n=10, method=inference)
            trace = approx.sample(draws=10)

        # Check if expected variables are in the trace
        assert "lambda" in trace.varnames
        assert "tau" in trace.varnames
        assert "lambda_nrn" in trace.varnames
        assert "lambda_state" in trace.varnames

        print("Test for AllTastePoisson passed")
        return True

__init__(data_array, n_states, **kwargs)

Parameters:

Name Type Description Default
data_array 4D Numpy array

tastes, trials, neurons, time_bins

required
n_states int

Number of states to model

required
**kwargs

Additional arguments

{}
Source code in pytau/changepoint_model.py
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
def __init__(self, data_array, n_states, **kwargs):
    """
    Args:
        data_array (4D Numpy array): tastes, trials, neurons, time_bins
        n_states (int): Number of states to model
        **kwargs: Additional arguments
    """
    super().__init__(**kwargs)
    self.data_array = data_array
    self.n_states = n_states

generate_model()

Returns:

Type Description

pymc model: Model class containing graph to run inference on

Source code in pytau/changepoint_model.py
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
def generate_model(self):
    """
    Returns:
        pymc model: Model class containing graph to run inference on
    """
    data_array = self.data_array
    n_states = self.n_states

    # Unroll arrays along taste axis
    data_array_long = np.concatenate(data_array, axis=0)

    # Find mean firing for initial values
    tastes = data_array.shape[0]
    length = data_array.shape[-1]
    nrns = data_array.shape[2]
    trials = data_array.shape[1]

    split_list = np.array_split(data_array, n_states, axis=-1)
    # Cut all to the same size
    min_val = min([x.shape[-1] for x in split_list])
    split_array = np.array([x[..., :min_val] for x in split_list])
    mean_vals = np.mean(split_array, axis=(2, -1)).swapaxes(0, 1)
    mean_vals += 0.01  # To avoid zero starting prob
    mean_nrn_vals = np.mean(mean_vals, axis=(0, 1))

    # Find evenly spaces switchpoints for initial values
    idx = np.arange(data_array.shape[-1])  # Index
    array_idx = np.broadcast_to(idx, data_array_long.shape)
    even_switches = np.linspace(0, idx.max(), n_states + 1)
    even_switches_normal = even_switches / np.max(even_switches)

    taste_label = np.repeat(
        np.arange(data_array.shape[0]), data_array.shape[1])
    trial_num = array_idx.shape[0]

    # Being constructing model
    with pm.Model() as model:
        # Hierarchical firing rates
        # Refer to model diagram
        # Mean firing rate of neuron AT ALL TIMES
        lambda_nrn = pm.Exponential(
            "lambda_nrn", 1 / mean_nrn_vals, shape=(mean_vals.shape[-1])
        )
        # Priors for each state, derived from each neuron
        # Mean firing rate of neuron IN EACH STATE (averaged across tastes)
        lambda_state = pm.Exponential(
            "lambda_state", lambda_nrn, shape=(mean_vals.shape[1:]))
        # Mean firing rate of neuron PER STATE PER TASTE
        lambda_latent = pm.Exponential(
            "lambda",
            lambda_state[np.newaxis, :, :],
            initval=mean_vals,
            shape=(mean_vals.shape),
        )

        # Changepoint time variable
        # INDEPENDENT TAU FOR EVERY TRIAL
        a = pm.HalfNormal("a_tau", 3.0, shape=n_states - 1)
        b = pm.HalfNormal("b_tau", 3.0, shape=n_states - 1)

        # Stack produces n_states x trials --> That gets transposed
        # to trials x n_states and gets sorted along n_states (axis=-1)
        # Sort should work the same way as the Ordered transform -->
        # see rv_sort_test.ipynb
        tau_latent = pm.Beta(
            "tau_latent",
            a,
            b,
            shape=(trial_num, n_states - 1),
            initval=tt.tile(even_switches_normal[1:(
                n_states)], (array_idx.shape[0], 1)),
        ).sort(axis=-1)

        tau = pm.Deterministic(
            "tau", idx.min() + (idx.max() - idx.min()) * tau_latent)

        weight_stack = tt.math.sigmoid(
            idx[np.newaxis, :] - tau[:, :, np.newaxis])
        weight_stack = tt.concatenate(
            [np.ones((tastes * trials, 1, length)), weight_stack], axis=1
        )
        inverse_stack = 1 - weight_stack[:, 1:]
        inverse_stack = tt.concatenate(
            [inverse_stack, np.ones((tastes * trials, 1, length))], axis=1
        )
        weight_stack = weight_stack * inverse_stack
        weight_stack = tt.tile(
            weight_stack[:, :, None, :], (1, 1, nrns, 1))

        lambda_latent = lambda_latent.dimshuffle(2, 0, 1)
        lambda_latent = tt.repeat(lambda_latent, trials, axis=1)
        lambda_latent = tt.tile(
            lambda_latent[..., None], (1, 1, 1, length))
        lambda_latent = lambda_latent.dimshuffle(1, 2, 0, 3)
        lambda_ = tt.sum(lambda_latent * weight_stack, axis=1)

        observation = pm.Poisson("obs", lambda_, observed=data_array_long)

    return model

test()

Test the model with synthetic data

Source code in pytau/changepoint_model.py
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
def test(self):
    """Test the model with synthetic data"""
    # Generate test data
    test_data = gen_test_array(
        (2, 5, 10, 100), n_states=self.n_states, type="poisson")

    # Create model with test data
    test_model = AllTastePoisson(test_data, self.n_states)
    model = test_model.generate_model()

    # Run a minimal inference to verify model works
    with model:
        # Just do a few iterations to test functionality
        inference = pm.ADVI()
        approx = pm.fit(n=10, method=inference)
        trace = approx.sample(draws=10)

    # Check if expected variables are in the trace
    assert "lambda" in trace.varnames
    assert "tau" in trace.varnames
    assert "lambda_nrn" in trace.varnames
    assert "lambda_state" in trace.varnames

    print("Test for AllTastePoisson passed")
    return True

AllTastePoissonTrialSwitch

Bases: ChangepointModel

Assuming only emissions change across trials Changepoint distribution remains constant

Source code in pytau/changepoint_model.py
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
class AllTastePoissonTrialSwitch(ChangepointModel):
    """
    Assuming only emissions change across trials
    Changepoint distribution remains constant
    """

    def __init__(self, data_array, switch_components, n_states, **kwargs):
        """
        Args:
            data_array (4D Numpy array): tastes, trials, neurons, time_bins
            switch_components (int): Number of trial switch components
            n_states (int): Number of states to model
            **kwargs: Additional arguments
        """
        super().__init__(**kwargs)
        self.data_array = data_array
        self.switch_components = switch_components
        self.n_states = n_states

    def generate_model(self):
        """
        Returns:
            pymc model: Model class containing graph to run inference on
        """
        data_array = self.data_array
        switch_components = self.switch_components
        n_states = self.n_states

        tastes, trial_num, nrn_num, time_bins = data_array.shape

        with pm.Model() as model:
            # Define Emissions
            # =================================================

            # nrns
            nrn_lambda = pm.Exponential("nrn_lambda", 10, shape=(nrn_num))

            # tastes x nrns
            taste_lambda = pm.Exponential(
                "taste_lambda", nrn_lambda.dimshuffle("x", 0), shape=(tastes, nrn_num)
            )

            # tastes x nrns x switch_comps
            trial_lambda = pm.Exponential(
                "trial_lambda",
                taste_lambda.dimshuffle(0, 1, "x"),
                shape=(tastes, nrn_num, switch_components),
            )

            # tastes x nrns x switch_comps x n_states
            state_lambda = pm.Exponential(
                "state_lambda",
                trial_lambda.dimshuffle(0, 1, 2, "x"),
                shape=(tastes, nrn_num, switch_components, n_states),
            )

            # Define Changepoints
            # =================================================
            # Assuming distribution of changepoints remains
            # the same across all trials

            a = pm.HalfCauchy("a_tau", 3.0, shape=n_states - 1)
            b = pm.HalfCauchy("b_tau", 3.0, shape=n_states - 1)

            even_switches = np.linspace(0, 1, n_states + 1)[1:-1]
            tau_latent = pm.Beta(
                "tau_latent",
                a,
                b,
                # initval=even_switches,
                shape=(tastes, trial_num, n_states - 1),
            ).sort(axis=-1)

            # Tasets x Trials x Changepoints
            tau = pm.Deterministic("tau", time_bins * tau_latent)

            # Define trial switches
            # Will have same structure as regular changepoints

            # a_trial = pm.HalfCauchy('a_trial', 3., shape = switch_components - 1)
            # b_trial = pm.HalfCauchy('b_trial', 3., shape = switch_components - 1)

            even_trial_switches = np.linspace(
                0, 1, switch_components + 1)[1:-1]
            tau_trial_latent = pm.Beta(
                "tau_trial_latent",
                1,
                1,
                initval=even_trial_switches,
                shape=(switch_components - 1),
            ).sort(axis=-1)

            # Trial_changepoints
            # =================================================
            tau_trial = pm.Deterministic(
                "tau_trial", trial_num * tau_trial_latent)

            trial_idx = np.arange(trial_num)
            trial_selector = tt.math.sigmoid(
                trial_idx[np.newaxis, :] - tau_trial.dimshuffle(0, "x")
            )

            trial_selector = tt.concatenate(
                [np.ones((1, trial_num)), trial_selector], axis=0)
            inverse_trial_selector = 1 - trial_selector[1:, :]
            inverse_trial_selector = tt.concatenate(
                [inverse_trial_selector, np.ones((1, trial_num))], axis=0
            )

            # switch_comps x trials
            trial_selector = np.multiply(
                trial_selector, inverse_trial_selector)

            # state_lambda: tastes x nrns x switch_comps x states

            # selected_trial_lambda : tastes x nrns x states x trials
            selected_trial_lambda = pm.Deterministic(
                "selected_trial_lambda",
                tt.sum(
                    # "tastes" x "nrns" x switch_comps x "states" x trials
                    trial_selector.dimshuffle("x", "x", 0, "x", 1)
                    * state_lambda.dimshuffle(0, 1, 2, 3, "x"),
                    axis=2,
                ),
            )

            # First, we can "select" sets of emissions depending on trial_changepoints
            # =================================================
            trial_idx = np.arange(trial_num)
            trial_selector = tt.math.sigmoid(
                trial_idx[np.newaxis, :] - tau_trial.dimshuffle(0, "x")
            )

            trial_selector = tt.concatenate(
                [np.ones((1, trial_num)), trial_selector], axis=0)
            inverse_trial_selector = 1 - trial_selector[1:, :]
            inverse_trial_selector = tt.concatenate(
                [inverse_trial_selector, np.ones((1, trial_num))], axis=0
            )

            # switch_comps x trials
            trial_selector = np.multiply(
                trial_selector, inverse_trial_selector)

            # Then, we can select state_emissions for every trial
            # =================================================

            idx = np.arange(time_bins)

            # tau : Tastes x Trials x Changepoints
            weight_stack = tt.math.sigmoid(
                idx[np.newaxis, :] - tau[:, :, :, np.newaxis])
            weight_stack = tt.concatenate(
                [np.ones((tastes, trial_num, 1, time_bins)), weight_stack], axis=2
            )
            inverse_stack = 1 - weight_stack[:, :, 1:]
            inverse_stack = tt.concatenate(
                [inverse_stack, np.ones((tastes, trial_num, 1, time_bins))], axis=2
            )

            # Tastes x Trials x states x Time
            weight_stack = np.multiply(weight_stack, inverse_stack)

            # Putting everything together
            # =================================================

            # selected_trial_lambda :           tastes x nrns x states x trials
            # Convert selected_trial_lambda --> tastes x trials x nrns x states x "time"

            # weight_stack :           tastes x trials x states x time
            # Convert weight_stack --> tastes x trials x "nrns" x states x time

            # tastes x trials x nrns x time
            lambda_ = tt.sum(
                selected_trial_lambda.dimshuffle(0, 3, 1, 2, "x")
                * weight_stack.dimshuffle(0, 1, "x", 2, 3),
                axis=3,
            )

            # Add observations
            observation = pm.Poisson("obs", lambda_, observed=data_array)

        return model

    def test(self):
        """Test the model with synthetic data"""
        # Generate test data
        test_data = gen_test_array(
            (2, 5, 10, 100), n_states=self.n_states, type="poisson")

        # Create model with test data
        test_model = AllTastePoissonTrialSwitch(
            test_data, self.switch_components, self.n_states)
        model = test_model.generate_model()

        # Run a minimal inference to verify model works
        with model:
            # Just do a few iterations to test functionality
            inference = pm.ADVI()
            approx = pm.fit(n=10, method=inference)
            trace = approx.sample(draws=10)

        # Check if expected variables are in the trace
        assert "nrn_lambda" in trace.varnames
        assert "tau" in trace.varnames
        assert "tau_trial" in trace.varnames
        assert "state_lambda" in trace.varnames
        assert "taste_lambda" in trace.varnames

        print("Test for AllTastePoissonTrialSwitch passed")
        return True

__init__(data_array, switch_components, n_states, **kwargs)

Parameters:

Name Type Description Default
data_array 4D Numpy array

tastes, trials, neurons, time_bins

required
switch_components int

Number of trial switch components

required
n_states int

Number of states to model

required
**kwargs

Additional arguments

{}
Source code in pytau/changepoint_model.py
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
def __init__(self, data_array, switch_components, n_states, **kwargs):
    """
    Args:
        data_array (4D Numpy array): tastes, trials, neurons, time_bins
        switch_components (int): Number of trial switch components
        n_states (int): Number of states to model
        **kwargs: Additional arguments
    """
    super().__init__(**kwargs)
    self.data_array = data_array
    self.switch_components = switch_components
    self.n_states = n_states

generate_model()

Returns:

Type Description

pymc model: Model class containing graph to run inference on

Source code in pytau/changepoint_model.py
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
def generate_model(self):
    """
    Returns:
        pymc model: Model class containing graph to run inference on
    """
    data_array = self.data_array
    switch_components = self.switch_components
    n_states = self.n_states

    tastes, trial_num, nrn_num, time_bins = data_array.shape

    with pm.Model() as model:
        # Define Emissions
        # =================================================

        # nrns
        nrn_lambda = pm.Exponential("nrn_lambda", 10, shape=(nrn_num))

        # tastes x nrns
        taste_lambda = pm.Exponential(
            "taste_lambda", nrn_lambda.dimshuffle("x", 0), shape=(tastes, nrn_num)
        )

        # tastes x nrns x switch_comps
        trial_lambda = pm.Exponential(
            "trial_lambda",
            taste_lambda.dimshuffle(0, 1, "x"),
            shape=(tastes, nrn_num, switch_components),
        )

        # tastes x nrns x switch_comps x n_states
        state_lambda = pm.Exponential(
            "state_lambda",
            trial_lambda.dimshuffle(0, 1, 2, "x"),
            shape=(tastes, nrn_num, switch_components, n_states),
        )

        # Define Changepoints
        # =================================================
        # Assuming distribution of changepoints remains
        # the same across all trials

        a = pm.HalfCauchy("a_tau", 3.0, shape=n_states - 1)
        b = pm.HalfCauchy("b_tau", 3.0, shape=n_states - 1)

        even_switches = np.linspace(0, 1, n_states + 1)[1:-1]
        tau_latent = pm.Beta(
            "tau_latent",
            a,
            b,
            # initval=even_switches,
            shape=(tastes, trial_num, n_states - 1),
        ).sort(axis=-1)

        # Tasets x Trials x Changepoints
        tau = pm.Deterministic("tau", time_bins * tau_latent)

        # Define trial switches
        # Will have same structure as regular changepoints

        # a_trial = pm.HalfCauchy('a_trial', 3., shape = switch_components - 1)
        # b_trial = pm.HalfCauchy('b_trial', 3., shape = switch_components - 1)

        even_trial_switches = np.linspace(
            0, 1, switch_components + 1)[1:-1]
        tau_trial_latent = pm.Beta(
            "tau_trial_latent",
            1,
            1,
            initval=even_trial_switches,
            shape=(switch_components - 1),
        ).sort(axis=-1)

        # Trial_changepoints
        # =================================================
        tau_trial = pm.Deterministic(
            "tau_trial", trial_num * tau_trial_latent)

        trial_idx = np.arange(trial_num)
        trial_selector = tt.math.sigmoid(
            trial_idx[np.newaxis, :] - tau_trial.dimshuffle(0, "x")
        )

        trial_selector = tt.concatenate(
            [np.ones((1, trial_num)), trial_selector], axis=0)
        inverse_trial_selector = 1 - trial_selector[1:, :]
        inverse_trial_selector = tt.concatenate(
            [inverse_trial_selector, np.ones((1, trial_num))], axis=0
        )

        # switch_comps x trials
        trial_selector = np.multiply(
            trial_selector, inverse_trial_selector)

        # state_lambda: tastes x nrns x switch_comps x states

        # selected_trial_lambda : tastes x nrns x states x trials
        selected_trial_lambda = pm.Deterministic(
            "selected_trial_lambda",
            tt.sum(
                # "tastes" x "nrns" x switch_comps x "states" x trials
                trial_selector.dimshuffle("x", "x", 0, "x", 1)
                * state_lambda.dimshuffle(0, 1, 2, 3, "x"),
                axis=2,
            ),
        )

        # First, we can "select" sets of emissions depending on trial_changepoints
        # =================================================
        trial_idx = np.arange(trial_num)
        trial_selector = tt.math.sigmoid(
            trial_idx[np.newaxis, :] - tau_trial.dimshuffle(0, "x")
        )

        trial_selector = tt.concatenate(
            [np.ones((1, trial_num)), trial_selector], axis=0)
        inverse_trial_selector = 1 - trial_selector[1:, :]
        inverse_trial_selector = tt.concatenate(
            [inverse_trial_selector, np.ones((1, trial_num))], axis=0
        )

        # switch_comps x trials
        trial_selector = np.multiply(
            trial_selector, inverse_trial_selector)

        # Then, we can select state_emissions for every trial
        # =================================================

        idx = np.arange(time_bins)

        # tau : Tastes x Trials x Changepoints
        weight_stack = tt.math.sigmoid(
            idx[np.newaxis, :] - tau[:, :, :, np.newaxis])
        weight_stack = tt.concatenate(
            [np.ones((tastes, trial_num, 1, time_bins)), weight_stack], axis=2
        )
        inverse_stack = 1 - weight_stack[:, :, 1:]
        inverse_stack = tt.concatenate(
            [inverse_stack, np.ones((tastes, trial_num, 1, time_bins))], axis=2
        )

        # Tastes x Trials x states x Time
        weight_stack = np.multiply(weight_stack, inverse_stack)

        # Putting everything together
        # =================================================

        # selected_trial_lambda :           tastes x nrns x states x trials
        # Convert selected_trial_lambda --> tastes x trials x nrns x states x "time"

        # weight_stack :           tastes x trials x states x time
        # Convert weight_stack --> tastes x trials x "nrns" x states x time

        # tastes x trials x nrns x time
        lambda_ = tt.sum(
            selected_trial_lambda.dimshuffle(0, 3, 1, 2, "x")
            * weight_stack.dimshuffle(0, 1, "x", 2, 3),
            axis=3,
        )

        # Add observations
        observation = pm.Poisson("obs", lambda_, observed=data_array)

    return model

test()

Test the model with synthetic data

Source code in pytau/changepoint_model.py
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
def test(self):
    """Test the model with synthetic data"""
    # Generate test data
    test_data = gen_test_array(
        (2, 5, 10, 100), n_states=self.n_states, type="poisson")

    # Create model with test data
    test_model = AllTastePoissonTrialSwitch(
        test_data, self.switch_components, self.n_states)
    model = test_model.generate_model()

    # Run a minimal inference to verify model works
    with model:
        # Just do a few iterations to test functionality
        inference = pm.ADVI()
        approx = pm.fit(n=10, method=inference)
        trace = approx.sample(draws=10)

    # Check if expected variables are in the trace
    assert "nrn_lambda" in trace.varnames
    assert "tau" in trace.varnames
    assert "tau_trial" in trace.varnames
    assert "state_lambda" in trace.varnames
    assert "taste_lambda" in trace.varnames

    print("Test for AllTastePoissonTrialSwitch passed")
    return True

AllTastePoissonVarsigFixed

Bases: ChangepointModel

** Model to fit changepoint to all tastes with fixed sigmoid ** ** Largely taken from "_v1/poisson_all_tastes_changepoint_model.py"

Source code in pytau/changepoint_model.py
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
class AllTastePoissonVarsigFixed(ChangepointModel):
    """
    ** Model to fit changepoint to all tastes with fixed sigmoid **
    ** Largely taken from "_v1/poisson_all_tastes_changepoint_model.py"
    """

    def __init__(self, data_array, n_states, inds_span=1, **kwargs):
        """
        Args:
            data_array (4D Numpy array): tastes, trials, neurons, time_bins
            n_states (int): Number of states to model
            inds_span(float): Number of indices to cover 5-95% change in sigmoid
            **kwargs: Additional arguments
        """
        super().__init__(**kwargs)
        self.data_array = data_array
        self.n_states = n_states
        self.inds_span = inds_span

    def generate_model(self):
        """
        Returns:
            pymc model: Model class containing graph to run inference on
        """
        data_array = self.data_array
        n_states = self.n_states
        inds_span = self.inds_span

        # Unroll arrays along taste axis
        data_array_long = np.concatenate(data_array, axis=0)

        # Find mean firing for initial values
        tastes = data_array.shape[0]
        length = data_array.shape[-1]
        nrns = data_array.shape[2]
        trials = data_array.shape[1]

        split_list = np.array_split(data_array, n_states, axis=-1)
        # Cut all to the same size
        min_val = min([x.shape[-1] for x in split_list])
        split_array = np.array([x[..., :min_val] for x in split_list])
        mean_vals = np.mean(split_array, axis=(2, -1)).swapaxes(0, 1)
        mean_vals += 0.01  # To avoid zero starting prob
        mean_nrn_vals = np.mean(mean_vals, axis=(0, 1))

        # Find evenly spaces switchpoints for initial values
        idx = np.arange(data_array.shape[-1])  # Index
        array_idx = np.broadcast_to(idx, data_array_long.shape)
        even_switches = np.linspace(0, idx.max(), n_states + 1)
        even_switches_normal = even_switches / np.max(even_switches)

        taste_label = np.repeat(
            np.arange(data_array.shape[0]), data_array.shape[1])
        trial_num = array_idx.shape[0]

        # Define sigmoid with given sharpness
        sig_b = inds_to_b(inds_span)

        def sigmoid(x):
            b_temp = tt.tile(
                np.array(sig_b)[None, None, None], x.tag.test_value.shape)
            return 1 / (1 + tt.exp(-b_temp * x))

        # Being constructing model
        with pm.Model() as model:
            # Hierarchical firing rates
            # Refer to model diagram
            # Mean firing rate of neuron AT ALL TIMES
            lambda_nrn = pm.Exponential(
                "lambda_nrn", 1 / mean_nrn_vals, shape=(mean_vals.shape[-1])
            )
            # Priors for each state, derived from each neuron
            # Mean firing rate of neuron IN EACH STATE (averaged across tastes)
            lambda_state = pm.Exponential(
                "lambda_state", lambda_nrn, shape=(mean_vals.shape[1:]))
            # Mean firing rate of neuron PER STATE PER TASTE
            lambda_latent = pm.Exponential(
                "lambda",
                lambda_state[np.newaxis, :, :],
                initval=mean_vals,
                shape=(mean_vals.shape),
            )

            # Changepoint time variable
            # INDEPENDENT TAU FOR EVERY TRIAL
            a = pm.HalfNormal("a_tau", 3.0, shape=n_states - 1)
            b = pm.HalfNormal("b_tau", 3.0, shape=n_states - 1)

            # Stack produces n_states x trials --> That gets transposed
            # to trials x n_states and gets sorted along n_states (axis=-1)
            # Sort should work the same way as the Ordered transform -->
            # see rv_sort_test.ipynb
            tau_latent = pm.Beta(
                "tau_latent",
                a,
                b,
                shape=(trial_num, n_states - 1),
                initval=tt.tile(even_switches_normal[1:(
                    n_states)], (array_idx.shape[0], 1)),
            ).sort(axis=-1)

            tau = pm.Deterministic(
                "tau", idx.min() + (idx.max() - idx.min()) * tau_latent)

            weight_stack = sigmoid(idx[np.newaxis, :] - tau[:, :, np.newaxis])
            weight_stack = tt.concatenate(
                [np.ones((tastes * trials, 1, length)), weight_stack], axis=1
            )
            inverse_stack = 1 - weight_stack[:, 1:]
            inverse_stack = tt.concatenate(
                [inverse_stack, np.ones((tastes * trials, 1, length))], axis=1
            )
            weight_stack = weight_stack * inverse_stack
            weight_stack = tt.tile(
                weight_stack[:, :, None, :], (1, 1, nrns, 1))

            lambda_latent = lambda_latent.dimshuffle(2, 0, 1)
            lambda_latent = tt.repeat(lambda_latent, trials, axis=1)
            lambda_latent = tt.tile(
                lambda_latent[..., None], (1, 1, 1, length))
            lambda_latent = lambda_latent.dimshuffle(1, 2, 0, 3)
            lambda_ = tt.sum(lambda_latent * weight_stack, axis=1)

            observation = pm.Poisson("obs", lambda_, observed=data_array_long)

        return model

    def test(self):
        """Test the model with synthetic data"""
        # Generate test data
        test_data = gen_test_array(
            (2, 5, 10, 100), n_states=self.n_states, type="poisson")

        # Create model with test data
        test_model = AllTastePoissonVarsigFixed(
            test_data, self.n_states, self.inds_span)
        model = test_model.generate_model()

        # Run a minimal inference to verify model works
        with model:
            # Just do a few iterations to test functionality
            inference = pm.ADVI()
            approx = pm.fit(n=10, method=inference)
            trace = approx.sample(draws=10)

        # Check if expected variables are in the trace
        assert "lambda" in trace.varnames
        assert "tau" in trace.varnames
        assert "lambda_nrn" in trace.varnames
        assert "lambda_state" in trace.varnames

        print("Test for AllTastePoissonVarsigFixed passed")
        return True

__init__(data_array, n_states, inds_span=1, **kwargs)

Parameters:

Name Type Description Default
data_array 4D Numpy array

tastes, trials, neurons, time_bins

required
n_states int

Number of states to model

required
inds_span float

Number of indices to cover 5-95% change in sigmoid

1
**kwargs

Additional arguments

{}
Source code in pytau/changepoint_model.py
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
def __init__(self, data_array, n_states, inds_span=1, **kwargs):
    """
    Args:
        data_array (4D Numpy array): tastes, trials, neurons, time_bins
        n_states (int): Number of states to model
        inds_span(float): Number of indices to cover 5-95% change in sigmoid
        **kwargs: Additional arguments
    """
    super().__init__(**kwargs)
    self.data_array = data_array
    self.n_states = n_states
    self.inds_span = inds_span

generate_model()

Returns:

Type Description

pymc model: Model class containing graph to run inference on

Source code in pytau/changepoint_model.py
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
def generate_model(self):
    """
    Returns:
        pymc model: Model class containing graph to run inference on
    """
    data_array = self.data_array
    n_states = self.n_states
    inds_span = self.inds_span

    # Unroll arrays along taste axis
    data_array_long = np.concatenate(data_array, axis=0)

    # Find mean firing for initial values
    tastes = data_array.shape[0]
    length = data_array.shape[-1]
    nrns = data_array.shape[2]
    trials = data_array.shape[1]

    split_list = np.array_split(data_array, n_states, axis=-1)
    # Cut all to the same size
    min_val = min([x.shape[-1] for x in split_list])
    split_array = np.array([x[..., :min_val] for x in split_list])
    mean_vals = np.mean(split_array, axis=(2, -1)).swapaxes(0, 1)
    mean_vals += 0.01  # To avoid zero starting prob
    mean_nrn_vals = np.mean(mean_vals, axis=(0, 1))

    # Find evenly spaces switchpoints for initial values
    idx = np.arange(data_array.shape[-1])  # Index
    array_idx = np.broadcast_to(idx, data_array_long.shape)
    even_switches = np.linspace(0, idx.max(), n_states + 1)
    even_switches_normal = even_switches / np.max(even_switches)

    taste_label = np.repeat(
        np.arange(data_array.shape[0]), data_array.shape[1])
    trial_num = array_idx.shape[0]

    # Define sigmoid with given sharpness
    sig_b = inds_to_b(inds_span)

    def sigmoid(x):
        b_temp = tt.tile(
            np.array(sig_b)[None, None, None], x.tag.test_value.shape)
        return 1 / (1 + tt.exp(-b_temp * x))

    # Being constructing model
    with pm.Model() as model:
        # Hierarchical firing rates
        # Refer to model diagram
        # Mean firing rate of neuron AT ALL TIMES
        lambda_nrn = pm.Exponential(
            "lambda_nrn", 1 / mean_nrn_vals, shape=(mean_vals.shape[-1])
        )
        # Priors for each state, derived from each neuron
        # Mean firing rate of neuron IN EACH STATE (averaged across tastes)
        lambda_state = pm.Exponential(
            "lambda_state", lambda_nrn, shape=(mean_vals.shape[1:]))
        # Mean firing rate of neuron PER STATE PER TASTE
        lambda_latent = pm.Exponential(
            "lambda",
            lambda_state[np.newaxis, :, :],
            initval=mean_vals,
            shape=(mean_vals.shape),
        )

        # Changepoint time variable
        # INDEPENDENT TAU FOR EVERY TRIAL
        a = pm.HalfNormal("a_tau", 3.0, shape=n_states - 1)
        b = pm.HalfNormal("b_tau", 3.0, shape=n_states - 1)

        # Stack produces n_states x trials --> That gets transposed
        # to trials x n_states and gets sorted along n_states (axis=-1)
        # Sort should work the same way as the Ordered transform -->
        # see rv_sort_test.ipynb
        tau_latent = pm.Beta(
            "tau_latent",
            a,
            b,
            shape=(trial_num, n_states - 1),
            initval=tt.tile(even_switches_normal[1:(
                n_states)], (array_idx.shape[0], 1)),
        ).sort(axis=-1)

        tau = pm.Deterministic(
            "tau", idx.min() + (idx.max() - idx.min()) * tau_latent)

        weight_stack = sigmoid(idx[np.newaxis, :] - tau[:, :, np.newaxis])
        weight_stack = tt.concatenate(
            [np.ones((tastes * trials, 1, length)), weight_stack], axis=1
        )
        inverse_stack = 1 - weight_stack[:, 1:]
        inverse_stack = tt.concatenate(
            [inverse_stack, np.ones((tastes * trials, 1, length))], axis=1
        )
        weight_stack = weight_stack * inverse_stack
        weight_stack = tt.tile(
            weight_stack[:, :, None, :], (1, 1, nrns, 1))

        lambda_latent = lambda_latent.dimshuffle(2, 0, 1)
        lambda_latent = tt.repeat(lambda_latent, trials, axis=1)
        lambda_latent = tt.tile(
            lambda_latent[..., None], (1, 1, 1, length))
        lambda_latent = lambda_latent.dimshuffle(1, 2, 0, 3)
        lambda_ = tt.sum(lambda_latent * weight_stack, axis=1)

        observation = pm.Poisson("obs", lambda_, observed=data_array_long)

    return model

test()

Test the model with synthetic data

Source code in pytau/changepoint_model.py
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
def test(self):
    """Test the model with synthetic data"""
    # Generate test data
    test_data = gen_test_array(
        (2, 5, 10, 100), n_states=self.n_states, type="poisson")

    # Create model with test data
    test_model = AllTastePoissonVarsigFixed(
        test_data, self.n_states, self.inds_span)
    model = test_model.generate_model()

    # Run a minimal inference to verify model works
    with model:
        # Just do a few iterations to test functionality
        inference = pm.ADVI()
        approx = pm.fit(n=10, method=inference)
        trace = approx.sample(draws=10)

    # Check if expected variables are in the trace
    assert "lambda" in trace.varnames
    assert "tau" in trace.varnames
    assert "lambda_nrn" in trace.varnames
    assert "lambda_state" in trace.varnames

    print("Test for AllTastePoissonVarsigFixed passed")
    return True

CategoricalChangepoint2D

Bases: ChangepointModel

Model for categorical data changepoint detection on 2D arrays.

Source code in pytau/changepoint_model.py
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
class CategoricalChangepoint2D(ChangepointModel):
    """Model for categorical data changepoint detection on 2D arrays."""

    def __init__(self, data_array, n_states, **kwargs):
        """
        Args:
            data_array (2D Numpy array): trials x length
                - Each element is a postive integer representing a category
            n_states (int): Number of states to model
            **kwargs: Additional arguments

        """

        super().__init__(**kwargs)
        # Make sure data array is int
        if not np.issubdtype(data_array.dtype, np.integer):
            raise ValueError(
                "Data array must contain integer category values.")
        # Check that data_array is 2D
        if data_array.ndim != 2:
            # If 3D, take the first trial/dimension to make it 2D
            if data_array.ndim == 3:
                data_array = data_array[0]
            else:
                raise ValueError("Data array must be 2D (trials x length).")
        self.data_array = data_array
        self.n_states = n_states

    def generate_model(self):
        data_array = self.data_array
        n_states = self.n_states
        trials, length = data_array.shape
        features = len(np.unique(data_array))

        # If features in data_array are not continuous integer values, map them
        feature_set = np.unique(data_array)
        if not np.array_equal(feature_set, np.arange(len(feature_set))):
            # Create a mapping from original categories to continuous integers
            category_map = {cat: i for i, cat in enumerate(feature_set)}
            data_array = np.vectorize(category_map.get)(data_array)

        idx = np.arange(length)
        flat_data_array = data_array.reshape((trials * length,))

        with pm.Model() as model:
            p = pm.Dirichlet("p", a=np.ones(
                (n_states, features)), shape=(n_states, features))

            # Infer changepoint locations
            a_tau = pm.HalfCauchy("a_tau", 3.0, shape=n_states - 1)
            b_tau = pm.HalfCauchy("b_tau", 3.0, shape=n_states - 1)
            # Shape: trials x changepoints
            tau_latent = pm.Beta("tau_latent", a_tau, b_tau, shape=(trials, n_states - 1)).sort(
                axis=-1
            )

            tau = pm.Deterministic(
                "tau", idx.min() + (idx.max() - idx.min()) * tau_latent)

            weight_stack = tt.math.sigmoid(
                idx[np.newaxis, :] - tau[:, :, np.newaxis])
            weight_stack = tt.concatenate(
                [np.ones((trials, 1, length)), weight_stack], axis=1)
            inverse_stack = 1 - weight_stack[:, 1:]
            inverse_stack = tt.concatenate(
                [inverse_stack, np.ones((trials, 1, length))], axis=1)
            weight_stack = np.multiply(weight_stack, inverse_stack)

            # shapes:
            #   - weight_stack: trials x states x length
            #   - p : states x features

            # shape: trials x length x features
            lambda_ = tt.tensordot(weight_stack, p, [1, 0])

            flat_lambda = lambda_.reshape((trials * length, features))

            # Use categorical likelihood
            # data_array = trials x length
            category = pm.Categorical(
                "category", p=flat_lambda, observed=flat_data_array)

        return model

    def test(self):
        test_data = np.random.randint(0, self.n_states, size=(5, 100))
        test_model = CategoricalChangepoint2D(test_data, self.n_states)
        model = test_model.generate_model()
        with model:
            inference = pm.ADVI()
            approx = pm.fit(n=10, method=inference)
            trace = approx.sample(draws=10)
        assert "p" in trace.varnames
        assert "tau" in trace.varnames
        print("Test for CategoricalChangepoint2D passed")
        return True

__init__(data_array, n_states, **kwargs)

Parameters:

Name Type Description Default
data_array 2D Numpy array

trials x length - Each element is a postive integer representing a category

required
n_states int

Number of states to model

required
**kwargs

Additional arguments

{}
Source code in pytau/changepoint_model.py
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
def __init__(self, data_array, n_states, **kwargs):
    """
    Args:
        data_array (2D Numpy array): trials x length
            - Each element is a postive integer representing a category
        n_states (int): Number of states to model
        **kwargs: Additional arguments

    """

    super().__init__(**kwargs)
    # Make sure data array is int
    if not np.issubdtype(data_array.dtype, np.integer):
        raise ValueError(
            "Data array must contain integer category values.")
    # Check that data_array is 2D
    if data_array.ndim != 2:
        # If 3D, take the first trial/dimension to make it 2D
        if data_array.ndim == 3:
            data_array = data_array[0]
        else:
            raise ValueError("Data array must be 2D (trials x length).")
    self.data_array = data_array
    self.n_states = n_states

ChangepointModel

Base class for all changepoint models

Source code in pytau/changepoint_model.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class ChangepointModel:
    """Base class for all changepoint models"""

    def __init__(self, **kwargs):
        """Initialize model with keyword arguments"""
        self.kwargs = kwargs

    def generate_model(self):
        """Generate pymc model - to be implemented by subclasses"""
        raise NotImplementedError("Subclasses must implement generate_model()")

    def test(self):
        """Test model functionality - to be implemented by subclasses"""
        raise NotImplementedError("Subclasses must implement test()")

__init__(**kwargs)

Initialize model with keyword arguments

Source code in pytau/changepoint_model.py
27
28
29
def __init__(self, **kwargs):
    """Initialize model with keyword arguments"""
    self.kwargs = kwargs

generate_model()

Generate pymc model - to be implemented by subclasses

Source code in pytau/changepoint_model.py
31
32
33
def generate_model(self):
    """Generate pymc model - to be implemented by subclasses"""
    raise NotImplementedError("Subclasses must implement generate_model()")

test()

Test model functionality - to be implemented by subclasses

Source code in pytau/changepoint_model.py
35
36
37
def test(self):
    """Test model functionality - to be implemented by subclasses"""
    raise NotImplementedError("Subclasses must implement test()")

GaussianChangepointMean2D

Bases: ChangepointModel

Model for gaussian data on 2D array detecting changes only in the mean.

Source code in pytau/changepoint_model.py
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
class GaussianChangepointMean2D(ChangepointModel):
    """Model for gaussian data on 2D array detecting changes only in
    the mean.
    """

    def __init__(self, data_array, n_states, **kwargs):
        """
        Args:
            data_array (2D Numpy array): <dimension> x time
            n_states (int): Number of states to model
            **kwargs: Additional arguments
        """
        super().__init__(**kwargs)
        self.data_array = data_array
        self.n_states = n_states

    def generate_model(self):
        """
        Returns:
            pymc model: Model class containing graph to run inference on
        """
        data_array = self.data_array
        n_states = self.n_states

        mean_vals = np.array(
            [np.mean(x, axis=-1)
             for x in np.array_split(data_array, n_states, axis=-1)]
        ).T
        mean_vals += 0.01  # To avoid zero starting prob

        y_dim = data_array.shape[0]
        idx = np.arange(data_array.shape[-1])
        length = idx.max() + 1

        with pm.Model() as model:
            mu = pm.Normal("mu", mu=mean_vals, sigma=1,
                           shape=(y_dim, n_states))
            # One variance for each dimension
            sigma = pm.HalfCauchy("sigma", 3.0, shape=(y_dim))

            a_tau = pm.HalfCauchy("a_tau", 3.0, shape=n_states - 1)
            b_tau = pm.HalfCauchy("b_tau", 3.0, shape=n_states - 1)

            even_switches = np.linspace(0, 1, n_states + 1)[1:-1]
            tau_latent = pm.Beta(
                "tau_latent", a_tau, b_tau, initval=even_switches, shape=(n_states - 1)
            ).sort(axis=-1)

            tau = pm.Deterministic(
                "tau", idx.min() + (idx.max() - idx.min()) * tau_latent)

            weight_stack = tt.math.sigmoid(
                idx[np.newaxis, :] - tau[:, np.newaxis])
            weight_stack = tt.concatenate(
                [np.ones((1, length)), weight_stack], axis=0)
            inverse_stack = 1 - weight_stack[1:]
            inverse_stack = tt.concatenate(
                [inverse_stack, np.ones((1, length))], axis=0)
            weight_stack = np.multiply(weight_stack, inverse_stack)

            mu_latent = mu.dot(weight_stack)
            sigma_latent = sigma.dimshuffle(0, "x")
            observation = pm.Normal(
                "obs", mu=mu_latent, sigma=sigma_latent, observed=data_array)

        return model

    def test(self):
        """Test the model with synthetic data"""
        # Generate test data
        test_data = gen_test_array(
            (10, 100), n_states=self.n_states, type="normal")

        # Create model with test data
        test_model = GaussianChangepointMean2D(test_data, self.n_states)
        model = test_model.generate_model()

        # Run a minimal inference to verify model works
        with model:
            # Just do a few iterations to test functionality
            inference = pm.ADVI()
            approx = pm.fit(n=10, method=inference)
            trace = approx.sample(draws=10)

        # Check if expected variables are in the trace
        assert "mu" in trace.varnames
        assert "sigma" in trace.varnames
        assert "tau" in trace.varnames

        print("Test for GaussianChangepointMean2D passed")
        return True

__init__(data_array, n_states, **kwargs)

Parameters:

Name Type Description Default
data_array 2D Numpy array

x time

required
n_states int

Number of states to model

required
**kwargs

Additional arguments

{}
Source code in pytau/changepoint_model.py
368
369
370
371
372
373
374
375
376
377
def __init__(self, data_array, n_states, **kwargs):
    """
    Args:
        data_array (2D Numpy array): <dimension> x time
        n_states (int): Number of states to model
        **kwargs: Additional arguments
    """
    super().__init__(**kwargs)
    self.data_array = data_array
    self.n_states = n_states

generate_model()

Returns:

Type Description

pymc model: Model class containing graph to run inference on

Source code in pytau/changepoint_model.py
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
def generate_model(self):
    """
    Returns:
        pymc model: Model class containing graph to run inference on
    """
    data_array = self.data_array
    n_states = self.n_states

    mean_vals = np.array(
        [np.mean(x, axis=-1)
         for x in np.array_split(data_array, n_states, axis=-1)]
    ).T
    mean_vals += 0.01  # To avoid zero starting prob

    y_dim = data_array.shape[0]
    idx = np.arange(data_array.shape[-1])
    length = idx.max() + 1

    with pm.Model() as model:
        mu = pm.Normal("mu", mu=mean_vals, sigma=1,
                       shape=(y_dim, n_states))
        # One variance for each dimension
        sigma = pm.HalfCauchy("sigma", 3.0, shape=(y_dim))

        a_tau = pm.HalfCauchy("a_tau", 3.0, shape=n_states - 1)
        b_tau = pm.HalfCauchy("b_tau", 3.0, shape=n_states - 1)

        even_switches = np.linspace(0, 1, n_states + 1)[1:-1]
        tau_latent = pm.Beta(
            "tau_latent", a_tau, b_tau, initval=even_switches, shape=(n_states - 1)
        ).sort(axis=-1)

        tau = pm.Deterministic(
            "tau", idx.min() + (idx.max() - idx.min()) * tau_latent)

        weight_stack = tt.math.sigmoid(
            idx[np.newaxis, :] - tau[:, np.newaxis])
        weight_stack = tt.concatenate(
            [np.ones((1, length)), weight_stack], axis=0)
        inverse_stack = 1 - weight_stack[1:]
        inverse_stack = tt.concatenate(
            [inverse_stack, np.ones((1, length))], axis=0)
        weight_stack = np.multiply(weight_stack, inverse_stack)

        mu_latent = mu.dot(weight_stack)
        sigma_latent = sigma.dimshuffle(0, "x")
        observation = pm.Normal(
            "obs", mu=mu_latent, sigma=sigma_latent, observed=data_array)

    return model

test()

Test the model with synthetic data

Source code in pytau/changepoint_model.py
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
def test(self):
    """Test the model with synthetic data"""
    # Generate test data
    test_data = gen_test_array(
        (10, 100), n_states=self.n_states, type="normal")

    # Create model with test data
    test_model = GaussianChangepointMean2D(test_data, self.n_states)
    model = test_model.generate_model()

    # Run a minimal inference to verify model works
    with model:
        # Just do a few iterations to test functionality
        inference = pm.ADVI()
        approx = pm.fit(n=10, method=inference)
        trace = approx.sample(draws=10)

    # Check if expected variables are in the trace
    assert "mu" in trace.varnames
    assert "sigma" in trace.varnames
    assert "tau" in trace.varnames

    print("Test for GaussianChangepointMean2D passed")
    return True

GaussianChangepointMeanDirichlet

Bases: ChangepointModel

Model for gaussian data on 2D array detecting changes only in the mean. Number of states determined using dirichlet process prior.

Source code in pytau/changepoint_model.py
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
class GaussianChangepointMeanDirichlet(ChangepointModel):
    """Model for gaussian data on 2D array detecting changes only in
    the mean. Number of states determined using dirichlet process prior.
    """

    def __init__(self, data_array, max_states=15, **kwargs):
        """
        Args:
            data_array (2D Numpy array): <dimension> x time
            max_states (int): Max number of states to include in truncated dirichlet process
            **kwargs: Additional arguments
        """
        super().__init__(**kwargs)
        self.data_array = data_array
        self.max_states = max_states

    def generate_model(self):
        """
        Returns:
            pymc model: Model class containing graph to run inference on
        """
        data_array = self.data_array
        max_states = self.max_states

        y_dim = data_array.shape[0]
        idx = np.arange(data_array.shape[-1])
        length = idx.max() + 1

        mean_vals = np.array(
            [np.mean(x, axis=-1)
             for x in np.array_split(data_array, max_states, axis=-1)]
        ).T
        mean_vals += 0.01  # To avoid zero starting prob
        test_std = np.std(data_array, axis=-1)

        with pm.Model() as model:
            # ===================
            # Emissions Variables
            # ===================
            lambda_latent = pm.Normal(
                "lambda", mu=mean_vals, sigma=10, shape=(y_dim, max_states))
            # One variance for each dimension
            sigma = pm.HalfCauchy("sigma", test_std, shape=(y_dim))

            # =====================
            # Changepoint Variables
            # =====================

            # Hyperpriors on alpha
            a_gamma = pm.Gamma("a_gamma", 10, 1)
            b_gamma = pm.Gamma("b_gamma", 1.5, 1)

            # Concentration parameter for beta
            alpha = pm.Gamma("alpha", a_gamma, b_gamma)

            # Draw beta's to calculate stick lengths
            beta = pm.Beta("beta", 1, alpha, shape=max_states)

            # Calculate stick lengths using stick_breaking process
            w_raw = pm.Deterministic("w_raw", stick_breaking(beta))

            # Make sure lengths add to 1, and scale to length of data
            w_latent = pm.Deterministic("w_latent", w_raw / w_raw.sum())
            tau = pm.Deterministic("tau", tt.cumsum(w_latent * length)[:-1])

            # Weight stack to assign lambda's to point in time
            weight_stack = tt.math.sigmoid(
                idx[np.newaxis, :] - tau[:, np.newaxis])
            weight_stack = tt.concatenate(
                [np.ones((1, length)), weight_stack], axis=0)
            inverse_stack = 1 - weight_stack[1:]
            inverse_stack = tt.concatenate(
                [inverse_stack, np.ones((1, length))], axis=0)
            weight_stack = np.multiply(weight_stack, inverse_stack)

            # Create timeseries for latent variable (mean emission)
            lambda_ = pm.Deterministic(
                "lambda_", tt.tensordot(
                    lambda_latent, weight_stack, axes=(1, 0))
            )
            sigma_latent = sigma.dimshuffle(0, "x")

            # Likelihood for observations
            observation = pm.Normal(
                "obs", mu=lambda_, sigma=sigma_latent, observed=data_array)
        return model

    def test(self):
        """Test the model with synthetic data"""
        # Generate test data
        test_data = gen_test_array((10, 100), n_states=3, type="normal")

        # Create model with test data
        test_model = GaussianChangepointMeanDirichlet(test_data, max_states=5)
        model = test_model.generate_model()

        # Run a minimal inference to verify model works
        with model:
            # Just do a few iterations to test functionality
            inference = pm.ADVI()
            approx = pm.fit(n=10, method=inference)
            trace = approx.sample(draws=10)

        # Check if expected variables are in the trace
        assert "lambda" in trace.varnames
        assert "sigma" in trace.varnames
        assert "tau" in trace.varnames
        assert "w_latent" in trace.varnames

        print("Test for GaussianChangepointMeanDirichlet passed")
        return True

__init__(data_array, max_states=15, **kwargs)

Parameters:

Name Type Description Default
data_array 2D Numpy array

x time

required
max_states int

Max number of states to include in truncated dirichlet process

15
**kwargs

Additional arguments

{}
Source code in pytau/changepoint_model.py
244
245
246
247
248
249
250
251
252
253
def __init__(self, data_array, max_states=15, **kwargs):
    """
    Args:
        data_array (2D Numpy array): <dimension> x time
        max_states (int): Max number of states to include in truncated dirichlet process
        **kwargs: Additional arguments
    """
    super().__init__(**kwargs)
    self.data_array = data_array
    self.max_states = max_states

generate_model()

Returns:

Type Description

pymc model: Model class containing graph to run inference on

Source code in pytau/changepoint_model.py
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
def generate_model(self):
    """
    Returns:
        pymc model: Model class containing graph to run inference on
    """
    data_array = self.data_array
    max_states = self.max_states

    y_dim = data_array.shape[0]
    idx = np.arange(data_array.shape[-1])
    length = idx.max() + 1

    mean_vals = np.array(
        [np.mean(x, axis=-1)
         for x in np.array_split(data_array, max_states, axis=-1)]
    ).T
    mean_vals += 0.01  # To avoid zero starting prob
    test_std = np.std(data_array, axis=-1)

    with pm.Model() as model:
        # ===================
        # Emissions Variables
        # ===================
        lambda_latent = pm.Normal(
            "lambda", mu=mean_vals, sigma=10, shape=(y_dim, max_states))
        # One variance for each dimension
        sigma = pm.HalfCauchy("sigma", test_std, shape=(y_dim))

        # =====================
        # Changepoint Variables
        # =====================

        # Hyperpriors on alpha
        a_gamma = pm.Gamma("a_gamma", 10, 1)
        b_gamma = pm.Gamma("b_gamma", 1.5, 1)

        # Concentration parameter for beta
        alpha = pm.Gamma("alpha", a_gamma, b_gamma)

        # Draw beta's to calculate stick lengths
        beta = pm.Beta("beta", 1, alpha, shape=max_states)

        # Calculate stick lengths using stick_breaking process
        w_raw = pm.Deterministic("w_raw", stick_breaking(beta))

        # Make sure lengths add to 1, and scale to length of data
        w_latent = pm.Deterministic("w_latent", w_raw / w_raw.sum())
        tau = pm.Deterministic("tau", tt.cumsum(w_latent * length)[:-1])

        # Weight stack to assign lambda's to point in time
        weight_stack = tt.math.sigmoid(
            idx[np.newaxis, :] - tau[:, np.newaxis])
        weight_stack = tt.concatenate(
            [np.ones((1, length)), weight_stack], axis=0)
        inverse_stack = 1 - weight_stack[1:]
        inverse_stack = tt.concatenate(
            [inverse_stack, np.ones((1, length))], axis=0)
        weight_stack = np.multiply(weight_stack, inverse_stack)

        # Create timeseries for latent variable (mean emission)
        lambda_ = pm.Deterministic(
            "lambda_", tt.tensordot(
                lambda_latent, weight_stack, axes=(1, 0))
        )
        sigma_latent = sigma.dimshuffle(0, "x")

        # Likelihood for observations
        observation = pm.Normal(
            "obs", mu=lambda_, sigma=sigma_latent, observed=data_array)
    return model

test()

Test the model with synthetic data

Source code in pytau/changepoint_model.py
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
def test(self):
    """Test the model with synthetic data"""
    # Generate test data
    test_data = gen_test_array((10, 100), n_states=3, type="normal")

    # Create model with test data
    test_model = GaussianChangepointMeanDirichlet(test_data, max_states=5)
    model = test_model.generate_model()

    # Run a minimal inference to verify model works
    with model:
        # Just do a few iterations to test functionality
        inference = pm.ADVI()
        approx = pm.fit(n=10, method=inference)
        trace = approx.sample(draws=10)

    # Check if expected variables are in the trace
    assert "lambda" in trace.varnames
    assert "sigma" in trace.varnames
    assert "tau" in trace.varnames
    assert "w_latent" in trace.varnames

    print("Test for GaussianChangepointMeanDirichlet passed")
    return True

GaussianChangepointMeanVar2D

Bases: ChangepointModel

Model for gaussian data on 2D array detecting changes in both mean and variance.

Source code in pytau/changepoint_model.py
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
class GaussianChangepointMeanVar2D(ChangepointModel):
    """Model for gaussian data on 2D array detecting changes in both
    mean and variance.
    """

    def __init__(self, data_array, n_states, **kwargs):
        """
        Args:
            data_array (2D Numpy array): <dimension> x time
            n_states (int): Number of states to model
            **kwargs: Additional arguments
        """
        super().__init__(**kwargs)
        self.data_array = data_array
        self.n_states = n_states

    def generate_model(self):
        """
        Returns:
            pymc model: Model class containing graph to run inference on
        """
        data_array = self.data_array
        n_states = self.n_states

        mean_vals = np.array(
            [np.mean(x, axis=-1)
             for x in np.array_split(data_array, n_states, axis=-1)]
        ).T
        mean_vals += 0.01  # To avoid zero starting prob

        y_dim = data_array.shape[0]
        idx = np.arange(data_array.shape[-1])
        length = idx.max() + 1

        with pm.Model() as model:
            mu = pm.Normal("mu", mu=mean_vals, sigma=1,
                           shape=(y_dim, n_states))
            sigma = pm.HalfCauchy("sigma", 3.0, shape=(y_dim, n_states))

            a_tau = pm.HalfCauchy("a_tau", 3.0, shape=n_states - 1)
            b_tau = pm.HalfCauchy("b_tau", 3.0, shape=n_states - 1)

            even_switches = np.linspace(0, 1, n_states + 1)[1:-1]
            tau_latent = pm.Beta(
                "tau_latent", a_tau, b_tau, initval=even_switches, shape=(n_states - 1)
            ).sort(axis=-1)

            tau = pm.Deterministic(
                "tau", idx.min() + (idx.max() - idx.min()) * tau_latent)

            weight_stack = tt.math.sigmoid(
                idx[np.newaxis, :] - tau[:, np.newaxis])
            weight_stack = tt.concatenate(
                [np.ones((1, length)), weight_stack], axis=0)
            inverse_stack = 1 - weight_stack[1:]
            inverse_stack = tt.concatenate(
                [inverse_stack, np.ones((1, length))], axis=0)
            weight_stack = np.multiply(weight_stack, inverse_stack)

            mu_latent = mu.dot(weight_stack)
            sigma_latent = sigma.dot(weight_stack)
            observation = pm.Normal(
                "obs", mu=mu_latent, sigma=sigma_latent, observed=data_array)

        return model

    def test(self):
        """Test the model with synthetic data"""
        # Generate test data
        test_data = gen_test_array(
            (10, 100), n_states=self.n_states, type="normal")

        # Create model with test data
        test_model = GaussianChangepointMeanVar2D(test_data, self.n_states)
        model = test_model.generate_model()

        # Run a minimal inference to verify model works
        with model:
            # Just do a few iterations to test functionality
            inference = pm.ADVI()
            approx = pm.fit(n=10, method=inference)
            trace = approx.sample(draws=10)

        # Check if expected variables are in the trace
        assert "mu" in trace.varnames
        assert "sigma" in trace.varnames
        assert "tau" in trace.varnames

        print("Test for GaussianChangepointMeanVar2D passed")
        return True

__init__(data_array, n_states, **kwargs)

Parameters:

Name Type Description Default
data_array 2D Numpy array

x time

required
n_states int

Number of states to model

required
**kwargs

Additional arguments

{}
Source code in pytau/changepoint_model.py
139
140
141
142
143
144
145
146
147
148
def __init__(self, data_array, n_states, **kwargs):
    """
    Args:
        data_array (2D Numpy array): <dimension> x time
        n_states (int): Number of states to model
        **kwargs: Additional arguments
    """
    super().__init__(**kwargs)
    self.data_array = data_array
    self.n_states = n_states

generate_model()

Returns:

Type Description

pymc model: Model class containing graph to run inference on

Source code in pytau/changepoint_model.py
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
def generate_model(self):
    """
    Returns:
        pymc model: Model class containing graph to run inference on
    """
    data_array = self.data_array
    n_states = self.n_states

    mean_vals = np.array(
        [np.mean(x, axis=-1)
         for x in np.array_split(data_array, n_states, axis=-1)]
    ).T
    mean_vals += 0.01  # To avoid zero starting prob

    y_dim = data_array.shape[0]
    idx = np.arange(data_array.shape[-1])
    length = idx.max() + 1

    with pm.Model() as model:
        mu = pm.Normal("mu", mu=mean_vals, sigma=1,
                       shape=(y_dim, n_states))
        sigma = pm.HalfCauchy("sigma", 3.0, shape=(y_dim, n_states))

        a_tau = pm.HalfCauchy("a_tau", 3.0, shape=n_states - 1)
        b_tau = pm.HalfCauchy("b_tau", 3.0, shape=n_states - 1)

        even_switches = np.linspace(0, 1, n_states + 1)[1:-1]
        tau_latent = pm.Beta(
            "tau_latent", a_tau, b_tau, initval=even_switches, shape=(n_states - 1)
        ).sort(axis=-1)

        tau = pm.Deterministic(
            "tau", idx.min() + (idx.max() - idx.min()) * tau_latent)

        weight_stack = tt.math.sigmoid(
            idx[np.newaxis, :] - tau[:, np.newaxis])
        weight_stack = tt.concatenate(
            [np.ones((1, length)), weight_stack], axis=0)
        inverse_stack = 1 - weight_stack[1:]
        inverse_stack = tt.concatenate(
            [inverse_stack, np.ones((1, length))], axis=0)
        weight_stack = np.multiply(weight_stack, inverse_stack)

        mu_latent = mu.dot(weight_stack)
        sigma_latent = sigma.dot(weight_stack)
        observation = pm.Normal(
            "obs", mu=mu_latent, sigma=sigma_latent, observed=data_array)

    return model

test()

Test the model with synthetic data

Source code in pytau/changepoint_model.py
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
def test(self):
    """Test the model with synthetic data"""
    # Generate test data
    test_data = gen_test_array(
        (10, 100), n_states=self.n_states, type="normal")

    # Create model with test data
    test_model = GaussianChangepointMeanVar2D(test_data, self.n_states)
    model = test_model.generate_model()

    # Run a minimal inference to verify model works
    with model:
        # Just do a few iterations to test functionality
        inference = pm.ADVI()
        approx = pm.fit(n=10, method=inference)
        trace = approx.sample(draws=10)

    # Check if expected variables are in the trace
    assert "mu" in trace.varnames
    assert "sigma" in trace.varnames
    assert "tau" in trace.varnames

    print("Test for GaussianChangepointMeanVar2D passed")
    return True

PoissonChangepoint1D

Bases: ChangepointModel

Model for changepoint detection in 1D Poisson time series

This model detects changepoints in 1D time series data using a Poisson likelihood. It assumes the data follows a Poisson distribution with different rates in different segments separated by changepoints.

Source code in pytau/changepoint_model.py
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
class PoissonChangepoint1D(ChangepointModel):
    """Model for changepoint detection in 1D Poisson time series

    This model detects changepoints in 1D time series data using a Poisson likelihood.
    It assumes the data follows a Poisson distribution with different rates in different
    segments separated by changepoints.
    """

    def __init__(self, data_array, n_states, **kwargs):
        """
        Args:
            data_array (1D Numpy array): Time series data
            n_states (int): Number of states to model
            **kwargs: Additional arguments
        """
        super().__init__(**kwargs)
        self.data_array = np.asarray(data_array)
        if self.data_array.ndim != 1:
            raise ValueError("data_array must be 1-dimensional")
        self.n_states = n_states

    def generate_model(self):
        """
        Returns:
            pymc model: Model class containing graph to run inference on
        """
        data_array = self.data_array
        n_states = self.n_states

        # Calculate initial lambda values by splitting data into segments
        mean_vals = np.array([
            np.mean(x) for x in np.array_split(data_array, n_states)
        ])
        mean_vals += 0.01  # To avoid zero starting prob

        idx = np.arange(len(data_array))
        length = len(data_array)

        with pm.Model() as model:
            # Lambda parameters for each state (Poisson rates)
            lambda_latent = pm.Exponential(
                "lambda", 1 / mean_vals, shape=n_states
            )

            # Changepoint locations
            a_tau = pm.HalfCauchy("a_tau", 3.0, shape=n_states - 1)
            b_tau = pm.HalfCauchy("b_tau", 3.0, shape=n_states - 1)

            # Initialize changepoints evenly across the time series
            even_switches = np.linspace(0, 1, n_states + 1)[1:-1]
            tau_latent = pm.Beta(
                "tau_latent",
                a_tau,
                b_tau,
                initval=even_switches,
                shape=(n_states - 1)
            ).sort(axis=-1)

            # Convert to actual time indices
            tau = pm.Deterministic(
                "tau", idx.min() + (idx.max() - idx.min()) * tau_latent
            )

            # Create weight matrix for smooth transitions between states
            weight_stack = tt.math.sigmoid(
                idx[np.newaxis, :] - tau[:, np.newaxis]
            )
            weight_stack = tt.concatenate(
                [np.ones((1, length)), weight_stack], axis=0
            )
            inverse_stack = 1 - weight_stack[1:]
            inverse_stack = tt.concatenate(
                [inverse_stack, np.ones((1, length))], axis=0
            )
            weight_stack = weight_stack * inverse_stack

            # Calculate time-varying lambda
            lambda_t = lambda_latent.dot(weight_stack)

            # Observation model
            observation = pm.Poisson("obs", lambda_t, observed=data_array)

        return model

    def test(self):
        """Test the model with synthetic data"""
        # Generate test data - 1D array with 100 time points
        test_data = gen_test_array(100, n_states=self.n_states, type="poisson")

        # Create model with test data
        test_model = PoissonChangepoint1D(test_data, self.n_states)
        model = test_model.generate_model()

        # Run a minimal inference to verify model works
        with model:
            # Just do a few iterations to test functionality
            inference = pm.ADVI()
            approx = pm.fit(n=10, method=inference)
            trace = approx.sample(draws=10)

        # Check if expected variables are in the trace
        assert "lambda" in trace.varnames
        assert "tau" in trace.varnames

        print("Test for PoissonChangepoint1D passed")
        return True

__init__(data_array, n_states, **kwargs)

Parameters:

Name Type Description Default
data_array 1D Numpy array

Time series data

required
n_states int

Number of states to model

required
**kwargs

Additional arguments

{}
Source code in pytau/changepoint_model.py
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
def __init__(self, data_array, n_states, **kwargs):
    """
    Args:
        data_array (1D Numpy array): Time series data
        n_states (int): Number of states to model
        **kwargs: Additional arguments
    """
    super().__init__(**kwargs)
    self.data_array = np.asarray(data_array)
    if self.data_array.ndim != 1:
        raise ValueError("data_array must be 1-dimensional")
    self.n_states = n_states

generate_model()

Returns:

Type Description

pymc model: Model class containing graph to run inference on

Source code in pytau/changepoint_model.py
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
def generate_model(self):
    """
    Returns:
        pymc model: Model class containing graph to run inference on
    """
    data_array = self.data_array
    n_states = self.n_states

    # Calculate initial lambda values by splitting data into segments
    mean_vals = np.array([
        np.mean(x) for x in np.array_split(data_array, n_states)
    ])
    mean_vals += 0.01  # To avoid zero starting prob

    idx = np.arange(len(data_array))
    length = len(data_array)

    with pm.Model() as model:
        # Lambda parameters for each state (Poisson rates)
        lambda_latent = pm.Exponential(
            "lambda", 1 / mean_vals, shape=n_states
        )

        # Changepoint locations
        a_tau = pm.HalfCauchy("a_tau", 3.0, shape=n_states - 1)
        b_tau = pm.HalfCauchy("b_tau", 3.0, shape=n_states - 1)

        # Initialize changepoints evenly across the time series
        even_switches = np.linspace(0, 1, n_states + 1)[1:-1]
        tau_latent = pm.Beta(
            "tau_latent",
            a_tau,
            b_tau,
            initval=even_switches,
            shape=(n_states - 1)
        ).sort(axis=-1)

        # Convert to actual time indices
        tau = pm.Deterministic(
            "tau", idx.min() + (idx.max() - idx.min()) * tau_latent
        )

        # Create weight matrix for smooth transitions between states
        weight_stack = tt.math.sigmoid(
            idx[np.newaxis, :] - tau[:, np.newaxis]
        )
        weight_stack = tt.concatenate(
            [np.ones((1, length)), weight_stack], axis=0
        )
        inverse_stack = 1 - weight_stack[1:]
        inverse_stack = tt.concatenate(
            [inverse_stack, np.ones((1, length))], axis=0
        )
        weight_stack = weight_stack * inverse_stack

        # Calculate time-varying lambda
        lambda_t = lambda_latent.dot(weight_stack)

        # Observation model
        observation = pm.Poisson("obs", lambda_t, observed=data_array)

    return model

test()

Test the model with synthetic data

Source code in pytau/changepoint_model.py
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
def test(self):
    """Test the model with synthetic data"""
    # Generate test data - 1D array with 100 time points
    test_data = gen_test_array(100, n_states=self.n_states, type="poisson")

    # Create model with test data
    test_model = PoissonChangepoint1D(test_data, self.n_states)
    model = test_model.generate_model()

    # Run a minimal inference to verify model works
    with model:
        # Just do a few iterations to test functionality
        inference = pm.ADVI()
        approx = pm.fit(n=10, method=inference)
        trace = approx.sample(draws=10)

    # Check if expected variables are in the trace
    assert "lambda" in trace.varnames
    assert "tau" in trace.varnames

    print("Test for PoissonChangepoint1D passed")
    return True

SingleTastePoisson

Bases: ChangepointModel

Model for changepoint on single taste

** Largely taken from "non_hardcoded_changepoint_test_3d.ipynb" ** Note : This model does not have hierarchical structure for emissions

Source code in pytau/changepoint_model.py
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
class SingleTastePoisson(ChangepointModel):
    """Model for changepoint on single taste

    ** Largely taken from "non_hardcoded_changepoint_test_3d.ipynb"
    ** Note : This model does not have hierarchical structure for emissions
    """

    def __init__(self, data_array, n_states, **kwargs):
        """
        Args:
            data_array (3D Numpy array): trials x neurons x time
            n_states (int): Number of states to model
            **kwargs: Additional arguments
        """
        super().__init__(**kwargs)
        self.data_array = data_array
        self.n_states = n_states

    def generate_model(self):
        """
        Returns:
            pymc model: Model class containing graph to run inference on
        """
        data_array = self.data_array
        n_states = self.n_states

        mean_vals = np.array(
            [np.mean(x, axis=-1)
             for x in np.array_split(data_array, n_states, axis=-1)]
        ).T
        mean_vals = np.mean(mean_vals, axis=1)
        mean_vals += 0.01  # To avoid zero starting prob

        nrns = data_array.shape[1]
        trials = data_array.shape[0]
        idx = np.arange(data_array.shape[-1])
        length = idx.max() + 1

        with pm.Model() as model:
            lambda_latent = pm.Exponential(
                "lambda", 1 / mean_vals, shape=(nrns, n_states))

            a_tau = pm.HalfCauchy("a_tau", 3.0, shape=n_states - 1)
            b_tau = pm.HalfCauchy("b_tau", 3.0, shape=n_states - 1)

            even_switches = np.linspace(0, 1, n_states + 1)[1:-1]
            tau_latent = pm.Beta(
                "tau_latent",
                a_tau,
                b_tau,
                # initval=even_switches,
                shape=(trials, n_states - 1),
            ).sort(axis=-1)

            tau = pm.Deterministic(
                "tau", idx.min() + (idx.max() - idx.min()) * tau_latent)

            weight_stack = tt.math.sigmoid(
                idx[np.newaxis, :] - tau[:, :, np.newaxis])
            weight_stack = tt.concatenate(
                [np.ones((trials, 1, length)), weight_stack], axis=1)
            inverse_stack = 1 - weight_stack[:, 1:]
            inverse_stack = tt.concatenate(
                [inverse_stack, np.ones((trials, 1, length))], axis=1)
            weight_stack = np.multiply(weight_stack, inverse_stack)

            lambda_ = tt.tensordot(weight_stack, lambda_latent, [
                                   1, 1]).swapaxes(1, 2)
            observation = pm.Poisson("obs", lambda_, observed=data_array)

        return model

    def test(self):
        """Test the model with synthetic data"""
        # Generate test data
        test_data = gen_test_array(
            (5, 10, 100), n_states=self.n_states, type="poisson")

        # Create model with test data
        test_model = SingleTastePoisson(test_data, self.n_states)
        model = test_model.generate_model()

        # Run a minimal inference to verify model works
        with model:
            # Just do a few iterations to test functionality
            inference = pm.ADVI()
            approx = pm.fit(n=10, method=inference)
            trace = approx.sample(draws=10)

        # Check if expected variables are in the trace
        assert "lambda" in trace.varnames
        assert "tau" in trace.varnames

        print("Test for SingleTastePoisson passed")
        return True

__init__(data_array, n_states, **kwargs)

Parameters:

Name Type Description Default
data_array 3D Numpy array

trials x neurons x time

required
n_states int

Number of states to model

required
**kwargs

Additional arguments

{}
Source code in pytau/changepoint_model.py
608
609
610
611
612
613
614
615
616
617
def __init__(self, data_array, n_states, **kwargs):
    """
    Args:
        data_array (3D Numpy array): trials x neurons x time
        n_states (int): Number of states to model
        **kwargs: Additional arguments
    """
    super().__init__(**kwargs)
    self.data_array = data_array
    self.n_states = n_states

generate_model()

Returns:

Type Description

pymc model: Model class containing graph to run inference on

Source code in pytau/changepoint_model.py
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
def generate_model(self):
    """
    Returns:
        pymc model: Model class containing graph to run inference on
    """
    data_array = self.data_array
    n_states = self.n_states

    mean_vals = np.array(
        [np.mean(x, axis=-1)
         for x in np.array_split(data_array, n_states, axis=-1)]
    ).T
    mean_vals = np.mean(mean_vals, axis=1)
    mean_vals += 0.01  # To avoid zero starting prob

    nrns = data_array.shape[1]
    trials = data_array.shape[0]
    idx = np.arange(data_array.shape[-1])
    length = idx.max() + 1

    with pm.Model() as model:
        lambda_latent = pm.Exponential(
            "lambda", 1 / mean_vals, shape=(nrns, n_states))

        a_tau = pm.HalfCauchy("a_tau", 3.0, shape=n_states - 1)
        b_tau = pm.HalfCauchy("b_tau", 3.0, shape=n_states - 1)

        even_switches = np.linspace(0, 1, n_states + 1)[1:-1]
        tau_latent = pm.Beta(
            "tau_latent",
            a_tau,
            b_tau,
            # initval=even_switches,
            shape=(trials, n_states - 1),
        ).sort(axis=-1)

        tau = pm.Deterministic(
            "tau", idx.min() + (idx.max() - idx.min()) * tau_latent)

        weight_stack = tt.math.sigmoid(
            idx[np.newaxis, :] - tau[:, :, np.newaxis])
        weight_stack = tt.concatenate(
            [np.ones((trials, 1, length)), weight_stack], axis=1)
        inverse_stack = 1 - weight_stack[:, 1:]
        inverse_stack = tt.concatenate(
            [inverse_stack, np.ones((trials, 1, length))], axis=1)
        weight_stack = np.multiply(weight_stack, inverse_stack)

        lambda_ = tt.tensordot(weight_stack, lambda_latent, [
                               1, 1]).swapaxes(1, 2)
        observation = pm.Poisson("obs", lambda_, observed=data_array)

    return model

test()

Test the model with synthetic data

Source code in pytau/changepoint_model.py
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
def test(self):
    """Test the model with synthetic data"""
    # Generate test data
    test_data = gen_test_array(
        (5, 10, 100), n_states=self.n_states, type="poisson")

    # Create model with test data
    test_model = SingleTastePoisson(test_data, self.n_states)
    model = test_model.generate_model()

    # Run a minimal inference to verify model works
    with model:
        # Just do a few iterations to test functionality
        inference = pm.ADVI()
        approx = pm.fit(n=10, method=inference)
        trace = approx.sample(draws=10)

    # Check if expected variables are in the trace
    assert "lambda" in trace.varnames
    assert "tau" in trace.varnames

    print("Test for SingleTastePoisson passed")
    return True

SingleTastePoissonDirichlet

Bases: ChangepointModel

Model for changepoint on single taste using dirichlet process prior

Source code in pytau/changepoint_model.py
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
class SingleTastePoissonDirichlet(ChangepointModel):
    """
    Model for changepoint on single taste using dirichlet process prior
    """

    def __init__(self, data_array, max_states=10, **kwargs):
        """
        Args:
            data_array (3D Numpy array): trials x neurons x time
            max_states (int): Maximum number of states to model
            **kwargs: Additional arguments
        """
        super().__init__(**kwargs)
        self.data_array = data_array
        self.max_states = max_states

    def generate_model(self):
        """
        Returns:
            pymc model: Model class containing graph to run inference on
        """
        data_array = self.data_array
        max_states = self.max_states

        mean_vals = np.array(
            [np.mean(x, axis=-1)
             for x in np.array_split(data_array, max_states, axis=-1)]
        ).T
        mean_vals = np.mean(mean_vals, axis=1)
        mean_vals += 0.01  # To avoid zero starting prob

        nrns = data_array.shape[1]
        trials = data_array.shape[0]
        idx = np.arange(data_array.shape[-1])
        length = idx.max() + 1

        with pm.Model() as model:
            # ===================
            # Emissions Variables
            # ===================
            lambda_latent = pm.Exponential(
                "lambda", 1 / mean_vals, shape=(nrns, max_states))

            # =====================
            # Changepoint Variables
            # =====================

            # Hyperpriors on alpha
            a_gamma = pm.Gamma("a_gamma", 10, 1)
            b_gamma = pm.Gamma("b_gamma", 1.5, 1)

            # Concentration parameter for beta
            alpha = pm.Gamma("alpha", a_gamma, b_gamma)

            # Draw beta's to calculate stick lengths
            beta = pm.Beta("beta", 1, alpha, shape=(trials, max_states))

            # Calculate stick lengths using stick_breaking process
            w_raw = pm.Deterministic(
                "w_raw", stick_breaking_trial(beta, trials))

            # Make sure lengths add to 1, and scale to length of data
            w_latent = pm.Deterministic(
                "w_latent", w_raw / w_raw.sum(axis=-1)[:, None])
            tau = pm.Deterministic("tau", tt.cumsum(
                w_latent * length, axis=-1)[:, :-1])

            # =====================
            # Rate over time
            # =====================

            # Weight stack to assign lambda's to point in time
            weight_stack = tt.math.sigmoid(
                idx[np.newaxis, :] - tau[:, :, np.newaxis])
            weight_stack = tt.concatenate(
                [np.ones((trials, 1, length)), weight_stack], axis=1)
            inverse_stack = 1 - weight_stack[:, 1:]
            inverse_stack = tt.concatenate(
                [inverse_stack, np.ones((trials, 1, length))], axis=1)
            # Trials x States x Time
            weight_stack = np.multiply(weight_stack, inverse_stack)

            lambda_ = pm.Deterministic(
                "lambda_",
                tt.tensordot(weight_stack, lambda_latent,
                             [1, 1]).swapaxes(1, 2),
            )

            # =====================
            # Likelihood
            # =====================
            observation = pm.Poisson("obs", lambda_, observed=data_array)

        return model

    def test(self):
        """Test the model with synthetic data"""
        # Generate test data
        test_data = gen_test_array((5, 10, 100), n_states=3, type="poisson")

        # Create model with test data
        test_model = SingleTastePoissonDirichlet(test_data, max_states=5)
        model = test_model.generate_model()

        # Run a minimal inference to verify model works
        with model:
            # Just do a few iterations to test functionality
            inference = pm.ADVI()
            approx = pm.fit(n=10, method=inference)
            trace = approx.sample(draws=10)

        # Check if expected variables are in the trace
        assert "lambda" in trace.varnames
        assert "tau" in trace.varnames
        assert "w_latent" in trace.varnames

        print("Test for SingleTastePoissonDirichlet passed")
        return True

__init__(data_array, max_states=10, **kwargs)

Parameters:

Name Type Description Default
data_array 3D Numpy array

trials x neurons x time

required
max_states int

Maximum number of states to model

10
**kwargs

Additional arguments

{}
Source code in pytau/changepoint_model.py
479
480
481
482
483
484
485
486
487
488
def __init__(self, data_array, max_states=10, **kwargs):
    """
    Args:
        data_array (3D Numpy array): trials x neurons x time
        max_states (int): Maximum number of states to model
        **kwargs: Additional arguments
    """
    super().__init__(**kwargs)
    self.data_array = data_array
    self.max_states = max_states

generate_model()

Returns:

Type Description

pymc model: Model class containing graph to run inference on

Source code in pytau/changepoint_model.py
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
def generate_model(self):
    """
    Returns:
        pymc model: Model class containing graph to run inference on
    """
    data_array = self.data_array
    max_states = self.max_states

    mean_vals = np.array(
        [np.mean(x, axis=-1)
         for x in np.array_split(data_array, max_states, axis=-1)]
    ).T
    mean_vals = np.mean(mean_vals, axis=1)
    mean_vals += 0.01  # To avoid zero starting prob

    nrns = data_array.shape[1]
    trials = data_array.shape[0]
    idx = np.arange(data_array.shape[-1])
    length = idx.max() + 1

    with pm.Model() as model:
        # ===================
        # Emissions Variables
        # ===================
        lambda_latent = pm.Exponential(
            "lambda", 1 / mean_vals, shape=(nrns, max_states))

        # =====================
        # Changepoint Variables
        # =====================

        # Hyperpriors on alpha
        a_gamma = pm.Gamma("a_gamma", 10, 1)
        b_gamma = pm.Gamma("b_gamma", 1.5, 1)

        # Concentration parameter for beta
        alpha = pm.Gamma("alpha", a_gamma, b_gamma)

        # Draw beta's to calculate stick lengths
        beta = pm.Beta("beta", 1, alpha, shape=(trials, max_states))

        # Calculate stick lengths using stick_breaking process
        w_raw = pm.Deterministic(
            "w_raw", stick_breaking_trial(beta, trials))

        # Make sure lengths add to 1, and scale to length of data
        w_latent = pm.Deterministic(
            "w_latent", w_raw / w_raw.sum(axis=-1)[:, None])
        tau = pm.Deterministic("tau", tt.cumsum(
            w_latent * length, axis=-1)[:, :-1])

        # =====================
        # Rate over time
        # =====================

        # Weight stack to assign lambda's to point in time
        weight_stack = tt.math.sigmoid(
            idx[np.newaxis, :] - tau[:, :, np.newaxis])
        weight_stack = tt.concatenate(
            [np.ones((trials, 1, length)), weight_stack], axis=1)
        inverse_stack = 1 - weight_stack[:, 1:]
        inverse_stack = tt.concatenate(
            [inverse_stack, np.ones((trials, 1, length))], axis=1)
        # Trials x States x Time
        weight_stack = np.multiply(weight_stack, inverse_stack)

        lambda_ = pm.Deterministic(
            "lambda_",
            tt.tensordot(weight_stack, lambda_latent,
                         [1, 1]).swapaxes(1, 2),
        )

        # =====================
        # Likelihood
        # =====================
        observation = pm.Poisson("obs", lambda_, observed=data_array)

    return model

test()

Test the model with synthetic data

Source code in pytau/changepoint_model.py
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
def test(self):
    """Test the model with synthetic data"""
    # Generate test data
    test_data = gen_test_array((5, 10, 100), n_states=3, type="poisson")

    # Create model with test data
    test_model = SingleTastePoissonDirichlet(test_data, max_states=5)
    model = test_model.generate_model()

    # Run a minimal inference to verify model works
    with model:
        # Just do a few iterations to test functionality
        inference = pm.ADVI()
        approx = pm.fit(n=10, method=inference)
        trace = approx.sample(draws=10)

    # Check if expected variables are in the trace
    assert "lambda" in trace.varnames
    assert "tau" in trace.varnames
    assert "w_latent" in trace.varnames

    print("Test for SingleTastePoissonDirichlet passed")
    return True

SingleTastePoissonTrialSwitch

Bases: ChangepointModel

Assuming only emissions change across trials Changepoint distribution remains constant

Source code in pytau/changepoint_model.py
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
class SingleTastePoissonTrialSwitch(ChangepointModel):
    """
    Assuming only emissions change across trials
    Changepoint distribution remains constant
    """

    def __init__(self, data_array, switch_components, n_states, **kwargs):
        """
        Args:
            data_array (3D Numpy array): trials x neurons x time
            switch_components (int): Number of trial switch components
            n_states (int): Number of states to model
            **kwargs: Additional arguments
        """
        super().__init__(**kwargs)
        self.data_array = data_array
        self.switch_components = switch_components
        self.n_states = n_states

    def generate_model(self):
        """
        Returns:
            pymc model: Model class containing graph to run inference on
        """
        data_array = self.data_array
        switch_components = self.switch_components
        n_states = self.n_states

        trial_num, nrn_num, time_bins = data_array.shape

        with pm.Model() as model:
            # Define Emissions

            # nrns
            nrn_lambda = pm.Exponential("nrn_lambda", 10, shape=(nrn_num))

            # nrns x switch_comps
            trial_lambda = pm.Exponential(
                "trial_lambda",
                nrn_lambda.dimshuffle(0, "x"),
                shape=(nrn_num, switch_components),
            )

            # nrns x switch_comps x n_states
            state_lambda = pm.Exponential(
                "state_lambda",
                trial_lambda.dimshuffle(0, 1, "x"),
                shape=(nrn_num, switch_components, n_states),
            )

            # Define Changepoints
            # Assuming distribution of changepoints remains
            # the same across all trials

            a = pm.HalfCauchy("a_tau", 3.0, shape=n_states - 1)
            b = pm.HalfCauchy("b_tau", 3.0, shape=n_states - 1)

            even_switches = np.linspace(0, 1, n_states + 1)[1:-1]
            tau_latent = pm.Beta(
                "tau_latent", a, b,
                # initval=even_switches,
                shape=(trial_num, n_states - 1)
            ).sort(axis=-1)

            # Trials x Changepoints
            tau = pm.Deterministic("tau", time_bins * tau_latent)

            # Define trial switches
            # Will have same structure as regular changepoints

            even_trial_switches = np.linspace(
                0, 1, switch_components + 1)[1:-1]
            tau_trial_latent = pm.Beta(
                "tau_trial_latent",
                1,
                1,
                initval=even_trial_switches,
                shape=(switch_components - 1),
            ).sort(axis=-1)

            # Trial_changepoints
            tau_trial = pm.Deterministic(
                "tau_trial", trial_num * tau_trial_latent)

            trial_idx = np.arange(trial_num)
            trial_selector = tt.math.sigmoid(
                trial_idx[np.newaxis, :] - tau_trial.dimshuffle(0, "x")
            )

            trial_selector = tt.concatenate(
                [np.ones((1, trial_num)), trial_selector], axis=0)
            inverse_trial_selector = 1 - trial_selector[1:, :]
            inverse_trial_selector = tt.concatenate(
                [inverse_trial_selector, np.ones((1, trial_num))], axis=0
            )

            # First, we can "select" sets of emissions depending on trial_changepoints
            # switch_comps x trials
            trial_selector = np.multiply(
                trial_selector, inverse_trial_selector)

            # state_lambda: nrns x switch_comps x states

            # selected_trial_lambda : nrns x states x trials
            selected_trial_lambda = pm.Deterministic(
                "selected_trial_lambda",
                tt.sum(
                    # "nrns" x switch_comps x "states" x trials
                    trial_selector.dimshuffle("x", 0, "x", 1)
                    * state_lambda.dimshuffle(0, 1, 2, "x"),
                    axis=1,
                ),
            )

            # Then, we can select state_emissions for every trial
            idx = np.arange(time_bins)

            # tau : Trials x Changepoints
            weight_stack = tt.math.sigmoid(
                idx[np.newaxis, :] - tau[:, :, np.newaxis])
            weight_stack = tt.concatenate(
                [np.ones((trial_num, 1, time_bins)), weight_stack], axis=1
            )
            inverse_stack = 1 - weight_stack[:, 1:]
            inverse_stack = tt.concatenate(
                [inverse_stack, np.ones((trial_num, 1, time_bins))], axis=1
            )

            # Trials x states x Time
            weight_stack = np.multiply(weight_stack, inverse_stack)

            # Convert selected_trial_lambda : nrns x trials x states x "time"

            # nrns x trials x time
            lambda_ = tt.sum(
                selected_trial_lambda.dimshuffle(0, 2, 1, "x")
                * weight_stack.dimshuffle("x", 0, 1, 2),
                axis=2,
            )

            # Convert to : trials x nrns x time
            lambda_ = lambda_.dimshuffle(1, 0, 2)

            # Add observations
            observation = pm.Poisson("obs", lambda_, observed=data_array)

        return model

    def test(self):
        """Test the model with synthetic data"""
        # Generate test data
        test_data = gen_test_array(
            (5, 10, 100), n_states=self.n_states, type="poisson")

        # Create model with test data
        test_model = SingleTastePoissonTrialSwitch(
            test_data, self.switch_components, self.n_states)
        model = test_model.generate_model()

        # Run a minimal inference to verify model works
        with model:
            # Just do a few iterations to test functionality
            inference = pm.ADVI()
            approx = pm.fit(n=10, method=inference)
            trace = approx.sample(draws=10)

        # Check if expected variables are in the trace
        assert "nrn_lambda" in trace.varnames
        assert "tau" in trace.varnames
        assert "tau_trial" in trace.varnames
        assert "state_lambda" in trace.varnames

        print("Test for SingleTastePoissonTrialSwitch passed")
        return True

__init__(data_array, switch_components, n_states, **kwargs)

Parameters:

Name Type Description Default
data_array 3D Numpy array

trials x neurons x time

required
switch_components int

Number of trial switch components

required
n_states int

Number of states to model

required
**kwargs

Additional arguments

{}
Source code in pytau/changepoint_model.py
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
def __init__(self, data_array, switch_components, n_states, **kwargs):
    """
    Args:
        data_array (3D Numpy array): trials x neurons x time
        switch_components (int): Number of trial switch components
        n_states (int): Number of states to model
        **kwargs: Additional arguments
    """
    super().__init__(**kwargs)
    self.data_array = data_array
    self.switch_components = switch_components
    self.n_states = n_states

generate_model()

Returns:

Type Description

pymc model: Model class containing graph to run inference on

Source code in pytau/changepoint_model.py
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
def generate_model(self):
    """
    Returns:
        pymc model: Model class containing graph to run inference on
    """
    data_array = self.data_array
    switch_components = self.switch_components
    n_states = self.n_states

    trial_num, nrn_num, time_bins = data_array.shape

    with pm.Model() as model:
        # Define Emissions

        # nrns
        nrn_lambda = pm.Exponential("nrn_lambda", 10, shape=(nrn_num))

        # nrns x switch_comps
        trial_lambda = pm.Exponential(
            "trial_lambda",
            nrn_lambda.dimshuffle(0, "x"),
            shape=(nrn_num, switch_components),
        )

        # nrns x switch_comps x n_states
        state_lambda = pm.Exponential(
            "state_lambda",
            trial_lambda.dimshuffle(0, 1, "x"),
            shape=(nrn_num, switch_components, n_states),
        )

        # Define Changepoints
        # Assuming distribution of changepoints remains
        # the same across all trials

        a = pm.HalfCauchy("a_tau", 3.0, shape=n_states - 1)
        b = pm.HalfCauchy("b_tau", 3.0, shape=n_states - 1)

        even_switches = np.linspace(0, 1, n_states + 1)[1:-1]
        tau_latent = pm.Beta(
            "tau_latent", a, b,
            # initval=even_switches,
            shape=(trial_num, n_states - 1)
        ).sort(axis=-1)

        # Trials x Changepoints
        tau = pm.Deterministic("tau", time_bins * tau_latent)

        # Define trial switches
        # Will have same structure as regular changepoints

        even_trial_switches = np.linspace(
            0, 1, switch_components + 1)[1:-1]
        tau_trial_latent = pm.Beta(
            "tau_trial_latent",
            1,
            1,
            initval=even_trial_switches,
            shape=(switch_components - 1),
        ).sort(axis=-1)

        # Trial_changepoints
        tau_trial = pm.Deterministic(
            "tau_trial", trial_num * tau_trial_latent)

        trial_idx = np.arange(trial_num)
        trial_selector = tt.math.sigmoid(
            trial_idx[np.newaxis, :] - tau_trial.dimshuffle(0, "x")
        )

        trial_selector = tt.concatenate(
            [np.ones((1, trial_num)), trial_selector], axis=0)
        inverse_trial_selector = 1 - trial_selector[1:, :]
        inverse_trial_selector = tt.concatenate(
            [inverse_trial_selector, np.ones((1, trial_num))], axis=0
        )

        # First, we can "select" sets of emissions depending on trial_changepoints
        # switch_comps x trials
        trial_selector = np.multiply(
            trial_selector, inverse_trial_selector)

        # state_lambda: nrns x switch_comps x states

        # selected_trial_lambda : nrns x states x trials
        selected_trial_lambda = pm.Deterministic(
            "selected_trial_lambda",
            tt.sum(
                # "nrns" x switch_comps x "states" x trials
                trial_selector.dimshuffle("x", 0, "x", 1)
                * state_lambda.dimshuffle(0, 1, 2, "x"),
                axis=1,
            ),
        )

        # Then, we can select state_emissions for every trial
        idx = np.arange(time_bins)

        # tau : Trials x Changepoints
        weight_stack = tt.math.sigmoid(
            idx[np.newaxis, :] - tau[:, :, np.newaxis])
        weight_stack = tt.concatenate(
            [np.ones((trial_num, 1, time_bins)), weight_stack], axis=1
        )
        inverse_stack = 1 - weight_stack[:, 1:]
        inverse_stack = tt.concatenate(
            [inverse_stack, np.ones((trial_num, 1, time_bins))], axis=1
        )

        # Trials x states x Time
        weight_stack = np.multiply(weight_stack, inverse_stack)

        # Convert selected_trial_lambda : nrns x trials x states x "time"

        # nrns x trials x time
        lambda_ = tt.sum(
            selected_trial_lambda.dimshuffle(0, 2, 1, "x")
            * weight_stack.dimshuffle("x", 0, 1, 2),
            axis=2,
        )

        # Convert to : trials x nrns x time
        lambda_ = lambda_.dimshuffle(1, 0, 2)

        # Add observations
        observation = pm.Poisson("obs", lambda_, observed=data_array)

    return model

test()

Test the model with synthetic data

Source code in pytau/changepoint_model.py
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
def test(self):
    """Test the model with synthetic data"""
    # Generate test data
    test_data = gen_test_array(
        (5, 10, 100), n_states=self.n_states, type="poisson")

    # Create model with test data
    test_model = SingleTastePoissonTrialSwitch(
        test_data, self.switch_components, self.n_states)
    model = test_model.generate_model()

    # Run a minimal inference to verify model works
    with model:
        # Just do a few iterations to test functionality
        inference = pm.ADVI()
        approx = pm.fit(n=10, method=inference)
        trace = approx.sample(draws=10)

    # Check if expected variables are in the trace
    assert "nrn_lambda" in trace.varnames
    assert "tau" in trace.varnames
    assert "tau_trial" in trace.varnames
    assert "state_lambda" in trace.varnames

    print("Test for SingleTastePoissonTrialSwitch passed")
    return True

SingleTastePoissonVarsig

Bases: ChangepointModel

Model for changepoint on single taste **Uses variables sigmoid slope inferred from data

** Largely taken from "non_hardcoded_changepoint_test_3d.ipynb" ** Note : This model does not have hierarchical structure for emissions

Source code in pytau/changepoint_model.py
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
class SingleTastePoissonVarsig(ChangepointModel):
    """Model for changepoint on single taste
    **Uses variables sigmoid slope inferred from data

    ** Largely taken from "non_hardcoded_changepoint_test_3d.ipynb"
    ** Note : This model does not have hierarchical structure for emissions
    """

    def __init__(self, data_array, n_states, **kwargs):
        """
        Args:
            data_array (3D Numpy array): trials x neurons x time
            n_states (int): Number of states to model
            **kwargs: Additional arguments
        """
        super().__init__(**kwargs)
        self.data_array = data_array
        self.n_states = n_states

    def generate_model(self):
        """
        Returns:
            pymc model: Model class containing graph to run inference on
        """
        data_array = self.data_array
        n_states = self.n_states

        mean_vals = np.array(
            [np.mean(x, axis=-1)
             for x in np.array_split(data_array, n_states, axis=-1)]
        ).T
        mean_vals = np.mean(mean_vals, axis=1)
        mean_vals += 0.01  # To avoid zero starting prob

        lambda_test_vals = np.diff(mean_vals, axis=-1)
        even_switches = np.linspace(0, 1, n_states + 1)[1:-1]

        nrns = data_array.shape[1]
        trials = data_array.shape[0]
        idx = np.arange(data_array.shape[-1])
        length = idx.max() + 1

        with pm.Model() as model:
            # Sigmoid slope
            sig_b = pm.Normal("sig_b", -1, 2, shape=n_states - 1)

            # Initial value
            s0 = pm.Exponential(
                "state0", 1 / (np.mean(mean_vals)), shape=nrns, initval=mean_vals[:, 0]
            )

            # Changes to lambda
            lambda_diff = pm.Normal(
                "lambda_diff",
                mu=0,
                sigma=10,
                shape=(nrns, n_states - 1),
                initval=lambda_test_vals,
            )

            # This is only here to be extracted at the end of sampling
            # NOT USED DIRECTLY IN MODEL
            lambda_fin = pm.Deterministic(
                "lambda", tt.concatenate(
                    [s0[:, np.newaxis], lambda_diff], axis=-1)
            )

            # Changepoint positions
            a = pm.HalfCauchy("a_tau", 10, shape=n_states - 1)
            b = pm.HalfCauchy("b_tau", 10, shape=n_states - 1)

            tau_latent = pm.Beta(
                "tau_latent", a, b,
                # initval=even_switches,
                shape=(trials, n_states - 1)
            ).sort(axis=-1)
            tau = pm.Deterministic(
                "tau", idx.min() + (idx.max() - idx.min()) * tau_latent)

            # Mechanical manipulations to generate firing rates
            idx_temp = np.tile(
                idx[np.newaxis, np.newaxis, :], (trials, n_states - 1, 1))
            tau_temp = tt.tile(tau[:, :, np.newaxis], (1, 1, len(idx)))
            sig_b_temp = tt.tile(
                sig_b[np.newaxis, :, np.newaxis], (trials, 1, len(idx)))

            weight_stack = var_sig_exp_tt(idx_temp - tau_temp, sig_b_temp)
            weight_stack_temp = tt.tile(
                weight_stack[:, np.newaxis, :, :], (1, nrns, 1, 1))

            s0_temp = tt.tile(
                s0[np.newaxis, :, np.newaxis, np.newaxis],
                (trials, 1, n_states - 1, len(idx)),
            )
            lambda_diff_temp = tt.tile(
                lambda_diff[np.newaxis, :, :,
                            np.newaxis], (trials, 1, 1, len(idx))
            )

            # Calculate lambda
            lambda_ = pm.Deterministic(
                "lambda_",
                tt.sum(s0_temp + (weight_stack_temp * lambda_diff_temp), axis=2),
            )
            # Bound lambda to prevent the diffs from making it negative
            # Don't let it go down to zero otherwise we have trouble with probabilities
            lambda_bounded = pm.Deterministic(
                "lambda_bounded", tt.switch(lambda_ >= 0.01, lambda_, 0.01)
            )

            # Add observations
            observation = pm.Poisson(
                "obs", lambda_bounded, observed=data_array)

        return model

    def test(self):
        """Test the model with synthetic data"""
        # Generate test data
        test_data = gen_test_array(
            (5, 10, 100), n_states=self.n_states, type="poisson")

        # Create model with test data
        test_model = SingleTastePoissonVarsig(test_data, self.n_states)
        model = test_model.generate_model()

        # Run a minimal inference to verify model works
        with model:
            # Just do a few iterations to test functionality
            inference = pm.ADVI()
            approx = pm.fit(n=10, method=inference)
            trace = approx.sample(draws=10)

        # Check if expected variables are in the trace
        assert "lambda" in trace.varnames
        assert "tau" in trace.varnames
        assert "sig_b" in trace.varnames

        print("Test for SingleTastePoissonVarsig passed")
        return True

__init__(data_array, n_states, **kwargs)

Parameters:

Name Type Description Default
data_array 3D Numpy array

trials x neurons x time

required
n_states int

Number of states to model

required
**kwargs

Additional arguments

{}
Source code in pytau/changepoint_model.py
729
730
731
732
733
734
735
736
737
738
def __init__(self, data_array, n_states, **kwargs):
    """
    Args:
        data_array (3D Numpy array): trials x neurons x time
        n_states (int): Number of states to model
        **kwargs: Additional arguments
    """
    super().__init__(**kwargs)
    self.data_array = data_array
    self.n_states = n_states

generate_model()

Returns:

Type Description

pymc model: Model class containing graph to run inference on

Source code in pytau/changepoint_model.py
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
def generate_model(self):
    """
    Returns:
        pymc model: Model class containing graph to run inference on
    """
    data_array = self.data_array
    n_states = self.n_states

    mean_vals = np.array(
        [np.mean(x, axis=-1)
         for x in np.array_split(data_array, n_states, axis=-1)]
    ).T
    mean_vals = np.mean(mean_vals, axis=1)
    mean_vals += 0.01  # To avoid zero starting prob

    lambda_test_vals = np.diff(mean_vals, axis=-1)
    even_switches = np.linspace(0, 1, n_states + 1)[1:-1]

    nrns = data_array.shape[1]
    trials = data_array.shape[0]
    idx = np.arange(data_array.shape[-1])
    length = idx.max() + 1

    with pm.Model() as model:
        # Sigmoid slope
        sig_b = pm.Normal("sig_b", -1, 2, shape=n_states - 1)

        # Initial value
        s0 = pm.Exponential(
            "state0", 1 / (np.mean(mean_vals)), shape=nrns, initval=mean_vals[:, 0]
        )

        # Changes to lambda
        lambda_diff = pm.Normal(
            "lambda_diff",
            mu=0,
            sigma=10,
            shape=(nrns, n_states - 1),
            initval=lambda_test_vals,
        )

        # This is only here to be extracted at the end of sampling
        # NOT USED DIRECTLY IN MODEL
        lambda_fin = pm.Deterministic(
            "lambda", tt.concatenate(
                [s0[:, np.newaxis], lambda_diff], axis=-1)
        )

        # Changepoint positions
        a = pm.HalfCauchy("a_tau", 10, shape=n_states - 1)
        b = pm.HalfCauchy("b_tau", 10, shape=n_states - 1)

        tau_latent = pm.Beta(
            "tau_latent", a, b,
            # initval=even_switches,
            shape=(trials, n_states - 1)
        ).sort(axis=-1)
        tau = pm.Deterministic(
            "tau", idx.min() + (idx.max() - idx.min()) * tau_latent)

        # Mechanical manipulations to generate firing rates
        idx_temp = np.tile(
            idx[np.newaxis, np.newaxis, :], (trials, n_states - 1, 1))
        tau_temp = tt.tile(tau[:, :, np.newaxis], (1, 1, len(idx)))
        sig_b_temp = tt.tile(
            sig_b[np.newaxis, :, np.newaxis], (trials, 1, len(idx)))

        weight_stack = var_sig_exp_tt(idx_temp - tau_temp, sig_b_temp)
        weight_stack_temp = tt.tile(
            weight_stack[:, np.newaxis, :, :], (1, nrns, 1, 1))

        s0_temp = tt.tile(
            s0[np.newaxis, :, np.newaxis, np.newaxis],
            (trials, 1, n_states - 1, len(idx)),
        )
        lambda_diff_temp = tt.tile(
            lambda_diff[np.newaxis, :, :,
                        np.newaxis], (trials, 1, 1, len(idx))
        )

        # Calculate lambda
        lambda_ = pm.Deterministic(
            "lambda_",
            tt.sum(s0_temp + (weight_stack_temp * lambda_diff_temp), axis=2),
        )
        # Bound lambda to prevent the diffs from making it negative
        # Don't let it go down to zero otherwise we have trouble with probabilities
        lambda_bounded = pm.Deterministic(
            "lambda_bounded", tt.switch(lambda_ >= 0.01, lambda_, 0.01)
        )

        # Add observations
        observation = pm.Poisson(
            "obs", lambda_bounded, observed=data_array)

    return model

test()

Test the model with synthetic data

Source code in pytau/changepoint_model.py
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
def test(self):
    """Test the model with synthetic data"""
    # Generate test data
    test_data = gen_test_array(
        (5, 10, 100), n_states=self.n_states, type="poisson")

    # Create model with test data
    test_model = SingleTastePoissonVarsig(test_data, self.n_states)
    model = test_model.generate_model()

    # Run a minimal inference to verify model works
    with model:
        # Just do a few iterations to test functionality
        inference = pm.ADVI()
        approx = pm.fit(n=10, method=inference)
        trace = approx.sample(draws=10)

    # Check if expected variables are in the trace
    assert "lambda" in trace.varnames
    assert "tau" in trace.varnames
    assert "sig_b" in trace.varnames

    print("Test for SingleTastePoissonVarsig passed")
    return True

SingleTastePoissonVarsigFixed

Bases: ChangepointModel

Model for changepoint on single taste **Uses sigmoid with given slope

** Largely taken from "non_hardcoded_changepoint_test_3d.ipynb" ** Note : This model does not have hierarchical structure for emissions

Source code in pytau/changepoint_model.py
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
class SingleTastePoissonVarsigFixed(ChangepointModel):
    """Model for changepoint on single taste
    **Uses sigmoid with given slope

    ** Largely taken from "non_hardcoded_changepoint_test_3d.ipynb"
    ** Note : This model does not have hierarchical structure for emissions
    """

    def __init__(self, data_array, n_states, inds_span=1, **kwargs):
        """
        Args:
            data_array (3D Numpy array): trials x neurons x time
            n_states (int): Number of states to model
            inds_span(float) : Number of indices to cover 5-95% change in sigmoid
            **kwargs: Additional arguments
        """
        super().__init__(**kwargs)
        self.data_array = data_array
        self.n_states = n_states
        self.inds_span = inds_span

    def generate_model(self):
        """
        Returns:
            pymc model: Model class containing graph to run inference on
        """
        data_array = self.data_array
        n_states = self.n_states
        inds_span = self.inds_span

        mean_vals = np.array(
            [np.mean(x, axis=-1)
             for x in np.array_split(data_array, n_states, axis=-1)]
        ).T
        mean_vals = np.mean(mean_vals, axis=1)
        mean_vals += 0.01  # To avoid zero starting prob

        lambda_test_vals = np.diff(mean_vals, axis=-1)
        even_switches = np.linspace(0, 1, n_states + 1)[1:-1]

        nrns = data_array.shape[1]
        trials = data_array.shape[0]
        idx = np.arange(data_array.shape[-1])
        length = idx.max() + 1

        # Define sigmoid with given sharpness
        sig_b = inds_to_b(inds_span)

        def sigmoid(x):
            b_temp = tt.tile(
                np.array(sig_b)[None, None, None], x.tag.test_value.shape)
            return 1 / (1 + tt.exp(-b_temp * x))

        with pm.Model() as model:
            # Initial value
            s0 = pm.Exponential(
                "state0", 1 / (np.mean(mean_vals)), shape=nrns, initval=mean_vals[:, 0]
            )

            # Changes to lambda
            lambda_diff = pm.Normal(
                "lambda_diff",
                mu=0,
                sigma=10,
                shape=(nrns, n_states - 1),
                initval=lambda_test_vals,
            )

            # This is only here to be extracted at the end of sampling
            # NOT USED DIRECTLY IN MODEL
            lambda_fin = pm.Deterministic(
                "lambda", tt.concatenate(
                    [s0[:, np.newaxis], lambda_diff], axis=-1)
            )

            # Changepoint positions
            a = pm.HalfCauchy("a_tau", 10, shape=n_states - 1)
            b = pm.HalfCauchy("b_tau", 10, shape=n_states - 1)

            tau_latent = pm.Beta(
                "tau_latent", a, b,
                # initval=even_switches,
                shape=(trials, n_states - 1)
            ).sort(axis=-1)
            tau = pm.Deterministic(
                "tau", idx.min() + (idx.max() - idx.min()) * tau_latent)

            # Mechanical manipulations to generate firing rates
            idx_temp = np.tile(
                idx[np.newaxis, np.newaxis, :], (trials, n_states - 1, 1))
            tau_temp = tt.tile(tau[:, :, np.newaxis], (1, 1, len(idx)))

            weight_stack = sigmoid(idx_temp - tau_temp)
            weight_stack_temp = tt.tile(
                weight_stack[:, np.newaxis, :, :], (1, nrns, 1, 1))

            s0_temp = tt.tile(
                s0[np.newaxis, :, np.newaxis, np.newaxis],
                (trials, 1, n_states - 1, len(idx)),
            )
            lambda_diff_temp = tt.tile(
                lambda_diff[np.newaxis, :, :,
                            np.newaxis], (trials, 1, 1, len(idx))
            )

            # Calculate lambda
            lambda_ = pm.Deterministic(
                "lambda_",
                tt.sum(s0_temp + (weight_stack_temp * lambda_diff_temp), axis=2),
            )
            # Bound lambda to prevent the diffs from making it negative
            # Don't let it go down to zero otherwise we have trouble with probabilities
            lambda_bounded = pm.Deterministic(
                "lambda_bounded", tt.switch(lambda_ >= 0.01, lambda_, 0.01)
            )

            # Add observations
            observation = pm.Poisson(
                "obs", lambda_bounded, observed=data_array)

        return model

    def test(self):
        """Test the model with synthetic data"""
        # Generate test data
        test_data = gen_test_array(
            (5, 10, 100), n_states=self.n_states, type="poisson")

        # Create model with test data
        test_model = SingleTastePoissonVarsigFixed(
            test_data, self.n_states, self.inds_span)
        model = test_model.generate_model()

        # Run a minimal inference to verify model works
        with model:
            # Just do a few iterations to test functionality
            inference = pm.ADVI()
            approx = pm.fit(n=10, method=inference)
            trace = approx.sample(draws=10)

        # Check if expected variables are in the trace
        assert "lambda" in trace.varnames
        assert "tau" in trace.varnames
        assert "state0" in trace.varnames

        print("Test for SingleTastePoissonVarsigFixed passed")
        return True

__init__(data_array, n_states, inds_span=1, **kwargs)

Parameters:

Name Type Description Default
data_array 3D Numpy array

trials x neurons x time

required
n_states int

Number of states to model

required
inds_span(float)

Number of indices to cover 5-95% change in sigmoid

required
**kwargs

Additional arguments

{}
Source code in pytau/changepoint_model.py
882
883
884
885
886
887
888
889
890
891
892
893
def __init__(self, data_array, n_states, inds_span=1, **kwargs):
    """
    Args:
        data_array (3D Numpy array): trials x neurons x time
        n_states (int): Number of states to model
        inds_span(float) : Number of indices to cover 5-95% change in sigmoid
        **kwargs: Additional arguments
    """
    super().__init__(**kwargs)
    self.data_array = data_array
    self.n_states = n_states
    self.inds_span = inds_span

generate_model()

Returns:

Type Description

pymc model: Model class containing graph to run inference on

Source code in pytau/changepoint_model.py
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
def generate_model(self):
    """
    Returns:
        pymc model: Model class containing graph to run inference on
    """
    data_array = self.data_array
    n_states = self.n_states
    inds_span = self.inds_span

    mean_vals = np.array(
        [np.mean(x, axis=-1)
         for x in np.array_split(data_array, n_states, axis=-1)]
    ).T
    mean_vals = np.mean(mean_vals, axis=1)
    mean_vals += 0.01  # To avoid zero starting prob

    lambda_test_vals = np.diff(mean_vals, axis=-1)
    even_switches = np.linspace(0, 1, n_states + 1)[1:-1]

    nrns = data_array.shape[1]
    trials = data_array.shape[0]
    idx = np.arange(data_array.shape[-1])
    length = idx.max() + 1

    # Define sigmoid with given sharpness
    sig_b = inds_to_b(inds_span)

    def sigmoid(x):
        b_temp = tt.tile(
            np.array(sig_b)[None, None, None], x.tag.test_value.shape)
        return 1 / (1 + tt.exp(-b_temp * x))

    with pm.Model() as model:
        # Initial value
        s0 = pm.Exponential(
            "state0", 1 / (np.mean(mean_vals)), shape=nrns, initval=mean_vals[:, 0]
        )

        # Changes to lambda
        lambda_diff = pm.Normal(
            "lambda_diff",
            mu=0,
            sigma=10,
            shape=(nrns, n_states - 1),
            initval=lambda_test_vals,
        )

        # This is only here to be extracted at the end of sampling
        # NOT USED DIRECTLY IN MODEL
        lambda_fin = pm.Deterministic(
            "lambda", tt.concatenate(
                [s0[:, np.newaxis], lambda_diff], axis=-1)
        )

        # Changepoint positions
        a = pm.HalfCauchy("a_tau", 10, shape=n_states - 1)
        b = pm.HalfCauchy("b_tau", 10, shape=n_states - 1)

        tau_latent = pm.Beta(
            "tau_latent", a, b,
            # initval=even_switches,
            shape=(trials, n_states - 1)
        ).sort(axis=-1)
        tau = pm.Deterministic(
            "tau", idx.min() + (idx.max() - idx.min()) * tau_latent)

        # Mechanical manipulations to generate firing rates
        idx_temp = np.tile(
            idx[np.newaxis, np.newaxis, :], (trials, n_states - 1, 1))
        tau_temp = tt.tile(tau[:, :, np.newaxis], (1, 1, len(idx)))

        weight_stack = sigmoid(idx_temp - tau_temp)
        weight_stack_temp = tt.tile(
            weight_stack[:, np.newaxis, :, :], (1, nrns, 1, 1))

        s0_temp = tt.tile(
            s0[np.newaxis, :, np.newaxis, np.newaxis],
            (trials, 1, n_states - 1, len(idx)),
        )
        lambda_diff_temp = tt.tile(
            lambda_diff[np.newaxis, :, :,
                        np.newaxis], (trials, 1, 1, len(idx))
        )

        # Calculate lambda
        lambda_ = pm.Deterministic(
            "lambda_",
            tt.sum(s0_temp + (weight_stack_temp * lambda_diff_temp), axis=2),
        )
        # Bound lambda to prevent the diffs from making it negative
        # Don't let it go down to zero otherwise we have trouble with probabilities
        lambda_bounded = pm.Deterministic(
            "lambda_bounded", tt.switch(lambda_ >= 0.01, lambda_, 0.01)
        )

        # Add observations
        observation = pm.Poisson(
            "obs", lambda_bounded, observed=data_array)

    return model

test()

Test the model with synthetic data

Source code in pytau/changepoint_model.py
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
def test(self):
    """Test the model with synthetic data"""
    # Generate test data
    test_data = gen_test_array(
        (5, 10, 100), n_states=self.n_states, type="poisson")

    # Create model with test data
    test_model = SingleTastePoissonVarsigFixed(
        test_data, self.n_states, self.inds_span)
    model = test_model.generate_model()

    # Run a minimal inference to verify model works
    with model:
        # Just do a few iterations to test functionality
        inference = pm.ADVI()
        approx = pm.fit(n=10, method=inference)
        trace = approx.sample(draws=10)

    # Check if expected variables are in the trace
    assert "lambda" in trace.varnames
    assert "tau" in trace.varnames
    assert "state0" in trace.varnames

    print("Test for SingleTastePoissonVarsigFixed passed")
    return True

advi_fit(model, fit, samples, convergence_tol=None)

Convenience function to perform ADVI fit on model

Parameters:

Name Type Description Default
model pymc model

model object to run inference on

required
fit int

Number of iterationst to fit the model for

required
samples int

Number of samples to draw from fitted model

required

Returns:

Name Type Description
model

original model on which inference was run,

approx

fitted model,

lambda_stack

array containing lambda (emission) values,

tau_samples,: array containing samples from changepoint distribution

model.obs.observations: processed array on which fit was run

Source code in pytau/changepoint_model.py
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
def advi_fit(model, fit, samples, convergence_tol=None):
    """Convenience function to perform ADVI fit on model

    Args:
        model (pymc model): model object to run inference on
        fit (int): Number of iterationst to fit the model for
        samples (int): Number of samples to draw from fitted model

    Returns:
        model: original model on which inference was run,
        approx: fitted model,
        lambda_stack: array containing lambda (emission) values,
        tau_samples,: array containing samples from changepoint distribution
        model.obs.observations: processed array on which fit was run
    """

    if convergence_tol is not None:
        callbacks = [pm.callbacks.CheckParametersConvergence(
            tolerance=convergence_tol)]
        print("Using convergence callback with tolerance:", convergence_tol)
    else:
        callbacks = None
    with model:
        inference = pm.ADVI("full-rank")
        approx = pm.fit(n=fit, method=inference, callbacks=callbacks)
        idata = approx.sample(draws=samples)

    # Check if tau exists in posterior samples (PyMC5 uses InferenceData)
    if "tau" not in idata.posterior.data_vars:
        available_vars = list(idata.posterior.data_vars.keys())
        raise KeyError(
            f"'tau' not found in posterior samples. Available variables: {available_vars}")

    # Extract relevant variables from InferenceData posterior
    try:
        tau_samples = idata.posterior["tau"].values
        # Handle potential dimension issues
        if tau_samples.ndim > 2:
            tau_samples = tau_samples.reshape(-1, tau_samples.shape[-1])
    except Exception as e:
        print(f"Error extracting tau samples: {e}")
        tau_samples = None

    # Get observed data from model (PyMC5 compatible)
    # Since notebooks don't use fit_data, return None to avoid compatibility issues
    observed_data = None

    if "lambda" in idata.posterior.data_vars:
        try:
            lambda_stack = idata.posterior["lambda"].values
            # Handle potential dimension issues
            if lambda_stack.ndim > 3:
                lambda_stack = lambda_stack.reshape(-1,
                                                    *lambda_stack.shape[-2:])
            lambda_stack = lambda_stack.swapaxes(0, 1)
            return model, approx, lambda_stack, tau_samples, observed_data
        except Exception as e:
            print(f"Error extracting lambda samples: {e}")
            return model, approx, None, tau_samples, observed_data

    if "mu" in idata.posterior.data_vars:
        try:
            mu_stack = idata.posterior["mu"].values
            sigma_stack = idata.posterior["sigma"].values
            # Handle potential dimension issues
            if mu_stack.ndim > 3:
                mu_stack = mu_stack.reshape(-1, *mu_stack.shape[-2:])
            if sigma_stack.ndim > 3:
                sigma_stack = sigma_stack.reshape(-1, *sigma_stack.shape[-2:])
            mu_stack = mu_stack.swapaxes(0, 1)
            sigma_stack = sigma_stack.swapaxes(0, 1)
            return model, approx, mu_stack, sigma_stack, tau_samples, observed_data
        except Exception as e:
            print(f"Error extracting mu/sigma samples: {e}")
            return model, approx, None, None, tau_samples, observed_data

    # Fallback - return what we can
    return model, approx, None, tau_samples, observed_data

all_taste_poisson(data_array, n_states, **kwargs)

Wrapper function for backward compatibility

Source code in pytau/changepoint_model.py
1176
1177
1178
1179
def all_taste_poisson(data_array, n_states, **kwargs):
    """Wrapper function for backward compatibility"""
    model_class = AllTastePoisson(data_array, n_states, **kwargs)
    return model_class.generate_model()

all_taste_poisson_trial_switch(data_array, switch_components, n_states, **kwargs)

Wrapper function for backward compatibility

Source code in pytau/changepoint_model.py
1848
1849
1850
1851
1852
def all_taste_poisson_trial_switch(data_array, switch_components, n_states, **kwargs):
    """Wrapper function for backward compatibility"""
    model_class = AllTastePoissonTrialSwitch(
        data_array, switch_components, n_states, **kwargs)
    return model_class.generate_model()

all_taste_poisson_varsig_fixed(data_array, n_states, inds_span=1, **kwargs)

Wrapper function for backward compatibility

Source code in pytau/changepoint_model.py
1338
1339
1340
1341
1342
def all_taste_poisson_varsig_fixed(data_array, n_states, inds_span=1, **kwargs):
    """Wrapper function for backward compatibility"""
    model_class = AllTastePoissonVarsigFixed(
        data_array, n_states, inds_span, **kwargs)
    return model_class.generate_model()

dpp_fit(model, n_chains=24, n_cores=1, tune=500, draws=500, use_numpyro=False)

Convenience function to fit DPP model

Source code in pytau/changepoint_model.py
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
def dpp_fit(model, n_chains=24, n_cores=1, tune=500, draws=500, use_numpyro=False):
    """Convenience function to fit DPP model"""
    if not use_numpyro:
        with model:
            dpp_trace = pm.sample(
                tune=tune,
                draws=draws,
                target_accept=0.95,
                chains=n_chains,
                cores=n_cores,
                return_inferencedata=False,
            )
    else:
        with model:
            dpp_trace = pm.sample(
                nuts_sampler="numpyro",
                tune=tune,
                draws=draws,
                target_accept=0.95,
                chains=n_chains,
                cores=n_cores,
                return_inferencedata=False,
            )
    return dpp_trace

extract_inferred_values(trace)

Convenience function to extract inferred values from ADVI fit

Parameters:

Name Type Description Default
trace dict

trace

required

Returns:

Name Type Description
dict

dictionary of inferred values

Source code in pytau/changepoint_model.py
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
def extract_inferred_values(trace):
    """Convenience function to extract inferred values from ADVI fit

    Args:
        trace (dict): trace

    Returns:
        dict: dictionary of inferred values
    """
    # Extract relevant variables from trace
    out_dict = dict(tau_samples=trace["tau"])
    if "lambda" in trace.varnames:
        out_dict["lambda_stack"] = trace["lambda"].swapaxes(0, 1)
    if "mu" in trace.varnames:
        out_dict["mu_stack"] = trace["mu"].swapaxes(0, 1)
        out_dict["sigma_stack"] = trace["sigma"].swapaxes(0, 1)
    return out_dict

find_best_states(data, model_generator, n_fit, n_samples, min_states=2, max_states=10, convergence_tol=None)

Convenience function to find best number of states for model

Parameters:

Name Type Description Default
data array

array on which to run inference

required
model_generator function

function that generates model

required
n_fit int

Number of iterationst to fit the model for

required
n_samples int

Number of samples to draw from fitted model

required
min_states int

Minimum number of states to test

2
max_states int

Maximum number of states to test

10
convergence_tol float

Tolerance for convergence. If None, will not check for convergence.

None

Returns:

Name Type Description
best_model

model with best number of states,

model_list

list of models with different number of states,

elbo_values

list of elbo values for different number of states

Source code in pytau/changepoint_model.py
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
def find_best_states(
        data,
        model_generator,
        n_fit, n_samples,
        min_states=2,
        max_states=10,
        convergence_tol=None,
):
    """Convenience function to find best number of states for model

    Args:
        data (array): array on which to run inference
        model_generator (function): function that generates model
        n_fit (int): Number of iterationst to fit the model for
        n_samples (int): Number of samples to draw from fitted model
        min_states (int): Minimum number of states to test
        max_states (int): Maximum number of states to test
        convergence_tol (float): Tolerance for convergence. If None, will not check for convergence.

    Returns:
        best_model: model with best number of states,
        model_list: list of models with different number of states,
        elbo_values: list of elbo values for different number of states
    """
    n_state_array = np.arange(min_states, max_states + 1)
    elbo_values = []
    model_list = []
    for n_states in tqdm(n_state_array):
        print(f"Fitting model with {n_states} states")
        # Have to use int instead of np.int64
        model = model_generator(data, int(n_states))
        model, approx = advi_fit(model, n_fit, n_samples, convergence_tol)[:2]
        elbo_values.append(approx.hist[-1])
        model_list.append(model)
    best_model = model_list[np.argmin(elbo_values)]
    return best_model, model_list, elbo_values

gaussian_changepoint_mean_2d(data_array, n_states, **kwargs)

Wrapper function for backward compatibility

Source code in pytau/changepoint_model.py
457
458
459
460
def gaussian_changepoint_mean_2d(data_array, n_states, **kwargs):
    """Wrapper function for backward compatibility"""
    model_class = GaussianChangepointMean2D(data_array, n_states, **kwargs)
    return model_class.generate_model()

gaussian_changepoint_mean_dirichlet(data_array, max_states=15, **kwargs)

Wrapper function for backward compatibility

Source code in pytau/changepoint_model.py
353
354
355
356
357
def gaussian_changepoint_mean_dirichlet(data_array, max_states=15, **kwargs):
    """Wrapper function for backward compatibility"""
    model_class = GaussianChangepointMeanDirichlet(
        data_array, max_states, **kwargs)
    return model_class.generate_model()

gaussian_changepoint_mean_var_2d(data_array, n_states, **kwargs)

Wrapper function for backward compatibility

Source code in pytau/changepoint_model.py
227
228
229
230
def gaussian_changepoint_mean_var_2d(data_array, n_states, **kwargs):
    """Wrapper function for backward compatibility"""
    model_class = GaussianChangepointMeanVar2D(data_array, n_states, **kwargs)
    return model_class.generate_model()

gen_test_array(array_size, n_states, type='poisson')

Generate test array for model fitting Last 2 dimensions consist of a single trial Time will always be last dimension

Parameters:

Name Type Description Default
array_size tuple or int

Size of array to generate. If int, generates 1D array.

required
n_states int

Number of states to generate

required
type str

Type of data to generate - normal - poisson

'poisson'
Source code in pytau/changepoint_model.py
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
def gen_test_array(array_size, n_states, type="poisson"):
    """
    Generate test array for model fitting
    Last 2 dimensions consist of a single trial
    Time will always be last dimension

    Args:
        array_size (tuple or int): Size of array to generate. If int, generates 1D array.
        n_states (int): Number of states to generate
        type (str): Type of data to generate
            - normal
            - poisson
    """
    # Handle 1D case
    if isinstance(array_size, int):
        assert array_size > n_states, "Array too small for states"
        assert type in [
            "normal", "poisson"], "Invalid type, please use normal or poisson"

        # Generate transition times for 1D case
        transition_times = np.random.random(n_states)
        transition_times = np.cumsum(transition_times)
        transition_times = transition_times / transition_times.max()
        transition_times *= array_size
        transition_times = transition_times.astype(int)

        # Generate state bounds
        state_bounds = np.zeros(n_states + 1, dtype=int)
        state_bounds[1:] = transition_times
        state_bounds[-1] = array_size

        # Generate state rates
        lambda_vals = np.random.exponential(2.0, n_states) + 0.5

        # Generate 1D array
        rate_array = np.zeros(array_size)
        for i in range(n_states):
            start_idx = state_bounds[i]
            end_idx = state_bounds[i + 1]
            rate_array[start_idx:end_idx] = lambda_vals[i]

        if type == "poisson":
            return np.random.poisson(rate_array)
        else:
            return np.random.normal(loc=rate_array, scale=0.1)

    # Handle multi-dimensional case (existing code)
    assert array_size[-1] > n_states, "Array too small for states"
    assert type in [
        "normal", "poisson"], "Invalid type, please use normal or poisson"

    # Generate transition times
    transition_times = np.random.random((*array_size[:-2], n_states))
    transition_times = np.cumsum(transition_times, axis=-1)
    transition_times = transition_times / \
        transition_times.max(axis=-1, keepdims=True)
    transition_times *= array_size[-1]
    transition_times = np.vectorize(int)(transition_times)

    # Generate state bounds
    state_bounds = np.zeros((*array_size[:-2], n_states + 1), dtype=int)
    state_bounds[..., 1:] = transition_times

    # Generate state rates
    lambda_vals = np.random.random((*array_size[:-1], n_states))

    # Generate array
    rate_array = np.zeros(array_size)
    inds = list(np.ndindex(lambda_vals.shape))
    for this_ind in inds:
        this_lambda = lambda_vals[this_ind[:-2]][:, this_ind[-1]]
        this_state_bounds = [
            state_bounds[(*this_ind[:-2], this_ind[-1])],
            state_bounds[(*this_ind[:-2], this_ind[-1] + 1)],
        ]
        rate_array[this_ind[:-2]][:,
                                  slice(*this_state_bounds)] = this_lambda[:, None]

    if type == "poisson":
        return np.random.poisson(rate_array)
    else:
        return np.random.normal(loc=rate_array, scale=0.1)

mcmc_fit(model, samples)

Convenience function to perform ADVI fit on model

Parameters:

Name Type Description Default
model pymc model

model object to run inference on

required
samples int

Number of samples to draw using MCMC

required

Returns:

Name Type Description
model

original model on which inference was run,

trace

samples drawn from MCMC,

lambda_stack

array containing lambda (emission) values,

tau_samples,: array containing samples from changepoint distribution

model.obs.observations: processed array on which fit was run

Source code in pytau/changepoint_model.py
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
def mcmc_fit(model, samples):
    """Convenience function to perform ADVI fit on model

    Args:
        model (pymc model): model object to run inference on
        samples (int): Number of samples to draw using MCMC

    Returns:
        model: original model on which inference was run,
        trace:  samples drawn from MCMC,
        lambda_stack: array containing lambda (emission) values,
        tau_samples,: array containing samples from changepoint distribution
        model.obs.observations: processed array on which fit was run
    """

    with model:
        sampler_kwargs = {"cores": 1, "chains": 4}
        idata = pm.sample(draws=samples, **sampler_kwargs)
        # Thin the samples (every 10th sample)
        idata_thinned = idata.sel(draw=slice(None, None, 10))

    # Extract relevant variables from InferenceData posterior
    try:
        tau_samples = idata_thinned.posterior["tau"].values
        # Handle potential dimension issues
        if tau_samples.ndim > 2:
            tau_samples = tau_samples.reshape(-1, tau_samples.shape[-1])
    except Exception as e:
        print(f"Error extracting tau samples: {e}")
        tau_samples = None

    # Get observed data from model (PyMC5 compatible)
    # Since notebooks don't use fit_data, return None to avoid compatibility issues
    observed_data = None

    if "lambda" in idata_thinned.posterior.data_vars:
        try:
            lambda_stack = idata_thinned.posterior["lambda"].values
            # Handle potential dimension issues
            if lambda_stack.ndim > 3:
                lambda_stack = lambda_stack.reshape(-1,
                                                    *lambda_stack.shape[-2:])
            lambda_stack = lambda_stack.swapaxes(0, 1)
            return model, idata_thinned, lambda_stack, tau_samples, observed_data
        except Exception as e:
            print(f"Error extracting lambda samples: {e}")
            return model, idata_thinned, None, tau_samples, observed_data

    if "mu" in idata_thinned.posterior.data_vars:
        try:
            mu_stack = idata_thinned.posterior["mu"].values
            sigma_stack = idata_thinned.posterior["sigma"].values
            # Handle potential dimension issues
            if mu_stack.ndim > 3:
                mu_stack = mu_stack.reshape(-1, *mu_stack.shape[-2:])
            if sigma_stack.ndim > 3:
                sigma_stack = sigma_stack.reshape(-1, *sigma_stack.shape[-2:])
            mu_stack = mu_stack.swapaxes(0, 1)
            sigma_stack = sigma_stack.swapaxes(0, 1)
            return model, idata_thinned, mu_stack, sigma_stack, tau_samples, observed_data
        except Exception as e:
            print(f"Error extracting mu/sigma samples: {e}")
            return model, idata_thinned, None, None, tau_samples, observed_data

    # Fallback - return what we can
    return model, idata_thinned, None, tau_samples, observed_data

poisson_changepoint_1d(data_array, n_states, **kwargs)

Wrapper function for backward compatibility

Source code in pytau/changepoint_model.py
2013
2014
2015
2016
def poisson_changepoint_1d(data_array, n_states, **kwargs):
    """Wrapper function for backward compatibility"""
    model_class = PoissonChangepoint1D(data_array, n_states, **kwargs)
    return model_class.generate_model()

run_all_tests()

Run tests for all model classes

Source code in pytau/changepoint_model.py
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
def run_all_tests():
    """Run tests for all model classes"""
    # Create test data
    test_data_1d = gen_test_array(100, n_states=3, type="poisson")
    test_data_2d = gen_test_array((10, 100), n_states=3, type="normal")
    test_data_3d = gen_test_array((5, 10, 100), n_states=3, type="poisson")
    test_data_4d = gen_test_array((2, 5, 10, 100), n_states=3, type="poisson")

    # Test each model class
    models_to_test = [
        PoissonChangepoint1D(test_data_1d, 3),
        GaussianChangepointMeanVar2D(test_data_2d, 3),
        GaussianChangepointMeanDirichlet(test_data_2d, 5),
        GaussianChangepointMean2D(test_data_2d, 3),
        SingleTastePoissonDirichlet(test_data_3d, 5),
        SingleTastePoisson(test_data_3d, 3),
        SingleTastePoissonVarsig(test_data_3d, 3),
        SingleTastePoissonVarsigFixed(test_data_3d, 3, 1),
        SingleTastePoissonTrialSwitch(test_data_3d, 2, 3),
        AllTastePoisson(test_data_4d, 3),
        AllTastePoissonVarsigFixed(test_data_4d, 3, 1),
        AllTastePoissonTrialSwitch(test_data_4d, 2, 3),
    ]

    failed_tests = []
    pbar = tqdm(models_to_test, total=len(models_to_test))
    for model in pbar:
        try:
            model.test()
            pbar.set_description(f"Test passed for {model.__class__.__name__}")
        except Exception as e:
            failed_tests.append(model.__class__.__name__)
            print(f"Test failed for {model.__class__.__name__}: {str(e)}")

    print("All tests completed")
    if failed_tests:
        print("Failed tests:", failed_tests)

single_taste_poisson(data_array, n_states, **kwargs)

Wrapper function for backward compatibility

Source code in pytau/changepoint_model.py
699
700
701
702
def single_taste_poisson(data_array, n_states, **kwargs):
    """Wrapper function for backward compatibility"""
    model_class = SingleTastePoisson(data_array, n_states, **kwargs)
    return model_class.generate_model()

single_taste_poisson_dirichlet(data_array, max_states=10, **kwargs)

Wrapper function for backward compatibility

Source code in pytau/changepoint_model.py
595
596
597
598
def single_taste_poisson_dirichlet(data_array, max_states=10, **kwargs):
    """Wrapper function for backward compatibility"""
    model_class = SingleTastePoissonDirichlet(data_array, max_states, **kwargs)
    return model_class.generate_model()

single_taste_poisson_trial_switch(data_array, switch_components, n_states, **kwargs)

Wrapper function for backward compatibility

Source code in pytau/changepoint_model.py
1529
1530
1531
1532
1533
def single_taste_poisson_trial_switch(data_array, switch_components, n_states, **kwargs):
    """Wrapper function for backward compatibility"""
    model_class = SingleTastePoissonTrialSwitch(
        data_array, switch_components, n_states, **kwargs)
    return model_class.generate_model()

single_taste_poisson_varsig(data_array, n_states, **kwargs)

Wrapper function for backward compatibility

Source code in pytau/changepoint_model.py
864
865
866
867
def single_taste_poisson_varsig(data_array, n_states, **kwargs):
    """Wrapper function for backward compatibility"""
    model_class = SingleTastePoissonVarsig(data_array, n_states, **kwargs)
    return model_class.generate_model()

single_taste_poisson_varsig_fixed(data_array, n_states, inds_span=1, **kwargs)

Wrapper function for backward compatibility

Source code in pytau/changepoint_model.py
1024
1025
1026
1027
1028
def single_taste_poisson_varsig_fixed(data_array, n_states, inds_span=1, **kwargs):
    """Wrapper function for backward compatibility"""
    model_class = SingleTastePoissonVarsigFixed(
        data_array, n_states, inds_span, **kwargs)
    return model_class.generate_model()

var_sig_exp_tt(x, b)

x --> b -->

Source code in pytau/changepoint_model.py
705
706
707
708
709
710
def var_sig_exp_tt(x, b):
    """
    x -->
    b -->
    """
    return 1 / (1 + tt.exp(-tt.exp(b) * x))

var_sig_tt(x, b)

x --> b -->

Source code in pytau/changepoint_model.py
713
714
715
716
717
718
def var_sig_tt(x, b):
    """
    x -->
    b -->
    """
    return 1 / (1 + tt.exp(-b * x))

=== Preprocessing functions ===

Code to preprocess spike trains before feeding into model

preprocess_all_taste(spike_array, time_lims, bin_width, data_transform)

Preprocess array containing trials for all tastes (in blocks) concatenated

Parameters:

Name Type Description Default
spike_array 4D Numpy Array

Taste x Trials x Neurons x Time

required
time_lims List/Tuple/Numpy Array

2-element object indicating limits of array

required
bin_width int

Width to use for binning

required
data_transform str

Data-type to return {actual, shuffled, simulated}

required

Raises:

Type Description
Exception

If transforms do not belong to ['shuffled','simulated','None',None]

Returns:

Type Description
4D Numpy Array

Of processed data

Source code in pytau/changepoint_preprocess.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
def preprocess_all_taste(spike_array, time_lims, bin_width, data_transform):
    """Preprocess array containing trials for all tastes (in blocks) concatenated

    Args:
        spike_array (4D Numpy Array): Taste x Trials x Neurons x Time
        time_lims (List/Tuple/Numpy Array): 2-element object indicating limits of array
        bin_width (int): Width to use for binning
        data_transform (str): Data-type to return {actual, shuffled, simulated}

    Raises:
        Exception: If transforms do not belong to ['shuffled','simulated','None',None]

    Returns:
        (4D Numpy Array): Of processed data
    """

    accepted_transforms = [
        "trial_shuffled",
        "spike_shuffled",
        "simulated",
        "None",
        None,
    ]
    if data_transform not in accepted_transforms:
        raise Exception(
            f"data_transform must be of type {accepted_transforms}")

    ##################################################
    # Create shuffled data
    ##################################################
    # Shuffle neurons across trials FOR SAME TASTE

    if data_transform == "trial_shuffled":
        transformed_dat = np.array(
            [np.random.permutation(neuron)
             for neuron in np.swapaxes(spike_array, 2, 0)]
        )
        transformed_dat = np.swapaxes(transformed_dat, 0, 2)

    if data_transform == "spike_shuffled":
        transformed_dat = spike_array.swapaxes(-1, 0)
        transformed_dat = np.stack([np.random.permutation(x)
                                   for x in transformed_dat])
        transformed_dat = transformed_dat.swapaxes(0, -1)

    ##################################################
    # Create simulated data
    ##################################################
    # Inhomogeneous poisson process using mean firing rates

    elif data_transform == "simulated":
        mean_firing = np.mean(spike_array, axis=1)
        mean_firing = np.broadcast_to(mean_firing[:, None], spike_array.shape)

        # Simulate spikes
        transformed_dat = (np.random.random(
            spike_array.shape) < mean_firing) * 1

    ##################################################
    # Null Transform Case
    ##################################################
    elif data_transform == None or data_transform == "None":
        transformed_dat = spike_array

    ##################################################
    # Bin Data
    ##################################################
    spike_binned = np.sum(
        transformed_dat[..., time_lims[0]: time_lims[1]].reshape(
            *transformed_dat.shape[:-1], -1, bin_width
        ),
        axis=-1,
    )
    spike_binned = spike_binned.astype(int)

    return spike_binned

preprocess_single_taste(spike_array, time_lims, bin_width, data_transform)

Preprocess array containing trials for all tastes (in blocks) concatenated

** Note, it may be useful to use x-arrays here to keep track of coordinates

Parameters:

Name Type Description Default
spike_array 3D Numpy array

trials x neurons x time

required
time_lims List/Tuple/Numpy Array

2-element object indicating limits of array

required
bin_width int

Width to use for binning

required
data_transform str

Data-type to return {actual, trial_shuffled, spike_shuffled, simulated}

required

Raises:

Type Description
Exception

If transforms do not belong to

Returns:

Type Description
3D Numpy Array

Of processed data

Source code in pytau/changepoint_preprocess.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def preprocess_single_taste(spike_array, time_lims, bin_width, data_transform):
    """Preprocess array containing trials for all tastes (in blocks) concatenated

    ** Note, it may be useful to use x-arrays here to keep track of coordinates

    Args:
        spike_array (3D Numpy array): trials x neurons x time
        time_lims (List/Tuple/Numpy Array): 2-element object indicating limits of array
        bin_width (int): Width to use for binning
        data_transform (str): Data-type to return
                            {actual, trial_shuffled, spike_shuffled, simulated}

    Raises:
        Exception: If transforms do not belong to
        ['trial_shuffled','spike_shuffled','simulated','None',None]

    Returns:
        (3D Numpy Array): Of processed data
    """

    accepted_transforms = [
        "trial_shuffled",
        "spike_shuffled",
        "simulated",
        "None",
        None,
    ]
    if data_transform not in accepted_transforms:
        raise Exception(
            f"data_transform must be of type {accepted_transforms}")

    ##################################################
    # Create shuffled data
    ##################################################
    # Shuffle neurons across trials FOR SAME TASTE

    if data_transform == "trial_shuffled":
        transformed_dat = np.array(
            [np.random.permutation(neuron)
             for neuron in np.swapaxes(spike_array, 1, 0)]
        )
        transformed_dat = np.swapaxes(transformed_dat, 0, 1)

    if data_transform == "spike_shuffled":
        transformed_dat = np.moveaxis(spike_array, -1, 0)
        transformed_dat = np.stack([np.random.permutation(x)
                                   for x in transformed_dat])
        transformed_dat = np.moveaxis(transformed_dat, 0, -1)
    ##################################################
    # Create simulated data
    ##################################################
    # Inhomogeneous poisson process using mean firing rates

    elif data_transform == "simulated":
        mean_firing = np.mean(spike_array, axis=0)

        # Simulate spikes
        transformed_dat = (
            np.array(
                [
                    np.random.random(mean_firing.shape) < mean_firing
                    for trial in range(spike_array.shape[0])
                ]
            )
            * 1
        )

    ##################################################
    # Null Transform Case
    ##################################################
    elif data_transform in (None, "None"):
        transformed_dat = spike_array

    ##################################################
    # Bin Data
    ##################################################
    spike_binned = np.sum(
        transformed_dat[..., time_lims[0]: time_lims[1]].reshape(
            *spike_array.shape[:-1], -1, bin_width
        ),
        axis=-1,
    )
    spike_binned = spike_binned.astype(int)

    return spike_binned