Introduction to mlad - a Weibull survival model
A Weibull model
In this introduction to mlad
I will fit a Weibull survival model using ml
and mlad
in order to show their similarities and their differences. You can fit a Weibull survival model using Stata’s streg
command, but for the purposes of this example we will assume that streg
does not exist and need to code from scratch.
The survival function for a Weibull model is,
\[S(t) = \exp(\lambda t^{\gamma})\]
and the hazard function is,
\[h(t) = \lambda \gamma t^{\gamma-1}\]
The parameters \(\lambda\) and \(\gamma\) can be modelled using a linear predictor, so the their values depend on covariates \(\boldsymbol{X}\). As \(\lambda\) and \(\gamma\) are positive, this is usually done on the log scale.
With survival data for each individual \(i\), we have a survival time, \(t_i\), and an event indicator \(d_i\), with, \(d_i=1\), denoting an event (e.g. death) and, \(d_i=0\), a censored observation.
The log-likelihood contribution for the \(i^{th}\) individual with survival data is,
\[LL_i = d_i \ln[h(t_i)] - \ln[S(t_i)]\]
So for a Weibull model the log-likelihood contribution for individual \(i\) is,
\[LL_i = d_i \left[\ln(\lambda) + \ln(\gamma) +(\gamma-1)\ln(t)\right] - \lambda t^{\gamma}\]
The total log-likelihood is just the sum over the \(N\) indiviuals,
\[LL = \sum_{i=1}^N ll_i\]
Mazimizing the likelihood using ml
I will first simulate some survival data using survsim
.
clear
.
set obs 500000
. of observations (_N) was 0, now 500,000.
Number
set seed 2345
.
gen x1 = rnormal()
.
gen x2 = rnormal()
.
d, dist(weibull) lambda(0.2) gamma(0.8) maxt(5) cov(x1 0.1 x2 0.1) . survsim t
I have simulated 500000 survival times from a Weibull disribution with parameters \(\lambda=0.2\) and \(\gamma=0.8\). There is a maximum follow-up time of 5 years. I have introduced 2 covariates, x1
and x2
, both of which have log hazard ratio of 0.1. Now the data is simulated I can use stset
,
stset t, failure(d=1)
.
data settings
Survival-time
d==1
Failure event:
Observed time interval: (0, t]on or before: failure
Exit
--------------------------------------------------------------------------total observations
500,000
0 exclusions
--------------------------------------------------------------------------
500,000 observations remaining, representingin single-record/single-failure data
257,478 failures total analysis time at risk and under observation
1,705,319
At risk from t = 0
Earliest observed entry t = 0exit t = 5 Last observed
To fit the Weibull model using ml
an ado file needs to be written that returns the log-likelhood. This is shown below,
type weibull_d0.ado
. program weibull_d0
version 17.0
args todo b lnf g H
tempvar lnlambda lngamma
`lnlambda' = `b', eq(1)
mleval `lngamma' = `b', eq(2)
mleval
`lnf' = _d*(`lnlambda' + `lngamma' + (exp(`lngamma') - 1)*ln(_t)) - ///
mlsum exp(`lnlambda')*_t^(exp(`lngamma'))
if (`todo'==0 | `lnf'>=.) exit
end
This is a d0
evaluator which means that only the log-likelihood (a scalar) needs to be returned. ml
will calculate the derivatives needed for the gradient and Hessian matrix numerically. This is great as I do not have to sit down and do the maths, but will be slow with large datasets.
In the ado file the linear predictor for both ln(lambda) and ln(gamma) are extracted using mleval
and then these are fed into the log-likelhood function, which is summed using mlsum
.
The model can be now be fitted using ml model
.
timer on 1
.
ml model d0 weibull_d0 (ln_lambda: = x1 x2) (ln_gamma:), maximize
.
Initial: Log likelihood = -1705318.9
Alternative: Log likelihood = -871949.09
Rescale: Log likelihood = -804798.79eq: Log likelihood = -763674.17
Rescale
Iteration 0: Log likelihood = -763674.17
Iteration 1: Log likelihood = -733377.66
Iteration 2: Log likelihood = -733244.73
Iteration 3: Log likelihood = -733244.65
Iteration 4: Log likelihood = -733244.65
ml display
.
of obs = 500,000
Number chi2(2) = 5013.81
Wald chi2 = 0.0000
Log likelihood = -733244.65 Prob >
------------------------------------------------------------------------------
| Coefficient Std. err. z P>|z| [95% conf. interval]
-------------+----------------------------------------------------------------
ln_lambda |
x1 | .0996771 .001975 50.47 0.000 .0958062 .1035479x2 | .0983837 .0019757 49.80 0.000 .0945114 .1022561
_cons | -1.612564 .0028114 -573.58 0.000 -1.618074 -1.607053
-------------+----------------------------------------------------------------
ln_gamma |_cons | -.2244506 .0018181 -123.45 0.000 -.2280141 -.2208872
------------------------------------------------------------------------------
timer off 1
.
timer list
.
1: 15.41 / 2 = 7.7065 2: 0.75 / 1 = 0.7480
The model fitted in 15.41 seconds and the parameter estimates are close to the true values, with the estimated \(\lambda\) exp(_b[_cons])=0.199
and the estimated \(\gamma\) exp([ln_gamma][_cons])=0.799
Mazimizing the likelihood using mlad
I will now fit the same model using mlad
. Rather than write a Stata ado file to define the log-likelhood, a Python function must be written. This is shown below,
type weibull_ll.py
. as jnp
import jax.numpy as mu
import mladutil
def python_ll(beta,X,wt,M):
lnlam = mu.linpred(beta,X,1)
lngam = mu.linpred(beta,X,2)exp(lngam)
gam = jnp.
"d"]*(lnlam + lngam + (gam - 1)*jnp.log(M["t"])) - jnp.exp(lnlam)*M["t"]**(gam)
lli = M[return(jnp.sum(lli))
First two modules are imported. The first is JAX’s version of numpy. This will nearly always have to be imported. The second, mladutil, is a set of utility programs for mlad
. The function name must always be python_ll(). There are 4 function arguments.
The first arguments, beta, is a Python list with the first item the parameters for ln(lambda) and the second item the parameters for ln(gamma).
The second function argument is X. The covariates are automatically transferred to Python and stored in a list with the covariates stored in an array for the first equation in X[1] and the kth equation in X[k]. If any offsets have been specified, these will also be included in X[0].
The third argument, wt defines any weights that have been specified or a columns of 1’s if they have not been specified.
The final argument, M, is a Python dictionary containing any variables specified in the
othervars()
option ofmlad
, matrices specified in thematrices()
option or scalars specified in thescalars()
option. Here the survival time (_t
) and the event indicator (_d
) are needed to calculate the likelihood function. Note that these will be named t and d in the Python dictionary, M, as defined in the othervarnames() option.linpred() is a utility function to calculate the current predicted value for the kth equation given X and beta. It is recommended that you use this function. linpred will automatically incorporate any offsets.
The syntax of mlad
is similar to ml
when specifying the equations. The name of the python file giving the log-likelihood is passed using the llfile()
option. The survival time _t
and the event indicator _d
need to be passed to Python using the othervars()
option. By default these will have the same names in the Python dictionary passed to the likelihood function, M, but are renamed below to t and d using the othervarnames()
option.
timer on 2
.
x2) ///
. mlad (ln_lambda: = x1 ///
> (ln_gamma: ), ///
> othervars(_t _d) d) ///
> othervarnames(t
> llfile(weibull_ll)
Initial: Log likelihood = -1705318.9
Alternative: Log likelihood = -871949.09
Rescale: Log likelihood = -804798.79eq: Log likelihood = -763674.17
Rescale
Iteration 0: Log likelihood = -763674.17
Iteration 1: Log likelihood = -733382.53
Iteration 2: Log likelihood = -733244.73
Iteration 3: Log likelihood = -733244.65
Iteration 4: Log likelihood = -733244.65
ml display
.
of obs = 500,000
Number chi2(2) = 5013.81
Wald chi2 = 0.0000
Log likelihood = -733244.65 Prob >
------------------------------------------------------------------------------
| Coefficient Std. err. z P>|z| [95% conf. interval]
-------------+----------------------------------------------------------------
ln_lambda |
x1 | .0996771 .001975 50.47 0.000 .0958062 .1035479x2 | .0983837 .0019757 49.80 0.000 .0945114 .1022561
_cons | -1.612564 .0028114 -573.58 0.000 -1.618074 -1.607053
-------------+----------------------------------------------------------------
ln_gamma |_cons | -.2244506 .0018181 -123.45 0.000 -.2280141 -.2208872
------------------------------------------------------------------------------
timer off 2
.
timer list
.
1: 15.41 / 2 = 7.7065 2: 1.58 / 2 = 0.7915
The estimates are identical to those obtained using ml
. There is a small increase in speed with mlad
running in 1.58 seconds. Greater speed gains will be obtained as the sample size increases.
If we look at what type of ml method has been used, we see it is a d2
evaluator.
di "`e(ml_method)'"
. d2
This means that although only the log-likehood function was returned by the Python function, the automatic differentiation used by the Jax module in python has enabled the gradient and Hessian functions to be returned.
Running on a larger dataset
The table below shows the time using different methods for estimation for a similar example to above, but now with a simulated sample size of 10,000,000 observations and 10 covariates. Each of the 10 covariates is included in the linear predictor for both \(\ln(\lambda)\) and \(\ln(\gamma)\) (thus a total of 22 parameters when including the intercepts for each equation),
Method | Time | Program |
---|---|---|
ml d0 |
8762 | 3208 |
mlad |
52 | 22 |
ml d2 |
136 | 22 |
ml lf0 |
281 | 79 |
streg |
207 | 33 |
Using
ml
with ad0
evaluator takes 2 hours and 43 minutes. Part of the reason it takes so long is due the likelihood program having to be called 3208 times. This is due to the numeric differentiation being required for all 22 parameters in the model.Supplying the same likelihood, but now in Python using
mlad
leads to the model being fitted in 52 seconds, a dramatic improvement. Most of speed gains is due to fewer program calls, the likelihood function is only called 22 times. This is because there are now functions to derive the gradient vector and Hessian matrix rather than using numerical differentiation.The maths to derive the gradient and Hessian matrix for a Weibull model is fairly simple, so in this case writing a
d2
evaluator forml
is simple (see here). The model fits in 136 seconds. Note this is still slower thanmlad
. The reason for this is that the Python functions are compiled using the XLA compiler and able to access multiple processors (I am using a 12 core computer with 2 threads per core).As the log-likelihood in this case can be expressed in linear form, i.e. we can sum the log-likelihood contribution for each row to give the total log-likelihood, a
lf0
evaluator can be used. This fits much faster thand0
at 281 seconds.A Weibull model can also be fitted using
streg
, which took 207 seconds to fit.streg
takes more care that the model will converge by fitting a constant only model first before fitting the full model, which explains the greater number of calls to the likelihood program.