good enough

This commit is contained in:
2025-08-02 16:35:50 +02:00
parent c1e4615da6
commit 9225f31fca

21
main.py
View File

@@ -7,11 +7,11 @@ import numpy as np
import torch
WIN = "window"
SPACING = 150
SPACING = 50
GREEN = (0, 255, 0)
RED = (0, 0, 255)
LIGHT_GREEN = (63, 191, 63)
LIGHT_RED = (63, 63, 191)
LIGHT_GREEN = (31, 167, 31)
LIGHT_RED = (31, 31, 167)
LINE_FILE = "/tmp/line.json"
@@ -118,8 +118,12 @@ def fit_to_line(line, *, canvas):
img = canvas.copy()
img = draw_line(img, line, GREEN)
show = draw_line(img.copy(), init, RED, with_points=True)
cv2.imshow(WIN, show)
cv2.waitKey(1)
fit = torch.tensor(init, dtype=torch.float32, requires_grad=True)
optimizer = torch.optim.Adam([fit], lr=1)
optimizer = torch.optim.Adam([fit], lr=5)
line_pt = torch.tensor(line, dtype=torch.float32)
fit_t = torch.linspace(0, 1, num_points, dtype=torch.float32)
@@ -150,8 +154,8 @@ def fit_to_line(line, *, canvas):
loss = (
0
#
+ 0.8 * point_to_point_error()
+ point_to_line_error() * 4
+ point_to_point_error()
+ point_to_line_error() * 2
+ line_to_line_error()
#
# + spacing_error(fit)
@@ -184,7 +188,7 @@ def main():
img = draw_line(canvas.copy(), line, GREEN)
fit = fit_to_line(np.array(line), canvas=canvas)
if fit is not None:
img = draw_line(img, fit, RED)
img = draw_line(img, fit, RED, with_points=True)
cv2.imshow(WIN, img)
def on_mouse(event, x, y, flags, _):
@@ -201,7 +205,6 @@ def main():
cv2.waitKey(1)
cv2.setMouseCallback(WIN, on_mouse)
if line:
with lock:
optimization()
@@ -214,6 +217,8 @@ def main():
save_line(line)
elif k == ord("l"):
line[:] = try_load_line()
with lock:
optimization()
elif k == ord("d"):
if os.path.isfile(LINE_FILE):
os.unlink(LINE_FILE)