good enough
This commit is contained in:
25
main.py
25
main.py
@@ -7,11 +7,11 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
WIN = "window"
|
WIN = "window"
|
||||||
SPACING = 150
|
SPACING = 50
|
||||||
GREEN = (0, 255, 0)
|
GREEN = (0, 255, 0)
|
||||||
RED = (0, 0, 255)
|
RED = (0, 0, 255)
|
||||||
LIGHT_GREEN = (63, 191, 63)
|
LIGHT_GREEN = (31, 167, 31)
|
||||||
LIGHT_RED = (63, 63, 191)
|
LIGHT_RED = (31, 31, 167)
|
||||||
LINE_FILE = "/tmp/line.json"
|
LINE_FILE = "/tmp/line.json"
|
||||||
|
|
||||||
|
|
||||||
@@ -118,8 +118,12 @@ def fit_to_line(line, *, canvas):
|
|||||||
img = canvas.copy()
|
img = canvas.copy()
|
||||||
img = draw_line(img, line, GREEN)
|
img = draw_line(img, line, GREEN)
|
||||||
|
|
||||||
|
show = draw_line(img.copy(), init, RED, with_points=True)
|
||||||
|
cv2.imshow(WIN, show)
|
||||||
|
cv2.waitKey(1)
|
||||||
|
|
||||||
fit = torch.tensor(init, dtype=torch.float32, requires_grad=True)
|
fit = torch.tensor(init, dtype=torch.float32, requires_grad=True)
|
||||||
optimizer = torch.optim.Adam([fit], lr=1)
|
optimizer = torch.optim.Adam([fit], lr=5)
|
||||||
|
|
||||||
line_pt = torch.tensor(line, dtype=torch.float32)
|
line_pt = torch.tensor(line, dtype=torch.float32)
|
||||||
fit_t = torch.linspace(0, 1, num_points, dtype=torch.float32)
|
fit_t = torch.linspace(0, 1, num_points, dtype=torch.float32)
|
||||||
@@ -150,8 +154,8 @@ def fit_to_line(line, *, canvas):
|
|||||||
loss = (
|
loss = (
|
||||||
0
|
0
|
||||||
#
|
#
|
||||||
+ 0.8 * point_to_point_error()
|
+ point_to_point_error()
|
||||||
+ point_to_line_error() * 4
|
+ point_to_line_error() * 2
|
||||||
+ line_to_line_error()
|
+ line_to_line_error()
|
||||||
#
|
#
|
||||||
# + spacing_error(fit)
|
# + spacing_error(fit)
|
||||||
@@ -184,7 +188,7 @@ def main():
|
|||||||
img = draw_line(canvas.copy(), line, GREEN)
|
img = draw_line(canvas.copy(), line, GREEN)
|
||||||
fit = fit_to_line(np.array(line), canvas=canvas)
|
fit = fit_to_line(np.array(line), canvas=canvas)
|
||||||
if fit is not None:
|
if fit is not None:
|
||||||
img = draw_line(img, fit, RED)
|
img = draw_line(img, fit, RED, with_points=True)
|
||||||
cv2.imshow(WIN, img)
|
cv2.imshow(WIN, img)
|
||||||
|
|
||||||
def on_mouse(event, x, y, flags, _):
|
def on_mouse(event, x, y, flags, _):
|
||||||
@@ -201,9 +205,8 @@ def main():
|
|||||||
cv2.waitKey(1)
|
cv2.waitKey(1)
|
||||||
cv2.setMouseCallback(WIN, on_mouse)
|
cv2.setMouseCallback(WIN, on_mouse)
|
||||||
|
|
||||||
if line:
|
with lock:
|
||||||
with lock:
|
optimization()
|
||||||
optimization()
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
if (k := cv2.waitKey(1)) == ord("c"):
|
if (k := cv2.waitKey(1)) == ord("c"):
|
||||||
@@ -214,6 +217,8 @@ def main():
|
|||||||
save_line(line)
|
save_line(line)
|
||||||
elif k == ord("l"):
|
elif k == ord("l"):
|
||||||
line[:] = try_load_line()
|
line[:] = try_load_line()
|
||||||
|
with lock:
|
||||||
|
optimization()
|
||||||
elif k == ord("d"):
|
elif k == ord("d"):
|
||||||
if os.path.isfile(LINE_FILE):
|
if os.path.isfile(LINE_FILE):
|
||||||
os.unlink(LINE_FILE)
|
os.unlink(LINE_FILE)
|
||||||
|
|||||||
Reference in New Issue
Block a user