Using mlad for splines on the log hazard scale.

Restricted cubic splines for the log hazard function.

We have done a lot of work using restricted splines to model survival data. The most common models we use are when using splines to model the log cumulative hazard functon, (see stpm2). However, sometimes it can be useful to directly model on the log hazard scale. Models on the log hazard scale using splines are computationally more intensive as the log cumulative hazard at each event/censoring time is needed to maximize the log-likelihood function and this has to be obtained through numerical integration for each individual in the study.

Various Stata commands can fit a model on the log-hazard scale including stgenreg, strcs and merlin.

A proportional hazards model using restriced cubic splines to estimate the baseline hazard function can be written as,

$$\ln[h(t)] = s(\ln[t]|\boldsymbol{\gamma}) + \mathbf{X}\boldsymbol{\beta}$$

where $s(\ln(t)|\boldsymbol{\gamma})$ is a restricted cubic spline function, $\mathbf{X}$ a set of covariates and associated parameters (log hazard ratios), $\boldsymbol{\beta}$.

The log-likelihood contribution for the $i^{th}$ subject is

$$ll_i = d_i \ln[h(t_i)] - \int_{t_{0i}}^{t_i} h(u)du $$

where $t_i$ is the event/censoring time, $t_{0i}$ is the entry time and $d_i$ is the event indicator for the $i^{th}$. The inclusion of $t_{0i}$ allows for models with delayed entry (left truncation).

For simple parametric models such as a Weibull model it is possible to derive the integral analytically. However, when using splines on the log-hazard scale this is not possible, so numeric integration needs to be used. Note, that this is one reason why it is often advantageous to use splines on the log cumulative hazard scale as the cumulative hazard can obtained analytically.

A simple way to do the numerical integration is using Gauss Legendre quadrature. In order to numerically integrate the hazard function between $t_{0i}$ and $t_i$ a set of nodes, $x_i$ and weights, $w_i$ is chosen. With more nodes/weights the greater the accuracy of the numerical integration. The nodes and weights that are generated can be used to integrate between [-1,1], but through a change of interval rule we can integrate between [$t_{0i}$,$t_{i}$].

With $n$ quadrature nodes and weights the integral can be obtaining using,

$$ \int_{t_{0i}}^{t_i} h(u) du \approx \frac{t_i - t_{0i}}{2} \sum_{k=1}^n w_k h\left(\frac{t_i - t_{0i}}{2}x_k + \frac{t_i + t_{0i}}{2}\right)$$

Note that this integral needs to be calculated for every individual in the study every time the likelihood function is called.

An example using stgenreg

stgenreg was the first command in Stata able to fit splines on the log hazard scale (see Crowther and Lambert 2013 and 2014), without having to finley split the time scale and approximate the integral using Poisson regression. stgenreg is a very general command that allows the user to define just about any parametric function for the (log) hazard function. Its generallity makes it slow with large datasets and it is a d0 type evaluator which means the gradient and Hessian matrix is obtained using numerical differentiation.

I will use the rott2 data to develop the mlad function. This example will only fit a proportional hazards model. I will use expand 50 to increase the size of the dataset to 149,100 as I am mainly interested in performance in larger datasets. I will include a few pre-selected covariates in the model.

. use https://www.pclambert.net/data/rott2b, clear
(Rotterdam breast cancer data (augmented with cause of death))

. expand 50
(146,118 observations created)

. tab size, gen(size)

     Tumour |
    size, 3 |
classes (t) |      Freq.     Percent        Cum.
------------+-----------------------------------
    <=20 mm |     69,350       46.51       46.51
  >20-50mmm |     64,550       43.29       89.81
     >50 mm |     15,200       10.19      100.00
------------+-----------------------------------
      Total |    149,100      100.00

. stset os, failure(osi=1) scale(12) exit (time 120)

Survival-time data settings

         Failure event: osi==1
Observed time interval: (0, os]
     Exit on or before: time 120
     Time for analysis: time/12

--------------------------------------------------------------------------
    149,100  total observations
          0  exclusions
--------------------------------------------------------------------------
    149,100  observations remaining, representing
     58,550  failures in single-record/single-failure data
  1,000,121  total analysis time at risk and under observation
                                                At risk from t =         0
                                     Earliest observed entry t =         0
                                          Last observed exit t =        10

. 
. timer clear

. timer on 1

. stgenreg, loghazard([xb])                                         ///
>           xb(hormon age size2 size3 enodes er pr_1 | #rcs(df(5))) ///
>           nodes(50)
Variables _eq1_cp2_rcs1 to _eq1_cp2_rcs5 were created

initial:       log likelihood = -926516.48
alternative:   log likelihood = -562274.45
rescale:       log likelihood = -178846.95
Iteration 0:   log likelihood = -178846.95  
Iteration 1:   log likelihood = -133783.09  
Iteration 2:   log likelihood = -132229.16  
Iteration 3:   log likelihood = -131981.55  
Iteration 4:   log likelihood = -131970.09  
Iteration 5:   log likelihood = -131970.07  
Iteration 6:   log likelihood = -131970.07  

Log likelihood = -131970.07                            Number of obs = 149,100

-------------------------------------------------------------------------------
              | Coefficient  Std. err.      z    P>|z|     [95% conf. interval]
--------------+----------------------------------------------------------------
       hormon |  -.2124007   .0127695   -16.63   0.000    -.2374285   -.1873729
          age |   .0118462   .0003482    34.02   0.000     .0111637    .0125286
        size2 |   .3920075   .0098359    39.85   0.000     .3727295    .4112854
        size3 |   .6967237   .0135336    51.48   0.000     .6701984     .723249
       enodes |  -1.866594   .0149733  -124.66   0.000    -1.895941   -1.837247
           er |  -8.12e-06   .0000157    -0.52   0.606     -.000039    .0000227
         pr_1 |  -.0924092    .001937   -47.71   0.000    -.0962056   -.0886127
_eq1_cp2_rcs1 |   .1130317   .0063038    17.93   0.000     .1006765    .1253869
_eq1_cp2_rcs2 |   .1819154   .0049743    36.57   0.000      .172166    .1916649
_eq1_cp2_rcs3 |  -.0080271   .0050397    -1.59   0.111    -.0179047    .0018505
_eq1_cp2_rcs4 |  -.0268743   .0046945    -5.72   0.000    -.0360755   -.0176732
_eq1_cp2_rcs5 |   .0171344   .0047121     3.64   0.000     .0078988      .02637
        _cons |  -1.907739   .0254978   -74.82   0.000    -1.957713   -1.857764
-------------------------------------------------------------------------------
 Quadrature method: Gauss-Legendre with 50 nodes

. timer off 1

. timer list
   1:    524.77 /        1 =     524.7710

The stgenreg command has fitted a model on the log hazard scale using a restricted cubic splines with 5 d.f. (6 knots) to model the effects of time through the log hazard function and a selection of covariates. I have used 50 nodes for the numerical integration as the default of 15 can be too low in some cases. The model took 524.8 seconds to fit.

Fitting the same model using mlad

First I need to write the log-likelihood function in Python. The log-likelihood file is shown below.

. type rcs_hazard.py
import jax.numpy as jnp   
import mladutil as mu
from   jax import vmap

def python_ll(beta,X,wt,M,Nnodes):
  ## Parameters
  xb    = mu.linpred(beta,X,1)
  xbrcs = mu.linpred(beta,X,2)

  ## hazard function
  def rcshaz(t):
    vrcsgen = vmap(mu.rcsgen_beta,(0,None,None,None))
    return(jnp.exp(vrcsgen(jnp.log(t),M["knots"][0],beta[1],M["R_bhazard"]) + xb))

  ## cumulative hazard
  cumhaz = mu.vecquad_gl(rcshaz,M["t0"],M["t"],Nnodes,())   

  ## return likelhood
  return(jnp.sum(wt*(M["d"]*(xb + xbrcs) - cumhaz)))

  • First the JAX version of numpy and the mladutil modules are loaded. In addition, the JAX function vmap is loaded. This will be described below.

  • The arguments of the python_ll function are similar to previous examples, but the number of nodes is passed as a separate argument rather than contained the dictionary, M. I will explain why below.

  • Although, I could have included one linear predictor, I have chosen to have one equation for the baseline log hazard function and one for the covariates. This makes the numerical integration easier.

  • The log-likelihood function needs to the calculate the hazard function at various time point to peform the numerical integration. This is where the vmap function is particularly useful as it can lead to vast speed improvements. vmap here takes a function, rcsgen_beta() that returns the restricted cubic spline basis functions multuplied by a vector of parameters for a single time point, and vetorizes it so that it can return the predicted values for all nodes (i.e. many time points). This new function is named vrcsgen().

  • Having defined the vrcsgen() function the numerical integration to calculate the cumulative hazard using Gauss Legendre quadrature can be performed using the mlad utility function, mu.vecquad_gl(). Note that the number of nodes for the numerical integration is passed to the function. This was passed separately to the python_ll function as it dictates the size of arrays that are calculated and JAX can give an error if it thinks teh size of arrays may change. When using mlad below the number of nodes is passed as a static scalar. This tells JAX, that this will not change with different calls to the functions when fitting the model.

  • Finally, the log-likelihood is returned as a scalar by summing the individual contributions to the likelihood.

Now mlad can be called to maximize the likelihood. First I will calculate the restricted cubic spline basis functions at the event/censoring times, store the knots and projection matrix so these can be passed to mlad. Note that the projection matrix can be used to transform the non orthogonolized splines to orthogonalized. I will use the same number of nodes for the numerical integration as I used when using stgenreg.

. timer on 2

. gen double lnt = ln(_t)

. rcsgen lnt, gen(_rcs) df(5) if2(_d==1) orthog
Variables _rcs1 to _rcs5 were created

. mata: st_matrix("knots",strtoreal(tokens(st_global("r(knots)"))))  

. matrix R_bhazard = r(R)

. 
. scalar Nnodes = 50

. mlad (xb: = hormon age size2 size3 enodes er pr_1, nocons )  ///
>      (rcs: = _rcs1 _rcs2 _rcs3 _rcs4 _rcs5)                  ///
>      , llfile(rcs_hazard)                                    ///
>        othervars(_t0 _t _d)                                  ///
>        othervarnames(t0 t d)                                 ///
>        matrices(knots R_bhazard)                             ///
>        staticscalars(Nnodes) 

initial:       log likelihood = -1000121.2
alternative:   log likelihood = -429338.54
rescale:       log likelihood = -248692.69
rescale eq:    log likelihood = -224204.35
Iteration 0:   log likelihood = -224204.35  
Iteration 1:   log likelihood = -218528.68  
Iteration 2:   log likelihood = -206013.59  
Iteration 3:   log likelihood = -205602.38  
Iteration 4:   log likelihood = -205574.87  
Iteration 5:   log likelihood =  -205574.8  
Iteration 6:   log likelihood =  -205574.8  

. ml display       

                                                      Number of obs =  149,100
                                                      Wald chi2(7)  = 33137.55
Log likelihood = -205574.8                            Prob > chi2   =   0.0000

------------------------------------------------------------------------------
             | Coefficient  Std. err.      z    P>|z|     [95% conf. interval]
-------------+----------------------------------------------------------------
xb           |
      hormon |  -.2124006   .0127695   -16.63   0.000    -.2374284   -.1873728
         age |   .0118462   .0003482    34.02   0.000     .0111637    .0125286
       size2 |   .3920078   .0098359    39.85   0.000     .3727299    .4112858
       size3 |   .6967243   .0135336    51.48   0.000      .670199    .7232496
      enodes |  -1.866594   .0149733  -124.66   0.000    -1.895941   -1.837247
          er |  -8.12e-06   .0000157    -0.52   0.606     -.000039    .0000227
        pr_1 |  -.0924092    .001937   -47.71   0.000    -.0962056   -.0886127
-------------+----------------------------------------------------------------
rcs          |
       _rcs1 |   .1130317   .0063038    17.93   0.000     .1006765    .1253869
       _rcs2 |   .1819154   .0049743    36.57   0.000      .172166    .1916648
       _rcs3 |  -.0080271   .0050397    -1.59   0.111    -.0179047    .0018505
       _rcs4 |  -.0268743   .0046945    -5.72   0.000    -.0360755   -.0176732
       _rcs5 |   .0171344   .0047121     3.64   0.000     .0078988      .02637
       _cons |   -1.90774   .0254977   -74.82   0.000    -1.957714   -1.857765
------------------------------------------------------------------------------

. timer off 2       

  • I have used two equations to separate out covariate effects from the effect of time.

  • I pass variables need by mlad using the othervars() option and rename using othervarnames().

  • The matrices comtaining the knot positions and the projection matrix are passed using the matrices option.

  • Finally the number of nodes is passed using the staticscalars() option rather than the scalar option for the reason described above.

. timer list
   1:    524.77 /        1 =     524.7710
   2:     12.41 /        1 =      12.4140

The model is notably faster than stgenreg taking 12.4 seconds to fit. This is a speed gain of 97.6%. This is perhaps not a fair comparison as users are far more likely these days to fit such a model using strcs or merlin.

The same model using strcs.

strcs can be used to fit the same model, but with user friendly syntax. strcs uses a gf2 evaluator meaning that the derivatives required for the gradient and Hessian functions have been derived analytically. In, addition the integration to derive the cumulative is more accurate in strcs as it makes use of the fact that integral before the first knot and after the last knot can be derived analytically and the numerical intregration is just performed between the first and last knots.

The model is fitted below,

. timer on 3

. strcs hormon age size2 size3 enodes er pr_1, df(5) nodes(50) nohr

Iteration 0:   log likelihood =  -131990.6  
Iteration 1:   log likelihood = -131970.07  
Iteration 2:   log likelihood = -131970.07  

Log likelihood = -131970.07                            Number of obs = 149,100

------------------------------------------------------------------------------
             | Coefficient  Std. err.      z    P>|z|     [95% conf. interval]
-------------+----------------------------------------------------------------
xb           |
      hormon |  -.2124006   .0127695   -16.63   0.000    -.2374284   -.1873729
         age |   .0118462   .0003482    34.02   0.000     .0111637    .0125286
       size2 |   .3920078   .0098359    39.85   0.000     .3727298    .4112858
       size3 |   .6967242   .0135336    51.48   0.000     .6701989    .7232495
      enodes |  -1.866594   .0149733  -124.66   0.000    -1.895941   -1.837247
          er |  -8.12e-06   .0000157    -0.52   0.606     -.000039    .0000227
        pr_1 |  -.0924092    .001937   -47.71   0.000    -.0962056   -.0886127
-------------+----------------------------------------------------------------
rcs          |
        __s1 |   .1130307   .0063038    17.93   0.000     .1006755    .1253859
        __s2 |   .1819159   .0049743    36.57   0.000     .1721665    .1916653
        __s3 |  -.0080258   .0050397    -1.59   0.111    -.0179034    .0018518
        __s4 |  -.0268745   .0046945    -5.72   0.000    -.0360756   -.0176734
        __s5 |    .017137   .0047121     3.64   0.000     .0079015    .0263725
       _cons |   -1.90774   .0254977   -74.82   0.000    -1.957715   -1.857766
------------------------------------------------------------------------------
 Quadrature method: Gauss-Legendre with 50 nodes

. timer off 3

Note that there are some very small difference in the estimated coefficients from the other models as the integration is more accurate in strcs.

An updated version of strcs will soon be releaed with a python option that will call mlad and thus lead to substantial speed gains.

The same model using merlin.

The same model can be fitted using merlin, which is a general command to fit a range of models that also include models with random effects (not used here). For the model fitted here merlin uses a gf2 type evaluator.

The model is fitted below

. timer on 4

. merlin (_t hormon age size2 size3 enodes er pr_1 rcs(_t, df(5) orthog log event), ///
>        family(loghazard, failure(_d) ) timevar(_t)) 
variables created for model 1, component 8: _cmp_1_8_1 to _cmp_1_8_5

Fitting full model:

Iteration 0:   log likelihood = -1099449.9  
Iteration 1:   log likelihood = -220407.83  
Iteration 2:   log likelihood = -207773.08  
Iteration 3:   log likelihood = -205807.88  
Iteration 4:   log likelihood = -205586.73  
Iteration 5:   log likelihood = -205574.89  
Iteration 6:   log likelihood = -205574.88  

Fixed effects regression model                         Number of obs = 149,100
Log likelihood = -205574.88
------------------------------------------------------------------------------
             | Coefficient  Std. err.      z    P>|z|     [95% conf. interval]
-------------+----------------------------------------------------------------
_t:          |            
      hormon |  -.2123989   .0127695   -16.63   0.000    -.2374267   -.1873711
         age |   .0118462   .0003482    34.02   0.000     .0111637    .0125286
       size2 |   .3920084   .0098359    39.85   0.000     .3727304    .4112863
       size3 |   .6967249   .0135336    51.48   0.000     .6701996    .7232502
      enodes |  -1.866594   .0149733  -124.66   0.000    -1.895941   -1.837247
          er |  -8.12e-06   .0000157    -0.52   0.606     -.000039    .0000227
        pr_1 |  -.0924091    .001937   -47.71   0.000    -.0962056   -.0886127
     rcs():1 |   .1130363   .0063038    17.93   0.000     .1006811    .1253915
     rcs():2 |   .1819146   .0049744    36.57   0.000      .172165    .1916642
     rcs():3 |  -.0080328   .0050397    -1.59   0.111    -.0179105     .001845
     rcs():4 |  -.0268595   .0046946    -5.72   0.000    -.0360607   -.0176582
     rcs():5 |   .0171086   .0047121     3.63   0.000      .007873    .0263441
       _cons |  -1.907736   .0254977   -74.82   0.000     -1.95771   -1.857761
------------------------------------------------------------------------------

. timer off 4

. timer list
   1:    524.77 /        1 =     524.7710
   2:     12.41 /        1 =      12.4140
   3:     62.48 /        1 =      62.4820
   4:    145.83 /        1 =     145.8270

In this dataset mlad has a speed gain of 97.6% over stgenreg, 80.1% over strcs and 91.5% over merlin.

Performance in larger datasets

The following table gives times and percentage speed improvements when comparing mlad with stgenreg, strcs and merlin for a range of sample sizes. This model is a proportional hazards model incorporating 10 covariates.

Sample Size mlad stgenreg strcs merlin
1,000 0.6 4.9 (87.8%) 0.4 (-50%) 1.9 (68.4%)
10,000 0.9 48 (98.1%) 2.4 (62.5%) 11 (91.7%)
50,000 2.2 193 (98.9%) 12 (82.0%) 86 (97.5%)
100,000 3.4 452 (99.2%) 27 (87.2%) 178 (98.1%)
250,000 7.4 1,125 (99.3%) 69 (89.2%) 441 (98.3%)
500,000 14.2 2,329 (99.4%) 139 (89.8%) 898 (98.4%)
1,000,000 26.4 4,694 (99.4%) 285 (90.7%) 1,789(98.5%)
2,500,000 65.0 - 678 (90.7%) 4,734(98.6%)

The speed gains over stgenreg are substantial with the models running in less than 1% of the time for sample sizes of 100,000 or more. The speed gains over strcs are of note with the models running in less than 10% of the time for sample sizes of 1,000,000 or more. The speed gains over merlin are substantial with the models running in less than 2% of the time for sample sizes of 100,000 or more.

This program is not efficient

The likelihood function for this model is fairly simple, but it is inefficient. The restricted cubic spline basis functions at the nodes are calculated each time the function is called. This is unncessary as the positions of the nodes do not change. In another example I fit the same model but pre-calculate the basis functions at the nodes.

References

Crowther, M.J. merlin—A unified modeling framework for data analysis and methods development in Stata The Stata Journal 2020;20:763–784

Bower H., Crowther M.J., Lambert, P.C. strcs: A command for fitting flexible parametric survival models on the log-hazard scale The Stata Journal 2016;16:989-1012

Crowther, M.J, Lambert, P.C. A general framework for parametric survival analysis. Statistics in Medicine 2014;33:5280-5297

Crowther, M.J., Lambert, P.C. stgenreg: A Stata Package for General Parametric Survival Analysis Journal of Statistical Software 2013;53:1-17

Professor of Biostatistics