Using mlad for interval censored data
Interval censoring
Interval censoring occurs when the exact time of an event is unknown, but it is known that it occurs between two time-points. The stintreg
command was introduced in Stata 16 for fitting parametric models with interval censoring. However, stintreg
uses a lf0
evaluator, which means derivatives are obtained numerically, so there is some potential for speed improvements when using mlad
.
The log-likelihood function needs to incorporate situations where there are a combination of exact uncensored observations (UC), exact right censored (RC) observations, left censored (LC) observations, and interval censored (IC) observations. Stata uses the following notation in the stintreg
manual,
Type | \(t_l\) | \(t_u\) | Event time |
---|---|---|---|
uncensored | a | a | a |
Interval censored | a | b | (a,b] |
left censored | . | b | (0,b] |
left censored | 0 | b | (0,b] |
right censored | a | . | (a,\(\infty\)) |
where \(t_l\) and \(t_u\) represent the lower and upper enpoints of the interval.
The log-likelihood for a general parametric distribution with the 4 types of censoring can be expressed as follows,
\[ \begin{align*} \ln L =& \sum_{i \in UC} \ln\left[f_i(t_{li})\right] \ \ + \\\\ &\sum_{i \in RC} \ln\left[S_i(t_{li})\right] \ \ + \\\\ &\sum_{i \in LC} \ln\left[1-S_i(t_{ui})\right] \ \ + \\\\ &\sum_{i \in IC} \ln\left[S_i(t_{li})-S_i(t_{ui})\right] \end{align*} \]
where \(f\_i(t)\) is the density function for the \(i^{th}\) subject and \(S\_i(t)\) is the survival function.
Using stintreg
To develop the mlad
Python function I will use the small example included in the manual for stintreg
comparing the cosmetic effects of two cancer treatments: radiotherapy alone versus radiotherapy plus adjuvant chemotherapy on breast retraction. I will illustrate using a Weibull model, but the approach is easily extended to other distributions. First the model can be fitted using stintreg
as follows,
use https://www.stata-press.com/data/r17/cosmesis, clear
. of breast cancer patients)
(Cosmetic deterioration
weibull)
. stintreg treat age, interval(ltime rtime) distribution(
model:
Fitting constant-only
Iteration 0: Log likelihood = -200.17506
Iteration 1: Log likelihood = -175.10124
Iteration 2: Log likelihood = -155.15341
Iteration 3: Log likelihood = -148.73905
Iteration 4: Log likelihood = -148.65585
Iteration 5: Log likelihood = -148.65584
model:
Fitting full
Iteration 0: Log likelihood = -148.65584
Iteration 1: Log likelihood = -143.35409
Iteration 2: Log likelihood = -142.98393
Iteration 3: Log likelihood = -142.98285
Iteration 4: Log likelihood = -142.98285
of obs = 94
Weibull PH regression Number
Uncensored = 0
Left-censored = 5
Right-censored = 38
Interval-cens. = 51
chi2(2) = 11.35
LR chi2 = 0.0034
Log likelihood = -142.98285 Prob >
------------------------------------------------------------------------------ratio Std. err. z P>|z| [95% conf. interval]
| Haz.
-------------+----------------------------------------------------------------
treat | 2.449817 .6972537 3.15 0.002 1.402391 4.279549
age | .9760212 .0366728 -0.65 0.518 .9067266 1.050611_cons | .0047193 .0076133 -3.32 0.001 .0001998 .1114466
-------------+----------------------------------------------------------------
/ln_p | .4820043 .120089 4.01 0.000 .2466342 .7173743
-------------+----------------------------------------------------------------p | 1.619317 .1944621 1.279711 2.049046
p | .6175444 .0741603 .488032 .7814265
1/
------------------------------------------------------------------------------_cons estimates baseline hazard.
Note:
estimates store stintreg .
I have stored the model estimates so they be compared to the estimates from mlad.
Using mlad
In order to fit the model using mlad
it is necessary to create a variable to denote the type of censoring. I create the variable ctype
below.
gen byte ctype = .
. missing values generated)
(94
qui replace ctype = 1 if ltime==rtime // UC
.
qui replace ctype = 2 if rtime >= . // RC
.
qui replace ctype = 3 if (ltime >= . | ltime==0) // LC
.
qui replace ctype = 4 if (rtime-ltime)>0 & !inlist(ctype,2,3) // IC
.
tab ctype
.
ctype | Freq. Percent Cum.
------------+-----------------------------------
2 | 38 40.43 40.43
3 | 5 5.32 45.74
4 | 51 54.26 100.00
------------+----------------------------------- Total | 94 100.00
I will later pass ctype
to Python.
The Python likelihood function used by mlad
is shown below
type weib_ic_ll.py
. as jnp
import jax.numpy as mu
import mladutil
def python_ll(beta, X, wt, M):exp(mu.linpred(beta,X,1))
lam = jnp.exp(mu.linpred(beta,X,2))
gam = jnp.
"ctype"]==1,jnp.log(mu.weibdens(M["ltime"],lam,gam)),0) +
lli = (jnp.where(M["ctype"]==2,jnp.log(mu.weibsurv(M["ltime"],lam,gam)),0) +
jnp.where(M["ctype"]==3,jnp.log(1 - mu.weibsurv(M["rtime"],lam,gam)),0) +
jnp.where(M["ctype"]==4,jnp.log(mu.weibsurv(M["ltime"],lam,gam)-mu.weibsurv(M["rtime"],lam,gam)),0))
jnp.where(M[return(jnp.sum(wt*lli))
The log-likelihood is simple and consists of extracting the parameters using mu.linpred
, defining the individual contributions to the log-likelihood and then summing these so that a single value for the log-likelhood can be returned.
Something important to note here is that in the function we must ensure that jnp.log()
function does not return a missing value, e.g. when taking the log of zero. This is still the case when the log of zero is actually not used through use of the jnp.where()
function. Below I give create ltime2
and rtime2
that replace incompatable values. It is important to note that these do not feed into the likelihood function due to the use of jnp.where()
.
The Python function and the required variables are then passed to mlad.
gen double ltime2 = cond(ltime==0,1e-8,ltime)
.
gen double rtime2 = cond(rtime==.,99,rtime)
.
///
. mlad (ln_lambda: = treat age, ) ///
> (ln_gamma: = ), ///
> othervars(ctype ltime2 rtime2) ///
> othervarnames(ctype ltime rtime)
> llfile(weib_ic_ll)
Initial: Log likelihood = -2088.4515
Alternative: Log likelihood = -1117.6877
Rescale: Log likelihood = -200.17506eq: Log likelihood = -200.17506
Rescale
Iteration 0: Log likelihood = -200.17506
Iteration 1: Log likelihood = -173.82452
Iteration 2: Log likelihood = -151.01988
Iteration 3: Log likelihood = -143.65323
Iteration 4: Log likelihood = -143.21346
Iteration 5: Log likelihood = -142.98514
Iteration 6: Log likelihood = -142.98285
Iteration 7: Log likelihood = -142.98285
ml display
.
of obs = 94
Number chi2(2) = 10.87
Wald chi2 = 0.0044
Log likelihood = -142.98285 Prob >
------------------------------------------------------------------------------
| Coefficient Std. err. z P>|z| [95% conf. interval]
-------------+----------------------------------------------------------------
ln_lambda |
treat | .8960133 .2846146 3.15 0.002 .338179 1.453848
age | -.024271 .0375738 -0.65 0.518 -.0979143 .0493723_cons | -5.3561 1.613238 -3.32 0.001 -8.51799 -2.194211
-------------+----------------------------------------------------------------
ln_gamma |_cons | .4820043 .1200888 4.01 0.000 .2466346 .717374
------------------------------------------------------------------------------
estimates store mlad .
The coefficients can be compared using estimate tables
estimates table stintreg mlad, equations(1,2) se
.
----------------------------------------
Variable | stintreg mlad
-------------+--------------------------
#1 |
treat | .89601332 .89601334
| .28461462 .2846146
age | -.02427097 -.02427097
| .03757381 .03757381 _cons | -5.3561002 -5.3561004
| 1.6132388 1.6132384
-------------+--------------------------
#2 |
ln_p | .48200426
| .12008898 _cons | .48200431
| .12008879
---------------------------------------- Legend: b/se
The estimates are the same to 7 decimal decimal places.
Speed comparison
The above was useful for developing the Python log-likelihood file, but speed gains will only be worthwhile for larger data sets.
The code below simlulates interval censored data for various sample sizes with 10 covariates. The 10 covariates are also included to model the shape parameter in both stgenreg
and mlad
. The methods are then compared in terms of execution time.
foreach ss in 1000 10000 100000 1000000 {
. quietly {
2. clear
3. timer clear
4. set obs `ss'
5. forvalues i = 1/10 {
6. gen x`i' = rnormal()
7. local cov `cov' x`i' 0.1
8.
9. }
10. d, dist(weib) lambda(0.2) gamma(0.8) maxt(5) cov(`cov')
. survsim rtime
11. gen double ltime = cond(runiform()<75 & d==1,runiform()*rtime,rtime)
. replace ltime = rtime if d==0
12. replace rtime = . if d==0
13.
14. timer on 1
. weibull) anc(x1-x10)
15. stintreg x1-x10, interval(ltime rtime) distribution(timer off 1
16.
17. timer on 2
. gen byte ctype = .
18. qui replace ctype = 1 if ltime==rtime // UC
19. qui replace ctype = 2 if rtime >= . // RC
20. qui replace ctype = 3 if (ltime >= . | ltime==0) // LC
21. qui replace ctype = 4 if (rtime-ltime)>0 & !inlist(ctype,2,3) // IC
22.
23. replace ltime = 1 if ltime ==0
. replace rtime = 1 if rtime ==.
24. ///
25. mlad (ln_lambda: = x1-x10, ) ///
> (ln_gamma: = x1-x10), ///
> othervars(ctype ltime rtime) ///
> llfile(weib_ic_ll) search(off)
> timer off 2
26.
27. }
28. di _newline "Sample Size: `ss'"
. timer list
29. di "(1) stintreg, (2) mlad"
30.
31. }
Sample Size: 1000
1: 0.12 / 1 = 0.1190
2: 0.51 / 1 = 0.5050
(1) stintreg, (2) mlad
Sample Size: 10000
1: 0.62 / 1 = 0.6170
2: 0.61 / 1 = 0.6150
(1) stintreg, (2) mlad
Sample Size: 100000
1: 6.54 / 1 = 6.5450
2: 1.13 / 1 = 1.1330
(1) stintreg, (2) mlad
Sample Size: 1000000
1: 84.20 / 1 = 84.1980
2: 7.00 / 1 = 7.0030 (1) stintreg, (2) mlad
stintreg
is faster for sample sizes of 1,000 and 10,000. mlad
has a speed improvement of ~70% for a sample size of 100,000 and of ~90% for a sample size of 1,000,000.