import gc
import os
from typing import Dict, List, Union
from huggingface_hub import login, hf_hub_download
import safetensors
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, CLIPTextModelWithProjection
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel

# CPUコア数を取得
cpu_cores = os.cpu_count()
print(f"CPUのコア数: {cpu_cores}")
torch.set_num_threads(cpu_cores)



# Hugging Faceにログイン
login(token="")

# Base models
SDXL_REPO = "stabilityai/stable-diffusion-xl-base-1.0"
DPO_REPO = "mhdang/dpo-sdxl-text2image-v1"
JN_REPO = "RunDiffusion/Juggernaut-XL-v9"
JSDXL_REPO = "stabilityai/japanese-stable-diffusion-xl"

# Evo-Ukiyoe
UKIYOE_REPO = "SakanaAI/Evo-Ukiyoe-v1"


def load_state_dict(checkpoint_file: Union[str, os.PathLike], device: str = "cpu"):
    file_extension = os.path.basename(checkpoint_file).split(".")[-1]
    if file_extension == "safetensors":
        return safetensors.torch.load_file(checkpoint_file, device=device)
    else:
        return torch.load(checkpoint_file, map_location=device)


def load_from_pretrained(repo_id, filename="diffusion_pytorch_model.fp16.safetensors", subfolder="unet", device="cuda") -> Dict[str, torch.Tensor]:
    return load_state_dict(
        hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder),
        device=device,
    )


def reshape_weight_task_tensors(task_tensors, weights):
    new_shape = weights.shape + (1,) * (task_tensors.dim() - weights.dim())
    weights = weights.view(new_shape)
    return weights


def linear(task_tensors: List[torch.Tensor], weights: torch.Tensor) -> torch.Tensor:
    task_tensors = torch.stack(task_tensors, dim=0)
    weights = reshape_weight_task_tensors(task_tensors, weights)
    weighted_task_tensors = task_tensors * weights
    mixed_task_tensors = weighted_task_tensors.sum(dim=0)
    return mixed_task_tensors


def merge_models(task_tensors, weights):
    keys = list(task_tensors[0].keys())
    weights = torch.tensor(weights, device=task_tensors[0][keys[0]].device)
    state_dict = {}
    for key in tqdm(keys, desc="Merging"):
        w_list = []
        for i, sd in enumerate(task_tensors):
            w = sd.pop(key)
            w_list.append(w)
        new_w = linear(task_tensors=w_list, weights=weights)
        state_dict[key] = new_w
    return state_dict


def split_conv_attn(weights):
    attn_tensors = {}
    conv_tensors = {}
    for key in list(weights.keys()):
        if any(k in key for k in ["to_k", "to_q", "to_v", "to_out.0"]):
            attn_tensors[key] = weights.pop(key)
        else:
            conv_tensors[key] = weights.pop(key)
    return {"conv": conv_tensors, "attn": attn_tensors}


def load_evo_ukiyoe(device="cuda") -> StableDiffusionXLPipeline:
    # Load base models
    sdxl_weights = split_conv_attn(load_from_pretrained(SDXL_REPO, device=device))
    dpo_weights = split_conv_attn(load_from_pretrained(DPO_REPO, "diffusion_pytorch_model.safetensors", device=device))
    jn_weights = split_conv_attn(load_from_pretrained(JN_REPO, device=device))
    jsdxl_weights = split_conv_attn(load_from_pretrained(JSDXL_REPO, device=device))

    # Merge base models
    tensors = [sdxl_weights, dpo_weights, jn_weights, jsdxl_weights]
    new_conv = merge_models(
        [sd["conv"] for sd in tensors],
        [0.15928833971605916, 0.1032449268871776, 0.6503217149752791, 0.08714501842148402],
    )
    new_attn = merge_models(
        [sd["attn"] for sd in tensors],
        [0.1877279276437178, 0.20014114603909822, 0.3922685507065275, 0.2198623756106564],
    )

    # Delete no longer needed variables to free memory
    del sdxl_weights, dpo_weights, jn_weights, jsdxl_weights
    gc.collect()
    if "cuda" in device:
        torch.cuda.empty_cache()

    # Instantiate UNet
    unet_config = UNet2DConditionModel.load_config(SDXL_REPO, subfolder="unet")
    unet = UNet2DConditionModel.from_config(unet_config).to(device=device)
    unet.load_state_dict({**new_conv, **new_attn})

    # Load other modules
    text_encoder = CLIPTextModelWithProjection.from_pretrained(
        JSDXL_REPO, subfolder="text_encoder", variant="fp16",
    )
    tokenizer = AutoTokenizer.from_pretrained(
        JSDXL_REPO, subfolder="tokenizer", use_fast=False,
    )

    # Load pipeline
    pipe = StableDiffusionXLPipeline.from_pretrained(
        SDXL_REPO,
        unet=unet,
        text_encoder=text_encoder,
        tokenizer=tokenizer,
        variant="fp16",
    )

    # Load Evo-Ukiyoe weights
    pipe.load_lora_weights(UKIYOE_REPO)
    pipe.fuse_lora(lora_scale=1.0)

    pipe = pipe.to(device=torch.device("cpu"))
    return pipe


# 実行例
if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    pipe = load_evo_ukiyoe(device=device)
    print("Evo-Ukiyoe pipeline loaded successfully.")

