Source code for difw.backend.numpy.transformer
# %% SETUP
import numpy as np
eps = np.finfo(np.float32).eps
np.seterr(divide="ignore", invalid="ignore")
# %% BATCH EFFECT
[docs]def batch_effect(x, theta):
if x.ndim == 1:
n_batch = theta.shape[0]
n_points = x.shape[-1]
x = np.broadcast_to(x, (n_batch, n_points)) # .flatten()
return x.flatten()
# %% FUNCTIONS
[docs]def get_affine(x, theta, params):
if params.precomputed:
return params.A, params.r
else:
n_batch = theta.shape[0]
n_points = x.shape[-1]
# r = np.broadcast_to(np.arange(n_batch), [n_points, n_batch]).T
# NOTE: here we suppose batch effect has been already executed
r = np.arange(n_batch).repeat(n_points / n_batch)
A = params.B.dot(theta.T).T.reshape(n_batch, -1, 2)
return A, r
[docs]def precompute_affine(x, theta, params):
params = params.copy()
params.precomputed = False
params.A, params.r = get_affine(x, theta, params)
params.precomputed = True
return params
[docs]def right_boundary(c, params):
xmin, xmax, nc = params.xmin, params.xmax, params.nc
return xmin + (c + 1) * (xmax - xmin) / nc + eps
[docs]def left_boundary(c, params):
xmin, xmax, nc = params.xmin, params.xmax, params.nc
return xmin + c * (xmax - xmin) / nc - eps
[docs]def get_cell(x, params):
xmin, xmax, nc = params.xmin, params.xmax, params.nc
c = np.floor((x - xmin) / (xmax - xmin) * nc)
c = np.clip(c, 0, nc - 1).astype(np.int32)
return c
[docs]def get_velocity(x, theta, params):
A, r = get_affine(x, theta, params)
c = get_cell(x, params)
a = A[r, c, 0]
b = A[r, c, 1]
return a * x + b
[docs]def get_psi(x, t, theta, params):
A, r = get_affine(x, theta, params)
c = get_cell(x, params)
a = A[r, c, 0]
b = A[r, c, 1]
cond = a == 0
x1 = x + t * b
eta = np.exp(t * a)
x2 = eta * x + (b / a) * (eta - 1.0)
# x2 = np.exp(t * a) * (x + (b / a)) - (b / a)
psi = np.where(cond, x1, x2)
return psi
[docs]def get_hit_time(x, theta, params):
A, r = get_affine(x, theta, params)
c = get_cell(x, params)
v = get_velocity(x, theta, params)
xc = np.where(v >= 0, right_boundary(c, params), left_boundary(c, params))
a = A[r, c, 0]
b = A[r, c, 1]
cond = a == 0
x1 = (xc - x) / b
x2 = np.log((xc + b / a) / (x + b / a)) / a
thit = np.where(cond, x1, x2)
return thit
[docs]def get_phi_numeric(x, t, theta, params):
nSteps2 = params.nSteps2
yn = x
deltaT = t / nSteps2
for j in range(nSteps2):
c = get_cell(yn, params)
midpoint = yn + deltaT / 2 * get_velocity(yn, theta, params)
c = get_cell(midpoint, params)
yn = yn + deltaT * get_velocity(midpoint, theta, params)
return yn
# %% INTEGRATION
[docs]def integrate_numeric(x, theta, params, time=1.0):
# setup
x = batch_effect(x, theta)
n_batch = theta.shape[0]
t = time
params = precompute_affine(x, theta, params)
# computation
xPrev = x
nSteps1 = params.nSteps1
deltaT = t / nSteps1
c = get_cell(xPrev, params)
for j in range(nSteps1):
xTemp = get_psi(xPrev, deltaT, theta, params)
cTemp = get_cell(xTemp, params)
xNum = get_phi_numeric(xPrev, deltaT, theta, params)
xPrev = np.where(c == cTemp, xTemp, xNum)
c = get_cell(xPrev, params)
return xPrev.reshape((n_batch, -1))
[docs]def integrate_closed_form(x, theta, params, time=1.0):
# setup
x = batch_effect(x, theta)
t = np.ones_like(x) * time
params = precompute_affine(x, theta, params)
n_batch = theta.shape[0]
# computation
phi = np.empty_like(x)
done = np.full_like(x, False, dtype=bool)
c = get_cell(x, params)
cont = 0
while True:
left = left_boundary(c, params)
right = right_boundary(c, params)
v = get_velocity(x, theta, params)
psi = get_psi(x, t, theta, params)
cond1 = np.logical_and(left <= psi, psi <= right)
cond2 = np.logical_and(v >= 0, c == params.nc - 1)
cond3 = np.logical_and(v <= 0, c == 0)
valid = np.any((cond1, cond2, cond3), axis=0)
phi[~done] = psi
done[~done] = valid
if np.alltrue(valid):
return phi.reshape((n_batch, -1))
x, t, params.r = x[~valid], t[~valid], params.r[~valid]
t -= get_hit_time(x, theta, params)
x = np.clip(psi, left, right)[~valid]
c = np.where(v >= 0, c + 1, c - 1)[~valid]
cont += 1
if cont > params.nc:
raise BaseException
return None
[docs]def integrate_closed_form_trace(x, theta, params, time=1.0):
x = batch_effect(x, theta)
t = np.ones_like(x) * time
params = precompute_affine(x, theta, params)
result = np.empty((*x.shape, 3))
done = np.full_like(x, False, dtype=bool)
c = get_cell(x, params)
cont = 0
while True:
left = left_boundary(c, params)
right = right_boundary(c, params)
v = get_velocity(x, theta, params)
psi = get_psi(x, t, theta, params)
cond1 = np.logical_and(left <= psi, psi <= right)
cond2 = np.logical_and(v >= 0, c == params.nc - 1)
cond3 = np.logical_and(v <= 0, c == 0)
valid = np.any((cond1, cond2, cond3), axis=0)
result[~done] = np.array([psi, t, c]).T
done[~done] = valid
if np.alltrue(valid):
return result
x, t, params.r = x[~valid], t[~valid], params.r[~valid]
t -= get_hit_time(x, theta, params)
x = np.clip(psi, left, right)[~valid]
c = np.where(v >= 0, c + 1, c - 1)[~valid]
cont += 1
if cont > params.nc:
raise BaseException
return None
# %% DERIVATIVE
[docs]def derivative_numeric(x, theta, params, time=1.0, h=1e-3):
# setup
n_points = x.shape[-1]
n_batch = theta.shape[0]
d = theta.shape[1]
# computation
der = np.empty((n_batch, n_points, d))
phi_1 = integrate_numeric(x, theta, params, time)
for k in range(d):
theta2 = theta.copy()
theta2[:, k] += h
phi_2 = integrate_numeric(x, theta2, params, time)
der[:, :, k] = (phi_2 - phi_1) / h
return phi_1, der
[docs]def derivative_closed_form(x, theta, params, time=1.0):
# setup
n_points = x.shape[-1]
n_batch = theta.shape[0]
d = theta.shape[1]
# computation
result = integrate_closed_form_trace(x, theta, params, time)
phi = result[:, 0].reshape((n_batch, -1))
tm = result[:, 1]
cm = result[:, 2]
# setup
x = batch_effect(x, theta)
params = precompute_affine(x, theta, params)
der = np.empty((n_batch, n_points, d))
for k in range(d):
dthit_dtheta_cum = np.zeros_like(x)
xm = x.copy()
c = get_cell(x, params)
while True:
valid = c == cm
if np.alltrue(valid):
break
step = np.sign(cm - c)
dthit_dtheta_cum[~valid] -= derivative_thit_theta(xm, theta, k, params)[~valid]
xm[~valid] = np.where(step == 1, right_boundary(c, params), left_boundary(c, params))[~valid]
c = c + step
dpsi_dtheta = derivative_psi_theta(xm, tm, theta, k, params)
dpsi_dtime = derivative_phi_time(xm, tm, theta, k, params)
dphi_dtheta = dpsi_dtheta + dpsi_dtime * dthit_dtheta_cum
der[:, :, k] = dphi_dtheta.reshape(n_batch, n_points)
return phi, der
[docs]def derivative_psi_theta(x, t, theta, k, params):
A, r = get_affine(x, theta, params)
c = get_cell(x, params)
a = A[r, c, 0]
b = A[r, c, 1]
ak = params.B[2 * c, k]
bk = params.B[2 * c + 1, k]
cond = a == 0
d1 = t * (x * ak + bk)
d2 = ak * t * np.exp(a * t) * (x + b / a) + (np.exp(t * a) - 1) * (bk * a - ak * b) / a ** 2
dpsi_dtheta = np.where(cond, d1, d2)
return dpsi_dtheta
[docs]def derivative_phi_time(x, t, theta, k, params):
A, r = get_affine(x, theta, params)
c = get_cell(x, params)
a = A[r, c, 0]
b = A[r, c, 1]
cond = a == 0
d1 = b
d2 = np.exp(t * a) * (a * x + b)
dpsi_dtime = np.where(cond, d1, d2)
return dpsi_dtime
[docs]def derivative_thit_theta(x, theta, k, params):
A, r = get_affine(x, theta, params)
c = get_cell(x, params)
a = A[r, c, 0]
b = A[r, c, 1]
ak = params.B[2 * c, k]
bk = params.B[2 * c + 1, k]
v = get_velocity(x, theta, params)
xc = np.where(v >= 0, right_boundary(c, params), left_boundary(c, params))
cond = a == 0
d1 = (x - xc) * bk / b ** 2
d2 = -ak * np.log((a * xc + b) / (a * x + b)) / a ** 2
d3 = (x - xc) * (bk * a - ak * b) / (a * (a * x + b) * (a * xc + b))
dthit_dtheta = np.where(cond, d1, d2 + d3)
return dthit_dtheta
# %% DERIVATIVE SPACE
[docs]def derivative_space_numeric(x, theta, params, time=1.0, h=1e-3):
# setup
n_points = x.shape[-1]
n_batch = theta.shape[0]
# computation
xe = np.concatenate([x, x+h])
phi = integrate_numeric(xe, theta, params, time)
phi_1, phi_2 = np.split(phi, 2, axis=1)
der = (phi_2 - phi_1) / h
return phi_1, der
[docs]def derivative_space_closed_form(x, theta, params, time=1.0):
# setup
n_points = x.shape[-1]
n_batch = theta.shape[0]
d = theta.shape[1]
# computation
t = np.ones_like(x) * time
result = integrate_closed_form_trace(x, theta, params, time)
phi = result[:, 0].reshape((n_batch, -1))
tm = result[:, 1]
cm = result[:, 2]
# setup
x = batch_effect(x, theta)
params = precompute_affine(x, theta, params)
dthit_dx = np.zeros_like(x)
dpsi_dx = np.zeros_like(x)
c = get_cell(x, params)
valid = c == cm
# dpsi_dx only on first valid cell
dpsi_dx[valid] = derivative_psi_x(x, t, theta, params)[valid]
# dthit_dx only on first non valid cell
dthit_dx[~valid] = derivative_thit_x(x, t, theta, params)[~valid]
xm = x.copy()
while True:
valid = c == cm
if np.alltrue(valid):
break
step = np.sign(cm - c)
xm[~valid] = np.where(step == 1, right_boundary(c, params), left_boundary(c, params))[~valid]
c = c + step
# dpsi_dtime at last cell
dpsi_dtime = derivative_psi_t(xm, tm, theta, params)
dphi_dx = dpsi_dx + dpsi_dtime * dthit_dx
dphi_dx = dphi_dx.reshape(n_batch, n_points)
return phi, dphi_dx
[docs]def derivative_thit_x(x, t, theta, params):
A, r = get_affine(x, theta, params)
c = get_cell(x, params)
a = A[r, c, 0]
b = A[r, c, 1]
return 1.0 / (a*x + b)
[docs]def derivative_psi_x(x, t, theta, params):
A, r = get_affine(x, theta, params)
c = get_cell(x, params)
a = A[r, c, 0]
b = A[r, c, 1]
return np.exp(t * a)
[docs]def derivative_psi_t(x, t, theta, params):
A, r = get_affine(x, theta, params)
c = get_cell(x, params)
a = A[r, c, 0]
b = A[r, c, 1]
return np.exp(t * a) * (a * x + b)