import cv2 import numpy as np import torch from tqdm import tqdm WIN = "window" SPACING = 50 GREEN = (0, 255, 0) RED = (0, 0, 255) def area_between_curves(a, b): polygon = torch.cat([b, torch.flip(a, dims=[0])], dim=0) # Calculate the area using the shoelace formula x = polygon[:, 0] y = polygon[:, 1] return 0.5 * torch.abs(x @ torch.roll(y, -1) - torch.roll(x, -1) @ y) def sample_polyline(line, t): deltas = seglengths(line) d = torch.cumsum(deltas, dim=0) / torch.sum(deltas) d = torch.cat([torch.zeros_like(d[0:1]), d]) idx = torch.searchsorted(d, t) plus = line[idx] minus = line[idx - 1] rem = (t - d[idx - 1]) / (d[idx] - d[idx - 1] + 1e-8) return (1 - rem[..., None]) * minus + rem[..., None] * plus def per_point_manhattan_distance(a, b, t): points_a = sample_polyline(a, t) points_b = sample_polyline(b, t) return torch.sum(torch.abs(points_a - points_b), dim=-1) def per_point_euclidean_distance(a, b, t): points_a = sample_polyline(a, t) points_b = sample_polyline(b, t) return torch.norm(points_a - points_b, dim=-1) def seglengths(line): return torch.norm(torch.diff(line, dim=-2), dim=-1) def spacing_error(line): # Calculate the distances between consecutive points return torch.mean(torch.abs(SPACING - seglengths(line))) def fit_to_line(line, *, canvas): total_length = np.sum(np.linalg.norm(np.diff(line, axis=0), axis=1)) num_points = np.ceil(total_length / SPACING).astype(int) # Generate evenly spaced points along the line if len(line) < 2: return None init = np.linspace(line[0], line[-1], num_points) if len(line) < 3: return init fit = torch.tensor(init, dtype=torch.float32, requires_grad=True) line_pt = torch.tensor(line, dtype=torch.float32) optimizer = torch.optim.Adam([fit], lr=1) while True: optimizer.zero_grad() sample_points = torch.rand((num_points // 2), dtype=torch.float32) sample_line = sample_polyline(line_pt, sample_points) sample_fit = sample_polyline(fit, sample_points) loss = torch.mean(torch.norm(sample_line - sample_fit, dim=-1)) + spacing_error( fit ) loss.backward() optimizer.step() img = canvas.copy() img = draw_line(img, line_pt.numpy(), GREEN) img = draw_line(img, fit.detach().numpy(), RED, with_points=True) # img = draw_circles(img, sample_line.numpy(), GREEN) # img = draw_circles(img, sample_fit.detach().numpy(), RED) cv2.imshow(WIN, img) if cv2.waitKey(1) == ord("q"): break fit = fit.detach() abc = area_between_curves(line_pt, fit) spacing_err = spacing_error(fit) print("Area between curves:", abc.item()) print("Spacing error:", spacing_err.item()) return fit.numpy() def draw_circles(img, points, color): for point in points: cv2.circle(img, tuple(np.round(point).astype(np.int32)), 5, color, -1) return img def draw_line(img, line, color, *, with_points=False): if len(line) > 1: cv2.polylines(img, [np.round(line).astype(np.int32)], False, color, 2) if with_points: img = draw_circles(img, line, color) return img def main(): line = [] is_calculating = {} canvas = np.zeros((500, 500, 3), dtype=np.uint8) cv2.namedWindow(WIN) cv2.imshow(WIN, canvas) def on_mouse(event, x, y, flags, _): def lbuttondown(): line.append((x, y)) if len(line) == 1: return line_np = np.array(line) img = draw_line(canvas.copy(), line_np, GREEN) if (fit := fit_to_line(line_np, canvas=canvas)) is not None: img = draw_line(img, fit, RED) cv2.imshow(WIN, img) if is_calculating.get(event, False): return is_calculating[event] = True if event == cv2.EVENT_LBUTTONDOWN: lbuttondown() is_calculating[event] = False cv2.setMouseCallback(WIN, on_mouse) while True: if (k := cv2.waitKey(1)) == ord("c"): line.clear() cv2.imshow(WIN, canvas) elif k == ord("q"): break if __name__ == "__main__": main()