From c1e4615da6cc7ba0ea7d106006a93e38ef62ad5b Mon Sep 17 00:00:00 2001 From: Pavel Lutskov Date: Sat, 2 Aug 2025 15:41:51 +0200 Subject: [PATCH] make things nicer --- main.py | 206 +++++++++++++++++++++++++++++++++++++------------------- 1 file changed, 137 insertions(+), 69 deletions(-) diff --git a/main.py b/main.py index 91640c9..cae83d8 100644 --- a/main.py +++ b/main.py @@ -1,12 +1,44 @@ +import json +import os +from threading import Lock + import cv2 import numpy as np import torch -from tqdm import tqdm WIN = "window" -SPACING = 50 +SPACING = 150 GREEN = (0, 255, 0) RED = (0, 0, 255) +LIGHT_GREEN = (63, 191, 63) +LIGHT_RED = (63, 63, 191) +LINE_FILE = "/tmp/line.json" + + +def save_line(line): + with open(LINE_FILE, "w") as f: + json.dump(line, f) + + +def try_load_line(): + if not os.path.exists(LINE_FILE): + return [] + with open(LINE_FILE, "r") as f: + return json.load(f) + + +def draw_circles(img, points, size, color): + for point in points: + cv2.circle(img, tuple(np.round(point).astype(np.int32)), size, color, -1) + return img + + +def draw_line(img, line, color, *, with_points=False): + if len(line) > 1: + cv2.polylines(img, [np.round(line).astype(np.int32)], False, color, 2) + if with_points: + img = draw_circles(img, line, 5, color) + return img def area_between_curves(a, b): @@ -18,28 +50,47 @@ def area_between_curves(a, b): return 0.5 * torch.abs(x @ torch.roll(y, -1) - torch.roll(x, -1) @ y) -def sample_polyline(line, t): +def find_beta(x, a, b): + return (x - a) / (b - a + 1e-8) + + +def sample_segments(line, t): deltas = seglengths(line) d = torch.cumsum(deltas, dim=0) / torch.sum(deltas) d = torch.cat([torch.zeros_like(d[0:1]), d]) - idx = torch.searchsorted(d, t) - - plus = line[idx] - minus = line[idx - 1] - rem = (t - d[idx - 1]) / (d[idx] - d[idx - 1] + 1e-8) - return (1 - rem[..., None]) * minus + rem[..., None] * plus + idx = torch.minimum( + torch.searchsorted(d, t), + torch.tensor(line.shape[-2] - 1, dtype=torch.long), + ) + return ( + torch.stack([line[idx - 1], line[idx]], dim=-2), + torch.stack([d[idx - 1], d[idx]], dim=-1), + ) -def per_point_manhattan_distance(a, b, t): - points_a = sample_polyline(a, t) - points_b = sample_polyline(b, t) - return torch.sum(torch.abs(points_a - points_b), dim=-1) +def sample_polyline(line, t): + segs, ds = sample_segments(line, t) + beta = find_beta(t, ds[..., 0], ds[..., 1])[..., None] + return (1 - beta) * segs[..., 0, :] + beta * segs[..., 1, :] -def per_point_euclidean_distance(a, b, t): - points_a = sample_polyline(a, t) - points_b = sample_polyline(b, t) - return torch.norm(points_a - points_b, dim=-1) +def l1_norm(x): + return torch.sum(torch.abs(x), dim=-1) + + +def l2_norm_squared(x): + return torch.sum(torch.square(x), dim=-1) + + +def cross_2d(a, b): + return a[..., 0] * b[..., 1] - a[..., 1] * b[..., 0] + + +def point_to_segment_distance(p, s): + s0 = s[..., 0, :] + s1 = s[..., 1, :] + ds = s1 - s0 + return torch.abs(cross_2d(p, ds) + cross_2d(s1, s0)) / torch.norm(ds, dim=-1) def seglengths(line): @@ -53,41 +104,65 @@ def spacing_error(line): def fit_to_line(line, *, canvas): total_length = np.sum(np.linalg.norm(np.diff(line, axis=0), axis=1)) - num_points = np.ceil(total_length / SPACING).astype(int) + num_points = np.round(total_length / SPACING + 1).astype(np.long) + # Generate evenly spaced points along the line if len(line) < 2: return None init = np.linspace(line[0], line[-1], num_points) - if len(line) < 3: + if len(init) < 3: return init + img = canvas.copy() + img = draw_line(img, line, GREEN) + fit = torch.tensor(init, dtype=torch.float32, requires_grad=True) - line_pt = torch.tensor(line, dtype=torch.float32) optimizer = torch.optim.Adam([fit], lr=1) + line_pt = torch.tensor(line, dtype=torch.float32) + fit_t = torch.linspace(0, 1, num_points, dtype=torch.float32) + segs_at_exact, _ = sample_segments(line_pt, fit_t) + line_at_exact = sample_polyline(line_pt, fit_t) + while True: + show = img.copy() optimizer.zero_grad() - sample_points = torch.rand((num_points // 2), dtype=torch.float32) - sample_line = sample_polyline(line_pt, sample_points) - sample_fit = sample_polyline(fit, sample_points) + def point_to_point_error(radius=SPACING): + # ruff: noqa: B023 + return torch.mean(l2_norm_squared(line_at_exact - fit) / radius**2) - loss = torch.mean(torch.norm(sample_line - sample_fit, dim=-1)) + spacing_error( - fit + def point_to_line_error(radius=SPACING): + # ruff: noqa: B023 + return torch.mean(point_to_segment_distance(fit, segs_at_exact) / radius) + + def line_to_line_error(radius=SPACING): + sample_points = torch.rand((num_points // 2), dtype=torch.float32) + sample_line = sample_polyline(line_pt, sample_points) + sample_fit = sample_polyline(fit, sample_points) + nonlocal show + show = draw_circles(show, sample_line.numpy(), 5, LIGHT_GREEN) + show = draw_circles(show, sample_fit.detach().numpy(), 5, LIGHT_RED) + return torch.mean(l1_norm(sample_line - sample_fit) / radius) + + loss = ( + 0 + # + + 0.8 * point_to_point_error() + + point_to_line_error() * 4 + + line_to_line_error() + # + # + spacing_error(fit) ) loss.backward() optimizer.step() - img = canvas.copy() - img = draw_line(img, line_pt.numpy(), GREEN) - img = draw_line(img, fit.detach().numpy(), RED, with_points=True) - # img = draw_circles(img, sample_line.numpy(), GREEN) - # img = draw_circles(img, sample_fit.detach().numpy(), RED) - cv2.imshow(WIN, img) - if cv2.waitKey(1) == ord("q"): + show = draw_line(show, fit.detach().numpy(), RED, with_points=True) + cv2.imshow(WIN, show) + if cv2.waitKey(1) == ord("x"): break fit = fit.detach() @@ -99,56 +174,49 @@ def fit_to_line(line, *, canvas): return fit.numpy() -def draw_circles(img, points, color): - for point in points: - cv2.circle(img, tuple(np.round(point).astype(np.int32)), 5, color, -1) - return img - - -def draw_line(img, line, color, *, with_points=False): - if len(line) > 1: - cv2.polylines(img, [np.round(line).astype(np.int32)], False, color, 2) - if with_points: - img = draw_circles(img, line, color) - return img - - def main(): - line = [] - is_calculating = {} + line = try_load_line() + lock = Lock() - canvas = np.zeros((500, 500, 3), dtype=np.uint8) - - cv2.namedWindow(WIN) - cv2.imshow(WIN, canvas) - - def on_mouse(event, x, y, flags, _): - - def lbuttondown(): - line.append((x, y)) - if len(line) == 1: - return - line_np = np.array(line) - img = draw_line(canvas.copy(), line_np, GREEN) - if (fit := fit_to_line(line_np, canvas=canvas)) is not None: - img = draw_line(img, fit, RED) + def optimization(): + if len(line) < 2: + return + img = draw_line(canvas.copy(), line, GREEN) + fit = fit_to_line(np.array(line), canvas=canvas) + if fit is not None: + img = draw_line(img, fit, RED) cv2.imshow(WIN, img) - if is_calculating.get(event, False): + def on_mouse(event, x, y, flags, _): + if lock.locked(): return - is_calculating[event] = True - if event == cv2.EVENT_LBUTTONDOWN: - lbuttondown() - - is_calculating[event] = False + with lock: + line.append((x, y)) + optimization() + canvas = np.zeros((500, 500, 3), dtype=np.uint8) + cv2.namedWindow(WIN) + cv2.imshow(WIN, canvas) + cv2.waitKey(1) cv2.setMouseCallback(WIN, on_mouse) + if line: + with lock: + optimization() + while True: if (k := cv2.waitKey(1)) == ord("c"): line.clear() cv2.imshow(WIN, canvas) + elif k == ord("s"): + if line: + save_line(line) + elif k == ord("l"): + line[:] = try_load_line() + elif k == ord("d"): + if os.path.isfile(LINE_FILE): + os.unlink(LINE_FILE) elif k == ord("q"): break