Implement get_landmarks
This commit is contained in:
@@ -11,6 +11,7 @@ requires-python = "~3.13"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"numpy (>=2.3.3,<3.0.0)",
|
"numpy (>=2.3.3,<3.0.0)",
|
||||||
"face-alignment (>=1.4.1,<2.0.0)",
|
"face-alignment (>=1.4.1,<2.0.0)",
|
||||||
|
"scikit-image (>=0.25.2,<0.26.0)",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -24,3 +25,4 @@ torch = [
|
|||||||
{url = "https://download.pytorch.org/whl/cu129/torch-2.8.0%2Bcu129-cp313-cp313-win_amd64.whl", platform = "win32"},
|
{url = "https://download.pytorch.org/whl/cu129/torch-2.8.0%2Bcu129-cp313-cp313-win_amd64.whl", platform = "win32"},
|
||||||
]
|
]
|
||||||
pytest = "^8.4.2"
|
pytest = "^8.4.2"
|
||||||
|
matplotlib = "^3.10.6"
|
||||||
|
|||||||
21
src/falign/landmarks.py
Normal file
21
src/falign/landmarks.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from skimage.io import imread
|
||||||
|
import torch
|
||||||
|
import face_alignment
|
||||||
|
|
||||||
|
|
||||||
|
def _get_face_landmarks(image: NDArray, face_aligner: face_alignment.FaceAlignment):
|
||||||
|
landmarks = face_aligner.get_landmarks(image)
|
||||||
|
assert landmarks is not None, "No face detected"
|
||||||
|
|
||||||
|
return landmarks[0]
|
||||||
|
|
||||||
|
|
||||||
|
def get_landmarks(path_image: Path, dtype: torch.dtype = torch.float32, device: torch.device = torch.device("cuda")):
|
||||||
|
face_aligner = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, dtype=dtype, device=device.type)
|
||||||
|
image = imread(path_image)
|
||||||
|
|
||||||
|
return _get_face_landmarks(image, face_aligner)
|
||||||
0
test/test_falign/__init__.py
Normal file
0
test/test_falign/__init__.py
Normal file
22
test/test_falign/test_landmarks.py
Normal file
22
test/test_falign/test_landmarks.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from matplotlib import pyplot
|
||||||
|
|
||||||
|
from falign.landmarks import get_landmarks
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_landmarks():
|
||||||
|
dir_gallery = Path(__file__).parent.parent.parent / "gallery"
|
||||||
|
path_image = dir_gallery / "original.jpg"
|
||||||
|
landmarks = get_landmarks(path_image)
|
||||||
|
|
||||||
|
assert landmarks.shape == (68, 2)
|
||||||
|
|
||||||
|
fig, ax = pyplot.subplots()
|
||||||
|
ax.imshow(pyplot.imread(path_image))
|
||||||
|
ax.scatter(landmarks[:, 0], landmarks[:, 1], s=1)
|
||||||
|
ax.axis("off")
|
||||||
|
ax.margins(0)
|
||||||
|
ax.set_aspect("equal")
|
||||||
|
fig.tight_layout(pad=0)
|
||||||
|
fig.savefig(dir_gallery / "test_get_landmarks.jpg")
|
||||||
Reference in New Issue
Block a user