Using mlad for interval censored data
Interval censoring
Interval censoring occurs when the exact time of an event is unknown, but it is known that it occurs between two time-points.
The stintreg
command was introduced in Stata 16 for fitting parametric models with interval censoring.
However, stintreg
uses a lf0
evaluator, which means derivatives are obtained numerically,
so there is some potential for speed improvements when using mlad
.
The log-likelihood function needs to incorporate situations where there are a combination of exact uncensored observations (UC),
exact right censored (RC) observations,
left censored (LC) observations, and
interval censored (IC) observations. Stata uses the following notation in the stintreg
manual,
Type | $t_l$ | $t_u$ | Event time |
---|---|---|---|
uncensored | a | a | a |
Interval censored | a | b | (a,b] |
left censored | . | b | (0,b] |
left censored | 0 | b | (0,b] |
right censored | a | . | (a,$\infy$) |
where $t_l$ and $t_u$ represent the lower and upper enpoints of the interval.
The log-likelihood for a general parametric distribution with the 4 types of censoring can be expressed as follows,
$$
\begin{align*}
\ln L =& \sum_{i \in UC} \ln\left[f_i(t_{li})\right] \ \ + \\
&\sum_{i \in RC} \ln\left[S_i(t_{li})\right] \ \ + \\
&\sum_{i \in LC} \ln\left[1-S_i(t_{ui})\right] \ \ + \\
&\sum_{i \in IC} \ln\left[S_i(t_{li})-S_i(t_{ui})\right]
\end{align*}
$$
where $f_i(t)$ is the density function for the $i^{th}$ subject and $S_i(t)$ is the survival function.
Using stintreg
To develop the mlad
Python function I will use the small example included in the manual for stintreg
comparing
the cosmetic effects of two cancer treatments: radiotherapy alone versus radiotherapy plus adjuvant chemotherapy
on breast retraction.
I will illustrate using a Weibull model, but the approach is easily extended to other distributions.
First the model can be fitted using stintreg
as follows,
. use https://www.stata-press.com/data/r17/cosmesis, clear
(Cosmetic deterioration of breast cancer patients)
. stintreg treat age, interval(ltime rtime) distribution(weibull)
Fitting constant-only model:
Iteration 0: Log likelihood = -200.17506
Iteration 1: Log likelihood = -175.10124
Iteration 2: Log likelihood = -155.15341
Iteration 3: Log likelihood = -148.73905
Iteration 4: Log likelihood = -148.65585
Iteration 5: Log likelihood = -148.65584
Fitting full model:
Iteration 0: Log likelihood = -148.65584
Iteration 1: Log likelihood = -143.35409
Iteration 2: Log likelihood = -142.98393
Iteration 3: Log likelihood = -142.98285
Iteration 4: Log likelihood = -142.98285
Weibull PH regression Number of obs = 94
Uncensored = 0
Left-censored = 5
Right-censored = 38
Interval-cens. = 51
LR chi2(2) = 11.35
Log likelihood = -142.98285 Prob > chi2 = 0.0034
------------------------------------------------------------------------------
| Haz. ratio Std. err. z P>|z| [95% conf. interval]
-------------+----------------------------------------------------------------
treat | 2.449817 .6972537 3.15 0.002 1.402391 4.279549
age | .9760212 .0366728 -0.65 0.518 .9067266 1.050611
_cons | .0047193 .0076133 -3.32 0.001 .0001998 .1114466
-------------+----------------------------------------------------------------
/ln_p | .4820043 .120089 4.01 0.000 .2466342 .7173743
-------------+----------------------------------------------------------------
p | 1.619317 .1944621 1.279711 2.049046
1/p | .6175444 .0741603 .488032 .7814265
------------------------------------------------------------------------------
Note: _cons estimates baseline hazard.
. estimates store stintreg
I have stored the model estimates so they be compared to the estimates from mlad.
Using mlad
In order to fit the model using mlad
it is necessary to create a variable to denote the type of censoring.
I create the variable ctype
below.
. gen byte ctype = .
(94 missing values generated)
. qui replace ctype = 1 if ltime==rtime // UC
. qui replace ctype = 2 if rtime >= . // RC
. qui replace ctype = 3 if (ltime >= . | ltime==0) // LC
. qui replace ctype = 4 if (rtime-ltime)>0 & !inlist(ctype,2,3) // IC
. tab ctype
ctype | Freq. Percent Cum.
------------+-----------------------------------
2 | 38 40.43 40.43
3 | 5 5.32 45.74
4 | 51 54.26 100.00
------------+-----------------------------------
Total | 94 100.00
I will later pass ctype
to Python.
The Python likelihood function used by mlad
is shown below
. type weib_ic_ll.py
import jax.numpy as jnp
import mladutil as mu
def python_ll(beta, X, wt, M):
lam = jnp.exp(mu.linpred(beta,X,1))
gam = jnp.exp(mu.linpred(beta,X,2))
lli = (jnp.where(M["ctype"]==1,jnp.log(mu.weibdens(M["ltime"],lam,gam)),0) +
jnp.where(M["ctype"]==2,jnp.log(mu.weibsurv(M["ltime"],lam,gam)),0) +
jnp.where(M["ctype"]==3,jnp.log(1 - mu.weibsurv(M["rtime"],lam,gam)),0) +
jnp.where(M["ctype"]==4,jnp.log(mu.weibsurv(M["ltime"],lam,gam)-mu.weibsurv(M["rtime"],lam,gam)),0))
return(jnp.sum(wt*lli))
The log-likelihood is simple and consists of extracting the parameters using mu.linpred
,
defining the individual contributions to the log-likelihood and then summing these so that a single value for the log-likelhood can be returned.
Something important to note here is that in the function we must ensure that jnp.log()
function does not return a missing value, e.g. when taking the log of zero. This is still the case when the log of zero is actually not used through use of the jnp.where()
function. Below I give create ltime2
and rtime2
that replace incompatable values. It is important to note that these do not feed into the likelihood function due to the use of jnp.where()
.
The Python function and the required variables are then passed to mlad.
. gen double ltime2 = cond(ltime==0,1e-8,ltime)
. gen double rtime2 = cond(rtime==.,99,rtime)
. mlad (ln_lambda: = treat age, ) ///
> (ln_gamma: = ), ///
> othervars(ctype ltime2 rtime2) ///
> othervarnames(ctype ltime rtime) ///
> llfile(weib_ic_ll)
Initial: Log likelihood = -2088.4515
Alternative: Log likelihood = -1117.6877
Rescale: Log likelihood = -200.17506
Rescale eq: Log likelihood = -200.17506
Iteration 0: Log likelihood = -200.17506
Iteration 1: Log likelihood = -173.82452
Iteration 2: Log likelihood = -151.01988
Iteration 3: Log likelihood = -143.65323
Iteration 4: Log likelihood = -143.21346
Iteration 5: Log likelihood = -142.98514
Iteration 6: Log likelihood = -142.98285
Iteration 7: Log likelihood = -142.98285
. ml display
Number of obs = 94
Wald chi2(2) = 10.87
Log likelihood = -142.98285 Prob > chi2 = 0.0044
------------------------------------------------------------------------------
| Coefficient Std. err. z P>|z| [95% conf. interval]
-------------+----------------------------------------------------------------
ln_lambda |
treat | .8960133 .2846146 3.15 0.002 .338179 1.453848
age | -.024271 .0375738 -0.65 0.518 -.0979143 .0493723
_cons | -5.3561 1.613238 -3.32 0.001 -8.51799 -2.194211
-------------+----------------------------------------------------------------
ln_gamma |
_cons | .4820043 .1200888 4.01 0.000 .2466346 .717374
------------------------------------------------------------------------------
. estimates store mlad
The coefficients can be compared using estimate tables
. estimates table stintreg mlad, equations(1,2) se
----------------------------------------
Variable | stintreg mlad
-------------+--------------------------
#1 |
treat | .89601332 .89601334
| .28461462 .2846146
age | -.02427097 -.02427097
| .03757381 .03757381
_cons | -5.3561002 -5.3561004
| 1.6132388 1.6132384
-------------+--------------------------
#2 |
ln_p | .48200426
| .12008898
_cons | .48200431
| .12008879
----------------------------------------
Legend: b/se
The estimates are the same to 7 decimal decimal places.
Speed comparison
The above was useful for developing the Python log-likelihood file, but speed gains will only be worthwhile for larger data sets.
The code below simlulates interval censored data for various sample sizes with 10 covariates.
The 10 covariates are also included to model the shape parameter in both stgenreg
and mlad
.
The methods are then compared in terms of execution time.
. foreach ss in 1000 10000 100000 1000000 {
2. quietly {
3. clear
4. timer clear
5. set obs `ss'
6. forvalues i = 1/10 {
7. gen x`i' = rnormal()
8. local cov `cov' x`i' 0.1
9. }
10.
. survsim rtime d, dist(weib) lambda(0.2) gamma(0.8) maxt(5) cov(`cov')
11.
. gen double ltime = cond(runiform()<75 & d==1,runiform()*rtime,rtime)
12. replace ltime = rtime if d==0
13. replace rtime = . if d==0
14.
. timer on 1
15. stintreg x1-x10, interval(ltime rtime) distribution(weibull) anc(x1-x10)
16. timer off 1
17.
. timer on 2
18. gen byte ctype = .
19. qui replace ctype = 1 if ltime==rtime // UC
20. qui replace ctype = 2 if rtime >= . // RC
21. qui replace ctype = 3 if (ltime >= . | ltime==0) // LC
22. qui replace ctype = 4 if (rtime-ltime)>0 & !inlist(ctype,2,3) // IC
23.
. replace ltime = 1 if ltime ==0
24. replace rtime = 1 if rtime ==.
25. mlad (ln_lambda: = x1-x10, ) ///
> (ln_gamma: = x1-x10), ///
> othervars(ctype ltime rtime) ///
> llfile(weib_ic_ll) ///
> search(off)
26. timer off 2
27. }
28.
. di _newline "Sample Size: `ss'"
29. timer list
30. di "(1) stintreg, (2) mlad"
31. }
Sample Size: 1000
1: 0.15 / 1 = 0.1470
2: 0.56 / 1 = 0.5560
(1) stintreg, (2) mlad
Sample Size: 10000
1: 0.60 / 1 = 0.5970
2: 0.62 / 1 = 0.6220
(1) stintreg, (2) mlad
Sample Size: 100000
1: 6.11 / 1 = 6.1080
2: 1.38 / 1 = 1.3790
(1) stintreg, (2) mlad
Sample Size: 1000000
1: 80.53 / 1 = 80.5310
2: 7.27 / 1 = 7.2730
(1) stintreg, (2) mlad
stintreg
is faster for sample sizes of 1,000 and 10,000. mlad
has a speed improvement of ~70% for a sample size of 100,000 and of ~90% for a sample size of 1,000,000.