Implement landmark plot function
This commit is contained in:
25
src/falign/plot.py
Normal file
25
src/falign/plot.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from pathlib import Path
|
||||
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from skimage.io import imsave
|
||||
|
||||
|
||||
def imsave_with_landmarks(
|
||||
path: Path,
|
||||
image: NDArray, landmarks: NDArray,
|
||||
size: int = 1,
|
||||
) -> None:
|
||||
"""
|
||||
Save image with landmarks.
|
||||
Args:
|
||||
path: The path to save the image.
|
||||
image: The image array.
|
||||
landmarks: The landmarks array of shape (N, 2).
|
||||
"""
|
||||
|
||||
for landmark in landmarks:
|
||||
x, y = int(landmark[0]), int(landmark[1])
|
||||
image[(y-size):(y+size+1), (x-size):(x+size+1)] = 255
|
||||
|
||||
imsave(path, image)
|
||||
@@ -1,8 +1,9 @@
|
||||
from pathlib import Path
|
||||
|
||||
from matplotlib import pyplot
|
||||
from skimage.io import imread
|
||||
|
||||
from falign.landmarks import get_landmarks
|
||||
from falign.plot import imsave_with_landmarks
|
||||
|
||||
|
||||
def test_get_landmarks():
|
||||
@@ -12,11 +13,8 @@ def test_get_landmarks():
|
||||
|
||||
assert landmarks.shape == (68, 2)
|
||||
|
||||
fig, ax = pyplot.subplots()
|
||||
ax.imshow(pyplot.imread(path_image))
|
||||
ax.scatter(landmarks[:, 0], landmarks[:, 1], s=1)
|
||||
ax.axis("off")
|
||||
ax.margins(0)
|
||||
ax.set_aspect("equal")
|
||||
fig.tight_layout(pad=0)
|
||||
fig.savefig(dir_gallery / "test_get_landmarks.jpg")
|
||||
imsave_with_landmarks(
|
||||
dir_gallery / "test_get_landmarks.jpg",
|
||||
imread(path_image), landmarks,
|
||||
size=5
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user