make things nicer
This commit is contained in:
206
main.py
206
main.py
@@ -1,12 +1,44 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
from threading import Lock
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
WIN = "window"
|
WIN = "window"
|
||||||
SPACING = 50
|
SPACING = 150
|
||||||
GREEN = (0, 255, 0)
|
GREEN = (0, 255, 0)
|
||||||
RED = (0, 0, 255)
|
RED = (0, 0, 255)
|
||||||
|
LIGHT_GREEN = (63, 191, 63)
|
||||||
|
LIGHT_RED = (63, 63, 191)
|
||||||
|
LINE_FILE = "/tmp/line.json"
|
||||||
|
|
||||||
|
|
||||||
|
def save_line(line):
|
||||||
|
with open(LINE_FILE, "w") as f:
|
||||||
|
json.dump(line, f)
|
||||||
|
|
||||||
|
|
||||||
|
def try_load_line():
|
||||||
|
if not os.path.exists(LINE_FILE):
|
||||||
|
return []
|
||||||
|
with open(LINE_FILE, "r") as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
def draw_circles(img, points, size, color):
|
||||||
|
for point in points:
|
||||||
|
cv2.circle(img, tuple(np.round(point).astype(np.int32)), size, color, -1)
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def draw_line(img, line, color, *, with_points=False):
|
||||||
|
if len(line) > 1:
|
||||||
|
cv2.polylines(img, [np.round(line).astype(np.int32)], False, color, 2)
|
||||||
|
if with_points:
|
||||||
|
img = draw_circles(img, line, 5, color)
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
def area_between_curves(a, b):
|
def area_between_curves(a, b):
|
||||||
@@ -18,28 +50,47 @@ def area_between_curves(a, b):
|
|||||||
return 0.5 * torch.abs(x @ torch.roll(y, -1) - torch.roll(x, -1) @ y)
|
return 0.5 * torch.abs(x @ torch.roll(y, -1) - torch.roll(x, -1) @ y)
|
||||||
|
|
||||||
|
|
||||||
def sample_polyline(line, t):
|
def find_beta(x, a, b):
|
||||||
|
return (x - a) / (b - a + 1e-8)
|
||||||
|
|
||||||
|
|
||||||
|
def sample_segments(line, t):
|
||||||
deltas = seglengths(line)
|
deltas = seglengths(line)
|
||||||
d = torch.cumsum(deltas, dim=0) / torch.sum(deltas)
|
d = torch.cumsum(deltas, dim=0) / torch.sum(deltas)
|
||||||
d = torch.cat([torch.zeros_like(d[0:1]), d])
|
d = torch.cat([torch.zeros_like(d[0:1]), d])
|
||||||
idx = torch.searchsorted(d, t)
|
idx = torch.minimum(
|
||||||
|
torch.searchsorted(d, t),
|
||||||
plus = line[idx]
|
torch.tensor(line.shape[-2] - 1, dtype=torch.long),
|
||||||
minus = line[idx - 1]
|
)
|
||||||
rem = (t - d[idx - 1]) / (d[idx] - d[idx - 1] + 1e-8)
|
return (
|
||||||
return (1 - rem[..., None]) * minus + rem[..., None] * plus
|
torch.stack([line[idx - 1], line[idx]], dim=-2),
|
||||||
|
torch.stack([d[idx - 1], d[idx]], dim=-1),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def per_point_manhattan_distance(a, b, t):
|
def sample_polyline(line, t):
|
||||||
points_a = sample_polyline(a, t)
|
segs, ds = sample_segments(line, t)
|
||||||
points_b = sample_polyline(b, t)
|
beta = find_beta(t, ds[..., 0], ds[..., 1])[..., None]
|
||||||
return torch.sum(torch.abs(points_a - points_b), dim=-1)
|
return (1 - beta) * segs[..., 0, :] + beta * segs[..., 1, :]
|
||||||
|
|
||||||
|
|
||||||
def per_point_euclidean_distance(a, b, t):
|
def l1_norm(x):
|
||||||
points_a = sample_polyline(a, t)
|
return torch.sum(torch.abs(x), dim=-1)
|
||||||
points_b = sample_polyline(b, t)
|
|
||||||
return torch.norm(points_a - points_b, dim=-1)
|
|
||||||
|
def l2_norm_squared(x):
|
||||||
|
return torch.sum(torch.square(x), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def cross_2d(a, b):
|
||||||
|
return a[..., 0] * b[..., 1] - a[..., 1] * b[..., 0]
|
||||||
|
|
||||||
|
|
||||||
|
def point_to_segment_distance(p, s):
|
||||||
|
s0 = s[..., 0, :]
|
||||||
|
s1 = s[..., 1, :]
|
||||||
|
ds = s1 - s0
|
||||||
|
return torch.abs(cross_2d(p, ds) + cross_2d(s1, s0)) / torch.norm(ds, dim=-1)
|
||||||
|
|
||||||
|
|
||||||
def seglengths(line):
|
def seglengths(line):
|
||||||
@@ -53,41 +104,65 @@ def spacing_error(line):
|
|||||||
|
|
||||||
def fit_to_line(line, *, canvas):
|
def fit_to_line(line, *, canvas):
|
||||||
total_length = np.sum(np.linalg.norm(np.diff(line, axis=0), axis=1))
|
total_length = np.sum(np.linalg.norm(np.diff(line, axis=0), axis=1))
|
||||||
num_points = np.ceil(total_length / SPACING).astype(int)
|
num_points = np.round(total_length / SPACING + 1).astype(np.long)
|
||||||
|
|
||||||
# Generate evenly spaced points along the line
|
# Generate evenly spaced points along the line
|
||||||
if len(line) < 2:
|
if len(line) < 2:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
init = np.linspace(line[0], line[-1], num_points)
|
init = np.linspace(line[0], line[-1], num_points)
|
||||||
|
|
||||||
if len(line) < 3:
|
if len(init) < 3:
|
||||||
return init
|
return init
|
||||||
|
|
||||||
|
img = canvas.copy()
|
||||||
|
img = draw_line(img, line, GREEN)
|
||||||
|
|
||||||
fit = torch.tensor(init, dtype=torch.float32, requires_grad=True)
|
fit = torch.tensor(init, dtype=torch.float32, requires_grad=True)
|
||||||
line_pt = torch.tensor(line, dtype=torch.float32)
|
|
||||||
optimizer = torch.optim.Adam([fit], lr=1)
|
optimizer = torch.optim.Adam([fit], lr=1)
|
||||||
|
|
||||||
|
line_pt = torch.tensor(line, dtype=torch.float32)
|
||||||
|
fit_t = torch.linspace(0, 1, num_points, dtype=torch.float32)
|
||||||
|
segs_at_exact, _ = sample_segments(line_pt, fit_t)
|
||||||
|
line_at_exact = sample_polyline(line_pt, fit_t)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
show = img.copy()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
sample_points = torch.rand((num_points // 2), dtype=torch.float32)
|
def point_to_point_error(radius=SPACING):
|
||||||
sample_line = sample_polyline(line_pt, sample_points)
|
# ruff: noqa: B023
|
||||||
sample_fit = sample_polyline(fit, sample_points)
|
return torch.mean(l2_norm_squared(line_at_exact - fit) / radius**2)
|
||||||
|
|
||||||
loss = torch.mean(torch.norm(sample_line - sample_fit, dim=-1)) + spacing_error(
|
def point_to_line_error(radius=SPACING):
|
||||||
fit
|
# ruff: noqa: B023
|
||||||
|
return torch.mean(point_to_segment_distance(fit, segs_at_exact) / radius)
|
||||||
|
|
||||||
|
def line_to_line_error(radius=SPACING):
|
||||||
|
sample_points = torch.rand((num_points // 2), dtype=torch.float32)
|
||||||
|
sample_line = sample_polyline(line_pt, sample_points)
|
||||||
|
sample_fit = sample_polyline(fit, sample_points)
|
||||||
|
nonlocal show
|
||||||
|
show = draw_circles(show, sample_line.numpy(), 5, LIGHT_GREEN)
|
||||||
|
show = draw_circles(show, sample_fit.detach().numpy(), 5, LIGHT_RED)
|
||||||
|
return torch.mean(l1_norm(sample_line - sample_fit) / radius)
|
||||||
|
|
||||||
|
loss = (
|
||||||
|
0
|
||||||
|
#
|
||||||
|
+ 0.8 * point_to_point_error()
|
||||||
|
+ point_to_line_error() * 4
|
||||||
|
+ line_to_line_error()
|
||||||
|
#
|
||||||
|
# + spacing_error(fit)
|
||||||
)
|
)
|
||||||
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
img = canvas.copy()
|
show = draw_line(show, fit.detach().numpy(), RED, with_points=True)
|
||||||
img = draw_line(img, line_pt.numpy(), GREEN)
|
cv2.imshow(WIN, show)
|
||||||
img = draw_line(img, fit.detach().numpy(), RED, with_points=True)
|
if cv2.waitKey(1) == ord("x"):
|
||||||
# img = draw_circles(img, sample_line.numpy(), GREEN)
|
|
||||||
# img = draw_circles(img, sample_fit.detach().numpy(), RED)
|
|
||||||
cv2.imshow(WIN, img)
|
|
||||||
if cv2.waitKey(1) == ord("q"):
|
|
||||||
break
|
break
|
||||||
|
|
||||||
fit = fit.detach()
|
fit = fit.detach()
|
||||||
@@ -99,56 +174,49 @@ def fit_to_line(line, *, canvas):
|
|||||||
return fit.numpy()
|
return fit.numpy()
|
||||||
|
|
||||||
|
|
||||||
def draw_circles(img, points, color):
|
|
||||||
for point in points:
|
|
||||||
cv2.circle(img, tuple(np.round(point).astype(np.int32)), 5, color, -1)
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def draw_line(img, line, color, *, with_points=False):
|
|
||||||
if len(line) > 1:
|
|
||||||
cv2.polylines(img, [np.round(line).astype(np.int32)], False, color, 2)
|
|
||||||
if with_points:
|
|
||||||
img = draw_circles(img, line, color)
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
line = []
|
line = try_load_line()
|
||||||
is_calculating = {}
|
lock = Lock()
|
||||||
|
|
||||||
canvas = np.zeros((500, 500, 3), dtype=np.uint8)
|
def optimization():
|
||||||
|
if len(line) < 2:
|
||||||
cv2.namedWindow(WIN)
|
return
|
||||||
cv2.imshow(WIN, canvas)
|
img = draw_line(canvas.copy(), line, GREEN)
|
||||||
|
fit = fit_to_line(np.array(line), canvas=canvas)
|
||||||
def on_mouse(event, x, y, flags, _):
|
if fit is not None:
|
||||||
|
img = draw_line(img, fit, RED)
|
||||||
def lbuttondown():
|
|
||||||
line.append((x, y))
|
|
||||||
if len(line) == 1:
|
|
||||||
return
|
|
||||||
line_np = np.array(line)
|
|
||||||
img = draw_line(canvas.copy(), line_np, GREEN)
|
|
||||||
if (fit := fit_to_line(line_np, canvas=canvas)) is not None:
|
|
||||||
img = draw_line(img, fit, RED)
|
|
||||||
cv2.imshow(WIN, img)
|
cv2.imshow(WIN, img)
|
||||||
|
|
||||||
if is_calculating.get(event, False):
|
def on_mouse(event, x, y, flags, _):
|
||||||
|
if lock.locked():
|
||||||
return
|
return
|
||||||
is_calculating[event] = True
|
|
||||||
|
|
||||||
if event == cv2.EVENT_LBUTTONDOWN:
|
if event == cv2.EVENT_LBUTTONDOWN:
|
||||||
lbuttondown()
|
with lock:
|
||||||
|
line.append((x, y))
|
||||||
is_calculating[event] = False
|
optimization()
|
||||||
|
|
||||||
|
canvas = np.zeros((500, 500, 3), dtype=np.uint8)
|
||||||
|
cv2.namedWindow(WIN)
|
||||||
|
cv2.imshow(WIN, canvas)
|
||||||
|
cv2.waitKey(1)
|
||||||
cv2.setMouseCallback(WIN, on_mouse)
|
cv2.setMouseCallback(WIN, on_mouse)
|
||||||
|
|
||||||
|
if line:
|
||||||
|
with lock:
|
||||||
|
optimization()
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
if (k := cv2.waitKey(1)) == ord("c"):
|
if (k := cv2.waitKey(1)) == ord("c"):
|
||||||
line.clear()
|
line.clear()
|
||||||
cv2.imshow(WIN, canvas)
|
cv2.imshow(WIN, canvas)
|
||||||
|
elif k == ord("s"):
|
||||||
|
if line:
|
||||||
|
save_line(line)
|
||||||
|
elif k == ord("l"):
|
||||||
|
line[:] = try_load_line()
|
||||||
|
elif k == ord("d"):
|
||||||
|
if os.path.isfile(LINE_FILE):
|
||||||
|
os.unlink(LINE_FILE)
|
||||||
elif k == ord("q"):
|
elif k == ord("q"):
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user