Files
polyfit/main.py
2025-08-02 15:41:51 +02:00

226 lines
5.9 KiB
Python

import json
import os
from threading import Lock
import cv2
import numpy as np
import torch
WIN = "window"
SPACING = 150
GREEN = (0, 255, 0)
RED = (0, 0, 255)
LIGHT_GREEN = (63, 191, 63)
LIGHT_RED = (63, 63, 191)
LINE_FILE = "/tmp/line.json"
def save_line(line):
with open(LINE_FILE, "w") as f:
json.dump(line, f)
def try_load_line():
if not os.path.exists(LINE_FILE):
return []
with open(LINE_FILE, "r") as f:
return json.load(f)
def draw_circles(img, points, size, color):
for point in points:
cv2.circle(img, tuple(np.round(point).astype(np.int32)), size, color, -1)
return img
def draw_line(img, line, color, *, with_points=False):
if len(line) > 1:
cv2.polylines(img, [np.round(line).astype(np.int32)], False, color, 2)
if with_points:
img = draw_circles(img, line, 5, color)
return img
def area_between_curves(a, b):
polygon = torch.cat([b, torch.flip(a, dims=[0])], dim=0)
# Calculate the area using the shoelace formula
x = polygon[:, 0]
y = polygon[:, 1]
return 0.5 * torch.abs(x @ torch.roll(y, -1) - torch.roll(x, -1) @ y)
def find_beta(x, a, b):
return (x - a) / (b - a + 1e-8)
def sample_segments(line, t):
deltas = seglengths(line)
d = torch.cumsum(deltas, dim=0) / torch.sum(deltas)
d = torch.cat([torch.zeros_like(d[0:1]), d])
idx = torch.minimum(
torch.searchsorted(d, t),
torch.tensor(line.shape[-2] - 1, dtype=torch.long),
)
return (
torch.stack([line[idx - 1], line[idx]], dim=-2),
torch.stack([d[idx - 1], d[idx]], dim=-1),
)
def sample_polyline(line, t):
segs, ds = sample_segments(line, t)
beta = find_beta(t, ds[..., 0], ds[..., 1])[..., None]
return (1 - beta) * segs[..., 0, :] + beta * segs[..., 1, :]
def l1_norm(x):
return torch.sum(torch.abs(x), dim=-1)
def l2_norm_squared(x):
return torch.sum(torch.square(x), dim=-1)
def cross_2d(a, b):
return a[..., 0] * b[..., 1] - a[..., 1] * b[..., 0]
def point_to_segment_distance(p, s):
s0 = s[..., 0, :]
s1 = s[..., 1, :]
ds = s1 - s0
return torch.abs(cross_2d(p, ds) + cross_2d(s1, s0)) / torch.norm(ds, dim=-1)
def seglengths(line):
return torch.norm(torch.diff(line, dim=-2), dim=-1)
def spacing_error(line):
# Calculate the distances between consecutive points
return torch.mean(torch.abs(SPACING - seglengths(line)))
def fit_to_line(line, *, canvas):
total_length = np.sum(np.linalg.norm(np.diff(line, axis=0), axis=1))
num_points = np.round(total_length / SPACING + 1).astype(np.long)
# Generate evenly spaced points along the line
if len(line) < 2:
return None
init = np.linspace(line[0], line[-1], num_points)
if len(init) < 3:
return init
img = canvas.copy()
img = draw_line(img, line, GREEN)
fit = torch.tensor(init, dtype=torch.float32, requires_grad=True)
optimizer = torch.optim.Adam([fit], lr=1)
line_pt = torch.tensor(line, dtype=torch.float32)
fit_t = torch.linspace(0, 1, num_points, dtype=torch.float32)
segs_at_exact, _ = sample_segments(line_pt, fit_t)
line_at_exact = sample_polyline(line_pt, fit_t)
while True:
show = img.copy()
optimizer.zero_grad()
def point_to_point_error(radius=SPACING):
# ruff: noqa: B023
return torch.mean(l2_norm_squared(line_at_exact - fit) / radius**2)
def point_to_line_error(radius=SPACING):
# ruff: noqa: B023
return torch.mean(point_to_segment_distance(fit, segs_at_exact) / radius)
def line_to_line_error(radius=SPACING):
sample_points = torch.rand((num_points // 2), dtype=torch.float32)
sample_line = sample_polyline(line_pt, sample_points)
sample_fit = sample_polyline(fit, sample_points)
nonlocal show
show = draw_circles(show, sample_line.numpy(), 5, LIGHT_GREEN)
show = draw_circles(show, sample_fit.detach().numpy(), 5, LIGHT_RED)
return torch.mean(l1_norm(sample_line - sample_fit) / radius)
loss = (
0
#
+ 0.8 * point_to_point_error()
+ point_to_line_error() * 4
+ line_to_line_error()
#
# + spacing_error(fit)
)
loss.backward()
optimizer.step()
show = draw_line(show, fit.detach().numpy(), RED, with_points=True)
cv2.imshow(WIN, show)
if cv2.waitKey(1) == ord("x"):
break
fit = fit.detach()
abc = area_between_curves(line_pt, fit)
spacing_err = spacing_error(fit)
print("Area between curves:", abc.item())
print("Spacing error:", spacing_err.item())
return fit.numpy()
def main():
line = try_load_line()
lock = Lock()
def optimization():
if len(line) < 2:
return
img = draw_line(canvas.copy(), line, GREEN)
fit = fit_to_line(np.array(line), canvas=canvas)
if fit is not None:
img = draw_line(img, fit, RED)
cv2.imshow(WIN, img)
def on_mouse(event, x, y, flags, _):
if lock.locked():
return
if event == cv2.EVENT_LBUTTONDOWN:
with lock:
line.append((x, y))
optimization()
canvas = np.zeros((500, 500, 3), dtype=np.uint8)
cv2.namedWindow(WIN)
cv2.imshow(WIN, canvas)
cv2.waitKey(1)
cv2.setMouseCallback(WIN, on_mouse)
if line:
with lock:
optimization()
while True:
if (k := cv2.waitKey(1)) == ord("c"):
line.clear()
cv2.imshow(WIN, canvas)
elif k == ord("s"):
if line:
save_line(line)
elif k == ord("l"):
line[:] = try_load_line()
elif k == ord("d"):
if os.path.isfile(LINE_FILE):
os.unlink(LINE_FILE)
elif k == ord("q"):
break
if __name__ == "__main__":
main()