# %% SETUP
import torch
# %% COMPARE EQUAL TO ZERO
eps = torch.finfo(torch.float32).eps
[docs]def cmpf(x, y):
return torch.abs(x - y) < eps
[docs]def cmpf0(x):
return torch.abs(x) < eps
# %% BATCH EFFECT
[docs]def batch_effect(x, theta):
if x.ndim == 1:
n_batch = theta.shape[0]
n_points = x.shape[-1]
x = torch.broadcast_to(x, (n_batch, n_points)) # .flatten()
return x.flatten()
# %% FUNCTIONS
[docs]def get_affine(x, theta, params):
if params.precomputed:
return params.A, params.r
else:
n_batch = theta.shape[0]
n_points = x.shape[-1]
repeat = int(n_points / n_batch)
# r = np.broadcast_to(np.arange(n_batch), [n_points, n_batch]).T
# NOTE: here we suppose batch effect has been already executed
r = torch.arange(n_batch).repeat_interleave(repeat).long().to(x.device)
A = params.B.mm(theta.T).T.reshape(n_batch, -1, 2).to(x.device)
return A, r
[docs]def precompute_affine(x, theta, params):
params = params.copy()
params.precomputed = False
params.A, params.r = get_affine(x, theta, params)
params.precomputed = True
return params
[docs]def right_boundary(c, params):
xmin, xmax, nc = params.xmin, params.xmax, params.nc
return xmin + (c + 1) * (xmax - xmin) / nc + eps
[docs]def left_boundary(c, params):
xmin, xmax, nc = params.xmin, params.xmax, params.nc
return xmin + c * (xmax - xmin) / nc - eps
[docs]def get_cell(x, params):
xmin, xmax, nc = params.xmin, params.xmax, params.nc
c = torch.floor((x - xmin) / (xmax - xmin) * nc)
c = torch.clamp(c, 0, nc - 1).long()
return c
[docs]def get_velocity(x, theta, params):
A, r = get_affine(x, theta, params)
c = get_cell(x, params)
a = A[r, c, 0]
b = A[r, c, 1]
return a * x + b
[docs]def calc_velocity(grid, theta, params):
grid = batch_effect(grid, theta)
v = get_velocity(grid, theta, params)
return v.reshape(theta.shape[0], -1)
[docs]def get_psi(x, t, theta, params):
A, r = get_affine(x, theta, params)
c = get_cell(x, params)
a = A[r, c, 0]
b = A[r, c, 1]
cond = cmpf0(a)
x1 = x + t * b
eta = torch.exp(t * a)
x2 = eta * x + (b / a) * (eta - 1.0)
# x2 = torch.exp(t * a) * (x + (b / a)) - (b / a)
psi = torch.where(cond, x1, x2)
return psi
[docs]def get_hit_time(x, theta, params):
thit = torch.empty_like(x)
valid = torch.full_like(x, True, dtype=bool)
c = get_cell(x, params)
A, r = get_affine(x, theta, params)
a = A[r, c, 0]
b = A[r, c, 1]
v = a * x + b
cond1 = cmpf0(v)
cc = c + torch.sign(v)
cond2 = torch.logical_or(cc < 0, cc >= params.nc)
xc = torch.where(v > 0, right_boundary(c, params), left_boundary(c, params))
vc = a * xc + b
cond3 = cmpf0(vc)
cond4 = torch.sign(v) != torch.sign(vc)
cond5 = torch.logical_or(xc == params.xmin, xc == params.xmax)
cond = cond1 | cond2 | cond3 | cond4 | cond5
thit[~cond] = torch.where(
cmpf0(a[~cond]), (xc[~cond] - x[~cond]) / b[~cond], torch.log(vc[~cond] / v[~cond]) / a[~cond],
)
thit[cond] = float("inf")
return thit, xc, cc
[docs]def get_phi_numeric(x, t, theta, params):
nSteps2 = params.nSteps2
yn = x
deltaT = t / nSteps2
for j in range(nSteps2):
midpoint = yn + deltaT / 2 * get_velocity(yn, theta, params)
yn = yn + deltaT * get_velocity(midpoint, theta, params)
return yn
# %% INTEGRATION
[docs]def integrate_numeric(x, theta, params, time=1.0):
# setup
x = batch_effect(x, theta)
n_batch = theta.shape[0]
t = time
params = precompute_affine(x, theta, params)
# computation
xPrev = x
nSteps1 = params.nSteps1
deltaT = t / nSteps1
for j in range(nSteps1):
c = get_cell(xPrev, params)
xTemp = get_psi(xPrev, deltaT, theta, params)
cTemp = get_cell(xTemp, params)
xNum = get_phi_numeric(xPrev, deltaT, theta, params)
xPrev = torch.where(c == cTemp, xTemp, xNum)
return xPrev.reshape((n_batch, -1))
# %% DERIVATIVE
[docs]def derivative_numeric(x, theta, params, time=1.0, h=1e-3):
# setup
n_points = x.shape[-1]
n_batch = theta.shape[0]
d = theta.shape[1]
# computation
der = torch.empty((n_batch, n_points, d), device=x.device)
phi_1 = integrate_numeric(x, theta, params, time)
for k in range(d):
theta2 = theta.clone()
theta2[:, k] += h
phi_2 = integrate_numeric(x, theta2, params, time)
der[:, :, k] = (phi_2 - phi_1) / h
return phi_1, der
[docs]def derivative_psi_theta(x, t, theta, k, params):
A, r = get_affine(x, theta, params)
A = A.double() # NOTE: double precision is necessary
c = get_cell(x, params)
a = A[r, c, 0]
b = A[r, c, 1]
ak = params.B[2 * c, k]
bk = params.B[2 * c + 1, k]
cond = cmpf0(a)
d1 = t * (x * ak + bk)
d2 = ak * t * torch.exp(t * a) * (x + b / a) + (torch.exp(t * a) - 1) * (bk * a - ak * b) / a ** 2
d1 = d1.double()
d2 = d2.double()
dpsi_dtheta = torch.where(cond, d1, d2)
return dpsi_dtheta
[docs]def derivative_phi_time(x, t, theta, k, params):
A, r = get_affine(x, theta, params)
c = get_cell(x, params)
a = A[r, c, 0]
b = A[r, c, 1]
cond = cmpf0(a)
d1 = b
d2 = torch.exp(t * a) * (a * x + b)
dpsi_dtime = torch.where(cond, d1, d2)
return dpsi_dtime
[docs]def derivative_thit_theta(x, theta, k, params):
A, r = get_affine(x, theta, params)
A = A.double() # NOTE: double precision is necessary
c = get_cell(x, params)
a = A[r, c, 0]
b = A[r, c, 1]
ak = params.B[2 * c, k]
bk = params.B[2 * c + 1, k]
v = get_velocity(x, theta, params)
xc = torch.where(v >= 0, right_boundary(c, params), left_boundary(c, params))
cond = cmpf0(a)
d1 = (x - xc) * bk / b ** 2
d2 = -ak * torch.log((a * xc + b) / (a * x + b)) / a ** 2
d3 = (x - xc) * (bk * a - ak * b) / (a * (a * x + b) * (a * xc + b))
dthit_dtheta = torch.where(cond, d1, d2 + d3)
return dthit_dtheta
# %% TRANSFORMATION
[docs]def derivative_numeric_trace(phi_1, x, theta, params, time=1.0, h=1e-3):
# setup
n_points = x.shape[-1]
n_batch = theta.shape[0]
d = theta.shape[1]
# computation
der = torch.empty((n_batch, n_points, d), device=x.device)
# phi_1 = integrate_numeric(x, theta, params)
for k in range(d):
theta[:, k] += h
phi_2 = integrate_numeric(x, theta, params, time)
theta[:, k] -= h
der[:, :, k] = (phi_2 - phi_1) / h
return der
# %% GRADIENT SPACE
[docs]def derivative_space_numeric(x, theta, params, time=1.0, h=1e-3):
# setup
n_points = x.shape[-1]
n_batch = theta.shape[0]
d = theta.shape[1]
# phi_1 = integrate_numeric(x, theta, params, time)
# phi_2 = integrate_numeric(x+h, theta, params, time)
# der = (phi_2 - phi_1) / h
# return phi_1, der
# computation
xe = torch.cat([x, x+h])
phi = integrate_numeric(xe, theta, params, time)
phi_1, phi_2 = torch.split(phi, n_points, dim=1)
der = (phi_2 - phi_1) / h
return phi_1, der
[docs]def derivative_thit_x(x, t, theta, params):
A, r = get_affine(x, theta, params)
c = get_cell(x, params)
a = A[r, c, 0]
b = A[r, c, 1]
return 1.0 / (a*x + b)
[docs]def derivative_psi_x(x, t, theta, params):
A, r = get_affine(x, theta, params)
c = get_cell(x, params)
a = A[r, c, 0]
b = A[r, c, 1]
return torch.exp(t * a)
[docs]def derivative_psi_t(x, t, theta, params):
A, r = get_affine(x, theta, params)
c = get_cell(x, params)
a = A[r, c, 0]
b = A[r, c, 1]
return torch.exp(t * a) * (a * x + b)