Buckets:

download
raw
104 kB
<meta charset="utf-8" /><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;减少内存使用&quot;,&quot;local&quot;:&quot;减少内存使用&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;多个 GPU&quot;,&quot;local&quot;:&quot;多个-gpu&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;分片检查点&quot;,&quot;local&quot;:&quot;分片检查点&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;设备放置&quot;,&quot;local&quot;:&quot;设备放置&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;VAE 切片&quot;,&quot;local&quot;:&quot;vae-切片&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;VAE 平铺&quot;,&quot;local&quot;:&quot;vae-平铺&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;卸载&quot;,&quot;local&quot;:&quot;卸载&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;CPU 卸载&quot;,&quot;local&quot;:&quot;cpu-卸载&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;模型卸载&quot;,&quot;local&quot;:&quot;模型卸载&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;组卸载&quot;,&quot;local&quot;:&quot;组卸载&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;CUDA 流&quot;,&quot;local&quot;:&quot;cuda-流&quot;,&quot;sections&quot;:[],&quot;depth&quot;:4},{&quot;title&quot;:&quot;卸载到磁盘&quot;,&quot;local&quot;:&quot;卸载到磁盘&quot;,&quot;sections&quot;:[],&quot;depth&quot;:4}],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;分层类型转换&quot;,&quot;local&quot;:&quot;分层类型转换&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;torch.channels_last&quot;,&quot;local&quot;:&quot;torchchannelslast&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;torch.jit.trace&quot;,&quot;local&quot;:&quot;torchjittrace&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;内存高效注意力&quot;,&quot;local&quot;:&quot;内存高效注意力&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}">
<link href="/docs/diffusers/pr_12652/zh/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/zh/_app/immutable/entry/start.ca7a833f.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/zh/_app/immutable/chunks/scheduler.e4ff9b64.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/zh/_app/immutable/chunks/singletons.71526a34.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/zh/_app/immutable/chunks/index.f9be34a7.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/zh/_app/immutable/chunks/paths.0df57e7f.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/zh/_app/immutable/entry/app.746b83f3.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/zh/_app/immutable/chunks/preload-helper.bb94e341.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/zh/_app/immutable/chunks/index.09f1bca0.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/zh/_app/immutable/nodes/0.8237e20e.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/zh/_app/immutable/chunks/each.e59479a4.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/zh/_app/immutable/nodes/27.ab2341a8.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/zh/_app/immutable/chunks/MermaidChart.svelte_svelte_type_style_lang.3bffcf96.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/zh/_app/immutable/chunks/CodeBlock.3dd9a65d.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/zh/_app/immutable/chunks/HfOption.44827c7f.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;减少内存使用&quot;,&quot;local&quot;:&quot;减少内存使用&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;多个 GPU&quot;,&quot;local&quot;:&quot;多个-gpu&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;分片检查点&quot;,&quot;local&quot;:&quot;分片检查点&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;设备放置&quot;,&quot;local&quot;:&quot;设备放置&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;VAE 切片&quot;,&quot;local&quot;:&quot;vae-切片&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;VAE 平铺&quot;,&quot;local&quot;:&quot;vae-平铺&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;卸载&quot;,&quot;local&quot;:&quot;卸载&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;CPU 卸载&quot;,&quot;local&quot;:&quot;cpu-卸载&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;模型卸载&quot;,&quot;local&quot;:&quot;模型卸载&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;组卸载&quot;,&quot;local&quot;:&quot;组卸载&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;CUDA 流&quot;,&quot;local&quot;:&quot;cuda-流&quot;,&quot;sections&quot;:[],&quot;depth&quot;:4},{&quot;title&quot;:&quot;卸载到磁盘&quot;,&quot;local&quot;:&quot;卸载到磁盘&quot;,&quot;sections&quot;:[],&quot;depth&quot;:4}],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;分层类型转换&quot;,&quot;local&quot;:&quot;分层类型转换&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;torch.channels_last&quot;,&quot;local&quot;:&quot;torchchannelslast&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;torch.jit.trace&quot;,&quot;local&quot;:&quot;torchjittrace&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;内存高效注意力&quot;,&quot;local&quot;:&quot;内存高效注意力&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <div class="items-center shrink-0 min-w-[100px] max-sm:min-w-[50px] justify-end ml-auto flex" style="float: right; margin-left: 10px; display: inline-flex; position: relative; z-index: 10;"><div class="inline-flex rounded-md max-sm:rounded-sm"><button class="inline-flex items-center gap-1 h-7 max-sm:h-7 px-2 max-sm:px-1.5 text-sm font-medium text-gray-800 border border-r-0 rounded-l-md max-sm:rounded-l-sm border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-live="polite"><span class="inline-flex items-center justify-center rounded-md p-0.5 max-sm:p-0 hover:text-gray-800 dark:hover:text-gray-200"><svg class="sm:size-3.5 size-3" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg></span> <span>Copy page</span></button> <button class="inline-flex items-center justify-center w-6 max-sm:w-5 h-7 max-sm:h-7 disabled:pointer-events-none text-sm text-gray-500 hover:text-gray-700 dark:hover:text-white rounded-r-md max-sm:rounded-r-sm border border-l transition border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-haspopup="menu" aria-expanded="false" aria-label="Open copy menu"><svg class="transition-transform text-gray-400 overflow-visible sm:size-3.5 size-3 rotate-0" width="1em" height="1em" viewBox="0 0 12 7" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M1 1L6 6L11 1" stroke="currentColor"></path></svg></button></div> </div> <h1 class="relative group"><a id="减少内存使用" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#减少内存使用"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>减少内存使用</span></h1> <p data-svelte-h="svelte-kg0q1b">现代diffusion models,如 <a href="../api/pipelines/flux">Flux</a><a href="../api/pipelines/wan">Wan</a>,拥有数十亿参数,在您的硬件上进行推理时会占用大量内存。这是一个挑战,因为常见的 GPU 通常没有足够的内存。为了克服内存限制,您可以使用多个 GPU(如果可用)、将一些管道组件卸载到 CPU 等。</p> <p data-svelte-h="svelte-i3qlfe">本指南将展示如何减少内存使用。</p> <blockquote class="tip" data-svelte-h="svelte-1j60o52"><p>请记住,这些技术可能需要根据模型进行调整。例如,基于 transformer 的扩散模型可能不会像基于 UNet 的模型那样从这些内存优化中同等受益。</p></blockquote> <h2 class="relative group"><a id="多个-gpu" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#多个-gpu"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>多个 GPU</span></h2> <p data-svelte-h="svelte-13on0j6">如果您有多个 GPU 的访问权限,有几种选项可以高效地在硬件上加载和分发大型模型。这些功能由 <a href="https://huggingface.co/docs/accelerate/index" rel="nofollow">Accelerate</a> 库支持,因此请确保先安装它。</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->pip install -U accelerate<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="分片检查点" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#分片检查点"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>分片检查点</span></h3> <p data-svelte-h="svelte-1080qfz">将大型检查点加载到多个分片中很有用,因为分片会逐个加载。这保持了低内存使用,只需要足够的内存来容纳模型大小和最大分片大小。我们建议当 fp32 检查点大于 5GB 时进行分片。默认分片大小为 5GB。</p> <p data-svelte-h="svelte-1ji9awf"><code>save_pretrained()</code> 中使用 <code>max_shard_size</code> 参数对检查点进行分片。</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> AutoModel
unet = AutoModel.from_pretrained(
<span class="hljs-string">&quot;stabilityai/stable-diffusion-xl-base-1.0&quot;</span>, subfolder=<span class="hljs-string">&quot;unet&quot;</span>
)
unet.save_pretrained(<span class="hljs-string">&quot;sdxl-unet-sharded&quot;</span>, max_shard_size=<span class="hljs-string">&quot;5GB&quot;</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1q2d295">现在您可以使用分片检查点,而不是常规检查点,以节省内存。</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> AutoModel, StableDiffusionXLPipeline
unet = AutoModel.from_pretrained(
<span class="hljs-string">&quot;username/sdxl-unet-sharded&quot;</span>, torch_dtype=torch.float16
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
<span class="hljs-string">&quot;stabilityai/stable-diffusion-xl-base-1.0&quot;</span>,
unet=unet,
torch_dtype=torch.float16
).to(<span class="hljs-string">&quot;cuda&quot;</span>)<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="设备放置" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#设备放置"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>设备放置</span></h3> <blockquote class="warning" data-svelte-h="svelte-1v88325"><p>设备放置是一个实验性功能,API 可能会更改。目前仅支持 <code>balanced</code> 策略。我们计划在未来支持额外的映射策略。</p></blockquote> <p data-svelte-h="svelte-1yx7fp"><code>device_map</code> 参数控制管道或模型中的组件如何
单个模型中的层分布在多个设备上。</p> <div class="flex space-x-2 items-center my-1.5 mr-8 h-7 !pl-0 -mx-3 md:mx-0"><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd border-gray-800 bg-black dark:bg-gray-700 text-white">pipeline level </div><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd text-gray-500 cursor-pointer opacity-90 hover:text-gray-700 dark:hover:text-gray-200 hover:shadow-sm">model level </div></div> <div class="language-select"><p data-svelte-h="svelte-on9lpr"><code>balanced</code> 设备放置策略将管道均匀分割到所有可用设备上。</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> AutoModel, StableDiffusionXLPipeline
pipeline = StableDiffusionXLPipeline.from_pretrained(
<span class="hljs-string">&quot;stabilityai/stable-diffusion-xl-base-1.0&quot;</span>,
torch_dtype=torch.float16,
device_map=<span class="hljs-string">&quot;balanced&quot;</span>
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-7k1fi2">您可以使用 <code>hf_device_map</code> 检查管道的设备映射。</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-built_in">print</span>(pipeline.hf_device_map)
{<span class="hljs-string">&#x27;unet&#x27;</span>: <span class="hljs-number">1</span>, <span class="hljs-string">&#x27;vae&#x27;</span>: <span class="hljs-number">1</span>, <span class="hljs-string">&#x27;safety_checker&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;text_encoder&#x27;</span>: <span class="hljs-number">0</span>}<!-- HTML_TAG_END --></pre></div> </div> <p data-svelte-h="svelte-5mk4n7">当设计您自己的 <code>device_map</code> 时,它应该是一个字典,包含模型的特定模块名称或层以及设备标识符(整数表示 GPU,<code>cpu</code> 表示 CPU,<code>disk</code> 表示磁盘)。</p> <p data-svelte-h="svelte-n6hxwz">在模型上调用 <code>hf_device_map</code> 以查看模型层如何分布,然后设计您自己的映射。</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-built_in">print</span>(transformer.hf_device_map)
{<span class="hljs-string">&#x27;pos_embed&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;time_text_embed&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;context_embedder&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;x_embedder&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;transformer_blocks&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;single_transformer_blocks.0&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;single_transformer_blocks.1&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;single_transformer_blocks.2&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;single_transformer_blocks.3&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;single_transformer_blocks.4&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;single_transformer_blocks.5&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;single_transformer_blocks.6&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;single_transformer_blocks.7&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;single_transformer_blocks.8&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;single_transformer_blocks.9&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;single_transformer_blocks.10&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.11&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.12&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.13&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.14&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.15&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.16&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.17&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.18&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.19&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.20&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.21&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.22&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.23&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.24&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.25&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.26&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.27&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.28&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.29&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.30&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.31&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.32&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.33&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.34&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.35&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.36&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.37&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;norm_out&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;proj_out&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-yo76tu">例如,下面的 <code>device_map</code><code>single_transformer_blocks.10</code><code>single_transformer_blocks.20</code> 放置在第二个 GPU(<code>1</code>)上。</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> AutoModel
device_map = {
<span class="hljs-string">&#x27;pos_embed&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;time_text_embed&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;context_embedder&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;x_embedder&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;transformer_blocks&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;single_transformer_blocks.0&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;single_transformer_blocks.1&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;single_transformer_blocks.2&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;single_transformer_blocks.3&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;single_transformer_blocks.4&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;single_transformer_blocks.5&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;single_transformer_blocks.6&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;single_transformer_blocks.7&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;single_transformer_blocks.8&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;single_transformer_blocks.9&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;single_transformer_blocks.10&#x27;</span>: <span class="hljs-number">1</span>, <span class="hljs-string">&#x27;single_transformer_blocks.11&#x27;</span>: <span class="hljs-number">1</span>, <span class="hljs-string">&#x27;single_transformer_blocks.12&#x27;</span>: <span class="hljs-number">1</span>, <span class="hljs-string">&#x27;single_transformer_blocks.13&#x27;</span>: <span class="hljs-number">1</span>, <span class="hljs-string">&#x27;single_transformer_blocks.14&#x27;</span>: <span class="hljs-number">1</span>, <span class="hljs-string">&#x27;single_transformer_blocks.15&#x27;</span>: <span class="hljs-number">1</span>, <span class="hljs-string">&#x27;single_transformer_blocks.16&#x27;</span>: <span class="hljs-number">1</span>, <span class="hljs-string">&#x27;single_transformer_blocks.17&#x27;</span>: <span class="hljs-number">1</span>, <span class="hljs-string">&#x27;single_transformer_blocks.18&#x27;</span>: <span class="hljs-number">1</span>, <span class="hljs-string">&#x27;single_transformer_blocks.19&#x27;</span>: <span class="hljs-number">1</span>, <span class="hljs-string">&#x27;single_transformer_blocks.20&#x27;</span>: <span class="hljs-number">1</span>, <span class="hljs-string">&#x27;single_transformer_blocks.21&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.22&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.23&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.24&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.25&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.26&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.27&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.28&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.29&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.30&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.31&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.32&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.33&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.34&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.35&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.36&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.37&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;norm_out&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;proj_out&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>
}
transformer = AutoModel.from_pretrained(
<span class="hljs-string">&quot;black-forest-labs/FLUX.1-dev&quot;</span>,
subfolder=<span class="hljs-string">&quot;transformer&quot;</span>,
device_map=device_map,
torch_dtype=torch.bfloat16
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-ze8jh9">传递一个字典,将最大内存使用量映射到每个设备以强制执行限制。如果设备不在 <code>max_memory</code> 中,它将被忽略,管道组件不会分发到该设备。</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> AutoModel, StableDiffusionXLPipeline
max_memory = {<span class="hljs-number">0</span>:<span class="hljs-string">&quot;1GB&quot;</span>, <span class="hljs-number">1</span>:<span class="hljs-string">&quot;1GB&quot;</span>}
pipeline = StableDiffusionXLPipeline.from_pretrained(
<span class="hljs-string">&quot;stabilityai/stable-diffusion-xl-base-1.0&quot;</span>,
torch_dtype=torch.float16,
device_map=<span class="hljs-string">&quot;balanced&quot;</span>,
max_memory=max_memory
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1fbdini">Diffusers 默认使用所有设备的最大内存,但如果它们无法适应 GPU,则需要使用单个 GPU 并通过以下方法卸载到 CPU。</p> <ul data-svelte-h="svelte-15h2fme"><li><code>enable_model_cpu_offload()</code> 仅适用于单个 GPU,但非常大的模型可能无法适应它</li> <li>使用 <code>enable_sequential_cpu_offload()</code> 可能有效,但它极其缓慢,并且仅限于单个 GPU。</li></ul> <p data-svelte-h="svelte-t6jreo">使用 <code>reset_device_map()</code> 方法来重置 <code>device_map</code>。如果您想在已进行设备映射的管道上使用方法如 <code>.to()</code><code>enable_sequential_cpu_offload()</code><code>enable_model_cpu_offload()</code>,这是必要的。</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->pipeline.reset_device_map()<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="vae-切片" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#vae-切片"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>VAE 切片</span></h2> <p data-svelte-h="svelte-hhjzz5">VAE 切片通过将大批次输入拆分为单个数据批次并分别处理它们来节省内存。这种方法在同时生成多个图像时效果最佳。</p> <p data-svelte-h="svelte-azd05b">例如,如果您同时生成 4 个图像,解码会将峰值激活内存增加 4 倍。VAE 切片通过一次只解码 1 个图像而不是所有 4 个图像来减少这种情况。</p> <p data-svelte-h="svelte-1jcre3h">调用 <code>enable_vae_slicing()</code> 来启用切片 VAE。您可以预期在解码多图像批次时性能会有小幅提升,而在单图像批次时没有性能影响。</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> AutoModel, StableDiffusionXLPipeline
pipeline = StableDiffusionXLPipeline.from_pretrained(
<span class="hljs-string">&quot;stabilityai/stable-diffusion-xl-base-1.0&quot;</span>,
torch_dtype=torch.float16,
).to(<span class="hljs-string">&quot;cuda&quot;</span>)
pipeline.enable_vae_slicing()
pipeline([<span class="hljs-string">&quot;An astronaut riding a horse on Mars&quot;</span>]*<span class="hljs-number">32</span>).images[<span class="hljs-number">0</span>]
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;Max memory reserved: <span class="hljs-subst">{torch.cuda.max_memory_allocated() / <span class="hljs-number">1024</span>**<span class="hljs-number">3</span>:<span class="hljs-number">.2</span>f}</span> GB&quot;</span>)<!-- HTML_TAG_END --></pre></div> <blockquote class="warning" data-svelte-h="svelte-u9dy4s"><p><code>AutoencoderKLWan</code><code>AsymmetricAutoencoderKL</code> 类不支持切片。</p></blockquote> <h2 class="relative group"><a id="vae-平铺" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#vae-平铺"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>VAE 平铺</span></h2> <p data-svelte-h="svelte-1n49s7i">VAE 平铺通过将图像划分为较小的重叠图块而不是一次性处理整个图像来节省内存。这也减少了峰值内存使用量,因为 GPU 一次只处理一个图块。</p> <p data-svelte-h="svelte-9l3su6">调用 <code>enable_vae_tiling()</code> 来启用 VAE 平铺。生成的图像可能因图块到图块的色调变化而有所不同,因为它们被单独解码,但图块之间不应有明显的接缝。对于低于预设(但可配置)限制的分辨率,平铺被禁用。例如,对于 <code>StableDiffusionPipeline</code> 中的 VAE,此限制为 512x512。</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> AutoPipelineForImage2Image
<span class="hljs-keyword">from</span> diffusers.utils <span class="hljs-keyword">import</span> load_image
pipeline = AutoPipelineForImage2Image.from_pretrained(
<span class="hljs-string">&quot;stabilityai/stable-diffusion-xl-base-1.0&quot;</span>, torch_dtype=torch.float16
).to(<span class="hljs-string">&quot;cuda&quot;</span>)
pipeline.enable_vae_tiling()
init_image = load_image(<span class="hljs-string">&quot;https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-sdxl-init.png&quot;</span>)
prompt = <span class="hljs-string">&quot;Astronaut in a jungle, cold color palette, muted colors, detailed, 8k&quot;</span>
pipeline(prompt, image=init_image, strength=<span class="hljs-number">0.5</span>).images[<span class="hljs-number">0</span>]
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;Max memory reserved: <span class="hljs-subst">{torch.cuda.max_memory_allocated() / <span class="hljs-number">1024</span>**<span class="hljs-number">3</span>:<span class="hljs-number">.2</span>f}</span> GB&quot;</span>)<!-- HTML_TAG_END --></pre></div> <blockquote class="warning" data-svelte-h="svelte-1w22rhs"><p><code>AutoencoderKLWan</code><code>AsymmetricAutoencoderKL</code> 不支持平铺。</p></blockquote> <h2 class="relative group"><a id="卸载" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#卸载"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>卸载</span></h2> <p data-svelte-h="svelte-18bl8b2">卸载策略将非当前活动层移动
将模型移动到 CPU 以避免增加 GPU 内存。这些策略可以与量化和 torch.compile 结合使用,以平衡推理速度和内存使用。</p> <p data-svelte-h="svelte-1o67dnx">有关更多详细信息,请参考 <a href="./speed-memory-optims">编译和卸载量化模型</a> 指南。</p> <h3 class="relative group"><a id="cpu-卸载" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#cpu-卸载"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>CPU 卸载</span></h3> <p data-svelte-h="svelte-8sy35u">CPU 卸载选择性地将权重从 GPU 移动到 CPU。当需要某个组件时,它被传输到 GPU;当不需要时,它被移动到 CPU。此方法作用于子模块而非整个模型。它通过避免将整个模型存储在 GPU 上来节省内存。</p> <p data-svelte-h="svelte-19u9hu8">CPU 卸载显著减少内存使用,但由于子模块在设备之间多次来回传递,它也非常慢。由于速度极慢,它通常不实用。</p> <blockquote class="warning" data-svelte-h="svelte-14l2uyw"><p>在调用 <code>enable_sequential_cpu_offload()</code> 之前,不要将管道移动到 CUDA,否则节省的内存非常有限(更多细节请参考此 <a href="https://github.com/huggingface/diffusers/issues/1934" rel="nofollow">issue</a>)。这是一个状态操作,会在模型上安装钩子。</p></blockquote> <p data-svelte-h="svelte-1u58nwg">调用 <code>enable_sequential_cpu_offload()</code> 以在管道上启用它。</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
<span class="hljs-string">&quot;black-forest-labs/FLUX.1-schnell&quot;</span>, torch_dtype=torch.bfloat16
)
pipeline.enable_sequential_cpu_offload()
pipeline(
prompt=<span class="hljs-string">&quot;An astronaut riding a horse on Mars&quot;</span>,
guidance_scale=<span class="hljs-number">0.</span>,
height=<span class="hljs-number">768</span>,
width=<span class="hljs-number">1360</span>,
num_inference_steps=<span class="hljs-number">4</span>,
max_sequence_length=<span class="hljs-number">256</span>,
).images[<span class="hljs-number">0</span>]
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;Max memory reserved: <span class="hljs-subst">{torch.cuda.max_memory_allocated() / <span class="hljs-number">1024</span>**<span class="hljs-number">3</span>:<span class="hljs-number">.2</span>f}</span> GB&quot;</span>)<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="模型卸载" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#模型卸载"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>模型卸载</span></h3> <p data-svelte-h="svelte-9x85p5">模型卸载将整个模型移动到 GPU,而不是选择性地移动某些层或模型组件。一个主要管道模型,通常是文本编码器、UNet 和 VAE,被放置在 GPU 上,而其他组件保持在 CPU 上。像 UNet 这样运行多次的组件会一直留在 GPU 上,直到完全完成且不再需要。这消除了 <a href="#cpu-offloading">CPU 卸载</a> 的通信开销,使模型卸载成为一个更快的替代方案。权衡是内存节省不会那么大。</p> <blockquote class="warning" data-svelte-h="svelte-6420if"><p>请注意,如果在安装钩子后模型在管道外部被重用(更多细节请参考 <a href="https://huggingface.co/docs/accelerate/en/package_reference/big_modeling#accelerate.hooks.remove_hook_from_module" rel="nofollow">移除钩子</a>),您需要按预期顺序运行整个管道和模型以正确卸载它们。这是一个状态操作,会在模型上安装钩子。</p></blockquote> <p data-svelte-h="svelte-1l12c2i">调用 <code>enable_model_cpu_offload()</code> 以在管道上启用它。</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
<span class="hljs-string">&quot;black-forest-labs/FLUX.1-schnell&quot;</span>, torch_dtype=torch.bfloat16
)
pipeline.enable_model_cpu_offload()
pipeline(
prompt=<span class="hljs-string">&quot;An astronaut riding a horse on Mars&quot;</span>,
guidance_scale=<span class="hljs-number">0.</span>,
height=<span class="hljs-number">768</span>,
width=<span class="hljs-number">1360</span>,
num_inference_steps=<span class="hljs-number">4</span>,
max_sequence_length=<span class="hljs-number">256</span>,
).images[<span class="hljs-number">0</span>]
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;最大内存保留: <span class="hljs-subst">{torch.cuda.max_memory_allocated() / <span class="hljs-number">1024</span>**<span class="hljs-number">3</span>:<span class="hljs-number">.2</span>f}</span> GB&quot;</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1k1me1b"><code>enable_model_cpu_offload()</code> 在您单独使用 <code>encode_prompt()</code> 方法生成文本编码器隐藏状态时也有帮助。</p> <h3 class="relative group"><a id="组卸载" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#组卸载"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>组卸载</span></h3> <p data-svelte-h="svelte-19yi3hk">组卸载将内部层组(<a href="https://pytorch.org/docs/stable/generated/torch.nn.ModuleList.html" rel="nofollow">torch.nn.ModuleList</a><a href="https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html" rel="nofollow">torch.nn.Sequential</a>)移动到 CPU。它比<a href="#model-offloading">模型卸载</a>使用更少的内存,并且比<a href="#cpu-offloading">CPU 卸载</a>更快,因为它减少了通信开销。</p> <blockquote class="warning" data-svelte-h="svelte-1ql5wwf"><p>如果前向实现包含权重相关的输入设备转换,组卸载可能不适用于所有模型,因为它可能与组卸载的设备转换机制冲突。</p></blockquote> <p data-svelte-h="svelte-eltwhr">调用 <code>enable_group_offload()</code> 为继承自 <code>ModelMixin</code> 的标准 Diffusers 模型组件启用它。对于不继承自 <code>ModelMixin</code> 的其他模型组件,例如通用 <a href="https://pytorch.org/docs/stable/generated/torch.nn.Module.html" rel="nofollow">torch.nn.Module</a>,使用 <code>apply_group_offloading()</code> 代替。</p> <p data-svelte-h="svelte-14zzrw4"><code>offload_type</code> 参数可以设置为 <code>block_level</code><code>leaf_level</code></p> <ul data-svelte-h="svelte-1ray78r"><li><code>block_level</code> 基于 <code>num_blocks_per_group</code> 参数卸载层组。例如,如果 <code>num_blocks_per_group=2</code> 在一个有 40 层的模型上,每次加载和卸载 2 层(总共 20 次加载/卸载)。这大大减少了内存需求。</li> <li><code>leaf_level</code> 在最低级别卸载单个层,等同于<a href="#cpu-offloading">CPU 卸载</a>。但如果您使用流而不放弃推理速度,它可以更快。</li></ul> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> CogVideoXPipeline
<span class="hljs-keyword">from</span> diffusers.hooks <span class="hljs-keyword">import</span> apply_group_offloading
<span class="hljs-keyword">from</span> diffusers.utils <span class="hljs-keyword">import</span> export_to_video
onload_device = torch.device(<span class="hljs-string">&quot;cuda&quot;</span>)
offload_device = torch.device(<span class="hljs-string">&quot;cpu&quot;</span>)
pipeline = CogVideoXPipeline.from_pretrained(<span class="hljs-string">&quot;THUDM/CogVideoX-5b&quot;</span>, torch_dtype=torch.bfloat16)
<span class="hljs-comment"># 对 Diffusers 模型实现使用 enable_group_offload 方法</span>
pipeline.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type=<span class="hljs-string">&quot;leaf_level&quot;</span>)
pipeline.vae.enable_group_offload(onload_device=onload_device, offload_type=<span class="hljs-string">&quot;leaf_level&quot;</span>)
<span class="hljs-comment"># 对其他模型组件使用 apply_group_offloading 方法</span>
apply_group_offloading(pipeline.text_encoder, onload_device=onload_device, offload_type=<span class="hljs-string">&quot;block_level&quot;</span>, num_blocks_per_group=<span class="hljs-number">2</span>)
prompt = (
<span class="hljs-string">&quot;A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. &quot;</span>
<span class="hljs-string">&quot;The panda&#x27;s fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other &quot;</span>
<span class="hljs-string">&quot;pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, &quot;</span>
<span class="hljs-string">&quot;casting a gentle glow on the scene. The panda&#x27;s face is expressive, showing concentration and joy as it plays. &quot;</span>
<span class="hljs-string">&quot;The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical &quot;</span>
<span class="hljs-string">&quot;atmosphere of this unique musical performance.&quot;</span>
)
video = pipeline(prompt=prompt, guidance_scale=<span class="hljs-number">6</span>, num_inference_steps=<span class="hljs-number">50</span>).frames[<span class="hljs-number">0</span>]
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;Max memory reserved: <span class="hljs-subst">{torch.cuda.max_memory_allocated() / <span class="hljs-number">1024</span>**<span class="hljs-number">3</span>:<span class="hljs-number">.2</span>f}</span> GB&quot;</span>)
export_to_video(video, <span class="hljs-string">&quot;output.mp4&quot;</span>, fps=<span class="hljs-number">8</span>)<!-- HTML_TAG_END --></pre></div> <h4 class="relative group"><a id="cuda-流" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#cuda-流"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>CUDA 流</span></h4> <p data-svelte-h="svelte-pfde8s"><code>use_stream</code> 参数可以激活支持异步数据传输流的 CUDA 设备,以减少整体执行时间,与 <a href="#cpu-offloading">CPU 卸载</a> 相比。它通过使用层预取重叠数据传输和计算。下一个要执行的层在当前层仍在执行时加载到 GPU 上。这会显著增加 CPU 内存,因此请确保您有模型大小的 2 倍内存。</p> <p data-svelte-h="svelte-1vw5frn">设置 <code>record_stream=True</code> 以获得更多速度提升,代价是内存使用量略有增加。请参阅 <a href="https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html" rel="nofollow">torch.Tensor.record_stream</a> 文档了解更多信息。</p> <blockquote class="tip" data-svelte-h="svelte-14tpl6d"><p><code>use_stream=True</code> 在启用平铺的 VAEs 上时,确保在推理前进行虚拟前向传递(可以使用虚拟输入),以避免设备不匹配错误。这可能不适用于所有实现,因此如果遇到任何问题,请随时提出问题。</p></blockquote> <p data-svelte-h="svelte-w0fj4y">如果您在使用启用 <code>use_stream</code><code>block_level</code> 组卸载,<code>num_blocks_per_group</code> 参数应设置为 <code>1</code>,否则会引发警告。</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->pipeline.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type=<span class="hljs-string">&quot;leaf_level&quot;</span>, use_stream=<span class="hljs-literal">True</span>, record_stream=<span class="hljs-literal">True</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-t6xcks"><code>low_cpu_mem_usage</code> 参数可以设置为 <code>True</code>,以在使用流进行组卸载时减少 CPU 内存使用。它最适合 <code>leaf_level</code> 卸载和 CPU 内存瓶颈的情况。通过动态创建固定张量而不是预先固定它们来节省内存。然而,这可能会增加整体执行时间。</p> <h4 class="relative group"><a id="卸载到磁盘" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#卸载到磁盘"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>卸载到磁盘</span></h4> <p data-svelte-h="svelte-1nq3fej">组卸载可能会消耗大量系统内存,具体取决于模型大小。在内存有限的系统上,尝试将组卸载到磁盘作为辅助内存。</p> <p data-svelte-h="svelte-1hbe5xk"><code>enable_group_offload()</code><code>apply_group_offloading()</code> 中设置 <code>offload_to_disk_path</code> 参数,将模型卸载到磁盘。</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->pipeline.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type=<span class="hljs-string">&quot;leaf_level&quot;</span>, offload_to_disk_path=<span class="hljs-string">&quot;path/to/disk&quot;</span>)
apply_group_offloading(pipeline.text_encoder, onload_device=onload_device, offload_type=<span class="hljs-string">&quot;block_level&quot;</span>, num_blocks_per_group=<span class="hljs-number">2</span>, offload_to_disk_path=<span class="hljs-string">&quot;path/to/disk&quot;</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-17y83c8">参考这些<a href="https://github.com/huggingface/diffusers/pull/11682#issue-3129365363" rel="nofollow">两个</a><a href="https://github.com/huggingface/diffusers/pull/11682#issuecomment-2955715126" rel="nofollow">表格</a>来比较速度和内存的权衡。</p> <h2 class="relative group"><a id="分层类型转换" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#分层类型转换"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>分层类型转换</span></h2> <blockquote class="tip" data-svelte-h="svelte-1rchtg8"><p>将分层类型转换与<a href="#group-offloading">组卸载</a>结合使用,以获得更多内存节省。</p></blockquote> <p data-svelte-h="svelte-1f732i8">分层类型转换将权重存储在较小的数据格式中(例如 <code>torch.float8_e4m3fn</code><code>torch.float8_e5m2</code>),以使用更少的内存,并在计算时将那些权重上转换为更高精度如 <code>torch.float16</code><code>torch.bfloat16</code>。某些层(归一化和调制相关权重)被跳过,因为将它们存储在 fp8 中可能会降低生成质量。</p> <blockquote class="warning" data-svelte-h="svelte-hsn5v6"><p>如果前向实现包含权重的内部类型转换,分层类型转换可能不适用于所有模型。当前的分层类型转换实现假设前向传递独立于权重精度,并且输入数据类型始终在 <code>compute_dtype</code> 中指定(请参见<a href="https://github.com/huggingface/transformers/blob/7f5077e53682ca855afc826162b204ebf809f1f9/src/transformers/models/t5/modeling_t5.py#L294-L299" rel="nofollow">这里</a>以获取不兼容的实现)。</p> <p>分层类型转换也可能在使用<a href="https://huggingface.co/docs/peft/index" rel="nofollow">PEFT</a>层的自定义建模实现上失败。有一些检查可用,但它们没有经过广泛测试或保证在所有情况下都能工作。</p></blockquote> <p data-svelte-h="svelte-zead05">调用 <code>enable_layerwise_casting()</code> 来设置存储和计算数据类型。</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> CogVideoXPipeline, CogVideoXTransformer3DModel
<span class="hljs-keyword">from</span> diffusers.utils <span class="hljs-keyword">import</span> export_to_video
transformer = CogVideoXTransformer3DModel.from_pretrained(
<span class="hljs-string">&quot;THUDM/CogVideoX-5b&quot;</span>,
subfolder=<span class="hljs-string">&quot;transformer&quot;</span>,
torch_dtype=torch.bfloat16
)
transformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
pipeline = CogVideoXPipeline.from_pretrained(<span class="hljs-string">&quot;THUDM/CogVideoX-5b&quot;</span>,
transformer=transformer,
torch_dtype=torch.bfloat16
).to(<span class="hljs-string">&quot;cuda&quot;</span>)
prompt = (
<span class="hljs-string">&quot;A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. &quot;</span>
<span class="hljs-string">&quot;The panda&#x27;s fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other &quot;</span>
<span class="hljs-string">&quot;pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, &quot;</span>
<span class="hljs-string">&quot;casting a gentle glow on the scene. The panda&#x27;s face is expressive, showing concentration and joy as it plays. &quot;</span>
<span class="hljs-string">&quot;The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical &quot;</span>
<span class="hljs-string">&quot;atmosphere of this unique musical performance.&quot;</span>
)
video = pipeline(prompt=prompt, guidance_scale=<span class="hljs-number">6</span>, num_inference_steps=<span class="hljs-number">50</span>).frames[<span class="hljs-number">0</span>]
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;Max memory reserved: <span class="hljs-subst">{torch.cuda.max_memory_allocated() / <span class="hljs-number">1024</span>**<span class="hljs-number">3</span>:<span class="hljs-number">.2</span>f}</span> GB&quot;</span>)
export_to_video(video, <span class="hljs-string">&quot;output.mp4&quot;</span>, fps=<span class="hljs-number">8</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-txkzq8"><code>apply_layerwise_casting()</code> 方法也可以在您需要更多控制和灵活性时使用。它可以通过在特定内部模块上调用它来部分应用于模型层。使用 <code>skip_modules_pattern</code><code>skip_modules_classes</code> 参数来指定要避免的模块,例如归一化和调制层。</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> CogVideoXTransformer3DModel
<span class="hljs-keyword">from</span> diffusers.hooks <span class="hljs-keyword">import</span> apply_layerwise_casting
transformer = CogVideoXTransformer3DModel.from_pretrained(
<span class="hljs-string">&quot;THUDM/CogVideoX-5b&quot;</span>,
subfolder=<span class="hljs-string">&quot;transformer&quot;</span>,
torch_dtype=torch.bfloat16
)
<span class="hljs-comment"># 跳过归一化层</span>
apply_layerwise_casting(
transformer,
storage_dtype=torch.float8_e4m3fn,
compute_dtype=torch.bfloat16,
skip_modules_classes=[<span class="hljs-string">&quot;norm&quot;</span>],
non_blocking=<span class="hljs-literal">True</span>,
)<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="torchchannelslast" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#torchchannelslast"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>torch.channels_last</span></h2> <p data-svelte-h="svelte-8d4g9i"><a href="https://pytorch.org/tutorials/intermediate/memory_format_tutorial.html" rel="nofollow">torch.channels_last</a> 将张量的存储方式从 <code>(批次大小, 通道数, 高度, 宽度)</code> 翻转为 <code>(批次大小, 高度, 宽度, 通道数)</code>。这使张量与硬件如何顺序访问存储在内存中的张量对齐,并避免了在内存中跳转以访问像素值。</p> <p data-svelte-h="svelte-kbl8lh">并非所有运算符当前都支持通道最后格式,并且可能导致性能更差,但仍然值得尝试。</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-built_in">print</span>(pipeline.unet.conv_out.state_dict()[<span class="hljs-string">&quot;weight&quot;</span>].stride()) <span class="hljs-comment"># (2880, 9, 3, 1)</span>
pipeline.unet.to(memory_format=torch.channels_last) <span class="hljs-comment"># 原地操作</span>
<span class="hljs-built_in">print</span>(
pipeline.unet.conv_out.state_dict()[<span class="hljs-string">&quot;weight&quot;</span>].stride()
) <span class="hljs-comment"># (2880, 1, 960, 320) 第二个维度的跨度为1证明它有效</span><!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="torchjittrace" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#torchjittrace"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>torch.jit.trace</span></h2> <p data-svelte-h="svelte-1feyw5u"><a href="https://pytorch.org/docs/stable/generated/torch.jit.trace.html" rel="nofollow">torch.jit.trace</a> 记录模型在样本输入上执行的操作,并根据记录的执行路径创建一个新的、优化的模型表示。在跟踪过程中,模型被优化以减少来自Python和动态控制流的开销,并且操作被融合在一起以提高效率。返回的可执行文件或 <a href="https://pytorch.org/docs/stable/generated/torch.jit.ScriptFunction.html" rel="nofollow">ScriptFunction</a> 可以被编译。</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> time
<span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> StableDiffusionPipeline
<span class="hljs-keyword">import</span> functools
<span class="hljs-comment"># torch 禁用梯度</span>
torch.set_grad_enabled(<span class="hljs-literal">False</span>)
<span class="hljs-comment"># 设置变量</span>
n_experiments = <span class="hljs-number">2</span>
unet_runs_per_experiment = <span class="hljs-number">50</span>
<span class="hljs-comment"># 加载样本输入</span>
<span class="hljs-keyword">def</span> <span class="hljs-title function_">generate_inputs</span>():
sample = torch.randn((<span class="hljs-number">2</span>, <span class="hljs-number">4</span>, <span class="hljs-number">64</span>, <span class="hljs-number">64</span>), device=<span class="hljs-string">&quot;cuda&quot;</span>, dtype=torch.float16)
timestep = torch.rand(<span class="hljs-number">1</span>, device=<span class="hljs-string">&quot;cuda&quot;</span>, dtype=torch.float16) * <span class="hljs-number">999</span>
encoder_hidden_states = torch.randn((<span class="hljs-number">2</span>, <span class="hljs-number">77</span>, <span class="hljs-number">768</span>), device=<span class="hljs-string">&quot;cuda&quot;</span>, dtype=torch.float16)
<span class="hljs-keyword">return</span> sample, timestep, encoder_hidden_states
pipeline = StableDiffusionPipeline.from_pretrained(
<span class="hljs-string">&quot;stable-diffusion-v1-5/stable-diffusion-v1-5&quot;</span>,
torch_dtype=torch.float16,
use_safetensors=<span class="hljs-literal">True</span>,
).to(<span class="hljs-string">&quot;cuda&quot;</span>)
unet = pipeline.unet
unet.<span class="hljs-built_in">eval</span>()
unet.to(memory
_<span class="hljs-built_in">format</span>=torch.channels_last) <span class="hljs-comment"># 使用 channels_last 内存格式</span>
unet.forward = functools.partial(unet.forward, return_dict=<span class="hljs-literal">False</span>) <span class="hljs-comment"># 设置 return_dict=False 为默认</span>
<span class="hljs-comment"># 预热</span>
<span class="hljs-keyword">for</span> _ <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(<span class="hljs-number">3</span>):
<span class="hljs-keyword">with</span> torch.inference_mode():
inputs = generate_inputs()
orig_output = unet(*inputs)
<span class="hljs-comment"># 追踪</span>
<span class="hljs-built_in">print</span>(<span class="hljs-string">&quot;tracing..&quot;</span>)
unet_traced = torch.jit.trace(unet, inputs)
unet_traced.<span class="hljs-built_in">eval</span>()
<span class="hljs-built_in">print</span>(<span class="hljs-string">&quot;done tracing&quot;</span>)
<span class="hljs-comment"># 预热和优化图</span>
<span class="hljs-keyword">for</span> _ <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(<span class="hljs-number">5</span>):
<span class="hljs-keyword">with</span> torch.inference_mode():
inputs = generate_inputs()
orig_output = unet_traced(*inputs)
<span class="hljs-comment"># 基准测试</span>
<span class="hljs-keyword">with</span> torch.inference_mode():
<span class="hljs-keyword">for</span> _ <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(n_experiments):
torch.cuda.synchronize()
start_time = time.time()
<span class="hljs-keyword">for</span> _ <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(unet_runs_per_experiment):
orig_output = unet_traced(*inputs)
torch.cuda.synchronize()
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;unet traced inference took <span class="hljs-subst">{time.time() - start_time:<span class="hljs-number">.2</span>f}</span> seconds&quot;</span>)
<span class="hljs-keyword">for</span> _ <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(n_experiments):
torch.cuda.synchronize()
start_time = time.time()
<span class="hljs-keyword">for</span> _ <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(unet_runs_per_experiment):
orig_output = unet(*inputs)
torch.cuda.synchronize()
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;unet inference took <span class="hljs-subst">{time.time() - start_time:<span class="hljs-number">.2</span>f}</span> seconds&quot;</span>)
<span class="hljs-comment"># 保存模型</span>
unet_traced.save(<span class="hljs-string">&quot;unet_traced.pt&quot;</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-qm13iw">替换管道的 UNet 为追踪版本。</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> StableDiffusionPipeline
<span class="hljs-keyword">from</span> dataclasses <span class="hljs-keyword">import</span> dataclass
<span class="hljs-meta">@dataclass</span>
<span class="hljs-keyword">class</span> <span class="hljs-title class_">UNet2DConditionOutput</span>:
sample: torch.Tensor
pipeline = StableDiffusionPipeline.from_pretrained(
<span class="hljs-string">&quot;stable-diffusion-v1-5/stable-diffusion-v1-5&quot;</span>,
torch_dtype=torch.float16,
use_safetensors=<span class="hljs-literal">True</span>,
).to(<span class="hljs-string">&quot;cuda&quot;</span>)
<span class="hljs-comment"># 使用 jitted unet</span>
unet_traced = torch.jit.load(<span class="hljs-string">&quot;unet_traced.pt&quot;</span>)
<span class="hljs-comment"># del pipeline.unet</span>
<span class="hljs-keyword">class</span> <span class="hljs-title class_">TracedUNet</span>(torch.nn.Module):
<span class="hljs-keyword">def</span> <span class="hljs-title function_">__init__</span>(<span class="hljs-params">self</span>):
<span class="hljs-built_in">super</span>().__init__()
self.in_channels = pipe.unet.config.in_channels
self.device = pipe.unet.device
<span class="hljs-keyword">def</span> <span class="hljs-title function_">forward</span>(<span class="hljs-params">self, latent_model_input, t, encoder_hidden_states</span>):
sample = unet_traced(latent_model_input, t, encoder_hidden_states)[<span class="hljs-number">0</span>]
<span class="hljs-keyword">return</span> UNet2DConditionOutput(sample=sample)
pipeline.unet = TracedUNet()
<span class="hljs-keyword">with</span> torch.inference_mode():
image = pipe([prompt] * <span class="hljs-number">1</span>, num_inference_steps=<span class="hljs-number">50</span>).images[<span class="hljs-number">0</span>]<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="内存高效注意力" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#内存高效注意力"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>内存高效注意力</span></h2> <blockquote class="tip" data-svelte-h="svelte-mizp64"><p>内存高效注意力优化内存使用 <em></em> <a href="./fp16#scaled-dot-product-attention">推理速度</a></p></blockquote> <p data-svelte-h="svelte-1m5qcbu">Transformers 注意力机制是内存密集型的,尤其对于长序列,因此您可以尝试使用不同且更内存高效的注意力类型。</p> <p data-svelte-h="svelte-5jk4ak">默认情况下,如果安装了 PyTorch &gt;= 2.0,则使用 <a href="https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html" rel="nofollow">scaled dot-product attention (SDPA)</a>。您无需对代码进行任何额外更改。</p> <p data-svelte-h="svelte-1f4dy2o">SDPA 还支持 <a href="https://github.com/Dao-AILab/flash-attention" rel="nofollow">FlashAttention</a><a href="https://github.com/facebookresearch/xformers" rel="nofollow">xFormers</a>,以及 a
这是一个原生的 C++ PyTorch 实现。它会根据您的输入自动选择最优的实现。</p> <p data-svelte-h="svelte-b9wnp4">您可以使用 <code>enable_xformers_memory_efficient_attention()</code> 方法显式地使用 xFormers。</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># pip install xformers</span>
<span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> StableDiffusionXLPipeline
pipeline = StableDiffusionXLPipeline.from_pretrained(
<span class="hljs-string">&quot;stabilityai/stable-diffusion-xl-base-1.0&quot;</span>,
torch_dtype=torch.float16,
).to(<span class="hljs-string">&quot;cuda&quot;</span>)
pipeline.enable_xformers_memory_efficient_attention()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-go6u4e">调用 <code>disable_xformers_memory_efficient_attention()</code> 来禁用它。</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->pipeline.disable_xformers_memory_efficient_attention()<!-- HTML_TAG_END --></pre></div> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/diffusers/blob/main/docs/source/zh/optimization/memory.md" target="_blank"><svg class="mr-1" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M31,16l-7,7l-1.41-1.41L28.17,16l-5.58-5.59L24,9l7,7z"></path><path d="M1,16l7-7l1.41,1.41L3.83,16l5.58,5.59L8,23l-7-7z"></path><path d="M12.419,25.484L17.639,6.552l1.932,0.518L14.351,26.002z"></path></svg> <span data-svelte-h="svelte-zjs2n5"><span class="underline">Update</span> on GitHub</span></a> <p></p>
<script>
{
__sveltekit_oaf0l2 = {
assets: "/docs/diffusers/pr_12652/zh",
base: "/docs/diffusers/pr_12652/zh",
env: {}
};
const element = document.currentScript.parentElement;
const data = [null,null];
Promise.all([
import("/docs/diffusers/pr_12652/zh/_app/immutable/entry/start.ca7a833f.js"),
import("/docs/diffusers/pr_12652/zh/_app/immutable/entry/app.746b83f3.js")
]).then(([kit, app]) => {
kit.start(app, element, {
node_ids: [0, 27],
data,
form: null,
error: null
});
});
}
</script>

Xet Storage Details

Size:
104 kB
·
Xet hash:
9c236712f171e6f8330ac4af7a3a0e54d5547d2bcc4da6becbd2139d0fad68d8

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.