diff --git a/main.py b/main.py index cae83d8..651136c 100644 --- a/main.py +++ b/main.py @@ -7,11 +7,11 @@ import numpy as np import torch WIN = "window" -SPACING = 150 +SPACING = 50 GREEN = (0, 255, 0) RED = (0, 0, 255) -LIGHT_GREEN = (63, 191, 63) -LIGHT_RED = (63, 63, 191) +LIGHT_GREEN = (31, 167, 31) +LIGHT_RED = (31, 31, 167) LINE_FILE = "/tmp/line.json" @@ -118,8 +118,12 @@ def fit_to_line(line, *, canvas): img = canvas.copy() img = draw_line(img, line, GREEN) + show = draw_line(img.copy(), init, RED, with_points=True) + cv2.imshow(WIN, show) + cv2.waitKey(1) + fit = torch.tensor(init, dtype=torch.float32, requires_grad=True) - optimizer = torch.optim.Adam([fit], lr=1) + optimizer = torch.optim.Adam([fit], lr=5) line_pt = torch.tensor(line, dtype=torch.float32) fit_t = torch.linspace(0, 1, num_points, dtype=torch.float32) @@ -150,8 +154,8 @@ def fit_to_line(line, *, canvas): loss = ( 0 # - + 0.8 * point_to_point_error() - + point_to_line_error() * 4 + + point_to_point_error() + + point_to_line_error() * 2 + line_to_line_error() # # + spacing_error(fit) @@ -184,7 +188,7 @@ def main(): img = draw_line(canvas.copy(), line, GREEN) fit = fit_to_line(np.array(line), canvas=canvas) if fit is not None: - img = draw_line(img, fit, RED) + img = draw_line(img, fit, RED, with_points=True) cv2.imshow(WIN, img) def on_mouse(event, x, y, flags, _): @@ -201,9 +205,8 @@ def main(): cv2.waitKey(1) cv2.setMouseCallback(WIN, on_mouse) - if line: - with lock: - optimization() + with lock: + optimization() while True: if (k := cv2.waitKey(1)) == ord("c"): @@ -214,6 +217,8 @@ def main(): save_line(line) elif k == ord("l"): line[:] = try_load_line() + with lock: + optimization() elif k == ord("d"): if os.path.isfile(LINE_FILE): os.unlink(LINE_FILE)