Using mlad for splines on the log hazard scale - with pysetup.
Restricted cubic splines for the log hazard function.
Here I will fit the same model as the previous example, but use the pysetup()
option. This will pre-calculate the spline basis functions at the nodes rather than repeatedly calculate them everytime the function is called, as was done in the previous example.
First I will run mlad as in previous example on the rott2 data. I use expand 400 to increase the sample size to 1,192,800
use https://www.pclambert.net/data/rott2b, clear
. data (augmented with cause of death))
(Rotterdam breast cancer
. expand 400
(1,189,818 observations created)
tab size, gen(size)
.
Tumour |size, 3 |
classes (t) | Freq. Percent Cum.
------------+-----------------------------------
<=20 mm | 554,800 46.51 46.51
>20-50mmm | 516,400 43.29 89.81
>50 mm | 121,600 10.19 100.00
------------+-----------------------------------
Total | 1,192,800 100.00
stset os, failure(osi=1) scale(12) exit (time 120)
.
data settings
Survival-time
Failure event: osi==1
Observed time interval: (0, os]on or before: time 120
Exit for analysis: time/12
Time
--------------------------------------------------------------------------total observations
1,192,800
0 exclusions
--------------------------------------------------------------------------
1,192,800 observations remaining, representingin single-record/single-failure data
468,400 failures total analysis time at risk and under observation
8,000,970
At risk from t = 0
Earliest observed entry t = 0exit t = 10
Last observed
timer clear
.
timer on 1
.
gen double lnt = ln(_t)
.
gen(_rcs) df(5) if2(_d==1) orthog
. rcsgen lnt,
Variables _rcs1 to _rcs5 were created
mata: st_matrix("knots",strtoreal(tokens(st_global("r(knots)"))))
.
matrix R_bhazard = r(R)
.
. scalar Nnodes = 50
.
xb: = hormon age size2 size3 enodes er pr_1, nocons ) ///
. mlad (///
> (rcs: = _rcs1 _rcs2 _rcs3 _rcs4 _rcs5) ///
> , llfile(rcs_hazard) ///
> othervars(_t0 _t _d) d) ///
> othervarnames(t0 t ///
> matrices(knots R_bhazard)
> staticscalars(Nnodes)
Initial: Log likelihood = -8000969.7
Alternative: Log likelihood = -3434708.4
Rescale: Log likelihood = -1989541.5eq: Log likelihood = -1793634.8
Rescale
Iteration 0: Log likelihood = -1793634.8
Iteration 1: Log likelihood = -1748229.4
Iteration 2: Log likelihood = -1648108.7
Iteration 3: Log likelihood = -1644819.1
Iteration 4: Log likelihood = -1644599
Iteration 5: Log likelihood = -1644598.4
Iteration 6: Log likelihood = -1644598.4
ml display
.
of obs = 1,192,800
Number chi2(7) = 265100.41
Wald chi2 = 0.0000
Log likelihood = -1644598.4 Prob >
------------------------------------------------------------------------------
| Coefficient Std. err. z P>|z| [95% conf. interval]
-------------+----------------------------------------------------------------xb |
hormon | -.2124006 .0045147 -47.05 0.000 -.2212492 -.2035519
age | .0118462 .0001231 96.23 0.000 .0116049 .0120875
size2 | .3920078 .0034775 112.73 0.000 .385192 .3988236
size3 | .6967243 .0047848 145.61 0.000 .6873461 .7061024
enodes | -1.866594 .0052939 -352.60 0.000 -1.87697 -1.856218
er | -8.12e-06 5.56e-06 -1.46 0.145 -.000019 2.79e-06
pr_1 | -.0924092 .0006848 -134.94 0.000 -.0937514 -.0910669
-------------+----------------------------------------------------------------
rcs |
_rcs1 | .1130317 .0022287 50.72 0.000 .1086635 .1173999
_rcs2 | .1819154 .0017587 103.44 0.000 .1784684 .1853623
_rcs3 | -.0080271 .0017818 -4.51 0.000 -.0115194 -.0045348
_rcs4 | -.0268743 .0016598 -16.19 0.000 -.0301274 -.0236213
_rcs5 | .0171344 .001666 10.28 0.000 .0138691 .0203997_cons | -1.90774 .0090148 -211.62 0.000 -1.925408 -1.890071
------------------------------------------------------------------------------
timer off 1 .
Now I will write a setup file which will slightly change the function to calculate the log-likelihood. The key points is that an array of the nodes rqeuired for the numerical integration will be calculated once in the setup file rather than every time the likelihood function is called. The gradient and Hessian functions (obtained through automatic differentiation) will also benefit from this.
This setup Python file is shown below.
type setup_rcs_hazard.py
.
from scipy.special import roots_legendre
from jax import vmapas jnp
import jax.numpy as mu
import mladutil
def mlad_setup(M):
vrcsgen = (vmap(mu.rcsgen,(0,None,None),0))"Nnodes"])
nodes, weights = roots_legendre(M[
"t"] - M["t0"])*nodes + 0.5*(M["t"] + M["t0"])
nodes2 = 0.5*(M["allnodes"] = vrcsgen(jnp.log(nodes2),M["knots"][0],M["R_bhazard"])
M["weights"] = weights
M[return(M)
The setup function must be called mlad_setup() and has one argument, M, the Python dictionary-
First I use vmap on the rcsgen() function to vectorise it, i.e. the function vrcsgen will be able to return the restricted cubic splines basis functions at all of the nodes.
The nodes and weights for Gauss-Legendre quadrature are then calculated using roots_legendre().
The Gauss-Legendre nodes are for an integral from [-1,1], but are transformed using a change of interval rule and stored in nodes2.
The basis functions are then calculated at each node and stored in M[“allnodes”] and the weights stored in M[“weights”].
The updated dictionary, M is then returned.
The python log-likelihood file is listed below.
type rcs_hazard_pysetup.py
. as jnp
import jax.numpy as mu
import mladutil
def python_ll(beta,X,wt,M):
## Parametersxb = mu.linpred(beta,X,1)
xbrcs = mu.linpred(beta,X,2)
## cumulative hazardexp(jnp.matmul(M["allnodes"],beta[1][:-1]) + beta[1][-1] + xb)
ch_at_nodes = jnp."t"]-M["t0"]))*jnp.sum(M["weights"]*ch_at_nodes,axis=1,keepdims=True)
cumhaz = (0.5*(M[
return(jnp.sum(wt*(M["d"]*(xb + xbrcs) - cumhaz)))
The linear predictor for the restricted cubic splines and the covariates effects are calculated using
mu.linpred()
.The cumulative hazard at each of the nodes for each individual is then calculated.
These are then used in conjunction with the weights to obtain the cumulative hazard at each event/censoring time.
Finally, the log-likelhood is returned.
Now the model can be fitted again, but this time making use of the setup program. This is passed to mlad
using the pysetup()
option.
drop lnt _rcs*
.
timer on 2
.
gen double lnt = ln(_t)
.
gen(_rcs) df(5) if2(_d==1) orthog
. rcsgen lnt,
Variables _rcs1 to _rcs5 were created
mata: st_matrix("knots",strtoreal(tokens(st_global("r(knots)"))))
.
matrix R_bhazard = r(R)
.
. scalar Nnodes = 50
.
xb: = hormon age size2 size3 enodes er pr_1, nocons ) ///
. mlad (///
> (rcs: = _rcs1 _rcs2 _rcs3 _rcs4 _rcs5) ///
> , llfile(rcs_hazard_pysetup) ///
> pysetup(setup_rcs_hazard) ///
> othervars(_t0 _t _d) d) ///
> othervarnames(t0 t ///
> matrices(knots R_bhazard)
> scalars(Nnodes)
Initial: Log likelihood = -8000969.7
Alternative: Log likelihood = -3434708.4
Rescale: Log likelihood = -1989541.5eq: Log likelihood = -1793634.8
Rescale
Iteration 0: Log likelihood = -1793634.8
Iteration 1: Log likelihood = -1748229.4
Iteration 2: Log likelihood = -1648108.7
Iteration 3: Log likelihood = -1644819.1
Iteration 4: Log likelihood = -1644599
Iteration 5: Log likelihood = -1644598.4
Iteration 6: Log likelihood = -1644598.4
ml display
.
of obs = 1,192,800
Number chi2(7) = 265100.41
Wald chi2 = 0.0000
Log likelihood = -1644598.4 Prob >
------------------------------------------------------------------------------
| Coefficient Std. err. z P>|z| [95% conf. interval]
-------------+----------------------------------------------------------------xb |
hormon | -.2124006 .0045147 -47.05 0.000 -.2212492 -.2035519
age | .0118462 .0001231 96.23 0.000 .0116049 .0120875
size2 | .3920078 .0034775 112.73 0.000 .385192 .3988236
size3 | .6967243 .0047848 145.61 0.000 .6873461 .7061024
enodes | -1.866594 .0052939 -352.60 0.000 -1.87697 -1.856218
er | -8.12e-06 5.56e-06 -1.46 0.145 -.000019 2.79e-06
pr_1 | -.0924092 .0006848 -134.94 0.000 -.0937514 -.0910669
-------------+----------------------------------------------------------------
rcs |
_rcs1 | .1130317 .0022287 50.72 0.000 .1086635 .1173999
_rcs2 | .1819154 .0017587 103.44 0.000 .1784684 .1853623
_rcs3 | -.0080271 .0017818 -4.51 0.000 -.0115194 -.0045348
_rcs4 | -.0268743 .0016598 -16.19 0.000 -.0301274 -.0236213
_rcs5 | .0171344 .001666 10.28 0.000 .0138691 .0203997_cons | -1.90774 .0090148 -211.62 0.000 -1.925408 -1.890071
------------------------------------------------------------------------------
timer off 2 .
The estimates from the models are identical. The times to fit each model can be seen below.
timer list
.
1: 73.37 / 1 = 73.3660 2: 46.40 / 1 = 46.4030
In this dataset using the setup file has a speed gain of 36.8% over the previous example.
The approach here is esentially what is implemented in stpm3
when using the python
option.