initial commit
This commit is contained in:
157
main.py
Normal file
157
main.py
Normal file
@@ -0,0 +1,157 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
WIN = "window"
|
||||
SPACING = 50
|
||||
GREEN = (0, 255, 0)
|
||||
RED = (0, 0, 255)
|
||||
|
||||
|
||||
def area_between_curves(a, b):
|
||||
polygon = torch.cat([b, torch.flip(a, dims=[0])], dim=0)
|
||||
|
||||
# Calculate the area using the shoelace formula
|
||||
x = polygon[:, 0]
|
||||
y = polygon[:, 1]
|
||||
return 0.5 * torch.abs(x @ torch.roll(y, -1) - torch.roll(x, -1) @ y)
|
||||
|
||||
|
||||
def sample_polyline(line, t):
|
||||
deltas = seglengths(line)
|
||||
d = torch.cumsum(deltas, dim=0) / torch.sum(deltas)
|
||||
d = torch.cat([torch.zeros_like(d[0:1]), d])
|
||||
idx = torch.searchsorted(d, t)
|
||||
|
||||
plus = line[idx]
|
||||
minus = line[idx - 1]
|
||||
rem = (t - d[idx - 1]) / (d[idx] - d[idx - 1] + 1e-8)
|
||||
return (1 - rem[..., None]) * minus + rem[..., None] * plus
|
||||
|
||||
|
||||
def per_point_manhattan_distance(a, b, t):
|
||||
points_a = sample_polyline(a, t)
|
||||
points_b = sample_polyline(b, t)
|
||||
return torch.sum(torch.abs(points_a - points_b), dim=-1)
|
||||
|
||||
|
||||
def per_point_euclidean_distance(a, b, t):
|
||||
points_a = sample_polyline(a, t)
|
||||
points_b = sample_polyline(b, t)
|
||||
return torch.norm(points_a - points_b, dim=-1)
|
||||
|
||||
|
||||
def seglengths(line):
|
||||
return torch.norm(torch.diff(line, dim=-2), dim=-1)
|
||||
|
||||
|
||||
def spacing_error(line):
|
||||
# Calculate the distances between consecutive points
|
||||
return torch.mean(torch.abs(SPACING - seglengths(line)))
|
||||
|
||||
|
||||
def fit_to_line(line, *, canvas):
|
||||
total_length = np.sum(np.linalg.norm(np.diff(line, axis=0), axis=1))
|
||||
num_points = np.ceil(total_length / SPACING).astype(int)
|
||||
# Generate evenly spaced points along the line
|
||||
if len(line) < 2:
|
||||
return None
|
||||
|
||||
init = np.linspace(line[0], line[-1], num_points)
|
||||
|
||||
if len(line) < 3:
|
||||
return init
|
||||
|
||||
fit = torch.tensor(init, dtype=torch.float32, requires_grad=True)
|
||||
line_pt = torch.tensor(line, dtype=torch.float32)
|
||||
optimizer = torch.optim.Adam([fit], lr=1)
|
||||
|
||||
while True:
|
||||
optimizer.zero_grad()
|
||||
|
||||
sample_points = torch.rand((num_points // 2), dtype=torch.float32)
|
||||
sample_line = sample_polyline(line_pt, sample_points)
|
||||
sample_fit = sample_polyline(fit, sample_points)
|
||||
|
||||
loss = torch.mean(torch.norm(sample_line - sample_fit, dim=-1)) + spacing_error(
|
||||
fit
|
||||
)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
img = canvas.copy()
|
||||
img = draw_line(img, line_pt.numpy(), GREEN)
|
||||
img = draw_line(img, fit.detach().numpy(), RED, with_points=True)
|
||||
# img = draw_circles(img, sample_line.numpy(), GREEN)
|
||||
# img = draw_circles(img, sample_fit.detach().numpy(), RED)
|
||||
cv2.imshow(WIN, img)
|
||||
if cv2.waitKey(1) == ord("q"):
|
||||
break
|
||||
|
||||
fit = fit.detach()
|
||||
abc = area_between_curves(line_pt, fit)
|
||||
spacing_err = spacing_error(fit)
|
||||
print("Area between curves:", abc.item())
|
||||
print("Spacing error:", spacing_err.item())
|
||||
|
||||
return fit.numpy()
|
||||
|
||||
|
||||
def draw_circles(img, points, color):
|
||||
for point in points:
|
||||
cv2.circle(img, tuple(np.round(point).astype(np.int32)), 5, color, -1)
|
||||
return img
|
||||
|
||||
|
||||
def draw_line(img, line, color, *, with_points=False):
|
||||
if len(line) > 1:
|
||||
cv2.polylines(img, [np.round(line).astype(np.int32)], False, color, 2)
|
||||
if with_points:
|
||||
img = draw_circles(img, line, color)
|
||||
return img
|
||||
|
||||
|
||||
def main():
|
||||
line = []
|
||||
is_calculating = {}
|
||||
|
||||
canvas = np.zeros((500, 500, 3), dtype=np.uint8)
|
||||
|
||||
cv2.namedWindow(WIN)
|
||||
cv2.imshow(WIN, canvas)
|
||||
|
||||
def on_mouse(event, x, y, flags, _):
|
||||
|
||||
def lbuttondown():
|
||||
line.append((x, y))
|
||||
if len(line) == 1:
|
||||
return
|
||||
line_np = np.array(line)
|
||||
img = draw_line(canvas.copy(), line_np, GREEN)
|
||||
if (fit := fit_to_line(line_np, canvas=canvas)) is not None:
|
||||
img = draw_line(img, fit, RED)
|
||||
cv2.imshow(WIN, img)
|
||||
|
||||
if is_calculating.get(event, False):
|
||||
return
|
||||
is_calculating[event] = True
|
||||
|
||||
if event == cv2.EVENT_LBUTTONDOWN:
|
||||
lbuttondown()
|
||||
|
||||
is_calculating[event] = False
|
||||
|
||||
cv2.setMouseCallback(WIN, on_mouse)
|
||||
|
||||
while True:
|
||||
if (k := cv2.waitKey(1)) == ord("c"):
|
||||
line.clear()
|
||||
cv2.imshow(WIN, canvas)
|
||||
elif k == ord("q"):
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user