0
0
Fork 0
mirror of https://github.com/GreemDev/Ryujinx.git synced 2024-12-23 10:25:47 +00:00

Fix vote and shuffle shader instructions on AMD GPUs (#5540)

* Move shuffle handling out of the backend to a transform pass

* Handle subgroup sizes higher than 32

* Stop using the subgroup size control extension

* Make GenerateShuffleFunction static

* Shader cache version bump
This commit is contained in:
gdkchan 2023-08-16 21:31:07 -03:00 committed by GitHub
parent 64079c034c
commit 6ed613a6e6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
35 changed files with 445 additions and 265 deletions

View file

@ -52,6 +52,7 @@ namespace Ryujinx.Graphics.GAL
public readonly int MaximumComputeSharedMemorySize; public readonly int MaximumComputeSharedMemorySize;
public readonly float MaximumSupportedAnisotropy; public readonly float MaximumSupportedAnisotropy;
public readonly int ShaderSubgroupSize;
public readonly int StorageBufferOffsetAlignment; public readonly int StorageBufferOffsetAlignment;
public readonly int GatherBiasPrecision; public readonly int GatherBiasPrecision;
@ -101,6 +102,7 @@ namespace Ryujinx.Graphics.GAL
uint maximumImagesPerStage, uint maximumImagesPerStage,
int maximumComputeSharedMemorySize, int maximumComputeSharedMemorySize,
float maximumSupportedAnisotropy, float maximumSupportedAnisotropy,
int shaderSubgroupSize,
int storageBufferOffsetAlignment, int storageBufferOffsetAlignment,
int gatherBiasPrecision) int gatherBiasPrecision)
{ {
@ -148,6 +150,7 @@ namespace Ryujinx.Graphics.GAL
MaximumImagesPerStage = maximumImagesPerStage; MaximumImagesPerStage = maximumImagesPerStage;
MaximumComputeSharedMemorySize = maximumComputeSharedMemorySize; MaximumComputeSharedMemorySize = maximumComputeSharedMemorySize;
MaximumSupportedAnisotropy = maximumSupportedAnisotropy; MaximumSupportedAnisotropy = maximumSupportedAnisotropy;
ShaderSubgroupSize = shaderSubgroupSize;
StorageBufferOffsetAlignment = storageBufferOffsetAlignment; StorageBufferOffsetAlignment = storageBufferOffsetAlignment;
GatherBiasPrecision = gatherBiasPrecision; GatherBiasPrecision = gatherBiasPrecision;
} }

View file

@ -22,7 +22,7 @@ namespace Ryujinx.Graphics.Gpu.Shader.DiskCache
private const ushort FileFormatVersionMajor = 1; private const ushort FileFormatVersionMajor = 1;
private const ushort FileFormatVersionMinor = 2; private const ushort FileFormatVersionMinor = 2;
private const uint FileFormatVersionPacked = ((uint)FileFormatVersionMajor << 16) | FileFormatVersionMinor; private const uint FileFormatVersionPacked = ((uint)FileFormatVersionMajor << 16) | FileFormatVersionMinor;
private const uint CodeGenVersion = 5576; private const uint CodeGenVersion = 5540;
private const string SharedTocFileName = "shared.toc"; private const string SharedTocFileName = "shared.toc";
private const string SharedDataFileName = "shared.data"; private const string SharedDataFileName = "shared.data";

View file

@ -137,6 +137,8 @@ namespace Ryujinx.Graphics.Gpu.Shader
public int QueryHostStorageBufferOffsetAlignment() => _context.Capabilities.StorageBufferOffsetAlignment; public int QueryHostStorageBufferOffsetAlignment() => _context.Capabilities.StorageBufferOffsetAlignment;
public int QueryHostSubgroupSize() => _context.Capabilities.ShaderSubgroupSize;
public bool QueryHostSupportsBgraFormat() => _context.Capabilities.SupportsBgraFormat; public bool QueryHostSupportsBgraFormat() => _context.Capabilities.SupportsBgraFormat;
public bool QueryHostSupportsFragmentShaderInterlock() => _context.Capabilities.SupportsFragmentShaderInterlock; public bool QueryHostSupportsFragmentShaderInterlock() => _context.Capabilities.SupportsFragmentShaderInterlock;

View file

@ -7,5 +7,6 @@
public const int MaxVertexAttribs = 16; public const int MaxVertexAttribs = 16;
public const int MaxVertexBuffers = 16; public const int MaxVertexBuffers = 16;
public const int MaxTransformFeedbackBuffers = 4; public const int MaxTransformFeedbackBuffers = 4;
public const int MaxSubgroupSize = 64;
} }
} }

View file

@ -175,6 +175,7 @@ namespace Ryujinx.Graphics.OpenGL
maximumImagesPerStage: 8, maximumImagesPerStage: 8,
maximumComputeSharedMemorySize: HwCapabilities.MaximumComputeSharedMemorySize, maximumComputeSharedMemorySize: HwCapabilities.MaximumComputeSharedMemorySize,
maximumSupportedAnisotropy: HwCapabilities.MaximumSupportedAnisotropy, maximumSupportedAnisotropy: HwCapabilities.MaximumSupportedAnisotropy,
shaderSubgroupSize: Constants.MaxSubgroupSize,
storageBufferOffsetAlignment: HwCapabilities.StorageBufferOffsetAlignment, storageBufferOffsetAlignment: HwCapabilities.StorageBufferOffsetAlignment,
gatherBiasPrecision: intelWindows || amdWindows ? 8 : 0); // Precision is 8 for these vendors on Vulkan. gatherBiasPrecision: intelWindows || amdWindows ? 8 : 0); // Precision is 8 for these vendors on Vulkan.
} }

View file

@ -25,6 +25,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
{ {
context.AppendLine("#extension GL_KHR_shader_subgroup_basic : enable"); context.AppendLine("#extension GL_KHR_shader_subgroup_basic : enable");
context.AppendLine("#extension GL_KHR_shader_subgroup_ballot : enable"); context.AppendLine("#extension GL_KHR_shader_subgroup_ballot : enable");
context.AppendLine("#extension GL_KHR_shader_subgroup_shuffle : enable");
} }
context.AppendLine("#extension GL_ARB_shader_group_vote : enable"); context.AppendLine("#extension GL_ARB_shader_group_vote : enable");
@ -201,26 +202,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
AppendHelperFunction(context, "Ryujinx.Graphics.Shader/CodeGen/Glsl/HelperFunctions/MultiplyHighU32.glsl"); AppendHelperFunction(context, "Ryujinx.Graphics.Shader/CodeGen/Glsl/HelperFunctions/MultiplyHighU32.glsl");
} }
if ((info.HelperFunctionsMask & HelperFunctionsMask.Shuffle) != 0)
{
AppendHelperFunction(context, "Ryujinx.Graphics.Shader/CodeGen/Glsl/HelperFunctions/Shuffle.glsl");
}
if ((info.HelperFunctionsMask & HelperFunctionsMask.ShuffleDown) != 0)
{
AppendHelperFunction(context, "Ryujinx.Graphics.Shader/CodeGen/Glsl/HelperFunctions/ShuffleDown.glsl");
}
if ((info.HelperFunctionsMask & HelperFunctionsMask.ShuffleUp) != 0)
{
AppendHelperFunction(context, "Ryujinx.Graphics.Shader/CodeGen/Glsl/HelperFunctions/ShuffleUp.glsl");
}
if ((info.HelperFunctionsMask & HelperFunctionsMask.ShuffleXor) != 0)
{
AppendHelperFunction(context, "Ryujinx.Graphics.Shader/CodeGen/Glsl/HelperFunctions/ShuffleXor.glsl");
}
if ((info.HelperFunctionsMask & HelperFunctionsMask.SwizzleAdd) != 0) if ((info.HelperFunctionsMask & HelperFunctionsMask.SwizzleAdd) != 0)
{ {
AppendHelperFunction(context, "Ryujinx.Graphics.Shader/CodeGen/Glsl/HelperFunctions/SwizzleAdd.glsl"); AppendHelperFunction(context, "Ryujinx.Graphics.Shader/CodeGen/Glsl/HelperFunctions/SwizzleAdd.glsl");

View file

@ -5,10 +5,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl
public static string MultiplyHighS32 = "Helper_MultiplyHighS32"; public static string MultiplyHighS32 = "Helper_MultiplyHighS32";
public static string MultiplyHighU32 = "Helper_MultiplyHighU32"; public static string MultiplyHighU32 = "Helper_MultiplyHighU32";
public static string Shuffle = "Helper_Shuffle";
public static string ShuffleDown = "Helper_ShuffleDown";
public static string ShuffleUp = "Helper_ShuffleUp";
public static string ShuffleXor = "Helper_ShuffleXor";
public static string SwizzleAdd = "Helper_SwizzleAdd"; public static string SwizzleAdd = "Helper_SwizzleAdd";
} }
} }

View file

@ -1,11 +0,0 @@
float Helper_Shuffle(float x, uint index, uint mask, out bool valid)
{
uint clamp = mask & 0x1fu;
uint segMask = (mask >> 8) & 0x1fu;
uint minThreadId = $SUBGROUP_INVOCATION$ & segMask;
uint maxThreadId = minThreadId | (clamp & ~segMask);
uint srcThreadId = (index & ~segMask) | minThreadId;
valid = srcThreadId <= maxThreadId;
float v = $SUBGROUP_BROADCAST$(x, srcThreadId);
return valid ? v : x;
}

View file

@ -1,11 +0,0 @@
float Helper_ShuffleDown(float x, uint index, uint mask, out bool valid)
{
uint clamp = mask & 0x1fu;
uint segMask = (mask >> 8) & 0x1fu;
uint minThreadId = $SUBGROUP_INVOCATION$ & segMask;
uint maxThreadId = minThreadId | (clamp & ~segMask);
uint srcThreadId = $SUBGROUP_INVOCATION$ + index;
valid = srcThreadId <= maxThreadId;
float v = $SUBGROUP_BROADCAST$(x, srcThreadId);
return valid ? v : x;
}

View file

@ -1,9 +0,0 @@
float Helper_ShuffleUp(float x, uint index, uint mask, out bool valid)
{
uint segMask = (mask >> 8) & 0x1fu;
uint minThreadId = $SUBGROUP_INVOCATION$ & segMask;
uint srcThreadId = $SUBGROUP_INVOCATION$ - index;
valid = int(srcThreadId) >= int(minThreadId);
float v = $SUBGROUP_BROADCAST$(x, srcThreadId);
return valid ? v : x;
}

View file

@ -1,11 +0,0 @@
float Helper_ShuffleXor(float x, uint index, uint mask, out bool valid)
{
uint clamp = mask & 0x1fu;
uint segMask = (mask >> 8) & 0x1fu;
uint minThreadId = $SUBGROUP_INVOCATION$ & segMask;
uint maxThreadId = minThreadId | (clamp & ~segMask);
uint srcThreadId = $SUBGROUP_INVOCATION$ ^ index;
valid = srcThreadId <= maxThreadId;
float v = $SUBGROUP_BROADCAST$(x, srcThreadId);
return valid ? v : x;
}

View file

@ -9,6 +9,7 @@ using static Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions.InstGenFSI;
using static Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions.InstGenHelper; using static Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions.InstGenHelper;
using static Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions.InstGenMemory; using static Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions.InstGenMemory;
using static Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions.InstGenPacking; using static Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions.InstGenPacking;
using static Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions.InstGenShuffle;
using static Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions.InstGenVector; using static Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions.InstGenVector;
using static Ryujinx.Graphics.Shader.StructuredIr.InstructionInfo; using static Ryujinx.Graphics.Shader.StructuredIr.InstructionInfo;
@ -174,6 +175,9 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions
case Instruction.PackHalf2x16: case Instruction.PackHalf2x16:
return PackHalf2x16(context, operation); return PackHalf2x16(context, operation);
case Instruction.Shuffle:
return Shuffle(context, operation);
case Instruction.Store: case Instruction.Store:
return Store(context, operation); return Store(context, operation);

View file

@ -13,14 +13,15 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions
AggregateType dstType = GetSrcVarType(operation.Inst, 0); AggregateType dstType = GetSrcVarType(operation.Inst, 0);
string arg = GetSoureExpr(context, operation.GetSource(0), dstType); string arg = GetSoureExpr(context, operation.GetSource(0), dstType);
char component = "xyzw"[operation.Index];
if (context.HostCapabilities.SupportsShaderBallot) if (context.HostCapabilities.SupportsShaderBallot)
{ {
return $"unpackUint2x32(ballotARB({arg})).x"; return $"unpackUint2x32(ballotARB({arg})).{component}";
} }
else else
{ {
return $"subgroupBallot({arg}).x"; return $"subgroupBallot({arg}).{component}";
} }
} }
} }

View file

@ -108,10 +108,10 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions
Add(Instruction.ShiftLeft, InstType.OpBinary, "<<", 3); Add(Instruction.ShiftLeft, InstType.OpBinary, "<<", 3);
Add(Instruction.ShiftRightS32, InstType.OpBinary, ">>", 3); Add(Instruction.ShiftRightS32, InstType.OpBinary, ">>", 3);
Add(Instruction.ShiftRightU32, InstType.OpBinary, ">>", 3); Add(Instruction.ShiftRightU32, InstType.OpBinary, ">>", 3);
Add(Instruction.Shuffle, InstType.CallQuaternary, HelperFunctionNames.Shuffle); Add(Instruction.Shuffle, InstType.Special);
Add(Instruction.ShuffleDown, InstType.CallQuaternary, HelperFunctionNames.ShuffleDown); Add(Instruction.ShuffleDown, InstType.CallBinary, "subgroupShuffleDown");
Add(Instruction.ShuffleUp, InstType.CallQuaternary, HelperFunctionNames.ShuffleUp); Add(Instruction.ShuffleUp, InstType.CallBinary, "subgroupShuffleUp");
Add(Instruction.ShuffleXor, InstType.CallQuaternary, HelperFunctionNames.ShuffleXor); Add(Instruction.ShuffleXor, InstType.CallBinary, "subgroupShuffleXor");
Add(Instruction.Sine, InstType.CallUnary, "sin"); Add(Instruction.Sine, InstType.CallUnary, "sin");
Add(Instruction.SquareRoot, InstType.CallUnary, "sqrt"); Add(Instruction.SquareRoot, InstType.CallUnary, "sqrt");
Add(Instruction.Store, InstType.Special); Add(Instruction.Store, InstType.Special);

View file

@ -0,0 +1,25 @@
using Ryujinx.Graphics.Shader.StructuredIr;
using Ryujinx.Graphics.Shader.Translation;
using static Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions.InstGenHelper;
namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions
{
static class InstGenShuffle
{
public static string Shuffle(CodeGenContext context, AstOperation operation)
{
string value = GetSoureExpr(context, operation.GetSource(0), AggregateType.FP32);
string index = GetSoureExpr(context, operation.GetSource(1), AggregateType.U32);
if (context.HostCapabilities.SupportsShaderBallot)
{
return $"readInvocationARB({value}, {index})";
}
else
{
return $"subgroupShuffle({value}, {index})";
}
}
}
}

View file

@ -231,7 +231,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
var execution = context.Constant(context.TypeU32(), Scope.Subgroup); var execution = context.Constant(context.TypeU32(), Scope.Subgroup);
var maskVector = context.GroupNonUniformBallot(uvec4Type, execution, context.Get(AggregateType.Bool, source)); var maskVector = context.GroupNonUniformBallot(uvec4Type, execution, context.Get(AggregateType.Bool, source));
var mask = context.CompositeExtract(context.TypeU32(), maskVector, (SpvLiteralInteger)0); var mask = context.CompositeExtract(context.TypeU32(), maskVector, (SpvLiteralInteger)operation.Index);
return new OperationResult(AggregateType.U32, mask); return new OperationResult(AggregateType.U32, mask);
} }
@ -1100,117 +1100,40 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
private static OperationResult GenerateShuffle(CodeGenContext context, AstOperation operation) private static OperationResult GenerateShuffle(CodeGenContext context, AstOperation operation)
{ {
var x = context.GetFP32(operation.GetSource(0)); var value = context.GetFP32(operation.GetSource(0));
var index = context.GetU32(operation.GetSource(1)); var index = context.GetU32(operation.GetSource(1));
var mask = context.GetU32(operation.GetSource(2));
var const31 = context.Constant(context.TypeU32(), 31); var result = context.GroupNonUniformShuffle(context.TypeFP32(), context.Constant(context.TypeU32(), (int)Scope.Subgroup), value, index);
var const8 = context.Constant(context.TypeU32(), 8);
var clamp = context.BitwiseAnd(context.TypeU32(), mask, const31);
var segMask = context.BitwiseAnd(context.TypeU32(), context.ShiftRightLogical(context.TypeU32(), mask, const8), const31);
var notSegMask = context.Not(context.TypeU32(), segMask);
var clampNotSegMask = context.BitwiseAnd(context.TypeU32(), clamp, notSegMask);
var indexNotSegMask = context.BitwiseAnd(context.TypeU32(), index, notSegMask);
var threadId = GetScalarInput(context, IoVariable.SubgroupLaneId);
var minThreadId = context.BitwiseAnd(context.TypeU32(), threadId, segMask);
var maxThreadId = context.BitwiseOr(context.TypeU32(), minThreadId, clampNotSegMask);
var srcThreadId = context.BitwiseOr(context.TypeU32(), indexNotSegMask, minThreadId);
var valid = context.ULessThanEqual(context.TypeBool(), srcThreadId, maxThreadId);
var value = context.GroupNonUniformShuffle(context.TypeFP32(), context.Constant(context.TypeU32(), (int)Scope.Subgroup), x, srcThreadId);
var result = context.Select(context.TypeFP32(), valid, value, x);
var validLocal = (AstOperand)operation.GetSource(3);
context.Store(context.GetLocalPointer(validLocal), context.BitcastIfNeeded(validLocal.VarType, AggregateType.Bool, valid));
return new OperationResult(AggregateType.FP32, result); return new OperationResult(AggregateType.FP32, result);
} }
private static OperationResult GenerateShuffleDown(CodeGenContext context, AstOperation operation) private static OperationResult GenerateShuffleDown(CodeGenContext context, AstOperation operation)
{ {
var x = context.GetFP32(operation.GetSource(0)); var value = context.GetFP32(operation.GetSource(0));
var index = context.GetU32(operation.GetSource(1)); var index = context.GetU32(operation.GetSource(1));
var mask = context.GetU32(operation.GetSource(2));
var const31 = context.Constant(context.TypeU32(), 31); var result = context.GroupNonUniformShuffleDown(context.TypeFP32(), context.Constant(context.TypeU32(), (int)Scope.Subgroup), value, index);
var const8 = context.Constant(context.TypeU32(), 8);
var clamp = context.BitwiseAnd(context.TypeU32(), mask, const31);
var segMask = context.BitwiseAnd(context.TypeU32(), context.ShiftRightLogical(context.TypeU32(), mask, const8), const31);
var notSegMask = context.Not(context.TypeU32(), segMask);
var clampNotSegMask = context.BitwiseAnd(context.TypeU32(), clamp, notSegMask);
var threadId = GetScalarInput(context, IoVariable.SubgroupLaneId);
var minThreadId = context.BitwiseAnd(context.TypeU32(), threadId, segMask);
var maxThreadId = context.BitwiseOr(context.TypeU32(), minThreadId, clampNotSegMask);
var srcThreadId = context.IAdd(context.TypeU32(), threadId, index);
var valid = context.ULessThanEqual(context.TypeBool(), srcThreadId, maxThreadId);
var value = context.GroupNonUniformShuffle(context.TypeFP32(), context.Constant(context.TypeU32(), (int)Scope.Subgroup), x, srcThreadId);
var result = context.Select(context.TypeFP32(), valid, value, x);
var validLocal = (AstOperand)operation.GetSource(3);
context.Store(context.GetLocalPointer(validLocal), context.BitcastIfNeeded(validLocal.VarType, AggregateType.Bool, valid));
return new OperationResult(AggregateType.FP32, result); return new OperationResult(AggregateType.FP32, result);
} }
private static OperationResult GenerateShuffleUp(CodeGenContext context, AstOperation operation) private static OperationResult GenerateShuffleUp(CodeGenContext context, AstOperation operation)
{ {
var x = context.GetFP32(operation.GetSource(0)); var value = context.GetFP32(operation.GetSource(0));
var index = context.GetU32(operation.GetSource(1)); var index = context.GetU32(operation.GetSource(1));
var mask = context.GetU32(operation.GetSource(2));
var const31 = context.Constant(context.TypeU32(), 31); var result = context.GroupNonUniformShuffleUp(context.TypeFP32(), context.Constant(context.TypeU32(), (int)Scope.Subgroup), value, index);
var const8 = context.Constant(context.TypeU32(), 8);
var segMask = context.BitwiseAnd(context.TypeU32(), context.ShiftRightLogical(context.TypeU32(), mask, const8), const31);
var threadId = GetScalarInput(context, IoVariable.SubgroupLaneId);
var minThreadId = context.BitwiseAnd(context.TypeU32(), threadId, segMask);
var srcThreadId = context.ISub(context.TypeU32(), threadId, index);
var valid = context.SGreaterThanEqual(context.TypeBool(), srcThreadId, minThreadId);
var value = context.GroupNonUniformShuffle(context.TypeFP32(), context.Constant(context.TypeU32(), (int)Scope.Subgroup), x, srcThreadId);
var result = context.Select(context.TypeFP32(), valid, value, x);
var validLocal = (AstOperand)operation.GetSource(3);
context.Store(context.GetLocalPointer(validLocal), context.BitcastIfNeeded(validLocal.VarType, AggregateType.Bool, valid));
return new OperationResult(AggregateType.FP32, result); return new OperationResult(AggregateType.FP32, result);
} }
private static OperationResult GenerateShuffleXor(CodeGenContext context, AstOperation operation) private static OperationResult GenerateShuffleXor(CodeGenContext context, AstOperation operation)
{ {
var x = context.GetFP32(operation.GetSource(0)); var value = context.GetFP32(operation.GetSource(0));
var index = context.GetU32(operation.GetSource(1)); var index = context.GetU32(operation.GetSource(1));
var mask = context.GetU32(operation.GetSource(2));
var const31 = context.Constant(context.TypeU32(), 31); var result = context.GroupNonUniformShuffleXor(context.TypeFP32(), context.Constant(context.TypeU32(), (int)Scope.Subgroup), value, index);
var const8 = context.Constant(context.TypeU32(), 8);
var clamp = context.BitwiseAnd(context.TypeU32(), mask, const31);
var segMask = context.BitwiseAnd(context.TypeU32(), context.ShiftRightLogical(context.TypeU32(), mask, const8), const31);
var notSegMask = context.Not(context.TypeU32(), segMask);
var clampNotSegMask = context.BitwiseAnd(context.TypeU32(), clamp, notSegMask);
var threadId = GetScalarInput(context, IoVariable.SubgroupLaneId);
var minThreadId = context.BitwiseAnd(context.TypeU32(), threadId, segMask);
var maxThreadId = context.BitwiseOr(context.TypeU32(), minThreadId, clampNotSegMask);
var srcThreadId = context.BitwiseXor(context.TypeU32(), threadId, index);
var valid = context.ULessThanEqual(context.TypeBool(), srcThreadId, maxThreadId);
var value = context.GroupNonUniformShuffle(context.TypeFP32(), context.Constant(context.TypeU32(), (int)Scope.Subgroup), x, srcThreadId);
var result = context.Select(context.TypeFP32(), valid, value, x);
var validLocal = (AstOperand)operation.GetSource(3);
context.Store(context.GetLocalPointer(validLocal), context.BitcastIfNeeded(validLocal.VarType, AggregateType.Bool, valid));
return new OperationResult(AggregateType.FP32, result); return new OperationResult(AggregateType.FP32, result);
} }

View file

@ -28,12 +28,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
_poolLock = new object(); _poolLock = new object();
} }
private const HelperFunctionsMask NeedsInvocationIdMask = private const HelperFunctionsMask NeedsInvocationIdMask = HelperFunctionsMask.SwizzleAdd;
HelperFunctionsMask.Shuffle |
HelperFunctionsMask.ShuffleDown |
HelperFunctionsMask.ShuffleUp |
HelperFunctionsMask.ShuffleXor |
HelperFunctionsMask.SwizzleAdd;
public static byte[] Generate(StructuredProgramInfo info, CodeGenParameters parameters) public static byte[] Generate(StructuredProgramInfo info, CodeGenParameters parameters)
{ {

View file

@ -307,6 +307,9 @@ namespace Ryujinx.Graphics.Shader.Decoders
case InstName.Sts: case InstName.Sts:
context.SetUsedFeature(FeatureFlags.SharedMemory); context.SetUsedFeature(FeatureFlags.SharedMemory);
break; break;
case InstName.Shfl:
context.SetUsedFeature(FeatureFlags.Shuffle);
break;
} }
block.OpCodes.Add(op); block.OpCodes.Add(op);

View file

@ -194,6 +194,15 @@ namespace Ryujinx.Graphics.Shader
return 16; return 16;
} }
/// <summary>
/// Queries host shader subgroup size.
/// </summary>
/// <returns>Host shader subgroup size in invocations</returns>
int QueryHostSubgroupSize()
{
return 32;
}
/// <summary> /// <summary>
/// Queries host support for texture formats with BGRA component order (such as BGRA8). /// Queries host support for texture formats with BGRA component order (such as BGRA8).
/// </summary> /// </summary>

View file

@ -76,7 +76,7 @@ namespace Ryujinx.Graphics.Shader.Instructions
switch (op.SReg) switch (op.SReg)
{ {
case SReg.LaneId: case SReg.LaneId:
src = context.Load(StorageKind.Input, IoVariable.SubgroupLaneId); src = EmitLoadSubgroupLaneId(context);
break; break;
case SReg.InvocationId: case SReg.InvocationId:
@ -146,19 +146,19 @@ namespace Ryujinx.Graphics.Shader.Instructions
break; break;
case SReg.EqMask: case SReg.EqMask:
src = context.Load(StorageKind.Input, IoVariable.SubgroupEqMask, null, Const(0)); src = EmitLoadSubgroupMask(context, IoVariable.SubgroupEqMask);
break; break;
case SReg.LtMask: case SReg.LtMask:
src = context.Load(StorageKind.Input, IoVariable.SubgroupLtMask, null, Const(0)); src = EmitLoadSubgroupMask(context, IoVariable.SubgroupLtMask);
break; break;
case SReg.LeMask: case SReg.LeMask:
src = context.Load(StorageKind.Input, IoVariable.SubgroupLeMask, null, Const(0)); src = EmitLoadSubgroupMask(context, IoVariable.SubgroupLeMask);
break; break;
case SReg.GtMask: case SReg.GtMask:
src = context.Load(StorageKind.Input, IoVariable.SubgroupGtMask, null, Const(0)); src = EmitLoadSubgroupMask(context, IoVariable.SubgroupGtMask);
break; break;
case SReg.GeMask: case SReg.GeMask:
src = context.Load(StorageKind.Input, IoVariable.SubgroupGeMask, null, Const(0)); src = EmitLoadSubgroupMask(context, IoVariable.SubgroupGeMask);
break; break;
default: default:
@ -169,6 +169,52 @@ namespace Ryujinx.Graphics.Shader.Instructions
context.Copy(GetDest(op.Dest), src); context.Copy(GetDest(op.Dest), src);
} }
private static Operand EmitLoadSubgroupLaneId(EmitterContext context)
{
if (context.TranslatorContext.GpuAccessor.QueryHostSubgroupSize() <= 32)
{
return context.Load(StorageKind.Input, IoVariable.SubgroupLaneId);
}
return context.BitwiseAnd(context.Load(StorageKind.Input, IoVariable.SubgroupLaneId), Const(0x1f));
}
private static Operand EmitLoadSubgroupMask(EmitterContext context, IoVariable ioVariable)
{
int subgroupSize = context.TranslatorContext.GpuAccessor.QueryHostSubgroupSize();
if (subgroupSize <= 32)
{
return context.Load(StorageKind.Input, ioVariable, null, Const(0));
}
else if (subgroupSize == 64)
{
Operand laneId = context.Load(StorageKind.Input, IoVariable.SubgroupLaneId);
Operand low = context.Load(StorageKind.Input, ioVariable, null, Const(0));
Operand high = context.Load(StorageKind.Input, ioVariable, null, Const(1));
return context.ConditionalSelect(context.BitwiseAnd(laneId, Const(32)), high, low);
}
else
{
Operand laneId = context.Load(StorageKind.Input, IoVariable.SubgroupLaneId);
Operand element = context.ShiftRightU32(laneId, Const(5));
Operand res = context.Load(StorageKind.Input, ioVariable, null, Const(0));
res = context.ConditionalSelect(
context.ICompareEqual(element, Const(1)),
context.Load(StorageKind.Input, ioVariable, null, Const(1)), res);
res = context.ConditionalSelect(
context.ICompareEqual(element, Const(2)),
context.Load(StorageKind.Input, ioVariable, null, Const(2)), res);
res = context.ConditionalSelect(
context.ICompareEqual(element, Const(3)),
context.Load(StorageKind.Input, ioVariable, null, Const(3)), res);
return res;
}
}
public static void SelR(EmitterContext context) public static void SelR(EmitterContext context)
{ {
InstSelR op = context.GetOp<InstSelR>(); InstSelR op = context.GetOp<InstSelR>();

View file

@ -50,20 +50,7 @@ namespace Ryujinx.Graphics.Shader.Instructions
InstVote op = context.GetOp<InstVote>(); InstVote op = context.GetOp<InstVote>();
Operand pred = GetPredicate(context, op.SrcPred, op.SrcPredInv); Operand pred = GetPredicate(context, op.SrcPred, op.SrcPredInv);
Operand res = null; Operand res = EmitVote(context, op.VoteMode, pred);
switch (op.VoteMode)
{
case VoteMode.All:
res = context.VoteAll(pred);
break;
case VoteMode.Any:
res = context.VoteAny(pred);
break;
case VoteMode.Eq:
res = context.VoteAllEqual(pred);
break;
}
if (res != null) if (res != null)
{ {
@ -76,7 +63,81 @@ namespace Ryujinx.Graphics.Shader.Instructions
if (op.Dest != RegisterConsts.RegisterZeroIndex) if (op.Dest != RegisterConsts.RegisterZeroIndex)
{ {
context.Copy(GetDest(op.Dest), context.Ballot(pred)); context.Copy(GetDest(op.Dest), EmitBallot(context, pred));
}
}
private static Operand EmitVote(EmitterContext context, VoteMode voteMode, Operand pred)
{
int subgroupSize = context.TranslatorContext.GpuAccessor.QueryHostSubgroupSize();
if (subgroupSize <= 32)
{
return voteMode switch
{
VoteMode.All => context.VoteAll(pred),
VoteMode.Any => context.VoteAny(pred),
VoteMode.Eq => context.VoteAllEqual(pred),
_ => null,
};
}
// Emulate vote with ballot masks.
// We do that when the GPU thread count is not 32,
// since the shader code assumes it is 32.
// allInvocations => ballot(pred) == ballot(true),
// anyInvocation => ballot(pred) != 0,
// allInvocationsEqual => ballot(pred) == balot(true) || ballot(pred) == 0
Operand ballotMask = EmitBallot(context, pred);
Operand AllTrue() => context.ICompareEqual(ballotMask, EmitBallot(context, Const(IrConsts.True)));
return voteMode switch
{
VoteMode.All => AllTrue(),
VoteMode.Any => context.ICompareNotEqual(ballotMask, Const(0)),
VoteMode.Eq => context.BitwiseOr(AllTrue(), context.ICompareEqual(ballotMask, Const(0))),
_ => null,
};
}
private static Operand EmitBallot(EmitterContext context, Operand pred)
{
int subgroupSize = context.TranslatorContext.GpuAccessor.QueryHostSubgroupSize();
if (subgroupSize <= 32)
{
return context.Ballot(pred, 0);
}
else if (subgroupSize == 64)
{
// TODO: Add support for vector destination and do that with a single operation.
Operand laneId = context.Load(StorageKind.Input, IoVariable.SubgroupLaneId);
Operand low = context.Ballot(pred, 0);
Operand high = context.Ballot(pred, 1);
return context.ConditionalSelect(context.BitwiseAnd(laneId, Const(32)), high, low);
}
else
{
// TODO: Add support for vector destination and do that with a single operation.
Operand laneId = context.Load(StorageKind.Input, IoVariable.SubgroupLaneId);
Operand element = context.ShiftRightU32(laneId, Const(5));
Operand res = context.Ballot(pred, 0);
res = context.ConditionalSelect(
context.ICompareEqual(element, Const(1)),
context.Ballot(pred, 1), res);
res = context.ConditionalSelect(
context.ICompareEqual(element, Const(2)),
context.Ballot(pred, 2), res);
res = context.ConditionalSelect(
context.ICompareEqual(element, Const(3)),
context.Ballot(pred, 3), res);
return res;
} }
} }
} }

View file

@ -12,10 +12,6 @@
<ItemGroup> <ItemGroup>
<EmbeddedResource Include="CodeGen\Glsl\HelperFunctions\MultiplyHighS32.glsl" /> <EmbeddedResource Include="CodeGen\Glsl\HelperFunctions\MultiplyHighS32.glsl" />
<EmbeddedResource Include="CodeGen\Glsl\HelperFunctions\MultiplyHighU32.glsl" /> <EmbeddedResource Include="CodeGen\Glsl\HelperFunctions\MultiplyHighU32.glsl" />
<EmbeddedResource Include="CodeGen\Glsl\HelperFunctions\Shuffle.glsl" />
<EmbeddedResource Include="CodeGen\Glsl\HelperFunctions\ShuffleDown.glsl" />
<EmbeddedResource Include="CodeGen\Glsl\HelperFunctions\ShuffleUp.glsl" />
<EmbeddedResource Include="CodeGen\Glsl\HelperFunctions\ShuffleXor.glsl" />
<EmbeddedResource Include="CodeGen\Glsl\HelperFunctions\SwizzleAdd.glsl" /> <EmbeddedResource Include="CodeGen\Glsl\HelperFunctions\SwizzleAdd.glsl" />
</ItemGroup> </ItemGroup>

View file

@ -7,10 +7,6 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
{ {
MultiplyHighS32 = 1 << 2, MultiplyHighS32 = 1 << 2,
MultiplyHighU32 = 1 << 3, MultiplyHighU32 = 1 << 3,
Shuffle = 1 << 4,
ShuffleDown = 1 << 5,
ShuffleUp = 1 << 6,
ShuffleXor = 1 << 7,
SwizzleAdd = 1 << 10, SwizzleAdd = 1 << 10,
FSI = 1 << 11, FSI = 1 << 11,
} }

View file

@ -109,14 +109,15 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
Add(Instruction.PackDouble2x32, AggregateType.FP64, AggregateType.U32, AggregateType.U32); Add(Instruction.PackDouble2x32, AggregateType.FP64, AggregateType.U32, AggregateType.U32);
Add(Instruction.PackHalf2x16, AggregateType.U32, AggregateType.FP32, AggregateType.FP32); Add(Instruction.PackHalf2x16, AggregateType.U32, AggregateType.FP32, AggregateType.FP32);
Add(Instruction.ReciprocalSquareRoot, AggregateType.Scalar, AggregateType.Scalar); Add(Instruction.ReciprocalSquareRoot, AggregateType.Scalar, AggregateType.Scalar);
Add(Instruction.Return, AggregateType.Void, AggregateType.U32);
Add(Instruction.Round, AggregateType.Scalar, AggregateType.Scalar); Add(Instruction.Round, AggregateType.Scalar, AggregateType.Scalar);
Add(Instruction.ShiftLeft, AggregateType.S32, AggregateType.S32, AggregateType.S32); Add(Instruction.ShiftLeft, AggregateType.S32, AggregateType.S32, AggregateType.S32);
Add(Instruction.ShiftRightS32, AggregateType.S32, AggregateType.S32, AggregateType.S32); Add(Instruction.ShiftRightS32, AggregateType.S32, AggregateType.S32, AggregateType.S32);
Add(Instruction.ShiftRightU32, AggregateType.U32, AggregateType.U32, AggregateType.S32); Add(Instruction.ShiftRightU32, AggregateType.U32, AggregateType.U32, AggregateType.S32);
Add(Instruction.Shuffle, AggregateType.FP32, AggregateType.FP32, AggregateType.U32, AggregateType.U32, AggregateType.Bool); Add(Instruction.Shuffle, AggregateType.FP32, AggregateType.FP32, AggregateType.U32);
Add(Instruction.ShuffleDown, AggregateType.FP32, AggregateType.FP32, AggregateType.U32, AggregateType.U32, AggregateType.Bool); Add(Instruction.ShuffleDown, AggregateType.FP32, AggregateType.FP32, AggregateType.U32);
Add(Instruction.ShuffleUp, AggregateType.FP32, AggregateType.FP32, AggregateType.U32, AggregateType.U32, AggregateType.Bool); Add(Instruction.ShuffleUp, AggregateType.FP32, AggregateType.FP32, AggregateType.U32);
Add(Instruction.ShuffleXor, AggregateType.FP32, AggregateType.FP32, AggregateType.U32, AggregateType.U32, AggregateType.Bool); Add(Instruction.ShuffleXor, AggregateType.FP32, AggregateType.FP32, AggregateType.U32);
Add(Instruction.Sine, AggregateType.Scalar, AggregateType.Scalar); Add(Instruction.Sine, AggregateType.Scalar, AggregateType.Scalar);
Add(Instruction.SquareRoot, AggregateType.Scalar, AggregateType.Scalar); Add(Instruction.SquareRoot, AggregateType.Scalar, AggregateType.Scalar);
Add(Instruction.Store, AggregateType.Void); Add(Instruction.Store, AggregateType.Void);
@ -131,7 +132,7 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
Add(Instruction.VoteAll, AggregateType.Bool, AggregateType.Bool); Add(Instruction.VoteAll, AggregateType.Bool, AggregateType.Bool);
Add(Instruction.VoteAllEqual, AggregateType.Bool, AggregateType.Bool); Add(Instruction.VoteAllEqual, AggregateType.Bool, AggregateType.Bool);
Add(Instruction.VoteAny, AggregateType.Bool, AggregateType.Bool); Add(Instruction.VoteAny, AggregateType.Bool, AggregateType.Bool);
#pragma warning restore IDE0055v #pragma warning restore IDE0055
} }
private static void Add(Instruction inst, AggregateType destType, params AggregateType[] srcTypes) private static void Add(Instruction inst, AggregateType destType, params AggregateType[] srcTypes)

View file

@ -282,18 +282,6 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
case Instruction.MultiplyHighU32: case Instruction.MultiplyHighU32:
context.Info.HelperFunctionsMask |= HelperFunctionsMask.MultiplyHighU32; context.Info.HelperFunctionsMask |= HelperFunctionsMask.MultiplyHighU32;
break; break;
case Instruction.Shuffle:
context.Info.HelperFunctionsMask |= HelperFunctionsMask.Shuffle;
break;
case Instruction.ShuffleDown:
context.Info.HelperFunctionsMask |= HelperFunctionsMask.ShuffleDown;
break;
case Instruction.ShuffleUp:
context.Info.HelperFunctionsMask |= HelperFunctionsMask.ShuffleUp;
break;
case Instruction.ShuffleXor:
context.Info.HelperFunctionsMask |= HelperFunctionsMask.ShuffleXor;
break;
case Instruction.SwizzleAdd: case Instruction.SwizzleAdd:
context.Info.HelperFunctionsMask |= HelperFunctionsMask.SwizzleAdd; context.Info.HelperFunctionsMask |= HelperFunctionsMask.SwizzleAdd;
break; break;

View file

@ -112,9 +112,13 @@ namespace Ryujinx.Graphics.Shader.Translation
return context.Add(Instruction.AtomicXor, storageKind, Local(), Const(binding), e0, e1, value); return context.Add(Instruction.AtomicXor, storageKind, Local(), Const(binding), e0, e1, value);
} }
public static Operand Ballot(this EmitterContext context, Operand a) public static Operand Ballot(this EmitterContext context, Operand a, int index)
{ {
return context.Add(Instruction.Ballot, Local(), a); Operand dest = Local();
context.Add(new Operation(Instruction.Ballot, index, dest, a));
return dest;
} }
public static Operand Barrier(this EmitterContext context) public static Operand Barrier(this EmitterContext context)
@ -782,21 +786,41 @@ namespace Ryujinx.Graphics.Shader.Translation
return context.Add(Instruction.ShiftRightU32, Local(), a, b); return context.Add(Instruction.ShiftRightU32, Local(), a, b);
} }
public static Operand Shuffle(this EmitterContext context, Operand a, Operand b)
{
return context.Add(Instruction.Shuffle, Local(), a, b);
}
public static (Operand, Operand) Shuffle(this EmitterContext context, Operand a, Operand b, Operand c) public static (Operand, Operand) Shuffle(this EmitterContext context, Operand a, Operand b, Operand c)
{ {
return context.Add(Instruction.Shuffle, (Local(), Local()), a, b, c); return context.Add(Instruction.Shuffle, (Local(), Local()), a, b, c);
} }
public static Operand ShuffleDown(this EmitterContext context, Operand a, Operand b)
{
return context.Add(Instruction.ShuffleDown, Local(), a, b);
}
public static (Operand, Operand) ShuffleDown(this EmitterContext context, Operand a, Operand b, Operand c) public static (Operand, Operand) ShuffleDown(this EmitterContext context, Operand a, Operand b, Operand c)
{ {
return context.Add(Instruction.ShuffleDown, (Local(), Local()), a, b, c); return context.Add(Instruction.ShuffleDown, (Local(), Local()), a, b, c);
} }
public static Operand ShuffleUp(this EmitterContext context, Operand a, Operand b)
{
return context.Add(Instruction.ShuffleUp, Local(), a, b);
}
public static (Operand, Operand) ShuffleUp(this EmitterContext context, Operand a, Operand b, Operand c) public static (Operand, Operand) ShuffleUp(this EmitterContext context, Operand a, Operand b, Operand c)
{ {
return context.Add(Instruction.ShuffleUp, (Local(), Local()), a, b, c); return context.Add(Instruction.ShuffleUp, (Local(), Local()), a, b, c);
} }
public static Operand ShuffleXor(this EmitterContext context, Operand a, Operand b)
{
return context.Add(Instruction.ShuffleXor, Local(), a, b);
}
public static (Operand, Operand) ShuffleXor(this EmitterContext context, Operand a, Operand b, Operand c) public static (Operand, Operand) ShuffleXor(this EmitterContext context, Operand a, Operand b, Operand c)
{ {
return context.Add(Instruction.ShuffleXor, (Local(), Local()), a, b, c); return context.Add(Instruction.ShuffleXor, (Local(), Local()), a, b, c);

View file

@ -18,6 +18,7 @@ namespace Ryujinx.Graphics.Shader.Translation
InstanceId = 1 << 3, InstanceId = 1 << 3,
DrawParameters = 1 << 4, DrawParameters = 1 << 4,
RtLayer = 1 << 5, RtLayer = 1 << 5,
Shuffle = 1 << 6,
FixedFuncAttr = 1 << 9, FixedFuncAttr = 1 << 9,
LocalMemory = 1 << 10, LocalMemory = 1 << 10,
SharedMemory = 1 << 11, SharedMemory = 1 << 11,

View file

@ -56,6 +56,20 @@ namespace Ryujinx.Graphics.Shader.Translation
return functionId; return functionId;
} }
public int GetOrCreateShuffleFunctionId(HelperFunctionName functionName, int subgroupSize)
{
if (_functionIds.TryGetValue((int)functionName, out int functionId))
{
return functionId;
}
Function function = GenerateShuffleFunction(functionName, subgroupSize);
functionId = AddFunction(function);
_functionIds.Add((int)functionName, functionId);
return functionId;
}
private Function GenerateFunction(HelperFunctionName functionName) private Function GenerateFunction(HelperFunctionName functionName)
{ {
return functionName switch return functionName switch
@ -216,6 +230,137 @@ namespace Ryujinx.Graphics.Shader.Translation
return new Function(ControlFlowGraph.Create(context.GetOperations()).Blocks, $"SharedStore{bitSize}_{id}", false, 2, 0); return new Function(ControlFlowGraph.Create(context.GetOperations()).Blocks, $"SharedStore{bitSize}_{id}", false, 2, 0);
} }
private static Function GenerateShuffleFunction(HelperFunctionName functionName, int subgroupSize)
{
return functionName switch
{
HelperFunctionName.Shuffle => GenerateShuffle(subgroupSize),
HelperFunctionName.ShuffleDown => GenerateShuffleDown(subgroupSize),
HelperFunctionName.ShuffleUp => GenerateShuffleUp(subgroupSize),
HelperFunctionName.ShuffleXor => GenerateShuffleXor(subgroupSize),
_ => throw new ArgumentException($"Invalid function name {functionName}"),
};
}
private static Function GenerateShuffle(int subgroupSize)
{
EmitterContext context = new();
Operand value = Argument(0);
Operand index = Argument(1);
Operand mask = Argument(2);
Operand clamp = context.BitwiseAnd(mask, Const(0x1f));
Operand segMask = context.BitwiseAnd(context.ShiftRightU32(mask, Const(8)), Const(0x1f));
Operand minThreadId = context.BitwiseAnd(GenerateLoadSubgroupLaneId(context, subgroupSize), segMask);
Operand maxThreadId = context.BitwiseOr(context.BitwiseAnd(clamp, context.BitwiseNot(segMask)), minThreadId);
Operand srcThreadId = context.BitwiseOr(context.BitwiseAnd(index, context.BitwiseNot(segMask)), minThreadId);
Operand valid = context.ICompareLessOrEqualUnsigned(srcThreadId, maxThreadId);
context.Copy(Argument(3), valid);
Operand result = context.Shuffle(value, GenerateSubgroupShuffleIndex(context, srcThreadId, subgroupSize));
context.Return(context.ConditionalSelect(valid, result, value));
return new Function(ControlFlowGraph.Create(context.GetOperations()).Blocks, "Shuffle", true, 3, 1);
}
private static Function GenerateShuffleDown(int subgroupSize)
{
EmitterContext context = new();
Operand value = Argument(0);
Operand index = Argument(1);
Operand mask = Argument(2);
Operand clamp = context.BitwiseAnd(mask, Const(0x1f));
Operand segMask = context.BitwiseAnd(context.ShiftRightU32(mask, Const(8)), Const(0x1f));
Operand laneId = GenerateLoadSubgroupLaneId(context, subgroupSize);
Operand minThreadId = context.BitwiseAnd(laneId, segMask);
Operand maxThreadId = context.BitwiseOr(context.BitwiseAnd(clamp, context.BitwiseNot(segMask)), minThreadId);
Operand srcThreadId = context.IAdd(laneId, index);
Operand valid = context.ICompareLessOrEqualUnsigned(srcThreadId, maxThreadId);
context.Copy(Argument(3), valid);
Operand result = context.Shuffle(value, GenerateSubgroupShuffleIndex(context, srcThreadId, subgroupSize));
context.Return(context.ConditionalSelect(valid, result, value));
return new Function(ControlFlowGraph.Create(context.GetOperations()).Blocks, "ShuffleDown", true, 3, 1);
}
private static Function GenerateShuffleUp(int subgroupSize)
{
EmitterContext context = new();
Operand value = Argument(0);
Operand index = Argument(1);
Operand mask = Argument(2);
Operand segMask = context.BitwiseAnd(context.ShiftRightU32(mask, Const(8)), Const(0x1f));
Operand laneId = GenerateLoadSubgroupLaneId(context, subgroupSize);
Operand minThreadId = context.BitwiseAnd(laneId, segMask);
Operand srcThreadId = context.ISubtract(laneId, index);
Operand valid = context.ICompareGreaterOrEqual(srcThreadId, minThreadId);
context.Copy(Argument(3), valid);
Operand result = context.Shuffle(value, GenerateSubgroupShuffleIndex(context, srcThreadId, subgroupSize));
context.Return(context.ConditionalSelect(valid, result, value));
return new Function(ControlFlowGraph.Create(context.GetOperations()).Blocks, "ShuffleUp", true, 3, 1);
}
private static Function GenerateShuffleXor(int subgroupSize)
{
EmitterContext context = new();
Operand value = Argument(0);
Operand index = Argument(1);
Operand mask = Argument(2);
Operand clamp = context.BitwiseAnd(mask, Const(0x1f));
Operand segMask = context.BitwiseAnd(context.ShiftRightU32(mask, Const(8)), Const(0x1f));
Operand laneId = GenerateLoadSubgroupLaneId(context, subgroupSize);
Operand minThreadId = context.BitwiseAnd(laneId, segMask);
Operand maxThreadId = context.BitwiseOr(context.BitwiseAnd(clamp, context.BitwiseNot(segMask)), minThreadId);
Operand srcThreadId = context.BitwiseExclusiveOr(laneId, index);
Operand valid = context.ICompareLessOrEqualUnsigned(srcThreadId, maxThreadId);
context.Copy(Argument(3), valid);
Operand result = context.Shuffle(value, GenerateSubgroupShuffleIndex(context, srcThreadId, subgroupSize));
context.Return(context.ConditionalSelect(valid, result, value));
return new Function(ControlFlowGraph.Create(context.GetOperations()).Blocks, "ShuffleXor", true, 3, 1);
}
private static Operand GenerateLoadSubgroupLaneId(EmitterContext context, int subgroupSize)
{
if (subgroupSize <= 32)
{
return context.Load(StorageKind.Input, IoVariable.SubgroupLaneId);
}
return context.BitwiseAnd(context.Load(StorageKind.Input, IoVariable.SubgroupLaneId), Const(0x1f));
}
private static Operand GenerateSubgroupShuffleIndex(EmitterContext context, Operand srcThreadId, int subgroupSize)
{
if (subgroupSize <= 32)
{
return srcThreadId;
}
return context.BitwiseOr(
context.BitwiseAnd(context.Load(StorageKind.Input, IoVariable.SubgroupLaneId), Const(0x60)),
srcThreadId);
}
private Function GenerateTexelFetchScaleFunction() private Function GenerateTexelFetchScaleFunction()
{ {
EmitterContext context = new(); EmitterContext context = new();

View file

@ -2,12 +2,18 @@ namespace Ryujinx.Graphics.Shader.Translation
{ {
enum HelperFunctionName enum HelperFunctionName
{ {
Invalid,
ConvertDoubleToFloat, ConvertDoubleToFloat,
ConvertFloatToDouble, ConvertFloatToDouble,
SharedAtomicMaxS32, SharedAtomicMaxS32,
SharedAtomicMinS32, SharedAtomicMinS32,
SharedStore8, SharedStore8,
SharedStore16, SharedStore16,
Shuffle,
ShuffleDown,
ShuffleUp,
ShuffleXor,
TexelFetchScale, TexelFetchScale,
TextureSizeUnscale, TextureSizeUnscale,
} }

View file

@ -0,0 +1,52 @@
using Ryujinx.Graphics.Shader.IntermediateRepresentation;
using Ryujinx.Graphics.Shader.Translation.Optimizations;
using System.Collections.Generic;
using static Ryujinx.Graphics.Shader.IntermediateRepresentation.OperandHelper;
namespace Ryujinx.Graphics.Shader.Translation.Transforms
{
class ShufflePass : ITransformPass
{
public static bool IsEnabled(IGpuAccessor gpuAccessor, ShaderStage stage, TargetLanguage targetLanguage, FeatureFlags usedFeatures)
{
return usedFeatures.HasFlag(FeatureFlags.Shuffle);
}
public static LinkedListNode<INode> RunPass(TransformContext context, LinkedListNode<INode> node)
{
Operation operation = (Operation)node.Value;
HelperFunctionName functionName = operation.Inst switch
{
Instruction.Shuffle => HelperFunctionName.Shuffle,
Instruction.ShuffleDown => HelperFunctionName.ShuffleDown,
Instruction.ShuffleUp => HelperFunctionName.ShuffleUp,
Instruction.ShuffleXor => HelperFunctionName.ShuffleXor,
_ => HelperFunctionName.Invalid,
};
if (functionName == HelperFunctionName.Invalid || operation.SourcesCount != 3 || operation.DestsCount != 2)
{
return node;
}
int functionId = context.Hfm.GetOrCreateShuffleFunctionId(functionName, context.GpuAccessor.QueryHostSubgroupSize());
Operand result = operation.GetDest(0);
Operand valid = operation.GetDest(1);
Operand value = operation.GetSource(0);
Operand index = operation.GetSource(1);
Operand mask = operation.GetSource(2);
operation.Dest = null;
Operand[] callArgs = new Operand[] { Const(functionId), value, index, mask, valid };
LinkedListNode<INode> newNode = node.List.AddBefore(node, new Operation(Instruction.Call, 0, result, callArgs));
Utils.DeleteNode(node, operation);
return newNode;
}
}
}

View file

@ -13,6 +13,7 @@ namespace Ryujinx.Graphics.Shader.Translation.Transforms
RunPass<TexturePass>(context); RunPass<TexturePass>(context);
RunPass<SharedStoreSmallIntCas>(context); RunPass<SharedStoreSmallIntCas>(context);
RunPass<SharedAtomicSignedCas>(context); RunPass<SharedAtomicSignedCas>(context);
RunPass<ShufflePass>(context);
} }
private static void RunPass<T>(TransformContext context) where T : ITransformPass private static void RunPass<T>(TransformContext context) where T : ITransformPass

View file

@ -25,7 +25,6 @@ namespace Ryujinx.Graphics.Vulkan
public readonly bool SupportsIndirectParameters; public readonly bool SupportsIndirectParameters;
public readonly bool SupportsFragmentShaderInterlock; public readonly bool SupportsFragmentShaderInterlock;
public readonly bool SupportsGeometryShaderPassthrough; public readonly bool SupportsGeometryShaderPassthrough;
public readonly bool SupportsSubgroupSizeControl;
public readonly bool SupportsShaderFloat64; public readonly bool SupportsShaderFloat64;
public readonly bool SupportsShaderInt8; public readonly bool SupportsShaderInt8;
public readonly bool SupportsShaderStencilExport; public readonly bool SupportsShaderStencilExport;
@ -45,9 +44,7 @@ namespace Ryujinx.Graphics.Vulkan
public readonly bool SupportsViewportArray2; public readonly bool SupportsViewportArray2;
public readonly bool SupportsHostImportedMemory; public readonly bool SupportsHostImportedMemory;
public readonly bool SupportsDepthClipControl; public readonly bool SupportsDepthClipControl;
public readonly uint MinSubgroupSize; public readonly uint SubgroupSize;
public readonly uint MaxSubgroupSize;
public readonly ShaderStageFlags RequiredSubgroupSizeStages;
public readonly SampleCountFlags SupportedSampleCounts; public readonly SampleCountFlags SupportedSampleCounts;
public readonly PortabilitySubsetFlags PortabilitySubset; public readonly PortabilitySubsetFlags PortabilitySubset;
public readonly uint VertexBufferAlignment; public readonly uint VertexBufferAlignment;
@ -64,7 +61,6 @@ namespace Ryujinx.Graphics.Vulkan
bool supportsIndirectParameters, bool supportsIndirectParameters,
bool supportsFragmentShaderInterlock, bool supportsFragmentShaderInterlock,
bool supportsGeometryShaderPassthrough, bool supportsGeometryShaderPassthrough,
bool supportsSubgroupSizeControl,
bool supportsShaderFloat64, bool supportsShaderFloat64,
bool supportsShaderInt8, bool supportsShaderInt8,
bool supportsShaderStencilExport, bool supportsShaderStencilExport,
@ -84,9 +80,7 @@ namespace Ryujinx.Graphics.Vulkan
bool supportsViewportArray2, bool supportsViewportArray2,
bool supportsHostImportedMemory, bool supportsHostImportedMemory,
bool supportsDepthClipControl, bool supportsDepthClipControl,
uint minSubgroupSize, uint subgroupSize,
uint maxSubgroupSize,
ShaderStageFlags requiredSubgroupSizeStages,
SampleCountFlags supportedSampleCounts, SampleCountFlags supportedSampleCounts,
PortabilitySubsetFlags portabilitySubset, PortabilitySubsetFlags portabilitySubset,
uint vertexBufferAlignment, uint vertexBufferAlignment,
@ -102,7 +96,6 @@ namespace Ryujinx.Graphics.Vulkan
SupportsIndirectParameters = supportsIndirectParameters; SupportsIndirectParameters = supportsIndirectParameters;
SupportsFragmentShaderInterlock = supportsFragmentShaderInterlock; SupportsFragmentShaderInterlock = supportsFragmentShaderInterlock;
SupportsGeometryShaderPassthrough = supportsGeometryShaderPassthrough; SupportsGeometryShaderPassthrough = supportsGeometryShaderPassthrough;
SupportsSubgroupSizeControl = supportsSubgroupSizeControl;
SupportsShaderFloat64 = supportsShaderFloat64; SupportsShaderFloat64 = supportsShaderFloat64;
SupportsShaderInt8 = supportsShaderInt8; SupportsShaderInt8 = supportsShaderInt8;
SupportsShaderStencilExport = supportsShaderStencilExport; SupportsShaderStencilExport = supportsShaderStencilExport;
@ -122,9 +115,7 @@ namespace Ryujinx.Graphics.Vulkan
SupportsViewportArray2 = supportsViewportArray2; SupportsViewportArray2 = supportsViewportArray2;
SupportsHostImportedMemory = supportsHostImportedMemory; SupportsHostImportedMemory = supportsHostImportedMemory;
SupportsDepthClipControl = supportsDepthClipControl; SupportsDepthClipControl = supportsDepthClipControl;
MinSubgroupSize = minSubgroupSize; SubgroupSize = subgroupSize;
MaxSubgroupSize = maxSubgroupSize;
RequiredSubgroupSizeStages = requiredSubgroupSizeStages;
SupportedSampleCounts = supportedSampleCounts; SupportedSampleCounts = supportedSampleCounts;
PortabilitySubset = portabilitySubset; PortabilitySubset = portabilitySubset;
VertexBufferAlignment = vertexBufferAlignment; VertexBufferAlignment = vertexBufferAlignment;

View file

@ -352,11 +352,6 @@ namespace Ryujinx.Graphics.Vulkan
return pipeline; return pipeline;
} }
if (gd.Capabilities.SupportsSubgroupSizeControl)
{
UpdateStageRequiredSubgroupSizes(gd, 1);
}
var pipelineCreateInfo = new ComputePipelineCreateInfo var pipelineCreateInfo = new ComputePipelineCreateInfo
{ {
SType = StructureType.ComputePipelineCreateInfo, SType = StructureType.ComputePipelineCreateInfo,
@ -616,11 +611,6 @@ namespace Ryujinx.Graphics.Vulkan
PDynamicStates = dynamicStates, PDynamicStates = dynamicStates,
}; };
if (gd.Capabilities.SupportsSubgroupSizeControl)
{
UpdateStageRequiredSubgroupSizes(gd, (int)StagesCount);
}
var pipelineCreateInfo = new GraphicsPipelineCreateInfo var pipelineCreateInfo = new GraphicsPipelineCreateInfo
{ {
SType = StructureType.GraphicsPipelineCreateInfo, SType = StructureType.GraphicsPipelineCreateInfo,
@ -659,19 +649,6 @@ namespace Ryujinx.Graphics.Vulkan
return pipeline; return pipeline;
} }
private readonly unsafe void UpdateStageRequiredSubgroupSizes(VulkanRenderer gd, int count)
{
for (int index = 0; index < count; index++)
{
bool canUseExplicitSubgroupSize =
(gd.Capabilities.RequiredSubgroupSizeStages & Stages[index].Stage) != 0 &&
gd.Capabilities.MinSubgroupSize <= RequiredSubgroupSize &&
gd.Capabilities.MaxSubgroupSize >= RequiredSubgroupSize;
Stages[index].PNext = canUseExplicitSubgroupSize ? StageRequiredSubgroupSizes.Pointer + index : null;
}
}
private void UpdateVertexAttributeDescriptions(VulkanRenderer gd) private void UpdateVertexAttributeDescriptions(VulkanRenderer gd)
{ {
// Vertex attributes exceeding the stride are invalid. // Vertex attributes exceeding the stride are invalid.

View file

@ -37,7 +37,6 @@ namespace Ryujinx.Graphics.Vulkan
"VK_EXT_shader_stencil_export", "VK_EXT_shader_stencil_export",
"VK_KHR_shader_float16_int8", "VK_KHR_shader_float16_int8",
"VK_EXT_shader_subgroup_ballot", "VK_EXT_shader_subgroup_ballot",
"VK_EXT_subgroup_size_control",
"VK_NV_geometry_shader_passthrough", "VK_NV_geometry_shader_passthrough",
"VK_NV_viewport_array2", "VK_NV_viewport_array2",
"VK_EXT_depth_clip_control", "VK_EXT_depth_clip_control",

View file

@ -151,6 +151,14 @@ namespace Ryujinx.Graphics.Vulkan
SType = StructureType.PhysicalDeviceProperties2, SType = StructureType.PhysicalDeviceProperties2,
}; };
PhysicalDeviceSubgroupProperties propertiesSubgroup = new()
{
SType = StructureType.PhysicalDeviceSubgroupProperties,
PNext = properties2.PNext,
};
properties2.PNext = &propertiesSubgroup;
PhysicalDeviceBlendOperationAdvancedPropertiesEXT propertiesBlendOperationAdvanced = new() PhysicalDeviceBlendOperationAdvancedPropertiesEXT propertiesBlendOperationAdvanced = new()
{ {
SType = StructureType.PhysicalDeviceBlendOperationAdvancedPropertiesExt, SType = StructureType.PhysicalDeviceBlendOperationAdvancedPropertiesExt,
@ -164,18 +172,6 @@ namespace Ryujinx.Graphics.Vulkan
properties2.PNext = &propertiesBlendOperationAdvanced; properties2.PNext = &propertiesBlendOperationAdvanced;
} }
PhysicalDeviceSubgroupSizeControlPropertiesEXT propertiesSubgroupSizeControl = new()
{
SType = StructureType.PhysicalDeviceSubgroupSizeControlPropertiesExt,
};
bool supportsSubgroupSizeControl = _physicalDevice.IsDeviceExtensionPresent("VK_EXT_subgroup_size_control");
if (supportsSubgroupSizeControl)
{
properties2.PNext = &propertiesSubgroupSizeControl;
}
bool supportsTransformFeedback = _physicalDevice.IsDeviceExtensionPresent(ExtTransformFeedback.ExtensionName); bool supportsTransformFeedback = _physicalDevice.IsDeviceExtensionPresent(ExtTransformFeedback.ExtensionName);
PhysicalDeviceTransformFeedbackPropertiesEXT propertiesTransformFeedback = new() PhysicalDeviceTransformFeedbackPropertiesEXT propertiesTransformFeedback = new()
@ -315,7 +311,6 @@ namespace Ryujinx.Graphics.Vulkan
_physicalDevice.IsDeviceExtensionPresent(KhrDrawIndirectCount.ExtensionName), _physicalDevice.IsDeviceExtensionPresent(KhrDrawIndirectCount.ExtensionName),
_physicalDevice.IsDeviceExtensionPresent("VK_EXT_fragment_shader_interlock"), _physicalDevice.IsDeviceExtensionPresent("VK_EXT_fragment_shader_interlock"),
_physicalDevice.IsDeviceExtensionPresent("VK_NV_geometry_shader_passthrough"), _physicalDevice.IsDeviceExtensionPresent("VK_NV_geometry_shader_passthrough"),
supportsSubgroupSizeControl,
features2.Features.ShaderFloat64, features2.Features.ShaderFloat64,
featuresShaderInt8.ShaderInt8, featuresShaderInt8.ShaderInt8,
_physicalDevice.IsDeviceExtensionPresent("VK_EXT_shader_stencil_export"), _physicalDevice.IsDeviceExtensionPresent("VK_EXT_shader_stencil_export"),
@ -335,9 +330,7 @@ namespace Ryujinx.Graphics.Vulkan
_physicalDevice.IsDeviceExtensionPresent("VK_NV_viewport_array2"), _physicalDevice.IsDeviceExtensionPresent("VK_NV_viewport_array2"),
_physicalDevice.IsDeviceExtensionPresent(ExtExternalMemoryHost.ExtensionName), _physicalDevice.IsDeviceExtensionPresent(ExtExternalMemoryHost.ExtensionName),
supportsDepthClipControl && featuresDepthClipControl.DepthClipControl, supportsDepthClipControl && featuresDepthClipControl.DepthClipControl,
propertiesSubgroupSizeControl.MinSubgroupSize, propertiesSubgroup.SubgroupSize,
propertiesSubgroupSizeControl.MaxSubgroupSize,
propertiesSubgroupSizeControl.RequiredSubgroupSizeStages,
supportedSampleCounts, supportedSampleCounts,
portabilityFlags, portabilityFlags,
vertexBufferAlignment, vertexBufferAlignment,
@ -623,6 +616,7 @@ namespace Ryujinx.Graphics.Vulkan
maximumImagesPerStage: Constants.MaxImagesPerStage, maximumImagesPerStage: Constants.MaxImagesPerStage,
maximumComputeSharedMemorySize: (int)limits.MaxComputeSharedMemorySize, maximumComputeSharedMemorySize: (int)limits.MaxComputeSharedMemorySize,
maximumSupportedAnisotropy: (int)limits.MaxSamplerAnisotropy, maximumSupportedAnisotropy: (int)limits.MaxSamplerAnisotropy,
shaderSubgroupSize: (int)Capabilities.SubgroupSize,
storageBufferOffsetAlignment: (int)limits.MinStorageBufferOffsetAlignment, storageBufferOffsetAlignment: (int)limits.MinStorageBufferOffsetAlignment,
gatherBiasPrecision: IsIntelWindows || IsAmdWindows ? (int)Capabilities.SubTexelPrecisionBits : 0); gatherBiasPrecision: IsIntelWindows || IsAmdWindows ? (int)Capabilities.SubTexelPrecisionBits : 0);
} }