Using mlad for splines on the log hazard scale - with pysetup.
Restricted cubic splines for the log hazard function.
Here I will fit the same model as the previous example, but use the pysetup()
option.
This will pre-calculate the spline basis functions at the nodes rather than repeatedly calculate them everytime the function is called,
as was done in the previous example.
First I will run mlad as in previous example on the rott2 data. I use expand 400 to increase the sample size to 1,192,800
. use https://www.pclambert.net/data/rott2b, clear
(Rotterdam breast cancer data (augmented with cause of death))
. expand 400
(1,189,818 observations created)
. tab size, gen(size)
Tumour |
size, 3 |
classes (t) | Freq. Percent Cum.
------------+-----------------------------------
<=20 mm | 554,800 46.51 46.51
>20-50mmm | 516,400 43.29 89.81
>50 mm | 121,600 10.19 100.00
------------+-----------------------------------
Total | 1,192,800 100.00
. stset os, failure(osi=1) scale(12) exit (time 120)
Survival-time data settings
Failure event: osi==1
Observed time interval: (0, os]
Exit on or before: time 120
Time for analysis: time/12
--------------------------------------------------------------------------
1,192,800 total observations
0 exclusions
--------------------------------------------------------------------------
1,192,800 observations remaining, representing
468,400 failures in single-record/single-failure data
8,000,970 total analysis time at risk and under observation
At risk from t = 0
Earliest observed entry t = 0
Last observed exit t = 10
. timer clear
. timer on 1
. gen double lnt = ln(_t)
. rcsgen lnt, gen(_rcs) df(5) if2(_d==1) orthog
Variables _rcs1 to _rcs5 were created
. mata: st_matrix("knots",strtoreal(tokens(st_global("r(knots)"))))
. matrix R_bhazard = r(R)
.
. scalar Nnodes = 50
. mlad (xb: = hormon age size2 size3 enodes er pr_1, nocons ) ///
> (rcs: = _rcs1 _rcs2 _rcs3 _rcs4 _rcs5) ///
> , llfile(rcs_hazard) ///
> othervars(_t0 _t _d) ///
> othervarnames(t0 t d) ///
> matrices(knots R_bhazard) ///
> staticscalars(Nnodes)
Initial: Log likelihood = -8000969.7
Alternative: Log likelihood = -3434708.4
Rescale: Log likelihood = -1989541.5
Rescale eq: Log likelihood = -1793634.8
Iteration 0: Log likelihood = -1793634.8
Iteration 1: Log likelihood = -1748229.4
Iteration 2: Log likelihood = -1648108.7
Iteration 3: Log likelihood = -1644819.1
Iteration 4: Log likelihood = -1644599
Iteration 5: Log likelihood = -1644598.4
Iteration 6: Log likelihood = -1644598.4
. ml display
Number of obs = 1,192,800
Wald chi2(7) = 265100.41
Log likelihood = -1644598.4 Prob > chi2 = 0.0000
------------------------------------------------------------------------------
| Coefficient Std. err. z P>|z| [95% conf. interval]
-------------+----------------------------------------------------------------
xb |
hormon | -.2124006 .0045147 -47.05 0.000 -.2212492 -.2035519
age | .0118462 .0001231 96.23 0.000 .0116049 .0120875
size2 | .3920078 .0034775 112.73 0.000 .385192 .3988236
size3 | .6967243 .0047848 145.61 0.000 .6873461 .7061024
enodes | -1.866594 .0052939 -352.60 0.000 -1.87697 -1.856218
er | -8.12e-06 5.56e-06 -1.46 0.145 -.000019 2.79e-06
pr_1 | -.0924092 .0006848 -134.94 0.000 -.0937514 -.0910669
-------------+----------------------------------------------------------------
rcs |
_rcs1 | .1130317 .0022287 50.72 0.000 .1086635 .1173999
_rcs2 | .1819154 .0017587 103.44 0.000 .1784684 .1853623
_rcs3 | -.0080271 .0017818 -4.51 0.000 -.0115194 -.0045348
_rcs4 | -.0268743 .0016598 -16.19 0.000 -.0301274 -.0236213
_rcs5 | .0171344 .001666 10.28 0.000 .0138691 .0203997
_cons | -1.90774 .0090148 -211.62 0.000 -1.925408 -1.890071
------------------------------------------------------------------------------
. timer off 1
Now I will write a setup file which will slightly change the function to calculate the log-likelihood. The key points is that an array of the nodes rqeuired for the numerical integration will be calculated once in the setup file rather than every time the likelihood function is called. The gradient and Hessian functions (obtained through automatic differentiation) will also benefit from this.
This setup Python file is shown below.
. type setup_rcs_hazard.py
from scipy.special import roots_legendre
from jax import vmap
import jax.numpy as jnp
import mladutil as mu
def mlad_setup(M):
vrcsgen = (vmap(mu.rcsgen,(0,None,None),0))
nodes, weights = roots_legendre(M["Nnodes"])
nodes2 = 0.5*(M["t"] - M["t0"])*nodes + 0.5*(M["t"] + M["t0"])
M["allnodes"] = vrcsgen(jnp.log(nodes2),M["knots"][0],M["R_bhazard"])
M["weights"] = weights
return(M)
-
The setup function must be called mlad_setup() and has one argument, M, the Python dictionary-
-
First I use vmap on the rcsgen() function to vectorise it, i.e. the function vrcsgen will be able to return the restricted cubic splines basis functions at all of the nodes.
-
The nodes and weights for Gauss-Legendre quadrature are then calculated using roots_legendre().
-
The Gauss-Legendre nodes are for an integral from [-1,1], but are transformed using a change of interval rule and stored in nodes2.
-
The basis functions are then calculated at each node and stored in M[“allnodes”] and the weights stored in M[“weights”].
-
The updated dictionary, M is then returned.
The python log-likelihood file is listed below.
. type rcs_hazard_pysetup.py
import jax.numpy as jnp
import mladutil as mu
def python_ll(beta,X,wt,M):
## Parameters
xb = mu.linpred(beta,X,1)
xbrcs = mu.linpred(beta,X,2)
## cumulative hazard
ch_at_nodes = jnp.exp(jnp.matmul(M["allnodes"],beta[1][:-1]) + beta[1][-1] + xb)
cumhaz = (0.5*(M["t"]-M["t0"]))*jnp.sum(M["weights"]*ch_at_nodes,axis=1,keepdims=True)
return(jnp.sum(wt*(M["d"]*(xb + xbrcs) - cumhaz)))
-
The linear predictor for the restricted cubic splines and the covariates effects are calculated using
mu.linpred()
. -
The cumulative hazard at each of the nodes for each individual is then calculated.
-
These are then used in conjunction with the weights to obtain the cumulative hazard at each event/censoring time.
-
Finally, the log-likelhood is returned.
Now the model can be fitted again, but this time making use of the setup program.
This is passed to mlad
using the pysetup()
option.
. drop lnt _rcs*
. timer on 2
. gen double lnt = ln(_t)
. rcsgen lnt, gen(_rcs) df(5) if2(_d==1) orthog
Variables _rcs1 to _rcs5 were created
. mata: st_matrix("knots",strtoreal(tokens(st_global("r(knots)"))))
. matrix R_bhazard = r(R)
.
. scalar Nnodes = 50
. mlad (xb: = hormon age size2 size3 enodes er pr_1, nocons ) ///
> (rcs: = _rcs1 _rcs2 _rcs3 _rcs4 _rcs5) ///
> , llfile(rcs_hazard_pysetup) ///
> pysetup(setup_rcs_hazard) ///
> othervars(_t0 _t _d) ///
> othervarnames(t0 t d) ///
> matrices(knots R_bhazard) ///
> scalars(Nnodes)
Initial: Log likelihood = -8000969.7
Alternative: Log likelihood = -3434708.4
Rescale: Log likelihood = -1989541.5
Rescale eq: Log likelihood = -1793634.8
Iteration 0: Log likelihood = -1793634.8
Iteration 1: Log likelihood = -1748229.4
Iteration 2: Log likelihood = -1648108.7
Iteration 3: Log likelihood = -1644819.1
Iteration 4: Log likelihood = -1644599
Iteration 5: Log likelihood = -1644598.4
Iteration 6: Log likelihood = -1644598.4
. ml display
Number of obs = 1,192,800
Wald chi2(7) = 265100.41
Log likelihood = -1644598.4 Prob > chi2 = 0.0000
------------------------------------------------------------------------------
| Coefficient Std. err. z P>|z| [95% conf. interval]
-------------+----------------------------------------------------------------
xb |
hormon | -.2124006 .0045147 -47.05 0.000 -.2212492 -.2035519
age | .0118462 .0001231 96.23 0.000 .0116049 .0120875
size2 | .3920078 .0034775 112.73 0.000 .385192 .3988236
size3 | .6967243 .0047848 145.61 0.000 .6873461 .7061024
enodes | -1.866594 .0052939 -352.60 0.000 -1.87697 -1.856218
er | -8.12e-06 5.56e-06 -1.46 0.145 -.000019 2.79e-06
pr_1 | -.0924092 .0006848 -134.94 0.000 -.0937514 -.0910669
-------------+----------------------------------------------------------------
rcs |
_rcs1 | .1130317 .0022287 50.72 0.000 .1086635 .1173999
_rcs2 | .1819154 .0017587 103.44 0.000 .1784684 .1853623
_rcs3 | -.0080271 .0017818 -4.51 0.000 -.0115194 -.0045348
_rcs4 | -.0268743 .0016598 -16.19 0.000 -.0301274 -.0236213
_rcs5 | .0171344 .001666 10.28 0.000 .0138691 .0203997
_cons | -1.90774 .0090148 -211.62 0.000 -1.925408 -1.890071
------------------------------------------------------------------------------
. timer off 2
The estimates from the models are identical. The times to fit each model can be seen below.
. timer list
1: 108.64 / 1 = 108.6410
2: 61.87 / 1 = 61.8730
In this dataset using the setup file has a speed gain of 43.0% over the previous example.