Using mlad for splines on the log hazard scale - with pysetup.

Restricted cubic splines for the log hazard function.

Here I will fit the same model as the previous example, but use the pysetup() option. This will pre-calculate the spline basis functions at the nodes rather than repeatedly calculate them everytime the function is called, as was done in the previous example.

First I will run mlad as in previous example on the rott2 data. I use expand 400 to increase the sample size to 1,192,800

. use https://www.pclambert.net/data/rott2b, clear
(Rotterdam breast cancer data (augmented with cause of death))

. expand 400
(1,189,818 observations created)

. tab size, gen(size)

     Tumour |
    size, 3 |
classes (t) |      Freq.     Percent        Cum.
------------+-----------------------------------
    <=20 mm |    554,800       46.51       46.51
  >20-50mmm |    516,400       43.29       89.81
     >50 mm |    121,600       10.19      100.00
------------+-----------------------------------
      Total |  1,192,800      100.00

. stset os, failure(osi=1) scale(12) exit (time 120)

Survival-time data settings

         Failure event: osi==1
Observed time interval: (0, os]
     Exit on or before: time 120
     Time for analysis: time/12

--------------------------------------------------------------------------
  1,192,800  total observations
          0  exclusions
--------------------------------------------------------------------------
  1,192,800  observations remaining, representing
    468,400  failures in single-record/single-failure data
  8,000,970  total analysis time at risk and under observation
                                                At risk from t =         0
                                     Earliest observed entry t =         0
                                          Last observed exit t =        10

. timer clear

. timer on 1

. gen double lnt = ln(_t)

. rcsgen lnt, gen(_rcs) df(5) if2(_d==1) orthog
Variables _rcs1 to _rcs5 were created

. mata: st_matrix("knots",strtoreal(tokens(st_global("r(knots)"))))  

. matrix R_bhazard = r(R)

. 
. scalar Nnodes = 50

. mlad (xb: = hormon age size2 size3 enodes er pr_1, nocons )  ///
>      (rcs: = _rcs1 _rcs2 _rcs3 _rcs4 _rcs5)                  ///
>      , llfile(rcs_hazard)                                    ///
>        othervars(_t0 _t _d)                                  ///
>        othervarnames(t0 t d)                                 ///
>        matrices(knots R_bhazard)                             ///
>        staticscalars(Nnodes) 

initial:       log likelihood = -8000969.7
alternative:   log likelihood = -3434708.4
rescale:       log likelihood = -1989541.5
rescale eq:    log likelihood = -1793634.8
Iteration 0:   log likelihood = -1793634.8  
Iteration 1:   log likelihood = -1748229.4  
Iteration 2:   log likelihood = -1648108.7  
Iteration 3:   log likelihood = -1644819.1  
Iteration 4:   log likelihood =   -1644599  
Iteration 5:   log likelihood = -1644598.4  
Iteration 6:   log likelihood = -1644598.4  

. ml display       

                                                     Number of obs = 1,192,800
                                                     Wald chi2(7)  = 265100.41
Log likelihood = -1644598.4                          Prob > chi2   =    0.0000

------------------------------------------------------------------------------
             | Coefficient  Std. err.      z    P>|z|     [95% conf. interval]
-------------+----------------------------------------------------------------
xb           |
      hormon |  -.2124006   .0045147   -47.05   0.000    -.2212492   -.2035519
         age |   .0118462   .0001231    96.23   0.000     .0116049    .0120875
       size2 |   .3920078   .0034775   112.73   0.000      .385192    .3988236
       size3 |   .6967243   .0047848   145.61   0.000     .6873461    .7061024
      enodes |  -1.866594   .0052939  -352.60   0.000     -1.87697   -1.856218
          er |  -8.12e-06   5.56e-06    -1.46   0.145     -.000019    2.79e-06
        pr_1 |  -.0924092   .0006848  -134.94   0.000    -.0937514   -.0910669
-------------+----------------------------------------------------------------
rcs          |
       _rcs1 |   .1130317   .0022287    50.72   0.000     .1086635    .1173999
       _rcs2 |   .1819154   .0017587   103.44   0.000     .1784684    .1853623
       _rcs3 |  -.0080271   .0017818    -4.51   0.000    -.0115194   -.0045348
       _rcs4 |  -.0268743   .0016598   -16.19   0.000    -.0301274   -.0236213
       _rcs5 |   .0171344    .001666    10.28   0.000     .0138691    .0203997
       _cons |   -1.90774   .0090148  -211.62   0.000    -1.925408   -1.890071
------------------------------------------------------------------------------

. timer off 1

Now I will write a setup file which will slightly change the function to calculate the log-likelihood. The key points is that an array of the nodes rqeuired for the numerical integration will be calculated once in the setup file rather than every time the likelihood function is called. The gradient and Hessian functions (obtained through automatic differentiation) will also benefit from this.

This setup Python file is shown below.

. type setup_rcs_hazard.py
from scipy.special import roots_legendre
from jax import vmap
import jax.numpy as jnp
import  mladutil as mu

def mlad_setup(M):
  vrcsgen = (vmap(mu.rcsgen,(0,None,None),0))
  nodes, weights = roots_legendre(M["Nnodes"])
  
  nodes2 = 0.5*(M["t"] - M["t0"])*nodes + 0.5*(M["t"] + M["t0"])
  M["allnodes"] = vrcsgen(jnp.log(nodes2),M["knots"][0],M["R_bhazard"])
  M["weights"] =  weights 
  return(M)

  • The setup function must be called mlad_setup() and has one argument, M, the Python dictionary-

  • First I use vmap on the rcsgen() function to vectorise it, i.e. the function vrcsgen will be able to return the restricted cubic splines basis functions at all of the nodes.

  • The nodes and weights for Gauss-Legendre quadrature are then calculated using roots_legendre().

  • The Gauss-Legendre nodes are for an integral from [-1,1], but are transformed using a change of interval rule and stored in nodes2.

  • The basis functions are then calculated at each node and stored in M[“allnodes”] and the weights stored in M[“weights”].

  • The updated dictionary, M is then returned.

The python log-likelihood file is listed below.

. type rcs_hazard_pysetup.py
import jax.numpy as jnp   
import mladutil as mu

def python_ll(beta,X,wt,M):
  ## Parameters
  xb    = mu.linpred(beta,X,1)
  xbrcs = mu.linpred(beta,X,2)

  ## cumulative hazard
  ch_at_nodes = jnp.exp(jnp.matmul(M["allnodes"],beta[1][:-1]) + beta[1][-1] + xb)
  cumhaz = (0.5*(M["t"]-M["t0"]))*jnp.sum(M["weights"]*ch_at_nodes,axis=1,keepdims=True)

  return(jnp.sum(wt*(M["d"]*(xb + xbrcs) - cumhaz)))

  • The linear predictor for the restricted cubic splines and the covariates effects are calculated using mu.linpred().

  • The cumulative hazard at each of the nodes for each individual is then calculated.

  • These are then used in conjunction with the weights to obtain the cumulative hazard at each event/censoring time.

  • Finally, the log-likelhood is returned.

Now the model can be fitted again, but this time making use of the setup program. This is passed to mlad using the pysetup() option.

. drop lnt _rcs*

. timer on 2

. gen double lnt = ln(_t)

. rcsgen lnt, gen(_rcs) df(5) if2(_d==1) orthog
Variables _rcs1 to _rcs5 were created

. mata: st_matrix("knots",strtoreal(tokens(st_global("r(knots)"))))  

. matrix R_bhazard = r(R)

. 
. scalar Nnodes = 50

. mlad (xb: = hormon age size2 size3 enodes er pr_1, nocons ) ///
>      (rcs: = _rcs1 _rcs2 _rcs3 _rcs4 _rcs5)                 ///
>      , llfile(rcs_hazard_pysetup)                           ///
>        pysetup(setup_rcs_hazard)                            ///
>        othervars(_t0 _t _d)                                 ///
>        othervarnames(t0 t d)                                ///
>        matrices(knots R_bhazard)                            ///
>        scalars(Nnodes) 

initial:       log likelihood = -8000969.7
alternative:   log likelihood = -3434708.4
rescale:       log likelihood = -1989541.5
rescale eq:    log likelihood = -1793634.8
Iteration 0:   log likelihood = -1793634.8  
Iteration 1:   log likelihood = -1748229.4  
Iteration 2:   log likelihood = -1648108.7  
Iteration 3:   log likelihood = -1644819.1  
Iteration 4:   log likelihood =   -1644599  
Iteration 5:   log likelihood = -1644598.4  
Iteration 6:   log likelihood = -1644598.4  

. ml display       

                                                     Number of obs = 1,192,800
                                                     Wald chi2(7)  = 265100.41
Log likelihood = -1644598.4                          Prob > chi2   =    0.0000

------------------------------------------------------------------------------
             | Coefficient  Std. err.      z    P>|z|     [95% conf. interval]
-------------+----------------------------------------------------------------
xb           |
      hormon |  -.2124006   .0045147   -47.05   0.000    -.2212492   -.2035519
         age |   .0118462   .0001231    96.23   0.000     .0116049    .0120875
       size2 |   .3920078   .0034775   112.73   0.000      .385192    .3988236
       size3 |   .6967243   .0047848   145.61   0.000     .6873461    .7061024
      enodes |  -1.866594   .0052939  -352.60   0.000     -1.87697   -1.856218
          er |  -8.12e-06   5.56e-06    -1.46   0.145     -.000019    2.79e-06
        pr_1 |  -.0924092   .0006848  -134.94   0.000    -.0937514   -.0910669
-------------+----------------------------------------------------------------
rcs          |
       _rcs1 |   .1130317   .0022287    50.72   0.000     .1086635    .1173999
       _rcs2 |   .1819154   .0017587   103.44   0.000     .1784684    .1853623
       _rcs3 |  -.0080271   .0017818    -4.51   0.000    -.0115194   -.0045348
       _rcs4 |  -.0268743   .0016598   -16.19   0.000    -.0301274   -.0236213
       _rcs5 |   .0171344    .001666    10.28   0.000     .0138691    .0203997
       _cons |   -1.90774   .0090148  -211.62   0.000    -1.925408   -1.890071
------------------------------------------------------------------------------

. timer off 2       

The estimates from the models are identical. The times to fit each model can be seen below.

. timer list
   1:     89.14 /        1 =      89.1350
   2:     53.94 /        1 =      53.9420

In this dataset using the setup file has a speed gain of 39.5% over the previous example.

Professor of Biostatistics