Using mlad for interval censored data

Interval censoring

Interval censoring occurs when the exact time of an event is unknown, but it is known that it occurs between two time-points. The stintreg command was introduced in Stata 16 for fitting parametric models with interval censoring. However, stintreg uses a lf0 evaluator, which means derivatives are obtained numerically, so there is some potential for speed improvements when using mlad.

The log-likelihood function needs to incorporate situations where there are a combination of exact uncensored observations (UC), exact right censored (RC) observations, left censored (LC) observations, and interval censored (IC) observations. Stata uses the following notation in stintreg,

Type $t_l$ $t_u$ Event time
uncensored a a a
Interval censored a b (a,b]
left censored . b (0,b]
left censored 0 b (0,b]
right censored a . (a,$\infy)

where $t_l$ and $t_u$ represent the lower and upper enpoints of the interval.

The log-likelihood for a general parametric distribution with the 4 types of censoring can be expressed as follows,

$$ \begin{align*} \ln L =& \sum_{i \in UC} \ln\left[f_i(t_{li})\right] \ \ + \\
&\sum_{i \in RC} \ln\left[S_i(t_{li})\right] \ \ + \\
&\sum_{i \in LC} \ln\left[1-S_i(t_{ui})\right] \ \ + \\
&\sum_{i \in IC} \ln\left[S_i(t_{li})-S_i(t_{ui})\right] \end{align*} $$

where $f_i(t)$ is the density function for the $i^{th}$ subject and $S_i(t)$ is the survival function.

Using stintreg

To develop the mlad Python function I will use the small example included in the manual for stintreg comparing the cosmetic effects of two cancer treatments: radiotherapy alone versus radiotherapy plus adjuvant chemotherapy on breast retraction. I will illustrate using a Weibull model, but the approach is easily extended to other distributions. First the model can be fitted using stintreg as follows,

. use https://www.stata-press.com/data/r17/cosmesis, clear
(Cosmetic deterioration of breast cancer patients)

. stintreg treat age, interval(ltime rtime) distribution(weibull)

Fitting constant-only model:
Iteration 0:   log likelihood = -200.17506  
Iteration 1:   log likelihood = -175.10124  
Iteration 2:   log likelihood = -155.15341  
Iteration 3:   log likelihood = -148.73905  
Iteration 4:   log likelihood = -148.65585  
Iteration 5:   log likelihood = -148.65584  

Fitting full model:
Iteration 0:   log likelihood = -148.65584  
Iteration 1:   log likelihood = -143.35409  
Iteration 2:   log likelihood = -142.98393  
Iteration 3:   log likelihood = -142.98285  
Iteration 4:   log likelihood = -142.98285  

Weibull PH regression                               Number of obs     =     94
                                                           Uncensored =      0
                                                        Left-censored =      5
                                                       Right-censored =     38
                                                       Interval-cens. =     51

                                                    LR chi2(2)        =  11.35
Log likelihood = -142.98285                         Prob > chi2       = 0.0034

------------------------------------------------------------------------------
             | Haz. ratio   Std. err.      z    P>|z|     [95% conf. interval]
-------------+----------------------------------------------------------------
       treat |   2.449817   .6972537     3.15   0.002     1.402391    4.279549
         age |   .9760212   .0366728    -0.65   0.518     .9067266    1.050611
       _cons |   .0047193   .0076133    -3.32   0.001     .0001998    .1114466
-------------+----------------------------------------------------------------
       /ln_p |   .4820043    .120089     4.01   0.000     .2466342    .7173743
-------------+----------------------------------------------------------------
           p |   1.619317   .1944621                      1.279711    2.049046
         1/p |   .6175444   .0741603                       .488032    .7814265
------------------------------------------------------------------------------
Note: _cons estimates baseline hazard.

. estimates store stintreg

I have stored the model estimates so they be compared to the estimates from mlad.

Using mlad

In order to fit the model using mlad it is necessary to create a variable to denote the type of censoring. I create the variable ctype below.

. gen byte ctype = .
(94 missing values generated)

. qui replace ctype = 1 if ltime==rtime                         // UC

. qui replace ctype = 2 if rtime >= .                           // RC

. qui replace ctype = 3 if (ltime >= . | ltime==0)              // LC

. qui replace ctype = 4 if (rtime-ltime)>0 & !inlist(ctype,2,3) // IC 

. tab ctype

      ctype |      Freq.     Percent        Cum.
------------+-----------------------------------
          2 |         38       40.43       40.43
          3 |          5        5.32       45.74
          4 |         51       54.26      100.00
------------+-----------------------------------
      Total |         94      100.00

I will later pass ctype to Python.

The python likelihood function used by mlad is shown below

. type weib_ic_ll.py
import jax.numpy as jnp
import mladutil as mu

def python_ll(beta, X, wt, M):
  lam = jnp.exp(mu.linpred(beta,X,1))
  gam = jnp.exp(mu.linpred(beta,X,2))
  
  lli = (jnp.where(M["ctype"]==1,jnp.log(mu.weibdens(M["ltime"],lam,gam)),0)                                 +
         jnp.where(M["ctype"]==2,jnp.log(mu.weibsurv(M["ltime"],lam,gam)),0)                                 +
         jnp.where(M["ctype"]==3,jnp.log(1 - mu.weibsurv(M["rtime"],lam,gam)),0)                             +
         jnp.where(M["ctype"]==4,jnp.log(mu.weibsurv(M["ltime"],lam,gam)-mu.weibsurv(M["rtime"],lam,gam)),0))
  return(jnp.sum(wt*lli))      

The log-likelihood is simple and consists of extracting the parameters using mu.linpred, defining the individual contributions to the log-likelihood and then summing these so that a single value for the log-likelhood can be returned.

Something important to note here is that in the function we must ensure that jnp.log() function does not return a missing value, e.g. when taking the log of zero. This is still the case when the log of zero is actually not used through use of the jnp.where() function. Below I give create ltime2 and rtime2 that replace incompatable values. It is important to note that these do not feed into the likelihood function due to the use of jnp.where().

The python function and the required variables are then passed to mlad.

. gen double ltime2 = cond(ltime==0,1e-8,ltime)

. gen double rtime2 = cond(rtime==.,99,rtime)

. mlad (ln_lambda: = treat age, )   ///
>      (ln_gamma: = ),              ///
>      othervars(ctype ltime2 rtime2) ///
>      othervarnames(ctype ltime rtime) ///
>      llfile(weib_ic_ll)  

initial:       log likelihood = -2088.4515
alternative:   log likelihood = -1117.6877
rescale:       log likelihood = -200.17506
rescale eq:    log likelihood = -200.17506
Iteration 0:   log likelihood = -200.17506  
Iteration 1:   log likelihood = -173.82452  
Iteration 2:   log likelihood = -151.01988  
Iteration 3:   log likelihood = -143.65323  
Iteration 4:   log likelihood = -143.21346  
Iteration 5:   log likelihood = -142.98514  
Iteration 6:   log likelihood = -142.98285  
Iteration 7:   log likelihood = -142.98285  

. ml display     

                                                        Number of obs =     94
                                                        Wald chi2(2)  =  10.87
Log likelihood = -142.98285                             Prob > chi2   = 0.0044

------------------------------------------------------------------------------
             | Coefficient  Std. err.      z    P>|z|     [95% conf. interval]
-------------+----------------------------------------------------------------
ln_lambda    |
       treat |   .8960133   .2846146     3.15   0.002      .338179    1.453848
         age |   -.024271   .0375738    -0.65   0.518    -.0979143    .0493723
       _cons |    -5.3561   1.613238    -3.32   0.001     -8.51799   -2.194211
-------------+----------------------------------------------------------------
ln_gamma     |
       _cons |   .4820043   .1200888     4.01   0.000     .2466346     .717374
------------------------------------------------------------------------------

. estimates store mlad

The coefficients can be compared using estimate tables

. estimates table stintreg mlad, equations(1,2) se

----------------------------------------
    Variable |  stintreg       mlad     
-------------+--------------------------
#1           |
       treat |  .89601332    .89601334  
             |  .28461462     .2846146  
         age | -.02427097   -.02427097  
             |  .03757381    .03757381  
       _cons | -5.3561002   -5.3561004  
             |  1.6132388    1.6132384  
-------------+--------------------------
#2           |
        ln_p |  .48200426               
             |  .12008898               
       _cons |               .48200431  
             |               .12008879  
----------------------------------------
                            Legend: b/se

. 
. 
. 
. 

The estimates are the same to 7 decimal decimal places.

Speed comparison

The above was useful for developing the Python log-likelihood file, but speed gains will only be worthwhile for larger data sets.

The code below simlulates interval censored data for various sample sizes with 10 covariates. The 10 covariates are also included to model the shape parameter in both stgenreg and mlad. The methods are then compared in terms of execution time.

. foreach ss in  1000 10000 100000 1000000 {
  2.   quietly {
  3.     clear
  4.     timer clear
  5.     set obs `ss'
  6.     forvalues i = 1/10 {
  7.       gen x`i' = rnormal()
  8.       local cov `cov' x`i' 0.1
  9.     }  
 10. 
.     survsim rtime d, dist(weib) lambda(0.2) gamma(0.8) maxt(5) cov(`cov') 
 11.   
.     gen double ltime = cond(runiform()<75 & d==1,runiform()*rtime,rtime) 
 12.     replace ltime = rtime if d==0
 13.     replace rtime = . if d==0
 14. 
.     timer on 1
 15.     stintreg x1-x10, interval(ltime rtime) distribution(weibull) anc(x1-x10)
 16.     timer off 1
 17. 
.     timer on 2
 18.     gen byte ctype = .
 19.     qui replace ctype = 1 if ltime==rtime                         // UC
 20.     qui replace ctype = 2 if rtime >= .                           // RC
 21.     qui replace ctype = 3 if (ltime >= . | ltime==0)              // LC
 22.     qui replace ctype = 4 if (rtime-ltime)>0 & !inlist(ctype,2,3) // IC 
 23.   
.     replace ltime = 1 if ltime ==0 
 24.     replace rtime = 1 if rtime ==. 
 25.     mlad (ln_lambda: = x1-x10, )      ///
>          (ln_gamma:  = x1-x10),       ///
>          othervars(ctype ltime rtime) ///
>          llfile(weib_ic_ll)           ///
>          search(off) 
 26.     timer off 2
 27.   }
 28.   
.   di _newline "Sample Size: `ss'"
 29.   timer list
 30.   di "(1) stintreg, (2) mlad" 
 31. }

Sample Size: 1000
   1:      0.13 /        1 =       0.1250
   2:      0.60 /        1 =       0.5980
(1) stintreg, (2) mlad

Sample Size: 10000
   1:      0.64 /        1 =       0.6430
   2:      0.68 /        1 =       0.6780
(1) stintreg, (2) mlad

Sample Size: 100000
   1:      5.62 /        1 =       5.6190
   2:      1.37 /        1 =       1.3730
(1) stintreg, (2) mlad

Sample Size: 1000000
   1:     77.67 /        1 =      77.6670
   2:      7.29 /        1 =       7.2880
(1) stintreg, (2) mlad

stintreg is faster for sample sizes of 1,000 and 10,000. mlad has a speed improvement of ~71% for a sample size of 100,000 and of ~90% for a sample size of 1,000,000.

Professor of Biostatistics