From 8fa8f3a39086ef97e8c84ee5ec794a31c378b3de Mon Sep 17 00:00:00 2001
From: Isaac Marovitz <isaacryu@icloud.com>
Date: Thu, 1 Aug 2024 15:51:06 +0100
Subject: [PATCH] Precise Float Fixes

Fixes artifacts in TOTK
---
 .../CodeGen/Msl/Declarations.cs                    |  5 +++++
 .../CodeGen/Msl/HelperFunctions/Precise.metal      | 14 ++++++++++++++
 .../CodeGen/Msl/Instructions/InstGen.cs            | 12 ++++++++++++
 .../CodeGen/Msl/NumberFormatter.cs                 |  6 ++++--
 .../Ryujinx.Graphics.Shader.csproj                 |  1 +
 .../StructuredIr/HelperFunctionsMask.cs            |  2 ++
 .../StructuredIr/StructuredProgram.cs              |  3 ++-
 .../StructuredIr/StructuredProgramContext.cs       |  3 ++-
 .../StructuredIr/StructuredProgramInfo.cs          |  7 ++++++-
 .../Translation/FeatureFlags.cs                    |  1 +
 .../Translation/Transforms/ForcePreciseEnable.cs   |  2 ++
 .../Translation/TranslatorContext.cs               |  1 +
 12 files changed, 52 insertions(+), 5 deletions(-)
 create mode 100644 src/Ryujinx.Graphics.Shader/CodeGen/Msl/HelperFunctions/Precise.metal

diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Declarations.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Declarations.cs
index ed423a60b..50cce8200 100644
--- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Declarations.cs
+++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Declarations.cs
@@ -122,6 +122,11 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
                 AppendHelperFunction(context, "Ryujinx.Graphics.Shader/CodeGen/Msl/HelperFunctions/SwizzleAdd.metal");
             }
 
+            if ((info.HelperFunctionsMask & HelperFunctionsMask.Precise) != 0)
+            {
+                AppendHelperFunction(context, "Ryujinx.Graphics.Shader/CodeGen/Msl/HelperFunctions/Precise.metal");
+            }
+
             return sets;
         }
 
diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/HelperFunctions/Precise.metal b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/HelperFunctions/Precise.metal
new file mode 100644
index 000000000..366bea1ac
--- /dev/null
+++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/HelperFunctions/Precise.metal
@@ -0,0 +1,14 @@
+template<typename T>
+[[clang::optnone]] T PreciseFAdd(T l, T r) {
+    return fma(T(1), l, r);
+}
+
+template<typename T>
+[[clang::optnone]] T PreciseFSub(T l, T r) {
+    return fma(T(-1), r, l);
+}
+
+template<typename T>
+[[clang::optnone]] T PreciseFMul(T l, T r) {
+    return fma(l, r, T(0));
+}
diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGen.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGen.cs
index ac294d960..715688987 100644
--- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGen.cs
+++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGen.cs
@@ -118,6 +118,18 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
                         return op + expr[0];
 
                     case 2:
+                        if (operation.ForcePrecise)
+                        {
+                            var func = (inst & Instruction.Mask) switch
+                            {
+                                Instruction.Add => "PreciseFAdd",
+                                Instruction.Subtract => "PreciseFSub",
+                                Instruction.Multiply => "PreciseFMul",
+                            };
+
+                            return $"{func}({expr[0]}, {expr[1]})";
+                        }
+
                         return $"{expr[0]} {op} {expr[1]}";
 
                     case 3:
diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/NumberFormatter.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/NumberFormatter.cs
index 8d288da3e..86cdfc0e6 100644
--- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/NumberFormatter.cs
+++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/NumberFormatter.cs
@@ -49,9 +49,11 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
                 return false;
             }
 
-            formatted = value.ToString("F9", CultureInfo.InvariantCulture);
+            formatted = value.ToString("G9", CultureInfo.InvariantCulture);
 
-            if (!formatted.Contains('.'))
+            if (!(formatted.Contains('.') ||
+                  formatted.Contains('e') ||
+                  formatted.Contains('E')))
             {
                 formatted += ".0f";
             }
diff --git a/src/Ryujinx.Graphics.Shader/Ryujinx.Graphics.Shader.csproj b/src/Ryujinx.Graphics.Shader/Ryujinx.Graphics.Shader.csproj
index ad26cbd56..6ba6d4225 100644
--- a/src/Ryujinx.Graphics.Shader/Ryujinx.Graphics.Shader.csproj
+++ b/src/Ryujinx.Graphics.Shader/Ryujinx.Graphics.Shader.csproj
@@ -20,5 +20,6 @@
     <EmbeddedResource Include="CodeGen\Msl\HelperFunctions\FindMSBS32.metal" />
     <EmbeddedResource Include="CodeGen\Msl\HelperFunctions\FindMSBU32.metal" />
     <EmbeddedResource Include="CodeGen\Msl\HelperFunctions\SwizzleAdd.metal" />
+    <EmbeddedResource Include="CodeGen\Msl\HelperFunctions\Precise.metal" />
   </ItemGroup>
 </Project>
diff --git a/src/Ryujinx.Graphics.Shader/StructuredIr/HelperFunctionsMask.cs b/src/Ryujinx.Graphics.Shader/StructuredIr/HelperFunctionsMask.cs
index 8e7bbd6f1..b70def78c 100644
--- a/src/Ryujinx.Graphics.Shader/StructuredIr/HelperFunctionsMask.cs
+++ b/src/Ryujinx.Graphics.Shader/StructuredIr/HelperFunctionsMask.cs
@@ -14,5 +14,7 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
 
         SwizzleAdd = 1 << 10,
         FSI = 1 << 11,
+
+        Precise = 1 << 13
     }
 }
diff --git a/src/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgram.cs b/src/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgram.cs
index 394099902..a1aef7f97 100644
--- a/src/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgram.cs
+++ b/src/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgram.cs
@@ -18,9 +18,10 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
             ShaderDefinitions definitions,
             ResourceManager resourceManager,
             TargetLanguage targetLanguage,
+            bool precise,
             bool debugMode)
         {
-            StructuredProgramContext context = new(attributeUsage, definitions, resourceManager, debugMode);
+            StructuredProgramContext context = new(attributeUsage, definitions, resourceManager, precise, debugMode);
 
             for (int funcIndex = 0; funcIndex < functions.Count; funcIndex++)
             {
diff --git a/src/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgramContext.cs b/src/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgramContext.cs
index 045662a1e..c26086c72 100644
--- a/src/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgramContext.cs
+++ b/src/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgramContext.cs
@@ -36,9 +36,10 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
             AttributeUsage attributeUsage,
             ShaderDefinitions definitions,
             ResourceManager resourceManager,
+            bool precise,
             bool debugMode)
         {
-            Info = new StructuredProgramInfo();
+            Info = new StructuredProgramInfo(precise);
 
             Definitions = definitions;
             ResourceManager = resourceManager;
diff --git a/src/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgramInfo.cs b/src/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgramInfo.cs
index ded2f2a89..585497ed3 100644
--- a/src/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgramInfo.cs
+++ b/src/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgramInfo.cs
@@ -10,11 +10,16 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
 
         public HelperFunctionsMask HelperFunctionsMask { get; set; }
 
-        public StructuredProgramInfo()
+        public StructuredProgramInfo(bool precise)
         {
             Functions = new List<StructuredFunction>();
 
             IoDefinitions = new HashSet<IoDefinition>();
+
+            if (precise)
+            {
+                HelperFunctionsMask |= HelperFunctionsMask.Precise;
+            }
         }
     }
 }
diff --git a/src/Ryujinx.Graphics.Shader/Translation/FeatureFlags.cs b/src/Ryujinx.Graphics.Shader/Translation/FeatureFlags.cs
index 82a54db83..26c924e89 100644
--- a/src/Ryujinx.Graphics.Shader/Translation/FeatureFlags.cs
+++ b/src/Ryujinx.Graphics.Shader/Translation/FeatureFlags.cs
@@ -26,5 +26,6 @@ namespace Ryujinx.Graphics.Shader.Translation
         SharedMemory = 1 << 11,
         Store = 1 << 12,
         VtgAsCompute = 1 << 13,
+        Precise = 1 << 14,
     }
 }
diff --git a/src/Ryujinx.Graphics.Shader/Translation/Transforms/ForcePreciseEnable.cs b/src/Ryujinx.Graphics.Shader/Translation/Transforms/ForcePreciseEnable.cs
index 6b7e1410f..c774816a3 100644
--- a/src/Ryujinx.Graphics.Shader/Translation/Transforms/ForcePreciseEnable.cs
+++ b/src/Ryujinx.Graphics.Shader/Translation/Transforms/ForcePreciseEnable.cs
@@ -27,6 +27,8 @@ namespace Ryujinx.Graphics.Shader.Translation.Transforms
                 addOp.Inst == (Instruction.FP32 | Instruction.Add) &&
                 addOp.GetSource(1).Type == OperandType.Constant)
             {
+                context.UsedFeatures |= FeatureFlags.Precise;
+
                 addOp.ForcePrecise = true;
             }
 
diff --git a/src/Ryujinx.Graphics.Shader/Translation/TranslatorContext.cs b/src/Ryujinx.Graphics.Shader/Translation/TranslatorContext.cs
index 5ca17690e..bec20bc2c 100644
--- a/src/Ryujinx.Graphics.Shader/Translation/TranslatorContext.cs
+++ b/src/Ryujinx.Graphics.Shader/Translation/TranslatorContext.cs
@@ -332,6 +332,7 @@ namespace Ryujinx.Graphics.Shader.Translation
                 definitions,
                 resourceManager,
                 Options.TargetLanguage,
+                usedFeatures.HasFlag(FeatureFlags.Precise),
                 Options.Flags.HasFlag(TranslationFlags.DebugMode));
 
             int geometryVerticesPerPrimitive = Definitions.OutputTopology switch