import torch

from config import config as cfg
from reid.model_3dta import HUANG_3DTA_test
from reid.output import extract_feature, load_3dta, load_3dta_data, output_batch
from utils import LOGGER


class Model3DTA:
    def __init__(self, cfg):
        model_3dta = HUANG_3DTA_test(load_3dta(cfg.REID.WHICH_EPOCH, cfg.REID.NAME, cfg.REID.BATCHSIZE))
        if torch.cuda.is_available():
            reid_device = torch.device('cuda:{:d}'.format(cfg.REID.GPU_ID))
            model_3dta = model_3dta.cuda(reid_device)
            LOGGER.info('REID model using GPU-{:d}'.format(cfg.REID.GPU_ID))
        model_3dta.eval()

        self.init = False
        self.model = model_3dta
        self.gallery_feature = None
        self.dataloaders = {}
        self.image_datasets = {}
        self.cfg = cfg

    def clear(self):
        del self.gallery_feature
        del self.dataloaders
        del self.image_datasets

    def update_gallery(self):
        image_datasets, dataloaders, _ = load_3dta_data(cfg.REID.TEST_DIR,
                                                        cfg.REID.BATCHSIZE,
                                                        ['gallery'])
        self.clear()
        self.gallery_feature = extract_feature(self.model, dataloaders['gallery'])
        self.dataloaders = dataloaders
        self.image_datasets = image_datasets
        self.init = True

    def predict(self, dataset_split='query'):
        if self.init:
            image_datasets, dataloaders, _ = load_3dta_data(cfg.REID.TEST_DIR,
                                                            cfg.REID.BATCHSIZE,
                                                            [dataset_split])
            self.image_datasets[dataset_split] = image_datasets[dataset_split]
            self.dataloaders[dataset_split] = dataloaders[dataset_split]
            imgs_path, score = output_batch(self.image_datasets, self.dataloaders,
                                            self.gallery_feature, self.model, dataset_split)

            return imgs_path, score
        else:
            LOGGER.warning('Update gallery feature first!')
            return [], []
