Using mlad for splines on the log hazard scale.
Restricted cubic splines for the log hazard function.
We have done a lot of work using restricted splines to model survival data. The most common models we use are when using splines to model the log cumulative hazard functon, (see stpm2). However, sometimes it can be useful to directly model on the log hazard scale. Models on the log hazard scale using splines are computationally more intensive as the log cumulative hazard at each event/censoring time is needed to maximize the loglikelihood function and this has to be obtained through numerical integration for each individual in the study.
Various Stata commands can fit a model on the loghazard scale including stgenreg
, strcs
and merlin
.
A proportional hazards model using restriced cubic splines to estimate the baseline hazard function can be written as,
$$\ln[h(t)] = s(\ln[t]\boldsymbol{\gamma}) + \mathbf{X}\boldsymbol{\beta}$$
where $s(\ln(t)\boldsymbol{\gamma})$ is a restricted cubic spline function, $\mathbf{X}$ a set of covariates and associated parameters (log hazard ratios), $\boldsymbol{\beta}$.
The loglikelihood contribution for the $i^{th}$ subject is
$$ll_i = d_i \ln[h(t_i)]  \int_{t_{0i}}^{t_i} h(u)du $$
where $t_i$ is the event/censoring time, $t_{0i}$ is the entry time and $d_i$ is the event indicator for the $i^{th}$. The inclusion of $t_{0i}$ allows for models with delayed entry (left truncation).
For simple parametric models such as a Weibull model it is possible to derive the integral analytically. However, when using splines on the loghazard scale this is not possible, so numeric integration needs to be used. Note, that this is one reason why it is often advantageous to use splines on the log cumulative hazard scale as the cumulative hazard can obtained analytically.
A simple way to do the numerical integration is using Gauss Legendre quadrature. In order to numerically integrate the hazard function between $t_{0i}$ and $t_i$ a set of nodes, $x_i$ and weights, $w_i$ is chosen. With more nodes/weights the greater the accuracy of the numerical integration. The nodes and weights that are generated can be used to integrate between [1,1], but through a change of interval rule we can integrate between [$t_{0i}$,$t_{i}$].
With $n$ quadrature nodes and weights the integral can be obtaining using,
$$ \int_{t_{0i}}^{t_i} h(u) du \approx \frac{t_i  t_{0i}}{2} \sum_{k=1}^n w_k h\left(\frac{t_i  t_{0i}}{2}x_k + \frac{t_i + t_{0i}}{2}\right)$$
Note that this integral needs to be calculated for every individual in the study every time the likelihood function is called.
An example using stgenreg
stgenreg
was the first command in Stata able to fit splines on the log hazard scale (see Crowther and Lambert 2013 and 2014),
without having to finley split the time scale and approximate the integral using Poisson regression.
stgenreg
is a very general command that allows the user to define just about any parametric function for the (log) hazard function.
Its generallity makes it slow with large datasets and it is a d0
type evaluator which means the gradient and Hessian matrix is obtained using numerical differentiation.
I will use the rott2
data to develop the mlad
function. This example will only fit a proportional hazards model. I will use expand 50
to increase the size of the dataset to 149,100 as I am mainly interested in performance in larger datasets. I will include a few preselected covariates in the model.
. use https://www.pclambert.net/data/rott2b, clear
(Rotterdam breast cancer data (augmented with cause of death))
. expand 50
(146,118 observations created)
. tab size, gen(size)
Tumour 
size, 3 
classes (t)  Freq. Percent Cum.
+
<=20 mm  69,350 46.51 46.51
>2050mmm  64,550 43.29 89.81
>50 mm  15,200 10.19 100.00
+
Total  149,100 100.00
. stset os, failure(osi=1) scale(12) exit (time 120)
Survivaltime data settings
Failure event: osi==1
Observed time interval: (0, os]
Exit on or before: time 120
Time for analysis: time/12

149,100 total observations
0 exclusions

149,100 observations remaining, representing
58,550 failures in singlerecord/singlefailure data
1,000,121 total analysis time at risk and under observation
At risk from t = 0
Earliest observed entry t = 0
Last observed exit t = 10
.
. timer clear
. timer on 1
. stgenreg, loghazard([xb]) ///
> xb(hormon age size2 size3 enodes er pr_1  #rcs(df(5))) ///
> nodes(50)
Variables _eq1_cp2_rcs1 to _eq1_cp2_rcs5 were created
Initial: Log likelihood = 926516.48
Alternative: Log likelihood = 562274.45
Rescale: Log likelihood = 178846.95
Iteration 0: Log likelihood = 178846.95
Iteration 1: Log likelihood = 133783.09
Iteration 2: Log likelihood = 132229.16
Iteration 3: Log likelihood = 131981.55
Iteration 4: Log likelihood = 131970.09
Iteration 5: Log likelihood = 131970.07
Iteration 6: Log likelihood = 131970.07
Log likelihood = 131970.07 Number of obs = 149,100

 Coefficient Std. err. z P>z [95% conf. interval]
+
hormon  .2124007 .0127695 16.63 0.000 .2374285 .1873729
age  .0118462 .0003482 34.02 0.000 .0111637 .0125286
size2  .3920075 .0098359 39.85 0.000 .3727295 .4112854
size3  .6967237 .0135336 51.48 0.000 .6701984 .723249
enodes  1.866594 .0149733 124.66 0.000 1.895941 1.837247
er  8.12e06 .0000157 0.52 0.606 .000039 .0000227
pr_1  .0924092 .001937 47.71 0.000 .0962056 .0886127
_eq1_cp2_rcs1  .1130317 .0063038 17.93 0.000 .1006765 .1253869
_eq1_cp2_rcs2  .1819154 .0049743 36.57 0.000 .172166 .1916649
_eq1_cp2_rcs3  .0080271 .0050397 1.59 0.111 .0179047 .0018505
_eq1_cp2_rcs4  .0268743 .0046945 5.72 0.000 .0360755 .0176732
_eq1_cp2_rcs5  .0171344 .0047121 3.64 0.000 .0078988 .02637
_cons  1.907739 .0254978 74.82 0.000 1.957713 1.857764

Quadrature method: GaussLegendre with 50 nodes
. timer off 1
. timer list
1: 571.86 / 1 = 571.8600
The stgenreg
command has fitted a model on the log hazard scale using a restricted cubic splines with 5 d.f. (6 knots) to model the effects of time through the log hazard function and a selection of covariates. I have used 50 nodes for the numerical integration as the default of 15 can be too low in some cases. The model took 571.9
seconds to fit.
Fitting the same model using mlad
First I need to write the loglikelihood function in Python. The loglikelihood file is shown below.
. type rcs_hazard.py
import jax.numpy as jnp
import mladutil as mu
from jax import vmap
def python_ll(beta,X,wt,M,Nnodes):
## Parameters
xb = mu.linpred(beta,X,1)
xbrcs = mu.linpred(beta,X,2)
## hazard function
def rcshaz(t):
vrcsgen = vmap(mu.rcsgen_beta,(0,None,None,None))
return(jnp.exp(vrcsgen(jnp.log(t),M["knots"][0],beta[1],M["R_bhazard"]) + xb))
## cumulative hazard
cumhaz = mu.vecquad_gl(rcshaz,M["t0"],M["t"],Nnodes,())
## return likelhood
return(jnp.sum(wt*(M["d"]*(xb + xbrcs)  cumhaz)))

First the JAX version of
numpy
and themladutil
modules are loaded. In addition, the JAX functionvmap
is loaded. This will be described below. 
The arguments of the
python_ll
function are similar to previous examples, but the number of nodes is passed as a separate argument rather than contained in the dictionary,M
. I will explain why below. 
Although, I could have included one linear predictor, I have chosen to have one equation for the baseline log hazard function and one for the covariates. This makes the numerical integration easier.

The loglikelihood function needs to the calculate the hazard function at various time point to peform the numerical integration. This is where the
vmap
function is particularly useful as it can lead to vast speed improvements.vmap
here takes a function,rcsgen_beta()
that returns the restricted cubic spline basis functions multuplied by a vector of parameters for a single time point, and vectorizes it so that it can return the predicted values for all nodes (i.e. many time points). This new function is namedvrcsgen()
. 
Having defined the
vrcsgen()
function the numerical integration to calculate the cumulative hazard using Gauss Legendre quadrature can be performed using themlad
utility function,mu.vecquad_gl()
. Note that the number of nodes for the numerical integration is passed to the function. This was passed separately to thepython_ll
function as it dictates the size of arrays that are calculated and JAX can give an error if it thinks the size of arrays may change. When usingmlad
below the number of nodes is passed as a static scalar. This tells JAX, that this will not change with different calls to the functions when fitting the model. 
Finally, the loglikelihood is returned as a scalar by summing the individual contributions to the likelihood.
Now mlad
can be called to maximize the loglikelihood.
First I will calculate the restricted cubic spline basis functions at the event/censoring times,
store the knots and projection matrix so these can be passed to mlad
.
Note that the projection matrix can be used to transform the non orthogonolized splines to orthogonalized.
I will use the same number of nodes for the numerical integration as I used when using stgenreg
.
. timer on 2
. gen double lnt = ln(_t)
. rcsgen lnt, gen(_rcs) df(5) if2(_d==1) orthog
Variables _rcs1 to _rcs5 were created
. mata: st_matrix("knots",strtoreal(tokens(st_global("r(knots)"))))
. matrix R_bhazard = r(R)
.
. scalar Nnodes = 50
. mlad (xb: = hormon age size2 size3 enodes er pr_1, nocons ) ///
> (rcs: = _rcs1 _rcs2 _rcs3 _rcs4 _rcs5) ///
> , llfile(rcs_hazard) ///
> othervars(_t0 _t _d) ///
> othervarnames(t0 t d) ///
> matrices(knots R_bhazard) ///
> staticscalars(Nnodes)
Initial: Log likelihood = 1000121.2
Alternative: Log likelihood = 429338.54
Rescale: Log likelihood = 248692.69
Rescale eq: Log likelihood = 224204.35
Iteration 0: Log likelihood = 224204.35
Iteration 1: Log likelihood = 218528.68
Iteration 2: Log likelihood = 206013.59
Iteration 3: Log likelihood = 205602.38
Iteration 4: Log likelihood = 205574.87
Iteration 5: Log likelihood = 205574.8
Iteration 6: Log likelihood = 205574.8
. ml display
Number of obs = 149,100
Wald chi2(7) = 33137.55
Log likelihood = 205574.8 Prob > chi2 = 0.0000

 Coefficient Std. err. z P>z [95% conf. interval]
+
xb 
hormon  .2124006 .0127695 16.63 0.000 .2374284 .1873728
age  .0118462 .0003482 34.02 0.000 .0111637 .0125286
size2  .3920078 .0098359 39.85 0.000 .3727299 .4112858
size3  .6967243 .0135336 51.48 0.000 .670199 .7232496
enodes  1.866594 .0149733 124.66 0.000 1.895941 1.837247
er  8.12e06 .0000157 0.52 0.606 .000039 .0000227
pr_1  .0924092 .001937 47.71 0.000 .0962056 .0886127
+
rcs 
_rcs1  .1130317 .0063038 17.93 0.000 .1006765 .1253869
_rcs2  .1819154 .0049743 36.57 0.000 .172166 .1916648
_rcs3  .0080271 .0050397 1.59 0.111 .0179047 .0018505
_rcs4  .0268743 .0046945 5.72 0.000 .0360755 .0176732
_rcs5  .0171344 .0047121 3.64 0.000 .0078988 .02637
_cons  1.90774 .0254977 74.82 0.000 1.957714 1.857765

. timer off 2

I have used two equations to separate out covariate effects from the effect of time.

I pass variables need by
mlad
using theothervars()
option and rename usingothervarnames()
. 
The matrices containing the knot positions and the projection matrix are passed using the
matrices
option. 
Finally, the number of nodes is passed using the
staticscalars()
option rather than thescalar
option for the reason described above.
. timer list
1: 571.86 / 1 = 571.8600
2: 16.05 / 1 = 16.0550
The model is notably faster than stgenreg
taking 16.1
seconds to fit.
This is a speed gain of 97.2%.
This is perhaps not a fair comparison as users are far more likely these days to fit such a model using stpm3
or merlin
.
The same model using stpm3
.
stpm3
can be used to fit the same model, but with user friendly syntax. stpm3
uses a gf2
evaluator meaning that the derivatives required for
the gradient and Hessian functions have been derived analytically.
The model is fitted below,
. timer on 3
. stpm3 hormon age size2 size3 enodes er pr_1, df(5) nodes(50) scale(lnhazard) splinetype(rcs)
Iteration 0: Log likelihood = 137164.27
Iteration 1: Log likelihood = 132882.41
Iteration 2: Log likelihood = 132083.78
Iteration 3: Log likelihood = 131971.67
Iteration 4: Log likelihood = 131970.07
Iteration 5: Log likelihood = 131970.07
Number of obs = 149,100
Wald chi2(7) = 33137.55
Log likelihood = 131970.07 Prob > chi2 = 0.0000

 Coefficient Std. err. z P>z [95% conf. interval]
+
xb 
hormon  .2124006 .0127695 16.63 0.000 .2374284 .1873728
age  .0118462 .0003482 34.02 0.000 .0111637 .0125286
size2  .3920078 .0098359 39.85 0.000 .3727299 .4112858
size3  .6967243 .0135336 51.48 0.000 .670199 .7232496
enodes  1.866594 .0149733 124.66 0.000 1.895941 1.837247
er  8.12e06 .0000157 0.52 0.606 .000039 .0000227
pr_1  .0924092 .001937 47.71 0.000 .0962056 .0886127
+
time 
_rcs1  1.420003 .0455305 31.19 0.000 1.330765 1.509242
_rcs2  .811509 .1058836 7.66 0.000 1.019037 .6039809
_rcs3  3.057948 .371769 8.23 0.000 2.329294 3.786602
_rcs4  3.854341 .6594531 5.84 0.000 5.146846 2.561837
_rcs5  2.040633 .5611931 3.64 0.000 .9407149 3.140552
_cons  2.1652 .0401335 53.95 0.000 2.24386 2.08654

Quadrature method: GaussLegendre with 50 nodes
. timer off 3
The same model using merlin
.
The same model can be fitted using merlin
, which is a general command to fit a range of models that also include models with random effects (not used here).
For the model fitted here merlin
uses a gf2
type evaluator.
The model is fitted below
. timer on 4
. merlin (_t hormon age size2 size3 enodes er pr_1 rcs(_t, df(5) orthog log event), ///
> family(loghazard, failure(_d) ) timevar(_t))
variables created for model 1, component 8: _cmp_1_8_1 to _cmp_1_8_5
Fitting full model:
Iteration 0: Log likelihood = 1000121.2
Iteration 1: Log likelihood = 218883.77
Iteration 2: Log likelihood = 207709.63
Iteration 3: Log likelihood = 205789.3
Iteration 4: Log likelihood = 205585.3
Iteration 5: Log likelihood = 205574.91
Iteration 6: Log likelihood = 205574.88
Iteration 7: Log likelihood = 205574.88
Fixed effects regression model Number of obs = 149,100
Log likelihood = 205574.88

 Coefficient Std. err. z P>z [95% conf. interval]
+
_t: 
hormon  .2123989 .0127695 16.63 0.000 .2374267 .1873711
age  .0118462 .0003482 34.02 0.000 .0111637 .0125286
size2  .3920084 .0098359 39.85 0.000 .3727304 .4112863
size3  .6967249 .0135336 51.48 0.000 .6701996 .7232502
enodes  1.866594 .0149733 124.66 0.000 1.895941 1.837247
er  8.12e06 .0000157 0.52 0.606 .000039 .0000227
pr_1  .0924091 .001937 47.71 0.000 .0962056 .0886127
rcs():1  .1130367 .0063038 17.93 0.000 .1006815 .1253919
rcs():2  .1819153 .0049744 36.57 0.000 .1721657 .1916649
rcs():3  .0080333 .0050397 1.59 0.111 .017911 .0018444
rcs():4  .0268593 .0046946 5.72 0.000 .0360605 .017658
rcs():5  .0171084 .0047121 3.63 0.000 .0078729 .0263439
_cons  1.907736 .0254977 74.82 0.000 1.95771 1.857761

. timer off 4
. timer list
1: 571.86 / 1 = 571.8600
2: 16.05 / 1 = 16.0550
3: 21.55 / 1 = 21.5450
4: 171.62 / 1 = 171.6200
In this dataset mlad
has a speed gain of 97.2% over stgenreg
, 25.5% over strcs
and 90.6% over merlin
.
Performance in larger datasets
The following table gives times and percentage speed improvements when comparing mlad
with stgenreg
, strcs
and merlin
for a range of sample sizes.
This model is a proportional hazards model incorporating 10 covariates.
[This table includes strcs' rather than
stpm3'. This will be updated soon.]
Sample Size  mlad 
stgenreg 
strcs 
merlin 

1,000  0.6  4.9 (87.8%)  0.4 (50%)  1.9 (68.4%) 
10,000  0.9  48 (98.1%)  2.4 (62.5%)  11 (91.7%) 
50,000  2.2  193 (98.9%)  12 (82.0%)  86 (97.5%) 
100,000  3.4  452 (99.2%)  27 (87.2%)  178 (98.1%) 
250,000  7.4  1,125 (99.3%)  69 (89.2%)  441 (98.3%) 
500,000  14.2  2,329 (99.4%)  139 (89.8%)  898 (98.4%) 
1,000,000  26.4  4,694 (99.4%)  285 (90.7%)  1,789(98.5%) 
2,500,000  65.0    678 (90.7%)  4,734(98.6%) 
The speed gains over stgenreg
are substantial with the models running in less than 1% of the time for sample sizes of 100,000 or more.
The speed gains over strcs
are of note with the models running in less than 10% of the time for sample sizes of 1,000,000 or more.
The speed gains over merlin
are substantial with the models running in less than 2% of the time for sample sizes of 100,000 or more.
This program is not efficient
The likelihood function for this model is fairly simple, but it is inefficient. The restricted cubic spline basis functions at the nodes are calculated each time the function is called. This is unncessary as the positions of the nodes do not change. In another example I fit the same model but precalculate the basis functions at the nodes.
References
Crowther, M.J. merlin—A unified modeling framework for data analysis and methods development in Stata The Stata Journal 2020;20:763–784
Bower H., Crowther M.J., Lambert, P.C. strcs: A command for fitting flexible parametric survival models on the loghazard scale The Stata Journal 2016;16:9891012
Crowther, M.J, Lambert, P.C. A general framework for parametric survival analysis. Statistics in Medicine 2014;33:52805297
Crowther, M.J., Lambert, P.C. stgenreg: A Stata Package for General Parametric Survival Analysis Journal of Statistical Software 2013;53:117