Using mlad for splines on the log hazard scale.
Restricted cubic splines for the log hazard function.
We have done a lot of work using restricted splines to model survival data. The most common models we use are when using splines to model the log cumulative hazard functon, (see stpm3). However, sometimes it can be useful to directly model on the log hazard scale. Models on the log hazard scale using splines are computationally more intensive as the log cumulative hazard at each event/censoring time is needed to maximize the log-likelihood function and this has to be obtained through numerical integration for each individual in the study.
Various Stata commands can fit a model on the log-hazard scale including stgenreg
, strcs
and merlin
. Since writing mlad
I have written stpm3
, which also can be used to fit mmodels on the log-hazard scale.
A proportional hazards model using restriced cubic splines to estimate the baseline hazard function can be written as,
\[\ln[h(t)] = s(\ln[t]|\boldsymbol{\gamma}) + \mathbf{X}\boldsymbol{\beta}\]
where \(s(\ln(t)|\boldsymbol{\gamma})\) is a restricted cubic spline function, \(\mathbf{X}\) a set of covariates and associated parameters (log hazard ratios), \(\boldsymbol{\beta}\).
The log-likelihood contribution for the \(i^{th}\) subject is
\[ll_i = d_i \ln[h(t_i)] - \int_{t_{0i}}^{t_i} h(u)du \]
where \(t_i\) is the event/censoring time, \(t_{0i}\) is the entry time and \(d_i\) is the event indicator for the \(i^{th}\). The inclusion of \(t_{0i}\) allows for models with delayed entry (left truncation).
For simple parametric models such as a Weibull model it is possible to derive the integral analytically. However, when using splines on the log-hazard scale this is not possible, so numeric integration needs to be used. Note, that this is one reason why it is often advantageous to use splines on the log cumulative hazard scale as the cumulative hazard can obtained analytically.
A simple way to do the numerical integration is using Gauss Legendre quadrature. In order to numerically integrate the hazard function between \(t_{0i}\) and \(t_i\) a set of nodes, \(x_i\) and weights, \(w_i\) is chosen. With more nodes/weights the greater the accuracy of the numerical integration. The nodes and weights that are generated can be used to integrate between [-1,1], but through a change of interval rule we can integrate between [\(t_{0i}\),\(t_{i}\)].
With \(n\) quadrature nodes and weights the integral can be obtaining using,
\[ \int_{t_{0i}}^{t_i} h(u) du \approx \frac{t_i - t_{0i}}{2} \sum_{k=1}^n w_k h\left(\frac{t_i - t_{0i}}{2}x_k + \frac{t_i + t_{0i}}{2}\right)\]
Note that this integral needs to be calculated for every individual in the study every time the likelihood function is called.
An example using stgenreg
stgenreg
was the first command in Stata able to fit splines on the log hazard scale (see Crowther and Lambert 2013 and 2014), without having to finley split the time scale and approximate the integral using Poisson regression. stgenreg
is a very general command that allows the user to define just about any parametric function for the (log) hazard function. Its generallity makes it slow with large datasets and it is a d0
type evaluator which means the gradient and Hessian matrix is obtained using numerical differentiation.
I will use the rott2
data to develop the mlad
function. This example will only fit a proportional hazards model. I will use expand 50
to increase the size of the dataset to 149,100 as I am mainly interested in performance in larger datasets. I will include a few pre-selected covariates in the model.
use https://www.pclambert.net/data/rott2b, clear
. data (augmented with cause of death))
(Rotterdam breast cancer
. expand 50
(146,118 observations created)
tab size, gen(size)
.
Tumour |size, 3 |
classes (t) | Freq. Percent Cum.
------------+-----------------------------------
<=20 mm | 69,350 46.51 46.51
>20-50mmm | 64,550 43.29 89.81
>50 mm | 15,200 10.19 100.00
------------+-----------------------------------
Total | 149,100 100.00
stset os, failure(osi=1) scale(12) exit (time 120)
.
data settings
Survival-time
Failure event: osi==1
Observed time interval: (0, os]on or before: time 120
Exit for analysis: time/12
Time
--------------------------------------------------------------------------total observations
149,100
0 exclusions
--------------------------------------------------------------------------
149,100 observations remaining, representingin single-record/single-failure data
58,550 failures total analysis time at risk and under observation
1,000,121
At risk from t = 0
Earliest observed entry t = 0exit t = 10
Last observed
. timer clear
.
timer on 1
.
xb]) ///
. stgenreg, loghazard([xb(hormon age size2 size3 enodes er pr_1 | #rcs(df(5))) ///
>
> nodes(50)
Variables _eq1_cp2_rcs1 to _eq1_cp2_rcs5 were created
Initial: Log likelihood = -926516.48
Alternative: Log likelihood = -562274.45
Rescale: Log likelihood = -178846.95
Iteration 0: Log likelihood = -178846.95
Iteration 1: Log likelihood = -133783.09
Iteration 2: Log likelihood = -132229.16
Iteration 3: Log likelihood = -131981.55
Iteration 4: Log likelihood = -131970.09
Iteration 5: Log likelihood = -131970.07
Iteration 6: Log likelihood = -131970.07
of obs = 149,100
Log likelihood = -131970.07 Number
-------------------------------------------------------------------------------
| Coefficient Std. err. z P>|z| [95% conf. interval]
--------------+----------------------------------------------------------------
hormon | -.2124007 .0127695 -16.63 0.000 -.2374285 -.1873729
age | .0118462 .0003482 34.02 0.000 .0111637 .0125286
size2 | .3920075 .0098359 39.85 0.000 .3727295 .4112854
size3 | .6967237 .0135336 51.48 0.000 .6701984 .723249
enodes | -1.866594 .0149733 -124.66 0.000 -1.895941 -1.837247
er | -8.12e-06 .0000157 -0.52 0.606 -.000039 .0000227
pr_1 | -.0924092 .001937 -47.71 0.000 -.0962056 -.0886127
_eq1_cp2_rcs1 | .1130317 .0063038 17.93 0.000 .1006765 .1253869
_eq1_cp2_rcs2 | .1819154 .0049743 36.57 0.000 .172166 .1916649
_eq1_cp2_rcs3 | -.0080271 .0050397 -1.59 0.111 -.0179047 .0018505
_eq1_cp2_rcs4 | -.0268743 .0046945 -5.72 0.000 -.0360755 -.0176732
_eq1_cp2_rcs5 | .0171344 .0047121 3.64 0.000 .0078988 .02637_cons | -1.907739 .0254978 -74.82 0.000 -1.957713 -1.857764
-------------------------------------------------------------------------------
Quadrature method: Gauss-Legendre with 50 nodes
timer off 1
.
timer list
. 1: 366.05 / 1 = 366.0470
The stgenreg
command has fitted a model on the log hazard scale using a restricted cubic splines with 5 d.f. (6 knots) to model the effects of time through the log hazard function and a selection of covariates. I have used 50 nodes for the numerical integration as the default of 15 can be too low in some cases. The model took 366.0
seconds to fit.
Fitting the same model using mlad
First I need to write the log-likelihood function in Python. The log-likelihood file is shown below.
type rcs_hazard.py
. as jnp
import jax.numpy as mu
import mladutil
from jax import vmap
def python_ll(beta,X,wt,M,Nnodes):
## Parametersxb = mu.linpred(beta,X,1)
xbrcs = mu.linpred(beta,X,2)
function
## hazard
def rcshaz(t):
vrcsgen = vmap(mu.rcsgen_beta,(0,None,None,None))return(jnp.exp(vrcsgen(jnp.log(t),M["knots"][0],beta[1],M["R_bhazard"]) + xb))
## cumulative hazard"t0"],M["t"],Nnodes,())
cumhaz = mu.vecquad_gl(rcshaz,M[
return likelhood
## return(jnp.sum(wt*(M["d"]*(xb + xbrcs) - cumhaz)))
First the JAX version of
numpy
and themladutil
modules are loaded. In addition, the JAX functionvmap
is loaded. This will be described below.The arguments of the
python_ll
function are similar to previous examples, but the number of nodes is passed as a separate argument rather than contained in the dictionary,M
. I will explain why below.Although, I could have included one linear predictor, I have chosen to have one equation for the baseline log hazard function and one for the covariates. This makes the numerical integration easier.
The log-likelihood function needs to the calculate the hazard function at various time point to peform the numerical integration. This is where the
vmap
function is particularly useful as it can lead to vast speed improvements.vmap
here takes a function,rcsgen_beta()
that returns the restricted cubic spline basis functions multuplied by a vector of parameters for a single time point, and vectorizes it so that it can return the predicted values for all nodes (i.e. many time points). This new function is namedvrcsgen()
.Having defined the
vrcsgen()
function the numerical integration to calculate the cumulative hazard using Gauss Legendre quadrature can be performed using themlad
utility function,mu.vecquad_gl()
. Note that the number of nodes for the numerical integration is passed to the function. This was passed separately to thepython_ll
function as it dictates the size of arrays that are calculated and JAX can give an error if it thinks the size of arrays may change. When usingmlad
below the number of nodes is passed as a static scalar. This tells JAX, that this will not change with different calls to the functions when fitting the model.Finally, the log-likelihood is returned as a scalar by summing the individual contributions to the likelihood.
Now mlad
can be called to maximize the log-likelihood. First I will calculate the restricted cubic spline basis functions at the event/censoring times, store the knots and projection matrix so these can be passed to mlad
. Note that the projection matrix can be used to transform the non orthogonolized splines to orthogonalized. I will use the same number of nodes for the numerical integration as I used when using stgenreg
.
timer on 2
.
gen double lnt = ln(_t)
.
gen(_rcs) df(5) if2(_d==1) orthog
. rcsgen lnt,
Variables _rcs1 to _rcs5 were created
mata: st_matrix("knots",strtoreal(tokens(st_global("r(knots)"))))
.
matrix R_bhazard = r(R)
.
. scalar Nnodes = 50
.
xb: = hormon age size2 size3 enodes er pr_1, nocons ) ///
. mlad (///
> (rcs: = _rcs1 _rcs2 _rcs3 _rcs4 _rcs5) ///
> , llfile(rcs_hazard) ///
> othervars(_t0 _t _d) d) ///
> othervarnames(t0 t ///
> matrices(knots R_bhazard)
> staticscalars(Nnodes)
Initial: Log likelihood = -1000121.2
Alternative: Log likelihood = -429338.54
Rescale: Log likelihood = -248692.69eq: Log likelihood = -224204.35
Rescale
Iteration 0: Log likelihood = -224204.35
Iteration 1: Log likelihood = -218528.68
Iteration 2: Log likelihood = -206013.59
Iteration 3: Log likelihood = -205602.38
Iteration 4: Log likelihood = -205574.87
Iteration 5: Log likelihood = -205574.8
Iteration 6: Log likelihood = -205574.8
ml display
.
of obs = 149,100
Number chi2(7) = 33137.55
Wald chi2 = 0.0000
Log likelihood = -205574.8 Prob >
------------------------------------------------------------------------------
| Coefficient Std. err. z P>|z| [95% conf. interval]
-------------+----------------------------------------------------------------xb |
hormon | -.2124006 .0127695 -16.63 0.000 -.2374284 -.1873728
age | .0118462 .0003482 34.02 0.000 .0111637 .0125286
size2 | .3920078 .0098359 39.85 0.000 .3727299 .4112858
size3 | .6967243 .0135336 51.48 0.000 .670199 .7232496
enodes | -1.866594 .0149733 -124.66 0.000 -1.895941 -1.837247
er | -8.12e-06 .0000157 -0.52 0.606 -.000039 .0000227
pr_1 | -.0924092 .001937 -47.71 0.000 -.0962056 -.0886127
-------------+----------------------------------------------------------------
rcs |
_rcs1 | .1130317 .0063038 17.93 0.000 .1006765 .1253869
_rcs2 | .1819154 .0049743 36.57 0.000 .172166 .1916648
_rcs3 | -.0080271 .0050397 -1.59 0.111 -.0179047 .0018505
_rcs4 | -.0268743 .0046945 -5.72 0.000 -.0360755 -.0176732
_rcs5 | .0171344 .0047121 3.64 0.000 .0078988 .02637_cons | -1.90774 .0254977 -74.82 0.000 -1.957714 -1.857765
------------------------------------------------------------------------------
timer off 2 .
I have used two equations to separate out covariate effects from the effect of time.
I pass variables need by
mlad
using theothervars()
option and rename usingothervarnames()
.The matrices containing the knot positions and the projection matrix are passed using the
matrices
option.Finally, the number of nodes is passed using the
staticscalars()
option rather than thescalar
option for the reason described above.
timer list
.
1: 366.05 / 1 = 366.0470 2: 9.76 / 1 = 9.7600
The model is notably faster than stgenreg
taking 9.8
seconds to fit. This is a speed gain of 97.3%. This is perhaps not a fair comparison as users are far more likely these days to fit such a model using stpm3
or merlin
.
The same model using stpm3
.
stpm3
can be used to fit the same model, but with user friendly syntax. stpm3
uses a gf2
evaluator meaning that the derivatives required for the gradient and Hessian functions have been derived analytically.
The model is fitted below,
timer on 3
.
scale(lnhazard) ///
. stpm3 hormon age size2 size3 enodes er pr_1, df(5) nodes(50)
> splinetype(rcs) integoptions(gl allnum)
Iteration 0: Log likelihood = -878143.18
Iteration 1: Log likelihood = -153694.4
Iteration 2: Log likelihood = -137939.98
Iteration 3: Log likelihood = -132405.71
Iteration 4: Log likelihood = -132006.08
Iteration 5: Log likelihood = -131970.24
Iteration 6: Log likelihood = -131970.07
Iteration 7: Log likelihood = -131970.07
of obs = 149,100
Number chi2(7) = 33137.55
Wald chi2 = 0.0000
Log likelihood = -131970.07 Prob >
------------------------------------------------------------------------------
| Coefficient Std. err. z P>|z| [95% conf. interval]
-------------+----------------------------------------------------------------xb |
hormon | -.2124006 .0127695 -16.63 0.000 -.2374284 -.1873728
age | .0118462 .0003482 34.02 0.000 .0111637 .0125286
size2 | .3920078 .0098359 39.85 0.000 .3727299 .4112858
size3 | .6967243 .0135336 51.48 0.000 .670199 .7232496
enodes | -1.866594 .0149733 -124.66 0.000 -1.895941 -1.837247
er | -8.12e-06 .0000157 -0.52 0.606 -.000039 .0000227
pr_1 | -.0924092 .001937 -47.71 0.000 -.0962056 -.0886127
-------------+----------------------------------------------------------------
time |
_rcs1 | 1.420007 .0455306 31.19 0.000 1.330769 1.509246
_rcs2 | -.8115046 .1058837 -7.66 0.000 -1.019033 -.6039765
_rcs3 | 3.057937 .3717691 8.23 0.000 2.329283 3.786591
_rcs4 | -3.85433 .6594532 -5.84 0.000 -5.146834 -2.561825
_rcs5 | 2.040627 .5611932 3.64 0.000 .9407085 3.140545_cons | -2.165198 .0401335 -53.95 0.000 -2.243859 -2.086538
------------------------------------------------------------------------------
Quadrature method: Gauss-Legendre with 50 nodes.
timer off 3 .
Note by default stpm3
uses tanh-sinh quadrature rather than Gauss Legendre and used 3-part integration that makes use od the fact that analytical integrals can be obtained before the first and after the last knot, so I use integoptions(gl allnum)
to make the models comparable. In addition stpm3
uses a different spline basis by default, so, again for comparability reasons uses, splinetype(rcs)
.
The same model using merlin
.
The same model can be fitted using merlin
, which is a general command to fit a range of models that also include models with random effects (not used here). For the model fitted here merlin
uses a gf2
type evaluator.
The model is fitted below
timer on 4
.
orthog log event), ///
. merlin (_t hormon age size2 size3 enodes er pr_1 rcs(_t, df(5) family(loghazard, failure(_d) ) timevar(_t))
> for model 1, component 8: _cmp_1_8_1 to _cmp_1_8_5
variables created
model:
Fitting full
Iteration 0: Log likelihood = -1000121.2
Iteration 1: Log likelihood = -218883.77
Iteration 2: Log likelihood = -207709.63
Iteration 3: Log likelihood = -205789.3
Iteration 4: Log likelihood = -205585.3
Iteration 5: Log likelihood = -205574.91
Iteration 6: Log likelihood = -205574.88
Iteration 7: Log likelihood = -205574.88
effects regression model Number of obs = 149,100
Fixed
Log likelihood = -205574.88
------------------------------------------------------------------------------
| Coefficient Std. err. z P>|z| [95% conf. interval]
-------------+----------------------------------------------------------------
_t: |
hormon | -.2123989 .0127695 -16.63 0.000 -.2374267 -.1873711
age | .0118462 .0003482 34.02 0.000 .0111637 .0125286
size2 | .3920084 .0098359 39.85 0.000 .3727304 .4112863
size3 | .6967249 .0135336 51.48 0.000 .6701996 .7232502
enodes | -1.866594 .0149733 -124.66 0.000 -1.895941 -1.837247
er | -8.12e-06 .0000157 -0.52 0.606 -.000039 .0000227
pr_1 | -.0924091 .001937 -47.71 0.000 -.0962056 -.0886127
rcs():1 | .1130367 .0063038 17.93 0.000 .1006815 .1253919
rcs():2 | .1819153 .0049744 36.57 0.000 .1721657 .1916649
rcs():3 | -.0080333 .0050397 -1.59 0.111 -.017911 .0018444
rcs():4 | -.0268593 .0046946 -5.72 0.000 -.0360605 -.017658
rcs():5 | .0171084 .0047121 3.63 0.000 .0078729 .0263439_cons | -1.907736 .0254977 -74.82 0.000 -1.95771 -1.857761
------------------------------------------------------------------------------
timer off 4 .
timer list
.
1: 366.05 / 1 = 366.0470
2: 9.76 / 1 = 9.7600
3: 20.43 / 1 = 20.4340 4: 133.52 / 1 = 133.5190
In this dataset mlad
has a speed gain of 97.3% over stgenreg
, 52.2% over stpm3
and 92.7% over merlin
.
Performance in larger datasets
The following table gives times and percentage speed improvements when comparing mlad
with stgenreg
, strcs
and merlin
for a range of sample sizes. This model is a proportional hazards model incorporating 10 covariates.
Sample Size | mlad |
stgenreg |
strcs |
merlin |
stpm3 |
---|---|---|---|---|---|
1,000 | 0.6 | 4.9 (87.8%) | 0.4 (-50%) | 1.9 (68.4%) | - |
10,000 | 0.9 | 48 (98.1%) | 2.4 (62.5%) | 11 (91.7%) | 0.7 (-30.3%) |
50,000 | 2.2 | 193 (98.9%) | 12 (82.0%) | 86 (97.5%) | 4.0 (46.0%) |
100,000 | 2.5 | 452 (99.2%) | 27 (87.2%) | 178 (98.1%) | 7.7 (67.9%) |
250,000 | 5.2 | 1,125 (99.3%) | 69 (89.2%) | 441 (98.3%) | 21 (75.3%) |
500,000 | 12.9 | 2,329 (99.4%) | 139 (89.8%) | 898 (98.4%) | 44 (70.1%) |
1,000,000 | 19.4 | 1,530 (99.4%) | 75 (90.7%) | 1,238 (98.5%) | 90 (78.6%) |
2,500,000 | 50.7 | - | 678 (90.7%) | 4,734(98.6%) | 229 (77.8%) |
The speed gains over stgenreg
are substantial with the models running in less than 1% of the time for sample sizes of 100,000 or more. The speed gains over strcs
are of note with the models running in less than 10% of the time for sample sizes of 1,000,000 or more. The speed gains over merlin
are substantial with the models running in less than 2% of the time for sample sizes of 100,000 or more. The speed gains over stpm3
are notable, but not as large as the others. I put some effort into making stpm3
computationally efficient.
This program is not efficient
The likelihood function for this model is fairly simple, but it is inefficient. The restricted cubic spline basis functions at the nodes are calculated each time the function is called. This is unncessary as the positions of the nodes do not change. In another example I fit the same model but pre-calculate the basis functions at the nodes.
References
Crowther, M.J. merlin—A unified modeling framework for data analysis and methods development in Stata The Stata Journal 2020;20:763–784
Bower H., Crowther M.J., Lambert, P.C. strcs: A command for fitting flexible parametric survival models on the log-hazard scale The Stata Journal 2016;16:989-1012
Crowther, M.J, Lambert, P.C. A general framework for parametric survival analysis. Statistics in Medicine 2014;33:5280-5297
Crowther, M.J., Lambert, P.C. stgenreg: A Stata Package for General Parametric Survival Analysis Journal of Statistical Software 2013;53:1-17