diff --git a/Ryujinx.Memory/WindowsShared/IntervalTree.cs b/Ryujinx.Memory/WindowsShared/IntervalTree.cs deleted file mode 100644 index fe12e8b8..00000000 --- a/Ryujinx.Memory/WindowsShared/IntervalTree.cs +++ /dev/null @@ -1,453 +0,0 @@ -using Ryujinx.Common.Collections; -using System; -using System.Collections.Generic; - -namespace Ryujinx.Memory.WindowsShared -{ - /// - /// An Augmented Interval Tree based off of the "TreeDictionary"'s Red-Black Tree. Allows fast overlap checking of ranges. - /// - /// Key - /// Value - class IntervalTree : IntrusiveRedBlackTreeImpl> where K : IComparable - { - private const int ArrayGrowthSize = 32; - - #region Public Methods - - /// - /// Gets the values of the interval whose key is . - /// - /// Key of the node value to get - /// Value with the given - /// True if the key is on the dictionary, false otherwise - public bool TryGet(K key, out V value) - { - IntervalTreeNode node = GetNode(key); - - if (node == null) - { - value = default; - return false; - } - - value = node.Value; - return true; - } - - /// - /// Returns the start addresses of the intervals whose start and end keys overlap the given range. - /// - /// Start of the range - /// End of the range - /// Overlaps array to place results in - /// Index to start writing results into the array. Defaults to 0 - /// Number of intervals found - public int Get(K start, K end, ref IntervalTreeNode[] overlaps, int overlapCount = 0) - { - GetNodes(Root, start, end, ref overlaps, ref overlapCount); - - return overlapCount; - } - - /// - /// Adds a new interval into the tree whose start is , end is and value is . - /// - /// Start of the range to add - /// End of the range to insert - /// Value to add - /// is null - public void Add(K start, K end, V value) - { - if (value == null) - { - throw new ArgumentNullException(nameof(value)); - } - - BSTInsert(start, end, value, null, out _); - } - - /// - /// Removes a value from the tree, searching for it with . - /// - /// Key of the node to remove - /// Number of deleted values - public int Remove(K key) - { - return Remove(GetNode(key)); - } - - /// - /// Removes a value from the tree, searching for it with . - /// - /// Node to be removed - /// Number of deleted values - public int Remove(IntervalTreeNode nodeToDelete) - { - if (nodeToDelete == null) - { - return 0; - } - - Delete(nodeToDelete); - - Count--; - - return 1; - } - - /// - /// Adds all the nodes in the dictionary into . - /// - /// A list of all values sorted by Key Order - public List AsList() - { - List list = new List(); - - AddToList(Root, list); - - return list; - } - - #endregion - - #region Private Methods (BST) - - /// - /// Adds all values that are children of or contained within into , in Key Order. - /// - /// The node to search for values within - /// The list to add values to - private void AddToList(IntervalTreeNode node, List list) - { - if (node == null) - { - return; - } - - AddToList(node.Left, list); - - list.Add(node.Value); - - AddToList(node.Right, list); - } - - /// - /// Retrieve the node reference whose key is , or null if no such node exists. - /// - /// Key of the node to get - /// is null - /// Node reference in the tree - private IntervalTreeNode GetNode(K key) - { - if (key == null) - { - throw new ArgumentNullException(nameof(key)); - } - - IntervalTreeNode node = Root; - while (node != null) - { - int cmp = key.CompareTo(node.Start); - if (cmp < 0) - { - node = node.Left; - } - else if (cmp > 0) - { - node = node.Right; - } - else - { - return node; - } - } - return null; - } - - /// - /// Retrieve all nodes that overlap the given start and end keys. - /// - /// Start of the range - /// End of the range - /// Overlaps array to place results in - /// Overlaps count to update - private void GetNodes(IntervalTreeNode node, K start, K end, ref IntervalTreeNode[] overlaps, ref int overlapCount) - { - if (node == null || start.CompareTo(node.Max) >= 0) - { - return; - } - - GetNodes(node.Left, start, end, ref overlaps, ref overlapCount); - - bool endsOnRight = end.CompareTo(node.Start) > 0; - if (endsOnRight) - { - if (start.CompareTo(node.End) < 0) - { - if (overlaps.Length >= overlapCount) - { - Array.Resize(ref overlaps, overlapCount + ArrayGrowthSize); - } - - overlaps[overlapCount++] = node; - } - - GetNodes(node.Right, start, end, ref overlaps, ref overlapCount); - } - } - - /// - /// Propagate an increase in max value starting at the given node, heading up the tree. - /// This should only be called if the max increases - not for rebalancing or removals. - /// - /// The node to start propagating from - private void PropagateIncrease(IntervalTreeNode node) - { - K max = node.Max; - IntervalTreeNode ptr = node; - - while ((ptr = ptr.Parent) != null) - { - if (max.CompareTo(ptr.Max) > 0) - { - ptr.Max = max; - } - else - { - break; - } - } - } - - /// - /// Propagate recalculating max value starting at the given node, heading up the tree. - /// This fully recalculates the max value from all children when there is potential for it to decrease. - /// - /// The node to start propagating from - private void PropagateFull(IntervalTreeNode node) - { - IntervalTreeNode ptr = node; - - do - { - K max = ptr.End; - - if (ptr.Left != null && ptr.Left.Max.CompareTo(max) > 0) - { - max = ptr.Left.Max; - } - - if (ptr.Right != null && ptr.Right.Max.CompareTo(max) > 0) - { - max = ptr.Right.Max; - } - - ptr.Max = max; - } while ((ptr = ptr.Parent) != null); - } - - /// - /// Insertion Mechanism for the interval tree. Similar to a BST insert, with the start of the range as the key. - /// Iterates the tree starting from the root and inserts a new node where all children in the left subtree are less than , and all children in the right subtree are greater than . - /// Each node can contain multiple values, and has an end address which is the maximum of all those values. - /// Post insertion, the "max" value of the node and all parents are updated. - /// - /// Start of the range to insert - /// End of the range to insert - /// Value to insert - /// Optional factory used to create a new value if is already on the tree - /// Node that was inserted or modified - /// True if was not yet on the tree, false otherwise - private bool BSTInsert(K start, K end, V value, Func updateFactoryCallback, out IntervalTreeNode outNode) - { - IntervalTreeNode parent = null; - IntervalTreeNode node = Root; - - while (node != null) - { - parent = node; - int cmp = start.CompareTo(node.Start); - if (cmp < 0) - { - node = node.Left; - } - else if (cmp > 0) - { - node = node.Right; - } - else - { - outNode = node; - - if (updateFactoryCallback != null) - { - // Replace - node.Value = updateFactoryCallback(start, node.Value); - - int endCmp = end.CompareTo(node.End); - - if (endCmp > 0) - { - node.End = end; - if (end.CompareTo(node.Max) > 0) - { - node.Max = end; - PropagateIncrease(node); - RestoreBalanceAfterInsertion(node); - } - } - else if (endCmp < 0) - { - node.End = end; - PropagateFull(node); - } - } - - return false; - } - } - IntervalTreeNode newNode = new IntervalTreeNode(start, end, value, parent); - if (newNode.Parent == null) - { - Root = newNode; - } - else if (start.CompareTo(parent.Start) < 0) - { - parent.Left = newNode; - } - else - { - parent.Right = newNode; - } - - PropagateIncrease(newNode); - Count++; - RestoreBalanceAfterInsertion(newNode); - outNode = newNode; - return true; - } - - /// - /// Removes the value from the dictionary after searching for it with . - /// - /// Tree node to be removed - private void Delete(IntervalTreeNode nodeToDelete) - { - IntervalTreeNode replacementNode; - - if (LeftOf(nodeToDelete) == null || RightOf(nodeToDelete) == null) - { - replacementNode = nodeToDelete; - } - else - { - replacementNode = nodeToDelete.Predecessor; - } - - IntervalTreeNode tmp = LeftOf(replacementNode) ?? RightOf(replacementNode); - - if (tmp != null) - { - tmp.Parent = ParentOf(replacementNode); - } - - if (ParentOf(replacementNode) == null) - { - Root = tmp; - } - else if (replacementNode == LeftOf(ParentOf(replacementNode))) - { - ParentOf(replacementNode).Left = tmp; - } - else - { - ParentOf(replacementNode).Right = tmp; - } - - if (replacementNode != nodeToDelete) - { - nodeToDelete.Start = replacementNode.Start; - nodeToDelete.Value = replacementNode.Value; - nodeToDelete.End = replacementNode.End; - nodeToDelete.Max = replacementNode.Max; - } - - PropagateFull(replacementNode); - - if (tmp != null && ColorOf(replacementNode) == Black) - { - RestoreBalanceAfterRemoval(tmp); - } - } - - #endregion - - #region Private Methods (RBL) - - protected override void RotateLeft(IntervalTreeNode node) - { - if (node != null) - { - base.RotateLeft(node); - - PropagateFull(node); - } - } - - protected override void RotateRight(IntervalTreeNode node) - { - if (node != null) - { - base.RotateRight(node); - - PropagateFull(node); - } - } - - #endregion - - public bool ContainsKey(K key) - { - return GetNode(key) != null; - } - } - - /// - /// Represents a node in the IntervalTree which contains start and end keys of type K, and a value of generic type V. - /// - /// Key type of the node - /// Value type of the node - class IntervalTreeNode : IntrusiveRedBlackTreeNode> - { - /// - /// The start of the range. - /// - public K Start; - - /// - /// The end of the range. - /// - public K End; - - /// - /// The maximum end value of this node and all its children. - /// - public K Max; - - /// - /// Value stored on this node. - /// - public V Value; - - public IntervalTreeNode(K start, K end, V value, IntervalTreeNode parent) - { - Start = start; - End = end; - Max = end; - Value = value; - Parent = parent; - } - } -} diff --git a/Ryujinx.Memory/WindowsShared/MappingTree.cs b/Ryujinx.Memory/WindowsShared/MappingTree.cs new file mode 100644 index 00000000..8f880f0c --- /dev/null +++ b/Ryujinx.Memory/WindowsShared/MappingTree.cs @@ -0,0 +1,69 @@ +using Ryujinx.Common.Collections; +using System; + +namespace Ryujinx.Memory.WindowsShared +{ + /// + /// A intrusive Red-Black Tree that also supports getting nodes overlapping a given range. + /// + /// Type of the value stored on the node + class MappingTree : IntrusiveRedBlackTree> + { + public int GetNodes(ulong start, ulong end, ref RangeNode[] overlaps, int overlapCount = 0) + { + RangeNode node = GetNode(new RangeNode(start, start + 1UL, default)); + + for (; node != null; node = node.Successor) + { + if (overlaps.Length <= overlapCount) + { + Array.Resize(ref overlaps, overlapCount + 1); + } + + overlaps[overlapCount++] = node; + + if (node.End >= end) + { + break; + } + } + + return overlapCount; + } + } + + class RangeNode : IntrusiveRedBlackTreeNode>, IComparable> + { + public ulong Start { get; } + public ulong End { get; private set; } + public T Value { get; } + + public RangeNode(ulong start, ulong end, T value) + { + Start = start; + End = end; + Value = value; + } + + public void Extend(ulong sizeDelta) + { + End += sizeDelta; + } + + public int CompareTo(RangeNode other) + { + if (Start < other.Start) + { + return -1; + } + else if (Start <= other.End - 1UL) + { + return 0; + } + else + { + return 1; + } + } + } +} \ No newline at end of file diff --git a/Ryujinx.Memory/WindowsShared/PlaceholderManager.cs b/Ryujinx.Memory/WindowsShared/PlaceholderManager.cs index 0937d462..6db8d7df 100644 --- a/Ryujinx.Memory/WindowsShared/PlaceholderManager.cs +++ b/Ryujinx.Memory/WindowsShared/PlaceholderManager.cs @@ -13,10 +13,10 @@ namespace Ryujinx.Memory.WindowsShared [SupportedOSPlatform("windows")] class PlaceholderManager { - private const ulong MinimumPageSize = 0x1000; + private const int InitialOverlapsSize = 10; - private readonly IntervalTree _mappings; - private readonly IntervalTree _protections; + private readonly MappingTree _mappings; + private readonly MappingTree _protections; private readonly IntPtr _partialUnmapStatePtr; private readonly Thread _partialUnmapTrimThread; @@ -25,8 +25,8 @@ namespace Ryujinx.Memory.WindowsShared /// public PlaceholderManager() { - _mappings = new IntervalTree(); - _protections = new IntervalTree(); + _mappings = new MappingTree(); + _protections = new MappingTree(); _partialUnmapStatePtr = PartialUnmapState.GlobalState; @@ -67,7 +67,7 @@ namespace Ryujinx.Memory.WindowsShared { lock (_mappings) { - _mappings.Add(address, address + size, ulong.MaxValue); + _mappings.Add(new RangeNode(address, address + size, ulong.MaxValue)); } } @@ -81,12 +81,12 @@ namespace Ryujinx.Memory.WindowsShared { ulong endAddress = address + size; - var overlaps = Array.Empty>(); + var overlaps = new RangeNode[InitialOverlapsSize]; int count; lock (_mappings) { - count = _mappings.Get(address, endAddress, ref overlaps); + count = _mappings.GetNodes(address, endAddress, ref overlaps); for (int index = 0; index < count; index++) { @@ -178,11 +178,11 @@ namespace Ryujinx.Memory.WindowsShared { ulong endAddress = address + size; - var overlaps = Array.Empty>(); + var overlaps = new RangeNode[InitialOverlapsSize]; lock (_mappings) { - int count = _mappings.Get(address, endAddress, ref overlaps); + int count = _mappings.GetNodes(address, endAddress, ref overlaps); Debug.Assert(count == 1); Debug.Assert(!IsMapped(overlaps[0].Value)); @@ -206,8 +206,8 @@ namespace Ryujinx.Memory.WindowsShared (IntPtr)size, AllocationType.Release | AllocationType.PreservePlaceholder)); - _mappings.Add(overlapStart, address, overlapValue); - _mappings.Add(endAddress, overlapEnd, AddBackingOffset(overlapValue, endAddress - overlapStart)); + _mappings.Add(new RangeNode(overlapStart, address, overlapValue)); + _mappings.Add(new RangeNode(endAddress, overlapEnd, AddBackingOffset(overlapValue, endAddress - overlapStart))); } else if (overlapStartsBefore) { @@ -218,7 +218,7 @@ namespace Ryujinx.Memory.WindowsShared (IntPtr)overlappedSize, AllocationType.Release | AllocationType.PreservePlaceholder)); - _mappings.Add(overlapStart, address, overlapValue); + _mappings.Add(new RangeNode(overlapStart, address, overlapValue)); } else if (overlapEndsAfter) { @@ -229,10 +229,10 @@ namespace Ryujinx.Memory.WindowsShared (IntPtr)overlappedSize, AllocationType.Release | AllocationType.PreservePlaceholder)); - _mappings.Add(endAddress, overlapEnd, AddBackingOffset(overlapValue, overlappedSize)); + _mappings.Add(new RangeNode(endAddress, overlapEnd, AddBackingOffset(overlapValue, overlappedSize))); } - _mappings.Add(address, endAddress, backingOffset); + _mappings.Add(new RangeNode(address, endAddress, backingOffset)); } } @@ -280,12 +280,12 @@ namespace Ryujinx.Memory.WindowsShared ulong unmapSize = (ulong)size; ulong endAddress = startAddress + unmapSize; - var overlaps = Array.Empty>(); + var overlaps = new RangeNode[InitialOverlapsSize]; int count; lock (_mappings) { - count = _mappings.Get(startAddress, endAddress, ref overlaps); + count = _mappings.GetNodes(startAddress, endAddress, ref overlaps); } for (int index = 0; index < count; index++) @@ -302,7 +302,7 @@ namespace Ryujinx.Memory.WindowsShared lock (_mappings) { _mappings.Remove(overlap); - _mappings.Add(overlapStart, overlapEnd, ulong.MaxValue); + _mappings.Add(new RangeNode(overlapStart, overlapEnd, ulong.MaxValue)); } bool overlapStartsBefore = overlapStart < startAddress; @@ -374,44 +374,53 @@ namespace Ryujinx.Memory.WindowsShared ulong endAddress = address + size; ulong blockAddress = (ulong)owner.Pointer; ulong blockEnd = blockAddress + owner.Size; - var overlaps = Array.Empty>(); + var overlaps = new RangeNode[InitialOverlapsSize]; int unmappedCount = 0; lock (_mappings) { - int count = _mappings.Get( - Math.Max(address - MinimumPageSize, blockAddress), - Math.Min(endAddress + MinimumPageSize, blockEnd), ref overlaps); + int count = _mappings.GetNodes(address, endAddress, ref overlaps); - if (count < 2) + if (count == 0) { - // Nothing to coalesce if we only have 1 or no overlaps. + // Nothing to coalesce if we no overlaps. return; } + RangeNode predecessor = overlaps[0].Predecessor; + RangeNode successor = overlaps[count - 1].Successor; + for (int index = 0; index < count; index++) { var overlap = overlaps[index]; if (!IsMapped(overlap.Value)) { - if (address > overlap.Start) - { - address = overlap.Start; - } - - if (endAddress < overlap.End) - { - endAddress = overlap.End; - } + address = Math.Min(address, overlap.Start); + endAddress = Math.Max(endAddress, overlap.End); _mappings.Remove(overlap); - unmappedCount++; } } - _mappings.Add(address, endAddress, ulong.MaxValue); + if (predecessor != null && !IsMapped(predecessor.Value) && predecessor.Start >= blockAddress) + { + address = Math.Min(address, predecessor.Start); + + _mappings.Remove(predecessor); + unmappedCount++; + } + + if (successor != null && !IsMapped(successor.Value) && successor.End <= blockEnd) + { + endAddress = Math.Max(endAddress, successor.End); + + _mappings.Remove(successor); + unmappedCount++; + } + + _mappings.Add(new RangeNode(address, endAddress, ulong.MaxValue)); } if (unmappedCount > 1) @@ -462,12 +471,12 @@ namespace Ryujinx.Memory.WindowsShared ulong reprotectSize = (ulong)size; ulong endAddress = reprotectAddress + reprotectSize; - var overlaps = Array.Empty>(); + var overlaps = new RangeNode[InitialOverlapsSize]; int count; lock (_mappings) { - count = _mappings.Get(reprotectAddress, endAddress, ref overlaps); + count = _mappings.GetNodes(reprotectAddress, endAddress, ref overlaps); } bool success = true; @@ -567,12 +576,12 @@ namespace Ryujinx.Memory.WindowsShared private void AddProtection(ulong address, ulong size, MemoryPermission permission) { ulong endAddress = address + size; - var overlaps = Array.Empty>(); + var overlaps = new RangeNode[InitialOverlapsSize]; int count; lock (_protections) { - count = _protections.Get(address, endAddress, ref overlaps); + count = _protections.GetNodes(address, endAddress, ref overlaps); if (count == 1 && overlaps[0].Start <= address && @@ -610,17 +619,17 @@ namespace Ryujinx.Memory.WindowsShared { if (startAddress > protAddress) { - _protections.Add(protAddress, startAddress, protPermission); + _protections.Add(new RangeNode(protAddress, startAddress, protPermission)); } if (endAddress < protEndAddress) { - _protections.Add(endAddress, protEndAddress, protPermission); + _protections.Add(new RangeNode(endAddress, protEndAddress, protPermission)); } } } - _protections.Add(startAddress, endAddress, permission); + _protections.Add(new RangeNode(startAddress, endAddress, permission)); } } @@ -632,12 +641,12 @@ namespace Ryujinx.Memory.WindowsShared private void RemoveProtection(ulong address, ulong size) { ulong endAddress = address + size; - var overlaps = Array.Empty>(); + var overlaps = new RangeNode[InitialOverlapsSize]; int count; lock (_protections) { - count = _protections.Get(address, endAddress, ref overlaps); + count = _protections.GetNodes(address, endAddress, ref overlaps); for (int index = 0; index < count; index++) { @@ -651,12 +660,12 @@ namespace Ryujinx.Memory.WindowsShared if (address > protAddress) { - _protections.Add(protAddress, address, protPermission); + _protections.Add(new RangeNode(protAddress, address, protPermission)); } if (endAddress < protEndAddress) { - _protections.Add(endAddress, protEndAddress, protPermission); + _protections.Add(new RangeNode(endAddress, protEndAddress, protPermission)); } } } @@ -670,12 +679,12 @@ namespace Ryujinx.Memory.WindowsShared private void RestoreRangeProtection(ulong address, ulong size) { ulong endAddress = address + size; - var overlaps = Array.Empty>(); + var overlaps = new RangeNode[InitialOverlapsSize]; int count; lock (_protections) { - count = _protections.Get(address, endAddress, ref overlaps); + count = _protections.GetNodes(address, endAddress, ref overlaps); } ulong startAddress = address;