import json import os from threading import Lock import cv2 import numpy as np import torch WIN = "window" SPACING = 50 GREEN = (0, 255, 0) RED = (0, 0, 255) LIGHT_GREEN = (31, 167, 31) LIGHT_RED = (31, 31, 167) LINE_FILE = "/tmp/line.json" def save_line(line): with open(LINE_FILE, "w") as f: json.dump(line, f) def try_load_line(): if not os.path.exists(LINE_FILE): return [] with open(LINE_FILE, "r") as f: return json.load(f) def draw_circles(img, points, size, color): for point in points: cv2.circle(img, tuple(np.round(point).astype(np.int32)), size, color, -1) return img def draw_line(img, line, color, *, with_points=False): if len(line) > 1: cv2.polylines(img, [np.round(line).astype(np.int32)], False, color, 2) if with_points: img = draw_circles(img, line, 5, color) return img def area_between_curves(a, b): polygon = torch.cat([b, torch.flip(a, dims=[0])], dim=0) # Calculate the area using the shoelace formula x = polygon[:, 0] y = polygon[:, 1] return 0.5 * torch.abs(x @ torch.roll(y, -1) - torch.roll(x, -1) @ y) def find_beta(x, a, b): return (x - a) / (b - a + 1e-8) def sample_segments(line, t): deltas = seglengths(line) d = torch.cumsum(deltas, dim=0) / torch.sum(deltas) d = torch.cat([torch.zeros_like(d[0:1]), d]) idx = torch.minimum( torch.searchsorted(d, t), torch.tensor(line.shape[-2] - 1, dtype=torch.long), ) return ( torch.stack([line[idx - 1], line[idx]], dim=-2), torch.stack([d[idx - 1], d[idx]], dim=-1), ) def sample_polyline(line, t): segs, ds = sample_segments(line, t) beta = find_beta(t, ds[..., 0], ds[..., 1])[..., None] return (1 - beta) * segs[..., 0, :] + beta * segs[..., 1, :] def l1_norm(x): return torch.sum(torch.abs(x), dim=-1) def l2_norm_squared(x): return torch.sum(torch.square(x), dim=-1) def cross_2d(a, b): return a[..., 0] * b[..., 1] - a[..., 1] * b[..., 0] def point_to_segment_distance(p, s): s0 = s[..., 0, :] s1 = s[..., 1, :] ds = s1 - s0 return torch.abs(cross_2d(p, ds) + cross_2d(s1, s0)) / torch.norm(ds, dim=-1) def seglengths(line): return torch.norm(torch.diff(line, dim=-2), dim=-1) def spacing_error(line): # Calculate the distances between consecutive points return torch.mean(torch.abs(SPACING - seglengths(line))) def fit_to_line(line, *, canvas): total_length = np.sum(np.linalg.norm(np.diff(line, axis=0), axis=1)) num_points = np.round(total_length / SPACING + 1).astype(np.long) # Generate evenly spaced points along the line if len(line) < 2: return None init = np.linspace(line[0], line[-1], num_points) if len(init) < 3: return init img = canvas.copy() img = draw_line(img, line, GREEN) show = draw_line(img.copy(), init, RED, with_points=True) cv2.imshow(WIN, show) cv2.waitKey(1) fit = torch.tensor(init, dtype=torch.float32, requires_grad=True) optimizer = torch.optim.Adam([fit], lr=5) line_pt = torch.tensor(line, dtype=torch.float32) fit_t = torch.linspace(0, 1, num_points, dtype=torch.float32) segs_at_exact, _ = sample_segments(line_pt, fit_t) line_at_exact = sample_polyline(line_pt, fit_t) while True: show = img.copy() optimizer.zero_grad() def point_to_point_error(radius=SPACING): # ruff: noqa: B023 return torch.mean(l2_norm_squared(line_at_exact - fit) / radius**2) def point_to_line_error(radius=SPACING): # ruff: noqa: B023 return torch.mean(point_to_segment_distance(fit, segs_at_exact) / radius) def line_to_line_error(radius=SPACING): sample_points = torch.rand((num_points // 2), dtype=torch.float32) sample_line = sample_polyline(line_pt, sample_points) sample_fit = sample_polyline(fit, sample_points) nonlocal show show = draw_circles(show, sample_line.numpy(), 5, LIGHT_GREEN) show = draw_circles(show, sample_fit.detach().numpy(), 5, LIGHT_RED) return torch.mean(l1_norm(sample_line - sample_fit) / radius) loss = ( 0 # + point_to_point_error() + point_to_line_error() * 2 + line_to_line_error() # # + spacing_error(fit) ) loss.backward() optimizer.step() show = draw_line(show, fit.detach().numpy(), RED, with_points=True) cv2.imshow(WIN, show) if cv2.waitKey(1) == ord("x"): break fit = fit.detach() abc = area_between_curves(line_pt, fit) spacing_err = spacing_error(fit) print("Area between curves:", abc.item()) print("Spacing error:", spacing_err.item()) return fit.numpy() def main(): line = try_load_line() lock = Lock() def optimization(): if len(line) < 2: return img = draw_line(canvas.copy(), line, GREEN) fit = fit_to_line(np.array(line), canvas=canvas) if fit is not None: img = draw_line(img, fit, RED, with_points=True) cv2.imshow(WIN, img) def on_mouse(event, x, y, flags, _): if lock.locked(): return if event == cv2.EVENT_LBUTTONDOWN: with lock: line.append((x, y)) optimization() canvas = np.zeros((500, 500, 3), dtype=np.uint8) cv2.namedWindow(WIN) cv2.imshow(WIN, canvas) cv2.waitKey(1) cv2.setMouseCallback(WIN, on_mouse) with lock: optimization() while True: if (k := cv2.waitKey(1)) == ord("c"): line.clear() cv2.imshow(WIN, canvas) elif k == ord("s"): if line: save_line(line) elif k == ord("l"): line[:] = try_load_line() with lock: optimization() elif k == ord("d"): if os.path.isfile(LINE_FILE): os.unlink(LINE_FILE) elif k == ord("q"): break if __name__ == "__main__": main()