Hackster is hosting Hackster Holidays, Ep. 7: Livestream & Giveaway Drawing. Watch previous episodes or stream live on Friday!Stream Hackster Holidays, Ep. 7 on Friday!
Aman Jobanputra
Published © GPL3+

Wear it!

The user can select a dress and they will be shown how they look in the dress in real-time.

IntermediateShowcase (no instructions)15 days75
Wear it!

Things used in this project

Hardware components

Minisforum Venus UM790 Pro with AMD Ryzen™ 9
Minisforum Venus UM790 Pro with AMD Ryzen™ 9
×1

Software apps and online services

AMD - Vitis AI
AMD - Ryzen AI Software

Story

Read more

Code

tester.ipynb

Python
Put this file in EdgeStyle-main folder
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gradio as gr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# torch\n",
    "import torch\n",
    "from torchvision import transforms\n",
    "from diffusers import (\n",
    "    AutoencoderKL,\n",
    "    UNet2DConditionModel,\n",
    "    UniPCMultistepScheduler,\n",
    "    StableDiffusionControlNetPipeline,\n",
    ")\n",
    "from diffusers.optimization import get_scheduler\n",
    "from transformers import AutoTokenizer, CLIPTextModel, CLIPModel, CLIPProcessor\n",
    "\n",
    "from model.utils import BestEmbeddings\n",
    "from model.edgestyle_multicontrolnet import EdgeStyleMultiControlNetModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\Users\\amanj\\miniconda3\\envs\\ryzenai-1.1-20240510-134534\\lib\\site-packages\\torch\\hub.py:294: UserWarning: You are about to download and run code from an untrusted repository. In a future release, this won't be allowed. To add the repository to your trusted list, change the command to {calling_fn}(..., trust_repo=False) and a command prompt will appear asking for an explicit confirmation of trust, or load(..., trust_repo=True), which will assume that the prompt is to be answered with 'yes'. You can also use load(..., trust_repo='check') which will only prompt for confirmation if the repo is not already trusted. This will eventually be the default behaviour\n",
      "  warnings.warn(\n",
      "Downloading: \"https://github.com/ultralytics/yolov5/zipball/master\" to C:\\Users\\amanj/.cache\\torch\\hub\\master.zip\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[31m\u001b[1mrequirements:\u001b[0m Ultralytics requirement ['pillow>=10.3.0'] not found, attempting AutoUpdate...\n",
      "Retry 1/2 failed: Command 'pip install --no-cache-dir \"pillow>=10.3.0\" ' returned non-zero exit status 1.\n",
      "Requirement already satisfied: pillow>=10.3.0 in c:\\users\\amanj\\miniconda3\\envs\\ryzenai-1.1-20240510-134534\\lib\\site-packages (10.4.0)\n",
      "\n",
      "\u001b[31m\u001b[1mrequirements:\u001b[0m AutoUpdate success  13.4s, installed 1 package: ['pillow>=10.3.0']\n",
      "\u001b[31m\u001b[1mrequirements:\u001b[0m  \u001b[1mRestart runtime or rerun command for updates to take effect\u001b[0m\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "YOLOv5  2024-7-31 Python-3.9.18 torch-2.1.0+cpu CPU\n",
      "\n",
      "Downloading https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5s.pt to yolov5s.pt...\n",
      "100%|| 14.1M/14.1M [00:12<00:00, 1.15MB/s]\n",
      "\n",
      "Fusing layers... \n",
      "YOLOv5s summary: 213 layers, 7225885 parameters, 0 gradients, 16.4 GFLOPs\n",
      "Adding AutoShape... \n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a422995dc9a14caea9814cb0a531be16",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "body_pose_model.pth:   0%|          | 0.00/209M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cb5d332b210147fc943530185620a3b7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "hand_pose_model.pth:   0%|          | 0.00/147M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f163ebf469ff4c90a13764f16f7266f2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "facenet.pth:   0%|          | 0.00/154M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# local\n",
    "from model.controllora import ControlLoRAModel, CachedControlNetModel\n",
    "from model.utils import BestEmbeddings\n",
    "from model.edgestyle_multicontrolnet import EdgeStyleMultiControlNetModel\n",
    "from model.edgestyle_pipeline import EdgeStyleStableDiffusionControlNetPipeline\n",
    "from extract_dataset import process_batch, create_sam_images_for_batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "RESOLUTION = 512\n",
    "\n",
    "IMAGES_TRANSFORMS = transforms.Compose(\n",
    "    [\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize([0.5], [0.5]),\n",
    "    ]\n",
    ")\n",
    "\n",
    "CONDITIONING_IMAGES_TRANSFORMS = transforms.Compose(\n",
    "    [\n",
    "        transforms.ToTensor(),\n",
    "    ]\n",
    ")\n",
    "\n",
    "CONTROLNET_PATTERN = [0, None, 1, None, 1, None]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "model loaded\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The config attributes {'mid_block_type': 'UNetMidBlock2DCrossAttn'} were passed to ControlLoRAModel, but are not expected and will be ignored. Please verify your config.json configuration file.\n",
      "The config attributes {'mid_block_type': 'UNetMidBlock2DCrossAttn'} were passed to ControlLoRAModel, but are not expected and will be ignored. Please verify your config.json configuration file.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "pipeline started\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a8dba0d6390e42398b887cd8e2fdd3d7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading pipeline components...:   0%|          | 0/5 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "PRETRAINED_MODEL_NAME_OR_PATH = \"./models/Realistic_Vision_V5.1_noVAE\"\n",
    "PRETRAINED_VAE_NAME_OR_PATH = \"./models/sd-vae-ft-mse\"\n",
    "PRETRAINED_OPENPOSE_NAME_OR_PATH = \"./models/control_v11p_sd15_openpose\"\n",
    "CONTROLNET_MODEL_NAME_OR_PATH = \"./models/EdgeStyle/controlnet\"\n",
    "CLIP_MODEL_NAME_OR_PATH = \"./models/clip-vit-large-patch14\"\n",
    "\n",
    "\n",
    "NEGATIVE_PROMPT = (\n",
    "    r\"deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, \"\n",
    "    \"cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, \"\n",
    "    \"disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, \"\n",
    "    \"floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation\"\n",
    ")\n",
    "\n",
    "PROMT_TO_ADD = (\n",
    "    \", gray background, RAW photo, subject, 8k uhd, dslr, soft lighting, high quality\"\n",
    ")\n",
    "print(\"model loaded\")\n",
    "model = CLIPModel.from_pretrained(CLIP_MODEL_NAME_OR_PATH)\n",
    "processor = CLIPProcessor.from_pretrained(CLIP_MODEL_NAME_OR_PATH)\n",
    "\n",
    "best_embeddings = BestEmbeddings(model, processor)\n",
    "\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\n",
    "    PRETRAINED_MODEL_NAME_OR_PATH,\n",
    "    subfolder=\"tokenizer\",\n",
    "    use_fast=False,\n",
    ")\n",
    "\n",
    "text_encoder = CLIPTextModel.from_pretrained(\n",
    "    PRETRAINED_MODEL_NAME_OR_PATH,\n",
    "    subfolder=\"text_encoder\",\n",
    ")\n",
    "vae = AutoencoderKL.from_pretrained(PRETRAINED_VAE_NAME_OR_PATH)\n",
    "\n",
    "unet = UNet2DConditionModel.from_pretrained(\n",
    "    PRETRAINED_MODEL_NAME_OR_PATH,\n",
    "    subfolder=\"unet\",\n",
    ")\n",
    "\n",
    "openpose = CachedControlNetModel.from_pretrained(PRETRAINED_OPENPOSE_NAME_OR_PATH)\n",
    "\n",
    "controlnet = EdgeStyleMultiControlNetModel.from_pretrained(\n",
    "    CONTROLNET_MODEL_NAME_OR_PATH,\n",
    "    vae=vae,\n",
    "    controlnet_class=ControlLoRAModel,\n",
    "    load_pattern=CONTROLNET_PATTERN,\n",
    "    static_controlnets=[None, openpose, None, openpose, None, openpose],\n",
    ")\n",
    "for net in controlnet.nets:\n",
    "    if net is not openpose:\n",
    "        net.tie_weights(unet)\n",
    "\n",
    "# pipeline = StableDiffusionControlNetPipeline.from_pretrained(\n",
    "#     PRETRAINED_MODEL_NAME_OR_PATH,\n",
    "#     vae=vae,\n",
    "#     text_encoder=text_encoder,\n",
    "#     tokenizer=tokenizer,\n",
    "#     unet=unet,\n",
    "#     controlnet=controlnet,\n",
    "#     safety_checker=None,\n",
    "# )\n",
    "print(\"pipeline started\")\n",
    "pipeline = EdgeStyleStableDiffusionControlNetPipeline.from_pretrained(\n",
    "    PRETRAINED_MODEL_NAME_OR_PATH,\n",
    "    vae=vae,\n",
    "    text_encoder=text_encoder,\n",
    "    tokenizer=tokenizer,\n",
    "    unet=unet,\n",
    "    controlnet=controlnet,\n",
    "    safety_checker=None,\n",
    ")\n",
    "pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)\n",
    "generator = torch.Generator().manual_seed(42)\n",
    "# vae.enable_xformers_memory_efficient_attention(attention_op=None)\n",
    "# pipeline.enable_xformers_memory_efficient_attention()\n",
    "pipeline = pipeline.to(device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "def preprocess(image_subject, image_cloth1, image_cloth2):\n",
    "    data = process_batch([image_subject, image_cloth1, image_cloth2])\n",
    "    if data is None or len(data) < 3:\n",
    "        # try again, sometimes first time fails\n",
    "        print(\"Retrying\")\n",
    "        data = process_batch([image_subject, image_cloth1, image_cloth2])\n",
    "    data = create_sam_images_for_batch(data)\n",
    "\n",
    "    image_subject_head = data[\"head_image\"].iloc[0]\n",
    "    image_cloth1_clothes = data[\"clothes_image\"].iloc[1]\n",
    "    image_cloth2_clothes = data[\"clothes_image\"].iloc[2]\n",
    "\n",
    "    image_subject_openpose = data[\"openpose_image\"].iloc[0]\n",
    "    image_cloth1_openpose = data[\"openpose_image\"].iloc[1]\n",
    "    image_cloth2_openpose = data[\"openpose_image\"].iloc[2]\n",
    "    print(\"preprocess done\")\n",
    "    return (\n",
    "        image_subject_head,\n",
    "        image_subject_openpose,\n",
    "        image_cloth1_clothes,\n",
    "        image_cloth1_openpose,\n",
    "        image_cloth2_clothes,\n",
    "        image_cloth2_openpose,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|| 1/1 [00:02<00:00,  2.32s/it]\n"
     ]
    }
   ],
   "source": [
    "data = process_batch([subject, cloth1, cloth2])\n",
    "data = create_sam_images_for_batch(data)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>xmin</th>\n",
       "      <th>ymin</th>\n",
       "      <th>xmax</th>\n",
       "      <th>ymax</th>\n",
       "      <th>confidence</th>\n",
       "      <th>...</th>\n",
       "      <th>subject_image</th>\n",
       "      <th>mask_image</th>\n",
       "      <th>agnostic_image</th>\n",
       "      <th>clothes_image</th>\n",
       "      <th>head_image</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>224.032074</td>\n",
       "      <td>35.632568</td>\n",
       "      <td>422.107849</td>\n",
       "      <td>512.0</td>\n",
       "      <td>0.922629</td>\n",
       "      <td>...</td>\n",
       "      <td>&lt;PIL.Image.Image image mode=RGB size=512x512 a...</td>\n",
       "      <td>&lt;PIL.Image.Image image mode=RGB size=512x512 a...</td>\n",
       "      <td>&lt;PIL.Image.Image image mode=RGB size=512x512 a...</td>\n",
       "      <td>&lt;PIL.Image.Image image mode=RGB size=512x512 a...</td>\n",
       "      <td>&lt;PIL.Image.Image image mode=RGB size=512x512 a...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>1 rows  17 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "         xmin       ymin        xmax   ymax  confidence  ...  \\\n",
       "0  224.032074  35.632568  422.107849  512.0    0.922629  ...   \n",
       "\n",
       "                                       subject_image  \\\n",
       "0  <PIL.Image.Image image mode=RGB size=512x512 a...   \n",
       "\n",
       "                                          mask_image  \\\n",
       "0  <PIL.Image.Image image mode=RGB size=512x512 a...   \n",
       "\n",
       "                                      agnostic_image  \\\n",
       "0  <PIL.Image.Image image mode=RGB size=512x512 a...   \n",
       "\n",
       "                                       clothes_image  \\\n",
       "0  <PIL.Image.Image image mode=RGB size=512x512 a...   \n",
       "\n",
       "                                          head_image  \n",
       "0  <PIL.Image.Image image mode=RGB size=512x512 a...  \n",
       "\n",
       "[1 rows x 17 columns]"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "def try_on(\n",
    "    image_subject_agnostic,\n",
    "    image_subject_openpose,\n",
    "    image_cloth1_clothes,\n",
    "    image_cloth1_openpose,\n",
    "    image_cloth2_clothes,\n",
    "    image_cloth2_openpose,\n",
    "    scale,\n",
    "    steps,\n",
    "):\n",
    "    with torch.no_grad():\n",
    "        generator.manual_seed(42)\n",
    "        prompts = best_embeddings([image_cloth1_clothes])\n",
    "        image = pipeline(\n",
    "            prompt=prompts[0] + \" \" + PROMT_TO_ADD,\n",
    "            # prompt=prompts[0],\n",
    "            guidance_scale=scale,\n",
    "            image=[\n",
    "                IMAGES_TRANSFORMS(image_subject_agnostic).unsqueeze(0),\n",
    "                CONDITIONING_IMAGES_TRANSFORMS(image_subject_openpose).unsqueeze(0),\n",
    "                IMAGES_TRANSFORMS(image_cloth1_clothes).unsqueeze(0),\n",
    "                CONDITIONING_IMAGES_TRANSFORMS(image_cloth1_openpose).unsqueeze(0),\n",
    "                IMAGES_TRANSFORMS(image_cloth2_clothes).unsqueeze(0),\n",
    "                CONDITIONING_IMAGES_TRANSFORMS(image_cloth2_openpose).unsqueeze(0),\n",
    "            ],\n",
    "            negative_prompt=NEGATIVE_PROMPT,\n",
    "            num_inference_steps=steps,\n",
    "            generator=generator,\n",
    "            # control_guidance_start=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n",
    "            # control_guidance_end=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n",
    "        ).images[0]\n",
    "    print(\"try on done\")\n",
    "\n",
    "    return image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "import cv2\n",
    "imgs = list((Path.cwd() / \"test_images\").glob(\"*\"))\n",
    "cloth1, cloth2 = cv2.imread(imgs[1]), cv2.imread(imgs[2])\n",
    "subject = cv2.imread(imgs[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|| 3/3 [00:07<00:00,  2.34s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "preprocess done\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "preprocess_output = preprocess(subject, cloth1, cloth2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ea7aeb37ed684ee286a8bcb29cb8752e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/20 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "try on done\n"
     ]
    }
   ],
   "source": [
    "fimage = try_on(*preprocess_output, scale=4.00, steps=30)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "fimage.show() #ye red tshirt ko purple kar dia"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ryzenai-1.1-20240510-134534",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}

onnx-export.ipynb

Python
Put this file in EdgeStyle-main folder
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Dict, Any, Tuple, Union, Optional\n",
    "\n",
    "from fvcore.nn import FlopCountAnalysis\n",
    "from torchinfo import summary\n",
    "\n",
    "import gc\n",
    "import os\n",
    "import torch\n",
    "import onnx\n",
    "import numpy as np\n",
    "\n",
    "from diffusers import UNet2DConditionModel, AutoencoderKL\n",
    "from diffusers.models.modeling_utils import ModelMixin\n",
    "from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel\n",
    "\n",
    "from model.edgestyle_multicontrolnet import EdgeStyleMultiControlNetModel\n",
    "from model.controllora import ControlLoRAModel, CachedControlNetModel\n",
    "\n",
    "from optimum.onnx.utils import check_model_uses_external_data\n",
    "\n",
    "# from onnx import version_converter, helper\n",
    "# from onnxruntime.transformers.shape_infer_helper import SymbolicShapeInferenceHelper\n",
    "# import onnx_graphsurgeon as gs\n",
    "\n",
    "PRETRAINED_MODEL_NAME_OR_PATH = \"./models/Realistic_Vision_V5.1_noVAE\"\n",
    "PRETRAINED_VAE_NAME_OR_PATH = \"./models/sd-vae-ft-mse\"\n",
    "PRETRAINED_OPENPOSE_NAME_OR_PATH = \"./models/control_v11p_sd15_openpose\"\n",
    "CONTROLNET_MODEL_NAME_OR_PATH = \"./models/EdgeStyle/controlnet\"\n",
    "ONNX_MODEL_NAME_OR_PATH = \"./models/EdgeStyle/\"\n",
    "CONTROLNET_PATTERN = [0, None, 1, None, 1, None]\n",
    "\n",
    "device = torch.device(\"cpu\")\n",
    "\n",
    "\n",
    "class OnnxUNetAndControlnets(ModelMixin):\n",
    "    def __init__(\n",
    "        self, unet: UNet2DConditionModel, controlnet: EdgeStyleMultiControlNetModel\n",
    "    ):\n",
    "        super().__init__()\n",
    "        self.unet = unet\n",
    "        self.controlnet = controlnet\n",
    "\n",
    "    def forward(\n",
    "        self,\n",
    "        sample: torch.FloatTensor,\n",
    "        timestep: Union[torch.Tensor, float, int],\n",
    "        encoder_hidden_states: torch.Tensor,\n",
    "        conditioning_scale: torch.Tensor,\n",
    "        image_0: torch.FloatTensor,\n",
    "        image_1: torch.FloatTensor,\n",
    "        image_2: torch.FloatTensor,\n",
    "        image_3: torch.FloatTensor,\n",
    "        image_4: torch.FloatTensor,\n",
    "        image_5: torch.FloatTensor,\n",
    "    ) -> torch.FloatTensor:\n",
    "        down_block_res_samples, mid_block_res_sample = self.controlnet(\n",
    "            sample,\n",
    "            timestep,\n",
    "            encoder_hidden_states=encoder_hidden_states,\n",
    "            controlnet_cond=[image_0, image_1, image_2, image_3, image_4, image_5],\n",
    "            conditioning_scale=conditioning_scale,\n",
    "            return_dict=False,\n",
    "        )\n",
    "\n",
    "        noise_pred = self.unet(\n",
    "            sample,\n",
    "            timestep,\n",
    "            encoder_hidden_states=encoder_hidden_states,\n",
    "            down_block_additional_residuals=down_block_res_samples,\n",
    "            mid_block_additional_residual=mid_block_res_sample,\n",
    "            return_dict=False,\n",
    "        )[0]\n",
    "\n",
    "        return noise_pred\n",
    "\n",
    "\n",
    "def print_tensor(tensor):\n",
    "    name = tensor.name\n",
    "    tensor_type = tensor.type.tensor_type\n",
    "    data_type = onnx.TensorProto.DataType.keys()[tensor_type.elem_type]\n",
    "    shape = [\n",
    "        dim.dim_param if dim.dim_param else dim.dim_value\n",
    "        for dim in tensor_type.shape.dim\n",
    "    ]\n",
    "    print(f\"{name} = {data_type} {shape}\")\n",
    "\n",
    "\n",
    "def print_io(model):\n",
    "    for tensor in model.graph.input:\n",
    "        print_tensor(tensor)\n",
    "    for tensor in model.graph.output:\n",
    "        print_tensor(tensor)\n",
    "\n",
    "@torch.no_grad()\n",
    "def export_unet():\n",
    "    vae = AutoencoderKL.from_pretrained(PRETRAINED_VAE_NAME_OR_PATH)\n",
    "\n",
    "    unet = UNet2DConditionModel.from_pretrained(\n",
    "        PRETRAINED_MODEL_NAME_OR_PATH,\n",
    "        subfolder=\"unet\",\n",
    "        # torch_dtype=torch.float16,\n",
    "    )\n",
    "\n",
    "    openpose = CachedControlNetModel.from_pretrained(PRETRAINED_OPENPOSE_NAME_OR_PATH,\n",
    "                                                    #   torch_dtype=torch.float16,\n",
    "                                                      )\n",
    "\n",
    "    controlnet = EdgeStyleMultiControlNetModel.from_pretrained(\n",
    "        CONTROLNET_MODEL_NAME_OR_PATH,\n",
    "        vae=vae,\n",
    "        controlnet_class=ControlLoRAModel,\n",
    "        load_pattern=CONTROLNET_PATTERN,\n",
    "        static_controlnets=[None, openpose, None, openpose, None, openpose],\n",
    "        # torch_dtype=torch.float16,\n",
    "    )\n",
    "    for net in controlnet.nets:\n",
    "        if net is not openpose:\n",
    "            net.tie_weights(unet)\n",
    "            # net.fuse_lora()\n",
    "\n",
    "    model = OnnxUNetAndControlnets(unet, controlnet)    \n",
    "\n",
    "    # set all parameters to not require gradients\n",
    "    for param in model.parameters():\n",
    "        param.requires_grad = False\n",
    "\n",
    "    model = model.eval()\n",
    "\n",
    "    model = model.to(device)\n",
    "\n",
    "    latent_model_input = torch.randn(2, 4, 64, 64)\n",
    "    timesteps = torch.randint(0, 1, (1,)).long()\n",
    "    timesteps = timesteps.repeat(2)\n",
    "    prompt_embeds = torch.randn(2, 77, 768)\n",
    "\n",
    "    conditioning_scale = torch.randn(6)\n",
    "\n",
    "    image_0 = torch.randn(1, 320, 64, 64)\n",
    "    image_0 = image_0.repeat(2, 1, 1, 1)\n",
    "    image_1 = torch.randn(1, 3, 512, 512)\n",
    "    image_1 = image_1.repeat(2, 1, 1, 1)\n",
    "    image_2 = torch.randn(1, 320, 64, 64)\n",
    "    image_2 = image_2.repeat(2, 1, 1, 1)\n",
    "    image_3 = torch.randn(1, 3, 512, 512)\n",
    "    image_3 = image_3.repeat(2, 1, 1, 1)\n",
    "    image_4 = torch.randn(1, 320, 64, 64)\n",
    "    image_4 = image_4.repeat(2, 1, 1, 1)\n",
    "    image_5 = torch.randn(1, 3, 512, 512)\n",
    "    image_5 = image_5.repeat(2, 1, 1, 1)\n",
    "\n",
    "    dummy_input = (\n",
    "        latent_model_input,\n",
    "        timesteps,\n",
    "        prompt_embeds,\n",
    "        conditioning_scale,\n",
    "        image_0,\n",
    "        image_1,\n",
    "        image_2,\n",
    "        image_3,\n",
    "        image_4,\n",
    "        image_5,\n",
    "    )    \n",
    "    dummy_input = tuple(x.to(device) for x in dummy_input)\n",
    "\n",
    "    predicted_noise_torch = model(*dummy_input)\n",
    "\n",
    "    flops = FlopCountAnalysis(model, dummy_input)\n",
    "    print(f\"fvcore FLOPs: {flops.total()/1e9:.2f} GFLOPs\")\n",
    "\n",
    "    statistics = summary(model, input_data=dummy_input, verbose=0)\n",
    "    print(f\"torchinfo FLOPs: {statistics.total_mult_adds/1e9:.2f} GFLOPs\")\n",
    "\n",
    "    onnx_model_path = os.path.join(ONNX_MODEL_NAME_OR_PATH, 'unet', \"model.onnx\")\n",
    "    onnx_model_dir = os.path.join(onnx_model_path.replace(\"model.onnx\", \"\"))\n",
    "\n",
    "    # onnx_model_dir = os.path.join(ONNX_MODEL_NAME_OR_PATH, 'unet')\n",
    "                                                  \n",
    "    os.makedirs(onnx_model_dir, exist_ok=True)\n",
    "    # os.makedirs(onnx_model_dir + '-infer', exist_ok=True)\n",
    "\n",
    "    print('exporting onnx model for unet and controlnets...') \n",
    "\n",
    "    # Conversion to ONNX\n",
    "    torch.onnx.export(\n",
    "        model,\n",
    "        dummy_input,\n",
    "        onnx_model_path,\n",
    "        export_params=True,\n",
    "        input_names=[\n",
    "            \"sample\",\n",
    "            \"timestep\",\n",
    "            \"encoder_hidden_states\",\n",
    "            \"conditioning_scale\",\n",
    "            \"image_0\",\n",
    "            \"image_1\",\n",
    "            \"image_2\",\n",
    "            \"image_3\",\n",
    "            \"image_4\",\n",
    "            \"image_5\",\n",
    "        ],\n",
    "        output_names=[\"output\"],\n",
    "        dynamic_axes={\n",
    "            \"sample\": {0: \"batch_size\"},  # variable length axes\n",
    "            \"timestep\": {0: \"batch_size\"},          \n",
    "            \"encoder_hidden_states\": {0: \"batch_size\"},\n",
    "            \"image_0\": {0: \"batch_size\"},\n",
    "            \"image_1\": {0: \"batch_size\"},\n",
    "            \"image_2\": {0: \"batch_size\"},\n",
    "            \"image_3\": {0: \"batch_size\"},\n",
    "            \"image_4\": {0: \"batch_size\"},\n",
    "            \"image_5\": {0: \"batch_size\"},\n",
    "            \"output\": {0: \"batch_size\"},\n",
    "        },\n",
    "        # training=torch.onnx.TrainingMode.EVAL,\n",
    "        do_constant_folding=True,\n",
    "        # verbose=True,\n",
    "        opset_version=14,\n",
    "    )\n",
    "\n",
    "    # check if external data was exported\n",
    "    onnx_model = onnx.load(onnx_model_path, load_external_data=False)\n",
    "    model_uses_external_data = check_model_uses_external_data(onnx_model)\n",
    "\n",
    "    if model_uses_external_data:\n",
    "        # try free model memory\n",
    "        del model\n",
    "        del onnx_model\n",
    "        gc.collect()\n",
    "        if device.type == \"cuda\" and torch.cuda.is_available():\n",
    "            torch.cuda.empty_cache()\n",
    "\n",
    "        onnx_model = onnx.load(\n",
    "            str(onnx_model_path), load_external_data=True\n",
    "        )  # this will probably be too memory heavy for large models\n",
    "        onnx.save(\n",
    "            onnx_model,\n",
    "            onnx_model_path,\n",
    "            save_as_external_data=True,\n",
    "            all_tensors_to_one_file=True,\n",
    "            location='model.onnx' + \"_data\",\n",
    "            size_threshold=1024,\n",
    "            convert_attribute=True,\n",
    "        )\n",
    "\n",
    "        del onnx_model\n",
    "        gc.collect()\n",
    "        if device.type == \"cuda\" and torch.cuda.is_available():\n",
    "            torch.cuda.empty_cache()\n",
    "\n",
    "        # delete all files except the model.onnx and onnx external data\n",
    "        for file in os.listdir(onnx_model_dir):\n",
    "            if file != \"model.onnx\" and file != \"model.onnx\" + \"_data\":\n",
    "                os.remove(os.path.join(onnx_model_dir, file))\n",
    "        \n",
    "    # print('running shape inference script...')\n",
    "\n",
    "    # # Print input/output shapes\n",
    "    # print(\"\\n*** BEFORE ***\")\n",
    "    # onnx_model = onnx.load(onnx_model_path, load_external_data=True)\n",
    "    # print_io(onnx_model)    \n",
    "\n",
    "    # # Run symbolic shape inference\n",
    "    # shape_infer_helper = SymbolicShapeInferenceHelper(onnx_model, verbose=3, auto_merge=True, guess_output_rank=False)\n",
    "    # all_infered = shape_infer_helper.infer(dynamic_axis_mapping = {\"batch_size\": 2})\n",
    "    # print(f\"Shape inference completed: {all_infered}\")\n",
    "    # onnx_model = shape_infer_helper.model_\n",
    "\n",
    "    # # Print input/output shapes again\n",
    "    # print(\"\\n\\n*** AFTER ***\")\n",
    "    # print_io(onnx_model)\n",
    "\n",
    "    # # save the model\n",
    "    # onnx.save(onnx_model, \n",
    "    #           onnx_model_dir + '-infer/model.onnx', \n",
    "    #           save_as_external_data=True, \n",
    "    #           all_tensors_to_one_file=True, \n",
    "    #           location='model.onnx' + \"_data\", \n",
    "    #           size_threshold=1024, \n",
    "    #           convert_attribute=True)\n",
    "    \n",
    "    # os.rename(f\"{onnx_model_path}\", f\"{onnx_model_path}-no-infer\")\n",
    "\n",
    "    # run shape inference script\n",
    "    # os.system(\n",
    "    #     f\"python -m onnxruntime.tools.symbolic_shape_infer \"\n",
    "    #     f\"--input {onnx_model_path} \"\n",
    "    #     f\"--output {onnx_model_path}-infer \"\n",
    "    #     f\"--auto_merge \"\n",
    "    #     # f\"--save_as_external_data \"\n",
    "    #     # f\"--all_tensors_to_one_file \"\n",
    "    #     # f\"--external_data_location model.onnx-infer_data \"\n",
    "    #     # f\"--external_data_size_threshold 1024 \"\n",
    "    #     f\"--verbose 3\"\n",
    "    # )\n",
    "    \n",
    "    # # delete old onnx model and data and rename new onnx model\n",
    "    # # os.remove(onnx_model_path)\n",
    "    # # os.remove(onnx_model_path + \"_data\")\n",
    "    # os.rename(f\"{onnx_model_path}-infer\", onnx_model_path)\n",
    "    \n",
    "    # print('shape inference script completed...')\n",
    "\n",
    "    print('exported onnx model for unet and controlnets...')\n",
    "    print('checking onnx model...')    \n",
    "\n",
    "    # graph = gs.import_onnx(onnx.load(onnx_model_dir + '-infer/model.onnx'))\n",
    "    # tensors = graph.tensors()\n",
    "\n",
    "    # for tensor_key in tensors.keys():\n",
    "    #     if '/controlnet/multi_controlnet_down_blocks' in tensor_key:\n",
    "    #         print(tensors[tensor_key])\n",
    "\n",
    "    # onnx.checker.check_model(onnx_model_path, full_check=True)\n",
    "\n",
    "    onnx_unet = OnnxRuntimeModel.from_pretrained(onnx_model_dir, provider=\"VitisAIExecutionProvider\", provider_options = [{'config_file': 'vaip_config.json', 'cacheKey': 'modelcachekey'}])\n",
    "    \n",
    "\n",
    "    predicted_noise_onnx = onnx_unet(\n",
    "        sample=latent_model_input,\n",
    "        timestep=timesteps,\n",
    "        encoder_hidden_states=prompt_embeds,\n",
    "        conditioning_scale=conditioning_scale,\n",
    "        image_0=image_0,\n",
    "        image_1=image_1,\n",
    "        image_2=image_2,\n",
    "        image_3=image_3,\n",
    "        image_4=image_4,\n",
    "        image_5=image_5,\n",
    "    )[0]\n",
    "    print('checking how close the output is...')\n",
    "\n",
    "    # compare the output error predicted_noise_torch is FloatTensor and predicted_noise_onnx is numpy array\n",
    "    np.testing.assert_allclose(\n",
    "        predicted_noise_torch.detach().cpu().numpy(),\n",
    "        predicted_noise_onnx,\n",
    "        rtol=1e-03,\n",
    "        atol=1e-05,\n",
    "    )\n",
    "    print('exported onnx model for unet and controlnets is correct...')\n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def export_vae_encoder_decoder():\n",
    "\n",
    "    vae = AutoencoderKL.from_pretrained(PRETRAINED_VAE_NAME_OR_PATH)\n",
    "\n",
    "    vae_encoder = vae.encoder\n",
    "    vae_decoder = vae.decoder\n",
    "\n",
    "    latent_model_input = torch.randn(2, 3, 512, 512)\n",
    "    dummy_input = latent_model_input\n",
    "\n",
    "    onnx_model_path = os.path.join(\n",
    "        ONNX_MODEL_NAME_OR_PATH, \"encoder\", \"model.onnx\"\n",
    "    )\n",
    "    onnx_model_dir = os.path.join(onnx_model_path.replace(\"model.onnx\", \"\"))\n",
    "\n",
    "    os.makedirs(onnx_model_dir, exist_ok=True)\n",
    "\n",
    "    torch.onnx.export(\n",
    "        vae_encoder,\n",
    "        dummy_input,\n",
    "        onnx_model_path,\n",
    "        export_params=True,\n",
    "        input_names=[\n",
    "            \"image\",\n",
    "        ],\n",
    "        output_names=[\"output\"],\n",
    "        dynamic_axes={\n",
    "            \"image\": {0: \"batch_size\"},  # variable length axes\n",
    "            \"output\": {0: \"batch_size\"},\n",
    "        },\n",
    "    )\n",
    "    # check the output has correct shape\n",
    "\n",
    "    onnx.checker.check_model(onnx_model_path)\n",
    "\n",
    "    latent_model_input = torch.randn(2, 4, 64, 64)\n",
    "    dummy_input = latent_model_input\n",
    "\n",
    "    onnx_model_path = os.path.join(\n",
    "        ONNX_MODEL_NAME_OR_PATH, \"decoder\", \"model.onnx\"\n",
    "    )\n",
    "    onnx_model_dir = os.path.join(onnx_model_path.replace(\"model.onnx\", \"\"))\n",
    "\n",
    "    os.makedirs(onnx_model_dir, exist_ok=True)\n",
    "\n",
    "    torch.onnx.export(\n",
    "        vae_decoder,\n",
    "        dummy_input,\n",
    "        onnx_model_path,\n",
    "        export_params=True,\n",
    "        input_names=[\n",
    "            \"latent_sample\",\n",
    "        ],\n",
    "        output_names=[\"output\"],\n",
    "        dynamic_axes={\n",
    "            \"latent_sample\": {0: \"batch_size\"},  # variable length axes\n",
    "            \"output\": {0: \"batch_size\"},\n",
    "        },\n",
    "    )\n",
    "    onnx.checker.check_model(onnx_model_path)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The config attributes {'mid_block_type': 'UNetMidBlock2DCrossAttn'} were passed to ControlLoRAModel, but are not expected and will be ignored. Please verify your config.json configuration file.\n",
      "The config attributes {'mid_block_type': 'UNetMidBlock2DCrossAttn'} were passed to ControlLoRAModel, but are not expected and will be ignored. Please verify your config.json configuration file.\n",
      "Unsupported operator aten::mul encountered 577 time(s)\n",
      "Unsupported operator aten::div encountered 205 time(s)\n",
      "Unsupported operator aten::exp encountered 7 time(s)\n",
      "Unsupported operator aten::sin encountered 7 time(s)\n",
      "Unsupported operator aten::cos encountered 7 time(s)\n",
      "Unsupported operator aten::add encountered 661 time(s)\n",
      "Unsupported operator aten::silu encountered 301 time(s)\n",
      "Unsupported operator aten::scaled_dot_product_attention encountered 116 time(s)\n",
      "Unsupported operator aten::gelu encountered 58 time(s)\n",
      "The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.\n",
      "controlnet.nets.0.controlnet_cond_embedding, controlnet.nets.0.controlnet_cond_embedding.autoencoder, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.conv_act, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.conv_in, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.conv_norm_out, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.conv_out, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.mid_block, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.mid_block.attentions.0, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.mid_block.attentions.0.group_norm, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.mid_block.attentions.0.to_k, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.mid_block.attentions.0.to_out.0, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.mid_block.attentions.0.to_out.1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.mid_block.attentions.0.to_q, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.mid_block.attentions.0.to_v, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.mid_block.resnets.0, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.mid_block.resnets.0.conv1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.mid_block.resnets.0.conv2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.mid_block.resnets.0.dropout, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.mid_block.resnets.0.norm1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.mid_block.resnets.0.norm2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.mid_block.resnets.1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.mid_block.resnets.1.conv1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.mid_block.resnets.1.conv2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.mid_block.resnets.1.dropout, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.mid_block.resnets.1.norm1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.mid_block.resnets.1.norm2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.0, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.0.resnets.0, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.0.resnets.0.conv1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.0.resnets.0.conv2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.0.resnets.0.dropout, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.0.resnets.0.norm1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.0.resnets.0.norm2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.0.resnets.1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.0.resnets.1.conv1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.0.resnets.1.conv2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.0.resnets.1.dropout, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.0.resnets.1.norm1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.0.resnets.1.norm2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.0.resnets.2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.0.resnets.2.conv1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.0.resnets.2.conv2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.0.resnets.2.dropout, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.0.resnets.2.norm1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.0.resnets.2.norm2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.0.upsamplers.0, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.0.upsamplers.0.conv, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.1.resnets.0, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.1.resnets.0.conv1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.1.resnets.0.conv2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.1.resnets.0.dropout, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.1.resnets.0.norm1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.1.resnets.0.norm2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.1.resnets.1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.1.resnets.1.conv1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.1.resnets.1.conv2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.1.resnets.1.dropout, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.1.resnets.1.norm1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.1.resnets.1.norm2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.1.resnets.2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.1.resnets.2.conv1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.1.resnets.2.conv2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.1.resnets.2.dropout, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.1.resnets.2.norm1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.1.resnets.2.norm2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.1.upsamplers.0, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.1.upsamplers.0.conv, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.2.resnets.0, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.2.resnets.0.conv1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.2.resnets.0.conv2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.2.resnets.0.conv_shortcut, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.2.resnets.0.dropout, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.2.resnets.0.norm1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.2.resnets.0.norm2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.2.resnets.1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.2.resnets.1.conv1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.2.resnets.1.conv2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.2.resnets.1.dropout, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.2.resnets.1.norm1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.2.resnets.1.norm2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.2.resnets.2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.2.resnets.2.conv1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.2.resnets.2.conv2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.2.resnets.2.dropout, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.2.resnets.2.norm1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.2.resnets.2.norm2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.2.upsamplers.0, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.2.upsamplers.0.conv, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.3, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.3.resnets.0, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.3.resnets.0.conv1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.3.resnets.0.conv2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.3.resnets.0.conv_shortcut, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.3.resnets.0.dropout, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.3.resnets.0.norm1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.3.resnets.0.norm2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.3.resnets.1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.3.resnets.1.conv1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.3.resnets.1.conv2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.3.resnets.1.dropout, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.3.resnets.1.norm1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.3.resnets.1.norm2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.3.resnets.2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.3.resnets.2.conv1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.3.resnets.2.conv2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.3.resnets.2.dropout, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.3.resnets.2.norm1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.decoder.up_blocks.3.resnets.2.norm2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.conv_act, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.conv_in, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.conv_norm_out, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.conv_out, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.0, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.0.downsamplers.0, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.0.downsamplers.0.conv, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.0.resnets.0, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.0.resnets.0.conv1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.0.resnets.0.conv2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.0.resnets.0.dropout, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.0.resnets.0.norm1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.0.resnets.0.norm2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.0.resnets.1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.0.resnets.1.conv1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.0.resnets.1.conv2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.0.resnets.1.dropout, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.0.resnets.1.norm1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.0.resnets.1.norm2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.1.downsamplers.0, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.1.downsamplers.0.conv, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.1.resnets.0, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.1.resnets.0.conv1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.1.resnets.0.conv2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.1.resnets.0.conv_shortcut, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.1.resnets.0.dropout, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.1.resnets.0.norm1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.1.resnets.0.norm2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.1.resnets.1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.1.resnets.1.conv1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.1.resnets.1.conv2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.1.resnets.1.dropout, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.1.resnets.1.norm1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.1.resnets.1.norm2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.2.downsamplers.0, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.2.downsamplers.0.conv, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.2.resnets.0, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.2.resnets.0.conv1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.2.resnets.0.conv2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.2.resnets.0.conv_shortcut, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.2.resnets.0.dropout, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.2.resnets.0.norm1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.2.resnets.0.norm2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.2.resnets.1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.2.resnets.1.conv1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.2.resnets.1.conv2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.2.resnets.1.dropout, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.2.resnets.1.norm1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.2.resnets.1.norm2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.3, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.3.resnets.0, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.3.resnets.0.conv1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.3.resnets.0.conv2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.3.resnets.0.dropout, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.3.resnets.0.norm1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.3.resnets.0.norm2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.3.resnets.1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.3.resnets.1.conv1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.3.resnets.1.conv2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.3.resnets.1.dropout, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.3.resnets.1.norm1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.down_blocks.3.resnets.1.norm2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.mid_block, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.mid_block.attentions.0, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.mid_block.attentions.0.group_norm, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.mid_block.attentions.0.to_k, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.mid_block.attentions.0.to_out.0, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.mid_block.attentions.0.to_out.1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.mid_block.attentions.0.to_q, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.mid_block.attentions.0.to_v, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.mid_block.resnets.0, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.mid_block.resnets.0.conv1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.mid_block.resnets.0.conv2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.mid_block.resnets.0.dropout, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.mid_block.resnets.0.norm1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.mid_block.resnets.0.norm2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.mid_block.resnets.1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.mid_block.resnets.1.conv1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.mid_block.resnets.1.conv2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.mid_block.resnets.1.dropout, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.mid_block.resnets.1.norm1, controlnet.nets.0.controlnet_cond_embedding.autoencoder.encoder.mid_block.resnets.1.norm2, controlnet.nets.0.controlnet_cond_embedding.autoencoder.post_quant_conv, controlnet.nets.0.controlnet_cond_embedding.autoencoder.quant_conv, controlnet.nets.2.controlnet_cond_embedding\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fvcore FLOPs: 2056.65 GFLOPs\n",
      "torchinfo FLOPs: 1224.51 GFLOPs\n",
      "exporting onnx model for unet and controlnets...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "d:\\EdgeStyle-main\\EdgeStyle-main\\model\\edgestyle_multicontrolnet.py:134: TracerWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).\n",
      "  zip(controlnet_cond, conditioning_scale, self.nets)\n",
      "d:\\EdgeStyle-main\\EdgeStyle-main\\model\\controllora.py:200: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n",
      "  if controlnet_cond.shape[2:] != sample.shape[2:]:\n",
      "c:\\Users\\amanj\\miniconda3\\envs\\ryzenai-1.1-20240510-134534\\lib\\site-packages\\diffusers\\models\\downsampling.py:135: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n",
      "  assert hidden_states.shape[1] == self.channels\n",
      "c:\\Users\\amanj\\miniconda3\\envs\\ryzenai-1.1-20240510-134534\\lib\\site-packages\\diffusers\\models\\downsampling.py:144: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n",
      "  assert hidden_states.shape[1] == self.channels\n",
      "d:\\EdgeStyle-main\\EdgeStyle-main\\model\\edgestyle_multicontrolnet.py:490: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n",
      "  t.size() == tensors[0].size() for t in tensors\n",
      "c:\\Users\\amanj\\miniconda3\\envs\\ryzenai-1.1-20240510-134534\\lib\\site-packages\\diffusers\\models\\unets\\unet_2d_condition.py:924: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n",
      "  if dim % default_overall_up_factor != 0:\n",
      "c:\\Users\\amanj\\miniconda3\\envs\\ryzenai-1.1-20240510-134534\\lib\\site-packages\\diffusers\\models\\upsampling.py:149: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n",
      "  assert hidden_states.shape[1] == self.channels\n",
      "c:\\Users\\amanj\\miniconda3\\envs\\ryzenai-1.1-20240510-134534\\lib\\site-packages\\diffusers\\models\\upsampling.py:165: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n",
      "  if hidden_states.shape[0] >= 64:\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "exported onnx model for unet and controlnets...\n",
      "checking onnx model...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\Users\\amanj\\miniconda3\\envs\\ryzenai-1.1-20240510-134534\\lib\\site-packages\\onnxruntime\\capi\\onnxruntime_inference_collection.py:65: UserWarning: Specified provider 'VitisAIExecutionProvider' is not in available provider names.Available providers: 'TensorrtExecutionProvider, CUDAExecutionProvider, CPUExecutionProvider'\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "checking how close the output is...\n"
     ]
    },
    {
     "ename": "AssertionError",
     "evalue": "\nNot equal to tolerance rtol=0.001, atol=1e-05\n\nMismatched elements: 32764 / 32768 (100%)\nMax absolute difference: 3.1349723\nMax relative difference: 2354816.2\n x: array([[[[ 1.503965,  0.324597, -0.027452, ..., -0.186348,  0.139646,\n          -0.41074 ],\n         [-0.176023, -1.568672,  0.783047, ...,  0.124436, -0.284108,...\n y: array([[[[ 0.029409, -0.021185, -0.016922, ..., -0.035758, -0.05489 ,\n          -0.031097],\n         [-0.013142, -0.054783, -0.002355, ..., -0.0032  , -0.029497,...",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mAssertionError\u001b[0m                            Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[12], line 1\u001b[0m\n\u001b[1;32m----> 1\u001b[0m \u001b[43mexport_unet\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m      2\u001b[0m \u001b[38;5;66;03m# export_vae_encoder_decoder()\u001b[39;00m\n\u001b[0;32m      3\u001b[0m \u001b[38;5;66;03m# onnx_model_path = os.path.join(\u001b[39;00m\n\u001b[0;32m      4\u001b[0m \u001b[38;5;66;03m#     ONNX_MODEL_NAME_OR_PATH, \"unet\", \"model.onnx\"\u001b[39;00m\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m     11\u001b[0m \n\u001b[0;32m     12\u001b[0m \u001b[38;5;66;03m# print(\"Model exported and validated successfully.\")\u001b[39;00m\n",
      "File \u001b[1;32mc:\\Users\\amanj\\miniconda3\\envs\\ryzenai-1.1-20240510-134534\\lib\\site-packages\\torch\\utils\\_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m    112\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[0;32m    113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m    114\u001b[0m     \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[1;32m--> 115\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m func(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
      "Cell \u001b[1;32mIn[11], line 333\u001b[0m, in \u001b[0;36mexport_unet\u001b[1;34m()\u001b[0m\n\u001b[0;32m    330\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mchecking how close the output is...\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m    332\u001b[0m \u001b[38;5;66;03m# compare the output error predicted_noise_torch is FloatTensor and predicted_noise_onnx is numpy array\u001b[39;00m\n\u001b[1;32m--> 333\u001b[0m \u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtesting\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43massert_allclose\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m    334\u001b[0m \u001b[43m    \u001b[49m\u001b[43mpredicted_noise_torch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdetach\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcpu\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnumpy\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    335\u001b[0m \u001b[43m    \u001b[49m\u001b[43mpredicted_noise_onnx\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    336\u001b[0m \u001b[43m    \u001b[49m\u001b[43mrtol\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1e-03\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[0;32m    337\u001b[0m \u001b[43m    \u001b[49m\u001b[43matol\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1e-05\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[0;32m    338\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    339\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mexported onnx model for unet and controlnets is correct...\u001b[39m\u001b[38;5;124m'\u001b[39m)\n",
      "    \u001b[1;31m[... skipping hidden 1 frame]\u001b[0m\n",
      "File \u001b[1;32mc:\\Users\\amanj\\miniconda3\\envs\\ryzenai-1.1-20240510-134534\\lib\\contextlib.py:79\u001b[0m, in \u001b[0;36mContextDecorator.__call__.<locals>.inner\u001b[1;34m(*args, **kwds)\u001b[0m\n\u001b[0;32m     76\u001b[0m \u001b[38;5;129m@wraps\u001b[39m(func)\n\u001b[0;32m     77\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21minner\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwds):\n\u001b[0;32m     78\u001b[0m     \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_recreate_cm():\n\u001b[1;32m---> 79\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m func(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwds)\n",
      "File \u001b[1;32mc:\\Users\\amanj\\miniconda3\\envs\\ryzenai-1.1-20240510-134534\\lib\\site-packages\\numpy\\testing\\_private\\utils.py:797\u001b[0m, in \u001b[0;36massert_array_compare\u001b[1;34m(comparison, x, y, err_msg, verbose, header, precision, equal_nan, equal_inf, strict)\u001b[0m\n\u001b[0;32m    793\u001b[0m         err_msg \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m'\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(remarks)\n\u001b[0;32m    794\u001b[0m         msg \u001b[38;5;241m=\u001b[39m build_err_msg([ox, oy], err_msg,\n\u001b[0;32m    795\u001b[0m                             verbose\u001b[38;5;241m=\u001b[39mverbose, header\u001b[38;5;241m=\u001b[39mheader,\n\u001b[0;32m    796\u001b[0m                             names\u001b[38;5;241m=\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mx\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124my\u001b[39m\u001b[38;5;124m'\u001b[39m), precision\u001b[38;5;241m=\u001b[39mprecision)\n\u001b[1;32m--> 797\u001b[0m         \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAssertionError\u001b[39;00m(msg)\n\u001b[0;32m    798\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m:\n\u001b[0;32m    799\u001b[0m     \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtraceback\u001b[39;00m\n",
      "\u001b[1;31mAssertionError\u001b[0m: \nNot equal to tolerance rtol=0.001, atol=1e-05\n\nMismatched elements: 32764 / 32768 (100%)\nMax absolute difference: 3.1349723\nMax relative difference: 2354816.2\n x: array([[[[ 1.503965,  0.324597, -0.027452, ..., -0.186348,  0.139646,\n          -0.41074 ],\n         [-0.176023, -1.568672,  0.783047, ...,  0.124436, -0.284108,...\n y: array([[[[ 0.029409, -0.021185, -0.016922, ..., -0.035758, -0.05489 ,\n          -0.031097],\n         [-0.013142, -0.054783, -0.002355, ..., -0.0032  , -0.029497,..."
     ]
    }
   ],
   "source": [
    "export_unet()\n",
    "# export_vae_encoder_decoder()\n",
    "# onnx_model_path = os.path.join(\n",
    "#     ONNX_MODEL_NAME_OR_PATH, \"unet\", \"model.onnx\"\n",
    "# )\n",
    "# onnx_model_dir = os.path.join(onnx_model_path.replace(\"model.onnx\", \"\"))\n",
    "\n",
    "# import onnx\n",
    "# onnx_model = onnx.load(onnx_model_path)\n",
    "# onnx.checker.check_model(onnx_model)\n",
    "\n",
    "# print(\"Model exported and validated successfully.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "111777"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ryzenai-1.1-20240510-134534",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}

tester-onnx.ipynb

Python
Put this file in EdgeStyle-main folder
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gradio as gr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# torch\n",
    "import torch\n",
    "from torchvision import transforms\n",
    "from diffusers import AutoencoderKL, OnnxRuntimeModel, UniPCMultistepScheduler\n",
    "from diffusers.optimization import get_scheduler\n",
    "from transformers import AutoTokenizer, CLIPTextModel, CLIPModel, CLIPProcessor\n",
    "\n",
    "from model.utils import BestEmbeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# torch\n",
    "import torch\n",
    "from torchvision import transforms\n",
    "from diffusers import AutoencoderKL, OnnxRuntimeModel, UniPCMultistepScheduler\n",
    "from diffusers.optimization import get_scheduler\n",
    "from transformers import AutoTokenizer, CLIPTextModel, CLIPModel, CLIPProcessor\n",
    "\n",
    "from model.utils import BestEmbeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using cache found in C:\\Users\\amanj/.cache\\torch\\hub\\ultralytics_yolov5_master\n",
      "YOLOv5  2024-7-31 Python-3.9.18 torch-2.1.0+cpu CPU\n",
      "\n",
      "Fusing layers... \n",
      "YOLOv5s summary: 213 layers, 7225885 parameters, 0 gradients, 16.4 GFLOPs\n",
      "Adding AutoShape... \n"
     ]
    }
   ],
   "source": [
    "# local\n",
    "from model.controllora import ControlLoRAModel, CachedControlNetModel\n",
    "from model.utils import BestEmbeddings\n",
    "from extract_dataset import process_batch, create_sam_images_for_batch\n",
    "from model.edgestyle_onnx_pipeline import EdgeStyleOnnxStableDiffusionControlNetPipeline\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "RESOLUTION = 512\n",
    "\n",
    "IMAGES_TRANSFORMS = transforms.Compose(\n",
    "    [\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize([0.5], [0.5]),\n",
    "    ]\n",
    ")\n",
    "\n",
    "CONDITIONING_IMAGES_TRANSFORMS = transforms.Compose(\n",
    "    [\n",
    "        transforms.ToTensor(),\n",
    "    ]\n",
    ")\n",
    "\n",
    "CONTROLNET_PATTERN = [0, None, 1, None, 1, None]\n",
    "\n",
    "device = torch.device(\"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "ename": "HFValidationError",
     "evalue": "Repo id must be in the form 'repo_name' or 'namespace/repo_name': './models/Realistic_Vision_V5.1_noVAE-onnx/text_encoder'. Use `repo_type` argument if needed.",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mHFValidationError\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[9], line 31\u001b[0m\n\u001b[0;32m     22\u001b[0m best_embeddings \u001b[38;5;241m=\u001b[39m BestEmbeddings(model, processor)\n\u001b[0;32m     25\u001b[0m tokenizer \u001b[38;5;241m=\u001b[39m AutoTokenizer\u001b[38;5;241m.\u001b[39mfrom_pretrained(\n\u001b[0;32m     26\u001b[0m     PRETRAINED_MODEL_NAME_OR_PATH,\n\u001b[0;32m     27\u001b[0m     subfolder\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtokenizer\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[0;32m     28\u001b[0m     use_fast\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[0;32m     29\u001b[0m )\n\u001b[1;32m---> 31\u001b[0m text_encoder \u001b[38;5;241m=\u001b[39m \u001b[43mOnnxRuntimeModel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_pretrained\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m     32\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m./models/Realistic_Vision_V5.1_noVAE-onnx/text_encoder\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\n\u001b[0;32m     33\u001b[0m \u001b[43m)\u001b[49m\n\u001b[0;32m     34\u001b[0m \u001b[38;5;66;03m# vae = AutoencoderKL.from_pretrained(PRETRAINED_VAE_NAME_OR_PATH)\u001b[39;00m\n\u001b[0;32m     35\u001b[0m \n\u001b[0;32m     36\u001b[0m \u001b[38;5;66;03m# unet = UNet2DConditionModel.from_pretrained(\u001b[39;00m\n\u001b[0;32m     37\u001b[0m \u001b[38;5;66;03m#     PRETRAINED_MODEL_NAME_OR_PATH,\u001b[39;00m\n\u001b[0;32m     38\u001b[0m \u001b[38;5;66;03m#     subfolder=\"unet\",\u001b[39;00m\n\u001b[0;32m     39\u001b[0m \u001b[38;5;66;03m# )\u001b[39;00m\n\u001b[0;32m     41\u001b[0m vae_encoder \u001b[38;5;241m=\u001b[39m OnnxRuntimeModel\u001b[38;5;241m.\u001b[39mfrom_pretrained(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m./models/sd-vae-ft-mse-onnx/encoder\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
      "File \u001b[1;32mc:\\Users\\amanj\\miniconda3\\envs\\ryzenai-1.1-20240510-134534\\lib\\site-packages\\huggingface_hub\\utils\\_validators.py:114\u001b[0m, in \u001b[0;36mvalidate_hf_hub_args.<locals>._inner_fn\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m    111\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m check_use_auth_token:\n\u001b[0;32m    112\u001b[0m     kwargs \u001b[38;5;241m=\u001b[39m smoothly_deprecate_use_auth_token(fn_name\u001b[38;5;241m=\u001b[39mfn\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m, has_token\u001b[38;5;241m=\u001b[39mhas_token, kwargs\u001b[38;5;241m=\u001b[39mkwargs)\n\u001b[1;32m--> 114\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m fn(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
      "File \u001b[1;32mc:\\Users\\amanj\\miniconda3\\envs\\ryzenai-1.1-20240510-134534\\lib\\site-packages\\diffusers\\pipelines\\onnx_utils.py:208\u001b[0m, in \u001b[0;36mOnnxRuntimeModel.from_pretrained\u001b[1;34m(cls, model_id, force_download, token, cache_dir, **model_kwargs)\u001b[0m\n\u001b[0;32m    205\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(\u001b[38;5;28mstr\u001b[39m(model_id)\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m@\u001b[39m\u001b[38;5;124m\"\u001b[39m)) \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m2\u001b[39m:\n\u001b[0;32m    206\u001b[0m     model_id, revision \u001b[38;5;241m=\u001b[39m model_id\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m@\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m--> 208\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m_from_pretrained(\n\u001b[0;32m    209\u001b[0m     model_id\u001b[38;5;241m=\u001b[39mmodel_id,\n\u001b[0;32m    210\u001b[0m     revision\u001b[38;5;241m=\u001b[39mrevision,\n\u001b[0;32m    211\u001b[0m     cache_dir\u001b[38;5;241m=\u001b[39mcache_dir,\n\u001b[0;32m    212\u001b[0m     force_download\u001b[38;5;241m=\u001b[39mforce_download,\n\u001b[0;32m    213\u001b[0m     token\u001b[38;5;241m=\u001b[39mtoken,\n\u001b[0;32m    214\u001b[0m     \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mmodel_kwargs,\n\u001b[0;32m    215\u001b[0m )\n",
      "File \u001b[1;32mc:\\Users\\amanj\\miniconda3\\envs\\ryzenai-1.1-20240510-134534\\lib\\site-packages\\huggingface_hub\\utils\\_validators.py:114\u001b[0m, in \u001b[0;36mvalidate_hf_hub_args.<locals>._inner_fn\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m    111\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m check_use_auth_token:\n\u001b[0;32m    112\u001b[0m     kwargs \u001b[38;5;241m=\u001b[39m smoothly_deprecate_use_auth_token(fn_name\u001b[38;5;241m=\u001b[39mfn\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m, has_token\u001b[38;5;241m=\u001b[39mhas_token, kwargs\u001b[38;5;241m=\u001b[39mkwargs)\n\u001b[1;32m--> 114\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m fn(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
      "File \u001b[1;32mc:\\Users\\amanj\\miniconda3\\envs\\ryzenai-1.1-20240510-134534\\lib\\site-packages\\diffusers\\pipelines\\onnx_utils.py:181\u001b[0m, in \u001b[0;36mOnnxRuntimeModel._from_pretrained\u001b[1;34m(cls, model_id, token, revision, force_download, cache_dir, file_name, provider, sess_options, **kwargs)\u001b[0m\n\u001b[0;32m    177\u001b[0m     kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodel_save_dir\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m Path(model_id)\n\u001b[0;32m    178\u001b[0m \u001b[38;5;66;03m# load model from hub\u001b[39;00m\n\u001b[0;32m    179\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m    180\u001b[0m     \u001b[38;5;66;03m# download model\u001b[39;00m\n\u001b[1;32m--> 181\u001b[0m     model_cache_path \u001b[38;5;241m=\u001b[39m \u001b[43mhf_hub_download\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m    182\u001b[0m \u001b[43m        \u001b[49m\u001b[43mrepo_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    183\u001b[0m \u001b[43m        \u001b[49m\u001b[43mfilename\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel_file_name\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    184\u001b[0m \u001b[43m        \u001b[49m\u001b[43mtoken\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtoken\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    185\u001b[0m \u001b[43m        \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrevision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    186\u001b[0m \u001b[43m        \u001b[49m\u001b[43mcache_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    187\u001b[0m \u001b[43m        \u001b[49m\u001b[43mforce_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mforce_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    188\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    189\u001b[0m     kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodel_save_dir\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m Path(model_cache_path)\u001b[38;5;241m.\u001b[39mparent\n\u001b[0;32m    190\u001b[0m     kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlatest_model_name\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m Path(model_cache_path)\u001b[38;5;241m.\u001b[39mname\n",
      "File \u001b[1;32mc:\\Users\\amanj\\miniconda3\\envs\\ryzenai-1.1-20240510-134534\\lib\\site-packages\\huggingface_hub\\utils\\_deprecation.py:101\u001b[0m, in \u001b[0;36m_deprecate_arguments.<locals>._inner_deprecate_positional_args.<locals>.inner_f\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m     99\u001b[0m         message \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m+\u001b[39m custom_message\n\u001b[0;32m    100\u001b[0m     warnings\u001b[38;5;241m.\u001b[39mwarn(message, \u001b[38;5;167;01mFutureWarning\u001b[39;00m)\n\u001b[1;32m--> 101\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m f(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
      "File \u001b[1;32mc:\\Users\\amanj\\miniconda3\\envs\\ryzenai-1.1-20240510-134534\\lib\\site-packages\\huggingface_hub\\utils\\_validators.py:106\u001b[0m, in \u001b[0;36mvalidate_hf_hub_args.<locals>._inner_fn\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m    101\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m arg_name, arg_value \u001b[38;5;129;01min\u001b[39;00m chain(\n\u001b[0;32m    102\u001b[0m     \u001b[38;5;28mzip\u001b[39m(signature\u001b[38;5;241m.\u001b[39mparameters, args),  \u001b[38;5;66;03m# Args values\u001b[39;00m\n\u001b[0;32m    103\u001b[0m     kwargs\u001b[38;5;241m.\u001b[39mitems(),  \u001b[38;5;66;03m# Kwargs values\u001b[39;00m\n\u001b[0;32m    104\u001b[0m ):\n\u001b[0;32m    105\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m arg_name \u001b[38;5;129;01min\u001b[39;00m [\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrepo_id\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfrom_id\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mto_id\u001b[39m\u001b[38;5;124m\"\u001b[39m]:\n\u001b[1;32m--> 106\u001b[0m         \u001b[43mvalidate_repo_id\u001b[49m\u001b[43m(\u001b[49m\u001b[43marg_value\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    108\u001b[0m     \u001b[38;5;28;01melif\u001b[39;00m arg_name \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtoken\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m arg_value \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m    109\u001b[0m         has_token \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n",
      "File \u001b[1;32mc:\\Users\\amanj\\miniconda3\\envs\\ryzenai-1.1-20240510-134534\\lib\\site-packages\\huggingface_hub\\utils\\_validators.py:154\u001b[0m, in \u001b[0;36mvalidate_repo_id\u001b[1;34m(repo_id)\u001b[0m\n\u001b[0;32m    151\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m HFValidationError(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mRepo id must be a string, not \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(repo_id)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m: \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mrepo_id\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m    153\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m repo_id\u001b[38;5;241m.\u001b[39mcount(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m/\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m--> 154\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m HFValidationError(\n\u001b[0;32m    155\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mRepo id must be in the form \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mrepo_name\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m or \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mnamespace/repo_name\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m:\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m    156\u001b[0m         \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mrepo_id\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m. Use `repo_type` argument if needed.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m    157\u001b[0m     )\n\u001b[0;32m    159\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m REPO_ID_REGEX\u001b[38;5;241m.\u001b[39mmatch(repo_id):\n\u001b[0;32m    160\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m HFValidationError(\n\u001b[0;32m    161\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mRepo id must use alphanumeric chars or \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m-\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m_\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m--\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m and \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m..\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m are\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m    162\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m forbidden, \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m-\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m and \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m cannot start or end the name, max length is 96:\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m    163\u001b[0m         \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mrepo_id\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m    164\u001b[0m     )\n",
      "\u001b[1;31mHFValidationError\u001b[0m: Repo id must be in the form 'repo_name' or 'namespace/repo_name': './models/Realistic_Vision_V5.1_noVAE-onnx/text_encoder'. Use `repo_type` argument if needed."
     ]
    }
   ],
   "source": [
    "PRETRAINED_MODEL_NAME_OR_PATH = \"./models/Realistic_Vision_V5.1_noVAE\"\n",
    "PRETRAINED_VAE_NAME_OR_PATH = \"./models/sd-vae-ft-mse\"\n",
    "PRETRAINED_OPENPOSE_NAME_OR_PATH = \"./models/control_v11p_sd15_openpose\"\n",
    "CONTROLNET_MODEL_NAME_OR_PATH = \"./models/EdgeStyle/controlnet\"\n",
    "CLIP_MODEL_NAME_OR_PATH = \"./models/clip-vit-large-patch14\"\n",
    "\n",
    "\n",
    "NEGATIVE_PROMPT = (\n",
    "    r\"deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, \"\n",
    "    \"anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, \"\n",
    "    \"bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, \"\n",
    "    \"mutated, ugly, disgusting, amputation,\"\n",
    ")\n",
    "\n",
    "PROMT_TO_ADD = (\n",
    "    \", gray background, RAW photo, subject, 8k uhd, dslr, soft lighting, high quality\"\n",
    ")\n",
    "\n",
    "model = CLIPModel.from_pretrained(CLIP_MODEL_NAME_OR_PATH)\n",
    "processor = CLIPProcessor.from_pretrained(CLIP_MODEL_NAME_OR_PATH)\n",
    "\n",
    "best_embeddings = BestEmbeddings(model, processor)\n",
    "\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\n",
    "    PRETRAINED_MODEL_NAME_OR_PATH,\n",
    "    subfolder=\"tokenizer\",\n",
    "    use_fast=False,\n",
    ")\n",
    "\n",
    "text_encoder = OnnxRuntimeModel.from_pretrained(\n",
    "    \"./models/Realistic_Vision_V5.1_noVAE-onnx/text_encoder\"\n",
    ")\n",
    "# vae = AutoencoderKL.from_pretrained(PRETRAINED_VAE_NAME_OR_PATH)\n",
    "\n",
    "# unet = UNet2DConditionModel.from_pretrained(\n",
    "#     PRETRAINED_MODEL_NAME_OR_PATH,\n",
    "#     subfolder=\"unet\",\n",
    "# )\n",
    "\n",
    "vae_encoder = OnnxRuntimeModel.from_pretrained(\"./models/sd-vae-ft-mse-onnx/encoder\")\n",
    "vae_decoder = OnnxRuntimeModel.from_pretrained(\"./models/sd-vae-ft-mse-onnx/decoder\")\n",
    "\n",
    "\n",
    "unet = OnnxRuntimeModel.from_pretrained(\n",
    "    \"./models/Realistic_Vision_V5.1_noVAE-onnx/unet\",\n",
    ")\n",
    "\n",
    "\n",
    "openpose = CachedControlNetModel.from_pretrained(PRETRAINED_OPENPOSE_NAME_OR_PATH)\n",
    "\n",
    "# controlnet = EdgeStyleMultiControlNetModel.from_pretrained(\n",
    "#     CONTROLNET_MODEL_NAME_OR_PATH,\n",
    "#     vae=vae,\n",
    "#     controlnet_class=ControlLoRAModel,\n",
    "#     load_pattern=CONTROLNET_PATTERN,\n",
    "#     static_controlnets=[None, openpose, None, openpose, None, openpose],\n",
    "# )\n",
    "# for net in controlnet.nets:\n",
    "#     if net is not openpose:\n",
    "#         net.tie_weights(unet)\n",
    "\n",
    "scheduler = UniPCMultistepScheduler.from_config(\n",
    "    PRETRAINED_MODEL_NAME_OR_PATH, subfolder=\"scheduler\"\n",
    ")\n",
    "\n",
    "pipeline = EdgeStyleOnnxStableDiffusionControlNetPipeline(\n",
    "    vae_encoder=vae_encoder,\n",
    "    vae_decoder=vae_decoder,\n",
    "    text_encoder=text_encoder,\n",
    "    tokenizer=tokenizer,\n",
    "    unet=unet,\n",
    "    scheduler=scheduler,\n",
    "    safety_checker=None,\n",
    "    feature_extractor=processor,\n",
    "    requires_safety_checker=False,\n",
    ")\n",
    "# generator = torch.Generator(device).manual_seed(42)\n",
    "pipeline = pipeline.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.onnx"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ryzenai-1.1-20240510-134534",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}

Code

Credits

Aman Jobanputra
1 project • 0 followers
Thanks to Andrei Ciobanu.

Comments