Introduction to mlad - a Weibull survival model

A Weibull model

In this introduction to mlad I will fit a Weibull survival model using ml and mlad in order to show their similarities and their differences. You can fit a Weibull survival model using Stata’s streg command, but for the purposes of this example we will assume that streg does not exist and need to code from scratch.

The survival function for a Weibull model is,

$$S(t) = \exp(\lambda t^{\gamma})$$

and the hazard function is,

$$h(t) = \lambda \gamma t^{\gamma-1}$$

The parameters $\lambda$ and $\gamma$ can be modelled using a linear predictor, so the their values depend on covariates $\boldsymbol{X}$. As $\lambda$ and $\gamma$ are positive, this is usually done on the log scale.

With survival data for each individual $i$, we have a survival time, $t_i$, and an event indicator $d_i$, with, $d_i=1$, denoting an event (e.g. death) and, $d_i=0$, a censored observation.

The log-likelihood contribution for the $i^{th}$ individual with survival data is,

$$LL_i = d_i \ln[h(t_i)] - \ln[S(t_i)]$$

So for a Weibull model the log-likelihood contribution for individual $i$ is,

$$LL_i = d_i \left[\ln(\lambda) + \ln(\gamma) +(\gamma-1)\ln(t)\right] - \lambda t^{\gamma}$$

The total log-likelihood is just the sum over the $N$ indiviuals,

$$LL = \sum_{i=1}^N ll_i$$

Mazimizing the likelihood using ml

I will first simulate some survival data using survsim.

. clear

. set obs 500000
Number of observations (_N) was 0, now 500,000.

. set seed 2345

. gen x1 = rnormal()

. gen x2 = rnormal()

. survsim t d, dist(weibull) lambda(0.2) gamma(0.8) maxt(5) cov(x1 0.1 x2 0.1)

I have simulated 500000 survival times from a Weibull disribution with parameters $\lambda=0.2$ and $\gamma=0.8$. There is a maximum follow-up time of 5 years. I have introduced 2 covariates, x1 and x2, both of which have log hazard ratio of 0.1. Now the data is simulated I can use stset,

. stset t, failure(d=1)

Survival-time data settings

         Failure event: d==1
Observed time interval: (0, t]
     Exit on or before: failure

--------------------------------------------------------------------------
    500,000  total observations
          0  exclusions
--------------------------------------------------------------------------
    500,000  observations remaining, representing
    257,478  failures in single-record/single-failure data
  1,705,319  total analysis time at risk and under observation
                                                At risk from t =         0
                                     Earliest observed entry t =         0
                                          Last observed exit t =         5

To fit the Weibull model using ml an ado file needs to be written that returns the log-likelhood. This is shown below,

. type weibull_d0.ado
program weibull_d0
  version 17.0
  args todo b lnf g H
  
  tempvar lnlambda lngamma
  mleval `lnlambda' = `b', eq(1)
  mleval `lngamma'  = `b', eq(2)
  
  mlsum `lnf' = _d*(`lnlambda' + `lngamma' + (exp(`lngamma') - 1)*ln(_t)) - ///
                exp(`lnlambda')*_t^(exp(`lngamma')) 
  if (`todo'==0 | `lnf'>=.) exit
end

This is a d0 evaluator which means that only the log-likelihood (a scalar) needs to be returned. ml will calculate the derivatives needed for the gradient and Hessian matrix numerically. This is great as I do not have to sit down and do the maths, but will be slow with large datasets.

In the ado file the linear predictor for both ln(lambda) and ln(gamma) are extracted using mleval and then these are fed into the log-likelhood function, which is summed using mlsum.

The model can be now be fitted using ml model.

. timer on 1

. ml model d0 weibull_d0 (ln_lambda: = x1 x2) (ln_gamma:), maximize 

Initial:      Log likelihood = -1705318.9
Alternative:  Log likelihood = -871949.09
Rescale:      Log likelihood = -804798.79
Rescale eq:   Log likelihood = -763674.17
Iteration 0:  Log likelihood = -763674.17  
Iteration 1:  Log likelihood = -733377.66  
Iteration 2:  Log likelihood = -733244.73  
Iteration 3:  Log likelihood = -733244.65  
Iteration 4:  Log likelihood = -733244.65  

. ml display

                                                       Number of obs = 500,000
                                                       Wald chi2(2)  = 5013.81
Log likelihood = -733244.65                            Prob > chi2   =  0.0000

------------------------------------------------------------------------------
             | Coefficient  Std. err.      z    P>|z|     [95% conf. interval]
-------------+----------------------------------------------------------------
ln_lambda    |
          x1 |   .0996771    .001975    50.47   0.000     .0958062    .1035479
          x2 |   .0983837   .0019757    49.80   0.000     .0945114    .1022561
       _cons |  -1.612564   .0028114  -573.58   0.000    -1.618074   -1.607053
-------------+----------------------------------------------------------------
ln_gamma     |
       _cons |  -.2244506   .0018181  -123.45   0.000    -.2280141   -.2208872
------------------------------------------------------------------------------

. timer off 1

. timer list
   1:      5.66 /        1 =       5.6600

The model fitted in 5.66 seconds and the parameter estimates are close to the true values, with the estimated $\lambda$ exp(_b[_cons])=0.199 and the estimated $\gamma$ exp([ln_gamma][_cons])=0.799

Mazimizing the likelihood using mlad

I will now fit the same model using mlad. Rather than write a Stata ado file to define the log-likelhood, a Python function must be written. This is shown below,

. type weibull_ll.py
import jax.numpy as jnp   
import mladutil as mu

def python_ll(beta,X,wt,M):
  lnlam =  mu.linpred(beta,X,1)
  lngam  = mu.linpred(beta,X,2)
  gam = jnp.exp(lngam)

  lli = M["d"]*(lnlam + lngam + (gam - 1)*jnp.log(M["t"])) - jnp.exp(lnlam)*M["t"]**(gam)
  return(jnp.sum(lli))



First two modules are imported. The first is JAX’s version of numpy. This will nearly always have to be imported. The second, mladutil, is a set of utility programs for mlad. The function name must always be python_ll(). There are 4 function arguments.

  • The first arguments, beta, is a Python list with the first item the parameters for ln(lambda) and the second item the parameters for ln(gamma).

  • The second function argument is X. The covariates are automatically transferred to Python and stored in a list with the covariates stored in an array for the first equation in X[1] and the kth equation in X[k]. If any offsets have been specified, these will also be included in X[0].

  • The third argument, wt defines any weights that have been specified or a columns of 1’s if they have not been specified.

  • The final argument, M, is a Python dictionary containing any variables specified in the othervars() option of mlad, matrices specified in the matrices() option or scalars specified in the scalars() option. Here the survival time (_t) and the event indicator (_d) are needed to calculate the likelihood function. Note that these will be named t and d in the Python dictionary, M, as defined in the othervarnames() option.

  • linpred() is a utility function to calculate the current predicted value for the kth equation given X and beta. It is recommended that you use this function. linpred will automatically incorporate any offsets.

The syntax of mlad is similar to ml when specifying the equations. The name of the python file giving the log-likelihood is passed using the llfile() option. The survival time _t and the event indicator _d need to be passed to Python using the othervars() option. By default these will have the same names in the Python dictionary passed to the likelihood function, M, but are renamed below to t and d using the othervarnames() option.

. timer on 2

. mlad (ln_lambda: = x1 x2)   ///
>      (ln_gamma: ),          /// 
>       othervars(_t _d)      ///
>       othervarnames(t d)    ///
>       llfile(weibull_ll)  

Initial:      Log likelihood = -1705318.9
Alternative:  Log likelihood = -871949.09
Rescale:      Log likelihood = -804798.79
Rescale eq:   Log likelihood = -763674.17
Iteration 0:  Log likelihood = -763674.17  
Iteration 1:  Log likelihood = -733382.53  
Iteration 2:  Log likelihood = -733244.73  
Iteration 3:  Log likelihood = -733244.65  
Iteration 4:  Log likelihood = -733244.65  

. ml display

                                                       Number of obs = 500,000
                                                       Wald chi2(2)  = 5013.81
Log likelihood = -733244.65                            Prob > chi2   =  0.0000

------------------------------------------------------------------------------
             | Coefficient  Std. err.      z    P>|z|     [95% conf. interval]
-------------+----------------------------------------------------------------
ln_lambda    |
          x1 |   .0996771    .001975    50.47   0.000     .0958062    .1035479
          x2 |   .0983837   .0019757    49.80   0.000     .0945114    .1022561
       _cons |  -1.612564   .0028114  -573.58   0.000    -1.618074   -1.607053
-------------+----------------------------------------------------------------
ln_gamma     |
       _cons |  -.2244506   .0018181  -123.45   0.000    -.2280141   -.2208872
------------------------------------------------------------------------------

. timer off 2

. timer list                    
   1:      5.66 /        1 =       5.6600
   2:      0.84 /        1 =       0.8380

The estimates are identical to those obtained using ml. There is a small increase in speed with mlad running in 0.84 seconds. Greater speed gains will be obtained as the sample size increases.

If we look at what type of ml method has been used, we see it is a d2 evaluator.

. di "`e(ml_method)'"
d2

This means that although only the log-likehood function was returned by the Python function, the automatic differentiation used by the Jax module in python has enabled the gradient and Hessian functions to be returned.

Running on a larger dataset

The table below shows the time using different methods for estimation for a similar example to above, but now with a simulated sample size of 10,000,000 observations and 10 covariates. Each of the 10 covariates is included in the linear predictor for both $\ln(\lambda)$ and $\ln(\gamma)$ (thus a total of 22 parameters when including the intercepts for each equation),

Method Time Program
ml d0 8762 3208
mlad 52 22
ml d2 136 22
ml lf0 281 79
streg 207 33
  • Using ml with a d0 evaluator takes 2 hours and 43 minutes. Part of the reason it takes so long is due the likelihood program having to be called 3208 times. This is due to the numeric differentiation being required for all 22 parameters in the model.

  • Supplying the same likelihood, but now in Python using mlad leads to the model being fitted in 52 seconds, a dramatic improvement. Most of speed gains is due to fewer program calls, the likelihood function is only called 22 times. This is because there are now functions to derive the gradient vector and Hessian matrix rather than using numerical differentiation.

  • The maths to derive the gradient and Hessian matrix for a Weibull model is fairly simple, so in this case writing a d2 evaluator for ml is simple (see here). The model fits in 136 seconds. Note this is still slower than mlad. The reason for this is that the Python functions are compiled using the XLA compiler and able to access multiple processors (I am using a 12 core computer with 2 threads per core).

  • As the log-likelihood in this case can be expressed in linear form, i.e. we can sum the log-likelihood contribution for each row to give the total log-likelihood, a lf0 evaluator can be used. This fits much faster than d0 at 281 seconds.

  • A Weibull model can also be fitted using streg, which took 207 seconds to fit. streg takes more care that the model will converge by fitting a constant only model first before fitting the full model, which explains the greater number of calls to the likelihood program.

Biostatistician