Commit 02ca4c4c authored by Marmaduke Woodman's avatar Marmaduke Woodman

use tvb param defaults & jit the thing

parent 01d41d82
......@@ -19,28 +19,64 @@ drift = Vector(
n = 6
flatjac = lambda x, y: x.diff(y).reshape(len(x)*len(y)).tolist()
terms = drift.reshape(1, n).tolist()[0] + flatjac(drift, state) + flatjac(drift, param)
from tvb.simulator.models.jansen_rit import JansenRit
defaults = {}
for (p,) in param.tolist():
key = str(p)
if hasattr(JansenRit, key):
defaults[key] = getattr(JansenRit, key).default.item()
use_defaults = True
if use_defaults:
terms = [_.subs(defaults).simplify() for _ in terms]
do_cse = True
if do_cse:
reps, rexs = cse(terms)
else:
reps, rexs = [], terms
for lhs, rhs in reps:
print(lhs, '=', rhs)
drift_rex = rexs[:n]
jac_state = rexs[n:n+n*n]
jac_param = rexs[n+n*n:]
lines = []
for i, var in enumerate(state):
lines.append(f'{var} = state[{i}]')
for lhs, rhs in reps:
lines.append(f'{lhs} = {rhs}')
for i, par in enumerate(param):
if str(par) not in defaults:
lines.append(f'{par} = param[{i}]')
for i, ex in enumerate(drift_rex):
print('xt[%d]' % i, '=', ex)
for i in range(n):
for j in range(n):
val = jac_state[i*n + j]
if val == 0:
continue
print("jxtx[%d, %d]" % (i, j), '=', val)
for i in range(n):
for j in range(n):
val = jac_param[i*n + j]
if val == 0:
continue
print("jxtp[%d, %d]" % (i, j), '=', val)
\ No newline at end of file
lines.append(f'xt[{i}] = {ex}')
for i, (js, jp) in enumerate(zip(jac_state, jac_param)):
if js != 0:
lines.append(f"jxtx[{i//n}, {i%n}] = {js}")
if jp != 0:
lines.append(f"jxtp[{i//n}, {i%n}] = {jp}")
# embed Euler and update multiple at once
header = 'def gufunc(state, param, xt, jxtx, jxtp):\n'
code = header + '\n'.join([' ' + line for line in lines])
ns = {}
exec('from math import *', ns)
exec(code, ns)
gufunc = ns['gufunc']
from numba import njit
import numpy as np
jit_gufunc = njit(gufunc)
state = np.random.randn(n)
param = np.random.randn(len(param))
xt = np.zeros((n, ))
jxtx = np.zeros((n, n))
jxtp = np.zeros((n, len(param)))
jit_gufunc(state, param, xt, jxtx, jxtp)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment