using Ryujinx.Graphics.Shader.IntermediateRepresentation; using Ryujinx.Graphics.Shader.Translation; using System.Collections.Generic; using System.Linq; using static Ryujinx.Graphics.Shader.StructuredIr.AstHelper; namespace Ryujinx.Graphics.Shader.StructuredIr { static class AstOptimizer { public static void Optimize(StructuredProgramContext context) { AstBlock mainBlock = context.CurrentFunction.MainBlock; // When debug mode is enabled, we disable expression propagation // (this makes comparison with the disassembly easier). if (!context.Config.Options.Flags.HasFlag(TranslationFlags.DebugMode)) { AstBlockVisitor visitor = new AstBlockVisitor(mainBlock); foreach (IAstNode node in visitor.Visit()) { if (node is AstAssignment assignment && assignment.Destination is AstOperand propVar) { bool isWorthPropagating = propVar.Uses.Count == 1 || IsWorthPropagating(assignment.Source); if (propVar.Defs.Count == 1 && isWorthPropagating) { PropagateExpression(propVar, assignment.Source); } if (propVar.Type == OperandType.LocalVariable && propVar.Uses.Count == 0) { visitor.Block.Remove(assignment); context.CurrentFunction.Locals.Remove(propVar); } } } } RemoveEmptyBlocks(mainBlock); } private static bool IsWorthPropagating(IAstNode source) { if (!(source is AstOperation srcOp)) { return false; } if (!InstructionInfo.IsUnary(srcOp.Inst)) { return false; } return srcOp.GetSource(0) is AstOperand || srcOp.Inst == Instruction.Copy; } private static void PropagateExpression(AstOperand propVar, IAstNode source) { IAstNode[] uses = propVar.Uses.ToArray(); foreach (IAstNode useNode in uses) { if (useNode is AstBlock useBlock) { useBlock.Condition = source; } else if (useNode is AstOperation useOperation) { for (int srcIndex = 0; srcIndex < useOperation.SourcesCount; srcIndex++) { if (useOperation.GetSource(srcIndex) == propVar) { useOperation.SetSource(srcIndex, source); } } } else if (useNode is AstAssignment useAssignment) { useAssignment.Source = source; } } } private static void RemoveEmptyBlocks(AstBlock mainBlock) { Queue<AstBlock> pending = new Queue<AstBlock>(); pending.Enqueue(mainBlock); while (pending.TryDequeue(out AstBlock block)) { foreach (IAstNode node in block) { if (node is AstBlock childBlock) { pending.Enqueue(childBlock); } } AstBlock parent = block.Parent; if (parent == null) { continue; } AstBlock nextBlock = Next(block) as AstBlock; bool hasElse = nextBlock != null && nextBlock.Type == AstBlockType.Else; bool isIf = block.Type == AstBlockType.If; if (block.Count == 0) { if (isIf) { if (hasElse) { nextBlock.TurnIntoIf(InverseCond(block.Condition)); } parent.Remove(block); } else if (block.Type == AstBlockType.Else) { parent.Remove(block); } } else if (isIf && parent.Type == AstBlockType.Else && parent.Count == (hasElse ? 2 : 1)) { AstBlock parentOfParent = parent.Parent; parent.Remove(block); parentOfParent.AddAfter(parent, block); if (hasElse) { parent.Remove(nextBlock); parentOfParent.AddAfter(block, nextBlock); } parentOfParent.Remove(parent); block.TurnIntoElseIf(); } } } } }