From 302ee73f349a194fc748971e17103db5cc999470 Mon Sep 17 00:00:00 2001
From: Isaac Marovitz <42140194+IsaacMarovitz@users.noreply.github.com>
Date: Mon, 2 Sep 2024 12:55:30 +0100
Subject: [PATCH] Metal: Unsupported topology indexed draw conversion (#40)

* Convert unsupported indexed buffer topologies

* Fix index count and dispatch size

* Cleanup

* Fix typos
---
 src/Ryujinx.Graphics.Metal/BufferHolder.cs    |  29 +++++
 src/Ryujinx.Graphics.Metal/BufferManager.cs   |  10 ++
 src/Ryujinx.Graphics.Metal/CacheByRange.cs    | 108 ++++++------------
 src/Ryujinx.Graphics.Metal/HelperShader.cs    |  66 ++++++++++-
 .../IndexBufferPattern.cs                     |  23 ----
 .../IndexBufferState.cs                       |  37 ++++++
 src/Ryujinx.Graphics.Metal/Pipeline.cs        |  29 +++--
 .../Ryujinx.Graphics.Metal.csproj             |   1 +
 .../Shaders/ConvertIndexBuffer.metal          |  59 ++++++++++
 9 files changed, 256 insertions(+), 106 deletions(-)
 create mode 100644 src/Ryujinx.Graphics.Metal/Shaders/ConvertIndexBuffer.metal

diff --git a/src/Ryujinx.Graphics.Metal/BufferHolder.cs b/src/Ryujinx.Graphics.Metal/BufferHolder.cs
index 47e9cd0e3..cc86a403f 100644
--- a/src/Ryujinx.Graphics.Metal/BufferHolder.cs
+++ b/src/Ryujinx.Graphics.Metal/BufferHolder.cs
@@ -318,6 +318,35 @@ namespace Ryujinx.Graphics.Metal
             return holder.GetBuffer();
         }
 
+        public Auto<DisposableBuffer> GetBufferTopologyConversion(CommandBufferScoped cbs, int offset, int size, IndexBufferPattern pattern, int indexSize)
+        {
+            if (!BoundToRange(offset, ref size))
+            {
+                return null;
+            }
+
+            var key = new TopologyConversionCacheKey(_renderer, pattern, indexSize);
+
+            if (!_cachedConvertedBuffers.TryGetValue(offset, size, key, out var holder))
+            {
+                // The destination index size is always I32.
+
+                int indexCount = size / indexSize;
+
+                int convertedCount = pattern.GetConvertedCount(indexCount);
+
+                holder = _renderer.BufferManager.Create(convertedCount * 4);
+
+                _renderer.HelperShader.ConvertIndexBuffer(cbs, this, holder, pattern, indexSize, offset, indexCount);
+
+                key.SetBuffer(holder.GetBuffer());
+
+                _cachedConvertedBuffers.Add(offset, size, key, holder);
+            }
+
+            return holder.GetBuffer();
+        }
+
         public bool TryGetCachedConvertedBuffer(int offset, int size, ICacheKey key, out BufferHolder holder)
         {
             return _cachedConvertedBuffers.TryGetValue(offset, size, key, out holder);
diff --git a/src/Ryujinx.Graphics.Metal/BufferManager.cs b/src/Ryujinx.Graphics.Metal/BufferManager.cs
index 71620f424..07a686223 100644
--- a/src/Ryujinx.Graphics.Metal/BufferManager.cs
+++ b/src/Ryujinx.Graphics.Metal/BufferManager.cs
@@ -177,6 +177,16 @@ namespace Ryujinx.Graphics.Metal
             return null;
         }
 
+        public Auto<DisposableBuffer> GetBufferTopologyConversion(CommandBufferScoped cbs, BufferHandle handle, int offset, int size, IndexBufferPattern pattern, int indexSize)
+        {
+            if (TryGetBuffer(handle, out var holder))
+            {
+                return holder.GetBufferTopologyConversion(cbs, offset, size, pattern, indexSize);
+            }
+
+            return null;
+        }
+
         public PinnedSpan<byte> GetData(BufferHandle handle, int offset, int size)
         {
             if (TryGetBuffer(handle, out var holder))
diff --git a/src/Ryujinx.Graphics.Metal/CacheByRange.cs b/src/Ryujinx.Graphics.Metal/CacheByRange.cs
index 80a0c1018..e2eb24f66 100644
--- a/src/Ryujinx.Graphics.Metal/CacheByRange.cs
+++ b/src/Ryujinx.Graphics.Metal/CacheByRange.cs
@@ -39,80 +39,42 @@ namespace Ryujinx.Graphics.Metal
         }
     }
 
-    // [SupportedOSPlatform("macos")]
-    // struct AlignedVertexBufferCacheKey : ICacheKey
-    // {
-    //     private readonly int _stride;
-    //     private readonly int _alignment;
-    //
-    //     // Used to notify the pipeline that bindings have invalidated on dispose.
-    //     // private readonly MetalRenderer _renderer;
-    //     // private Auto<DisposableBuffer> _buffer;
-    //
-    //     public AlignedVertexBufferCacheKey(MetalRenderer renderer, int stride, int alignment)
-    //     {
-    //         // _renderer = renderer;
-    //         _stride = stride;
-    //         _alignment = alignment;
-    //         // _buffer = null;
-    //     }
-    //
-    //     public readonly bool KeyEqual(ICacheKey other)
-    //     {
-    //         return other is AlignedVertexBufferCacheKey entry &&
-    //                entry._stride == _stride &&
-    //                entry._alignment == _alignment;
-    //     }
-    //
-    //     public void SetBuffer(Auto<DisposableBuffer> buffer)
-    //     {
-    //         // _buffer = buffer;
-    //     }
-    //
-    //     public readonly void Dispose()
-    //     {
-    //         // TODO: Tell pipeline buffer is dirty!
-    //         // _renderer.PipelineInternal.DirtyVertexBuffer(_buffer);
-    //     }
-    // }
+    [SupportedOSPlatform("macos")]
+    struct TopologyConversionCacheKey : ICacheKey
+    {
+        private readonly IndexBufferPattern _pattern;
+        private readonly int _indexSize;
 
-    // [SupportedOSPlatform("macos")]
-    // struct TopologyConversionCacheKey : ICacheKey
-    // {
-    //     // TODO: Patterns
-    //     // private readonly IndexBufferPattern _pattern;
-    //     private readonly int _indexSize;
-    //
-    //     // Used to notify the pipeline that bindings have invalidated on dispose.
-    //     // private readonly MetalRenderer _renderer;
-    //     // private Auto<DisposableBuffer> _buffer;
-    //
-    //     public TopologyConversionCacheKey(MetalRenderer renderer, /*IndexBufferPattern pattern, */int indexSize)
-    //     {
-    //         // _renderer = renderer;
-    //         // _pattern = pattern;
-    //         _indexSize = indexSize;
-    //         // _buffer = null;
-    //     }
-    //
-    //     public readonly bool KeyEqual(ICacheKey other)
-    //     {
-    //         return other is TopologyConversionCacheKey entry &&
-    //                // entry._pattern == _pattern &&
-    //                entry._indexSize == _indexSize;
-    //     }
-    //
-    //     public void SetBuffer(Auto<DisposableBuffer> buffer)
-    //     {
-    //         // _buffer = buffer;
-    //     }
-    //
-    //     public readonly void Dispose()
-    //     {
-    //         // TODO: Tell pipeline buffer is dirty!
-    //         // _renderer.PipelineInternal.DirtyVertexBuffer(_buffer);
-    //     }
-    // }
+        // Used to notify the pipeline that bindings have invalidated on dispose.
+        // private readonly MetalRenderer _renderer;
+        // private Auto<DisposableBuffer> _buffer;
+
+        public TopologyConversionCacheKey(MetalRenderer renderer, IndexBufferPattern pattern, int indexSize)
+        {
+            // _renderer = renderer;
+            // _buffer = null;
+            _pattern = pattern;
+            _indexSize = indexSize;
+        }
+
+        public readonly bool KeyEqual(ICacheKey other)
+        {
+            return other is TopologyConversionCacheKey entry &&
+                   entry._pattern == _pattern &&
+                   entry._indexSize == _indexSize;
+        }
+
+        public void SetBuffer(Auto<DisposableBuffer> buffer)
+        {
+            // _buffer = buffer;
+        }
+
+        public readonly void Dispose()
+        {
+            // TODO: Tell pipeline buffer is dirty!
+            // _renderer.PipelineInternal.DirtyVertexBuffer(_buffer);
+        }
+    }
 
     [SupportedOSPlatform("macos")]
     readonly struct Dependency
diff --git a/src/Ryujinx.Graphics.Metal/HelperShader.cs b/src/Ryujinx.Graphics.Metal/HelperShader.cs
index 8039641ea..a4a1215a6 100644
--- a/src/Ryujinx.Graphics.Metal/HelperShader.cs
+++ b/src/Ryujinx.Graphics.Metal/HelperShader.cs
@@ -33,6 +33,7 @@ namespace Ryujinx.Graphics.Metal
         private readonly IProgram _programDepthStencilClear;
         private readonly IProgram _programStrideChange;
         private readonly IProgram _programConvertD32S8ToD24S8;
+        private readonly IProgram _programConvertIndexBuffer;
         private readonly IProgram _programDepthBlit;
         private readonly IProgram _programDepthBlitMs;
         private readonly IProgram _programStencilBlit;
@@ -163,6 +164,17 @@ namespace Ryujinx.Graphics.Metal
                 new ShaderSource(convertD32S8ToD24S8Source, ShaderStage.Compute, TargetLanguage.Msl)
             ], convertD32S8ToD24S8ResourceLayout, device, new ComputeSize(64, 1, 1));
 
+            var convertIndexBufferLayout = new ResourceLayoutBuilder()
+                .Add(ResourceStages.Compute, ResourceType.StorageBuffer, 1)
+                .Add(ResourceStages.Compute, ResourceType.StorageBuffer, 2, true)
+                .Add(ResourceStages.Compute, ResourceType.StorageBuffer, 3).Build();
+
+            var convertIndexBufferSource = ReadMsl("ConvertIndexBuffer.metal");
+            _programConvertIndexBuffer = new Program(
+            [
+                new ShaderSource(convertIndexBufferSource, ShaderStage.Compute, TargetLanguage.Msl)
+            ], convertIndexBufferLayout, device, new ComputeSize(16, 1, 1));
+
             var depthBlitSource = ReadMsl("DepthBlit.metal");
             _programDepthBlit = new Program(
             [
@@ -574,7 +586,7 @@ namespace Ryujinx.Graphics.Metal
             var srcBuffer = src.GetBuffer();
             var dstBuffer = dst.GetBuffer();
 
-            const int ParamsBufferSize = 16;
+            const int ParamsBufferSize = 4 * sizeof(int);
 
             // Save current state
             _pipeline.SwapState(_helperShaderState);
@@ -636,6 +648,58 @@ namespace Ryujinx.Graphics.Metal
             _pipeline.SwapState(null);
         }
 
+        public void ConvertIndexBuffer(
+            CommandBufferScoped cbs,
+            BufferHolder src,
+            BufferHolder dst,
+            IndexBufferPattern pattern,
+            int indexSize,
+            int srcOffset,
+            int indexCount)
+        {
+            // TODO: Support conversion with primitive restart enabled.
+
+            int primitiveCount = pattern.GetPrimitiveCount(indexCount);
+            int outputIndexSize = 4;
+
+            var srcBuffer = src.GetBuffer();
+            var dstBuffer = dst.GetBuffer();
+
+            const int ParamsBufferSize = 16 * sizeof(int);
+
+            // Save current state
+            _pipeline.SwapState(_helperShaderState);
+
+            Span<int> shaderParams = stackalloc int[ParamsBufferSize / sizeof(int)];
+
+            shaderParams[8] = pattern.PrimitiveVertices;
+            shaderParams[9] = pattern.PrimitiveVerticesOut;
+            shaderParams[10] = indexSize;
+            shaderParams[11] = outputIndexSize;
+            shaderParams[12] = pattern.BaseIndex;
+            shaderParams[13] = pattern.IndexStride;
+            shaderParams[14] = srcOffset;
+            shaderParams[15] = primitiveCount;
+
+            pattern.OffsetIndex.CopyTo(shaderParams[..pattern.OffsetIndex.Length]);
+
+            using var patternScoped = _renderer.BufferManager.ReserveOrCreate(cbs, ParamsBufferSize);
+            patternScoped.Holder.SetDataUnchecked<int>(patternScoped.Offset, shaderParams);
+
+            Span<Auto<DisposableBuffer>> sbRanges = new Auto<DisposableBuffer>[2];
+
+            sbRanges[0] = srcBuffer;
+            sbRanges[1] = dstBuffer;
+            _pipeline.SetStorageBuffers(1, sbRanges);
+            _pipeline.SetStorageBuffers([new BufferAssignment(3, patternScoped.Range)]);
+
+            _pipeline.SetProgram(_programConvertIndexBuffer);
+            _pipeline.DispatchCompute(BitUtils.DivRoundUp(primitiveCount, 16), 1, 1, "Convert Index Buffer");
+
+            // Restore previous state
+            _pipeline.SwapState(null);
+        }
+
         public unsafe void ClearColor(
             int index,
             ReadOnlySpan<float> clearColor,
diff --git a/src/Ryujinx.Graphics.Metal/IndexBufferPattern.cs b/src/Ryujinx.Graphics.Metal/IndexBufferPattern.cs
index 7292b3134..24e3222fe 100644
--- a/src/Ryujinx.Graphics.Metal/IndexBufferPattern.cs
+++ b/src/Ryujinx.Graphics.Metal/IndexBufferPattern.cs
@@ -1,6 +1,5 @@
 using Ryujinx.Graphics.GAL;
 using System;
-using System.Collections.Generic;
 using System.Runtime.InteropServices;
 using System.Runtime.Versioning;
 
@@ -49,28 +48,6 @@ namespace Ryujinx.Graphics.Metal
             return primitiveCount * OffsetIndex.Length;
         }
 
-        public IEnumerable<int> GetIndexMapping(int indexCount)
-        {
-            int primitiveCount = GetPrimitiveCount(indexCount);
-            int index = BaseIndex;
-
-            for (int i = 0; i < primitiveCount; i++)
-            {
-                if (RepeatStart)
-                {
-                    // Used for triangle fan
-                    yield return 0;
-                }
-
-                for (int j = RepeatStart ? 1 : 0; j < OffsetIndex.Length; j++)
-                {
-                    yield return index + OffsetIndex[j];
-                }
-
-                index += IndexStride;
-            }
-        }
-
         public BufferHandle GetRepeatingBuffer(int vertexCount, out int indexCount)
         {
             int primitiveCount = GetPrimitiveCount(vertexCount);
diff --git a/src/Ryujinx.Graphics.Metal/IndexBufferState.cs b/src/Ryujinx.Graphics.Metal/IndexBufferState.cs
index 7cd2ff42e..411df9685 100644
--- a/src/Ryujinx.Graphics.Metal/IndexBufferState.cs
+++ b/src/Ryujinx.Graphics.Metal/IndexBufferState.cs
@@ -62,5 +62,42 @@ namespace Ryujinx.Graphics.Metal
 
             return (new MTLBuffer(IntPtr.Zero), 0, MTLIndexType.UInt16);
         }
+
+        public (MTLBuffer, int, MTLIndexType) GetConvertedIndexBuffer(
+            MetalRenderer renderer,
+            CommandBufferScoped cbs,
+            int firstIndex,
+            int indexCount,
+            int convertedCount,
+            IndexBufferPattern pattern)
+        {
+            // Convert the index buffer using the given pattern.
+            int indexSize = GetIndexSize();
+
+            int firstIndexOffset = firstIndex * indexSize;
+
+            var autoBuffer = renderer.BufferManager.GetBufferTopologyConversion(cbs, _handle, _offset + firstIndexOffset, indexCount * indexSize, pattern, indexSize);
+
+            int size = convertedCount * 4;
+
+            if (autoBuffer != null)
+            {
+                DisposableBuffer buffer = autoBuffer.Get(cbs, 0, size);
+
+                return (buffer.Value, 0, MTLIndexType.UInt32);
+            }
+
+            return (new MTLBuffer(IntPtr.Zero), 0, MTLIndexType.UInt32);
+        }
+
+        private int GetIndexSize()
+        {
+            return _type switch
+            {
+                IndexType.UInt => 4,
+                IndexType.UShort => 2,
+                _ => 1,
+            };
+        }
     }
 }
diff --git a/src/Ryujinx.Graphics.Metal/Pipeline.cs b/src/Ryujinx.Graphics.Metal/Pipeline.cs
index 6b42578ea..f6a5e0908 100644
--- a/src/Ryujinx.Graphics.Metal/Pipeline.cs
+++ b/src/Ryujinx.Graphics.Metal/Pipeline.cs
@@ -404,6 +404,8 @@ namespace Ryujinx.Graphics.Metal
                 return;
             }
 
+            var primitiveType = TopologyRemap(_encoderStateManager.Topology).Convert();
+
             if (TopologyUnsupported(_encoderStateManager.Topology))
             {
                 var pattern = GetIndexBufferPattern();
@@ -412,7 +414,6 @@ namespace Ryujinx.Graphics.Metal
                 var buffer = _renderer.BufferManager.GetBuffer(handle, false);
                 var mtlBuffer = buffer.Get(Cbs, 0, indexCount * sizeof(int)).Value;
 
-                var primitiveType = TopologyRemap(_encoderStateManager.Topology).Convert();
                 var renderCommandEncoder = GetOrCreateRenderEncoder(true);
 
                 renderCommandEncoder.DrawIndexedPrimitives(
@@ -424,7 +425,6 @@ namespace Ryujinx.Graphics.Metal
             }
             else
             {
-                var primitiveType = TopologyRemap(_encoderStateManager.Topology).Convert();
                 var renderCommandEncoder = GetOrCreateRenderEncoder(true);
 
                 if (debugGroupName != String.Empty)
@@ -483,15 +483,26 @@ namespace Ryujinx.Graphics.Metal
                 return;
             }
 
-            // TODO: Reindex unsupported topologies
-            if (TopologyUnsupported(_encoderStateManager.Topology))
-            {
-                Logger.Warning?.Print(LogClass.Gpu, $"Drawing indexed with unsupported topology: {_encoderStateManager.Topology}");
-            }
+            MTLBuffer mtlBuffer;
+            int offset;
+            MTLIndexType type;
+            int finalIndexCount = indexCount;
 
             var primitiveType = TopologyRemap(_encoderStateManager.Topology).Convert();
 
-            (MTLBuffer mtlBuffer, int offset, MTLIndexType type) = _encoderStateManager.IndexBuffer.GetIndexBuffer(_renderer, Cbs);
+            if (TopologyUnsupported(_encoderStateManager.Topology))
+            {
+                var pattern = GetIndexBufferPattern();
+                int convertedCount = pattern.GetConvertedCount(indexCount);
+
+                finalIndexCount = convertedCount;
+
+                (mtlBuffer, offset, type) = _encoderStateManager.IndexBuffer.GetConvertedIndexBuffer(_renderer, Cbs, firstIndex, indexCount, convertedCount, pattern);
+            }
+            else
+            {
+                (mtlBuffer, offset, type) = _encoderStateManager.IndexBuffer.GetIndexBuffer(_renderer, Cbs);
+            }
 
             if (mtlBuffer.NativePtr != IntPtr.Zero)
             {
@@ -499,7 +510,7 @@ namespace Ryujinx.Graphics.Metal
 
                 renderCommandEncoder.DrawIndexedPrimitives(
                     primitiveType,
-                    (ulong)indexCount,
+                    (ulong)finalIndexCount,
                     type,
                     mtlBuffer,
                     (ulong)offset,
diff --git a/src/Ryujinx.Graphics.Metal/Ryujinx.Graphics.Metal.csproj b/src/Ryujinx.Graphics.Metal/Ryujinx.Graphics.Metal.csproj
index 0839c426a..cc1345598 100644
--- a/src/Ryujinx.Graphics.Metal/Ryujinx.Graphics.Metal.csproj
+++ b/src/Ryujinx.Graphics.Metal/Ryujinx.Graphics.Metal.csproj
@@ -19,6 +19,7 @@
       <EmbeddedResource Include="Shaders\BlitMs.metal" />
       <EmbeddedResource Include="Shaders\ChangeBufferStride.metal" />
       <EmbeddedResource Include="Shaders\ConvertD32S8ToD24S8.metal" />
+      <EmbeddedResource Include="Shaders\ConvertIndexBuffer.metal" />
       <EmbeddedResource Include="Shaders\ColorClear.metal" />
       <EmbeddedResource Include="Shaders\DepthStencilClear.metal" />
       <EmbeddedResource Include="Shaders\DepthBlit.metal" />
diff --git a/src/Ryujinx.Graphics.Metal/Shaders/ConvertIndexBuffer.metal b/src/Ryujinx.Graphics.Metal/Shaders/ConvertIndexBuffer.metal
new file mode 100644
index 000000000..c8fee5818
--- /dev/null
+++ b/src/Ryujinx.Graphics.Metal/Shaders/ConvertIndexBuffer.metal
@@ -0,0 +1,59 @@
+#include <metal_stdlib>
+
+using namespace metal;
+
+struct IndexBufferPattern {
+    int pattern[8];
+    int primitiveVertices;
+    int primitiveVerticesOut;
+    int indexSize;
+    int indexSizeOut;
+    int baseIndex;
+    int indexStride;
+    int srcOffset;
+    int totalPrimitives;
+};
+
+struct InData {
+    uint8_t data[1];
+};
+
+struct OutData {
+    uint8_t data[1];
+};
+
+struct StorageBuffers {
+    device InData* in_data;
+    device OutData* out_data;
+    constant IndexBufferPattern* index_buffer_pattern;
+};
+
+kernel void kernelMain(device StorageBuffers &storage_buffers [[buffer(STORAGE_BUFFERS_INDEX)]],
+                       uint3 thread_position_in_grid [[thread_position_in_grid]])
+{
+    int primitiveIndex = int(thread_position_in_grid.x);
+    if (primitiveIndex >= storage_buffers.index_buffer_pattern->totalPrimitives)
+    {
+        return;
+    }
+
+    int inOffset = primitiveIndex * storage_buffers.index_buffer_pattern->indexStride;
+    int outOffset = primitiveIndex * storage_buffers.index_buffer_pattern->primitiveVerticesOut;
+
+    for (int i = 0; i < storage_buffers.index_buffer_pattern->primitiveVerticesOut; i++)
+    {
+        int j;
+        int io = max(0, inOffset + storage_buffers.index_buffer_pattern->baseIndex + storage_buffers.index_buffer_pattern->pattern[i]) * storage_buffers.index_buffer_pattern->indexSize;
+        int oo = (outOffset + i) * storage_buffers.index_buffer_pattern->indexSizeOut;
+
+        for (j = 0; j < storage_buffers.index_buffer_pattern->indexSize; j++)
+        {
+            storage_buffers.out_data->data[oo + j] = storage_buffers.in_data->data[storage_buffers.index_buffer_pattern->srcOffset + io + j];
+        }
+
+        for(; j < storage_buffers.index_buffer_pattern->indexSizeOut; j++)
+        {
+            storage_buffers.out_data->data[oo + j] = uint8_t(0);
+        }
+    }
+}