import torch, torch.nn.functional as F
def __init__(self, H, W, M=4, fx=..., fy=..., cx=..., cy=...,
sigma_rgb=10, sigma_d=0.02, dt=1/30):
# particle state: (H,W,M,6) [x,y,z,vx,vy,vz]
self.fx, self.fy, self.cx, self.cy = fx, fy, cx, cy
self.sigma_rgb, self.sigma_d = sigma_rgb, sigma_d
# ---------- helper -------------------------------------------------
def backproject(self, depth):
"""depth: (H,W) in metres → xyz world coords same shape"""
u = torch.arange(W,device=depth.device).view(1,-1).expand(H,-1)
v = torch.arange(H,device=depth.device).view(-1,1).expand(-1,W)
x = (u - self.cx) * depth / self.fx
y = (v - self.cy) * depth / self.fy
return torch.stack((x,y,z),-1) # (H,W,3)
# ---------- API ----------------------------------------------------
xyz0 = self.backproject(depth0)
self.state = torch.zeros(H,W,self.M,6,device=xyz0.device)
self.state[...,:3] = xyz0.unsqueeze(2) # pos
self.state[...,3:] = 0 # vel = 0
def step(self, rgb_t, depth_t, rgb_t1, depth_t1):
rgb_* : uint8 (H,W,3) in 0‑255
depth_*: float (H,W) in metres
Returns dense 3‑D flow (H,W,3)
if self.state is None: self.init(depth_t)
H,W,_,M = *rgb_t.shape, self.M
# ----- ➊ PREDICT ------------------------------------------------
noise_v = 0.01*torch.randn_like(st[...,3:])
noise_p = 0.001*torch.randn_like(st[...,:3])
st[...,3:] += noise_v # v ← v+ϵ
st[...,:3] += st[...,3:]*self.dt + noise_p # x ← x+vΔt
# ----- ➋ PROJECT TO t+1 ----------------------------------------
x,y,z = st[...,:3].unbind(-1) # (H,W,M)
u = self.fx*x/z + self.cx
v = self.fy*y/z + self.cy
grid = torch.stack((2*u/W-1, 2*v/H-1), -1) # NDC
rgb_pred = F.grid_sample(rgb_t1.float().permute(2,0,1)
.unsqueeze(0)/255., grid.view(1,-1,1,2),
mode='bilinear',padding_mode='border'
).view(3,H,W,M).permute(1,2,3,0)
depth_pred = F.grid_sample(depth_t1.unsqueeze(0).unsqueeze(0),
mode='bilinear',padding_mode='border'
# ----- ➌ LIKELIHOOD -------------------------------------------
col_err = ((rgb_pred*255 - rgb_t.unsqueeze(2))**2).mean(-1)
p_rgb = torch.exp(-col_err /(2*self.sigma_rgb**2)) # (H,W,M)
d_err = torch.abs(depth_pred - z)
p_d = torch.exp(-d_err /(self.sigma_d)) # (H,W,M)
w /= w.sum(-1,keepdim=True) # normalise
# ----- ➍ RESAMPLE (+ ray‑jitter) -------------------------------
idx = torch.multinomial(w.view(-1,M), M, replacement=True)\
index=idx.unsqueeze(-1).expand(-1,-1,-1,6))
st[...,'z slice'] += torch.randn_like(st[...,'z slice'])*0.002
self.state = st # keep for next
flow = st[...,3:] # v field
return flow.mean(2) # (H,W,3)