from collections import defaultdict

import face_recognition
import numpy as np

from utils import LOGGER


def face_detection_and_recognition_3(id_list, body_list):
    id_to_body = defaultdict(list)
    if body_list == []:
        return id_list, body_list, id_to_body

    ret_body_list = []
    for idx, ((cam_id, body), frame_id, encoding) in enumerate(body_list):
        if encoding is not None:
            ret_body_list.append(((cam_id, body), frame_id, encoding))
            continue

        face_locs = face_recognition.face_locations(body)
        if face_locs:
            locs, img = unique_face(face_locs, body)
            face_img = img[locs[0][0]:locs[0][2], locs[0][3]:locs[0][1]]

            face_landmarks_list = face_recognition.face_landmarks(face_img)  # obtain the landmarks
            if not face_landmarks_list:
                continue

            dis1 = face_landmarks_list[0]['bottom_lip'][0][1] - face_landmarks_list[0]['right_eye'][0][1]
            if dis1 > 13:
                body_encoding = face_recognition.face_encodings(img, locs)[0].reshape((1, 128))
                ret_body_list.append(((cam_id, body), frame_id, body_encoding))
            else:
                LOGGER.debug('face not valid.')

    if ret_body_list == []:
        LOGGER.debug('all face small than 20.')
        return id_list, ret_body_list, id_to_body

    bodies_encoding = np.empty((0, 128))
    for body in ret_body_list:
        _, _, body_encoding = body
        bodies_encoding = np.vstack((bodies_encoding, body_encoding))

    if id_list == []:
        return id_list, ret_body_list, id_to_body

    ret_id_list = []
    for face_id, frame_id, id_encoding in id_list:
        if len(bodies_encoding) == 0:
            ret_id_list.append((face_id, frame_id))
            continue

        dist = face_recognition.face_distance(bodies_encoding, id_encoding)
        indies = np.where(dist < 0.4)[0]
        if len(indies) != 0:
            for idx in indies:
                id_to_body[face_id].append(ret_body_list[idx][0])

            bodies_encoding = np.delete(bodies_encoding, indies, axis=0)
            for idx in sorted(indies, reverse=True):
                ret_body_list.pop(idx)
        else:
            ret_id_list.append((face_id, frame_id, id_encoding))

    return ret_id_list, ret_body_list, id_to_body


def unique_face(face_locations, img):
    """

    :param face_locations: top, right, bottom, left
    :param img:
    :return:
    """
    if len(face_locations) == 1:
        return face_locations, img

    areas = [abs((loc[2] - loc[0]) * (loc[1] - loc[3])) for loc in face_locations]
    max_ind = np.argmax(areas)

    for i, loc in enumerate(face_locations):
        if i == max_ind:
            continue
        img[loc[0]:loc[2], loc[3]:loc[1]] = [[[128, 128, 128]]]

    return [face_locations[max_ind]], img
