Introduction to mlad - a Weibull survival model

A Weibull model

In this introduction to mlad I will fit a Weibull survival model using ml and mlad in order to show their similarities and their differences. You can fit a Weibull survival model using Stata’s streg but for the purposes of this example we will assume that streg does not exist and need to code from scratch.

The survival function for a Weibull model is,

$$S(t) = \exp(\lambda t^{\gamma})$$

and the hazard function is,

$$h(t) = \lambda \gamma t^{\gamma-1}$$

The parameters $\lambda$ and $\gamma$ can be modelled using a linear predictor. As $\lambda$ and $\gamma$ are positive, this is usually done on the log scale.

With survival data for eash individual $i$, we have a survival time, $t_i$, and an event indicator $d_i$, with, $d_i=1$, denoting an event (e.g. death) and, $d_i=0$, a censored observation.

The log-likelihood contribution for the $i^{th}$ individual with survival data is,

$$ll_i = d_i \ln[h(t_i)] - \ln[S(t_i)]$$

So for a Weibull model the log-likelihood contribution for individual $i$ is,

$$ll_i = d_i \left[\ln(\lambda) + \ln(\gamma) +(\gamma-1)\ln(t)\right] - \lambda t^{\gamma}$$

The total log-likelihood is just the sum over the $N$ indiviuals,

$$ll = \sum_{i=1}^N ll_i$$

Mazimizing the likelihood using ml

I will first simulate some survival data using survsim.

. clear

. set obs 500000
Number of observations (_N) was 0, now 500,000.

. set seed 987123

. gen x1 = rnormal()

. gen x2 = rnormal()

. survsim t d, dist(weibull) lambda(0.2) gamma(0.8) maxt(5) cov(x1 0.1 x2 0.1)

I have simulated 500000 survival times from a Weibull disribution with parameters $\lambda=0.2$ and $\gamma=0.8$. There is a maximum follow-up time of 5 years. I have introduced 2 covariates, x1 and x2, both of which have log hazard ratio of 0.1. Now the data is simulated I can use stset,

. stset t, failure(d=1)

Survival-time data settings

Failure event: d==1
Observed time interval: (0, t]
Exit on or before: failure

--------------------------------------------------------------------------
500,000  total observations
0  exclusions
--------------------------------------------------------------------------
500,000  observations remaining, representing
258,175  failures in single-record/single-failure data
1,703,093  total analysis time at risk and under observation
At risk from t =         0
Earliest observed entry t =         0
Last observed exit t =         5

To fit the Weibull model using ml an ado file needs to be written that returns the log-likelhood. This is shown below,

program weibull_d0
version 17.0
args todo b lnf g H

tempvar lnlambda lngamma
mleval lnlambda' = b', eq(1)
mleval lngamma'  = b', eq(2)

mlsum lnf' = _d*(lnlambda' + lngamma' + (exp(lngamma') - 1)*ln(_t)) - ///
exp(lnlambda')*_t^(exp(lngamma'))
if (todo'==0 | lnf'>=.) exit
end

This is a d0 evaluator which means that only the log-likelihood (a scalar) needs to be returned. ml will calculate the derivatives need for the gradient and Hessian matrix numerically. This is great as I do not have to sit down and do the maths, but will be slow with large datasets.

In the ado file the linear predictor for both ln(lambda) and ln(gamma) are extracted using mleval and then these are fed into the log-likelhood function, which is summed using mlsum.

The model can be now be fitted using ml model.

. timer on 1

. ml model d0 weibull_d0 (ln_lambda: = x1 x2) (ln_gamma:), maximize

initial:       log likelihood = -1703093.1
alternative:   log likelihood = -872315.77
rescale:       log likelihood = -806348.45
rescale eq:    log likelihood = -764593.08
Iteration 0:   log likelihood = -764593.08
Iteration 1:   log likelihood = -734592.41
Iteration 2:   log likelihood = -734446.28
Iteration 3:   log likelihood = -734446.19
Iteration 4:   log likelihood = -734446.19

. ml display

Number of obs = 500,000
Wald chi2(2)  = 4794.46
Log likelihood = -734446.19                            Prob > chi2   =  0.0000

------------------------------------------------------------------------------
| Coefficient  Std. err.      z    P>|z|     [95% conf. interval]
-------------+----------------------------------------------------------------
ln_lambda    |
x1 |   .0948762   .0019735    48.08   0.000     .0910083    .0987442
x2 |    .098692   .0019759    49.95   0.000     .0948193    .1025647
_cons |  -1.610073   .0028083  -573.33   0.000    -1.615577   -1.604569
-------------+----------------------------------------------------------------
ln_gamma     |
_cons |  -.2227555   .0018153  -122.71   0.000    -.2263135   -.2191975
------------------------------------------------------------------------------

. timer off 1

. timer list
1:      6.68 /        1 =       6.6820

The model fitted in 6.68 seconds and the parameter estimates are close to the true values, with the estimated $\lambda$ is exp(_b[_cons])=0.200 and the estimated $\gamma$ is exp([ln_gamma][_cons])=0.800

I will now fit the same model using mlad. Rather than write a Stata ado file to define the log-likelhood, a Python function must be written. This is shown below,

. type weibull_ll.py
import jax.numpy as jnp

def python_ll(beta,X,wt,M):
lnlam =  mu.linpred(beta,X,1)
lngam  = mu.linpred(beta,X,2)
gam = jnp.exp(lngam)

lli = M["d"]*(lnlam + lngam + (gam - 1)*jnp.log(M["t"])) - jnp.exp(lnlam)*M["t"]**(gam)
return(jnp.sum(lli))

First two modules are imported. The first is JAX’s version of numpy. This will nearly always have to be imported. The second, mladutil, is a set of utility programs for mlad. The function name must always be python_ll(). There are 4 function arguments.

• The first arguments, beta, is a Python list with the first item the parameters for ln(lambda) and the second item the parameters for ln(gamma).

• The second function arguement is X. The covariates are automatically transferred to Python and stored in a list with the covariates stored in an array for the first equation in X and the kth equation in X[k]. If any offsets have been specified, these will also be included in X.

• The third argument, wt defines any weights that have been specified or a columns of 1’s if they have not been specified.

• The final arguement, M, is a Python dictionary containing any variables specified in the othervars() option of mlad, matrices specified in the matrices() option or scalars specified in the scalars() option. Here the survival time (_t) and the event indicator (_d) are needed to calculate the likelihood function. Note that these will be named t and d in the Python dictionary, M, as defined in the othervarnames() option.

• linpred() is a utility function to calculate the current predicted value for the kth equation given X and beta. It is recommended that you use this function. linpred will automatically incorporate any offsets.

The syntax of mlad is similar to ml when specifying the equations. The name of the python file giving the log-likelihood is passed using the llfile() option. The survival time _t and the event indicator _d need to be passed to Python using the othervars() option. By default these will have the same names in the Python dictionary passed to the likelihood function, M, but are renamed below to t and d using the othervarnames() option.

. timer on 2

. mlad (ln_lambda: = x1 x2)   ///
>      (ln_gamma: ),          ///
>       othervars(_t _d)      ///
>       othervarnames(t d)    ///
>       llfile(weibull_ll)

initial:       log likelihood = -1703093.1
alternative:   log likelihood = -872315.77
rescale:       log likelihood = -806348.45
rescale eq:    log likelihood = -764593.08
Iteration 0:   log likelihood = -764593.08
Iteration 1:   log likelihood = -734598.86
Iteration 2:   log likelihood = -734446.29
Iteration 3:   log likelihood = -734446.19
Iteration 4:   log likelihood = -734446.19

. ml display

Number of obs = 500,000
Wald chi2(2)  = 4794.46
Log likelihood = -734446.19                            Prob > chi2   =  0.0000

------------------------------------------------------------------------------
| Coefficient  Std. err.      z    P>|z|     [95% conf. interval]
-------------+----------------------------------------------------------------
ln_lambda    |
x1 |   .0948762   .0019735    48.08   0.000     .0910083    .0987442
x2 |    .098692   .0019759    49.95   0.000     .0948193    .1025647
_cons |  -1.610073   .0028083  -573.33   0.000    -1.615577   -1.604569
-------------+----------------------------------------------------------------
ln_gamma     |
_cons |  -.2227555   .0018153  -122.71   0.000    -.2263135   -.2191976
------------------------------------------------------------------------------

. timer off 2

. timer list
1:      6.68 /        1 =       6.6820
2:      1.24 /        1 =       1.2420

The estimates are identical to those obtained using ml. There is a small increase in speed with mlad running in 1.24 seconds. Greater speed gains will be obtained as the sample size increases.

If we look at what type of ml method has been used, we see it is a d2 evaluator.

. di "`e(ml_method)'"
d2

This means that although only the log-likehood function was returned by the Python function, the automatic differentiation used by the Jax module in python has enabled the gradient and Hessian functions to be returned.

Running on a larger dataset

The table below shows the time using different methods for estimation for a similar example to above, but now with a simulated sample size of 10,000,000 observations and 10 covariates. Each of the 10 covariates is included in the linear predictor for both $\ln(\lambda)$ and $\ln(\gamma)$,

Method Time Program
ml d0 8762 3208