0
0
Fork 0
mirror of https://github.com/GreemDev/Ryujinx.git synced 2025-01-05 16:11:59 +00:00

Separate guest/host tracking + unaligned protection (#6486)

* WIP: Separate guest/host tracking + unaligned protection

Allow memory manager to define support for single byte guest tracking

* Formatting

* Improve docs

* Properly handle cases where the address space bits are too low

* Address feedback
This commit is contained in:
riperiperi 2024-03-14 22:38:27 +00:00 committed by GitHub
parent ce607db944
commit fdd3263e31
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 774 additions and 763 deletions

View file

@ -91,54 +91,54 @@ namespace ARMeilleure.Instructions
#region "Read" #region "Read"
public static byte ReadByte(ulong address) public static byte ReadByte(ulong address)
{ {
return GetMemoryManager().ReadTracked<byte>(address); return GetMemoryManager().ReadGuest<byte>(address);
} }
public static ushort ReadUInt16(ulong address) public static ushort ReadUInt16(ulong address)
{ {
return GetMemoryManager().ReadTracked<ushort>(address); return GetMemoryManager().ReadGuest<ushort>(address);
} }
public static uint ReadUInt32(ulong address) public static uint ReadUInt32(ulong address)
{ {
return GetMemoryManager().ReadTracked<uint>(address); return GetMemoryManager().ReadGuest<uint>(address);
} }
public static ulong ReadUInt64(ulong address) public static ulong ReadUInt64(ulong address)
{ {
return GetMemoryManager().ReadTracked<ulong>(address); return GetMemoryManager().ReadGuest<ulong>(address);
} }
public static V128 ReadVector128(ulong address) public static V128 ReadVector128(ulong address)
{ {
return GetMemoryManager().ReadTracked<V128>(address); return GetMemoryManager().ReadGuest<V128>(address);
} }
#endregion #endregion
#region "Write" #region "Write"
public static void WriteByte(ulong address, byte value) public static void WriteByte(ulong address, byte value)
{ {
GetMemoryManager().Write(address, value); GetMemoryManager().WriteGuest(address, value);
} }
public static void WriteUInt16(ulong address, ushort value) public static void WriteUInt16(ulong address, ushort value)
{ {
GetMemoryManager().Write(address, value); GetMemoryManager().WriteGuest(address, value);
} }
public static void WriteUInt32(ulong address, uint value) public static void WriteUInt32(ulong address, uint value)
{ {
GetMemoryManager().Write(address, value); GetMemoryManager().WriteGuest(address, value);
} }
public static void WriteUInt64(ulong address, ulong value) public static void WriteUInt64(ulong address, ulong value)
{ {
GetMemoryManager().Write(address, value); GetMemoryManager().WriteGuest(address, value);
} }
public static void WriteVector128(ulong address, V128 value) public static void WriteVector128(ulong address, V128 value)
{ {
GetMemoryManager().Write(address, value); GetMemoryManager().WriteGuest(address, value);
} }
#endregion #endregion

View file

@ -28,6 +28,17 @@ namespace ARMeilleure.Memory
/// <returns>The data</returns> /// <returns>The data</returns>
T ReadTracked<T>(ulong va) where T : unmanaged; T ReadTracked<T>(ulong va) where T : unmanaged;
/// <summary>
/// Reads data from CPU mapped memory, from guest code. (with read tracking)
/// </summary>
/// <typeparam name="T">Type of the data being read</typeparam>
/// <param name="va">Virtual address of the data in memory</param>
/// <returns>The data</returns>
T ReadGuest<T>(ulong va) where T : unmanaged
{
return ReadTracked<T>(va);
}
/// <summary> /// <summary>
/// Writes data to CPU mapped memory. /// Writes data to CPU mapped memory.
/// </summary> /// </summary>
@ -36,6 +47,17 @@ namespace ARMeilleure.Memory
/// <param name="value">Data to be written</param> /// <param name="value">Data to be written</param>
void Write<T>(ulong va, T value) where T : unmanaged; void Write<T>(ulong va, T value) where T : unmanaged;
/// <summary>
/// Writes data to CPU mapped memory, from guest code.
/// </summary>
/// <typeparam name="T">Type of the data being written</typeparam>
/// <param name="va">Virtual address to write the data into</param>
/// <param name="value">Data to be written</param>
void WriteGuest<T>(ulong va, T value) where T : unmanaged
{
Write(va, value);
}
/// <summary> /// <summary>
/// Gets a read-only span of data from CPU mapped memory. /// Gets a read-only span of data from CPU mapped memory.
/// </summary> /// </summary>

View file

@ -8,7 +8,6 @@ using System.Linq;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
using System.Runtime.Versioning; using System.Runtime.Versioning;
using System.Threading;
namespace Ryujinx.Cpu.AppleHv namespace Ryujinx.Cpu.AppleHv
{ {
@ -18,21 +17,6 @@ namespace Ryujinx.Cpu.AppleHv
[SupportedOSPlatform("macos")] [SupportedOSPlatform("macos")]
public class HvMemoryManager : VirtualMemoryManagerRefCountedBase<ulong, ulong>, IMemoryManager, IVirtualMemoryManagerTracked, IWritableBlock public class HvMemoryManager : VirtualMemoryManagerRefCountedBase<ulong, ulong>, IMemoryManager, IVirtualMemoryManagerTracked, IWritableBlock
{ {
public const int PageToPteShift = 5; // 32 pages (2 bits each) in one ulong page table entry.
public const ulong BlockMappedMask = 0x5555555555555555; // First bit of each table entry set.
private enum HostMappedPtBits : ulong
{
Unmapped = 0,
Mapped,
WriteTracked,
ReadWriteTracked,
MappedReplicated = 0x5555555555555555,
WriteTrackedReplicated = 0xaaaaaaaaaaaaaaaa,
ReadWriteTrackedReplicated = ulong.MaxValue,
}
private readonly InvalidAccessHandler _invalidAccessHandler; private readonly InvalidAccessHandler _invalidAccessHandler;
private readonly HvAddressSpace _addressSpace; private readonly HvAddressSpace _addressSpace;
@ -42,7 +26,7 @@ namespace Ryujinx.Cpu.AppleHv
private readonly MemoryBlock _backingMemory; private readonly MemoryBlock _backingMemory;
private readonly PageTable<ulong> _pageTable; private readonly PageTable<ulong> _pageTable;
private readonly ulong[] _pageBitmap; private readonly ManagedPageFlags _pages;
public bool Supports4KBPages => true; public bool Supports4KBPages => true;
@ -84,7 +68,7 @@ namespace Ryujinx.Cpu.AppleHv
AddressSpaceBits = asBits; AddressSpaceBits = asBits;
_pageBitmap = new ulong[1 << (AddressSpaceBits - (PageBits + PageToPteShift))]; _pages = new ManagedPageFlags(AddressSpaceBits);
Tracking = new MemoryTracking(this, PageSize, invalidAccessHandler); Tracking = new MemoryTracking(this, PageSize, invalidAccessHandler);
} }
@ -95,7 +79,7 @@ namespace Ryujinx.Cpu.AppleHv
PtMap(va, pa, size); PtMap(va, pa, size);
_addressSpace.MapUser(va, pa, size, MemoryPermission.ReadWriteExecute); _addressSpace.MapUser(va, pa, size, MemoryPermission.ReadWriteExecute);
AddMapping(va, size); _pages.AddMapping(va, size);
Tracking.Map(va, size); Tracking.Map(va, size);
} }
@ -126,7 +110,7 @@ namespace Ryujinx.Cpu.AppleHv
UnmapEvent?.Invoke(va, size); UnmapEvent?.Invoke(va, size);
Tracking.Unmap(va, size); Tracking.Unmap(va, size);
RemoveMapping(va, size); _pages.RemoveMapping(va, size);
_addressSpace.UnmapUser(va, size); _addressSpace.UnmapUser(va, size);
PtUnmap(va, size); PtUnmap(va, size);
} }
@ -360,22 +344,7 @@ namespace Ryujinx.Cpu.AppleHv
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public bool IsMapped(ulong va) public bool IsMapped(ulong va)
{ {
return ValidateAddress(va) && IsMappedImpl(va); return ValidateAddress(va) && _pages.IsMapped(va);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private bool IsMappedImpl(ulong va)
{
ulong page = va >> PageBits;
int bit = (int)((page & 31) << 1);
int pageIndex = (int)(page >> PageToPteShift);
ref ulong pageRef = ref _pageBitmap[pageIndex];
ulong pte = Volatile.Read(ref pageRef);
return ((pte >> bit) & 3) != 0;
} }
/// <inheritdoc/> /// <inheritdoc/>
@ -383,58 +352,7 @@ namespace Ryujinx.Cpu.AppleHv
{ {
AssertValidAddressAndSize(va, size); AssertValidAddressAndSize(va, size);
return IsRangeMappedImpl(va, size); return _pages.IsRangeMapped(va, size);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static void GetPageBlockRange(ulong pageStart, ulong pageEnd, out ulong startMask, out ulong endMask, out int pageIndex, out int pageEndIndex)
{
startMask = ulong.MaxValue << ((int)(pageStart & 31) << 1);
endMask = ulong.MaxValue >> (64 - ((int)(pageEnd & 31) << 1));
pageIndex = (int)(pageStart >> PageToPteShift);
pageEndIndex = (int)((pageEnd - 1) >> PageToPteShift);
}
private bool IsRangeMappedImpl(ulong va, ulong size)
{
int pages = GetPagesCount(va, size, out _);
if (pages == 1)
{
return IsMappedImpl(va);
}
ulong pageStart = va >> PageBits;
ulong pageEnd = pageStart + (ulong)pages;
GetPageBlockRange(pageStart, pageEnd, out ulong startMask, out ulong endMask, out int pageIndex, out int pageEndIndex);
// Check if either bit in each 2 bit page entry is set.
// OR the block with itself shifted down by 1, and check the first bit of each entry.
ulong mask = BlockMappedMask & startMask;
while (pageIndex <= pageEndIndex)
{
if (pageIndex == pageEndIndex)
{
mask &= endMask;
}
ref ulong pageRef = ref _pageBitmap[pageIndex++];
ulong pte = Volatile.Read(ref pageRef);
pte |= pte >> 1;
if ((pte & mask) != mask)
{
return false;
}
mask = BlockMappedMask;
}
return true;
} }
private static void ThrowMemoryNotContiguous() => throw new MemoryNotContiguousException(); private static void ThrowMemoryNotContiguous() => throw new MemoryNotContiguousException();
@ -560,76 +478,7 @@ namespace Ryujinx.Cpu.AppleHv
return; return;
} }
// Software table, used for managed memory tracking. _pages.SignalMemoryTracking(Tracking, va, size, write, exemptId);
int pages = GetPagesCount(va, size, out _);
ulong pageStart = va >> PageBits;
if (pages == 1)
{
ulong tag = (ulong)(write ? HostMappedPtBits.WriteTracked : HostMappedPtBits.ReadWriteTracked);
int bit = (int)((pageStart & 31) << 1);
int pageIndex = (int)(pageStart >> PageToPteShift);
ref ulong pageRef = ref _pageBitmap[pageIndex];
ulong pte = Volatile.Read(ref pageRef);
ulong state = ((pte >> bit) & 3);
if (state >= tag)
{
Tracking.VirtualMemoryEvent(va, size, write, precise: false, exemptId);
return;
}
else if (state == 0)
{
ThrowInvalidMemoryRegionException($"Not mapped: va=0x{va:X16}, size=0x{size:X16}");
}
}
else
{
ulong pageEnd = pageStart + (ulong)pages;
GetPageBlockRange(pageStart, pageEnd, out ulong startMask, out ulong endMask, out int pageIndex, out int pageEndIndex);
ulong mask = startMask;
ulong anyTrackingTag = (ulong)HostMappedPtBits.WriteTrackedReplicated;
while (pageIndex <= pageEndIndex)
{
if (pageIndex == pageEndIndex)
{
mask &= endMask;
}
ref ulong pageRef = ref _pageBitmap[pageIndex++];
ulong pte = Volatile.Read(ref pageRef);
ulong mappedMask = mask & BlockMappedMask;
ulong mappedPte = pte | (pte >> 1);
if ((mappedPte & mappedMask) != mappedMask)
{
ThrowInvalidMemoryRegionException($"Not mapped: va=0x{va:X16}, size=0x{size:X16}");
}
pte &= mask;
if ((pte & anyTrackingTag) != 0) // Search for any tracking.
{
// Writes trigger any tracking.
// Only trigger tracking from reads if both bits are set on any page.
if (write || (pte & (pte >> 1) & BlockMappedMask) != 0)
{
Tracking.VirtualMemoryEvent(va, size, write, precise: false, exemptId);
break;
}
}
mask = ulong.MaxValue;
}
}
} }
/// <summary> /// <summary>
@ -656,103 +505,28 @@ namespace Ryujinx.Cpu.AppleHv
} }
/// <inheritdoc/> /// <inheritdoc/>
public void TrackingReprotect(ulong va, ulong size, MemoryPermission protection) public void TrackingReprotect(ulong va, ulong size, MemoryPermission protection, bool guest)
{ {
// Protection is inverted on software pages, since the default value is 0. if (guest)
protection = (~protection) & MemoryPermission.ReadAndWrite;
int pages = GetPagesCount(va, size, out va);
ulong pageStart = va >> PageBits;
if (pages == 1)
{ {
ulong protTag = protection switch _addressSpace.ReprotectUser(va, size, protection);
{
MemoryPermission.None => (ulong)HostMappedPtBits.Mapped,
MemoryPermission.Write => (ulong)HostMappedPtBits.WriteTracked,
_ => (ulong)HostMappedPtBits.ReadWriteTracked,
};
int bit = (int)((pageStart & 31) << 1);
ulong tagMask = 3UL << bit;
ulong invTagMask = ~tagMask;
ulong tag = protTag << bit;
int pageIndex = (int)(pageStart >> PageToPteShift);
ref ulong pageRef = ref _pageBitmap[pageIndex];
ulong pte;
do
{
pte = Volatile.Read(ref pageRef);
}
while ((pte & tagMask) != 0 && Interlocked.CompareExchange(ref pageRef, (pte & invTagMask) | tag, pte) != pte);
} }
else else
{ {
ulong pageEnd = pageStart + (ulong)pages; _pages.TrackingReprotect(va, size, protection);
GetPageBlockRange(pageStart, pageEnd, out ulong startMask, out ulong endMask, out int pageIndex, out int pageEndIndex);
ulong mask = startMask;
ulong protTag = protection switch
{
MemoryPermission.None => (ulong)HostMappedPtBits.MappedReplicated,
MemoryPermission.Write => (ulong)HostMappedPtBits.WriteTrackedReplicated,
_ => (ulong)HostMappedPtBits.ReadWriteTrackedReplicated,
};
while (pageIndex <= pageEndIndex)
{
if (pageIndex == pageEndIndex)
{
mask &= endMask;
} }
ref ulong pageRef = ref _pageBitmap[pageIndex++];
ulong pte;
ulong mappedMask;
// Change the protection of all 2 bit entries that are mapped.
do
{
pte = Volatile.Read(ref pageRef);
mappedMask = pte | (pte >> 1);
mappedMask |= (mappedMask & BlockMappedMask) << 1;
mappedMask &= mask; // Only update mapped pages within the given range.
}
while (Interlocked.CompareExchange(ref pageRef, (pte & (~mappedMask)) | (protTag & mappedMask), pte) != pte);
mask = ulong.MaxValue;
}
}
protection = protection switch
{
MemoryPermission.None => MemoryPermission.ReadAndWrite,
MemoryPermission.Write => MemoryPermission.Read,
_ => MemoryPermission.None,
};
_addressSpace.ReprotectUser(va, size, protection);
} }
/// <inheritdoc/> /// <inheritdoc/>
public RegionHandle BeginTracking(ulong address, ulong size, int id) public RegionHandle BeginTracking(ulong address, ulong size, int id, RegionFlags flags = RegionFlags.None)
{ {
return Tracking.BeginTracking(address, size, id); return Tracking.BeginTracking(address, size, id, flags);
} }
/// <inheritdoc/> /// <inheritdoc/>
public MultiRegionHandle BeginGranularTracking(ulong address, ulong size, IEnumerable<IRegionHandle> handles, ulong granularity, int id) public MultiRegionHandle BeginGranularTracking(ulong address, ulong size, IEnumerable<IRegionHandle> handles, ulong granularity, int id, RegionFlags flags = RegionFlags.None)
{ {
return Tracking.BeginGranularTracking(address, size, handles, granularity, id); return Tracking.BeginGranularTracking(address, size, handles, granularity, id, flags);
} }
/// <inheritdoc/> /// <inheritdoc/>
@ -761,86 +535,6 @@ namespace Ryujinx.Cpu.AppleHv
return Tracking.BeginSmartGranularTracking(address, size, granularity, id); return Tracking.BeginSmartGranularTracking(address, size, granularity, id);
} }
/// <summary>
/// Adds the given address mapping to the page table.
/// </summary>
/// <param name="va">Virtual memory address</param>
/// <param name="size">Size to be mapped</param>
private void AddMapping(ulong va, ulong size)
{
int pages = GetPagesCount(va, size, out _);
ulong pageStart = va >> PageBits;
ulong pageEnd = pageStart + (ulong)pages;
GetPageBlockRange(pageStart, pageEnd, out ulong startMask, out ulong endMask, out int pageIndex, out int pageEndIndex);
ulong mask = startMask;
while (pageIndex <= pageEndIndex)
{
if (pageIndex == pageEndIndex)
{
mask &= endMask;
}
ref ulong pageRef = ref _pageBitmap[pageIndex++];
ulong pte;
ulong mappedMask;
// Map all 2-bit entries that are unmapped.
do
{
pte = Volatile.Read(ref pageRef);
mappedMask = pte | (pte >> 1);
mappedMask |= (mappedMask & BlockMappedMask) << 1;
mappedMask |= ~mask; // Treat everything outside the range as mapped, thus unchanged.
}
while (Interlocked.CompareExchange(ref pageRef, (pte & mappedMask) | (BlockMappedMask & (~mappedMask)), pte) != pte);
mask = ulong.MaxValue;
}
}
/// <summary>
/// Removes the given address mapping from the page table.
/// </summary>
/// <param name="va">Virtual memory address</param>
/// <param name="size">Size to be unmapped</param>
private void RemoveMapping(ulong va, ulong size)
{
int pages = GetPagesCount(va, size, out _);
ulong pageStart = va >> PageBits;
ulong pageEnd = pageStart + (ulong)pages;
GetPageBlockRange(pageStart, pageEnd, out ulong startMask, out ulong endMask, out int pageIndex, out int pageEndIndex);
startMask = ~startMask;
endMask = ~endMask;
ulong mask = startMask;
while (pageIndex <= pageEndIndex)
{
if (pageIndex == pageEndIndex)
{
mask |= endMask;
}
ref ulong pageRef = ref _pageBitmap[pageIndex++];
ulong pte;
do
{
pte = Volatile.Read(ref pageRef);
}
while (Interlocked.CompareExchange(ref pageRef, pte & mask, pte) != pte);
mask = 0;
}
}
private ulong GetPhysicalAddressChecked(ulong va) private ulong GetPhysicalAddressChecked(ulong va)
{ {
if (!IsMapped(va)) if (!IsMapped(va))

View file

@ -28,8 +28,9 @@ namespace Ryujinx.Cpu
/// <param name="address">CPU virtual address of the region</param> /// <param name="address">CPU virtual address of the region</param>
/// <param name="size">Size of the region</param> /// <param name="size">Size of the region</param>
/// <param name="id">Handle ID</param> /// <param name="id">Handle ID</param>
/// <param name="flags">Region flags</param>
/// <returns>The memory tracking handle</returns> /// <returns>The memory tracking handle</returns>
RegionHandle BeginTracking(ulong address, ulong size, int id); RegionHandle BeginTracking(ulong address, ulong size, int id, RegionFlags flags = RegionFlags.None);
/// <summary> /// <summary>
/// Obtains a memory tracking handle for the given virtual region, with a specified granularity. This should be disposed when finished with. /// Obtains a memory tracking handle for the given virtual region, with a specified granularity. This should be disposed when finished with.
@ -39,8 +40,9 @@ namespace Ryujinx.Cpu
/// <param name="handles">Handles to inherit state from or reuse. When none are present, provide null</param> /// <param name="handles">Handles to inherit state from or reuse. When none are present, provide null</param>
/// <param name="granularity">Desired granularity of write tracking</param> /// <param name="granularity">Desired granularity of write tracking</param>
/// <param name="id">Handle ID</param> /// <param name="id">Handle ID</param>
/// <param name="flags">Region flags</param>
/// <returns>The memory tracking handle</returns> /// <returns>The memory tracking handle</returns>
MultiRegionHandle BeginGranularTracking(ulong address, ulong size, IEnumerable<IRegionHandle> handles, ulong granularity, int id); MultiRegionHandle BeginGranularTracking(ulong address, ulong size, IEnumerable<IRegionHandle> handles, ulong granularity, int id, RegionFlags flags = RegionFlags.None);
/// <summary> /// <summary>
/// Obtains a smart memory tracking handle for the given virtual region, with a specified granularity. This should be disposed when finished with. /// Obtains a smart memory tracking handle for the given virtual region, with a specified granularity. This should be disposed when finished with.

View file

@ -33,6 +33,8 @@ namespace Ryujinx.Cpu.Jit
private readonly MemoryBlock _pageTable; private readonly MemoryBlock _pageTable;
private readonly ManagedPageFlags _pages;
/// <summary> /// <summary>
/// Page table base pointer. /// Page table base pointer.
/// </summary> /// </summary>
@ -70,6 +72,8 @@ namespace Ryujinx.Cpu.Jit
AddressSpaceSize = asSize; AddressSpaceSize = asSize;
_pageTable = new MemoryBlock((asSize / PageSize) * PteSize); _pageTable = new MemoryBlock((asSize / PageSize) * PteSize);
_pages = new ManagedPageFlags(AddressSpaceBits);
Tracking = new MemoryTracking(this, PageSize); Tracking = new MemoryTracking(this, PageSize);
} }
@ -89,6 +93,7 @@ namespace Ryujinx.Cpu.Jit
remainingSize -= PageSize; remainingSize -= PageSize;
} }
_pages.AddMapping(oVa, size);
Tracking.Map(oVa, size); Tracking.Map(oVa, size);
} }
@ -111,6 +116,7 @@ namespace Ryujinx.Cpu.Jit
UnmapEvent?.Invoke(va, size); UnmapEvent?.Invoke(va, size);
Tracking.Unmap(va, size); Tracking.Unmap(va, size);
_pages.RemoveMapping(va, size);
ulong remainingSize = size; ulong remainingSize = size;
while (remainingSize != 0) while (remainingSize != 0)
@ -148,6 +154,26 @@ namespace Ryujinx.Cpu.Jit
} }
} }
/// <inheritdoc/>
public T ReadGuest<T>(ulong va) where T : unmanaged
{
try
{
SignalMemoryTrackingImpl(va, (ulong)Unsafe.SizeOf<T>(), false, true);
return Read<T>(va);
}
catch (InvalidMemoryRegionException)
{
if (_invalidAccessHandler == null || !_invalidAccessHandler(va))
{
throw;
}
return default;
}
}
/// <inheritdoc/> /// <inheritdoc/>
public override void Read(ulong va, Span<byte> data) public override void Read(ulong va, Span<byte> data)
{ {
@ -183,6 +209,16 @@ namespace Ryujinx.Cpu.Jit
WriteImpl(va, data); WriteImpl(va, data);
} }
/// <inheritdoc/>
public void WriteGuest<T>(ulong va, T value) where T : unmanaged
{
Span<byte> data = MemoryMarshal.Cast<T, byte>(MemoryMarshal.CreateSpan(ref value, 1));
SignalMemoryTrackingImpl(va, (ulong)data.Length, true, true);
WriteImpl(va, data);
}
/// <inheritdoc/> /// <inheritdoc/>
public void WriteUntracked(ulong va, ReadOnlySpan<byte> data) public void WriteUntracked(ulong va, ReadOnlySpan<byte> data)
{ {
@ -520,10 +556,12 @@ namespace Ryujinx.Cpu.Jit
} }
/// <inheritdoc/> /// <inheritdoc/>
public void TrackingReprotect(ulong va, ulong size, MemoryPermission protection) public void TrackingReprotect(ulong va, ulong size, MemoryPermission protection, bool guest)
{ {
AssertValidAddressAndSize(va, size); AssertValidAddressAndSize(va, size);
if (guest)
{
// Protection is inverted on software pages, since the default value is 0. // Protection is inverted on software pages, since the default value is 0.
protection = (~protection) & MemoryPermission.ReadAndWrite; protection = (~protection) & MemoryPermission.ReadAndWrite;
@ -553,17 +591,22 @@ namespace Ryujinx.Cpu.Jit
pageStart++; pageStart++;
} }
} }
else
/// <inheritdoc/>
public RegionHandle BeginTracking(ulong address, ulong size, int id)
{ {
return Tracking.BeginTracking(address, size, id); _pages.TrackingReprotect(va, size, protection);
}
} }
/// <inheritdoc/> /// <inheritdoc/>
public MultiRegionHandle BeginGranularTracking(ulong address, ulong size, IEnumerable<IRegionHandle> handles, ulong granularity, int id) public RegionHandle BeginTracking(ulong address, ulong size, int id, RegionFlags flags = RegionFlags.None)
{ {
return Tracking.BeginGranularTracking(address, size, handles, granularity, id); return Tracking.BeginTracking(address, size, id, flags);
}
/// <inheritdoc/>
public MultiRegionHandle BeginGranularTracking(ulong address, ulong size, IEnumerable<IRegionHandle> handles, ulong granularity, int id, RegionFlags flags = RegionFlags.None)
{
return Tracking.BeginGranularTracking(address, size, handles, granularity, id, flags);
} }
/// <inheritdoc/> /// <inheritdoc/>
@ -572,8 +615,7 @@ namespace Ryujinx.Cpu.Jit
return Tracking.BeginSmartGranularTracking(address, size, granularity, id); return Tracking.BeginSmartGranularTracking(address, size, granularity, id);
} }
/// <inheritdoc/> private void SignalMemoryTrackingImpl(ulong va, ulong size, bool write, bool guest, bool precise = false, int? exemptId = null)
public void SignalMemoryTracking(ulong va, ulong size, bool write, bool precise = false, int? exemptId = null)
{ {
AssertValidAddressAndSize(va, size); AssertValidAddressAndSize(va, size);
@ -583,6 +625,11 @@ namespace Ryujinx.Cpu.Jit
return; return;
} }
// If the memory tracking is coming from the guest, use the tag bits in the page table entry.
// Otherwise, use the managed page flags.
if (guest)
{
// We emulate guard pages for software memory access. This makes for an easy transition to // We emulate guard pages for software memory access. This makes for an easy transition to
// tracking using host guard pages in future, but also supporting platforms where this is not possible. // tracking using host guard pages in future, but also supporting platforms where this is not possible.
@ -602,13 +649,24 @@ namespace Ryujinx.Cpu.Jit
if ((pte & tag) != 0) if ((pte & tag) != 0)
{ {
Tracking.VirtualMemoryEvent(va, size, write, precise: false, exemptId); Tracking.VirtualMemoryEvent(va, size, write, precise: false, exemptId, true);
break; break;
} }
pageStart++; pageStart++;
} }
} }
else
{
_pages.SignalMemoryTracking(Tracking, va, size, write, exemptId);
}
}
/// <inheritdoc/>
public void SignalMemoryTracking(ulong va, ulong size, bool write, bool precise = false, int? exemptId = null)
{
SignalMemoryTrackingImpl(va, size, write, false, precise, exemptId);
}
private ulong PaToPte(ulong pa) private ulong PaToPte(ulong pa)
{ {

View file

@ -6,7 +6,6 @@ using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using System.Threading;
namespace Ryujinx.Cpu.Jit namespace Ryujinx.Cpu.Jit
{ {
@ -15,21 +14,6 @@ namespace Ryujinx.Cpu.Jit
/// </summary> /// </summary>
public sealed class MemoryManagerHostMapped : VirtualMemoryManagerRefCountedBase<ulong, ulong>, IMemoryManager, IVirtualMemoryManagerTracked, IWritableBlock public sealed class MemoryManagerHostMapped : VirtualMemoryManagerRefCountedBase<ulong, ulong>, IMemoryManager, IVirtualMemoryManagerTracked, IWritableBlock
{ {
public const int PageToPteShift = 5; // 32 pages (2 bits each) in one ulong page table entry.
public const ulong BlockMappedMask = 0x5555555555555555; // First bit of each table entry set.
private enum HostMappedPtBits : ulong
{
Unmapped = 0,
Mapped,
WriteTracked,
ReadWriteTracked,
MappedReplicated = 0x5555555555555555,
WriteTrackedReplicated = 0xaaaaaaaaaaaaaaaa,
ReadWriteTrackedReplicated = ulong.MaxValue,
}
private readonly InvalidAccessHandler _invalidAccessHandler; private readonly InvalidAccessHandler _invalidAccessHandler;
private readonly bool _unsafeMode; private readonly bool _unsafeMode;
@ -39,7 +23,7 @@ namespace Ryujinx.Cpu.Jit
private readonly MemoryEhMeilleure _memoryEh; private readonly MemoryEhMeilleure _memoryEh;
private readonly ulong[] _pageBitmap; private readonly ManagedPageFlags _pages;
/// <inheritdoc/> /// <inheritdoc/>
public bool Supports4KBPages => MemoryBlock.GetPageSize() == PageSize; public bool Supports4KBPages => MemoryBlock.GetPageSize() == PageSize;
@ -81,7 +65,7 @@ namespace Ryujinx.Cpu.Jit
AddressSpaceBits = asBits; AddressSpaceBits = asBits;
_pageBitmap = new ulong[1 << (AddressSpaceBits - (PageBits + PageToPteShift))]; _pages = new ManagedPageFlags(AddressSpaceBits);
Tracking = new MemoryTracking(this, (int)MemoryBlock.GetPageSize(), invalidAccessHandler); Tracking = new MemoryTracking(this, (int)MemoryBlock.GetPageSize(), invalidAccessHandler);
_memoryEh = new MemoryEhMeilleure(_addressSpace.Base, _addressSpace.Mirror, Tracking); _memoryEh = new MemoryEhMeilleure(_addressSpace.Base, _addressSpace.Mirror, Tracking);
@ -94,7 +78,7 @@ namespace Ryujinx.Cpu.Jit
/// <param name="size">Size of the range in bytes</param> /// <param name="size">Size of the range in bytes</param>
private void AssertMapped(ulong va, ulong size) private void AssertMapped(ulong va, ulong size)
{ {
if (!ValidateAddressAndSize(va, size) || !IsRangeMappedImpl(va, size)) if (!ValidateAddressAndSize(va, size) || !_pages.IsRangeMapped(va, size))
{ {
throw new InvalidMemoryRegionException($"Not mapped: va=0x{va:X16}, size=0x{size:X16}"); throw new InvalidMemoryRegionException($"Not mapped: va=0x{va:X16}, size=0x{size:X16}");
} }
@ -106,7 +90,7 @@ namespace Ryujinx.Cpu.Jit
AssertValidAddressAndSize(va, size); AssertValidAddressAndSize(va, size);
_addressSpace.Map(va, pa, size, flags); _addressSpace.Map(va, pa, size, flags);
AddMapping(va, size); _pages.AddMapping(va, size);
PtMap(va, pa, size); PtMap(va, pa, size);
Tracking.Map(va, size); Tracking.Map(va, size);
@ -126,7 +110,7 @@ namespace Ryujinx.Cpu.Jit
UnmapEvent?.Invoke(va, size); UnmapEvent?.Invoke(va, size);
Tracking.Unmap(va, size); Tracking.Unmap(va, size);
RemoveMapping(va, size); _pages.RemoveMapping(va, size);
PtUnmap(va, size); PtUnmap(va, size);
_addressSpace.Unmap(va, size); _addressSpace.Unmap(va, size);
} }
@ -337,22 +321,7 @@ namespace Ryujinx.Cpu.Jit
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public bool IsMapped(ulong va) public bool IsMapped(ulong va)
{ {
return ValidateAddress(va) && IsMappedImpl(va); return ValidateAddress(va) && _pages.IsMapped(va);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private bool IsMappedImpl(ulong va)
{
ulong page = va >> PageBits;
int bit = (int)((page & 31) << 1);
int pageIndex = (int)(page >> PageToPteShift);
ref ulong pageRef = ref _pageBitmap[pageIndex];
ulong pte = Volatile.Read(ref pageRef);
return ((pte >> bit) & 3) != 0;
} }
/// <inheritdoc/> /// <inheritdoc/>
@ -360,58 +329,7 @@ namespace Ryujinx.Cpu.Jit
{ {
AssertValidAddressAndSize(va, size); AssertValidAddressAndSize(va, size);
return IsRangeMappedImpl(va, size); return _pages.IsRangeMapped(va, size);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static void GetPageBlockRange(ulong pageStart, ulong pageEnd, out ulong startMask, out ulong endMask, out int pageIndex, out int pageEndIndex)
{
startMask = ulong.MaxValue << ((int)(pageStart & 31) << 1);
endMask = ulong.MaxValue >> (64 - ((int)(pageEnd & 31) << 1));
pageIndex = (int)(pageStart >> PageToPteShift);
pageEndIndex = (int)((pageEnd - 1) >> PageToPteShift);
}
private bool IsRangeMappedImpl(ulong va, ulong size)
{
int pages = GetPagesCount(va, size, out _);
if (pages == 1)
{
return IsMappedImpl(va);
}
ulong pageStart = va >> PageBits;
ulong pageEnd = pageStart + (ulong)pages;
GetPageBlockRange(pageStart, pageEnd, out ulong startMask, out ulong endMask, out int pageIndex, out int pageEndIndex);
// Check if either bit in each 2 bit page entry is set.
// OR the block with itself shifted down by 1, and check the first bit of each entry.
ulong mask = BlockMappedMask & startMask;
while (pageIndex <= pageEndIndex)
{
if (pageIndex == pageEndIndex)
{
mask &= endMask;
}
ref ulong pageRef = ref _pageBitmap[pageIndex++];
ulong pte = Volatile.Read(ref pageRef);
pte |= pte >> 1;
if ((pte & mask) != mask)
{
return false;
}
mask = BlockMappedMask;
}
return true;
} }
/// <inheritdoc/> /// <inheritdoc/>
@ -486,76 +404,7 @@ namespace Ryujinx.Cpu.Jit
return; return;
} }
// Software table, used for managed memory tracking. _pages.SignalMemoryTracking(Tracking, va, size, write, exemptId);
int pages = GetPagesCount(va, size, out _);
ulong pageStart = va >> PageBits;
if (pages == 1)
{
ulong tag = (ulong)(write ? HostMappedPtBits.WriteTracked : HostMappedPtBits.ReadWriteTracked);
int bit = (int)((pageStart & 31) << 1);
int pageIndex = (int)(pageStart >> PageToPteShift);
ref ulong pageRef = ref _pageBitmap[pageIndex];
ulong pte = Volatile.Read(ref pageRef);
ulong state = ((pte >> bit) & 3);
if (state >= tag)
{
Tracking.VirtualMemoryEvent(va, size, write, precise: false, exemptId);
return;
}
else if (state == 0)
{
ThrowInvalidMemoryRegionException($"Not mapped: va=0x{va:X16}, size=0x{size:X16}");
}
}
else
{
ulong pageEnd = pageStart + (ulong)pages;
GetPageBlockRange(pageStart, pageEnd, out ulong startMask, out ulong endMask, out int pageIndex, out int pageEndIndex);
ulong mask = startMask;
ulong anyTrackingTag = (ulong)HostMappedPtBits.WriteTrackedReplicated;
while (pageIndex <= pageEndIndex)
{
if (pageIndex == pageEndIndex)
{
mask &= endMask;
}
ref ulong pageRef = ref _pageBitmap[pageIndex++];
ulong pte = Volatile.Read(ref pageRef);
ulong mappedMask = mask & BlockMappedMask;
ulong mappedPte = pte | (pte >> 1);
if ((mappedPte & mappedMask) != mappedMask)
{
ThrowInvalidMemoryRegionException($"Not mapped: va=0x{va:X16}, size=0x{size:X16}");
}
pte &= mask;
if ((pte & anyTrackingTag) != 0) // Search for any tracking.
{
// Writes trigger any tracking.
// Only trigger tracking from reads if both bits are set on any page.
if (write || (pte & (pte >> 1) & BlockMappedMask) != 0)
{
Tracking.VirtualMemoryEvent(va, size, write, precise: false, exemptId);
break;
}
}
mask = ulong.MaxValue;
}
}
} }
/// <summary> /// <summary>
@ -582,103 +431,28 @@ namespace Ryujinx.Cpu.Jit
} }
/// <inheritdoc/> /// <inheritdoc/>
public void TrackingReprotect(ulong va, ulong size, MemoryPermission protection) public void TrackingReprotect(ulong va, ulong size, MemoryPermission protection, bool guest)
{ {
// Protection is inverted on software pages, since the default value is 0. if (guest)
protection = (~protection) & MemoryPermission.ReadAndWrite;
int pages = GetPagesCount(va, size, out va);
ulong pageStart = va >> PageBits;
if (pages == 1)
{ {
ulong protTag = protection switch _addressSpace.Base.Reprotect(va, size, protection, false);
{
MemoryPermission.None => (ulong)HostMappedPtBits.Mapped,
MemoryPermission.Write => (ulong)HostMappedPtBits.WriteTracked,
_ => (ulong)HostMappedPtBits.ReadWriteTracked,
};
int bit = (int)((pageStart & 31) << 1);
ulong tagMask = 3UL << bit;
ulong invTagMask = ~tagMask;
ulong tag = protTag << bit;
int pageIndex = (int)(pageStart >> PageToPteShift);
ref ulong pageRef = ref _pageBitmap[pageIndex];
ulong pte;
do
{
pte = Volatile.Read(ref pageRef);
}
while ((pte & tagMask) != 0 && Interlocked.CompareExchange(ref pageRef, (pte & invTagMask) | tag, pte) != pte);
} }
else else
{ {
ulong pageEnd = pageStart + (ulong)pages; _pages.TrackingReprotect(va, size, protection);
GetPageBlockRange(pageStart, pageEnd, out ulong startMask, out ulong endMask, out int pageIndex, out int pageEndIndex);
ulong mask = startMask;
ulong protTag = protection switch
{
MemoryPermission.None => (ulong)HostMappedPtBits.MappedReplicated,
MemoryPermission.Write => (ulong)HostMappedPtBits.WriteTrackedReplicated,
_ => (ulong)HostMappedPtBits.ReadWriteTrackedReplicated,
};
while (pageIndex <= pageEndIndex)
{
if (pageIndex == pageEndIndex)
{
mask &= endMask;
} }
ref ulong pageRef = ref _pageBitmap[pageIndex++];
ulong pte;
ulong mappedMask;
// Change the protection of all 2 bit entries that are mapped.
do
{
pte = Volatile.Read(ref pageRef);
mappedMask = pte | (pte >> 1);
mappedMask |= (mappedMask & BlockMappedMask) << 1;
mappedMask &= mask; // Only update mapped pages within the given range.
}
while (Interlocked.CompareExchange(ref pageRef, (pte & (~mappedMask)) | (protTag & mappedMask), pte) != pte);
mask = ulong.MaxValue;
}
}
protection = protection switch
{
MemoryPermission.None => MemoryPermission.ReadAndWrite,
MemoryPermission.Write => MemoryPermission.Read,
_ => MemoryPermission.None,
};
_addressSpace.Base.Reprotect(va, size, protection, false);
} }
/// <inheritdoc/> /// <inheritdoc/>
public RegionHandle BeginTracking(ulong address, ulong size, int id) public RegionHandle BeginTracking(ulong address, ulong size, int id, RegionFlags flags = RegionFlags.None)
{ {
return Tracking.BeginTracking(address, size, id); return Tracking.BeginTracking(address, size, id, flags);
} }
/// <inheritdoc/> /// <inheritdoc/>
public MultiRegionHandle BeginGranularTracking(ulong address, ulong size, IEnumerable<IRegionHandle> handles, ulong granularity, int id) public MultiRegionHandle BeginGranularTracking(ulong address, ulong size, IEnumerable<IRegionHandle> handles, ulong granularity, int id, RegionFlags flags = RegionFlags.None)
{ {
return Tracking.BeginGranularTracking(address, size, handles, granularity, id); return Tracking.BeginGranularTracking(address, size, handles, granularity, id, flags);
} }
/// <inheritdoc/> /// <inheritdoc/>
@ -687,86 +461,6 @@ namespace Ryujinx.Cpu.Jit
return Tracking.BeginSmartGranularTracking(address, size, granularity, id); return Tracking.BeginSmartGranularTracking(address, size, granularity, id);
} }
/// <summary>
/// Adds the given address mapping to the page table.
/// </summary>
/// <param name="va">Virtual memory address</param>
/// <param name="size">Size to be mapped</param>
private void AddMapping(ulong va, ulong size)
{
int pages = GetPagesCount(va, size, out _);
ulong pageStart = va >> PageBits;
ulong pageEnd = pageStart + (ulong)pages;
GetPageBlockRange(pageStart, pageEnd, out ulong startMask, out ulong endMask, out int pageIndex, out int pageEndIndex);
ulong mask = startMask;
while (pageIndex <= pageEndIndex)
{
if (pageIndex == pageEndIndex)
{
mask &= endMask;
}
ref ulong pageRef = ref _pageBitmap[pageIndex++];
ulong pte;
ulong mappedMask;
// Map all 2-bit entries that are unmapped.
do
{
pte = Volatile.Read(ref pageRef);
mappedMask = pte | (pte >> 1);
mappedMask |= (mappedMask & BlockMappedMask) << 1;
mappedMask |= ~mask; // Treat everything outside the range as mapped, thus unchanged.
}
while (Interlocked.CompareExchange(ref pageRef, (pte & mappedMask) | (BlockMappedMask & (~mappedMask)), pte) != pte);
mask = ulong.MaxValue;
}
}
/// <summary>
/// Removes the given address mapping from the page table.
/// </summary>
/// <param name="va">Virtual memory address</param>
/// <param name="size">Size to be unmapped</param>
private void RemoveMapping(ulong va, ulong size)
{
int pages = GetPagesCount(va, size, out _);
ulong pageStart = va >> PageBits;
ulong pageEnd = pageStart + (ulong)pages;
GetPageBlockRange(pageStart, pageEnd, out ulong startMask, out ulong endMask, out int pageIndex, out int pageEndIndex);
startMask = ~startMask;
endMask = ~endMask;
ulong mask = startMask;
while (pageIndex <= pageEndIndex)
{
if (pageIndex == pageEndIndex)
{
mask |= endMask;
}
ref ulong pageRef = ref _pageBitmap[pageIndex++];
ulong pte;
do
{
pte = Volatile.Read(ref pageRef);
}
while (Interlocked.CompareExchange(ref pageRef, pte & mask, pte) != pte);
mask = 0;
}
}
/// <summary> /// <summary>
/// Disposes of resources used by the memory manager. /// Disposes of resources used by the memory manager.
/// </summary> /// </summary>

View file

@ -0,0 +1,389 @@
using Ryujinx.Memory;
using Ryujinx.Memory.Tracking;
using System;
using System.Runtime.CompilerServices;
using System.Threading;
namespace Ryujinx.Cpu
{
/// <summary>
/// A page bitmap that keeps track of mapped state and tracking protection
/// for managed memory accesses (not using host page protection).
/// </summary>
internal readonly struct ManagedPageFlags
{
public const int PageBits = 12;
public const int PageSize = 1 << PageBits;
public const int PageMask = PageSize - 1;
private readonly ulong[] _pageBitmap;
public const int PageToPteShift = 5; // 32 pages (2 bits each) in one ulong page table entry.
public const ulong BlockMappedMask = 0x5555555555555555; // First bit of each table entry set.
private enum ManagedPtBits : ulong
{
Unmapped = 0,
Mapped,
WriteTracked,
ReadWriteTracked,
MappedReplicated = 0x5555555555555555,
WriteTrackedReplicated = 0xaaaaaaaaaaaaaaaa,
ReadWriteTrackedReplicated = ulong.MaxValue,
}
public ManagedPageFlags(int addressSpaceBits)
{
int bits = Math.Max(0, addressSpaceBits - (PageBits + PageToPteShift));
_pageBitmap = new ulong[1 << bits];
}
/// <summary>
/// Computes the number of pages in a virtual address range.
/// </summary>
/// <param name="va">Virtual address of the range</param>
/// <param name="size">Size of the range</param>
/// <param name="startVa">The virtual address of the beginning of the first page</param>
/// <remarks>This function does not differentiate between allocated and unallocated pages.</remarks>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static int GetPagesCount(ulong va, ulong size, out ulong startVa)
{
// WARNING: Always check if ulong does not overflow during the operations.
startVa = va & ~(ulong)PageMask;
ulong vaSpan = (va - startVa + size + PageMask) & ~(ulong)PageMask;
return (int)(vaSpan / PageSize);
}
/// <summary>
/// Checks if the page at a given CPU virtual address is mapped.
/// </summary>
/// <param name="va">Virtual address to check</param>
/// <returns>True if the address is mapped, false otherwise</returns>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public readonly bool IsMapped(ulong va)
{
ulong page = va >> PageBits;
int bit = (int)((page & 31) << 1);
int pageIndex = (int)(page >> PageToPteShift);
ref ulong pageRef = ref _pageBitmap[pageIndex];
ulong pte = Volatile.Read(ref pageRef);
return ((pte >> bit) & 3) != 0;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static void GetPageBlockRange(ulong pageStart, ulong pageEnd, out ulong startMask, out ulong endMask, out int pageIndex, out int pageEndIndex)
{
startMask = ulong.MaxValue << ((int)(pageStart & 31) << 1);
endMask = ulong.MaxValue >> (64 - ((int)(pageEnd & 31) << 1));
pageIndex = (int)(pageStart >> PageToPteShift);
pageEndIndex = (int)((pageEnd - 1) >> PageToPteShift);
}
/// <summary>
/// Checks if a memory range is mapped.
/// </summary>
/// <param name="va">Virtual address of the range</param>
/// <param name="size">Size of the range in bytes</param>
/// <returns>True if the entire range is mapped, false otherwise</returns>
public readonly bool IsRangeMapped(ulong va, ulong size)
{
int pages = GetPagesCount(va, size, out _);
if (pages == 1)
{
return IsMapped(va);
}
ulong pageStart = va >> PageBits;
ulong pageEnd = pageStart + (ulong)pages;
GetPageBlockRange(pageStart, pageEnd, out ulong startMask, out ulong endMask, out int pageIndex, out int pageEndIndex);
// Check if either bit in each 2 bit page entry is set.
// OR the block with itself shifted down by 1, and check the first bit of each entry.
ulong mask = BlockMappedMask & startMask;
while (pageIndex <= pageEndIndex)
{
if (pageIndex == pageEndIndex)
{
mask &= endMask;
}
ref ulong pageRef = ref _pageBitmap[pageIndex++];
ulong pte = Volatile.Read(ref pageRef);
pte |= pte >> 1;
if ((pte & mask) != mask)
{
return false;
}
mask = BlockMappedMask;
}
return true;
}
/// <summary>
/// Reprotect a region of virtual memory for tracking.
/// </summary>
/// <param name="va">Virtual address base</param>
/// <param name="size">Size of the region to protect</param>
/// <param name="protection">Memory protection to set</param>
public readonly void TrackingReprotect(ulong va, ulong size, MemoryPermission protection)
{
// Protection is inverted on software pages, since the default value is 0.
protection = (~protection) & MemoryPermission.ReadAndWrite;
int pages = GetPagesCount(va, size, out va);
ulong pageStart = va >> PageBits;
if (pages == 1)
{
ulong protTag = protection switch
{
MemoryPermission.None => (ulong)ManagedPtBits.Mapped,
MemoryPermission.Write => (ulong)ManagedPtBits.WriteTracked,
_ => (ulong)ManagedPtBits.ReadWriteTracked,
};
int bit = (int)((pageStart & 31) << 1);
ulong tagMask = 3UL << bit;
ulong invTagMask = ~tagMask;
ulong tag = protTag << bit;
int pageIndex = (int)(pageStart >> PageToPteShift);
ref ulong pageRef = ref _pageBitmap[pageIndex];
ulong pte;
do
{
pte = Volatile.Read(ref pageRef);
}
while ((pte & tagMask) != 0 && Interlocked.CompareExchange(ref pageRef, (pte & invTagMask) | tag, pte) != pte);
}
else
{
ulong pageEnd = pageStart + (ulong)pages;
GetPageBlockRange(pageStart, pageEnd, out ulong startMask, out ulong endMask, out int pageIndex, out int pageEndIndex);
ulong mask = startMask;
ulong protTag = protection switch
{
MemoryPermission.None => (ulong)ManagedPtBits.MappedReplicated,
MemoryPermission.Write => (ulong)ManagedPtBits.WriteTrackedReplicated,
_ => (ulong)ManagedPtBits.ReadWriteTrackedReplicated,
};
while (pageIndex <= pageEndIndex)
{
if (pageIndex == pageEndIndex)
{
mask &= endMask;
}
ref ulong pageRef = ref _pageBitmap[pageIndex++];
ulong pte;
ulong mappedMask;
// Change the protection of all 2 bit entries that are mapped.
do
{
pte = Volatile.Read(ref pageRef);
mappedMask = pte | (pte >> 1);
mappedMask |= (mappedMask & BlockMappedMask) << 1;
mappedMask &= mask; // Only update mapped pages within the given range.
}
while (Interlocked.CompareExchange(ref pageRef, (pte & (~mappedMask)) | (protTag & mappedMask), pte) != pte);
mask = ulong.MaxValue;
}
}
}
/// <summary>
/// Alerts the memory tracking that a given region has been read from or written to.
/// This should be called before read/write is performed.
/// </summary>
/// <param name="tracking">Memory tracking structure to call when pages are protected</param>
/// <param name="va">Virtual address of the region</param>
/// <param name="size">Size of the region</param>
/// <param name="write">True if the region was written, false if read</param>
/// <param name="exemptId">Optional ID of the handles that should not be signalled</param>
/// <remarks>
/// This function also validates that the given range is both valid and mapped, and will throw if it is not.
/// </remarks>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public readonly void SignalMemoryTracking(MemoryTracking tracking, ulong va, ulong size, bool write, int? exemptId = null)
{
// Software table, used for managed memory tracking.
int pages = GetPagesCount(va, size, out _);
ulong pageStart = va >> PageBits;
if (pages == 1)
{
ulong tag = (ulong)(write ? ManagedPtBits.WriteTracked : ManagedPtBits.ReadWriteTracked);
int bit = (int)((pageStart & 31) << 1);
int pageIndex = (int)(pageStart >> PageToPteShift);
ref ulong pageRef = ref _pageBitmap[pageIndex];
ulong pte = Volatile.Read(ref pageRef);
ulong state = ((pte >> bit) & 3);
if (state >= tag)
{
tracking.VirtualMemoryEvent(va, size, write, precise: false, exemptId);
return;
}
else if (state == 0)
{
ThrowInvalidMemoryRegionException($"Not mapped: va=0x{va:X16}, size=0x{size:X16}");
}
}
else
{
ulong pageEnd = pageStart + (ulong)pages;
GetPageBlockRange(pageStart, pageEnd, out ulong startMask, out ulong endMask, out int pageIndex, out int pageEndIndex);
ulong mask = startMask;
ulong anyTrackingTag = (ulong)ManagedPtBits.WriteTrackedReplicated;
while (pageIndex <= pageEndIndex)
{
if (pageIndex == pageEndIndex)
{
mask &= endMask;
}
ref ulong pageRef = ref _pageBitmap[pageIndex++];
ulong pte = Volatile.Read(ref pageRef);
ulong mappedMask = mask & BlockMappedMask;
ulong mappedPte = pte | (pte >> 1);
if ((mappedPte & mappedMask) != mappedMask)
{
ThrowInvalidMemoryRegionException($"Not mapped: va=0x{va:X16}, size=0x{size:X16}");
}
pte &= mask;
if ((pte & anyTrackingTag) != 0) // Search for any tracking.
{
// Writes trigger any tracking.
// Only trigger tracking from reads if both bits are set on any page.
if (write || (pte & (pte >> 1) & BlockMappedMask) != 0)
{
tracking.VirtualMemoryEvent(va, size, write, precise: false, exemptId);
break;
}
}
mask = ulong.MaxValue;
}
}
}
/// <summary>
/// Adds the given address mapping to the page table.
/// </summary>
/// <param name="va">Virtual memory address</param>
/// <param name="size">Size to be mapped</param>
public readonly void AddMapping(ulong va, ulong size)
{
int pages = GetPagesCount(va, size, out _);
ulong pageStart = va >> PageBits;
ulong pageEnd = pageStart + (ulong)pages;
GetPageBlockRange(pageStart, pageEnd, out ulong startMask, out ulong endMask, out int pageIndex, out int pageEndIndex);
ulong mask = startMask;
while (pageIndex <= pageEndIndex)
{
if (pageIndex == pageEndIndex)
{
mask &= endMask;
}
ref ulong pageRef = ref _pageBitmap[pageIndex++];
ulong pte;
ulong mappedMask;
// Map all 2-bit entries that are unmapped.
do
{
pte = Volatile.Read(ref pageRef);
mappedMask = pte | (pte >> 1);
mappedMask |= (mappedMask & BlockMappedMask) << 1;
mappedMask |= ~mask; // Treat everything outside the range as mapped, thus unchanged.
}
while (Interlocked.CompareExchange(ref pageRef, (pte & mappedMask) | (BlockMappedMask & (~mappedMask)), pte) != pte);
mask = ulong.MaxValue;
}
}
/// <summary>
/// Removes the given address mapping from the page table.
/// </summary>
/// <param name="va">Virtual memory address</param>
/// <param name="size">Size to be unmapped</param>
public readonly void RemoveMapping(ulong va, ulong size)
{
int pages = GetPagesCount(va, size, out _);
ulong pageStart = va >> PageBits;
ulong pageEnd = pageStart + (ulong)pages;
GetPageBlockRange(pageStart, pageEnd, out ulong startMask, out ulong endMask, out int pageIndex, out int pageEndIndex);
startMask = ~startMask;
endMask = ~endMask;
ulong mask = startMask;
while (pageIndex <= pageEndIndex)
{
if (pageIndex == pageEndIndex)
{
mask |= endMask;
}
ref ulong pageRef = ref _pageBitmap[pageIndex++];
ulong pte;
do
{
pte = Volatile.Read(ref pageRef);
}
while (Interlocked.CompareExchange(ref pageRef, pte & mask, pte) != pte);
mask = 0;
}
}
private static void ThrowInvalidMemoryRegionException(string message) => throw new InvalidMemoryRegionException(message);
}
}

View file

@ -69,7 +69,7 @@ namespace Ryujinx.Graphics.Gpu.Image
Address = address; Address = address;
Size = size; Size = size;
_memoryTracking = physicalMemory.BeginGranularTracking(address, size, ResourceKind.Pool); _memoryTracking = physicalMemory.BeginGranularTracking(address, size, ResourceKind.Pool, RegionFlags.None);
_memoryTracking.RegisterPreciseAction(address, size, PreciseAction); _memoryTracking.RegisterPreciseAction(address, size, PreciseAction);
_modifiedDelegate = RegionModified; _modifiedDelegate = RegionModified;
} }

View file

@ -128,13 +128,13 @@ namespace Ryujinx.Graphics.Gpu.Memory
if (_useGranular) if (_useGranular)
{ {
_memoryTrackingGranular = physicalMemory.BeginGranularTracking(address, size, ResourceKind.Buffer, baseHandles); _memoryTrackingGranular = physicalMemory.BeginGranularTracking(address, size, ResourceKind.Buffer, RegionFlags.UnalignedAccess, baseHandles);
_memoryTrackingGranular.RegisterPreciseAction(address, size, PreciseAction); _memoryTrackingGranular.RegisterPreciseAction(address, size, PreciseAction);
} }
else else
{ {
_memoryTracking = physicalMemory.BeginTracking(address, size, ResourceKind.Buffer); _memoryTracking = physicalMemory.BeginTracking(address, size, ResourceKind.Buffer, RegionFlags.UnalignedAccess);
if (baseHandles != null) if (baseHandles != null)
{ {

View file

@ -368,10 +368,11 @@ namespace Ryujinx.Graphics.Gpu.Memory
/// <param name="address">CPU virtual address of the region</param> /// <param name="address">CPU virtual address of the region</param>
/// <param name="size">Size of the region</param> /// <param name="size">Size of the region</param>
/// <param name="kind">Kind of the resource being tracked</param> /// <param name="kind">Kind of the resource being tracked</param>
/// <param name="flags">Region flags</param>
/// <returns>The memory tracking handle</returns> /// <returns>The memory tracking handle</returns>
public RegionHandle BeginTracking(ulong address, ulong size, ResourceKind kind) public RegionHandle BeginTracking(ulong address, ulong size, ResourceKind kind, RegionFlags flags = RegionFlags.None)
{ {
return _cpuMemory.BeginTracking(address, size, (int)kind); return _cpuMemory.BeginTracking(address, size, (int)kind, flags);
} }
/// <summary> /// <summary>
@ -408,12 +409,19 @@ namespace Ryujinx.Graphics.Gpu.Memory
/// <param name="address">CPU virtual address of the region</param> /// <param name="address">CPU virtual address of the region</param>
/// <param name="size">Size of the region</param> /// <param name="size">Size of the region</param>
/// <param name="kind">Kind of the resource being tracked</param> /// <param name="kind">Kind of the resource being tracked</param>
/// <param name="flags">Region flags</param>
/// <param name="handles">Handles to inherit state from or reuse</param> /// <param name="handles">Handles to inherit state from or reuse</param>
/// <param name="granularity">Desired granularity of write tracking</param> /// <param name="granularity">Desired granularity of write tracking</param>
/// <returns>The memory tracking handle</returns> /// <returns>The memory tracking handle</returns>
public MultiRegionHandle BeginGranularTracking(ulong address, ulong size, ResourceKind kind, IEnumerable<IRegionHandle> handles = null, ulong granularity = 4096) public MultiRegionHandle BeginGranularTracking(
ulong address,
ulong size,
ResourceKind kind,
RegionFlags flags = RegionFlags.None,
IEnumerable<IRegionHandle> handles = null,
ulong granularity = 4096)
{ {
return _cpuMemory.BeginGranularTracking(address, size, handles, granularity, (int)kind); return _cpuMemory.BeginGranularTracking(address, size, handles, granularity, (int)kind, flags);
} }
/// <summary> /// <summary>

View file

@ -392,7 +392,7 @@ namespace Ryujinx.Memory
} }
/// <inheritdoc/> /// <inheritdoc/>
public void TrackingReprotect(ulong va, ulong size, MemoryPermission protection) public void TrackingReprotect(ulong va, ulong size, MemoryPermission protection, bool guest = false)
{ {
throw new NotImplementedException(); throw new NotImplementedException();
} }

View file

@ -214,6 +214,7 @@ namespace Ryujinx.Memory
/// <param name="va">Virtual address base</param> /// <param name="va">Virtual address base</param>
/// <param name="size">Size of the region to protect</param> /// <param name="size">Size of the region to protect</param>
/// <param name="protection">Memory protection to set</param> /// <param name="protection">Memory protection to set</param>
void TrackingReprotect(ulong va, ulong size, MemoryPermission protection); /// <param name="guest">True if the protection is for guest access, false otherwise</param>
void TrackingReprotect(ulong va, ulong size, MemoryPermission protection, bool guest);
} }
} }

View file

@ -14,9 +14,14 @@ namespace Ryujinx.Memory.Tracking
// Only use these from within the lock. // Only use these from within the lock.
private readonly NonOverlappingRangeList<VirtualRegion> _virtualRegions; private readonly NonOverlappingRangeList<VirtualRegion> _virtualRegions;
// Guest virtual regions are a subset of the normal virtual regions, with potentially different protection
// and expanded area of effect on platforms that don't support misaligned page protection.
private readonly NonOverlappingRangeList<VirtualRegion> _guestVirtualRegions;
private readonly int _pageSize; private readonly int _pageSize;
private readonly bool _singleByteGuestTracking;
/// <summary> /// <summary>
/// This lock must be obtained when traversing or updating the region-handle hierarchy. /// This lock must be obtained when traversing or updating the region-handle hierarchy.
/// It is not required when reading dirty flags. /// It is not required when reading dirty flags.
@ -27,16 +32,27 @@ namespace Ryujinx.Memory.Tracking
/// Create a new tracking structure for the given "physical" memory block, /// Create a new tracking structure for the given "physical" memory block,
/// with a given "virtual" memory manager that will provide mappings and virtual memory protection. /// with a given "virtual" memory manager that will provide mappings and virtual memory protection.
/// </summary> /// </summary>
/// <remarks>
/// If <paramref name="singleByteGuestTracking" /> is true, the memory manager must also support protection on partially
/// unmapped regions without throwing exceptions or dropping protection on the mapped portion.
/// </remarks>
/// <param name="memoryManager">Virtual memory manager</param> /// <param name="memoryManager">Virtual memory manager</param>
/// <param name="block">Physical memory block</param>
/// <param name="pageSize">Page size of the virtual memory space</param> /// <param name="pageSize">Page size of the virtual memory space</param>
public MemoryTracking(IVirtualMemoryManager memoryManager, int pageSize, InvalidAccessHandler invalidAccessHandler = null) /// <param name="invalidAccessHandler">Method to call for invalid memory accesses</param>
/// <param name="singleByteGuestTracking">True if the guest only signals writes for the first byte</param>
public MemoryTracking(
IVirtualMemoryManager memoryManager,
int pageSize,
InvalidAccessHandler invalidAccessHandler = null,
bool singleByteGuestTracking = false)
{ {
_memoryManager = memoryManager; _memoryManager = memoryManager;
_pageSize = pageSize; _pageSize = pageSize;
_invalidAccessHandler = invalidAccessHandler; _invalidAccessHandler = invalidAccessHandler;
_singleByteGuestTracking = singleByteGuestTracking;
_virtualRegions = new NonOverlappingRangeList<VirtualRegion>(); _virtualRegions = new NonOverlappingRangeList<VirtualRegion>();
_guestVirtualRegions = new NonOverlappingRangeList<VirtualRegion>();
} }
private (ulong address, ulong size) PageAlign(ulong address, ulong size) private (ulong address, ulong size) PageAlign(ulong address, ulong size)
@ -62,7 +78,11 @@ namespace Ryujinx.Memory.Tracking
{ {
ref var overlaps = ref ThreadStaticArray<VirtualRegion>.Get(); ref var overlaps = ref ThreadStaticArray<VirtualRegion>.Get();
int count = _virtualRegions.FindOverlapsNonOverlapping(va, size, ref overlaps); for (int type = 0; type < 2; type++)
{
NonOverlappingRangeList<VirtualRegion> regions = type == 0 ? _virtualRegions : _guestVirtualRegions;
int count = regions.FindOverlapsNonOverlapping(va, size, ref overlaps);
for (int i = 0; i < count; i++) for (int i = 0; i < count; i++)
{ {
@ -79,6 +99,7 @@ namespace Ryujinx.Memory.Tracking
} }
} }
} }
}
/// <summary> /// <summary>
/// Indicate that a virtual region has been unmapped. /// Indicate that a virtual region has been unmapped.
@ -95,7 +116,11 @@ namespace Ryujinx.Memory.Tracking
{ {
ref var overlaps = ref ThreadStaticArray<VirtualRegion>.Get(); ref var overlaps = ref ThreadStaticArray<VirtualRegion>.Get();
int count = _virtualRegions.FindOverlapsNonOverlapping(va, size, ref overlaps); for (int type = 0; type < 2; type++)
{
NonOverlappingRangeList<VirtualRegion> regions = type == 0 ? _virtualRegions : _guestVirtualRegions;
int count = regions.FindOverlapsNonOverlapping(va, size, ref overlaps);
for (int i = 0; i < count; i++) for (int i = 0; i < count; i++)
{ {
@ -105,17 +130,44 @@ namespace Ryujinx.Memory.Tracking
} }
} }
} }
}
/// <summary>
/// Alter a tracked memory region to properly capture unaligned accesses.
/// For most memory manager modes, this does nothing.
/// </summary>
/// <param name="address">Original region address</param>
/// <param name="size">Original region size</param>
/// <returns>A new address and size for tracking unaligned accesses</returns>
internal (ulong newAddress, ulong newSize) GetUnalignedSafeRegion(ulong address, ulong size)
{
if (_singleByteGuestTracking)
{
// The guest only signals the first byte of each memory access with the current memory manager.
// To catch unaligned access properly, we need to also protect the page before the address.
// Assume that the address and size are already aligned.
return (address - (ulong)_pageSize, size + (ulong)_pageSize);
}
else
{
return (address, size);
}
}
/// <summary> /// <summary>
/// Get a list of virtual regions that a handle covers. /// Get a list of virtual regions that a handle covers.
/// </summary> /// </summary>
/// <param name="va">Starting virtual memory address of the handle</param> /// <param name="va">Starting virtual memory address of the handle</param>
/// <param name="size">Size of the handle's memory region</param> /// <param name="size">Size of the handle's memory region</param>
/// <param name="guest">True if getting handles for guest protection, false otherwise</param>
/// <returns>A list of virtual regions within the given range</returns> /// <returns>A list of virtual regions within the given range</returns>
internal List<VirtualRegion> GetVirtualRegionsForHandle(ulong va, ulong size) internal List<VirtualRegion> GetVirtualRegionsForHandle(ulong va, ulong size, bool guest)
{ {
List<VirtualRegion> result = new(); List<VirtualRegion> result = new();
_virtualRegions.GetOrAddRegions(result, va, size, (va, size) => new VirtualRegion(this, va, size)); NonOverlappingRangeList<VirtualRegion> regions = guest ? _guestVirtualRegions : _virtualRegions;
regions.GetOrAddRegions(result, va, size, (va, size) => new VirtualRegion(this, va, size, guest));
return result; return result;
} }
@ -125,9 +177,16 @@ namespace Ryujinx.Memory.Tracking
/// </summary> /// </summary>
/// <param name="region">Region to remove</param> /// <param name="region">Region to remove</param>
internal void RemoveVirtual(VirtualRegion region) internal void RemoveVirtual(VirtualRegion region)
{
if (region.Guest)
{
_guestVirtualRegions.Remove(region);
}
else
{ {
_virtualRegions.Remove(region); _virtualRegions.Remove(region);
} }
}
/// <summary> /// <summary>
/// Obtains a memory tracking handle for the given virtual region, with a specified granularity. This should be disposed when finished with. /// Obtains a memory tracking handle for the given virtual region, with a specified granularity. This should be disposed when finished with.
@ -137,10 +196,11 @@ namespace Ryujinx.Memory.Tracking
/// <param name="handles">Handles to inherit state from or reuse. When none are present, provide null</param> /// <param name="handles">Handles to inherit state from or reuse. When none are present, provide null</param>
/// <param name="granularity">Desired granularity of write tracking</param> /// <param name="granularity">Desired granularity of write tracking</param>
/// <param name="id">Handle ID</param> /// <param name="id">Handle ID</param>
/// <param name="flags">Region flags</param>
/// <returns>The memory tracking handle</returns> /// <returns>The memory tracking handle</returns>
public MultiRegionHandle BeginGranularTracking(ulong address, ulong size, IEnumerable<IRegionHandle> handles, ulong granularity, int id) public MultiRegionHandle BeginGranularTracking(ulong address, ulong size, IEnumerable<IRegionHandle> handles, ulong granularity, int id, RegionFlags flags = RegionFlags.None)
{ {
return new MultiRegionHandle(this, address, size, handles, granularity, id); return new MultiRegionHandle(this, address, size, handles, granularity, id, flags);
} }
/// <summary> /// <summary>
@ -164,15 +224,16 @@ namespace Ryujinx.Memory.Tracking
/// <param name="address">CPU virtual address of the region</param> /// <param name="address">CPU virtual address of the region</param>
/// <param name="size">Size of the region</param> /// <param name="size">Size of the region</param>
/// <param name="id">Handle ID</param> /// <param name="id">Handle ID</param>
/// <param name="flags">Region flags</param>
/// <returns>The memory tracking handle</returns> /// <returns>The memory tracking handle</returns>
public RegionHandle BeginTracking(ulong address, ulong size, int id) public RegionHandle BeginTracking(ulong address, ulong size, int id, RegionFlags flags = RegionFlags.None)
{ {
var (paAddress, paSize) = PageAlign(address, size); var (paAddress, paSize) = PageAlign(address, size);
lock (TrackingLock) lock (TrackingLock)
{ {
bool mapped = _memoryManager.IsRangeMapped(address, size); bool mapped = _memoryManager.IsRangeMapped(address, size);
RegionHandle handle = new(this, paAddress, paSize, address, size, id, mapped); RegionHandle handle = new(this, paAddress, paSize, address, size, id, flags, mapped);
return handle; return handle;
} }
@ -186,15 +247,16 @@ namespace Ryujinx.Memory.Tracking
/// <param name="bitmap">The bitmap owning the dirty flag for this handle</param> /// <param name="bitmap">The bitmap owning the dirty flag for this handle</param>
/// <param name="bit">The bit of this handle within the dirty flag</param> /// <param name="bit">The bit of this handle within the dirty flag</param>
/// <param name="id">Handle ID</param> /// <param name="id">Handle ID</param>
/// <param name="flags">Region flags</param>
/// <returns>The memory tracking handle</returns> /// <returns>The memory tracking handle</returns>
internal RegionHandle BeginTrackingBitmap(ulong address, ulong size, ConcurrentBitmap bitmap, int bit, int id) internal RegionHandle BeginTrackingBitmap(ulong address, ulong size, ConcurrentBitmap bitmap, int bit, int id, RegionFlags flags = RegionFlags.None)
{ {
var (paAddress, paSize) = PageAlign(address, size); var (paAddress, paSize) = PageAlign(address, size);
lock (TrackingLock) lock (TrackingLock)
{ {
bool mapped = _memoryManager.IsRangeMapped(address, size); bool mapped = _memoryManager.IsRangeMapped(address, size);
RegionHandle handle = new(this, paAddress, paSize, address, size, bitmap, bit, id, mapped); RegionHandle handle = new(this, paAddress, paSize, address, size, bitmap, bit, id, flags, mapped);
return handle; return handle;
} }
@ -202,6 +264,7 @@ namespace Ryujinx.Memory.Tracking
/// <summary> /// <summary>
/// Signal that a virtual memory event happened at the given location. /// Signal that a virtual memory event happened at the given location.
/// The memory event is assumed to be triggered by guest code.
/// </summary> /// </summary>
/// <param name="address">Virtual address accessed</param> /// <param name="address">Virtual address accessed</param>
/// <param name="size">Size of the region affected in bytes</param> /// <param name="size">Size of the region affected in bytes</param>
@ -209,7 +272,7 @@ namespace Ryujinx.Memory.Tracking
/// <returns>True if the event triggered any tracking regions, false otherwise</returns> /// <returns>True if the event triggered any tracking regions, false otherwise</returns>
public bool VirtualMemoryEvent(ulong address, ulong size, bool write) public bool VirtualMemoryEvent(ulong address, ulong size, bool write)
{ {
return VirtualMemoryEvent(address, size, write, precise: false, null); return VirtualMemoryEvent(address, size, write, precise: false, exemptId: null, guest: true);
} }
/// <summary> /// <summary>
@ -222,8 +285,9 @@ namespace Ryujinx.Memory.Tracking
/// <param name="write">Whether the region was written to or read</param> /// <param name="write">Whether the region was written to or read</param>
/// <param name="precise">True if the access is precise, false otherwise</param> /// <param name="precise">True if the access is precise, false otherwise</param>
/// <param name="exemptId">Optional ID that of the handles that should not be signalled</param> /// <param name="exemptId">Optional ID that of the handles that should not be signalled</param>
/// <param name="guest">True if the access is from the guest, false otherwise</param>
/// <returns>True if the event triggered any tracking regions, false otherwise</returns> /// <returns>True if the event triggered any tracking regions, false otherwise</returns>
public bool VirtualMemoryEvent(ulong address, ulong size, bool write, bool precise, int? exemptId = null) public bool VirtualMemoryEvent(ulong address, ulong size, bool write, bool precise, int? exemptId = null, bool guest = false)
{ {
// Look up the virtual region using the region list. // Look up the virtual region using the region list.
// Signal up the chain to relevant handles. // Signal up the chain to relevant handles.
@ -234,7 +298,9 @@ namespace Ryujinx.Memory.Tracking
{ {
ref var overlaps = ref ThreadStaticArray<VirtualRegion>.Get(); ref var overlaps = ref ThreadStaticArray<VirtualRegion>.Get();
int count = _virtualRegions.FindOverlapsNonOverlapping(address, size, ref overlaps); NonOverlappingRangeList<VirtualRegion> regions = guest ? _guestVirtualRegions : _virtualRegions;
int count = regions.FindOverlapsNonOverlapping(address, size, ref overlaps);
if (count == 0 && !precise) if (count == 0 && !precise)
{ {
@ -242,7 +308,7 @@ namespace Ryujinx.Memory.Tracking
{ {
// TODO: There is currently the possibility that a page can be protected after its virtual region is removed. // TODO: There is currently the possibility that a page can be protected after its virtual region is removed.
// This code handles that case when it happens, but it would be better to find out how this happens. // This code handles that case when it happens, but it would be better to find out how this happens.
_memoryManager.TrackingReprotect(address & ~(ulong)(_pageSize - 1), (ulong)_pageSize, MemoryPermission.ReadAndWrite); _memoryManager.TrackingReprotect(address & ~(ulong)(_pageSize - 1), (ulong)_pageSize, MemoryPermission.ReadAndWrite, guest);
return true; // This memory _should_ be mapped, so we need to try again. return true; // This memory _should_ be mapped, so we need to try again.
} }
else else
@ -252,6 +318,12 @@ namespace Ryujinx.Memory.Tracking
} }
else else
{ {
if (guest && _singleByteGuestTracking)
{
// Increase the access size to trigger handles with misaligned accesses.
size += (ulong)_pageSize;
}
for (int i = 0; i < count; i++) for (int i = 0; i < count; i++)
{ {
VirtualRegion region = overlaps[i]; VirtualRegion region = overlaps[i];
@ -285,9 +357,10 @@ namespace Ryujinx.Memory.Tracking
/// </summary> /// </summary>
/// <param name="region">Region to reprotect</param> /// <param name="region">Region to reprotect</param>
/// <param name="permission">Memory permission to protect with</param> /// <param name="permission">Memory permission to protect with</param>
internal void ProtectVirtualRegion(VirtualRegion region, MemoryPermission permission) /// <param name="guest">True if the protection is for guest access, false otherwise</param>
internal void ProtectVirtualRegion(VirtualRegion region, MemoryPermission permission, bool guest)
{ {
_memoryManager.TrackingReprotect(region.Address, region.Size, permission); _memoryManager.TrackingReprotect(region.Address, region.Size, permission, guest);
} }
/// <summary> /// <summary>

View file

@ -37,7 +37,8 @@ namespace Ryujinx.Memory.Tracking
ulong size, ulong size,
IEnumerable<IRegionHandle> handles, IEnumerable<IRegionHandle> handles,
ulong granularity, ulong granularity,
int id) int id,
RegionFlags flags)
{ {
_handles = new RegionHandle[(size + granularity - 1) / granularity]; _handles = new RegionHandle[(size + granularity - 1) / granularity];
Granularity = granularity; Granularity = granularity;
@ -62,7 +63,7 @@ namespace Ryujinx.Memory.Tracking
// Fill any gap left before this handle. // Fill any gap left before this handle.
while (i < startIndex) while (i < startIndex)
{ {
RegionHandle fillHandle = tracking.BeginTrackingBitmap(address + (ulong)i * granularity, granularity, _dirtyBitmap, i, id); RegionHandle fillHandle = tracking.BeginTrackingBitmap(address + (ulong)i * granularity, granularity, _dirtyBitmap, i, id, flags);
fillHandle.Parent = this; fillHandle.Parent = this;
_handles[i++] = fillHandle; _handles[i++] = fillHandle;
} }
@ -83,7 +84,7 @@ namespace Ryujinx.Memory.Tracking
while (i < endIndex) while (i < endIndex)
{ {
RegionHandle splitHandle = tracking.BeginTrackingBitmap(address + (ulong)i * granularity, granularity, _dirtyBitmap, i, id); RegionHandle splitHandle = tracking.BeginTrackingBitmap(address + (ulong)i * granularity, granularity, _dirtyBitmap, i, id, flags);
splitHandle.Parent = this; splitHandle.Parent = this;
splitHandle.Reprotect(handle.Dirty); splitHandle.Reprotect(handle.Dirty);
@ -106,7 +107,7 @@ namespace Ryujinx.Memory.Tracking
// Fill any remaining space with new handles. // Fill any remaining space with new handles.
while (i < _handles.Length) while (i < _handles.Length)
{ {
RegionHandle handle = tracking.BeginTrackingBitmap(address + (ulong)i * granularity, granularity, _dirtyBitmap, i, id); RegionHandle handle = tracking.BeginTrackingBitmap(address + (ulong)i * granularity, granularity, _dirtyBitmap, i, id, flags);
handle.Parent = this; handle.Parent = this;
_handles[i++] = handle; _handles[i++] = handle;
} }

View file

@ -0,0 +1,21 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace Ryujinx.Memory.Tracking
{
[Flags]
public enum RegionFlags
{
None = 0,
/// <summary>
/// Access to the resource is expected to occasionally be unaligned.
/// With some memory managers, guest protection must extend into the previous page to cover unaligned access.
/// If this is not expected, protection is not altered, which can avoid unintended resource dirty/flush.
/// </summary>
UnalignedAccess = 1,
}
}

View file

@ -55,6 +55,8 @@ namespace Ryujinx.Memory.Tracking
private RegionSignal _preAction; // Action to perform before a read or write. This will block the memory access. private RegionSignal _preAction; // Action to perform before a read or write. This will block the memory access.
private PreciseRegionSignal _preciseAction; // Action to perform on a precise read or write. private PreciseRegionSignal _preciseAction; // Action to perform on a precise read or write.
private readonly List<VirtualRegion> _regions; private readonly List<VirtualRegion> _regions;
private readonly List<VirtualRegion> _guestRegions;
private readonly List<VirtualRegion> _allRegions;
private readonly MemoryTracking _tracking; private readonly MemoryTracking _tracking;
private bool _disposed; private bool _disposed;
@ -99,6 +101,7 @@ namespace Ryujinx.Memory.Tracking
/// <param name="bitmap">The bitmap the dirty flag for this handle is stored in</param> /// <param name="bitmap">The bitmap the dirty flag for this handle is stored in</param>
/// <param name="bit">The bit index representing the dirty flag for this handle</param> /// <param name="bit">The bit index representing the dirty flag for this handle</param>
/// <param name="id">Handle ID</param> /// <param name="id">Handle ID</param>
/// <param name="flags">Region flags</param>
/// <param name="mapped">True if the region handle starts mapped</param> /// <param name="mapped">True if the region handle starts mapped</param>
internal RegionHandle( internal RegionHandle(
MemoryTracking tracking, MemoryTracking tracking,
@ -109,6 +112,7 @@ namespace Ryujinx.Memory.Tracking
ConcurrentBitmap bitmap, ConcurrentBitmap bitmap,
int bit, int bit,
int id, int id,
RegionFlags flags,
bool mapped = true) bool mapped = true)
{ {
Bitmap = bitmap; Bitmap = bitmap;
@ -128,11 +132,12 @@ namespace Ryujinx.Memory.Tracking
RealEndAddress = realAddress + realSize; RealEndAddress = realAddress + realSize;
_tracking = tracking; _tracking = tracking;
_regions = tracking.GetVirtualRegionsForHandle(address, size);
foreach (var region in _regions) _regions = tracking.GetVirtualRegionsForHandle(address, size, false);
{ _guestRegions = GetGuestRegions(tracking, address, size, flags);
region.Handles.Add(this); _allRegions = new List<VirtualRegion>(_regions.Count + _guestRegions.Count);
}
InitializeRegions();
} }
/// <summary> /// <summary>
@ -145,8 +150,9 @@ namespace Ryujinx.Memory.Tracking
/// <param name="realAddress">The real, unaligned address of the handle</param> /// <param name="realAddress">The real, unaligned address of the handle</param>
/// <param name="realSize">The real, unaligned size of the handle</param> /// <param name="realSize">The real, unaligned size of the handle</param>
/// <param name="id">Handle ID</param> /// <param name="id">Handle ID</param>
/// <param name="flags">Region flags</param>
/// <param name="mapped">True if the region handle starts mapped</param> /// <param name="mapped">True if the region handle starts mapped</param>
internal RegionHandle(MemoryTracking tracking, ulong address, ulong size, ulong realAddress, ulong realSize, int id, bool mapped = true) internal RegionHandle(MemoryTracking tracking, ulong address, ulong size, ulong realAddress, ulong realSize, int id, RegionFlags flags, bool mapped = true)
{ {
Bitmap = new ConcurrentBitmap(1, mapped); Bitmap = new ConcurrentBitmap(1, mapped);
@ -163,8 +169,37 @@ namespace Ryujinx.Memory.Tracking
RealEndAddress = realAddress + realSize; RealEndAddress = realAddress + realSize;
_tracking = tracking; _tracking = tracking;
_regions = tracking.GetVirtualRegionsForHandle(address, size);
foreach (var region in _regions) _regions = tracking.GetVirtualRegionsForHandle(address, size, false);
_guestRegions = GetGuestRegions(tracking, address, size, flags);
_allRegions = new List<VirtualRegion>(_regions.Count + _guestRegions.Count);
InitializeRegions();
}
private List<VirtualRegion> GetGuestRegions(MemoryTracking tracking, ulong address, ulong size, RegionFlags flags)
{
ulong guestAddress;
ulong guestSize;
if (flags.HasFlag(RegionFlags.UnalignedAccess))
{
(guestAddress, guestSize) = tracking.GetUnalignedSafeRegion(address, size);
}
else
{
(guestAddress, guestSize) = (address, size);
}
return tracking.GetVirtualRegionsForHandle(guestAddress, guestSize, true);
}
private void InitializeRegions()
{
_allRegions.AddRange(_regions);
_allRegions.AddRange(_guestRegions);
foreach (var region in _allRegions)
{ {
region.Handles.Add(this); region.Handles.Add(this);
} }
@ -321,7 +356,7 @@ namespace Ryujinx.Memory.Tracking
lock (_tracking.TrackingLock) lock (_tracking.TrackingLock)
{ {
foreach (VirtualRegion region in _regions) foreach (VirtualRegion region in _allRegions)
{ {
protectionChanged |= region.UpdateProtection(); protectionChanged |= region.UpdateProtection();
} }
@ -379,7 +414,7 @@ namespace Ryujinx.Memory.Tracking
{ {
lock (_tracking.TrackingLock) lock (_tracking.TrackingLock)
{ {
foreach (VirtualRegion region in _regions) foreach (VirtualRegion region in _allRegions)
{ {
region.UpdateProtection(); region.UpdateProtection();
} }
@ -413,10 +448,19 @@ namespace Ryujinx.Memory.Tracking
/// </summary> /// </summary>
/// <param name="region">Virtual region to add as a child</param> /// <param name="region">Virtual region to add as a child</param>
internal void AddChild(VirtualRegion region) internal void AddChild(VirtualRegion region)
{
if (region.Guest)
{
_guestRegions.Add(region);
}
else
{ {
_regions.Add(region); _regions.Add(region);
} }
_allRegions.Add(region);
}
/// <summary> /// <summary>
/// Signal that this handle has been mapped or unmapped. /// Signal that this handle has been mapped or unmapped.
/// </summary> /// </summary>
@ -469,7 +513,7 @@ namespace Ryujinx.Memory.Tracking
lock (_tracking.TrackingLock) lock (_tracking.TrackingLock)
{ {
foreach (VirtualRegion region in _regions) foreach (VirtualRegion region in _allRegions)
{ {
region.RemoveHandle(this); region.RemoveHandle(this);
} }

View file

@ -13,10 +13,14 @@ namespace Ryujinx.Memory.Tracking
private readonly MemoryTracking _tracking; private readonly MemoryTracking _tracking;
private MemoryPermission _lastPermission; private MemoryPermission _lastPermission;
public VirtualRegion(MemoryTracking tracking, ulong address, ulong size, MemoryPermission lastPermission = MemoryPermission.Invalid) : base(address, size) public bool Guest { get; }
public VirtualRegion(MemoryTracking tracking, ulong address, ulong size, bool guest, MemoryPermission lastPermission = MemoryPermission.Invalid) : base(address, size)
{ {
_lastPermission = lastPermission; _lastPermission = lastPermission;
_tracking = tracking; _tracking = tracking;
Guest = guest;
} }
/// <inheritdoc/> /// <inheritdoc/>
@ -103,7 +107,7 @@ namespace Ryujinx.Memory.Tracking
if (_lastPermission != permission) if (_lastPermission != permission)
{ {
_tracking.ProtectVirtualRegion(this, permission); _tracking.ProtectVirtualRegion(this, permission, Guest);
_lastPermission = permission; _lastPermission = permission;
return true; return true;
@ -131,7 +135,7 @@ namespace Ryujinx.Memory.Tracking
public override INonOverlappingRange Split(ulong splitAddress) public override INonOverlappingRange Split(ulong splitAddress)
{ {
VirtualRegion newRegion = new(_tracking, splitAddress, EndAddress - splitAddress, _lastPermission); VirtualRegion newRegion = new(_tracking, splitAddress, EndAddress - splitAddress, Guest, _lastPermission);
Size = splitAddress - Address; Size = splitAddress - Address;
// The new region inherits all of our parents. // The new region inherits all of our parents.

View file

@ -107,7 +107,7 @@ namespace Ryujinx.Tests.Memory
throw new NotImplementedException(); throw new NotImplementedException();
} }
public void TrackingReprotect(ulong va, ulong size, MemoryPermission protection) public void TrackingReprotect(ulong va, ulong size, MemoryPermission protection, bool guest)
{ {
OnProtect?.Invoke(va, size, protection); OnProtect?.Invoke(va, size, protection);
} }