diff --git a/pyproject.toml b/pyproject.toml index 8cf9b14..a1016e6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,7 @@ requires-python = "~3.13" dependencies = [ "numpy (>=2.3.3,<3.0.0)", "face-alignment (>=1.4.1,<2.0.0)", + "scikit-image (>=0.25.2,<0.26.0)", ] @@ -24,3 +25,4 @@ torch = [ {url = "https://download.pytorch.org/whl/cu129/torch-2.8.0%2Bcu129-cp313-cp313-win_amd64.whl", platform = "win32"}, ] pytest = "^8.4.2" +matplotlib = "^3.10.6" diff --git a/src/falign/landmarks.py b/src/falign/landmarks.py new file mode 100644 index 0000000..84bdd1b --- /dev/null +++ b/src/falign/landmarks.py @@ -0,0 +1,21 @@ +from numpy.typing import NDArray + +from pathlib import Path + +from skimage.io import imread +import torch +import face_alignment + + +def _get_face_landmarks(image: NDArray, face_aligner: face_alignment.FaceAlignment): + landmarks = face_aligner.get_landmarks(image) + assert landmarks is not None, "No face detected" + + return landmarks[0] + + +def get_landmarks(path_image: Path, dtype: torch.dtype = torch.float32, device: torch.device = torch.device("cuda")): + face_aligner = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, dtype=dtype, device=device.type) + image = imread(path_image) + + return _get_face_landmarks(image, face_aligner) diff --git a/test/test_falign/__init__.py b/test/test_falign/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/test_falign/test_landmarks.py b/test/test_falign/test_landmarks.py new file mode 100644 index 0000000..22f3b08 --- /dev/null +++ b/test/test_falign/test_landmarks.py @@ -0,0 +1,22 @@ +from pathlib import Path + +from matplotlib import pyplot + +from falign.landmarks import get_landmarks + + +def test_get_landmarks(): + dir_gallery = Path(__file__).parent.parent.parent / "gallery" + path_image = dir_gallery / "original.jpg" + landmarks = get_landmarks(path_image) + + assert landmarks.shape == (68, 2) + + fig, ax = pyplot.subplots() + ax.imshow(pyplot.imread(path_image)) + ax.scatter(landmarks[:, 0], landmarks[:, 1], s=1) + ax.axis("off") + ax.margins(0) + ax.set_aspect("equal") + fig.tight_layout(pad=0) + fig.savefig(dir_gallery / "test_get_landmarks.jpg")