import json
import os
from multiprocessing import Process, Queue
from pprint import pformat
from time import sleep

import torch
from torchvision import transforms as T

from attribute_detect.net import ResNet50_nFC
from utils import LOGGER

trans = T.Compose([
    T.Resize(size=(288, 144)),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])


def start_attr_det(cfg):
    attr_queue = Queue(maxsize=64)
    cfg.ATTR.PIPE = attr_queue
    p1 = Process(target=attr_det, args=(cfg,))
    p1.start()


def attr_det(cfg):
    attr_model = load_network()
    attr_queue = cfg.ATTR.PIPE

    while True:
        if cfg.ATTR.PIPE and not attr_queue.empty():
            src, face_folder, face_id = attr_queue.get(timeout=0.1)
            src = trans(src)
            src = src.unsqueeze(dim=0)
            out = attr_model.forward(src)
            pred = torch.gt(out, torch.ones_like(out) * 0.5)  # threshold=0.5
            Dec = predict_decoder(cfg.ATTR.DATASET)
            Dec.decode(pred, face_folder, face_id)  # 在此函数中做具体的操作
        else:
            sleep(0.1)


def load_network():
    network = ResNet50_nFC(30)
    save_path = os.path.join('attribute_detect/checkpoints', 'market', 'resnet50', 'net-last.pth')
    network.load_state_dict(torch.load(save_path))
    network = network.eval()
    return network


class predict_decoder(object):

    def __init__(self, dataset):
        with open('attribute_detect/doc/label.json', 'r') as f:
            self.label_list = json.load(f)[dataset]
        with open('attribute_detect/doc/attribute.json', 'r') as f:
            self.attribute_dict = json.load(f)[dataset]
        self.dataset = dataset
        self.num_label = len(self.label_list)

    def decode(self, pred, face_folder, face_id):
        pred = pred.squeeze(dim=0)
        attributes = {}
        for idx in range(self.num_label):
            name, choice = self.attribute_dict[self.label_list[idx]]
            if choice[pred[idx]]:
                attributes[name] = choice[pred[idx]]
        with open(os.path.join(face_folder, 'attribute_{:d}.json'.format(face_id)), 'w') as f:
            json.dump(attributes, f, indent=2)

        pstr = pformat(attributes, indent=2)
        LOGGER.info('\n{:s}'.format(pstr))
