Source code for pyPCG.lr_hsmm

"""
Logistic Regression - Hidden Semi-Markov Model segmenter

Based on:
D. B. Springer, L. Tarassenko and G. D. Clifford, "Logistic Regression-HSMM-Based Heart Sound Segmentation," in IEEE Transactions on Biomedical Engineering, vol. 63, no. 4, pp. 822-832, April 2016, doi: 10.1109/TBME.2015.2475278.
"""

import pywt
import json
import warnings
import numpy as np
import numpy.typing as npt
import scipy.signal as sgn
from math import ceil
from sklearn.linear_model import LogisticRegression
from hsmmlearn.emissions import AbstractEmissions
from hsmmlearn.hsmm import HSMMModel
from scipy.stats import norm
from joblib import Parallel, delayed
from scipy.stats import multivariate_normal

INF = 999999999

[docs] class LR_HSMM(): """ Main segmenter object Attributes: signal_fs (int): Sampling frequency of the input signal. Default: 1000 [Hz] feature_fs (int): Sampling frequency for feature calculations. Default: 50 [Hz] mean_s1_len (float): Average S1 duration. Default: 122 [ms] mean_s2_len (float): Average S2 duration. Default: 99 [ms] std_s1_len (float): Standard deviation of S1 durations. Default: 22 [ms] std_s2_len (float): Standard deviation of S2 durations. Default: 22 [ms] bandpass_frq (tuple[float,float]): Cutoff frequencies for pre-processing band-pass filtering (Butterworth, 4th order). Default: (25,400) [Hz] expected_hr_range (tuple[float,float]): Minimum and maximum expected heartrates. Default: (30,120) [bpm] hsmm_model (hsmmlearn.hsmm.HSMMModel): State predictor model lr_model (hsmmlearn.emissions.AbstractEmissions): Probability emissions for the HSMM states. Includes a LogisticRegression model for each state """ def __init__(self) -> None: self.signal_fs = 1000 self.feature_fs = 50 self.mean_s1_len = 122 self.mean_s2_len = 99 self.std_s1_len = 22 self.std_s2_len = 22 self.bandpass_frq = (25,400) self.expected_hr_range = (30,120) self.hsmm_model = None self.lr_model = _LREmission()
[docs] def train_model(self,train_data:npt.NDArray[np.float64]|list[float],train_s1_annot:npt.NDArray[np.float64]|list[float],train_s2_annot:npt.NDArray[np.float64]|list[float],multiprocess:int|None=None) -> None: """Trains the model on the specified data with S1 and S2 location annotations Args: train_data (np.ndarray): Array of input signals train_s1_annot (np.ndarray): Array of S1 annotations for each signal train_s2_annot (np.ndarray): Array of S2 annotations for each signal """ tmat = np.array([[0.,1.,0.,0.],[0.,0.,1.,0.],[0.,0.,0.,1.],[1.,0.,0.,0.]]) states = np.array([]) f_henv, f_env, f_psd, f_wt = np.array([]),np.array([]),np.array([]),np.array([]) d_hr, d_sys = np.array([]),np.array([]) print("Generating features...") if multiprocess is not None: result = Parallel(n_jobs=multiprocess,backend="multiprocessing")(delayed(_generate_features)(data,self.signal_fs,self.feature_fs,self.bandpass_frq) for data in train_data) #type: ignore henv, env, psd, wt = zip(*result) #type: ignore for f in henv: f_henv = np.append(f_henv,f) for f in env: f_env = np.append(f_env,f) for f in psd: f_psd = np.append(f_psd,f) for f in wt: f_wt = np.append(f_wt,f) for i,(data,s1_annot,s2_annot) in enumerate(zip(train_data,train_s1_annot,train_s2_annot)): hr, sys = _get_hr_sys(data,self.signal_fs,self.bandpass_frq,self.expected_hr_range[0],self.expected_hr_range[1]) sts = _generate_states(data,s1_annot,s2_annot,self.signal_fs,self.feature_fs,self.mean_s1_len,self.mean_s2_len,self.std_s1_len,self.std_s2_len) if multiprocess is None: henv, env, psd, wt = _generate_features(data,self.signal_fs,self.feature_fs,self.bandpass_frq) f_henv = np.append(f_henv,henv) f_env = np.append(f_env,env) f_psd = np.append(f_psd,psd) f_wt = np.append(f_wt,wt) states = np.append(states,sts) d_hr = np.append(d_hr,hr) d_sys = np.append(d_sys,sys) features = np.array([f_henv,f_env,f_psd,f_wt]) # Calculating duration distributions here is not in parity with Springer et al. durs = _get_duration_distributions(np.mean(d_hr),np.mean(d_sys),self.feature_fs,self.mean_s1_len,self.mean_s2_len,self.std_s1_len,self.std_s2_len) print("Training model...") self.lr_model = _LREmission(features.T, states) self.hsmm_model = HSMMModel(self.lr_model,durs,tmat)
def train_with_precalc_features(self,features:npt.NDArray[np.float64],train_data:npt.NDArray[np.float64]|list[float],train_s1_annot:npt.NDArray[np.float64]|list[float],train_s2_annot:npt.NDArray[np.float64]|list[float]) -> None: tmat = np.array([[0.,1.,0.,0.],[0.,0.,1.,0.],[0.,0.,0.,1.],[1.,0.,0.,0.]]) states = np.array([]) d_hr, d_sys = np.array([]),np.array([]) for i,(data,s1_annot,s2_annot) in enumerate(zip(train_data,train_s1_annot,train_s2_annot)): hr, sys = _get_hr_sys(data,self.signal_fs,self.bandpass_frq,self.expected_hr_range[0],self.expected_hr_range[1]) sts = _generate_states(data,s1_annot,s2_annot,self.signal_fs,self.feature_fs,self.mean_s1_len,self.mean_s2_len,self.std_s1_len,self.std_s2_len) states = np.append(states,sts) d_hr = np.append(d_hr,hr) d_sys = np.append(d_sys,sys) # Calculating duration distributions here is not in parity with Springer et al. durs = _get_duration_distributions(np.mean(d_hr),np.mean(d_sys),self.feature_fs,self.mean_s1_len,self.mean_s2_len,self.std_s1_len,self.std_s2_len) print("Training model...") self.lr_model = _LREmission(features.T, states) self.hsmm_model = HSMMModel(self.lr_model,durs,tmat)
[docs] def segment_single(self,sig:npt.NDArray[np.float64],recalc_timing:bool=False) -> tuple[npt.NDArray[np.float64],npt.NDArray[np.float64]]: """Predicts the states for the given PCG signal Args: sig (np.ndarray): Input signal to be segmented Returns: tuple[np.ndarray,np.ndarray]: Predicted state for each sample and the calculated features """ henv, env, psd, wt = _generate_features(sig,self.signal_fs,self.feature_fs,self.bandpass_frq) seg_features = np.array([henv, env, psd, wt], dtype=np.float64) if self.hsmm_model is None: warnings.warn("Attempting to segment with untrained model. Returning empty states...",RuntimeWarning) return np.empty(0), np.empty(0) if recalc_timing: # Recalculating duration distributions for only the record to be segmented hr, sys = _get_hr_sys(sig,self.signal_fs,self.bandpass_frq,self.expected_hr_range[0],self.expected_hr_range[1]) durs = _get_duration_distributions(hr,sys,self.feature_fs,self.mean_s1_len,self.mean_s2_len,self.std_s1_len,self.std_s2_len) self.hsmm_model.durations = durs d_states = self.hsmm_model.decode(seg_features.T) e_states = _expand_states(d_states+1,self.feature_fs,self.signal_fs,len(sig)) return e_states, seg_features
[docs] def save_model(self,filename:str) -> None: """Saves the model parameters to a json file Args: filename (str): Name of file to be saved """ if self.hsmm_model is None: warnings.warn("Attempting to save untrained model. No file will be written. Returning...",RuntimeWarning) return durs = self.hsmm_model.durations.tolist() #type: ignore tmat = self.hsmm_model.tmat.tolist() #type: ignore emissions = self.hsmm_model.emissions.serialize() #type: ignore config = { "sig_fs":self.signal_fs, "f_fs": self.feature_fs, "preproc_bp": self.bandpass_frq, "mean_s1": self.mean_s1_len, "mean_s2" : self.mean_s2_len, "std_s1" : self.std_s1_len, "std_s2" : self.std_s2_len, "min_hr": self.expected_hr_range[0], "max_hr": self.expected_hr_range[1] } serialized_hsmm = {"config":config,"durations":durs,"transition":tmat,"emissions":emissions} with open(filename,"w") as save: save.write(json.dumps(serialized_hsmm))
[docs] def load_model(self,filename:str) -> None: """Loads the model parameters from a json file Args: filename (str): Name of file containing model parameters """ with open(filename,"r") as load: data = json.loads(load.read()) durs = np.array(data["durations"]) tmat = np.array(data["transition"]) config = data["config"] self.signal_fs = config["sig_fs"] self.feature_fs = config["f_fs"] self.bandpass_frq = config["preproc_bp"] self.mean_s1_len = config["mean_s1"] self.mean_s2_len = config["mean_s2"] self.std_s1_len = config["std_s1"] self.std_s2_len = config["std_s2"] self.expected_hr_range = (config["min_hr"],config["max_hr"]) emission = _LREmission() emission.unserialize(data["emissions"]) self.lr_model = emission self.hsmm_model = HSMMModel(self.lr_model,durs,tmat)
class _LREmission(AbstractEmissions): def __init__(self, features=None, states=None) -> None: self.LRmodel_s1 = LogisticRegression() self.LRmodel_s2 = LogisticRegression() self.LRmodel_sys = LogisticRegression() self.LRmodel_dia = LogisticRegression() self.predictors = [self.LRmodel_s1,self.LRmodel_sys,self.LRmodel_s2,self.LRmodel_dia] if(features is None or states is None): return s1_train, s2_train, sys_train, dia_train = np.zeros_like(states), np.zeros_like(states), np.zeros_like(states), np.zeros_like(states) s1_train[states != 1] = 1 s2_train[states != 3] = 1 sys_train[states != 2] = 1 dia_train[states != 4] = 1 total = np.concatenate((features[states==1],features[states==2],features[states==3],features[states==4])) self.mu = np.mean(total,axis=0) self.sigma = np.cov(total.T) print("Training S1 LR...") self.LRmodel_s1 = LogisticRegression(random_state=0,max_iter=100,class_weight="balanced",tol=1e-6).fit(features,s1_train) print("Training S2 LR...") self.LRmodel_s2 = LogisticRegression(random_state=0,max_iter=100,class_weight="balanced",tol=1e-6).fit(features,s2_train) print("Training sys LR...") self.LRmodel_sys = LogisticRegression(random_state=0,max_iter=100,class_weight="balanced",tol=1e-6).fit(features,sys_train) print("Training dia LR...") self.LRmodel_dia = LogisticRegression(random_state=0,max_iter=100,class_weight="balanced",tol=1e-6).fit(features,dia_train) # self.LRmodel_complete = LogisticRegression(random_state=0,max_iter=100,class_weight="balanced",multi_class="multinomial",tol=1e-6).fit(features,states) self.predictors = [self.LRmodel_s1,self.LRmodel_sys,self.LRmodel_s2,self.LRmodel_dia] #TODO: possible to replace with a single LR predictor def likelihood(self, obs): probs = np.empty((len(obs),len(self.predictors))) for n,predictor in enumerate(self.predictors): pi_hat = predictor.predict_proba(obs)[:,0] for t in range(len(obs)): correction = multivariate_normal.pdf(obs[t,:],mean=self.mu,cov=self.sigma) #type:ignore probs[t,n] = (pi_hat[t]*correction)/0.25 return probs.T def serialize(self): serialized_models = {"lr_s1":None, "lr_sys":None, "lr_s2":None, "lr_dia":None} for lr_type,model in zip(serialized_models.keys(),self.predictors): params = model.get_params() attrs = [i for i in dir(model) if i.endswith('_') and not i.endswith('__') and not i.startswith('_')] attr_dict = {i: getattr(model, i) for i in attrs} for k in attr_dict: if isinstance(attr_dict[k], np.ndarray): attr_dict[k] = attr_dict[k].tolist() serialized_lr = {"params":params,"attrs":attr_dict} serialized_models[lr_type] = serialized_lr #type: ignore extra_params = {"mu":self.mu.tolist(), "sigma":self.sigma.tolist()} serialized = serialized_models | extra_params return serialized def unserialize(self,serial): saved_predictors = list(serial.keys()) saved_predictors.remove("mu") saved_predictors.remove("sigma") for loaded,lr in zip(saved_predictors,self.predictors): params = serial[loaded]["params"] attrs = serial[loaded]["attrs"] lr.set_params(**params) for k in attrs: if isinstance(attrs[k],list): setattr(lr,k,np.array(attrs[k])) else: setattr(lr,k,attrs[k]) self.mu = serial["mu"] self.sigma = serial["sigma"] def _envelope_feature(sig): env = abs(sgn.hilbert(sig)) #type: ignore return env def _h_envelope_feature(sig,sig_fs): env = _envelope_feature(sig) lp = sgn.butter(1,8,output='sos',fs=sig_fs,btype='lowpass') filt = np.exp(sgn.sosfiltfilt(lp,np.log(env))) filt[0] = filt[1] #? return filt def _psd_feature(sig,sig_fs): f_lo = 40 f_hi = 60 [f,_,Zxx] = sgn.stft(sig,fs=sig_fs,window="hamming",nperseg=sig_fs//40,scaling="psd",nfft=1024) lo_pos = np.argmin(np.abs(f-f_lo)) hi_pos = np.argmin(np.abs(f-f_hi)) psd = np.mean(np.abs(Zxx[lo_pos:hi_pos,:])**2,axis=0) psd_re = sgn.resample_poly(psd,len(sig),len(psd)) return psd_re def _wt_feature(sig): coefs = pywt.wavedec(sig,"rbio3.9",level=3) cD = coefs[1:] cD.reverse() detail = np.zeros((3,len(sig))) for i in range(3): d = np.tile(cD[i],(2**(i+1),1)).T.ravel() start = len(d)-len(sig)-1//2 end = start + len(sig) detail[i,:] = d[start:end] d3 = np.abs(detail[2,:]) return d3 def _normalize(sig): m = np.mean(sig) s = np.std(sig) n_sig = (sig-m)/s return n_sig def _spike_removal(sig,sig_fs): window_s = round(sig_fs/2) trailing = len(sig) % window_s frames = np.reshape(sig[:-trailing],(window_s,-1)) MAAs = np.max(np.abs(frames),axis=0) if len(MAAs) == 0: return sig while(np.any(MAAs>np.median(MAAs,axis=0)*3)): framenum = np.argmax(MAAs) pos = np.argmax(np.abs(frames[:,framenum])) zerocrossings = np.append(np.abs(np.diff(np.sign(frames[:,framenum])))>1,0) spike_start = 0 find = np.nonzero(zerocrossings[:pos])[0] if len(find)>0: spike_start = max(1,find[-1]) zerocrossings[:pos] = 0 find = np.nonzero(zerocrossings)[0] spike_end = window_s+1 if len(find)>0: spike_end = min(find[0],window_s)+1 frames[spike_start:spike_end,framenum] = 0.0001 MAAs = np.max(np.abs(frames),axis=0) removed = np.reshape(frames,(-1,1)) removed = np.append(removed,sig[len(removed):]) return removed def _generate_features(sig,sig_fs,f_fs,preproc=(25,400)): bpf = sgn.butter(4,preproc,"bandpass",output="sos",fs=sig_fs) f_sig = sgn.sosfiltfilt(bpf,sig) rem_sig = _spike_removal(f_sig,sig_fs) h_env = _h_envelope_feature(rem_sig,sig_fs) env = _envelope_feature(rem_sig) psd = _psd_feature(rem_sig,sig_fs) wt = _wt_feature(rem_sig) d_h_env = _normalize(sgn.resample_poly(h_env,f_fs,sig_fs)) d_env = _normalize(sgn.resample_poly(env,f_fs,sig_fs)) d_psd = _normalize(sgn.resample_poly(psd,f_fs,sig_fs)) d_wt = _normalize(sgn.resample_poly(wt,f_fs,sig_fs)) return np.array(d_h_env),np.array(d_env),np.array(d_psd),np.array(d_wt) def _generate_states(sig,annot_s1,annot_s2,sig_fs,f_fs,mean_s1=122,mean_s2=99,std_s1=22,std_s2=22): henv = _h_envelope_feature(sig,sig_fs) env = sgn.resample_poly(henv,f_fs,sig_fs) states = np.zeros_like(env) scale = f_fs/1000 as1 = np.round(np.array(annot_s1)*f_fs).astype(int) as2 = np.round(np.array(annot_s2)*f_fs).astype(int) ms1 = round(mean_s1*scale) ms2 = round(mean_s2*scale) ss1 = round(std_s1*scale) ss2 = round(std_s2*scale) for s1 in as1: upper_s1 = min(len(states)-1,s1+ms1+ss1) # +std_s1 lower_s1 = max(1,s1-ms1-ss1) if lower_s1>=upper_s1: continue search = env[lower_s1:upper_s1] s1_ind = np.argmax(search) s1_ind = min(len(states)-1,lower_s1+s1_ind) #type: ignore upper_s1 = min(len(states)-1,ceil(s1_ind+(ms1/2))) lower_s1 = max(0,ceil(s1_ind-(ms1/2))) states[lower_s1:upper_s1] = 1 for s2 in as2: upper_s2 = min(len(states)-1,s2+ms2+ss2) lower_s2 = max(0,s2-ms2-ss2) if lower_s2>=upper_s2: continue search = env[lower_s2:upper_s2]*(1-states[lower_s2:upper_s2]) s2_ind = np.argmax(search) s2_ind = min(len(states)-1,lower_s2+s2_ind) #type: ignore upper_s2 = min(len(states)-1,ceil(s2_ind+(ms2/2))) lower_s2 = max(0,ceil(s2_ind-(ms2/2))) states[lower_s2:upper_s2] = 3 s1_labels = as1 diffs = s1_labels - s2 diffs[diffs<0] = INF end_pos = 0 if len(diffs<INF)==0: end_pos = len(states)-1 else: end_pos = s1_labels[np.argmin(diffs)] states[ceil(s2_ind+(ms2/2)):end_pos] = 4 empty_states = np.nonzero(states)[0] if len(empty_states)>0: first_definite = empty_states[0] if first_definite > 0: if states[first_definite+1] == 1: states[0:first_definite] = 4 if states[first_definite+1] == 3: states[0:first_definite] = 2 last_definite = empty_states[-1] if last_definite > 0: if states[last_definite] == 1: states[last_definite:] = 2 if states[last_definite] == 3: states[last_definite:] = 4 states[states==0] = 2 return states def _get_hr_sys(sig,sig_fs,preproc=(25,400),min_hr=30,max_hr=120): bpf = sgn.butter(4,preproc,"bandpass",output="sos",fs=sig_fs) f_sig = sgn.sosfiltfilt(bpf,sig) rem_sig = _spike_removal(f_sig,sig_fs) h_env = _h_envelope_feature(rem_sig,sig_fs) y = h_env - np.mean(h_env) coef = sgn.correlate(y,y) coef = coef[len(h_env):]/np.max(coef) min_ind = round((60/max_hr)*sig_fs) max_ind = round((60/min_hr)*sig_fs) ind = np.argmax(coef[min_ind:max_ind]) + min_ind heartrate = 60/(ind/sig_fs) max_sys = round(((60/heartrate)*sig_fs)/2) min_sys = round(0.2*sig_fs) ind = np.argmax(coef[min_sys:max_sys]) + min_sys systole = ind/sig_fs return heartrate, systole def _get_duration_params(hr,sys,mean_s1=122,mean_s2=99,std_s1=22,std_s2=22,fs=50): m_s1 = round(mean_s1/1000*fs) s_s1 = round(std_s1/1000*fs) m_s2 = round(mean_s2/1000*fs) s_s2 = round(std_s2/1000*fs) mean_sys = round(sys*fs) - m_s1 std_sys = (25/1000)*fs #TODO: extract to model parameter mean_dia = ((60/hr)-sys-mean_s2/1000)*fs std_dia = 0.07*mean_dia + (6/1000)*fs #TODO: extract to model parameter # min_sys = mean_sys - 3*(std_sys+std_s1) #unused max_sys = mean_sys + 3*(std_sys+std_s1) # min_dia = mean_dia - 3*std_dia #unused max_dia = mean_dia + 3*std_dia # min_s1 = m_s1 - 3*s_s1 #unused max_s1 = m_s1 + 3*s_s1 # min_s2 = m_s2 - 3*s_s2 #unused max_s2 = m_s2 + 3*s_s2 max_duration = max([max_s1+2*std_s1,max_s2+2*std_s2,max_sys+2*(std_sys+std_s1),max_dia+2*std_dia]) # min_duration = min([min_s1-2*std_s1,min_s2-2*std_s2,min_sys-2*(std_sys-std_s1),min_dia-2*std_dia]) #unused return max_duration, mean_sys, std_sys, mean_dia, std_dia def _get_duration_distributions(hr,sys,f_fs,mean_s1=122,mean_s2=99,std_s1=22,std_s2=22): max_duration, mean_sys, std_sys, mean_dia, std_dia = _get_duration_params(hr,sys,mean_s1,mean_s2,std_s1,std_s2,f_fs) m_s1 = round(mean_s1/1000*f_fs) s_s1 = round(std_s1/1000*f_fs) m_s2 = round(mean_s2/1000*f_fs) s_s2 = round(std_s2/1000*f_fs) max_duration = round(max_duration) s1_dur, s2_dur, sys_dur, dia_dur = np.zeros((max_duration)),np.zeros((max_duration)),np.zeros((max_duration)),np.zeros((max_duration)) for i in range(1,max_duration+1): s1_dur[i-1] = norm.pdf(i,loc=m_s1,scale=s_s1) s2_dur[i-1] = norm.pdf(i,loc=m_s2,scale=s_s2) sys_dur[i-1] = norm.pdf(i,loc=mean_sys,scale=std_sys) dia_dur[i-1] = norm.pdf(i,loc=mean_dia,scale=std_dia) return np.array([s1_dur,sys_dur,s2_dur,dia_dur]) def _expand_states(states,orig_fs,new_fs,new_len): expanded = np.zeros(new_len) changes = np.nonzero(np.diff(states))[0] changes = np.append(changes,len(states)-1) start = 0 for end in changes: mid = round((end-start)/2) + start mid_val = states[mid] exp_start = round((start/orig_fs)*new_fs) exp_end = round((end/orig_fs)*new_fs) if round((end/orig_fs)*new_fs) < new_len else new_len expanded[exp_start:exp_end] = mid_val start = end return expanded if __name__ == '__main__': print("LR-HSMM model based on Springer et al.")