good enough

This commit is contained in:
2025-08-02 16:35:50 +02:00
parent c1e4615da6
commit 9225f31fca

25
main.py
View File

@@ -7,11 +7,11 @@ import numpy as np
import torch import torch
WIN = "window" WIN = "window"
SPACING = 150 SPACING = 50
GREEN = (0, 255, 0) GREEN = (0, 255, 0)
RED = (0, 0, 255) RED = (0, 0, 255)
LIGHT_GREEN = (63, 191, 63) LIGHT_GREEN = (31, 167, 31)
LIGHT_RED = (63, 63, 191) LIGHT_RED = (31, 31, 167)
LINE_FILE = "/tmp/line.json" LINE_FILE = "/tmp/line.json"
@@ -118,8 +118,12 @@ def fit_to_line(line, *, canvas):
img = canvas.copy() img = canvas.copy()
img = draw_line(img, line, GREEN) img = draw_line(img, line, GREEN)
show = draw_line(img.copy(), init, RED, with_points=True)
cv2.imshow(WIN, show)
cv2.waitKey(1)
fit = torch.tensor(init, dtype=torch.float32, requires_grad=True) fit = torch.tensor(init, dtype=torch.float32, requires_grad=True)
optimizer = torch.optim.Adam([fit], lr=1) optimizer = torch.optim.Adam([fit], lr=5)
line_pt = torch.tensor(line, dtype=torch.float32) line_pt = torch.tensor(line, dtype=torch.float32)
fit_t = torch.linspace(0, 1, num_points, dtype=torch.float32) fit_t = torch.linspace(0, 1, num_points, dtype=torch.float32)
@@ -150,8 +154,8 @@ def fit_to_line(line, *, canvas):
loss = ( loss = (
0 0
# #
+ 0.8 * point_to_point_error() + point_to_point_error()
+ point_to_line_error() * 4 + point_to_line_error() * 2
+ line_to_line_error() + line_to_line_error()
# #
# + spacing_error(fit) # + spacing_error(fit)
@@ -184,7 +188,7 @@ def main():
img = draw_line(canvas.copy(), line, GREEN) img = draw_line(canvas.copy(), line, GREEN)
fit = fit_to_line(np.array(line), canvas=canvas) fit = fit_to_line(np.array(line), canvas=canvas)
if fit is not None: if fit is not None:
img = draw_line(img, fit, RED) img = draw_line(img, fit, RED, with_points=True)
cv2.imshow(WIN, img) cv2.imshow(WIN, img)
def on_mouse(event, x, y, flags, _): def on_mouse(event, x, y, flags, _):
@@ -201,9 +205,8 @@ def main():
cv2.waitKey(1) cv2.waitKey(1)
cv2.setMouseCallback(WIN, on_mouse) cv2.setMouseCallback(WIN, on_mouse)
if line: with lock:
with lock: optimization()
optimization()
while True: while True:
if (k := cv2.waitKey(1)) == ord("c"): if (k := cv2.waitKey(1)) == ord("c"):
@@ -214,6 +217,8 @@ def main():
save_line(line) save_line(line)
elif k == ord("l"): elif k == ord("l"):
line[:] = try_load_line() line[:] = try_load_line()
with lock:
optimization()
elif k == ord("d"): elif k == ord("d"):
if os.path.isfile(LINE_FILE): if os.path.isfile(LINE_FILE):
os.unlink(LINE_FILE) os.unlink(LINE_FILE)