import torch
from .interpolation import interpolate
from .transformer import transformer, gradient, gradient_space
from .transformer import get_cell, calc_velocity
from .transformer import interpolate_grid
# %%
[docs]def assert_version():
numbers = torch.__version__.split(".")
version = float(numbers[0] + "." + numbers[1])
assert (
version >= 1.0
), """ You are using a older installation of pytorch, please install 1.0.0
or newer """
# %%
half = torch.float16
single = torch.float32
double = torch.float64
[docs]def to(x, dtype=torch.float32, device=None):
if type(device) == str:
device = torch.device("cuda") if device == "gpu" else torch.device("cpu")
if torch.is_tensor(x):
return x.detach().clone().type(dtype).to(device)
return torch.tensor(x, dtype=dtype, device=device)
# %%
[docs]def tonumpy(x):
return x.cpu().numpy()
# %%
[docs]def check_device(x, device_name):
return (x.is_cuda) == (device_name == "gpu")
# %%
[docs]def backend_type():
return torch.Tensor
# %%
# %%
[docs]def identity(d, n_sample=1, epsilon=0, device="cpu"):
assert epsilon >= 0, "epsilon need to be larger than 0"
device = torch.device("cpu") if device == "cpu" else torch.device("cuda")
return torch.zeros(n_sample, d, dtype=torch.float32, device=device) + epsilon
# %%
# %%
[docs]def exp(*args, **kwargs):
return torch.exp(*args, **kwargs)
[docs]def linspace(*args, **kwargs):
return torch.linspace(*args, **kwargs)
[docs]def meshgrid(*args, **kwargs):
return torch.meshgrid(*args, **kwargs)
[docs]def matmul(*args, **kwargs):
return torch.matmul(*args, **kwargs)
[docs]def max(*args, **kwargs):
return torch.max(*args, **kwargs)
[docs]def ones(*args, **kwargs):
return torch.ones(*args, **kwargs)
[docs]def pdist(c):
x, y = torch.meshgrid(c, c)
return torch.abs(x - y)