0
0
Fork 0
mirror of https://github.com/GreemDev/Ryujinx.git synced 2024-12-22 23:55:47 +00:00

Metal: Compute Shaders (#19)

* check for too bix texture bindings

* implement lod query

* print shader stage name

* always have fragment input

* resolve merge conflicts

* fix: lod query

* fix: casting texture coords

* support non-array memories

* use structure types for buffers

* implement compute pipeline cache

* compute dispatch

* improve error message

* rebind compute state

* bind compute textures

* pass local size as an argument to dispatch

* implement texture buffers

* hack: change vertex index to vertex id

* pass support buffer as an argument to every function

* return at the end of function

* fix: certain missing compute bindings

* implement texture base

* improve texture binding system

* remove useless exception

* move texture handle to texture base

* fix: segfault when using disposed textures

---------

Co-authored-by: Samuliak <samuliak77@gmail.com>
Co-authored-by: SamoZ256 <96914946+SamoZ256@users.noreply.github.com>
This commit is contained in:
Isaac Marovitz 2024-05-29 16:21:59 +01:00
parent 131ab75d55
commit b064d76a4f
26 changed files with 718 additions and 224 deletions

View file

@ -25,7 +25,7 @@ namespace Ryujinx.Graphics.GAL
void CopyBuffer(BufferHandle source, BufferHandle destination, int srcOffset, int dstOffset, int size);
void DispatchCompute(int groupsX, int groupsY, int groupsZ);
void DispatchCompute(int groupsX, int groupsY, int groupsZ, int groupSizeX, int groupSizeY, int groupSizeZ);
void Draw(int vertexCount, int instanceCount, int firstVertex, int firstInstance);
void DrawIndexed(

View file

@ -6,17 +6,23 @@ namespace Ryujinx.Graphics.GAL.Multithreading.Commands
private int _groupsX;
private int _groupsY;
private int _groupsZ;
private int _groupSizeX;
private int _groupSizeY;
private int _groupSizeZ;
public void Set(int groupsX, int groupsY, int groupsZ)
public void Set(int groupsX, int groupsY, int groupsZ, int groupSizeX, int groupSizeY, int groupSizeZ)
{
_groupsX = groupsX;
_groupsY = groupsY;
_groupsZ = groupsZ;
_groupSizeX = groupSizeX;
_groupSizeY = groupSizeY;
_groupSizeZ = groupSizeZ;
}
public static void Run(ref DispatchComputeCommand command, ThreadedRenderer threaded, IRenderer renderer)
{
renderer.Pipeline.DispatchCompute(command._groupsX, command._groupsY, command._groupsZ);
renderer.Pipeline.DispatchCompute(command._groupsX, command._groupsY, command._groupsZ, command._groupSizeX, command._groupSizeY, command._groupSizeZ);
}
}
}

View file

@ -63,9 +63,9 @@ namespace Ryujinx.Graphics.GAL.Multithreading
_renderer.QueueCommand();
}
public void DispatchCompute(int groupsX, int groupsY, int groupsZ)
public void DispatchCompute(int groupsX, int groupsY, int groupsZ, int groupSizeX, int groupSizeY, int groupSizeZ)
{
_renderer.New<DispatchComputeCommand>().Set(groupsX, groupsY, groupsZ);
_renderer.New<DispatchComputeCommand>().Set(groupsX, groupsY, groupsZ, groupSizeX, groupSizeY, groupSizeZ);
_renderer.QueueCommand();
}

View file

@ -200,7 +200,7 @@ namespace Ryujinx.Graphics.Gpu.Engine.Compute
_channel.BufferManager.CommitComputeBindings();
_context.Renderer.Pipeline.DispatchCompute(qmd.CtaRasterWidth, qmd.CtaRasterHeight, qmd.CtaRasterDepth);
_context.Renderer.Pipeline.DispatchCompute(qmd.CtaRasterWidth, qmd.CtaRasterHeight, qmd.CtaRasterDepth, qmd.CtaThreadDimension0, qmd.CtaThreadDimension1, qmd.CtaThreadDimension2);
_3dEngine.ForceShaderUpdate();
}

View file

@ -211,7 +211,10 @@ namespace Ryujinx.Graphics.Gpu.Engine.Threed.ComputeDraw
_context.Renderer.Pipeline.DispatchCompute(
BitUtils.DivRoundUp(_count, ComputeLocalSize),
BitUtils.DivRoundUp(_instanceCount, ComputeLocalSize),
1);
1,
ComputeLocalSize,
ComputeLocalSize,
ComputeLocalSize);
}
/// <summary>
@ -260,7 +263,10 @@ namespace Ryujinx.Graphics.Gpu.Engine.Threed.ComputeDraw
_context.Renderer.Pipeline.DispatchCompute(
BitUtils.DivRoundUp(primitivesCount, ComputeLocalSize),
BitUtils.DivRoundUp(_instanceCount, ComputeLocalSize),
_geometryAsCompute.Info.ThreadsPerInputPrimitive);
_geometryAsCompute.Info.ThreadsPerInputPrimitive,
ComputeLocalSize,
ComputeLocalSize,
ComputeLocalSize);
}
/// <summary>

View file

@ -0,0 +1,36 @@
using Ryujinx.Common.Logging;
using SharpMetal.Foundation;
using SharpMetal.Metal;
using System;
using System.Runtime.Versioning;
namespace Ryujinx.Graphics.Metal
{
[SupportedOSPlatform("macos")]
public class ComputePipelineCache : StateCache<MTLComputePipelineState, MTLFunction, MTLFunction>
{
private readonly MTLDevice _device;
public ComputePipelineCache(MTLDevice device)
{
_device = device;
}
protected override MTLFunction GetHash(MTLFunction function)
{
return function;
}
protected override MTLComputePipelineState CreateValue(MTLFunction function)
{
var error = new NSError(IntPtr.Zero);
var pipelineState = _device.NewComputePipelineState(function, ref error);
if (error != IntPtr.Zero)
{
Logger.Error?.PrintMsg(LogClass.Gpu, $"Failed to create Compute Pipeline State: {StringHelper.String(error.LocalizedDescription)}");
}
return pipelineState;
}
}
}

View file

@ -15,6 +15,6 @@ namespace Ryujinx.Graphics.Metal
// TODO: Check this value
public const int MaxVertexLayouts = 16;
public const int MaxTextures = 31;
public const int MaxSamplers = 31;
public const int MaxSamplers = 16;
}
}

View file

@ -8,20 +8,23 @@ namespace Ryujinx.Graphics.Metal
{
public struct DirtyFlags
{
public bool Pipeline = false;
public bool RenderPipeline = false;
public bool ComputePipeline = false;
public bool DepthStencil = false;
public DirtyFlags() { }
public void MarkAll()
{
Pipeline = true;
RenderPipeline = true;
ComputePipeline = true;
DepthStencil = true;
}
public void Clear()
{
Pipeline = false;
RenderPipeline = false;
ComputePipeline = false;
DepthStencil = false;
}
}
@ -31,13 +34,17 @@ namespace Ryujinx.Graphics.Metal
{
public MTLFunction? VertexFunction = null;
public MTLFunction? FragmentFunction = null;
public MTLFunction? ComputeFunction = null;
public MTLTexture[] FragmentTextures = new MTLTexture[Constants.MaxTextures];
public TextureBase[] FragmentTextures = new TextureBase[Constants.MaxTextures];
public MTLSamplerState[] FragmentSamplers = new MTLSamplerState[Constants.MaxSamplers];
public MTLTexture[] VertexTextures = new MTLTexture[Constants.MaxTextures];
public TextureBase[] VertexTextures = new TextureBase[Constants.MaxTextures];
public MTLSamplerState[] VertexSamplers = new MTLSamplerState[Constants.MaxSamplers];
public TextureBase[] ComputeTextures = new TextureBase[Constants.MaxTextures];
public MTLSamplerState[] ComputeSamplers = new MTLSamplerState[Constants.MaxSamplers];
public List<BufferInfo> UniformBuffers = [];
public List<BufferInfo> StorageBuffers = [];
@ -87,10 +94,12 @@ namespace Ryujinx.Graphics.Metal
{
// Certain state (like viewport and scissor) doesn't need to be cloned, as it is always reacreated when assigned to
EncoderState clone = this;
clone.FragmentTextures = (MTLTexture[])FragmentTextures.Clone();
clone.FragmentTextures = (TextureBase[])FragmentTextures.Clone();
clone.FragmentSamplers = (MTLSamplerState[])FragmentSamplers.Clone();
clone.VertexTextures = (MTLTexture[])VertexTextures.Clone();
clone.VertexTextures = (TextureBase[])VertexTextures.Clone();
clone.VertexSamplers = (MTLSamplerState[])VertexSamplers.Clone();
clone.ComputeTextures = (TextureBase[])ComputeTextures.Clone();
clone.ComputeSamplers = (MTLSamplerState[])ComputeSamplers.Clone();
clone.BlendDescriptors = (BlendDescriptor?[])BlendDescriptors.Clone();
clone.VertexBuffers = (VertexBufferDescriptor[])VertexBuffers.Clone();
clone.VertexAttribs = (VertexAttribDescriptor[])VertexAttribs.Clone();

View file

@ -15,6 +15,7 @@ namespace Ryujinx.Graphics.Metal
private readonly Pipeline _pipeline;
private readonly RenderPipelineCache _renderPipelineCache;
private readonly ComputePipelineCache _computePipelineCache;
private readonly DepthStencilCache _depthStencilCache;
private EncoderState _currentState = new();
@ -33,6 +34,7 @@ namespace Ryujinx.Graphics.Metal
{
_pipeline = pipeline;
_renderPipelineCache = new(device);
_computePipelineCache = new(device);
_depthStencilCache = new(device);
// Zero buffer
@ -50,6 +52,7 @@ namespace Ryujinx.Graphics.Metal
_currentState.BackFaceStencil.Dispose();
_renderPipelineCache.Dispose();
_computePipelineCache.Dispose();
_depthStencilCache.Dispose();
}
@ -77,8 +80,8 @@ namespace Ryujinx.Graphics.Metal
SetScissors(renderCommandEncoder);
SetViewports(renderCommandEncoder);
SetVertexBuffers(renderCommandEncoder, _currentState.VertexBuffers);
SetBuffers(renderCommandEncoder, _currentState.UniformBuffers, true);
SetBuffers(renderCommandEncoder, _currentState.StorageBuffers, true);
SetRenderBuffers(renderCommandEncoder, _currentState.UniformBuffers, true);
SetRenderBuffers(renderCommandEncoder, _currentState.StorageBuffers, true);
SetCullMode(renderCommandEncoder);
SetFrontFace(renderCommandEncoder);
SetStencilRefValue(renderCommandEncoder);
@ -107,7 +110,7 @@ namespace Ryujinx.Graphics.Metal
if (_currentState.RenderTargets[i] != null)
{
var passAttachment = renderPassDescriptor.ColorAttachments.Object((ulong)i);
passAttachment.Texture = _currentState.RenderTargets[i].MTLTexture;
passAttachment.Texture = _currentState.RenderTargets[i].GetHandle();
passAttachment.LoadAction = _currentState.ClearLoadAction ? MTLLoadAction.Clear : MTLLoadAction.Load;
passAttachment.StoreAction = MTLStoreAction.Store;
}
@ -118,19 +121,19 @@ namespace Ryujinx.Graphics.Metal
if (_currentState.DepthStencil != null)
{
switch (_currentState.DepthStencil.MTLTexture.PixelFormat)
switch (_currentState.DepthStencil.GetHandle().PixelFormat)
{
// Depth Only Attachment
case MTLPixelFormat.Depth16Unorm:
case MTLPixelFormat.Depth32Float:
depthAttachment.Texture = _currentState.DepthStencil.MTLTexture;
depthAttachment.Texture = _currentState.DepthStencil.GetHandle();
depthAttachment.LoadAction = MTLLoadAction.Load;
depthAttachment.StoreAction = MTLStoreAction.Store;
break;
// Stencil Only Attachment
case MTLPixelFormat.Stencil8:
stencilAttachment.Texture = _currentState.DepthStencil.MTLTexture;
stencilAttachment.Texture = _currentState.DepthStencil.GetHandle();
stencilAttachment.LoadAction = MTLLoadAction.Load;
stencilAttachment.StoreAction = MTLStoreAction.Store;
break;
@ -138,16 +141,16 @@ namespace Ryujinx.Graphics.Metal
// Combined Attachment
case MTLPixelFormat.Depth24UnormStencil8:
case MTLPixelFormat.Depth32FloatStencil8:
depthAttachment.Texture = _currentState.DepthStencil.MTLTexture;
depthAttachment.Texture = _currentState.DepthStencil.GetHandle();
depthAttachment.LoadAction = MTLLoadAction.Load;
depthAttachment.StoreAction = MTLStoreAction.Store;
stencilAttachment.Texture = _currentState.DepthStencil.MTLTexture;
stencilAttachment.Texture = _currentState.DepthStencil.GetHandle();
stencilAttachment.LoadAction = MTLLoadAction.Load;
stencilAttachment.StoreAction = MTLStoreAction.Store;
break;
default:
Logger.Error?.PrintMsg(LogClass.Gpu, $"Unsupported Depth/Stencil Format: {_currentState.DepthStencil.MTLTexture.PixelFormat}!");
Logger.Error?.PrintMsg(LogClass.Gpu, $"Unsupported Depth/Stencil Format: {_currentState.DepthStencil.GetHandle().PixelFormat}!");
break;
}
}
@ -166,10 +169,18 @@ namespace Ryujinx.Graphics.Metal
SetViewports(renderCommandEncoder);
SetScissors(renderCommandEncoder);
SetVertexBuffers(renderCommandEncoder, _currentState.VertexBuffers);
SetBuffers(renderCommandEncoder, _currentState.UniformBuffers, true);
SetBuffers(renderCommandEncoder, _currentState.StorageBuffers, true);
SetTextureAndSampler(renderCommandEncoder, ShaderStage.Vertex, _currentState.VertexTextures, _currentState.VertexSamplers);
SetTextureAndSampler(renderCommandEncoder, ShaderStage.Fragment, _currentState.FragmentTextures, _currentState.FragmentSamplers);
SetRenderBuffers(renderCommandEncoder, _currentState.UniformBuffers, true);
SetRenderBuffers(renderCommandEncoder, _currentState.StorageBuffers, true);
for (ulong i = 0; i < Constants.MaxTextures; i++)
{
SetRenderTexture(renderCommandEncoder, ShaderStage.Vertex, i, _currentState.VertexTextures[i]);
SetRenderTexture(renderCommandEncoder, ShaderStage.Fragment, i, _currentState.FragmentTextures[i]);
}
for (ulong i = 0; i < Constants.MaxSamplers; i++)
{
SetRenderSampler(renderCommandEncoder, ShaderStage.Vertex, i, _currentState.VertexSamplers[i]);
SetRenderSampler(renderCommandEncoder, ShaderStage.Fragment, i, _currentState.FragmentSamplers[i]);
}
// Cleanup
renderPassDescriptor.Dispose();
@ -177,11 +188,34 @@ namespace Ryujinx.Graphics.Metal
return renderCommandEncoder;
}
public void RebindState(MTLRenderCommandEncoder renderCommandEncoder)
public MTLComputeCommandEncoder CreateComputeCommandEncoder()
{
if (_currentState.Dirty.Pipeline)
var descriptor = new MTLComputePassDescriptor();
var computeCommandEncoder = _pipeline.CommandBuffer.ComputeCommandEncoder(descriptor);
// Rebind all the state
SetComputeBuffers(computeCommandEncoder, _currentState.UniformBuffers);
SetComputeBuffers(computeCommandEncoder, _currentState.StorageBuffers);
for (ulong i = 0; i < Constants.MaxTextures; i++)
{
SetPipelineState(renderCommandEncoder);
SetComputeTexture(computeCommandEncoder, i, _currentState.ComputeTextures[i]);
}
for (ulong i = 0; i < Constants.MaxSamplers; i++)
{
SetComputeSampler(computeCommandEncoder, i, _currentState.ComputeSamplers[i]);
}
// Cleanup
descriptor.Dispose();
return computeCommandEncoder;
}
public void RebindRenderState(MTLRenderCommandEncoder renderCommandEncoder)
{
if (_currentState.Dirty.RenderPipeline)
{
SetRenderPipelineState(renderCommandEncoder);
}
if (_currentState.Dirty.DepthStencil)
@ -190,10 +224,22 @@ namespace Ryujinx.Graphics.Metal
}
// Clear the dirty flags
_currentState.Dirty.Clear();
_currentState.Dirty.RenderPipeline = false;
_currentState.Dirty.DepthStencil = false;
}
private readonly void SetPipelineState(MTLRenderCommandEncoder renderCommandEncoder)
public void RebindComputeState(MTLComputeCommandEncoder computeCommandEncoder)
{
if (_currentState.Dirty.ComputePipeline)
{
SetComputePipelineState(computeCommandEncoder);
}
// Clear the dirty flags
_currentState.Dirty.ComputePipeline = false;
}
private readonly void SetRenderPipelineState(MTLRenderCommandEncoder renderCommandEncoder)
{
var renderPipelineDescriptor = new MTLRenderPipelineDescriptor();
@ -202,7 +248,7 @@ namespace Ryujinx.Graphics.Metal
if (_currentState.RenderTargets[i] != null)
{
var pipelineAttachment = renderPipelineDescriptor.ColorAttachments.Object((ulong)i);
pipelineAttachment.PixelFormat = _currentState.RenderTargets[i].MTLTexture.PixelFormat;
pipelineAttachment.PixelFormat = _currentState.RenderTargets[i].GetHandle().PixelFormat;
pipelineAttachment.SourceAlphaBlendFactor = MTLBlendFactor.SourceAlpha;
pipelineAttachment.DestinationAlphaBlendFactor = MTLBlendFactor.OneMinusSourceAlpha;
pipelineAttachment.SourceRGBBlendFactor = MTLBlendFactor.SourceAlpha;
@ -225,27 +271,27 @@ namespace Ryujinx.Graphics.Metal
if (_currentState.DepthStencil != null)
{
switch (_currentState.DepthStencil.MTLTexture.PixelFormat)
switch (_currentState.DepthStencil.GetHandle().PixelFormat)
{
// Depth Only Attachment
case MTLPixelFormat.Depth16Unorm:
case MTLPixelFormat.Depth32Float:
renderPipelineDescriptor.DepthAttachmentPixelFormat = _currentState.DepthStencil.MTLTexture.PixelFormat;
renderPipelineDescriptor.DepthAttachmentPixelFormat = _currentState.DepthStencil.GetHandle().PixelFormat;
break;
// Stencil Only Attachment
case MTLPixelFormat.Stencil8:
renderPipelineDescriptor.StencilAttachmentPixelFormat = _currentState.DepthStencil.MTLTexture.PixelFormat;
renderPipelineDescriptor.StencilAttachmentPixelFormat = _currentState.DepthStencil.GetHandle().PixelFormat;
break;
// Combined Attachment
case MTLPixelFormat.Depth24UnormStencil8:
case MTLPixelFormat.Depth32FloatStencil8:
renderPipelineDescriptor.DepthAttachmentPixelFormat = _currentState.DepthStencil.MTLTexture.PixelFormat;
renderPipelineDescriptor.StencilAttachmentPixelFormat = _currentState.DepthStencil.MTLTexture.PixelFormat;
renderPipelineDescriptor.DepthAttachmentPixelFormat = _currentState.DepthStencil.GetHandle().PixelFormat;
renderPipelineDescriptor.StencilAttachmentPixelFormat = _currentState.DepthStencil.GetHandle().PixelFormat;
break;
default:
Logger.Error?.PrintMsg(LogClass.Gpu, $"Unsupported Depth/Stencil Format: {_currentState.DepthStencil.MTLTexture.PixelFormat}!");
Logger.Error?.PrintMsg(LogClass.Gpu, $"Unsupported Depth/Stencil Format: {_currentState.DepthStencil.GetHandle().PixelFormat}!");
break;
}
}
@ -287,6 +333,18 @@ namespace Ryujinx.Graphics.Metal
}
}
private readonly void SetComputePipelineState(MTLComputeCommandEncoder computeCommandEncoder)
{
if (_currentState.ComputeFunction == null)
{
return;
}
var pipelineState = _computePipelineCache.GetOrCreate(_currentState.ComputeFunction.Value);
computeCommandEncoder.SetComputePipelineState(pipelineState);
}
public void UpdateIndexBuffer(BufferRange buffer, IndexType type)
{
if (buffer.Handle != BufferHandle.Null)
@ -307,17 +365,34 @@ namespace Ryujinx.Graphics.Metal
{
Program prg = (Program)program;
if (prg.VertexFunction == IntPtr.Zero)
if (prg.VertexFunction == IntPtr.Zero && prg.ComputeFunction == IntPtr.Zero)
{
Logger.Error?.PrintMsg(LogClass.Gpu, "Invalid Vertex Function!");
if (prg.FragmentFunction == IntPtr.Zero)
{
Logger.Error?.PrintMsg(LogClass.Gpu, "No compute function");
}
else
{
Logger.Error?.PrintMsg(LogClass.Gpu, "No vertex function");
}
return;
}
_currentState.VertexFunction = prg.VertexFunction;
_currentState.FragmentFunction = prg.FragmentFunction;
if (prg.VertexFunction != IntPtr.Zero)
{
_currentState.VertexFunction = prg.VertexFunction;
_currentState.FragmentFunction = prg.FragmentFunction;
// Mark dirty
_currentState.Dirty.Pipeline = true;
// Mark dirty
_currentState.Dirty.RenderPipeline = true;
}
if (prg.ComputeFunction != IntPtr.Zero)
{
_currentState.ComputeFunction = prg.ComputeFunction;
// Mark dirty
_currentState.Dirty.ComputePipeline = true;
}
}
public void UpdateRenderTargets(ITexture[] colors, ITexture depthStencil)
@ -383,7 +458,7 @@ namespace Ryujinx.Graphics.Metal
_currentState.VertexAttribs = vertexAttribs.ToArray();
// Mark dirty
_currentState.Dirty.Pipeline = true;
_currentState.Dirty.RenderPipeline = true;
}
public void UpdateBlendDescriptors(int index, BlendDescriptor blend)
@ -557,7 +632,7 @@ namespace Ryujinx.Graphics.Metal
}
// Mark dirty
_currentState.Dirty.Pipeline = true;
_currentState.Dirty.RenderPipeline = true;
}
// Inlineable
@ -579,10 +654,18 @@ namespace Ryujinx.Graphics.Metal
}
// Inline update
if (_pipeline.CurrentEncoderType == EncoderType.Render && _pipeline.CurrentEncoder != null)
if (_pipeline.CurrentEncoder != null)
{
var renderCommandEncoder = new MTLRenderCommandEncoder(_pipeline.CurrentEncoder.Value);
SetBuffers(renderCommandEncoder, _currentState.UniformBuffers, true);
if (_pipeline.CurrentEncoderType == EncoderType.Render)
{
var renderCommandEncoder = new MTLRenderCommandEncoder(_pipeline.CurrentEncoder.Value);
SetRenderBuffers(renderCommandEncoder, _currentState.UniformBuffers, true);
}
else if (_pipeline.CurrentEncoderType == EncoderType.Compute)
{
var computeCommandEncoder = new MTLComputeCommandEncoder(_pipeline.CurrentEncoder.Value);
SetComputeBuffers(computeCommandEncoder, _currentState.UniformBuffers);
}
}
}
@ -606,10 +689,18 @@ namespace Ryujinx.Graphics.Metal
}
// Inline update
if (_pipeline.CurrentEncoderType == EncoderType.Render && _pipeline.CurrentEncoder != null)
if (_pipeline.CurrentEncoder != null)
{
var renderCommandEncoder = new MTLRenderCommandEncoder(_pipeline.CurrentEncoder.Value);
SetBuffers(renderCommandEncoder, _currentState.StorageBuffers, true);
if (_pipeline.CurrentEncoderType == EncoderType.Render)
{
var renderCommandEncoder = new MTLRenderCommandEncoder(_pipeline.CurrentEncoder.Value);
SetRenderBuffers(renderCommandEncoder, _currentState.StorageBuffers, true);
}
else if (_pipeline.CurrentEncoderType == EncoderType.Compute)
{
var computeCommandEncoder = new MTLComputeCommandEncoder(_pipeline.CurrentEncoder.Value);
SetComputeBuffers(computeCommandEncoder, _currentState.StorageBuffers);
}
}
}
@ -653,29 +744,86 @@ namespace Ryujinx.Graphics.Metal
}
// Inlineable
public readonly void UpdateTextureAndSampler(ShaderStage stage, ulong binding, MTLTexture texture, MTLSamplerState sampler)
public readonly void UpdateTexture(ShaderStage stage, ulong binding, TextureBase texture)
{
if (binding > 30)
{
Logger.Warning?.Print(LogClass.Gpu, $"Texture binding ({binding}) must be <= 30");
return;
}
switch (stage)
{
case ShaderStage.Fragment:
_currentState.FragmentTextures[binding] = texture;
_currentState.FragmentSamplers[binding] = sampler;
break;
case ShaderStage.Vertex:
_currentState.VertexTextures[binding] = texture;
_currentState.VertexSamplers[binding] = sampler;
break;
case ShaderStage.Compute:
_currentState.ComputeTextures[binding] = texture;
break;
}
if (_pipeline.CurrentEncoderType == EncoderType.Render && _pipeline.CurrentEncoder != null)
if (_pipeline.CurrentEncoder != null)
{
var renderCommandEncoder = new MTLRenderCommandEncoder(_pipeline.CurrentEncoder.Value);
// TODO: Only update the new ones
SetTextureAndSampler(renderCommandEncoder, ShaderStage.Vertex, _currentState.VertexTextures, _currentState.VertexSamplers);
SetTextureAndSampler(renderCommandEncoder, ShaderStage.Fragment, _currentState.FragmentTextures, _currentState.FragmentSamplers);
if (_pipeline.CurrentEncoderType == EncoderType.Render)
{
var renderCommandEncoder = new MTLRenderCommandEncoder(_pipeline.CurrentEncoder.Value);
SetRenderTexture(renderCommandEncoder, ShaderStage.Vertex, binding, texture);
SetRenderTexture(renderCommandEncoder, ShaderStage.Fragment, binding, texture);
}
else if (_pipeline.CurrentEncoderType == EncoderType.Compute)
{
var computeCommandEncoder = new MTLComputeCommandEncoder(_pipeline.CurrentEncoder.Value);
SetComputeTexture(computeCommandEncoder, binding, texture);
}
}
}
// Inlineable
public readonly void UpdateSampler(ShaderStage stage, ulong binding, MTLSamplerState sampler)
{
if (binding > 15)
{
Logger.Warning?.Print(LogClass.Gpu, $"Sampler binding ({binding}) must be <= 15");
return;
}
switch (stage)
{
case ShaderStage.Fragment:
_currentState.FragmentSamplers[binding] = sampler;
break;
case ShaderStage.Vertex:
_currentState.VertexSamplers[binding] = sampler;
break;
case ShaderStage.Compute:
_currentState.ComputeSamplers[binding] = sampler;
break;
}
if (_pipeline.CurrentEncoder != null)
{
if (_pipeline.CurrentEncoderType == EncoderType.Render)
{
var renderCommandEncoder = new MTLRenderCommandEncoder(_pipeline.CurrentEncoder.Value);
SetRenderSampler(renderCommandEncoder, ShaderStage.Vertex, binding, sampler);
SetRenderSampler(renderCommandEncoder, ShaderStage.Fragment, binding, sampler);
}
else if (_pipeline.CurrentEncoderType == EncoderType.Compute)
{
var computeCommandEncoder = new MTLComputeCommandEncoder(_pipeline.CurrentEncoder.Value);
SetComputeSampler(computeCommandEncoder, binding, sampler);
}
}
}
// Inlineable
public readonly void UpdateTextureAndSampler(ShaderStage stage, ulong binding, TextureBase texture, MTLSamplerState sampler)
{
UpdateTexture(stage, binding, texture);
UpdateSampler(stage, binding, sampler);
}
private readonly void SetDepthStencilState(MTLRenderCommandEncoder renderCommandEncoder)
{
if (_currentState.DepthStencilState != null)
@ -807,10 +955,10 @@ namespace Ryujinx.Graphics.Metal
Index = bufferDescriptors.Length
});
SetBuffers(renderCommandEncoder, buffers);
SetRenderBuffers(renderCommandEncoder, buffers);
}
private readonly void SetBuffers(MTLRenderCommandEncoder renderCommandEncoder, List<BufferInfo> buffers, bool fragment = false)
private readonly void SetRenderBuffers(MTLRenderCommandEncoder renderCommandEncoder, List<BufferInfo> buffers, bool fragment = false)
{
foreach (var buffer in buffers)
{
@ -823,6 +971,14 @@ namespace Ryujinx.Graphics.Metal
}
}
private readonly void SetComputeBuffers(MTLComputeCommandEncoder computeCommandEncoder, List<BufferInfo> buffers)
{
foreach (var buffer in buffers)
{
computeCommandEncoder.SetBuffer(new MTLBuffer(buffer.Handle), (ulong)buffer.Offset, (ulong)buffer.Index);
}
}
private readonly void SetCullMode(MTLRenderCommandEncoder renderCommandEncoder)
{
renderCommandEncoder.SetCullMode(_currentState.CullMode);
@ -838,41 +994,64 @@ namespace Ryujinx.Graphics.Metal
renderCommandEncoder.SetStencilReferenceValues((uint)_currentState.FrontRefValue, (uint)_currentState.BackRefValue);
}
private static void SetTextureAndSampler(MTLRenderCommandEncoder renderCommandEncoder, ShaderStage stage, MTLTexture[] textures, MTLSamplerState[] samplers)
private static void SetRenderTexture(MTLRenderCommandEncoder renderCommandEncoder, ShaderStage stage, ulong binding, TextureBase texture)
{
for (int i = 0; i < textures.Length; i++)
if (texture == null)
{
var texture = textures[i];
if (texture != IntPtr.Zero)
{
switch (stage)
{
case ShaderStage.Vertex:
renderCommandEncoder.SetVertexTexture(texture, (ulong)i);
break;
case ShaderStage.Fragment:
renderCommandEncoder.SetFragmentTexture(texture, (ulong)i);
break;
}
}
return;
}
for (int i = 0; i < samplers.Length; i++)
var textureHandle = texture.GetHandle();
if (textureHandle != IntPtr.Zero)
{
var sampler = samplers[i];
if (sampler != IntPtr.Zero)
switch (stage)
{
switch (stage)
{
case ShaderStage.Vertex:
renderCommandEncoder.SetVertexSamplerState(sampler, (ulong)i);
break;
case ShaderStage.Fragment:
renderCommandEncoder.SetFragmentSamplerState(sampler, (ulong)i);
break;
}
case ShaderStage.Vertex:
renderCommandEncoder.SetVertexTexture(textureHandle, binding);
break;
case ShaderStage.Fragment:
renderCommandEncoder.SetFragmentTexture(textureHandle, binding);
break;
}
}
}
private static void SetRenderSampler(MTLRenderCommandEncoder renderCommandEncoder, ShaderStage stage, ulong binding, MTLSamplerState sampler)
{
if (sampler != IntPtr.Zero)
{
switch (stage)
{
case ShaderStage.Vertex:
renderCommandEncoder.SetVertexSamplerState(sampler, binding);
break;
case ShaderStage.Fragment:
renderCommandEncoder.SetFragmentSamplerState(sampler, binding);
break;
}
}
}
private static void SetComputeTexture(MTLComputeCommandEncoder computeCommandEncoder, ulong binding, TextureBase texture)
{
if (texture == null)
{
return;
}
var textureHandle = texture.GetHandle();
if (textureHandle != IntPtr.Zero)
{
computeCommandEncoder.SetTexture(textureHandle, binding);
}
}
private static void SetComputeSampler(MTLComputeCommandEncoder computeCommandEncoder, ulong binding, MTLSamplerState sampler)
{
if (sampler != IntPtr.Zero)
{
computeCommandEncoder.SetSamplerState(sampler, binding);
}
}
}
}

View file

@ -5,6 +5,8 @@ using Ryujinx.Graphics.Shader.Translation;
using SharpMetal.Metal;
using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.Versioning;
namespace Ryujinx.Graphics.Metal

View file

@ -97,9 +97,12 @@ namespace Ryujinx.Graphics.Metal
public ITexture CreateTexture(TextureCreateInfo info)
{
var texture = new Texture(_device, _pipeline, info);
if (info.Target == Target.TextureBuffer)
{
return new TextureBuffer(_device, _pipeline, info);
}
return texture;
return new Texture(_device, _pipeline, info);
}
public ITextureArray CreateTextureArray(int size, bool isBuffer)

View file

@ -69,7 +69,6 @@ namespace Ryujinx.Graphics.Metal
public MTLRenderCommandEncoder GetOrCreateRenderEncoder()
{
MTLRenderCommandEncoder renderCommandEncoder;
if (_currentEncoder == null || _currentEncoderType != EncoderType.Render)
{
renderCommandEncoder = BeginRenderPass();
@ -79,7 +78,7 @@ namespace Ryujinx.Graphics.Metal
renderCommandEncoder = new MTLRenderCommandEncoder(_currentEncoder.Value);
}
_encoderStateManager.RebindState(renderCommandEncoder);
_encoderStateManager.RebindRenderState(renderCommandEncoder);
return renderCommandEncoder;
}
@ -99,15 +98,19 @@ namespace Ryujinx.Graphics.Metal
public MTLComputeCommandEncoder GetOrCreateComputeEncoder()
{
if (_currentEncoder != null)
MTLComputeCommandEncoder computeCommandEncoder;
if (_currentEncoder == null || _currentEncoderType != EncoderType.Compute)
{
if (_currentEncoderType == EncoderType.Compute)
{
return new MTLComputeCommandEncoder(_currentEncoder.Value);
}
computeCommandEncoder = BeginComputePass();
}
else
{
computeCommandEncoder = new MTLComputeCommandEncoder(_currentEncoder.Value);
}
return BeginComputePass();
_encoderStateManager.RebindComputeState(computeCommandEncoder);
return computeCommandEncoder;
}
public void EndCurrentPass()
@ -164,8 +167,7 @@ namespace Ryujinx.Graphics.Metal
{
EndCurrentPass();
var descriptor = new MTLComputePassDescriptor();
var computeCommandEncoder = _commandBuffer.ComputeCommandEncoder(descriptor);
var computeCommandEncoder = _encoderStateManager.CreateComputeCommandEncoder();
_currentEncoder = computeCommandEncoder;
_currentEncoderType = EncoderType.Compute;
@ -274,9 +276,13 @@ namespace Ryujinx.Graphics.Metal
(ulong)size);
}
public void DispatchCompute(int groupsX, int groupsY, int groupsZ)
public void DispatchCompute(int groupsX, int groupsY, int groupsZ, int groupSizeX, int groupSizeY, int groupSizeZ)
{
Logger.Warning?.Print(LogClass.Gpu, "Not Implemented!");
var computeCommandEncoder = GetOrCreateComputeEncoder();
computeCommandEncoder.DispatchThreadgroups(
new MTLSize{width = (ulong)groupsX, height = (ulong)groupsY, depth = (ulong)groupsZ},
new MTLSize{width = (ulong)groupSizeX, height = (ulong)groupSizeY, depth = (ulong)groupSizeZ});
}
public void Draw(int vertexCount, int instanceCount, int firstVertex, int firstInstance)
@ -397,7 +403,10 @@ namespace Ryujinx.Graphics.Metal
public void SetImage(ShaderStage stage, int binding, ITexture texture, Format imageFormat)
{
Logger.Warning?.Print(LogClass.Gpu, "Not Implemented!");
if (texture is TextureBase tex)
{
_encoderStateManager.UpdateTexture(stage, (ulong)binding, tex);
}
}
public void SetImageArray(ShaderStage stage, int binding, IImageArray array)
@ -491,28 +500,14 @@ namespace Ryujinx.Graphics.Metal
public void SetTextureAndSampler(ShaderStage stage, int binding, ITexture texture, ISampler sampler)
{
if (texture is Texture tex)
if (texture is TextureBase tex)
{
if (sampler is Sampler samp)
{
var mtlTexture = tex.MTLTexture;
var mtlSampler = samp.GetSampler();
var index = (ulong)binding;
switch (stage)
{
case ShaderStage.Vertex:
case ShaderStage.Fragment:
_encoderStateManager.UpdateTextureAndSampler(stage, index, mtlTexture, mtlSampler);
break;
case ShaderStage.Compute:
var computeCommandEncoder = GetOrCreateComputeEncoder();
computeCommandEncoder.SetTexture(mtlTexture, index);
computeCommandEncoder.SetSamplerState(mtlSampler, index);
break;
default:
throw new ArgumentOutOfRangeException(nameof(stage), stage, "Unsupported shader stage!");
}
_encoderStateManager.UpdateTextureAndSampler(stage, index, tex, mtlSampler);
}
}
}

View file

@ -26,7 +26,7 @@ namespace Ryujinx.Graphics.Metal
var shaderLibrary = device.NewLibrary(StringHelper.NSString(shader.Code), new MTLCompileOptions(IntPtr.Zero), ref libraryError);
if (libraryError != IntPtr.Zero)
{
Logger.Warning?.Print(LogClass.Gpu, $"Shader linking failed: \n{StringHelper.String(libraryError.LocalizedDescription)}");
Logger.Warning?.Print(LogClass.Gpu, $"{shader.Stage} shader linking failed: \n{StringHelper.String(libraryError.LocalizedDescription)}");
_status = ProgramLinkStatus.Failure;
return;
}
@ -34,7 +34,7 @@ namespace Ryujinx.Graphics.Metal
switch (shaders[index].Stage)
{
case ShaderStage.Compute:
ComputeFunction = shaderLibrary.NewFunction(StringHelper.NSString("computeMain"));
ComputeFunction = shaderLibrary.NewFunction(StringHelper.NSString("kernelMain"));
break;
case ShaderStage.Vertex:
VertexFunction = shaderLibrary.NewFunction(StringHelper.NSString("vertexMain"));

View file

@ -10,24 +10,10 @@ using System.Runtime.Versioning;
namespace Ryujinx.Graphics.Metal
{
[SupportedOSPlatform("macos")]
class Texture : ITexture, IDisposable
class Texture : TextureBase, ITexture
{
private readonly TextureCreateInfo _info;
private readonly Pipeline _pipeline;
private readonly MTLDevice _device;
public MTLTexture MTLTexture;
public TextureCreateInfo Info => _info;
public int Width => Info.Width;
public int Height => Info.Height;
public int Depth => Info.Depth;
public Texture(MTLDevice device, Pipeline pipeline, TextureCreateInfo info)
public Texture(MTLDevice device, Pipeline pipeline, TextureCreateInfo info) : base(device, pipeline, info)
{
_device = device;
_pipeline = pipeline;
_info = info;
var descriptor = new MTLTextureDescriptor
{
PixelFormat = FormatTable.GetFormat(Info.Format),
@ -50,15 +36,11 @@ namespace Ryujinx.Graphics.Metal
descriptor.Swizzle = GetSwizzle(info, descriptor.PixelFormat);
MTLTexture = _device.NewTexture(descriptor);
_mtlTexture = _device.NewTexture(descriptor);
}
public Texture(MTLDevice device, Pipeline pipeline, TextureCreateInfo info, MTLTexture sourceTexture, int firstLayer, int firstLevel)
public Texture(MTLDevice device, Pipeline pipeline, TextureCreateInfo info, MTLTexture sourceTexture, int firstLayer, int firstLevel) : base(device, pipeline, info)
{
_device = device;
_pipeline = pipeline;
_info = info;
var pixelFormat = FormatTable.GetFormat(Info.Format);
var textureType = Info.Target.Convert();
NSRange levels;
@ -75,7 +57,7 @@ namespace Ryujinx.Graphics.Metal
var swizzle = GetSwizzle(info, pixelFormat);
MTLTexture = sourceTexture.NewTextureView(pixelFormat, textureType, levels, slices, swizzle);
_mtlTexture = sourceTexture.NewTextureView(pixelFormat, textureType, levels, slices, swizzle);
}
private MTLTextureSwizzleChannels GetSwizzle(TextureCreateInfo info, MTLPixelFormat pixelFormat)
@ -118,14 +100,14 @@ namespace Ryujinx.Graphics.Metal
if (destination is Texture destinationTexture)
{
blitCommandEncoder.CopyFromTexture(
MTLTexture,
_mtlTexture,
(ulong)firstLayer,
(ulong)firstLevel,
destinationTexture.MTLTexture,
destinationTexture._mtlTexture,
(ulong)firstLayer,
(ulong)firstLevel,
MTLTexture.ArrayLength,
MTLTexture.MipmapLevelCount);
_mtlTexture.ArrayLength,
_mtlTexture.MipmapLevelCount);
}
}
@ -136,14 +118,14 @@ namespace Ryujinx.Graphics.Metal
if (destination is Texture destinationTexture)
{
blitCommandEncoder.CopyFromTexture(
MTLTexture,
_mtlTexture,
(ulong)srcLayer,
(ulong)srcLevel,
destinationTexture.MTLTexture,
destinationTexture._mtlTexture,
(ulong)dstLayer,
(ulong)dstLevel,
MTLTexture.ArrayLength,
MTLTexture.MipmapLevelCount);
_mtlTexture.ArrayLength,
_mtlTexture.MipmapLevelCount);
}
}
@ -158,7 +140,7 @@ namespace Ryujinx.Graphics.Metal
ulong bytesPerRow = (ulong)Info.GetMipStride(level);
ulong bytesPerImage = 0;
if (MTLTexture.TextureType == MTLTextureType.Type3D)
if (_mtlTexture.TextureType == MTLTextureType.Type3D)
{
bytesPerImage = bytesPerRow * (ulong)Info.Height;
}
@ -167,11 +149,11 @@ namespace Ryujinx.Graphics.Metal
MTLBuffer mtlBuffer = new(Unsafe.As<BufferHandle, IntPtr>(ref handle));
blitCommandEncoder.CopyFromTexture(
MTLTexture,
_mtlTexture,
(ulong)layer,
(ulong)level,
new MTLOrigin(),
new MTLSize { width = MTLTexture.Width, height = MTLTexture.Height, depth = MTLTexture.Depth },
new MTLSize { width = _mtlTexture.Width, height = _mtlTexture.Height, depth = _mtlTexture.Depth },
mtlBuffer,
(ulong)range.Offset,
bytesPerRow,
@ -180,7 +162,7 @@ namespace Ryujinx.Graphics.Metal
public ITexture CreateView(TextureCreateInfo info, int firstLayer, int firstLevel)
{
return new Texture(_device, _pipeline, info, MTLTexture, firstLayer, firstLevel);
return new Texture(_device, _pipeline, info, _mtlTexture, firstLayer, firstLevel);
}
public PinnedSpan<byte> GetData()
@ -195,7 +177,7 @@ namespace Ryujinx.Graphics.Metal
ulong bytesPerRow = (ulong)Info.GetMipStride(level);
ulong length = bytesPerRow * (ulong)Info.Height;
ulong bytesPerImage = 0;
if (MTLTexture.TextureType == MTLTextureType.Type3D)
if (_mtlTexture.TextureType == MTLTextureType.Type3D)
{
bytesPerImage = length;
}
@ -205,11 +187,11 @@ namespace Ryujinx.Graphics.Metal
var mtlBuffer = _device.NewBuffer(length, MTLResourceOptions.ResourceStorageModeShared);
blitCommandEncoder.CopyFromTexture(
MTLTexture,
_mtlTexture,
(ulong)layer,
(ulong)level,
new MTLOrigin(),
new MTLSize { width = MTLTexture.Width, height = MTLTexture.Height, depth = MTLTexture.Depth },
new MTLSize { width = _mtlTexture.Width, height = _mtlTexture.Height, depth = _mtlTexture.Depth },
mtlBuffer,
0,
bytesPerRow,
@ -255,7 +237,7 @@ namespace Ryujinx.Graphics.Metal
(ulong)Info.GetMipStride(level),
(ulong)mipSize,
new MTLSize { width = (ulong)width, height = (ulong)height, depth = is3D ? (ulong)depth : 1 },
MTLTexture,
_mtlTexture,
0,
(ulong)level,
new MTLOrigin()
@ -282,7 +264,7 @@ namespace Ryujinx.Graphics.Metal
ulong bytesPerRow = (ulong)Info.GetMipStride(level);
ulong bytesPerImage = 0;
if (MTLTexture.TextureType == MTLTextureType.Type3D)
if (_mtlTexture.TextureType == MTLTextureType.Type3D)
{
bytesPerImage = bytesPerRow * (ulong)Info.Height;
}
@ -299,8 +281,8 @@ namespace Ryujinx.Graphics.Metal
0,
bytesPerRow,
bytesPerImage,
new MTLSize { width = MTLTexture.Width, height = MTLTexture.Height, depth = MTLTexture.Depth },
MTLTexture,
new MTLSize { width = _mtlTexture.Width, height = _mtlTexture.Height, depth = _mtlTexture.Depth },
_mtlTexture,
(ulong)layer,
(ulong)level,
new MTLOrigin()
@ -317,7 +299,7 @@ namespace Ryujinx.Graphics.Metal
ulong bytesPerRow = (ulong)Info.GetMipStride(level);
ulong bytesPerImage = 0;
if (MTLTexture.TextureType == MTLTextureType.Type3D)
if (_mtlTexture.TextureType == MTLTextureType.Type3D)
{
bytesPerImage = bytesPerRow * (ulong)Info.Height;
}
@ -335,7 +317,7 @@ namespace Ryujinx.Graphics.Metal
bytesPerRow,
bytesPerImage,
new MTLSize { width = (ulong)region.Width, height = (ulong)region.Height, depth = 1 },
MTLTexture,
_mtlTexture,
(ulong)layer,
(ulong)level,
new MTLOrigin { x = (ulong)region.X, y = (ulong)region.Y }
@ -348,18 +330,7 @@ namespace Ryujinx.Graphics.Metal
public void SetStorage(BufferRange buffer)
{
Logger.Warning?.Print(LogClass.Gpu, "Not Implemented!");
}
public void Release()
{
Dispose();
}
public void Dispose()
{
MTLTexture.SetPurgeableState(MTLPurgeableState.Volatile);
MTLTexture.Dispose();
throw new NotImplementedException();
}
}
}

View file

@ -0,0 +1,59 @@
using Ryujinx.Common.Logging;
using Ryujinx.Graphics.GAL;
using SharpMetal.Foundation;
using SharpMetal.Metal;
using System;
using System.Buffers;
using System.Runtime.CompilerServices;
using System.Runtime.Versioning;
namespace Ryujinx.Graphics.Metal
{
[SupportedOSPlatform("macos")]
abstract class TextureBase : IDisposable
{
private bool _disposed;
protected readonly TextureCreateInfo _info;
protected readonly Pipeline _pipeline;
protected readonly MTLDevice _device;
protected MTLTexture _mtlTexture;
public TextureCreateInfo Info => _info;
public int Width => Info.Width;
public int Height => Info.Height;
public int Depth => Info.Depth;
public TextureBase(MTLDevice device, Pipeline pipeline, TextureCreateInfo info)
{
_device = device;
_pipeline = pipeline;
_info = info;
}
public MTLTexture GetHandle()
{
if (_disposed)
{
return new MTLTexture(IntPtr.Zero);
}
return _mtlTexture;
}
public void Release()
{
Dispose();
}
public void Dispose()
{
if (_mtlTexture != IntPtr.Zero)
{
_mtlTexture.Dispose();
}
_disposed = true;
}
}
}

View file

@ -0,0 +1,112 @@
using Ryujinx.Common.Logging;
using Ryujinx.Graphics.GAL;
using SharpMetal.Foundation;
using SharpMetal.Metal;
using System;
using System.Buffers;
using System.Runtime.CompilerServices;
using System.Runtime.Versioning;
namespace Ryujinx.Graphics.Metal
{
[SupportedOSPlatform("macos")]
class TextureBuffer : Texture, ITexture
{
private MTLBuffer? _bufferHandle;
private int _offset;
private int _size;
public TextureBuffer(MTLDevice device, Pipeline pipeline, TextureCreateInfo info) : base(device, pipeline, info) { }
public void CreateView()
{
var descriptor = new MTLTextureDescriptor
{
PixelFormat = FormatTable.GetFormat(Info.Format),
Usage = MTLTextureUsage.ShaderRead | MTLTextureUsage.ShaderWrite,
StorageMode = MTLStorageMode.Shared,
TextureType = Info.Target.Convert(),
Width = (ulong)Info.Width,
Height = (ulong)Info.Height
};
_mtlTexture = _bufferHandle.Value.NewTexture(descriptor, (ulong)_offset, (ulong)_size);
}
public void CopyTo(ITexture destination, int firstLayer, int firstLevel)
{
throw new NotSupportedException();
}
public void CopyTo(ITexture destination, int srcLayer, int dstLayer, int srcLevel, int dstLevel)
{
throw new NotSupportedException();
}
public void CopyTo(ITexture destination, Extents2D srcRegion, Extents2D dstRegion, bool linearFilter)
{
throw new NotSupportedException();
}
public ITexture CreateView(TextureCreateInfo info, int firstLayer, int firstLevel)
{
throw new NotSupportedException();
}
// TODO: Implement this method
public PinnedSpan<byte> GetData()
{
throw new NotImplementedException();
}
public PinnedSpan<byte> GetData(int layer, int level)
{
return GetData();
}
public void CopyTo(BufferRange range, int layer, int level, int stride)
{
throw new NotImplementedException();
}
public void SetData(IMemoryOwner<byte> data)
{
// TODO
//_gd.SetBufferData(_bufferHandle, _offset, data.Memory.Span);
data.Dispose();
}
public void SetData(IMemoryOwner<byte> data, int layer, int level)
{
throw new NotSupportedException();
}
public void SetData(IMemoryOwner<byte> data, int layer, int level, Rectangle<int> region)
{
throw new NotSupportedException();
}
public void SetStorage(BufferRange buffer)
{
if (buffer.Handle != BufferHandle.Null)
{
var handle = buffer.Handle;
MTLBuffer bufferHandle = new(Unsafe.As<BufferHandle, IntPtr>(ref handle));
if (_bufferHandle == bufferHandle &&
_offset == buffer.Offset &&
_size == buffer.Size)
{
return;
}
_bufferHandle = bufferHandle;
_offset = buffer.Offset;
_size = buffer.Size;
Release();
CreateView();
}
}
}
}

View file

@ -205,7 +205,7 @@ namespace Ryujinx.Graphics.OpenGL
Buffer.Copy(source, destination, srcOffset, dstOffset, size);
}
public void DispatchCompute(int groupsX, int groupsY, int groupsZ)
public void DispatchCompute(int groupsX, int groupsY, int groupsZ, int groupSizeX, int groupSizeY, int groupSizeZ)
{
if (!_program.IsLinked)
{

View file

@ -8,6 +8,9 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
{
public const string Tab = " ";
// The number of additional arguments that every function (except for the main one) must have (for instance support_buffer)
public const int additionalArgCount = 1;
public StructuredFunction CurrentFunction { get; set; }
public StructuredProgramInfo Info { get; }

View file

@ -54,6 +54,9 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
DeclareInputAttributes(context, info.IoDefinitions.Where(x => IsUserDefined(x, StorageKind.Input)));
context.AppendLine();
DeclareOutputAttributes(context, info.IoDefinitions.Where(x => x.StorageKind == StorageKind.Output));
context.AppendLine();
DeclareBufferStructures(context, context.Properties.ConstantBuffers.Values);
DeclareBufferStructures(context, context.Properties.StorageBuffers.Values);
}
static bool IsUserDefined(IoDefinition ioDefinition, StorageKind storageKind)
@ -111,8 +114,41 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
{
foreach (var memory in memories)
{
string arraySize = "";
if ((memory.Type & AggregateType.Array) != 0)
{
arraySize = $"[{memory.ArrayLength}]";
}
var typeName = GetVarTypeName(context, memory.Type & ~AggregateType.Array);
context.AppendLine($"{typeName} {memory.Name}[{memory.ArrayLength}];");
context.AppendLine($"{typeName} {memory.Name}{arraySize};");
}
}
private static void DeclareBufferStructures(CodeGenContext context, IEnumerable<BufferDefinition> buffers)
{
foreach (BufferDefinition buffer in buffers)
{
context.AppendLine($"struct Struct_{buffer.Name}");
context.EnterScope();
foreach (StructureField field in buffer.Type.Fields)
{
if (field.Type.HasFlag(AggregateType.Array) && field.ArrayLength > 0)
{
string typeName = GetVarTypeName(context, field.Type & ~AggregateType.Array);
context.AppendLine($"{typeName} {field.Name}[{field.ArrayLength}];");
}
else
{
string typeName = GetVarTypeName(context, field.Type & ~AggregateType.Array);
context.AppendLine($"{typeName} {field.Name};");
}
}
context.LeaveScope(";");
context.AppendLine();
}
}
@ -124,7 +160,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
}
else
{
if (inputs.Any())
if (inputs.Any() || context.Definitions.Stage == ShaderStage.Fragment)
{
string prefix = "";
@ -136,9 +172,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
case ShaderStage.Fragment:
context.AppendLine($"struct FragmentIn");
break;
case ShaderStage.Compute:
context.AppendLine($"struct KernelIn");
break;
}
context.EnterScope();

View file

@ -134,7 +134,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
case Instruction.Load:
return Load(context, operation);
case Instruction.Lod:
return "|| LOD ||";
return Lod(context, operation);
case Instruction.MemoryBarrier:
return "|| MEMORY BARRIER ||";
case Instruction.Store:

View file

@ -12,11 +12,16 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
var functon = context.GetFunction(funcId.Value);
string[] args = new string[operation.SourcesCount - 1];
int argCount = operation.SourcesCount - 1;
string[] args = new string[argCount + CodeGenContext.additionalArgCount];
for (int i = 0; i < args.Length; i++)
// Additional arguments
args[0] = "support_buffer";
int argIndex = CodeGenContext.additionalArgCount;
for (int i = 0; i < argCount; i++)
{
args[i] = GetSourceExpr(context, operation.GetSource(i + 1), functon.GetArgumentType(i));
args[argIndex++] = GetSourceExpr(context, operation.GetSource(i + 1), functon.GetArgumentType(i));
}
return $"{functon.Name}({string.Join(", ", args)})";

View file

@ -24,6 +24,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
inputsCount--;
}
string fieldName = "";
switch (storageKind)
{
case StorageKind.ConstantBuffer:
@ -45,6 +46,15 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
StructureField field = buffer.Type.Fields[fieldIndex.Value];
varName = buffer.Name;
if ((field.Type & AggregateType.Array) != 0 && field.ArrayLength == 0)
{
// Unsized array, the buffer is indexed instead of the field
fieldName = "." + field.Name;
}
else
{
varName += "->" + field.Name;
}
varType = field.Type;
break;
@ -126,6 +136,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
varName += $"[{GetSourceExpr(context, src, AggregateType.S32)}]";
}
}
varName += fieldName;
if (isStore)
{
@ -141,6 +152,37 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
return GenerateLoadOrStore(context, operation, isStore: false);
}
// TODO: check this
public static string Lod(CodeGenContext context, AstOperation operation)
{
AstTextureOperation texOp = (AstTextureOperation)operation;
int coordsCount = texOp.Type.GetDimensions();
int coordsIndex = 0;
string samplerName = GetSamplerName(context.Properties, texOp);
string coordsExpr;
if (coordsCount > 1)
{
string[] elems = new string[coordsCount];
for (int index = 0; index < coordsCount; index++)
{
elems[index] = GetSourceExpr(context, texOp.GetSource(coordsIndex + index), AggregateType.FP32);
}
coordsExpr = "float" + coordsCount + "(" + string.Join(", ", elems) + ")";
}
else
{
coordsExpr = GetSourceExpr(context, texOp.GetSource(coordsIndex), AggregateType.FP32);
}
return $"tex_{samplerName}.calculate_unclamped_lod(samp_{samplerName}, {coordsExpr}){GetMaskMultiDest(texOp.Index)}";
}
public static string Store(CodeGenContext context, AstOperation operation)
{
return GenerateLoadOrStore(context, operation, isStore: true);
@ -176,11 +218,13 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
}
else
{
texCall += "sample";
if (isGather)
{
texCall += "_gather";
texCall += "gather";
}
else
{
texCall += "sample";
}
if (isShadow)
@ -188,22 +232,31 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
texCall += "_compare";
}
texCall += $"(samp_{samplerName}";
texCall += $"(samp_{samplerName}, ";
}
int coordsCount = texOp.Type.GetDimensions();
int pCount = coordsCount;
bool appended = false;
void Append(string str)
{
texCall += ", " + str;
if (appended)
{
texCall += ", ";
}
else {
appended = true;
}
texCall += str;
}
AggregateType coordType = intCoords ? AggregateType.S32 : AggregateType.FP32;
string AssemblePVector(int count)
{
string coords;
if (count > 1)
{
string[] elems = new string[count];
@ -213,14 +266,16 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
elems[index] = Src(coordType);
}
string prefix = intCoords ? "int" : "float";
return prefix + count + "(" + string.Join(", ", elems) + ")";
coords = string.Join(", ", elems);
}
else
{
return Src(coordType);
coords = Src(coordType);
}
string prefix = intCoords ? "uint" : "float";
return prefix + (count > 1 ? count : "") + "(" + coords + ")";
}
Append(AssemblePVector(pCount));
@ -254,6 +309,11 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
private static string GetMaskMultiDest(int mask)
{
if (mask == 0x0)
{
return "";
}
string swizzle = ".";
for (int i = 0; i < 4; i++)

View file

@ -35,7 +35,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
IoVariable.ThreadId => ("thread_position_in_threadgroup", AggregateType.Vector3 | AggregateType.U32),
IoVariable.VertexId => ("vertex_id", AggregateType.S32),
// gl_VertexIndex does not have a direct equivalent in MSL
IoVariable.VertexIndex => ("vertex_index", AggregateType.U32),
IoVariable.VertexIndex => ("vertex_id", AggregateType.U32),
IoVariable.ViewportIndex => ("viewport_array_index", AggregateType.S32),
IoVariable.FragmentCoord => ("in.position", AggregateType.Vector4 | AggregateType.FP32),
_ => (null, AggregateType.Invalid),

View file

@ -48,6 +48,12 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
PrintBlock(context, function.MainBlock, isMainFunc);
// In case the shader hasn't returned, return
if (isMainFunc && stage != ShaderStage.Compute)
{
context.AppendLine("return out;");
}
context.LeaveScope();
}
@ -57,11 +63,20 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
ShaderStage stage,
bool isMainFunc = false)
{
string[] args = new string[function.InArguments.Length + function.OutArguments.Length];
int additionalArgCount = isMainFunc ? 0 : CodeGenContext.additionalArgCount;
string[] args = new string[additionalArgCount + function.InArguments.Length + function.OutArguments.Length];
// All non-main functions need to be able to access the support_buffer as well
if (!isMainFunc)
{
args[0] = "constant Struct_support_buffer* support_buffer";
}
int argIndex = additionalArgCount;
for (int i = 0; i < function.InArguments.Length; i++)
{
args[i] = $"{Declarations.GetVarTypeName(context, function.InArguments[i])} {OperandManager.GetArgumentName(i)}";
args[argIndex++] = $"{Declarations.GetVarTypeName(context, function.InArguments[i])} {OperandManager.GetArgumentName(i)}";
}
for (int i = 0; i < function.OutArguments.Length; i++)
@ -69,7 +84,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
int j = i + function.InArguments.Length;
// Likely need to be made into pointers
args[j] = $"out {Declarations.GetVarTypeName(context, function.OutArguments[i])} {OperandManager.GetArgumentName(j)}";
args[argIndex++] = $"out {Declarations.GetVarTypeName(context, function.OutArguments[i])} {OperandManager.GetArgumentName(j)}";
}
string funcKeyword = "inline";
@ -97,20 +112,16 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
returnType = "void";
}
if (context.AttributeUsage.UsedInputAttributes != 0)
if (stage == ShaderStage.Vertex)
{
if (stage == ShaderStage.Vertex)
if (context.AttributeUsage.UsedInputAttributes != 0)
{
args = args.Prepend("VertexIn in [[stage_in]]").ToArray();
}
else if (stage == ShaderStage.Fragment)
{
args = args.Prepend("FragmentIn in [[stage_in]]").ToArray();
}
else if (stage == ShaderStage.Compute)
{
args = args.Prepend("KernelIn in [[stage_in]]").ToArray();
}
}
else if (stage == ShaderStage.Fragment)
{
args = args.Prepend("FragmentIn in [[stage_in]]").ToArray();
}
// TODO: add these only if they are used
@ -119,18 +130,22 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
args = args.Append("uint vertex_id [[vertex_id]]").ToArray();
args = args.Append("uint instance_id [[instance_id]]").ToArray();
}
else if (stage == ShaderStage.Compute)
{
args = args.Append("uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]]").ToArray();
args = args.Append("uint3 thread_position_in_grid [[thread_position_in_grid]]").ToArray();
args = args.Append("uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]]").ToArray();
}
foreach (var constantBuffer in context.Properties.ConstantBuffers.Values)
{
var varType = constantBuffer.Type.Fields[0].Type & ~AggregateType.Array;
args = args.Append($"constant {Declarations.GetVarTypeName(context, varType)} *{constantBuffer.Name} [[buffer({constantBuffer.Binding})]]").ToArray();
args = args.Append($"constant Struct_{constantBuffer.Name}* {constantBuffer.Name} [[buffer({constantBuffer.Binding})]]").ToArray();
}
foreach (var storageBuffers in context.Properties.StorageBuffers.Values)
{
var varType = storageBuffers.Type.Fields[0].Type & ~AggregateType.Array;
// Offset the binding by 15 to avoid clashing with the constant buffers
args = args.Append($"device {Declarations.GetVarTypeName(context, varType)} *{storageBuffers.Name} [[buffer({storageBuffers.Binding + 15})]]").ToArray();
args = args.Append($"device Struct_{storageBuffers.Name}* {storageBuffers.Name} [[buffer({storageBuffers.Binding + 15})]]").ToArray();
}
foreach (var texture in context.Properties.Textures.Values)

View file

@ -861,7 +861,7 @@ namespace Ryujinx.Graphics.Vulkan
_pipeline.SetStorageBuffers(1, sbRanges);
_pipeline.SetProgram(_programStrideChange);
_pipeline.DispatchCompute(1 + elems / ConvertElementsPerWorkgroup, 1, 1);
_pipeline.DispatchCompute(1 + elems / ConvertElementsPerWorkgroup, 1, 1, 0, 0, 0);
_pipeline.Finish(gd, cbs);
}
@ -1044,7 +1044,7 @@ namespace Ryujinx.Graphics.Vulkan
int dispatchX = (Math.Min(srcView.Info.Width, dstView.Info.Width) + 31) / 32;
int dispatchY = (Math.Min(srcView.Info.Height, dstView.Info.Height) + 31) / 32;
_pipeline.DispatchCompute(dispatchX, dispatchY, 1);
_pipeline.DispatchCompute(dispatchX, dispatchY, 1, 0, 0, 0);
if (srcView != src)
{
@ -1170,7 +1170,7 @@ namespace Ryujinx.Graphics.Vulkan
_pipeline.SetTextureAndSamplerIdentitySwizzle(ShaderStage.Compute, 0, srcView, null);
_pipeline.SetImage(ShaderStage.Compute, 0, dstView.GetView(format));
_pipeline.DispatchCompute(dispatchX, dispatchY, 1);
_pipeline.DispatchCompute(dispatchX, dispatchY, 1, 0, 0, 0);
if (srcView != src)
{
@ -1582,7 +1582,7 @@ namespace Ryujinx.Graphics.Vulkan
_pipeline.SetStorageBuffers(stackalloc[] { new BufferAssignment(3, patternScoped.Range) });
_pipeline.SetProgram(_programConvertIndirectData);
_pipeline.DispatchCompute(1, 1, 1);
_pipeline.DispatchCompute(1, 1, 1, 0, 0, 0);
BufferHolder.InsertBufferBarrier(
gd,
@ -1684,7 +1684,7 @@ namespace Ryujinx.Graphics.Vulkan
_pipeline.SetStorageBuffers(1, sbRanges);
_pipeline.SetProgram(_programConvertD32S8ToD24S8);
_pipeline.DispatchCompute(1 + inSize / ConvertElementsPerWorkgroup, 1, 1);
_pipeline.DispatchCompute(1 + inSize / ConvertElementsPerWorkgroup, 1, 1, 0, 0, 0);
_pipeline.Finish(gd, cbs);

View file

@ -295,7 +295,7 @@ namespace Ryujinx.Graphics.Vulkan
}
}
public void DispatchCompute(int groupsX, int groupsY, int groupsZ)
public void DispatchCompute(int groupsX, int groupsY, int groupsZ, int groupSizeX, int groupSizeY, int groupSizeZ)
{
if (!_program.IsLinked)
{