From 49f970d5bd9163e2b4e26a33ef8f84529174d5de Mon Sep 17 00:00:00 2001 From: gdkchan Date: Sun, 25 Oct 2020 17:00:44 -0300 Subject: [PATCH] Implement CAL and RET shader instructions (#1618) * Add support for CAL and RET shader instructions * Remove unused stuff * Fix a bug that could cause the wrong values to be passed to a function * Avoid repopulating function id dictionary every time * PR feedback * Fix vertex shader A/B merge --- .../CodeGen/Glsl/CodeGenContext.cs | 15 +- .../CodeGen/Glsl/Declarations.cs | 7 +- .../CodeGen/Glsl/DefaultNames.cs | 2 + .../CodeGen/Glsl/GlslGenerator.cs | 116 +++-- .../CodeGen/Glsl/Instructions/InstGen.cs | 10 + .../CodeGen/Glsl/Instructions/InstGenCall.cs | 29 ++ .../Glsl/Instructions/InstGenHelper.cs | 3 +- .../Glsl/Instructions/InstGenMemory.cs | 4 +- .../CodeGen/Glsl/OperandManager.cs | 26 +- Ryujinx.Graphics.Shader/Decoders/Decoder.cs | 294 ++++++----- .../Decoders/OpCodeTable.cs | 2 + .../Instructions/InstEmitFlow.cs | 33 +- .../IntermediateRepresentation/BasicBlock.cs | 1 + .../IntermediateRepresentation/Function.cs | 23 + .../IntermediateRepresentation/Instruction.cs | 2 + .../OperandHelper.cs | 5 + .../IntermediateRepresentation/OperandType.cs | 1 + .../IntermediateRepresentation/Operation.cs | 19 + .../StructuredIr/AstOperation.cs | 19 +- .../StructuredIr/AstOptimizer.cs | 4 +- .../StructuredIr/AstTextureOperation.cs | 2 +- .../StructuredIr/InstructionInfo.cs | 5 + .../StructuredIr/StructuredFunction.cs | 41 ++ .../StructuredIr/StructuredProgram.cs | 111 ++-- .../StructuredIr/StructuredProgramContext.cs | 29 +- .../StructuredIr/StructuredProgramInfo.cs | 10 +- .../Translation/ControlFlowGraph.cs | 49 +- .../Translation/Dominance.cs | 44 +- .../Translation/EmitterContext.cs | 17 +- .../Translation/EmitterContextInsts.cs | 19 +- .../Translation/Optimizations/Optimizer.cs | 2 + .../Translation/RegisterUsage.cs | 484 ++++++++++++++++++ .../Translation/Translator.cs | 310 ++++++----- 33 files changed, 1337 insertions(+), 401 deletions(-) create mode 100644 Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenCall.cs create mode 100644 Ryujinx.Graphics.Shader/IntermediateRepresentation/Function.cs create mode 100644 Ryujinx.Graphics.Shader/StructuredIr/StructuredFunction.cs create mode 100644 Ryujinx.Graphics.Shader/Translation/RegisterUsage.cs diff --git a/Ryujinx.Graphics.Shader/CodeGen/Glsl/CodeGenContext.cs b/Ryujinx.Graphics.Shader/CodeGen/Glsl/CodeGenContext.cs index 50b9bc9f..85347dfd 100644 --- a/Ryujinx.Graphics.Shader/CodeGen/Glsl/CodeGenContext.cs +++ b/Ryujinx.Graphics.Shader/CodeGen/Glsl/CodeGenContext.cs @@ -10,9 +10,13 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl { public const string Tab = " "; + private readonly StructuredProgramInfo _info; + + public StructuredFunction CurrentFunction { get; set; } + public ShaderConfig Config { get; } - public bool CbIndexable { get; } + public bool CbIndexable => _info.UsesCbIndexing; public List CBufferDescriptors { get; } public List SBufferDescriptors { get; } @@ -27,10 +31,10 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl private string _indentation; - public CodeGenContext(ShaderConfig config, bool cbIndexable) + public CodeGenContext(StructuredProgramInfo info, ShaderConfig config) { + _info = info; Config = config; - CbIndexable = cbIndexable; CBufferDescriptors = new List(); SBufferDescriptors = new List(); @@ -95,6 +99,11 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl descriptor.CbufOffset == cBufOffset); } + public StructuredFunction GetFunction(int id) + { + return _info.Functions[id]; + } + private void UpdateIndentation() { _indentation = GetIndentation(_level); diff --git a/Ryujinx.Graphics.Shader/CodeGen/Glsl/Declarations.cs b/Ryujinx.Graphics.Shader/CodeGen/Glsl/Declarations.cs index 08279839..6f5e75aa 100644 --- a/Ryujinx.Graphics.Shader/CodeGen/Glsl/Declarations.cs +++ b/Ryujinx.Graphics.Shader/CodeGen/Glsl/Declarations.cs @@ -187,9 +187,9 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl } } - public static void DeclareLocals(CodeGenContext context, StructuredProgramInfo info) + public static void DeclareLocals(CodeGenContext context, StructuredFunction function) { - foreach (AstOperand decl in info.Locals) + foreach (AstOperand decl in function.Locals) { string name = context.OperandManager.DeclareLocal(decl); @@ -197,13 +197,14 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl } } - private static string GetVarTypeName(VariableType type) + public static string GetVarTypeName(VariableType type) { switch (type) { case VariableType.Bool: return "bool"; case VariableType.F32: return "precise float"; case VariableType.F64: return "double"; + case VariableType.None: return "void"; case VariableType.S32: return "int"; case VariableType.U32: return "uint"; } diff --git a/Ryujinx.Graphics.Shader/CodeGen/Glsl/DefaultNames.cs b/Ryujinx.Graphics.Shader/CodeGen/Glsl/DefaultNames.cs index d1cf4636..cd9ca96e 100644 --- a/Ryujinx.Graphics.Shader/CodeGen/Glsl/DefaultNames.cs +++ b/Ryujinx.Graphics.Shader/CodeGen/Glsl/DefaultNames.cs @@ -22,6 +22,8 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl public const string LocalMemoryName = "local_mem"; public const string SharedMemoryName = "shared_mem"; + public const string ArgumentNamePrefix = "a"; + public const string UndefinedName = "undef"; public const string IsBgraName = "is_bgra"; diff --git a/Ryujinx.Graphics.Shader/CodeGen/Glsl/GlslGenerator.cs b/Ryujinx.Graphics.Shader/CodeGen/Glsl/GlslGenerator.cs index 00a32262..276544fc 100644 --- a/Ryujinx.Graphics.Shader/CodeGen/Glsl/GlslGenerator.cs +++ b/Ryujinx.Graphics.Shader/CodeGen/Glsl/GlslGenerator.cs @@ -10,13 +10,32 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl { static class GlslGenerator { + private const string MainFunctionName = "main"; + public static GlslProgram Generate(StructuredProgramInfo info, ShaderConfig config) { - CodeGenContext context = new CodeGenContext(config, info.UsesCbIndexing); + CodeGenContext context = new CodeGenContext(info, config); Declarations.Declare(context, info); - PrintMainBlock(context, info); + if (info.Functions.Count != 0) + { + for (int i = 1; i < info.Functions.Count; i++) + { + context.AppendLine($"{GetFunctionSignature(info.Functions[i])};"); + } + + context.AppendLine(); + + for (int i = 1; i < info.Functions.Count; i++) + { + PrintFunction(context, info, info.Functions[i]); + + context.AppendLine(); + } + } + + PrintFunction(context, info, info.Functions[0], MainFunctionName); return new GlslProgram( context.CBufferDescriptors.ToArray(), @@ -26,55 +45,78 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl context.GetCode()); } - private static void PrintMainBlock(CodeGenContext context, StructuredProgramInfo info) + private static void PrintFunction(CodeGenContext context, StructuredProgramInfo info, StructuredFunction function, string funcName = null) { - context.AppendLine("void main()"); + context.CurrentFunction = function; + context.AppendLine(GetFunctionSignature(function, funcName)); context.EnterScope(); - Declarations.DeclareLocals(context, info); + Declarations.DeclareLocals(context, function); - // Some games will leave some elements of gl_Position uninitialized, - // in those cases, the elements will contain undefined values according - // to the spec, but on NVIDIA they seems to be always initialized to (0, 0, 0, 1), - // so we do explicit initialization to avoid UB on non-NVIDIA gpus. - if (context.Config.Stage == ShaderStage.Vertex) + if (funcName == MainFunctionName) { - context.AppendLine("gl_Position = vec4(0.0, 0.0, 0.0, 1.0);"); - } - - // Ensure that unused attributes are set, otherwise the downstream - // compiler may eliminate them. - // (Not needed for fragment shader as it is the last stage). - if (context.Config.Stage != ShaderStage.Compute && - context.Config.Stage != ShaderStage.Fragment) - { - for (int attr = 0; attr < Declarations.MaxAttributes; attr++) + // Some games will leave some elements of gl_Position uninitialized, + // in those cases, the elements will contain undefined values according + // to the spec, but on NVIDIA they seems to be always initialized to (0, 0, 0, 1), + // so we do explicit initialization to avoid UB on non-NVIDIA gpus. + if (context.Config.Stage == ShaderStage.Vertex) { - if (info.OAttributes.Contains(attr)) - { - continue; - } + context.AppendLine("gl_Position = vec4(0.0, 0.0, 0.0, 1.0);"); + } - if ((context.Config.Flags & TranslationFlags.Feedback) != 0) + // Ensure that unused attributes are set, otherwise the downstream + // compiler may eliminate them. + // (Not needed for fragment shader as it is the last stage). + if (context.Config.Stage != ShaderStage.Compute && + context.Config.Stage != ShaderStage.Fragment) + { + for (int attr = 0; attr < Declarations.MaxAttributes; attr++) { - context.AppendLine($"{DefaultNames.OAttributePrefix}{attr}_x = 0;"); - context.AppendLine($"{DefaultNames.OAttributePrefix}{attr}_y = 0;"); - context.AppendLine($"{DefaultNames.OAttributePrefix}{attr}_z = 0;"); - context.AppendLine($"{DefaultNames.OAttributePrefix}{attr}_w = 0;"); - } - else - { - context.AppendLine($"{DefaultNames.OAttributePrefix}{attr} = vec4(0);"); + if (info.OAttributes.Contains(attr)) + { + continue; + } + + if ((context.Config.Flags & TranslationFlags.Feedback) != 0) + { + context.AppendLine($"{DefaultNames.OAttributePrefix}{attr}_x = 0;"); + context.AppendLine($"{DefaultNames.OAttributePrefix}{attr}_y = 0;"); + context.AppendLine($"{DefaultNames.OAttributePrefix}{attr}_z = 0;"); + context.AppendLine($"{DefaultNames.OAttributePrefix}{attr}_w = 0;"); + } + else + { + context.AppendLine($"{DefaultNames.OAttributePrefix}{attr} = vec4(0);"); + } } } } - PrintBlock(context, info.MainBlock); + PrintBlock(context, function.MainBlock); context.LeaveScope(); } + private static string GetFunctionSignature(StructuredFunction function, string funcName = null) + { + string[] args = new string[function.InArguments.Length + function.OutArguments.Length]; + + for (int i = 0; i < function.InArguments.Length; i++) + { + args[i] = $"{Declarations.GetVarTypeName(function.InArguments[i])} {OperandManager.GetArgumentName(i)}"; + } + + for (int i = 0; i < function.OutArguments.Length; i++) + { + int j = i + function.InArguments.Length; + + args[j] = $"out {Declarations.GetVarTypeName(function.OutArguments[i])} {OperandManager.GetArgumentName(j)}"; + } + + return $"{Declarations.GetVarTypeName(function.ReturnType)} {funcName ?? function.Name}({string.Join(", ", args)})"; + } + private static void PrintBlock(CodeGenContext context, AstBlock block) { AstBlockVisitor visitor = new AstBlockVisitor(block); @@ -123,8 +165,8 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl } else if (node is AstAssignment assignment) { - VariableType srcType = OperandManager.GetNodeDestType(assignment.Source); - VariableType dstType = OperandManager.GetNodeDestType(assignment.Destination); + VariableType srcType = OperandManager.GetNodeDestType(context, assignment.Source); + VariableType dstType = OperandManager.GetNodeDestType(context, assignment.Destination); string dest; @@ -154,7 +196,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl private static string GetCondExpr(CodeGenContext context, IAstNode cond) { - VariableType srcType = OperandManager.GetNodeDestType(cond); + VariableType srcType = OperandManager.GetNodeDestType(context, cond); return ReinterpretCast(context, cond, srcType, VariableType.Bool); } diff --git a/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGen.cs b/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGen.cs index f1c741e6..388f0c25 100644 --- a/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGen.cs +++ b/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGen.cs @@ -2,6 +2,7 @@ using Ryujinx.Graphics.Shader.IntermediateRepresentation; using Ryujinx.Graphics.Shader.StructuredIr; using System; +using static Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions.InstGenCall; using static Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions.InstGenHelper; using static Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions.InstGenMemory; using static Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions.InstGenPacking; @@ -82,6 +83,12 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions { string op = info.OpName; + // Return may optionally have a return value (and in this case it is unary). + if (inst == Instruction.Return && operation.SourcesCount != 0) + { + return $"{op} {GetSoureExpr(context, operation.GetSource(0), context.CurrentFunction.ReturnType)}"; + } + int arity = (int)(info.Type & InstType.ArityMask); string[] expr = new string[arity]; @@ -116,6 +123,9 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions { switch (inst) { + case Instruction.Call: + return Call(context, operation); + case Instruction.ImageLoad: return ImageLoadOrStore(context, operation); diff --git a/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenCall.cs b/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenCall.cs new file mode 100644 index 00000000..2df6960d --- /dev/null +++ b/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenCall.cs @@ -0,0 +1,29 @@ +using Ryujinx.Graphics.Shader.IntermediateRepresentation; +using Ryujinx.Graphics.Shader.StructuredIr; +using System.Diagnostics; + +using static Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions.InstGenHelper; + +namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions +{ + static class InstGenCall + { + public static string Call(CodeGenContext context, AstOperation operation) + { + AstOperand funcId = (AstOperand)operation.GetSource(0); + + Debug.Assert(funcId.Type == OperandType.Constant); + + var function = context.GetFunction(funcId.Value); + + string[] args = new string[operation.SourcesCount - 1]; + + for (int i = 0; i < args.Length; i++) + { + args[i] = GetSoureExpr(context, operation.GetSource(i + 1), function.GetArgumentType(i)); + } + + return $"{function.Name}({string.Join(", ", args)})"; + } + } +} \ No newline at end of file diff --git a/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenHelper.cs b/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenHelper.cs index 15f9b666..1b1efe9d 100644 --- a/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenHelper.cs +++ b/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenHelper.cs @@ -36,6 +36,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions Add(Instruction.BitwiseExclusiveOr, InstType.OpBinaryCom, "^", 7); Add(Instruction.BitwiseNot, InstType.OpUnary, "~", 0); Add(Instruction.BitwiseOr, InstType.OpBinaryCom, "|", 8); + Add(Instruction.Call, InstType.Special); Add(Instruction.Ceiling, InstType.CallUnary, "ceil"); Add(Instruction.Clamp, InstType.CallTernary, "clamp"); Add(Instruction.ClampU32, InstType.CallTernary, "clamp"); @@ -135,7 +136,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions public static string GetSoureExpr(CodeGenContext context, IAstNode node, VariableType dstType) { - return ReinterpretCast(context, node, OperandManager.GetNodeDestType(node), dstType); + return ReinterpretCast(context, node, OperandManager.GetNodeDestType(context, node), dstType); } public static string Enclose(string expr, IAstNode node, Instruction pInst, bool isLhs) diff --git a/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenMemory.cs b/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenMemory.cs index cb339f05..7fdca138 100644 --- a/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenMemory.cs +++ b/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenMemory.cs @@ -226,7 +226,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions string offsetExpr = GetSoureExpr(context, src1, GetSrcVarType(operation.Inst, 0)); - VariableType srcType = OperandManager.GetNodeDestType(src2); + VariableType srcType = OperandManager.GetNodeDestType(context, src2); string src = TypeConversion.ReinterpretCast(context, src2, srcType, VariableType.U32); @@ -242,7 +242,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions string indexExpr = GetSoureExpr(context, src1, GetSrcVarType(operation.Inst, 0)); string offsetExpr = GetSoureExpr(context, src2, GetSrcVarType(operation.Inst, 1)); - VariableType srcType = OperandManager.GetNodeDestType(src3); + VariableType srcType = OperandManager.GetNodeDestType(context, src3); string src = TypeConversion.ReinterpretCast(context, src3, srcType, VariableType.U32); diff --git a/Ryujinx.Graphics.Shader/CodeGen/Glsl/OperandManager.cs b/Ryujinx.Graphics.Shader/CodeGen/Glsl/OperandManager.cs index 459b60c4..14ea7032 100644 --- a/Ryujinx.Graphics.Shader/CodeGen/Glsl/OperandManager.cs +++ b/Ryujinx.Graphics.Shader/CodeGen/Glsl/OperandManager.cs @@ -3,6 +3,7 @@ using Ryujinx.Graphics.Shader.StructuredIr; using Ryujinx.Graphics.Shader.Translation; using System; using System.Collections.Generic; +using System.Diagnostics; using static Ryujinx.Graphics.Shader.StructuredIr.InstructionInfo; @@ -96,6 +97,9 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl { switch (operand.Type) { + case OperandType.Argument: + return GetArgumentName(operand.Value); + case OperandType.Attribute: return GetAttributeName(operand, config); @@ -287,7 +291,12 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl return "xyzw"[value]; } - public static VariableType GetNodeDestType(IAstNode node) + public static string GetArgumentName(int argIndex) + { + return $"{DefaultNames.ArgumentNamePrefix}{argIndex}"; + } + + public static VariableType GetNodeDestType(CodeGenContext context, IAstNode node) { if (node is AstOperation operation) { @@ -298,6 +307,14 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl { return GetOperandVarType((AstOperand)operation.GetSource(0)); } + else if (operation.Inst == Instruction.Call) + { + AstOperand funcId = (AstOperand)operation.GetSource(0); + + Debug.Assert(funcId.Type == OperandType.Constant); + + return context.GetFunction(funcId.Value).ReturnType; + } else if (operation is AstTextureOperation texOp && (texOp.Inst == Instruction.ImageLoad || texOp.Inst == Instruction.ImageStore)) @@ -309,6 +326,13 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl } else if (node is AstOperand operand) { + if (operand.Type == OperandType.Argument) + { + int argIndex = operand.Value; + + return context.CurrentFunction.GetArgumentType(argIndex); + } + return GetOperandVarType(operand); } else diff --git a/Ryujinx.Graphics.Shader/Decoders/Decoder.cs b/Ryujinx.Graphics.Shader/Decoders/Decoder.cs index 3f08bdd9..ca45aab5 100644 --- a/Ryujinx.Graphics.Shader/Decoders/Decoder.cs +++ b/Ryujinx.Graphics.Shader/Decoders/Decoder.cs @@ -9,149 +9,172 @@ namespace Ryujinx.Graphics.Shader.Decoders { static class Decoder { - public static Block[] Decode(IGpuAccessor gpuAccessor, ulong startAddress) + public static Block[][] Decode(IGpuAccessor gpuAccessor, ulong startAddress) { - List blocks = new List(); + List funcs = new List(); - Queue workQueue = new Queue(); + Queue funcQueue = new Queue(); + HashSet funcVisited = new HashSet(); - Dictionary visited = new Dictionary(); - - Block GetBlock(ulong blkAddress) + void EnqueueFunction(ulong funcAddress) { - if (!visited.TryGetValue(blkAddress, out Block block)) + if (funcVisited.Add(funcAddress)) { - block = new Block(blkAddress); - - workQueue.Enqueue(block); - - visited.Add(blkAddress, block); - } - - return block; - } - - GetBlock(0); - - while (workQueue.TryDequeue(out Block currBlock)) - { - // Check if the current block is inside another block. - if (BinarySearch(blocks, currBlock.Address, out int nBlkIndex)) - { - Block nBlock = blocks[nBlkIndex]; - - if (nBlock.Address == currBlock.Address) - { - throw new InvalidOperationException("Found duplicate block address on the list."); - } - - nBlock.Split(currBlock); - - blocks.Insert(nBlkIndex + 1, currBlock); - - continue; - } - - // If we have a block after the current one, set the limit address. - ulong limitAddress = ulong.MaxValue; - - if (nBlkIndex != blocks.Count) - { - Block nBlock = blocks[nBlkIndex]; - - int nextIndex = nBlkIndex + 1; - - if (nBlock.Address < currBlock.Address && nextIndex < blocks.Count) - { - limitAddress = blocks[nextIndex].Address; - } - else if (nBlock.Address > currBlock.Address) - { - limitAddress = blocks[nBlkIndex].Address; - } - } - - FillBlock(gpuAccessor, currBlock, limitAddress, startAddress); - - if (currBlock.OpCodes.Count != 0) - { - // We should have blocks for all possible branch targets, - // including those from SSY/PBK instructions. - foreach (OpCodePush pushOp in currBlock.PushOpCodes) - { - GetBlock(pushOp.GetAbsoluteAddress()); - } - - // Set child blocks. "Branch" is the block the branch instruction - // points to (when taken), "Next" is the block at the next address, - // executed when the branch is not taken. For Unconditional Branches - // or end of program, Next is null. - OpCode lastOp = currBlock.GetLastOp(); - - if (lastOp is OpCodeBranch opBr) - { - currBlock.Branch = GetBlock(opBr.GetAbsoluteAddress()); - } - else if (lastOp is OpCodeBranchIndir opBrIndir) - { - // An indirect branch could go anywhere, we don't know the target. - // Those instructions are usually used on a switch to jump table - // compiler optimization, and in those cases the possible targets - // seems to be always right after the BRX itself. We can assume - // that the possible targets are all the blocks in-between the - // instruction right after the BRX, and the common target that - // all the "cases" should eventually jump to, acting as the - // switch break. - Block firstTarget = GetBlock(currBlock.EndAddress); - - firstTarget.BrIndir = opBrIndir; - - opBrIndir.PossibleTargets.Add(firstTarget); - } - - if (!IsUnconditionalBranch(lastOp)) - { - currBlock.Next = GetBlock(currBlock.EndAddress); - } - } - - // Insert the new block on the list (sorted by address). - if (blocks.Count != 0) - { - Block nBlock = blocks[nBlkIndex]; - - blocks.Insert(nBlkIndex + (nBlock.Address < currBlock.Address ? 1 : 0), currBlock); - } - else - { - blocks.Add(currBlock); - } - - // Do we have a block after the current one? - if (currBlock.BrIndir != null && HasBlockAfter(gpuAccessor, currBlock, startAddress)) - { - bool targetVisited = visited.ContainsKey(currBlock.EndAddress); - - Block possibleTarget = GetBlock(currBlock.EndAddress); - - currBlock.BrIndir.PossibleTargets.Add(possibleTarget); - - if (!targetVisited) - { - possibleTarget.BrIndir = currBlock.BrIndir; - } + funcQueue.Enqueue(funcAddress); } } - foreach (Block block in blocks.Where(x => x.PushOpCodes.Count != 0)) + funcQueue.Enqueue(0); + + while (funcQueue.TryDequeue(out ulong funcAddress)) { - for (int pushOpIndex = 0; pushOpIndex < block.PushOpCodes.Count; pushOpIndex++) + List blocks = new List(); + Queue workQueue = new Queue(); + Dictionary visited = new Dictionary(); + + Block GetBlock(ulong blkAddress) { - PropagatePushOp(visited, block, pushOpIndex); + if (!visited.TryGetValue(blkAddress, out Block block)) + { + block = new Block(blkAddress); + + workQueue.Enqueue(block); + visited.Add(blkAddress, block); + } + + return block; } + + GetBlock(funcAddress); + + while (workQueue.TryDequeue(out Block currBlock)) + { + // Check if the current block is inside another block. + if (BinarySearch(blocks, currBlock.Address, out int nBlkIndex)) + { + Block nBlock = blocks[nBlkIndex]; + + if (nBlock.Address == currBlock.Address) + { + throw new InvalidOperationException("Found duplicate block address on the list."); + } + + nBlock.Split(currBlock); + blocks.Insert(nBlkIndex + 1, currBlock); + + continue; + } + + // If we have a block after the current one, set the limit address. + ulong limitAddress = ulong.MaxValue; + + if (nBlkIndex != blocks.Count) + { + Block nBlock = blocks[nBlkIndex]; + + int nextIndex = nBlkIndex + 1; + + if (nBlock.Address < currBlock.Address && nextIndex < blocks.Count) + { + limitAddress = blocks[nextIndex].Address; + } + else if (nBlock.Address > currBlock.Address) + { + limitAddress = blocks[nBlkIndex].Address; + } + } + + FillBlock(gpuAccessor, currBlock, limitAddress, startAddress); + + if (currBlock.OpCodes.Count != 0) + { + // We should have blocks for all possible branch targets, + // including those from SSY/PBK instructions. + foreach (OpCodePush pushOp in currBlock.PushOpCodes) + { + GetBlock(pushOp.GetAbsoluteAddress()); + } + + // Set child blocks. "Branch" is the block the branch instruction + // points to (when taken), "Next" is the block at the next address, + // executed when the branch is not taken. For Unconditional Branches + // or end of program, Next is null. + OpCode lastOp = currBlock.GetLastOp(); + + if (lastOp is OpCodeBranch opBr) + { + if (lastOp.Emitter == InstEmit.Cal) + { + EnqueueFunction(opBr.GetAbsoluteAddress()); + } + else + { + currBlock.Branch = GetBlock(opBr.GetAbsoluteAddress()); + } + } + else if (lastOp is OpCodeBranchIndir opBrIndir) + { + // An indirect branch could go anywhere, we don't know the target. + // Those instructions are usually used on a switch to jump table + // compiler optimization, and in those cases the possible targets + // seems to be always right after the BRX itself. We can assume + // that the possible targets are all the blocks in-between the + // instruction right after the BRX, and the common target that + // all the "cases" should eventually jump to, acting as the + // switch break. + Block firstTarget = GetBlock(currBlock.EndAddress); + + firstTarget.BrIndir = opBrIndir; + + opBrIndir.PossibleTargets.Add(firstTarget); + } + + if (!IsUnconditionalBranch(lastOp)) + { + currBlock.Next = GetBlock(currBlock.EndAddress); + } + } + + // Insert the new block on the list (sorted by address). + if (blocks.Count != 0) + { + Block nBlock = blocks[nBlkIndex]; + + blocks.Insert(nBlkIndex + (nBlock.Address < currBlock.Address ? 1 : 0), currBlock); + } + else + { + blocks.Add(currBlock); + } + + // Do we have a block after the current one? + if (currBlock.BrIndir != null && HasBlockAfter(gpuAccessor, currBlock, startAddress)) + { + bool targetVisited = visited.ContainsKey(currBlock.EndAddress); + + Block possibleTarget = GetBlock(currBlock.EndAddress); + + currBlock.BrIndir.PossibleTargets.Add(possibleTarget); + + if (!targetVisited) + { + possibleTarget.BrIndir = currBlock.BrIndir; + } + } + } + + foreach (Block block in blocks.Where(x => x.PushOpCodes.Count != 0)) + { + for (int pushOpIndex = 0; pushOpIndex < block.PushOpCodes.Count; pushOpIndex++) + { + PropagatePushOp(visited, block, pushOpIndex); + } + } + + funcs.Add(blocks.ToArray()); } - return blocks.ToArray(); + return funcs.ToArray(); } private static bool HasBlockAfter(IGpuAccessor gpuAccessor, Block currBlock, ulong startAdddress) @@ -251,7 +274,7 @@ namespace Ryujinx.Graphics.Shader.Decoders block.OpCodes.Add(op); } - while (!IsBranch(block.GetLastOp())); + while (!IsControlFlowChange(block.GetLastOp())); block.EndAddress = address; @@ -260,7 +283,7 @@ namespace Ryujinx.Graphics.Shader.Decoders private static bool IsUnconditionalBranch(OpCode opCode) { - return IsUnconditional(opCode) && IsBranch(opCode); + return IsUnconditional(opCode) && IsControlFlowChange(opCode); } private static bool IsUnconditional(OpCode opCode) @@ -273,7 +296,7 @@ namespace Ryujinx.Graphics.Shader.Decoders return opCode.Predicate.Index == RegisterConsts.PredicateTrueIndex && !opCode.InvertPredicate; } - private static bool IsBranch(OpCode opCode) + private static bool IsControlFlowChange(OpCode opCode) { return (opCode is OpCodeBranch opBranch && !opBranch.PushTarget) || opCode is OpCodeBranchIndir || @@ -281,11 +304,6 @@ namespace Ryujinx.Graphics.Shader.Decoders opCode is OpCodeExit; } - private static bool IsExit(OpCode opCode) - { - return opCode is OpCodeExit; - } - private struct PathBlockState { public Block Block { get; } diff --git a/Ryujinx.Graphics.Shader/Decoders/OpCodeTable.cs b/Ryujinx.Graphics.Shader/Decoders/OpCodeTable.cs index 61168b59..302f1fc4 100644 --- a/Ryujinx.Graphics.Shader/Decoders/OpCodeTable.cs +++ b/Ryujinx.Graphics.Shader/Decoders/OpCodeTable.cs @@ -46,6 +46,7 @@ namespace Ryujinx.Graphics.Shader.Decoders Set("111000100100xx", InstEmit.Bra, OpCodeBranch.Create); Set("111000110100xx", InstEmit.Brk, OpCodeBranchPop.Create); Set("111000100101xx", InstEmit.Brx, OpCodeBranchIndir.Create); + Set("111000100110xx", InstEmit.Cal, OpCodeBranch.Create); Set("0101000010100x", InstEmit.Csetp, OpCodePset.Create); Set("0100110001110x", InstEmit.Dadd, OpCodeFArithCbuf.Create); Set("0011100x01110x", InstEmit.Dadd, OpCodeDArithImm.Create); @@ -185,6 +186,7 @@ namespace Ryujinx.Graphics.Shader.Decoders Set("0011100x11110x", InstEmit.R2p, OpCodeAluImm.Create); Set("0101110011110x", InstEmit.R2p, OpCodeAluReg.Create); Set("1110101111111x", InstEmit.Red, OpCodeRed.Create); + Set("111000110010xx", InstEmit.Ret, OpCodeExit.Create); Set("0100110010010x", InstEmit.Rro, OpCodeFArithCbuf.Create); Set("0011100x10010x", InstEmit.Rro, OpCodeFArithImm.Create); Set("0101110010010x", InstEmit.Rro, OpCodeFArithReg.Create); diff --git a/Ryujinx.Graphics.Shader/Instructions/InstEmitFlow.cs b/Ryujinx.Graphics.Shader/Instructions/InstEmitFlow.cs index e34f2988..332074ae 100644 --- a/Ryujinx.Graphics.Shader/Instructions/InstEmitFlow.cs +++ b/Ryujinx.Graphics.Shader/Instructions/InstEmitFlow.cs @@ -51,10 +51,25 @@ namespace Ryujinx.Graphics.Shader.Instructions } } - public static void Depbar(EmitterContext context) { } + public static void Cal(EmitterContext context) + { + OpCodeBranch op = (OpCodeBranch)context.CurrOp; + + context.Call(context.GetFunctionId(op.GetAbsoluteAddress()), false); + } + + public static void Depbar(EmitterContext context) + { + } public static void Exit(EmitterContext context) { + if (context.IsNonMain) + { + context.Config.GpuAccessor.Log("Invalid exit on non-main function."); + return; + } + OpCodeExit op = (OpCodeExit)context.CurrOp; // TODO: Figure out how this is supposed to work in the @@ -70,13 +85,27 @@ namespace Ryujinx.Graphics.Shader.Instructions context.Discard(); } - public static void Nop(EmitterContext context) { } + public static void Nop(EmitterContext context) + { + } public static void Pbk(EmitterContext context) { EmitPbkOrSsy(context); } + public static void Ret(EmitterContext context) + { + if (context.IsNonMain) + { + context.Return(); + } + else + { + context.Config.GpuAccessor.Log("Invalid return on main function."); + } + } + public static void Ssy(EmitterContext context) { EmitPbkOrSsy(context); diff --git a/Ryujinx.Graphics.Shader/IntermediateRepresentation/BasicBlock.cs b/Ryujinx.Graphics.Shader/IntermediateRepresentation/BasicBlock.cs index 94975337..1f7d2b25 100644 --- a/Ryujinx.Graphics.Shader/IntermediateRepresentation/BasicBlock.cs +++ b/Ryujinx.Graphics.Shader/IntermediateRepresentation/BasicBlock.cs @@ -24,6 +24,7 @@ namespace Ryujinx.Graphics.Shader.IntermediateRepresentation } public bool HasBranch => _branch != null; + public bool Reachable => Index == 0 || Predecessors.Count != 0; public List Predecessors { get; } diff --git a/Ryujinx.Graphics.Shader/IntermediateRepresentation/Function.cs b/Ryujinx.Graphics.Shader/IntermediateRepresentation/Function.cs new file mode 100644 index 00000000..e535c3fc --- /dev/null +++ b/Ryujinx.Graphics.Shader/IntermediateRepresentation/Function.cs @@ -0,0 +1,23 @@ +namespace Ryujinx.Graphics.Shader.IntermediateRepresentation +{ + class Function + { + public BasicBlock[] Blocks { get; } + + public string Name { get; } + + public bool ReturnsValue { get; } + + public int InArgumentsCount { get; } + public int OutArgumentsCount { get; } + + public Function(BasicBlock[] blocks, string name, bool returnsValue, int inArgumentsCount, int outArgumentsCount) + { + Blocks = blocks; + Name = name; + ReturnsValue = returnsValue; + InArgumentsCount = inArgumentsCount; + OutArgumentsCount = outArgumentsCount; + } + } +} \ No newline at end of file diff --git a/Ryujinx.Graphics.Shader/IntermediateRepresentation/Instruction.cs b/Ryujinx.Graphics.Shader/IntermediateRepresentation/Instruction.cs index 4a6c3a78..c0356e46 100644 --- a/Ryujinx.Graphics.Shader/IntermediateRepresentation/Instruction.cs +++ b/Ryujinx.Graphics.Shader/IntermediateRepresentation/Instruction.cs @@ -31,6 +31,8 @@ namespace Ryujinx.Graphics.Shader.IntermediateRepresentation Branch, BranchIfFalse, BranchIfTrue, + Call, + CallOutArgument, Ceiling, Clamp, ClampU32, diff --git a/Ryujinx.Graphics.Shader/IntermediateRepresentation/OperandHelper.cs b/Ryujinx.Graphics.Shader/IntermediateRepresentation/OperandHelper.cs index 6765f8a4..221e278f 100644 --- a/Ryujinx.Graphics.Shader/IntermediateRepresentation/OperandHelper.cs +++ b/Ryujinx.Graphics.Shader/IntermediateRepresentation/OperandHelper.cs @@ -5,6 +5,11 @@ namespace Ryujinx.Graphics.Shader.IntermediateRepresentation { static class OperandHelper { + public static Operand Argument(int value) + { + return new Operand(OperandType.Argument, value); + } + public static Operand Attribute(int value) { return new Operand(OperandType.Attribute, value); diff --git a/Ryujinx.Graphics.Shader/IntermediateRepresentation/OperandType.cs b/Ryujinx.Graphics.Shader/IntermediateRepresentation/OperandType.cs index 8f8df9e4..3427b103 100644 --- a/Ryujinx.Graphics.Shader/IntermediateRepresentation/OperandType.cs +++ b/Ryujinx.Graphics.Shader/IntermediateRepresentation/OperandType.cs @@ -2,6 +2,7 @@ namespace Ryujinx.Graphics.Shader.IntermediateRepresentation { enum OperandType { + Argument, Attribute, Constant, ConstantBuffer, diff --git a/Ryujinx.Graphics.Shader/IntermediateRepresentation/Operation.cs b/Ryujinx.Graphics.Shader/IntermediateRepresentation/Operation.cs index 2c4a88cd..a86a278a 100644 --- a/Ryujinx.Graphics.Shader/IntermediateRepresentation/Operation.cs +++ b/Ryujinx.Graphics.Shader/IntermediateRepresentation/Operation.cs @@ -48,6 +48,25 @@ namespace Ryujinx.Graphics.Shader.IntermediateRepresentation Index = index; } + public void AppendOperands(params Operand[] operands) + { + int startIndex = _sources.Length; + + Array.Resize(ref _sources, startIndex + operands.Length); + + for (int index = 0; index < operands.Length; index++) + { + Operand source = operands[index]; + + if (source.Type == OperandType.LocalVariable) + { + source.UseOps.Add(this); + } + + _sources[startIndex + index] = source; + } + } + private Operand AssignDest(Operand dest) { if (dest != null && dest.Type == OperandType.LocalVariable) diff --git a/Ryujinx.Graphics.Shader/StructuredIr/AstOperation.cs b/Ryujinx.Graphics.Shader/StructuredIr/AstOperation.cs index 76eee71e..a8474955 100644 --- a/Ryujinx.Graphics.Shader/StructuredIr/AstOperation.cs +++ b/Ryujinx.Graphics.Shader/StructuredIr/AstOperation.cs @@ -14,24 +14,35 @@ namespace Ryujinx.Graphics.Shader.StructuredIr public int SourcesCount => _sources.Length; - public AstOperation(Instruction inst, params IAstNode[] sources) + public AstOperation(Instruction inst, IAstNode[] sources, int sourcesCount) { Inst = inst; _sources = sources; - foreach (IAstNode source in sources) + for (int index = 0; index < sources.Length; index++) { - AddUse(source, this); + if (index < sourcesCount) + { + AddUse(sources[index], this); + } + else + { + AddDef(sources[index], this); + } } Index = 0; } - public AstOperation(Instruction inst, int index, params IAstNode[] sources) : this(inst, sources) + public AstOperation(Instruction inst, int index, IAstNode[] sources, int sourcesCount) : this(inst, sources, sourcesCount) { Index = index; } + public AstOperation(Instruction inst, params IAstNode[] sources) : this(inst, sources, sources.Length) + { + } + public IAstNode GetSource(int index) { return _sources[index]; diff --git a/Ryujinx.Graphics.Shader/StructuredIr/AstOptimizer.cs b/Ryujinx.Graphics.Shader/StructuredIr/AstOptimizer.cs index a37e1a3e..4c6d17a0 100644 --- a/Ryujinx.Graphics.Shader/StructuredIr/AstOptimizer.cs +++ b/Ryujinx.Graphics.Shader/StructuredIr/AstOptimizer.cs @@ -11,7 +11,7 @@ namespace Ryujinx.Graphics.Shader.StructuredIr { public static void Optimize(StructuredProgramContext context) { - AstBlock mainBlock = context.Info.MainBlock; + AstBlock mainBlock = context.CurrentFunction.MainBlock; // When debug mode is enabled, we disable expression propagation // (this makes comparison with the disassembly easier). @@ -34,7 +34,7 @@ namespace Ryujinx.Graphics.Shader.StructuredIr { visitor.Block.Remove(assignment); - context.Info.Locals.Remove(propVar); + context.CurrentFunction.Locals.Remove(propVar); } } } diff --git a/Ryujinx.Graphics.Shader/StructuredIr/AstTextureOperation.cs b/Ryujinx.Graphics.Shader/StructuredIr/AstTextureOperation.cs index a3fa3e3a..188bf919 100644 --- a/Ryujinx.Graphics.Shader/StructuredIr/AstTextureOperation.cs +++ b/Ryujinx.Graphics.Shader/StructuredIr/AstTextureOperation.cs @@ -19,7 +19,7 @@ namespace Ryujinx.Graphics.Shader.StructuredIr int handle, int arraySize, int index, - params IAstNode[] sources) : base(inst, index, sources) + params IAstNode[] sources) : base(inst, index, sources, sources.Length) { Type = type; Format = format; diff --git a/Ryujinx.Graphics.Shader/StructuredIr/InstructionInfo.cs b/Ryujinx.Graphics.Shader/StructuredIr/InstructionInfo.cs index 3fcc5f11..fcf39cc0 100644 --- a/Ryujinx.Graphics.Shader/StructuredIr/InstructionInfo.cs +++ b/Ryujinx.Graphics.Shader/StructuredIr/InstructionInfo.cs @@ -49,6 +49,7 @@ namespace Ryujinx.Graphics.Shader.StructuredIr Add(Instruction.BitwiseOr, VariableType.Int, VariableType.Int, VariableType.Int); Add(Instruction.BranchIfTrue, VariableType.None, VariableType.Bool); Add(Instruction.BranchIfFalse, VariableType.None, VariableType.Bool); + Add(Instruction.Call, VariableType.Scalar); Add(Instruction.Ceiling, VariableType.Scalar, VariableType.Scalar, VariableType.Scalar); Add(Instruction.Clamp, VariableType.Scalar, VariableType.Scalar, VariableType.Scalar, VariableType.Scalar); Add(Instruction.ClampU32, VariableType.U32, VariableType.U32, VariableType.U32, VariableType.U32); @@ -151,6 +152,10 @@ namespace Ryujinx.Graphics.Shader.StructuredIr { return VariableType.F32; } + else if (inst == Instruction.Call) + { + return VariableType.S32; + } return GetFinalVarType(_infoTbl[(int)(inst & Instruction.Mask)].SrcTypes[index], inst); } diff --git a/Ryujinx.Graphics.Shader/StructuredIr/StructuredFunction.cs b/Ryujinx.Graphics.Shader/StructuredIr/StructuredFunction.cs new file mode 100644 index 00000000..3723f259 --- /dev/null +++ b/Ryujinx.Graphics.Shader/StructuredIr/StructuredFunction.cs @@ -0,0 +1,41 @@ +using System.Collections.Generic; + +namespace Ryujinx.Graphics.Shader.StructuredIr +{ + class StructuredFunction + { + public AstBlock MainBlock { get; } + + public string Name { get; } + + public VariableType ReturnType { get; } + + public VariableType[] InArguments { get; } + public VariableType[] OutArguments { get; } + + public HashSet Locals { get; } + + public StructuredFunction( + AstBlock mainBlock, + string name, + VariableType returnType, + VariableType[] inArguments, + VariableType[] outArguments) + { + MainBlock = mainBlock; + Name = name; + ReturnType = returnType; + InArguments = inArguments; + OutArguments = outArguments; + + Locals = new HashSet(); + } + + public VariableType GetArgumentType(int index) + { + return index >= InArguments.Length + ? OutArguments[index - InArguments.Length] + : InArguments[index]; + } + } +} \ No newline at end of file diff --git a/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgram.cs b/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgram.cs index 65de5218..66570dc9 100644 --- a/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgram.cs +++ b/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgram.cs @@ -8,51 +8,108 @@ namespace Ryujinx.Graphics.Shader.StructuredIr { static class StructuredProgram { - public static StructuredProgramInfo MakeStructuredProgram(BasicBlock[] blocks, ShaderConfig config) + public static StructuredProgramInfo MakeStructuredProgram(Function[] functions, ShaderConfig config) { - PhiFunctions.Remove(blocks); + StructuredProgramContext context = new StructuredProgramContext(config); - StructuredProgramContext context = new StructuredProgramContext(blocks.Length, config); - - for (int blkIndex = 0; blkIndex < blocks.Length; blkIndex++) + for (int funcIndex = 0; funcIndex < functions.Length; funcIndex++) { - BasicBlock block = blocks[blkIndex]; + Function function = functions[funcIndex]; - context.EnterBlock(block); + BasicBlock[] blocks = function.Blocks; - foreach (INode node in block.Operations) + VariableType returnType = function.ReturnsValue ? VariableType.S32 : VariableType.None; + + VariableType[] inArguments = new VariableType[function.InArgumentsCount]; + VariableType[] outArguments = new VariableType[function.OutArgumentsCount]; + + for (int i = 0; i < inArguments.Length; i++) { - Operation operation = (Operation)node; + inArguments[i] = VariableType.S32; + } - if (IsBranchInst(operation.Inst)) + for (int i = 0; i < outArguments.Length; i++) + { + outArguments[i] = VariableType.S32; + } + + context.EnterFunction(blocks.Length, function.Name, returnType, inArguments, outArguments); + + PhiFunctions.Remove(blocks); + + for (int blkIndex = 0; blkIndex < blocks.Length; blkIndex++) + { + BasicBlock block = blocks[blkIndex]; + + context.EnterBlock(block); + + for (LinkedListNode opNode = block.Operations.First; opNode != null; opNode = opNode.Next) { - context.LeaveBlock(block, operation); - } - else - { - AddOperation(context, operation); + Operation operation = (Operation)opNode.Value; + + if (IsBranchInst(operation.Inst)) + { + context.LeaveBlock(block, operation); + } + else if (operation.Inst != Instruction.CallOutArgument) + { + AddOperation(context, opNode); + } } } + + GotoElimination.Eliminate(context.GetGotos()); + + AstOptimizer.Optimize(context); + + context.LeaveFunction(); } - GotoElimination.Eliminate(context.GetGotos()); - - AstOptimizer.Optimize(context); - return context.Info; } - private static void AddOperation(StructuredProgramContext context, Operation operation) + private static void AddOperation(StructuredProgramContext context, LinkedListNode opNode) { + Operation operation = (Operation)opNode.Value; + Instruction inst = operation.Inst; - IAstNode[] sources = new IAstNode[operation.SourcesCount]; + bool isCall = inst == Instruction.Call; - for (int index = 0; index < sources.Length; index++) + int sourcesCount = operation.SourcesCount; + + List callOutOperands = new List(); + + if (isCall) + { + LinkedListNode scan = opNode.Next; + + while (scan != null && scan.Value is Operation nextOp && nextOp.Inst == Instruction.CallOutArgument) + { + callOutOperands.Add(nextOp.Dest); + scan = scan.Next; + } + + sourcesCount += callOutOperands.Count; + } + + IAstNode[] sources = new IAstNode[sourcesCount]; + + for (int index = 0; index < operation.SourcesCount; index++) { sources[index] = context.GetOperandUse(operation.GetSource(index)); } + if (isCall) + { + for (int index = 0; index < callOutOperands.Count; index++) + { + sources[operation.SourcesCount + index] = context.GetOperandDef(callOutOperands[index]); + } + + callOutOperands.Clear(); + } + AstTextureOperation GetAstTextureOperation(TextureOperation texOp) { return new AstTextureOperation( @@ -98,8 +155,6 @@ namespace Ryujinx.Graphics.Shader.StructuredIr AddSBufferUse(context.Info.SBuffers, operation); } - AstAssignment assignment; - // If all the sources are bool, it's better to use short-circuiting // logical operations, rather than forcing a cast to int and doing // a bitwise operation with the value, as it is likely to be used as @@ -152,16 +207,14 @@ namespace Ryujinx.Graphics.Shader.StructuredIr } else if (!isCopy) { - source = new AstOperation(inst, operation.Index, sources); + source = new AstOperation(inst, operation.Index, sources, operation.SourcesCount); } else { source = sources[0]; } - assignment = new AstAssignment(dest, source); - - context.AddNode(assignment); + context.AddNode(new AstAssignment(dest, source)); } else if (operation.Inst == Instruction.Comment) { @@ -182,7 +235,7 @@ namespace Ryujinx.Graphics.Shader.StructuredIr AddSBufferUse(context.Info.SBuffers, operation); } - context.AddNode(new AstOperation(inst, operation.Index, sources)); + context.AddNode(new AstOperation(inst, operation.Index, sources, operation.SourcesCount)); } // Those instructions needs to be emulated by using helper functions, diff --git a/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgramContext.cs b/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgramContext.cs index b7d5efbe..2667be1d 100644 --- a/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgramContext.cs +++ b/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgramContext.cs @@ -24,11 +24,25 @@ namespace Ryujinx.Graphics.Shader.StructuredIr private int _currEndIndex; private int _loopEndIndex; + public StructuredFunction CurrentFunction { get; private set; } + public StructuredProgramInfo Info { get; } public ShaderConfig Config { get; } - public StructuredProgramContext(int blocksCount, ShaderConfig config) + public StructuredProgramContext(ShaderConfig config) + { + Info = new StructuredProgramInfo(); + + Config = config; + } + + public void EnterFunction( + int blocksCount, + string name, + VariableType returnType, + VariableType[] inArguments, + VariableType[] outArguments) { _loopTails = new HashSet(); @@ -45,9 +59,12 @@ namespace Ryujinx.Graphics.Shader.StructuredIr _currEndIndex = blocksCount; _loopEndIndex = blocksCount; - Info = new StructuredProgramInfo(_currBlock); + CurrentFunction = new StructuredFunction(_currBlock, name, returnType, inArguments, outArguments); + } - Config = config; + public void LeaveFunction() + { + Info.Functions.Add(CurrentFunction); } public void EnterBlock(BasicBlock block) @@ -185,7 +202,7 @@ namespace Ryujinx.Graphics.Shader.StructuredIr // so it is reset to false by the "local" assignment anyway. if (block.Index != 0) { - Info.MainBlock.AddFirst(Assign(gotoTempAsg.Destination, Const(IrConsts.False))); + CurrentFunction.MainBlock.AddFirst(Assign(gotoTempAsg.Destination, Const(IrConsts.False))); } } @@ -253,7 +270,7 @@ namespace Ryujinx.Graphics.Shader.StructuredIr { AstOperand newTemp = Local(type); - Info.Locals.Add(newTemp); + CurrentFunction.Locals.Add(newTemp); return newTemp; } @@ -304,7 +321,7 @@ namespace Ryujinx.Graphics.Shader.StructuredIr _localsMap.Add(operand, astOperand); - Info.Locals.Add(astOperand); + CurrentFunction.Locals.Add(astOperand); } return astOperand; diff --git a/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgramInfo.cs b/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgramInfo.cs index ef3b3eca..16a27f51 100644 --- a/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgramInfo.cs +++ b/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgramInfo.cs @@ -4,9 +4,7 @@ namespace Ryujinx.Graphics.Shader.StructuredIr { class StructuredProgramInfo { - public AstBlock MainBlock { get; } - - public HashSet Locals { get; } + public List Functions { get; } public HashSet CBuffers { get; } public HashSet SBuffers { get; } @@ -22,11 +20,9 @@ namespace Ryujinx.Graphics.Shader.StructuredIr public HashSet Samplers { get; } public HashSet Images { get; } - public StructuredProgramInfo(AstBlock mainBlock) + public StructuredProgramInfo() { - MainBlock = mainBlock; - - Locals = new HashSet(); + Functions = new List(); CBuffers = new HashSet(); SBuffers = new HashSet(); diff --git a/Ryujinx.Graphics.Shader/Translation/ControlFlowGraph.cs b/Ryujinx.Graphics.Shader/Translation/ControlFlowGraph.cs index e2ca74a4..fb0535c8 100644 --- a/Ryujinx.Graphics.Shader/Translation/ControlFlowGraph.cs +++ b/Ryujinx.Graphics.Shader/Translation/ControlFlowGraph.cs @@ -3,9 +3,52 @@ using System.Collections.Generic; namespace Ryujinx.Graphics.Shader.Translation { - static class ControlFlowGraph + class ControlFlowGraph { - public static BasicBlock[] MakeCfg(Operation[] operations) + public BasicBlock[] Blocks { get; } + public BasicBlock[] PostOrderBlocks { get; } + public int[] PostOrderMap { get; } + + public ControlFlowGraph(BasicBlock[] blocks) + { + Blocks = blocks; + + HashSet visited = new HashSet(); + + Stack blockStack = new Stack(); + + List postOrderBlocks = new List(blocks.Length); + + PostOrderMap = new int[blocks.Length]; + + visited.Add(blocks[0]); + + blockStack.Push(blocks[0]); + + while (blockStack.TryPop(out BasicBlock block)) + { + if (block.Next != null && visited.Add(block.Next)) + { + blockStack.Push(block); + blockStack.Push(block.Next); + } + else if (block.Branch != null && visited.Add(block.Branch)) + { + blockStack.Push(block); + blockStack.Push(block.Branch); + } + else + { + PostOrderMap[block.Index] = postOrderBlocks.Count; + + postOrderBlocks.Add(block); + } + } + + PostOrderBlocks = postOrderBlocks.ToArray(); + } + + public static ControlFlowGraph Create(Operation[] operations) { Dictionary labels = new Dictionary(); @@ -86,7 +129,7 @@ namespace Ryujinx.Graphics.Shader.Translation } } - return blocks.ToArray(); + return new ControlFlowGraph(blocks.ToArray()); } private static bool EndsWithUnconditionalInst(INode node) diff --git a/Ryujinx.Graphics.Shader/Translation/Dominance.cs b/Ryujinx.Graphics.Shader/Translation/Dominance.cs index 6a3ff35f..da4a38da 100644 --- a/Ryujinx.Graphics.Shader/Translation/Dominance.cs +++ b/Ryujinx.Graphics.Shader/Translation/Dominance.cs @@ -7,50 +7,18 @@ namespace Ryujinx.Graphics.Shader.Translation { // Those methods are an implementation of the algorithms on "A Simple, Fast Dominance Algorithm". // https://www.cs.rice.edu/~keith/EMBED/dom.pdf - public static void FindDominators(BasicBlock entry, int blocksCount) + public static void FindDominators(ControlFlowGraph cfg) { - HashSet visited = new HashSet(); - - Stack blockStack = new Stack(); - - List postOrderBlocks = new List(blocksCount); - - int[] postOrderMap = new int[blocksCount]; - - visited.Add(entry); - - blockStack.Push(entry); - - while (blockStack.TryPop(out BasicBlock block)) - { - if (block.Next != null && visited.Add(block.Next)) - { - blockStack.Push(block); - blockStack.Push(block.Next); - } - else if (block.Branch != null && visited.Add(block.Branch)) - { - blockStack.Push(block); - blockStack.Push(block.Branch); - } - else - { - postOrderMap[block.Index] = postOrderBlocks.Count; - - postOrderBlocks.Add(block); - } - } - BasicBlock Intersect(BasicBlock block1, BasicBlock block2) { while (block1 != block2) { - while (postOrderMap[block1.Index] < postOrderMap[block2.Index]) + while (cfg.PostOrderMap[block1.Index] < cfg.PostOrderMap[block2.Index]) { block1 = block1.ImmediateDominator; } - while (postOrderMap[block2.Index] < postOrderMap[block1.Index]) + while (cfg.PostOrderMap[block2.Index] < cfg.PostOrderMap[block1.Index]) { block2 = block2.ImmediateDominator; } @@ -59,7 +27,7 @@ namespace Ryujinx.Graphics.Shader.Translation return block1; } - entry.ImmediateDominator = entry; + cfg.Blocks[0].ImmediateDominator = cfg.Blocks[0]; bool modified; @@ -67,9 +35,9 @@ namespace Ryujinx.Graphics.Shader.Translation { modified = false; - for (int blkIndex = postOrderBlocks.Count - 2; blkIndex >= 0; blkIndex--) + for (int blkIndex = cfg.PostOrderBlocks.Length - 2; blkIndex >= 0; blkIndex--) { - BasicBlock block = postOrderBlocks[blkIndex]; + BasicBlock block = cfg.PostOrderBlocks[blkIndex]; BasicBlock newIDom = null; diff --git a/Ryujinx.Graphics.Shader/Translation/EmitterContext.cs b/Ryujinx.Graphics.Shader/Translation/EmitterContext.cs index c5ebe9e7..d5d30f12 100644 --- a/Ryujinx.Graphics.Shader/Translation/EmitterContext.cs +++ b/Ryujinx.Graphics.Shader/Translation/EmitterContext.cs @@ -13,16 +13,18 @@ namespace Ryujinx.Graphics.Shader.Translation public ShaderConfig Config { get; } - private List _operations; + public bool IsNonMain { get; } - private Dictionary _labels; + private readonly IReadOnlyDictionary _funcs; + private readonly List _operations; + private readonly Dictionary _labels; - public EmitterContext(ShaderConfig config) + public EmitterContext(ShaderConfig config, bool isNonMain, IReadOnlyDictionary funcs) { Config = config; - + IsNonMain = isNonMain; + _funcs = funcs; _operations = new List(); - _labels = new Dictionary(); } @@ -71,6 +73,11 @@ namespace Ryujinx.Graphics.Shader.Translation return label; } + public int GetFunctionId(ulong address) + { + return _funcs[address]; + } + public void PrepareForReturn() { if (Config.Stage == ShaderStage.Fragment) diff --git a/Ryujinx.Graphics.Shader/Translation/EmitterContextInsts.cs b/Ryujinx.Graphics.Shader/Translation/EmitterContextInsts.cs index c8d622b2..40f3370f 100644 --- a/Ryujinx.Graphics.Shader/Translation/EmitterContextInsts.cs +++ b/Ryujinx.Graphics.Shader/Translation/EmitterContextInsts.cs @@ -136,6 +136,16 @@ namespace Ryujinx.Graphics.Shader.Translation return context.Add(Instruction.BranchIfTrue, d, a); } + public static Operand Call(this EmitterContext context, int funcId, bool returns, params Operand[] args) + { + Operand[] args2 = new Operand[args.Length + 1]; + + args2[0] = Const(funcId); + args.CopyTo(args2, 1); + + return context.Add(Instruction.Call, returns ? Local() : null, args2); + } + public static Operand ConditionalSelect(this EmitterContext context, Operand a, Operand b, Operand c) { return context.Add(Instruction.ConditionalSelect, Local(), a, b, c); @@ -521,11 +531,16 @@ namespace Ryujinx.Graphics.Shader.Translation return context.Add(Instruction.PackHalf2x16, Local(), a, b); } - public static Operand Return(this EmitterContext context) + public static void Return(this EmitterContext context) { context.PrepareForReturn(); + context.Add(Instruction.Return); + } - return context.Add(Instruction.Return); + public static void Return(this EmitterContext context, Operand returnValue) + { + context.PrepareForReturn(); + context.Add(Instruction.Return, null, returnValue); } public static Operand ShiftLeft(this EmitterContext context, Operand a, Operand b) diff --git a/Ryujinx.Graphics.Shader/Translation/Optimizations/Optimizer.cs b/Ryujinx.Graphics.Shader/Translation/Optimizations/Optimizer.cs index 286574cf..32c7d2f0 100644 --- a/Ryujinx.Graphics.Shader/Translation/Optimizations/Optimizer.cs +++ b/Ryujinx.Graphics.Shader/Translation/Optimizations/Optimizer.cs @@ -287,6 +287,8 @@ namespace Ryujinx.Graphics.Shader.Translation.Optimizations case Instruction.AtomicOr: case Instruction.AtomicSwap: case Instruction.AtomicXor: + case Instruction.Call: + case Instruction.CallOutArgument: return true; } } diff --git a/Ryujinx.Graphics.Shader/Translation/RegisterUsage.cs b/Ryujinx.Graphics.Shader/Translation/RegisterUsage.cs new file mode 100644 index 00000000..fd90391f --- /dev/null +++ b/Ryujinx.Graphics.Shader/Translation/RegisterUsage.cs @@ -0,0 +1,484 @@ +using Ryujinx.Graphics.Shader.Decoders; +using Ryujinx.Graphics.Shader.IntermediateRepresentation; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Numerics; + +namespace Ryujinx.Graphics.Shader.Translation +{ + static class RegisterUsage + { + private const int RegsCount = 256; + private const int RegsMask = RegsCount - 1; + + private const int GprMasks = 4; + private const int PredMasks = 1; + private const int FlagMasks = 1; + private const int TotalMasks = GprMasks + PredMasks + FlagMasks; + + private struct RegisterMask : IEquatable + { + public long GprMask0 { get; set; } + public long GprMask1 { get; set; } + public long GprMask2 { get; set; } + public long GprMask3 { get; set; } + public long PredMask { get; set; } + public long FlagMask { get; set; } + + public RegisterMask(long gprMask0, long gprMask1, long gprMask2, long gprMask3, long predMask, long flagMask) + { + GprMask0 = gprMask0; + GprMask1 = gprMask1; + GprMask2 = gprMask2; + GprMask3 = gprMask3; + PredMask = predMask; + FlagMask = flagMask; + } + + public long GetMask(int index) + { + return index switch + { + 0 => GprMask0, + 1 => GprMask1, + 2 => GprMask2, + 3 => GprMask3, + 4 => PredMask, + 5 => FlagMask, + _ => throw new ArgumentOutOfRangeException(nameof(index)) + }; + } + + public static RegisterMask operator &(RegisterMask x, RegisterMask y) + { + return new RegisterMask( + x.GprMask0 & y.GprMask0, + x.GprMask1 & y.GprMask1, + x.GprMask2 & y.GprMask2, + x.GprMask3 & y.GprMask3, + x.PredMask & y.PredMask, + x.FlagMask & y.FlagMask); + } + + public static RegisterMask operator |(RegisterMask x, RegisterMask y) + { + return new RegisterMask( + x.GprMask0 | y.GprMask0, + x.GprMask1 | y.GprMask1, + x.GprMask2 | y.GprMask2, + x.GprMask3 | y.GprMask3, + x.PredMask | y.PredMask, + x.FlagMask | y.FlagMask); + } + + public static RegisterMask operator ~(RegisterMask x) + { + return new RegisterMask( + ~x.GprMask0, + ~x.GprMask1, + ~x.GprMask2, + ~x.GprMask3, + ~x.PredMask, + ~x.FlagMask); + } + + public static bool operator ==(RegisterMask x, RegisterMask y) + { + return x.Equals(y); + } + + public static bool operator !=(RegisterMask x, RegisterMask y) + { + return !x.Equals(y); + } + + public override bool Equals(object obj) + { + return obj is RegisterMask regMask && Equals(regMask); + } + + public bool Equals(RegisterMask other) + { + return GprMask0 == other.GprMask0 && + GprMask1 == other.GprMask1 && + GprMask2 == other.GprMask2 && + GprMask3 == other.GprMask3 && + PredMask == other.PredMask && + FlagMask == other.FlagMask; + } + + public override int GetHashCode() + { + return HashCode.Combine(GprMask0, GprMask1, GprMask2, GprMask3, PredMask, FlagMask); + } + } + + public struct FunctionRegisterUsage + { + public Register[] InArguments { get; } + public Register[] OutArguments { get; } + + public FunctionRegisterUsage(Register[] inArguments, Register[] outArguments) + { + InArguments = inArguments; + OutArguments = outArguments; + } + } + + public static FunctionRegisterUsage RunPass(ControlFlowGraph cfg) + { + List inArguments = new List(); + List outArguments = new List(); + + // Compute local register inputs and outputs used inside blocks. + RegisterMask[] localInputs = new RegisterMask[cfg.Blocks.Length]; + RegisterMask[] localOutputs = new RegisterMask[cfg.Blocks.Length]; + + foreach (BasicBlock block in cfg.Blocks) + { + for (LinkedListNode node = block.Operations.First; node != null; node = node.Next) + { + Operation operation = node.Value as Operation; + + for (int srcIndex = 0; srcIndex < operation.SourcesCount; srcIndex++) + { + Operand source = operation.GetSource(srcIndex); + + if (source.Type != OperandType.Register) + { + continue; + } + + Register register = source.GetRegister(); + + localInputs[block.Index] |= GetMask(register) & ~localOutputs[block.Index]; + } + + if (operation.Dest != null && operation.Dest.Type == OperandType.Register) + { + localOutputs[block.Index] |= GetMask(operation.Dest.GetRegister()); + } + } + } + + // Compute global register inputs and outputs used across blocks. + RegisterMask[] globalCmnOutputs = new RegisterMask[cfg.Blocks.Length]; + + RegisterMask[] globalInputs = new RegisterMask[cfg.Blocks.Length]; + RegisterMask[] globalOutputs = new RegisterMask[cfg.Blocks.Length]; + + RegisterMask allOutputs = new RegisterMask(); + RegisterMask allCmnOutputs = new RegisterMask(-1L, -1L, -1L, -1L, -1L, -1L); + + bool modified; + + bool firstPass = true; + + do + { + modified = false; + + // Compute register outputs. + for (int index = cfg.PostOrderBlocks.Length - 1; index >= 0; index--) + { + BasicBlock block = cfg.PostOrderBlocks[index]; + + if (block.Predecessors.Count != 0) + { + BasicBlock predecessor = block.Predecessors[0]; + + RegisterMask cmnOutputs = localOutputs[predecessor.Index] | globalCmnOutputs[predecessor.Index]; + + RegisterMask outputs = globalOutputs[predecessor.Index]; + + for (int pIndex = 1; pIndex < block.Predecessors.Count; pIndex++) + { + predecessor = block.Predecessors[pIndex]; + + cmnOutputs &= localOutputs[predecessor.Index] | globalCmnOutputs[predecessor.Index]; + + outputs |= globalOutputs[predecessor.Index]; + } + + globalInputs[block.Index] |= outputs & ~cmnOutputs; + + if (!firstPass) + { + cmnOutputs &= globalCmnOutputs[block.Index]; + } + + if (EndsWithReturn(block)) + { + allCmnOutputs &= cmnOutputs | localOutputs[block.Index]; + } + + if (Exchange(globalCmnOutputs, block.Index, cmnOutputs)) + { + modified = true; + } + + outputs |= localOutputs[block.Index]; + + if (Exchange(globalOutputs, block.Index, globalOutputs[block.Index] | outputs)) + { + allOutputs |= outputs; + modified = true; + } + } + else if (Exchange(globalOutputs, block.Index, localOutputs[block.Index])) + { + allOutputs |= localOutputs[block.Index]; + modified = true; + } + } + + // Compute register inputs. + for (int index = 0; index < cfg.PostOrderBlocks.Length; index++) + { + BasicBlock block = cfg.PostOrderBlocks[index]; + + RegisterMask inputs = localInputs[block.Index]; + + if (block.Next != null) + { + inputs |= globalInputs[block.Next.Index]; + } + + if (block.Branch != null) + { + inputs |= globalInputs[block.Branch.Index]; + } + + inputs &= ~globalCmnOutputs[block.Index]; + + if (Exchange(globalInputs, block.Index, globalInputs[block.Index] | inputs)) + { + modified = true; + } + } + + firstPass = false; + } + while (modified); + + // Insert load and store context instructions where needed. + foreach (BasicBlock block in cfg.Blocks) + { + // The only block without any predecessor should be the entry block. + // It always needs a context load as it is the first block to run. + if (block.Predecessors.Count == 0) + { + RegisterMask inputs = globalInputs[block.Index] | (allOutputs & ~allCmnOutputs); + + LoadLocals(block, inputs, inArguments); + } + + if (EndsWithReturn(block)) + { + StoreLocals(block, allOutputs, inArguments.Count, outArguments); + } + } + + return new FunctionRegisterUsage(inArguments.ToArray(), outArguments.ToArray()); + } + + public static void FixupCalls(BasicBlock[] blocks, FunctionRegisterUsage[] frus) + { + foreach (BasicBlock block in blocks) + { + for (LinkedListNode node = block.Operations.First; node != null; node = node.Next) + { + Operation operation = node.Value as Operation; + + if (operation.Inst == Instruction.Call) + { + Operand funcId = operation.GetSource(0); + + Debug.Assert(funcId.Type == OperandType.Constant); + + var fru = frus[funcId.Value]; + + Operand[] regs = new Operand[fru.InArguments.Length]; + + for (int i = 0; i < fru.InArguments.Length; i++) + { + regs[i] = OperandHelper.Register(fru.InArguments[i]); + } + + operation.AppendOperands(regs); + + for (int i = 0; i < fru.OutArguments.Length; i++) + { + Operation callOutArgOp = new Operation(Instruction.CallOutArgument, OperandHelper.Register(fru.OutArguments[i])); + + node = block.Operations.AddAfter(node, callOutArgOp); + } + } + } + } + } + + private static bool StartsWith(BasicBlock block, Instruction inst) + { + if (block.Operations.Count == 0) + { + return false; + } + + return block.Operations.First.Value is Operation operation && operation.Inst == inst; + } + + private static bool EndsWith(BasicBlock block, Instruction inst) + { + if (block.Operations.Count == 0) + { + return false; + } + + return block.Operations.Last.Value is Operation operation && operation.Inst == inst; + } + + private static RegisterMask GetMask(Register register) + { + Span gprMasks = stackalloc long[4]; + long predMask = 0; + long flagMask = 0; + + switch (register.Type) + { + case RegisterType.Gpr: + gprMasks[register.Index >> 6] = 1L << (register.Index & 0x3f); + break; + case RegisterType.Predicate: + predMask = 1L << register.Index; + break; + case RegisterType.Flag: + flagMask = 1L << register.Index; + break; + } + + return new RegisterMask(gprMasks[0], gprMasks[1], gprMasks[2], gprMasks[3], predMask, flagMask); + } + + private static bool Exchange(RegisterMask[] masks, int blkIndex, RegisterMask value) + { + RegisterMask oldValue = masks[blkIndex]; + + masks[blkIndex] = value; + + return oldValue != value; + } + + private static void LoadLocals(BasicBlock block, RegisterMask masks, List inArguments) + { + bool fillArgsList = inArguments.Count == 0; + LinkedListNode node = null; + int argIndex = 0; + + for (int i = 0; i < TotalMasks; i++) + { + (RegisterType regType, int baseRegIndex) = GetRegTypeAndBaseIndex(i); + long mask = masks.GetMask(i); + + while (mask != 0) + { + int bit = BitOperations.TrailingZeroCount(mask); + + mask &= ~(1L << bit); + + Register register = new Register(baseRegIndex + bit, regType); + + if (fillArgsList) + { + inArguments.Add(register); + } + + Operation copyOp = new Operation(Instruction.Copy, OperandHelper.Register(register), OperandHelper.Argument(argIndex++)); + + if (node == null) + { + node = block.Operations.AddFirst(copyOp); + } + else + { + node = block.Operations.AddAfter(node, copyOp); + } + } + } + + Debug.Assert(argIndex <= inArguments.Count); + } + + private static void StoreLocals(BasicBlock block, RegisterMask masks, int inArgumentsCount, List outArguments) + { + LinkedListNode node = null; + int argIndex = inArgumentsCount; + bool fillArgsList = outArguments.Count == 0; + + for (int i = 0; i < TotalMasks; i++) + { + (RegisterType regType, int baseRegIndex) = GetRegTypeAndBaseIndex(i); + long mask = masks.GetMask(i); + + while (mask != 0) + { + int bit = BitOperations.TrailingZeroCount(mask); + + mask &= ~(1L << bit); + + Register register = new Register(baseRegIndex + bit, regType); + + if (fillArgsList) + { + outArguments.Add(register); + } + + Operation copyOp = new Operation(Instruction.Copy, OperandHelper.Argument(argIndex++), OperandHelper.Register(register)); + + if (node == null) + { + node = block.Operations.AddBefore(block.Operations.Last, copyOp); + } + else + { + node = block.Operations.AddAfter(node, copyOp); + } + } + } + + Debug.Assert(argIndex <= inArgumentsCount + outArguments.Count); + } + + private static (RegisterType RegType, int BaseRegIndex) GetRegTypeAndBaseIndex(int i) + { + RegisterType regType = RegisterType.Gpr; + int baseRegIndex = 0; + + if (i < GprMasks) + { + baseRegIndex = i * sizeof(long) * 8; + } + else if (i == GprMasks) + { + regType = RegisterType.Predicate; + } + else + { + regType = RegisterType.Flag; + } + + return (regType, baseRegIndex); + } + + private static bool EndsWithReturn(BasicBlock block) + { + if (!(block.GetLastOp() is Operation operation)) + { + return false; + } + + return operation.Inst == Instruction.Return; + } + } +} \ No newline at end of file diff --git a/Ryujinx.Graphics.Shader/Translation/Translator.cs b/Ryujinx.Graphics.Shader/Translation/Translator.cs index db0924b3..f8093c84 100644 --- a/Ryujinx.Graphics.Shader/Translation/Translator.cs +++ b/Ryujinx.Graphics.Shader/Translation/Translator.cs @@ -14,6 +14,16 @@ namespace Ryujinx.Graphics.Shader.Translation { private const int HeaderSize = 0x50; + private struct FunctionCode + { + public Operation[] Code { get; } + + public FunctionCode(Operation[] code) + { + Code = code; + } + } + public static ShaderProgram Translate(ulong address, IGpuAccessor gpuAccessor, TranslationFlags flags) { return Translate(DecodeShader(address, gpuAccessor, flags, out ShaderConfig config), config); @@ -21,32 +31,65 @@ namespace Ryujinx.Graphics.Shader.Translation public static ShaderProgram Translate(ulong addressA, ulong addressB, IGpuAccessor gpuAccessor, TranslationFlags flags) { - Operation[] opsA = DecodeShader(addressA, gpuAccessor, flags | TranslationFlags.VertexA, out ShaderConfig configA); - Operation[] opsB = DecodeShader(addressB, gpuAccessor, flags, out ShaderConfig config); + FunctionCode[] funcA = DecodeShader(addressA, gpuAccessor, flags | TranslationFlags.VertexA, out ShaderConfig configA); + FunctionCode[] funcB = DecodeShader(addressB, gpuAccessor, flags, out ShaderConfig config); config.SetUsedFeature(configA.UsedFeatures); - return Translate(Combine(opsA, opsB), config, configA.Size); + return Translate(Combine(funcA, funcB), config, configA.Size); } - private static ShaderProgram Translate(Operation[] ops, ShaderConfig config, int sizeA = 0) + private static ShaderProgram Translate(FunctionCode[] functions, ShaderConfig config, int sizeA = 0) { - BasicBlock[] blocks = ControlFlowGraph.MakeCfg(ops); + var cfgs = new ControlFlowGraph[functions.Length]; + var frus = new RegisterUsage.FunctionRegisterUsage[functions.Length]; - if (blocks.Length > 0) + for (int i = 0; i < functions.Length; i++) { - Dominance.FindDominators(blocks[0], blocks.Length); + cfgs[i] = ControlFlowGraph.Create(functions[i].Code); - Dominance.FindDominanceFrontiers(blocks); - - Ssa.Rename(blocks); - - Optimizer.RunPass(blocks, config); - - Lowering.RunPass(blocks, config); + if (i != 0) + { + frus[i] = RegisterUsage.RunPass(cfgs[i]); + } } - StructuredProgramInfo sInfo = StructuredProgram.MakeStructuredProgram(blocks, config); + Function[] funcs = new Function[functions.Length]; + + for (int i = 0; i < functions.Length; i++) + { + var cfg = cfgs[i]; + + int inArgumentsCount = 0; + int outArgumentsCount = 0; + + if (i != 0) + { + var fru = frus[i]; + + inArgumentsCount = fru.InArguments.Length; + outArgumentsCount = fru.OutArguments.Length; + } + + if (cfg.Blocks.Length != 0) + { + RegisterUsage.FixupCalls(cfg.Blocks, frus); + + Dominance.FindDominators(cfg); + + Dominance.FindDominanceFrontiers(cfg.Blocks); + + Ssa.Rename(cfg.Blocks); + + Optimizer.RunPass(cfg.Blocks, config); + + Lowering.RunPass(cfg.Blocks, config); + } + + funcs[i] = new Function(cfg.Blocks, $"fun{i}", false, inArgumentsCount, outArgumentsCount); + } + + StructuredProgramInfo sInfo = StructuredProgram.MakeStructuredProgram(funcs, config); GlslProgram program = GlslGenerator.Generate(sInfo, config); @@ -62,9 +105,9 @@ namespace Ryujinx.Graphics.Shader.Translation return new ShaderProgram(spInfo, config.Stage, glslCode, config.Size, sizeA); } - private static Operation[] DecodeShader(ulong address, IGpuAccessor gpuAccessor, TranslationFlags flags, out ShaderConfig config) + private static FunctionCode[] DecodeShader(ulong address, IGpuAccessor gpuAccessor, TranslationFlags flags, out ShaderConfig config) { - Block[] cfg; + Block[][] cfg; if ((flags & TranslationFlags.Compute) != 0) { @@ -83,112 +126,131 @@ namespace Ryujinx.Graphics.Shader.Translation { gpuAccessor.Log("Invalid branch detected, failed to build CFG."); - return Array.Empty(); + return Array.Empty(); } - EmitterContext context = new EmitterContext(config); + Dictionary funcIds = new Dictionary(); + + for (int funcIndex = 0; funcIndex < cfg.Length; funcIndex++) + { + funcIds.Add(cfg[funcIndex][0].Address, funcIndex); + } + + List funcs = new List(); ulong maxEndAddress = 0; - for (int blkIndex = 0; blkIndex < cfg.Length; blkIndex++) + for (int funcIndex = 0; funcIndex < cfg.Length; funcIndex++) { - Block block = cfg[blkIndex]; + EmitterContext context = new EmitterContext(config, funcIndex != 0, funcIds); - if (maxEndAddress < block.EndAddress) + for (int blkIndex = 0; blkIndex < cfg[funcIndex].Length; blkIndex++) { - maxEndAddress = block.EndAddress; + Block block = cfg[funcIndex][blkIndex]; + + if (maxEndAddress < block.EndAddress) + { + maxEndAddress = block.EndAddress; + } + + context.CurrBlock = block; + + context.MarkLabel(context.GetLabel(block.Address)); + + EmitOps(context, block); } - context.CurrBlock = block; - - context.MarkLabel(context.GetLabel(block.Address)); - - for (int opIndex = 0; opIndex < block.OpCodes.Count; opIndex++) - { - OpCode op = block.OpCodes[opIndex]; - - if ((flags & TranslationFlags.DebugMode) != 0) - { - string instName; - - if (op.Emitter != null) - { - instName = op.Emitter.Method.Name; - } - else - { - instName = "???"; - - gpuAccessor.Log($"Invalid instruction at 0x{op.Address:X6} (0x{op.RawOpCode:X16})."); - } - - string dbgComment = $"0x{op.Address:X6}: 0x{op.RawOpCode:X16} {instName}"; - - context.Add(new CommentNode(dbgComment)); - } - - if (op.NeverExecute) - { - continue; - } - - Operand predSkipLbl = null; - - bool skipPredicateCheck = op is OpCodeBranch opBranch && !opBranch.PushTarget; - - if (op is OpCodeBranchPop opBranchPop) - { - // If the instruction is a SYNC or BRK instruction with only one - // possible target address, then the instruction is basically - // just a simple branch, we can generate code similar to branch - // instructions, with the condition check on the branch itself. - skipPredicateCheck = opBranchPop.Targets.Count < 2; - } - - if (!(op.Predicate.IsPT || skipPredicateCheck)) - { - Operand label; - - if (opIndex == block.OpCodes.Count - 1 && block.Next != null) - { - label = context.GetLabel(block.Next.Address); - } - else - { - label = Label(); - - predSkipLbl = label; - } - - Operand pred = Register(op.Predicate); - - if (op.InvertPredicate) - { - context.BranchIfTrue(label, pred); - } - else - { - context.BranchIfFalse(label, pred); - } - } - - context.CurrOp = op; - - op.Emitter?.Invoke(context); - - if (predSkipLbl != null) - { - context.MarkLabel(predSkipLbl); - } - } + funcs.Add(new FunctionCode(context.GetOperations())); } config.SizeAdd((int)maxEndAddress + (flags.HasFlag(TranslationFlags.Compute) ? 0 : HeaderSize)); - return context.GetOperations(); + return funcs.ToArray(); } - private static Operation[] Combine(Operation[] a, Operation[] b) + internal static void EmitOps(EmitterContext context, Block block) + { + for (int opIndex = 0; opIndex < block.OpCodes.Count; opIndex++) + { + OpCode op = block.OpCodes[opIndex]; + + if ((context.Config.Flags & TranslationFlags.DebugMode) != 0) + { + string instName; + + if (op.Emitter != null) + { + instName = op.Emitter.Method.Name; + } + else + { + instName = "???"; + + context.Config.GpuAccessor.Log($"Invalid instruction at 0x{op.Address:X6} (0x{op.RawOpCode:X16})."); + } + + string dbgComment = $"0x{op.Address:X6}: 0x{op.RawOpCode:X16} {instName}"; + + context.Add(new CommentNode(dbgComment)); + } + + if (op.NeverExecute) + { + continue; + } + + Operand predSkipLbl = null; + + bool skipPredicateCheck = op is OpCodeBranch opBranch && !opBranch.PushTarget; + + if (op is OpCodeBranchPop opBranchPop) + { + // If the instruction is a SYNC or BRK instruction with only one + // possible target address, then the instruction is basically + // just a simple branch, we can generate code similar to branch + // instructions, with the condition check on the branch itself. + skipPredicateCheck = opBranchPop.Targets.Count < 2; + } + + if (!(op.Predicate.IsPT || skipPredicateCheck)) + { + Operand label; + + if (opIndex == block.OpCodes.Count - 1 && block.Next != null) + { + label = context.GetLabel(block.Next.Address); + } + else + { + label = Label(); + + predSkipLbl = label; + } + + Operand pred = Register(op.Predicate); + + if (op.InvertPredicate) + { + context.BranchIfTrue(label, pred); + } + else + { + context.BranchIfFalse(label, pred); + } + } + + context.CurrOp = op; + + op.Emitter?.Invoke(context); + + if (predSkipLbl != null) + { + context.MarkLabel(predSkipLbl); + } + } + } + + private static FunctionCode[] Combine(FunctionCode[] a, FunctionCode[] b) { // Here we combine two shaders. // For shader A: @@ -199,15 +261,17 @@ namespace Ryujinx.Graphics.Shader.Translation // For shader B: // - All user attribute loads on shader B are turned into copies from a // temporary variable, as long that attribute is written by shader A. - List output = new List(a.Length + b.Length); + FunctionCode[] output = new FunctionCode[a.Length + b.Length - 1]; + + List ops = new List(a.Length + b.Length); Operand[] temps = new Operand[AttributeConsts.UserAttributesCount * 4]; Operand lblB = Label(); - for (int index = 0; index < a.Length; index++) + for (int index = 0; index < a[0].Code.Length; index++) { - Operation operation = a[index]; + Operation operation = a[0].Code[index]; if (IsUserAttribute(operation.Dest)) { @@ -227,19 +291,19 @@ namespace Ryujinx.Graphics.Shader.Translation if (operation.Inst == Instruction.Return) { - output.Add(new Operation(Instruction.Branch, lblB)); + ops.Add(new Operation(Instruction.Branch, lblB)); } else { - output.Add(operation); + ops.Add(operation); } } - output.Add(new Operation(Instruction.MarkLabel, lblB)); + ops.Add(new Operation(Instruction.MarkLabel, lblB)); - for (int index = 0; index < b.Length; index++) + for (int index = 0; index < b[0].Code.Length; index++) { - Operation operation = b[index]; + Operation operation = b[0].Code[index]; for (int srcIndex = 0; srcIndex < operation.SourcesCount; srcIndex++) { @@ -256,10 +320,22 @@ namespace Ryujinx.Graphics.Shader.Translation } } - output.Add(operation); + ops.Add(operation); } - return output.ToArray(); + output[0] = new FunctionCode(ops.ToArray()); + + for (int i = 1; i < a.Length; i++) + { + output[i] = a[i]; + } + + for (int i = 1; i < b.Length; i++) + { + output[a.Length + i - 1] = b[i]; + } + + return output; } private static bool IsUserAttribute(Operand operand)