0
0
Fork 0

Move partial unmap handler to the native signal handler (#3437)

* Initial commit with a lot of testing stuff.

* Partial Unmap Cleanup Part 1

* Fix some minor issues, hopefully windows tests.

* Disable partial unmap tests on macos for now

Weird issue.

* Goodbye magic number

* Add COMPlus_EnableAlternateStackCheck for tests

`COMPlus_EnableAlternateStackCheck` is needed for NullReferenceException handling to work on linux after registering the signal handler, due to how dotnet registers its own signal handler.

* Address some feedback

* Force retry when memory is mapped in memory tracking

This case existed before, but returning `false` no longer retries, so it would crash immediately after unprotecting the memory... Now, we return `true` to deliberately retry.

This case existed before (was just broken by this change) and I don't really want to look into fixing the issue right now. Technically, this means that on guest code partial unmaps will retry _due to this_ rather than hitting the handler. I don't expect this to cause any issues.

This should fix random crashes in Xenoblade Chronicles 2.

* Use IsRangeMapped

* Suppress MockMemoryManager.UnmapEvent warning

This event is not signalled by the mock memory manager.

* Remove 4kb mapping
This commit is contained in:
riperiperi 2022-07-30 00:16:29 +02:00 committed by GitHub
parent 952d013c67
commit 14ce9e1567
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
24 changed files with 1355 additions and 391 deletions

View file

@ -197,12 +197,29 @@ namespace ARMeilleure.Signal
// Only call tracking if in range. // Only call tracking if in range.
context.BranchIfFalse(nextLabel, inRange, BasicBlockFrequency.Cold); context.BranchIfFalse(nextLabel, inRange, BasicBlockFrequency.Cold);
context.Copy(inRegionLocal, Const(1));
Operand offset = context.BitwiseAnd(context.Subtract(faultAddress, rangeAddress), Const(~PageMask)); Operand offset = context.BitwiseAnd(context.Subtract(faultAddress, rangeAddress), Const(~PageMask));
// Call the tracking action, with the pointer's relative offset to the base address. // Call the tracking action, with the pointer's relative offset to the base address.
Operand trackingActionPtr = context.Load(OperandType.I64, Const((ulong)signalStructPtr + rangeBaseOffset + 20)); Operand trackingActionPtr = context.Load(OperandType.I64, Const((ulong)signalStructPtr + rangeBaseOffset + 20));
context.Call(trackingActionPtr, OperandType.I32, offset, Const(PageSize), isWrite, Const(0));
context.Copy(inRegionLocal, Const(0));
Operand skipActionLabel = Label();
// Tracking action should be non-null to call it, otherwise assume false return.
context.BranchIfFalse(skipActionLabel, trackingActionPtr);
Operand result = context.Call(trackingActionPtr, OperandType.I32, offset, Const(PageSize), isWrite, Const(0));
context.Copy(inRegionLocal, result);
context.MarkLabel(skipActionLabel);
// If the tracking action returns false or does not exist, it might be an invalid access due to a partial overlap on Windows.
if (OperatingSystem.IsWindows())
{
context.BranchIfTrue(endLabel, inRegionLocal);
context.Copy(inRegionLocal, WindowsPartialUnmapHandler.EmitRetryFromAccessViolation(context));
}
context.Branch(endLabel); context.Branch(endLabel);

View file

@ -0,0 +1,84 @@
using ARMeilleure.IntermediateRepresentation;
using ARMeilleure.Translation;
using System;
using static ARMeilleure.IntermediateRepresentation.Operand.Factory;
namespace ARMeilleure.Signal
{
public struct NativeWriteLoopState
{
public int Running;
public int Error;
}
public static class TestMethods
{
public delegate bool DebugPartialUnmap();
public delegate int DebugThreadLocalMapGetOrReserve(int threadId, int initialState);
public delegate void DebugNativeWriteLoop(IntPtr nativeWriteLoopPtr, IntPtr writePtr);
public static DebugPartialUnmap GenerateDebugPartialUnmap()
{
EmitterContext context = new EmitterContext();
var result = WindowsPartialUnmapHandler.EmitRetryFromAccessViolation(context);
context.Return(result);
// Compile and return the function.
ControlFlowGraph cfg = context.GetControlFlowGraph();
OperandType[] argTypes = new OperandType[] { OperandType.I64 };
return Compiler.Compile(cfg, argTypes, OperandType.I32, CompilerOptions.HighCq).Map<DebugPartialUnmap>();
}
public static DebugThreadLocalMapGetOrReserve GenerateDebugThreadLocalMapGetOrReserve(IntPtr structPtr)
{
EmitterContext context = new EmitterContext();
var result = WindowsPartialUnmapHandler.EmitThreadLocalMapIntGetOrReserve(context, structPtr, context.LoadArgument(OperandType.I32, 0), context.LoadArgument(OperandType.I32, 1));
context.Return(result);
// Compile and return the function.
ControlFlowGraph cfg = context.GetControlFlowGraph();
OperandType[] argTypes = new OperandType[] { OperandType.I64 };
return Compiler.Compile(cfg, argTypes, OperandType.I32, CompilerOptions.HighCq).Map<DebugThreadLocalMapGetOrReserve>();
}
public static DebugNativeWriteLoop GenerateDebugNativeWriteLoop()
{
EmitterContext context = new EmitterContext();
// Loop a write to the target address until "running" is false.
Operand structPtr = context.Copy(context.LoadArgument(OperandType.I64, 0));
Operand writePtr = context.Copy(context.LoadArgument(OperandType.I64, 1));
Operand loopLabel = Label();
context.MarkLabel(loopLabel);
context.Store(writePtr, Const(12345));
Operand running = context.Load(OperandType.I32, structPtr);
context.BranchIfTrue(loopLabel, running);
context.Return();
// Compile and return the function.
ControlFlowGraph cfg = context.GetControlFlowGraph();
OperandType[] argTypes = new OperandType[] { OperandType.I64 };
return Compiler.Compile(cfg, argTypes, OperandType.None, CompilerOptions.HighCq).Map<DebugNativeWriteLoop>();
}
}
}

View file

@ -0,0 +1,186 @@
using ARMeilleure.IntermediateRepresentation;
using ARMeilleure.Translation;
using Ryujinx.Common.Memory.PartialUnmaps;
using System;
using static ARMeilleure.IntermediateRepresentation.Operand.Factory;
namespace ARMeilleure.Signal
{
/// <summary>
/// Methods to handle signals caused by partial unmaps. See the structs for C# implementations of the methods.
/// </summary>
internal static class WindowsPartialUnmapHandler
{
public static Operand EmitRetryFromAccessViolation(EmitterContext context)
{
IntPtr partialRemapStatePtr = PartialUnmapState.GlobalState;
IntPtr localCountsPtr = IntPtr.Add(partialRemapStatePtr, PartialUnmapState.LocalCountsOffset);
// Get the lock first.
EmitNativeReaderLockAcquire(context, IntPtr.Add(partialRemapStatePtr, PartialUnmapState.PartialUnmapLockOffset));
IntPtr getCurrentThreadId = WindowsSignalHandlerRegistration.GetCurrentThreadIdFunc();
Operand threadId = context.Call(Const((ulong)getCurrentThreadId), OperandType.I32);
Operand threadIndex = EmitThreadLocalMapIntGetOrReserve(context, localCountsPtr, threadId, Const(0));
Operand endLabel = Label();
Operand retry = context.AllocateLocal(OperandType.I32);
Operand threadIndexValidLabel = Label();
context.BranchIfFalse(threadIndexValidLabel, context.ICompareEqual(threadIndex, Const(-1)));
context.Copy(retry, Const(1)); // Always retry when thread local cannot be allocated.
context.Branch(endLabel);
context.MarkLabel(threadIndexValidLabel);
Operand threadLocalPartialUnmapsPtr = EmitThreadLocalMapIntGetValuePtr(context, localCountsPtr, threadIndex);
Operand threadLocalPartialUnmaps = context.Load(OperandType.I32, threadLocalPartialUnmapsPtr);
Operand partialUnmapsCount = context.Load(OperandType.I32, Const((ulong)IntPtr.Add(partialRemapStatePtr, PartialUnmapState.PartialUnmapsCountOffset)));
context.Copy(retry, context.ICompareNotEqual(threadLocalPartialUnmaps, partialUnmapsCount));
Operand noRetryLabel = Label();
context.BranchIfFalse(noRetryLabel, retry);
// if (retry) {
context.Store(threadLocalPartialUnmapsPtr, partialUnmapsCount);
context.Branch(endLabel);
context.MarkLabel(noRetryLabel);
// }
context.MarkLabel(endLabel);
// Finally, release the lock and return the retry value.
EmitNativeReaderLockRelease(context, IntPtr.Add(partialRemapStatePtr, PartialUnmapState.PartialUnmapLockOffset));
return retry;
}
public static Operand EmitThreadLocalMapIntGetOrReserve(EmitterContext context, IntPtr threadLocalMapPtr, Operand threadId, Operand initialState)
{
Operand idsPtr = Const((ulong)IntPtr.Add(threadLocalMapPtr, ThreadLocalMap<int>.ThreadIdsOffset));
Operand i = context.AllocateLocal(OperandType.I32);
context.Copy(i, Const(0));
// (Loop 1) Check all slots for a matching Thread ID (while also trying to allocate)
Operand endLabel = Label();
Operand loopLabel = Label();
context.MarkLabel(loopLabel);
Operand offset = context.Multiply(i, Const(sizeof(int)));
Operand idPtr = context.Add(idsPtr, context.SignExtend32(OperandType.I64, offset));
// Check that this slot has the thread ID.
Operand existingId = context.CompareAndSwap(idPtr, threadId, threadId);
// If it was already the thread ID, then we just need to return i.
context.BranchIfTrue(endLabel, context.ICompareEqual(existingId, threadId));
context.Copy(i, context.Add(i, Const(1)));
context.BranchIfTrue(loopLabel, context.ICompareLess(i, Const(ThreadLocalMap<int>.MapSize)));
// (Loop 2) Try take a slot that is 0 with our Thread ID.
context.Copy(i, Const(0)); // Reset i.
Operand loop2Label = Label();
context.MarkLabel(loop2Label);
Operand offset2 = context.Multiply(i, Const(sizeof(int)));
Operand idPtr2 = context.Add(idsPtr, context.SignExtend32(OperandType.I64, offset2));
// Try and swap in the thread id on top of 0.
Operand existingId2 = context.CompareAndSwap(idPtr2, Const(0), threadId);
Operand idNot0Label = Label();
// If it was 0, then we need to initialize the struct entry and return i.
context.BranchIfFalse(idNot0Label, context.ICompareEqual(existingId2, Const(0)));
Operand structsPtr = Const((ulong)IntPtr.Add(threadLocalMapPtr, ThreadLocalMap<int>.StructsOffset));
Operand structPtr = context.Add(structsPtr, context.SignExtend32(OperandType.I64, offset2));
context.Store(structPtr, initialState);
context.Branch(endLabel);
context.MarkLabel(idNot0Label);
context.Copy(i, context.Add(i, Const(1)));
context.BranchIfTrue(loop2Label, context.ICompareLess(i, Const(ThreadLocalMap<int>.MapSize)));
context.Copy(i, Const(-1)); // Could not place the thread in the list.
context.MarkLabel(endLabel);
return context.Copy(i);
}
private static Operand EmitThreadLocalMapIntGetValuePtr(EmitterContext context, IntPtr threadLocalMapPtr, Operand index)
{
Operand offset = context.Multiply(index, Const(sizeof(int)));
Operand structsPtr = Const((ulong)IntPtr.Add(threadLocalMapPtr, ThreadLocalMap<int>.StructsOffset));
return context.Add(structsPtr, context.SignExtend32(OperandType.I64, offset));
}
private static void EmitThreadLocalMapIntRelease(EmitterContext context, IntPtr threadLocalMapPtr, Operand threadId, Operand index)
{
Operand offset = context.Multiply(index, Const(sizeof(int)));
Operand idsPtr = Const((ulong)IntPtr.Add(threadLocalMapPtr, ThreadLocalMap<int>.ThreadIdsOffset));
Operand idPtr = context.Add(idsPtr, context.SignExtend32(OperandType.I64, offset));
context.CompareAndSwap(idPtr, threadId, Const(0));
}
private static void EmitAtomicAddI32(EmitterContext context, Operand ptr, Operand additive)
{
Operand loop = Label();
context.MarkLabel(loop);
Operand initial = context.Load(OperandType.I32, ptr);
Operand newValue = context.Add(initial, additive);
Operand replaced = context.CompareAndSwap(ptr, initial, newValue);
context.BranchIfFalse(loop, context.ICompareEqual(initial, replaced));
}
private static void EmitNativeReaderLockAcquire(EmitterContext context, IntPtr nativeReaderLockPtr)
{
Operand writeLockPtr = Const((ulong)IntPtr.Add(nativeReaderLockPtr, NativeReaderWriterLock.WriteLockOffset));
// Spin until we can acquire the write lock.
Operand spinLabel = Label();
context.MarkLabel(spinLabel);
// Old value must be 0 to continue (we gained the write lock)
context.BranchIfTrue(spinLabel, context.CompareAndSwap(writeLockPtr, Const(0), Const(1)));
// Increment reader count.
EmitAtomicAddI32(context, Const((ulong)IntPtr.Add(nativeReaderLockPtr, NativeReaderWriterLock.ReaderCountOffset)), Const(1));
// Release write lock.
context.CompareAndSwap(writeLockPtr, Const(1), Const(0));
}
private static void EmitNativeReaderLockRelease(EmitterContext context, IntPtr nativeReaderLockPtr)
{
// Decrement reader count.
EmitAtomicAddI32(context, Const((ulong)IntPtr.Add(nativeReaderLockPtr, NativeReaderWriterLock.ReaderCountOffset)), Const(-1));
}
}
}

View file

@ -1,9 +1,10 @@
using System; using System;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
namespace ARMeilleure.Signal namespace ARMeilleure.Signal
{ {
class WindowsSignalHandlerRegistration unsafe class WindowsSignalHandlerRegistration
{ {
[DllImport("kernel32.dll")] [DllImport("kernel32.dll")]
private static extern IntPtr AddVectoredExceptionHandler(uint first, IntPtr handler); private static extern IntPtr AddVectoredExceptionHandler(uint first, IntPtr handler);
@ -11,6 +12,14 @@ namespace ARMeilleure.Signal
[DllImport("kernel32.dll")] [DllImport("kernel32.dll")]
private static extern ulong RemoveVectoredExceptionHandler(IntPtr handle); private static extern ulong RemoveVectoredExceptionHandler(IntPtr handle);
[DllImport("kernel32.dll", SetLastError = true, CharSet = CharSet.Ansi)]
static extern IntPtr LoadLibrary([MarshalAs(UnmanagedType.LPStr)] string lpFileName);
[DllImport("kernel32.dll", CharSet = CharSet.Ansi, ExactSpelling = true, SetLastError = true)]
private static extern IntPtr GetProcAddress(IntPtr hModule, string procName);
private static IntPtr _getCurrentThreadIdPtr;
public static IntPtr RegisterExceptionHandler(IntPtr action) public static IntPtr RegisterExceptionHandler(IntPtr action)
{ {
return AddVectoredExceptionHandler(1, action); return AddVectoredExceptionHandler(1, action);
@ -20,5 +29,17 @@ namespace ARMeilleure.Signal
{ {
return RemoveVectoredExceptionHandler(handle) != 0; return RemoveVectoredExceptionHandler(handle) != 0;
} }
public static IntPtr GetCurrentThreadIdFunc()
{
if (_getCurrentThreadIdPtr == IntPtr.Zero)
{
IntPtr handle = LoadLibrary("kernel32.dll");
_getCurrentThreadIdPtr = GetProcAddress(handle, "GetCurrentThreadId");
}
return _getCurrentThreadIdPtr;
}
} }
} }

View file

@ -0,0 +1,80 @@
using System.Runtime.InteropServices;
using System.Threading;
using static Ryujinx.Common.Memory.PartialUnmaps.PartialUnmapHelpers;
namespace Ryujinx.Common.Memory.PartialUnmaps
{
/// <summary>
/// A simple implementation of a ReaderWriterLock which can be used from native code.
/// </summary>
[StructLayout(LayoutKind.Sequential, Pack = 1)]
public struct NativeReaderWriterLock
{
public int WriteLock;
public int ReaderCount;
public static int WriteLockOffset;
public static int ReaderCountOffset;
/// <summary>
/// Populates the field offsets for use when emitting native code.
/// </summary>
static NativeReaderWriterLock()
{
NativeReaderWriterLock instance = new NativeReaderWriterLock();
WriteLockOffset = OffsetOf(ref instance, ref instance.WriteLock);
ReaderCountOffset = OffsetOf(ref instance, ref instance.ReaderCount);
}
/// <summary>
/// Acquires the reader lock.
/// </summary>
public void AcquireReaderLock()
{
// Must take write lock for a very short time to become a reader.
while (Interlocked.CompareExchange(ref WriteLock, 1, 0) != 0) { }
Interlocked.Increment(ref ReaderCount);
Interlocked.Exchange(ref WriteLock, 0);
}
/// <summary>
/// Releases the reader lock.
/// </summary>
public void ReleaseReaderLock()
{
Interlocked.Decrement(ref ReaderCount);
}
/// <summary>
/// Upgrades to a writer lock. The reader lock is temporarily released while obtaining the writer lock.
/// </summary>
public void UpgradeToWriterLock()
{
// Prevent any more threads from entering reader.
// If the write lock is already taken, wait for it to not be taken.
Interlocked.Decrement(ref ReaderCount);
while (Interlocked.CompareExchange(ref WriteLock, 1, 0) != 0) { }
// Wait for reader count to drop to 0, then take the lock again as the only reader.
while (Interlocked.CompareExchange(ref ReaderCount, 1, 0) != 0) { }
}
/// <summary>
/// Downgrades from a writer lock, back to a reader one.
/// </summary>
public void DowngradeFromWriterLock()
{
// Release the WriteLock.
Interlocked.Exchange(ref WriteLock, 0);
}
}
}

View file

@ -0,0 +1,20 @@
using System.Runtime.CompilerServices;
namespace Ryujinx.Common.Memory.PartialUnmaps
{
static class PartialUnmapHelpers
{
/// <summary>
/// Calculates a byte offset of a given field within a struct.
/// </summary>
/// <typeparam name="T">Struct type</typeparam>
/// <typeparam name="T2">Field type</typeparam>
/// <param name="storage">Parent struct</param>
/// <param name="target">Field</param>
/// <returns>The byte offset of the given field in the given struct</returns>
public static int OffsetOf<T, T2>(ref T2 storage, ref T target)
{
return (int)Unsafe.ByteOffset(ref Unsafe.As<T2, T>(ref storage), ref target);
}
}
}

View file

@ -0,0 +1,160 @@
using System;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.Versioning;
using System.Threading;
using static Ryujinx.Common.Memory.PartialUnmaps.PartialUnmapHelpers;
namespace Ryujinx.Common.Memory.PartialUnmaps
{
/// <summary>
/// State for partial unmaps. Intended to be used on Windows.
/// </summary>
[StructLayout(LayoutKind.Sequential, Pack = 1)]
public struct PartialUnmapState
{
public NativeReaderWriterLock PartialUnmapLock;
public int PartialUnmapsCount;
public ThreadLocalMap<int> LocalCounts;
public readonly static int PartialUnmapLockOffset;
public readonly static int PartialUnmapsCountOffset;
public readonly static int LocalCountsOffset;
public readonly static IntPtr GlobalState;
[SupportedOSPlatform("windows")]
[DllImport("kernel32.dll")]
public static extern int GetCurrentThreadId();
[SupportedOSPlatform("windows")]
[DllImport("kernel32.dll", SetLastError = true)]
static extern IntPtr OpenThread(int dwDesiredAccess, bool bInheritHandle, uint dwThreadId);
[SupportedOSPlatform("windows")]
[DllImport("kernel32.dll", SetLastError = true)]
public static extern bool CloseHandle(IntPtr hObject);
[SupportedOSPlatform("windows")]
[DllImport("kernel32.dll", SetLastError = true)]
static extern bool GetExitCodeThread(IntPtr hThread, out uint lpExitCode);
/// <summary>
/// Creates a global static PartialUnmapState and populates the field offsets.
/// </summary>
static unsafe PartialUnmapState()
{
PartialUnmapState instance = new PartialUnmapState();
PartialUnmapLockOffset = OffsetOf(ref instance, ref instance.PartialUnmapLock);
PartialUnmapsCountOffset = OffsetOf(ref instance, ref instance.PartialUnmapsCount);
LocalCountsOffset = OffsetOf(ref instance, ref instance.LocalCounts);
int size = Unsafe.SizeOf<PartialUnmapState>();
GlobalState = Marshal.AllocHGlobal(size);
Unsafe.InitBlockUnaligned((void*)GlobalState, 0, (uint)size);
}
/// <summary>
/// Resets the global state.
/// </summary>
public static unsafe void Reset()
{
int size = Unsafe.SizeOf<PartialUnmapState>();
Unsafe.InitBlockUnaligned((void*)GlobalState, 0, (uint)size);
}
/// <summary>
/// Gets a reference to the global state.
/// </summary>
/// <returns>A reference to the global state</returns>
public static unsafe ref PartialUnmapState GetRef()
{
return ref Unsafe.AsRef<PartialUnmapState>((void*)GlobalState);
}
/// <summary>
/// Checks if an access violation handler should retry execution due to a fault caused by partial unmap.
/// </summary>
/// <remarks>
/// Due to Windows limitations, <see cref="UnmapView"/> might need to unmap more memory than requested.
/// The additional memory that was unmapped is later remapped, however this leaves a time gap where the
/// memory might be accessed but is unmapped. Users of the API must compensate for that by catching the
/// access violation and retrying if it happened between the unmap and remap operation.
/// This method can be used to decide if retrying in such cases is necessary or not.
///
/// This version of the function is not used, but serves as a reference for the native
/// implementation in ARMeilleure.
/// </remarks>
/// <returns>True if execution should be retried, false otherwise</returns>
[SupportedOSPlatform("windows")]
public bool RetryFromAccessViolation()
{
PartialUnmapLock.AcquireReaderLock();
int threadID = GetCurrentThreadId();
int threadIndex = LocalCounts.GetOrReserve(threadID, 0);
if (threadIndex == -1)
{
// Out of thread local space... try again later.
PartialUnmapLock.ReleaseReaderLock();
return true;
}
ref int threadLocalPartialUnmapsCount = ref LocalCounts.GetValue(threadIndex);
bool retry = threadLocalPartialUnmapsCount != PartialUnmapsCount;
if (retry)
{
threadLocalPartialUnmapsCount = PartialUnmapsCount;
}
PartialUnmapLock.ReleaseReaderLock();
return retry;
}
/// <summary>
/// Iterates and trims threads in the thread -> count map that
/// are no longer active.
/// </summary>
[SupportedOSPlatform("windows")]
public void TrimThreads()
{
const uint ExitCodeStillActive = 259;
const int ThreadQueryInformation = 0x40;
Span<int> ids = LocalCounts.ThreadIds.ToSpan();
for (int i = 0; i < ids.Length; i++)
{
int id = ids[i];
if (id != 0)
{
IntPtr handle = OpenThread(ThreadQueryInformation, false, (uint)id);
if (handle == IntPtr.Zero)
{
Interlocked.CompareExchange(ref ids[i], 0, id);
}
else
{
GetExitCodeThread(handle, out uint exitCode);
if (exitCode != ExitCodeStillActive)
{
Interlocked.CompareExchange(ref ids[i], 0, id);
}
CloseHandle(handle);
}
}
}
}
}
}

View file

@ -0,0 +1,92 @@
using System.Runtime.InteropServices;
using System.Threading;
using static Ryujinx.Common.Memory.PartialUnmaps.PartialUnmapHelpers;
namespace Ryujinx.Common.Memory.PartialUnmaps
{
/// <summary>
/// A simple fixed size thread safe map that can be used from native code.
/// Integer thread IDs map to corresponding structs.
/// </summary>
/// <typeparam name="T">The value type for the map</typeparam>
[StructLayout(LayoutKind.Sequential, Pack = 1)]
public struct ThreadLocalMap<T> where T : unmanaged
{
public const int MapSize = 20;
public Array20<int> ThreadIds;
public Array20<T> Structs;
public static int ThreadIdsOffset;
public static int StructsOffset;
/// <summary>
/// Populates the field offsets for use when emitting native code.
/// </summary>
static ThreadLocalMap()
{
ThreadLocalMap<T> instance = new ThreadLocalMap<T>();
ThreadIdsOffset = OffsetOf(ref instance, ref instance.ThreadIds);
StructsOffset = OffsetOf(ref instance, ref instance.Structs);
}
/// <summary>
/// Gets the index of a given thread ID in the map, or reserves one.
/// When reserving a struct, its value is set to the given initial value.
/// Returns -1 when there is no space to reserve a new entry.
/// </summary>
/// <param name="threadId">Thread ID to use as a key</param>
/// <param name="initial">Initial value of the associated struct.</param>
/// <returns>The index of the entry, or -1 if none</returns>
public int GetOrReserve(int threadId, T initial)
{
// Try get a match first.
for (int i = 0; i < MapSize; i++)
{
int compare = Interlocked.CompareExchange(ref ThreadIds[i], threadId, threadId);
if (compare == threadId)
{
return i;
}
}
// Try get a free entry. Since the id is assumed to be unique to this thread, we know it doesn't exist yet.
for (int i = 0; i < MapSize; i++)
{
int compare = Interlocked.CompareExchange(ref ThreadIds[i], threadId, 0);
if (compare == 0)
{
Structs[i] = initial;
return i;
}
}
return -1;
}
/// <summary>
/// Gets the struct value for a given map entry.
/// </summary>
/// <param name="index">Index of the entry</param>
/// <returns>A reference to the struct value</returns>
public ref T GetValue(int index)
{
return ref Structs[index];
}
/// <summary>
/// Releases an entry from the map.
/// </summary>
/// <param name="index">Index of the entry to release</param>
public void Release(int index)
{
Interlocked.Exchange(ref ThreadIds[index], 0);
}
}
}

View file

@ -89,10 +89,10 @@ namespace Ryujinx.Cpu.Jit
MemoryAllocationFlags asFlags = MemoryAllocationFlags.Reserve | MemoryAllocationFlags.ViewCompatible; MemoryAllocationFlags asFlags = MemoryAllocationFlags.Reserve | MemoryAllocationFlags.ViewCompatible;
_addressSpace = new MemoryBlock(asSize, asFlags); _addressSpace = new MemoryBlock(asSize, asFlags);
_addressSpaceMirror = new MemoryBlock(asSize, asFlags | MemoryAllocationFlags.ForceWindows4KBViewMapping); _addressSpaceMirror = new MemoryBlock(asSize, asFlags);
Tracking = new MemoryTracking(this, PageSize, invalidAccessHandler); Tracking = new MemoryTracking(this, PageSize, invalidAccessHandler);
_memoryEh = new MemoryEhMeilleure(_addressSpace, Tracking); _memoryEh = new MemoryEhMeilleure(_addressSpace, _addressSpaceMirror, Tracking);
} }
/// <summary> /// <summary>

View file

@ -6,36 +6,57 @@ using System.Runtime.InteropServices;
namespace Ryujinx.Cpu namespace Ryujinx.Cpu
{ {
class MemoryEhMeilleure : IDisposable public class MemoryEhMeilleure : IDisposable
{ {
private delegate bool TrackingEventDelegate(ulong address, ulong size, bool write, bool precise = false); private delegate bool TrackingEventDelegate(ulong address, ulong size, bool write, bool precise = false);
private readonly MemoryBlock _addressSpace;
private readonly MemoryTracking _tracking; private readonly MemoryTracking _tracking;
private readonly TrackingEventDelegate _trackingEvent; private readonly TrackingEventDelegate _trackingEvent;
private readonly ulong _baseAddress; private readonly ulong _baseAddress;
private readonly ulong _mirrorAddress;
public MemoryEhMeilleure(MemoryBlock addressSpace, MemoryTracking tracking) public MemoryEhMeilleure(MemoryBlock addressSpace, MemoryBlock addressSpaceMirror, MemoryTracking tracking)
{ {
_addressSpace = addressSpace;
_tracking = tracking; _tracking = tracking;
_baseAddress = (ulong)_addressSpace.Pointer; _baseAddress = (ulong)addressSpace.Pointer;
ulong endAddress = _baseAddress + addressSpace.Size; ulong endAddress = _baseAddress + addressSpace.Size;
_trackingEvent = new TrackingEventDelegate(tracking.VirtualMemoryEventEh); _trackingEvent = new TrackingEventDelegate(tracking.VirtualMemoryEvent);
bool added = NativeSignalHandler.AddTrackedRegion((nuint)_baseAddress, (nuint)endAddress, Marshal.GetFunctionPointerForDelegate(_trackingEvent)); bool added = NativeSignalHandler.AddTrackedRegion((nuint)_baseAddress, (nuint)endAddress, Marshal.GetFunctionPointerForDelegate(_trackingEvent));
if (!added) if (!added)
{ {
throw new InvalidOperationException("Number of allowed tracked regions exceeded."); throw new InvalidOperationException("Number of allowed tracked regions exceeded.");
} }
if (OperatingSystem.IsWindows())
{
// Add a tracking event with no signal handler for the mirror on Windows.
// The native handler has its own code to check for the partial overlap race when regions are protected by accident,
// and when there is no signal handler present.
_mirrorAddress = (ulong)addressSpaceMirror.Pointer;
ulong endAddressMirror = _mirrorAddress + addressSpace.Size;
bool addedMirror = NativeSignalHandler.AddTrackedRegion((nuint)_mirrorAddress, (nuint)endAddressMirror, IntPtr.Zero);
if (!addedMirror)
{
throw new InvalidOperationException("Number of allowed tracked regions exceeded.");
}
}
} }
public void Dispose() public void Dispose()
{ {
NativeSignalHandler.RemoveTrackedRegion((nuint)_baseAddress); NativeSignalHandler.RemoveTrackedRegion((nuint)_baseAddress);
if (_mirrorAddress != 0)
{
NativeSignalHandler.RemoveTrackedRegion((nuint)_mirrorAddress);
}
} }
} }
} }

View file

@ -4,7 +4,7 @@ using System.Collections.Generic;
namespace Ryujinx.Memory.Tests namespace Ryujinx.Memory.Tests
{ {
class MockVirtualMemoryManager : IVirtualMemoryManager public class MockVirtualMemoryManager : IVirtualMemoryManager
{ {
public bool NoMappings = false; public bool NoMappings = false;

View file

@ -38,9 +38,15 @@ namespace Ryujinx.Memory.Tests
Assert.AreEqual(Marshal.ReadInt32(_memoryBlock.Pointer, 0x2040), 0xbadc0de); Assert.AreEqual(Marshal.ReadInt32(_memoryBlock.Pointer, 0x2040), 0xbadc0de);
} }
[Test, Explicit] [Test]
public void Test_Alias() public void Test_Alias()
{ {
if (OperatingSystem.IsMacOS())
{
// Memory aliasing tests fail on CI at the moment.
return;
}
using MemoryBlock backing = new MemoryBlock(0x10000, MemoryAllocationFlags.Mirrorable); using MemoryBlock backing = new MemoryBlock(0x10000, MemoryAllocationFlags.Mirrorable);
using MemoryBlock toAlias = new MemoryBlock(0x10000, MemoryAllocationFlags.Reserve | MemoryAllocationFlags.ViewCompatible); using MemoryBlock toAlias = new MemoryBlock(0x10000, MemoryAllocationFlags.Reserve | MemoryAllocationFlags.ViewCompatible);
@ -51,9 +57,15 @@ namespace Ryujinx.Memory.Tests
Assert.AreEqual(Marshal.ReadInt32(backing.Pointer, 0x1000), 0xbadc0de); Assert.AreEqual(Marshal.ReadInt32(backing.Pointer, 0x1000), 0xbadc0de);
} }
[Test, Explicit] [Test]
public void Test_AliasRandom() public void Test_AliasRandom()
{ {
if (OperatingSystem.IsMacOS())
{
// Memory aliasing tests fail on CI at the moment.
return;
}
using MemoryBlock backing = new MemoryBlock(0x80000, MemoryAllocationFlags.Mirrorable); using MemoryBlock backing = new MemoryBlock(0x80000, MemoryAllocationFlags.Mirrorable);
using MemoryBlock toAlias = new MemoryBlock(0x80000, MemoryAllocationFlags.Reserve | MemoryAllocationFlags.ViewCompatible); using MemoryBlock toAlias = new MemoryBlock(0x80000, MemoryAllocationFlags.Reserve | MemoryAllocationFlags.ViewCompatible);

View file

@ -35,12 +35,6 @@ namespace Ryujinx.Memory
/// Indicates that the memory block should support mapping views of a mirrorable memory block. /// Indicates that the memory block should support mapping views of a mirrorable memory block.
/// The block that is to have their views mapped should be created with the <see cref="Mirrorable"/> flag. /// The block that is to have their views mapped should be created with the <see cref="Mirrorable"/> flag.
/// </summary> /// </summary>
ViewCompatible = 1 << 3, ViewCompatible = 1 << 3
/// <summary>
/// Forces views to be mapped page by page on Windows. When partial unmaps are done, this avoids the need
/// to unmap the full range and remap sub-ranges, which creates a time window with incorrectly unmapped memory.
/// </summary>
ForceWindows4KBViewMapping = 1 << 4
} }
} }

View file

@ -13,14 +13,11 @@ namespace Ryujinx.Memory
private readonly bool _usesSharedMemory; private readonly bool _usesSharedMemory;
private readonly bool _isMirror; private readonly bool _isMirror;
private readonly bool _viewCompatible; private readonly bool _viewCompatible;
private readonly bool _forceWindows4KBView;
private IntPtr _sharedMemory; private IntPtr _sharedMemory;
private IntPtr _pointer; private IntPtr _pointer;
private ConcurrentDictionary<MemoryBlock, byte> _viewStorages; private ConcurrentDictionary<MemoryBlock, byte> _viewStorages;
private int _viewCount; private int _viewCount;
internal bool ForceWindows4KBView => _forceWindows4KBView;
/// <summary> /// <summary>
/// Pointer to the memory block data. /// Pointer to the memory block data.
/// </summary> /// </summary>
@ -49,8 +46,7 @@ namespace Ryujinx.Memory
else if (flags.HasFlag(MemoryAllocationFlags.Reserve)) else if (flags.HasFlag(MemoryAllocationFlags.Reserve))
{ {
_viewCompatible = flags.HasFlag(MemoryAllocationFlags.ViewCompatible); _viewCompatible = flags.HasFlag(MemoryAllocationFlags.ViewCompatible);
_forceWindows4KBView = flags.HasFlag(MemoryAllocationFlags.ForceWindows4KBViewMapping); _pointer = MemoryManagement.Reserve(size, _viewCompatible);
_pointer = MemoryManagement.Reserve(size, _viewCompatible, _forceWindows4KBView);
} }
else else
{ {
@ -173,7 +169,7 @@ namespace Ryujinx.Memory
/// <exception cref="MemoryProtectionException">Throw when <paramref name="permission"/> is invalid</exception> /// <exception cref="MemoryProtectionException">Throw when <paramref name="permission"/> is invalid</exception>
public void Reprotect(ulong offset, ulong size, MemoryPermission permission, bool throwOnFail = true) public void Reprotect(ulong offset, ulong size, MemoryPermission permission, bool throwOnFail = true)
{ {
MemoryManagement.Reprotect(GetPointerInternal(offset, size), size, permission, _viewCompatible, _forceWindows4KBView, throwOnFail); MemoryManagement.Reprotect(GetPointerInternal(offset, size), size, permission, _viewCompatible, throwOnFail);
} }
/// <summary> /// <summary>
@ -406,7 +402,7 @@ namespace Ryujinx.Memory
} }
else else
{ {
MemoryManagement.Free(ptr, Size, _forceWindows4KBView); MemoryManagement.Free(ptr, Size);
} }
foreach (MemoryBlock viewStorage in _viewStorages.Keys) foreach (MemoryBlock viewStorage in _viewStorages.Keys)

View file

@ -20,11 +20,11 @@ namespace Ryujinx.Memory
} }
} }
public static IntPtr Reserve(ulong size, bool viewCompatible, bool force4KBMap) public static IntPtr Reserve(ulong size, bool viewCompatible)
{ {
if (OperatingSystem.IsWindows()) if (OperatingSystem.IsWindows())
{ {
return MemoryManagementWindows.Reserve((IntPtr)size, viewCompatible, force4KBMap); return MemoryManagementWindows.Reserve((IntPtr)size, viewCompatible);
} }
else if (OperatingSystem.IsLinux() || OperatingSystem.IsMacOS()) else if (OperatingSystem.IsLinux() || OperatingSystem.IsMacOS())
{ {
@ -72,14 +72,7 @@ namespace Ryujinx.Memory
{ {
if (OperatingSystem.IsWindows()) if (OperatingSystem.IsWindows())
{ {
if (owner.ForceWindows4KBView) MemoryManagementWindows.MapView(sharedMemory, srcOffset, address, (IntPtr)size, owner);
{
MemoryManagementWindows.MapView4KB(sharedMemory, srcOffset, address, (IntPtr)size);
}
else
{
MemoryManagementWindows.MapView(sharedMemory, srcOffset, address, (IntPtr)size, owner);
}
} }
else if (OperatingSystem.IsLinux() || OperatingSystem.IsMacOS()) else if (OperatingSystem.IsLinux() || OperatingSystem.IsMacOS())
{ {
@ -95,14 +88,7 @@ namespace Ryujinx.Memory
{ {
if (OperatingSystem.IsWindows()) if (OperatingSystem.IsWindows())
{ {
if (owner.ForceWindows4KBView) MemoryManagementWindows.UnmapView(sharedMemory, address, (IntPtr)size, owner);
{
MemoryManagementWindows.UnmapView4KB(address, (IntPtr)size);
}
else
{
MemoryManagementWindows.UnmapView(sharedMemory, address, (IntPtr)size, owner);
}
} }
else if (OperatingSystem.IsLinux() || OperatingSystem.IsMacOS()) else if (OperatingSystem.IsLinux() || OperatingSystem.IsMacOS())
{ {
@ -114,20 +100,13 @@ namespace Ryujinx.Memory
} }
} }
public static void Reprotect(IntPtr address, ulong size, MemoryPermission permission, bool forView, bool force4KBMap, bool throwOnFail) public static void Reprotect(IntPtr address, ulong size, MemoryPermission permission, bool forView, bool throwOnFail)
{ {
bool result; bool result;
if (OperatingSystem.IsWindows()) if (OperatingSystem.IsWindows())
{ {
if (forView && force4KBMap) result = MemoryManagementWindows.Reprotect(address, (IntPtr)size, permission, forView);
{
result = MemoryManagementWindows.Reprotect4KB(address, (IntPtr)size, permission, forView);
}
else
{
result = MemoryManagementWindows.Reprotect(address, (IntPtr)size, permission, forView);
}
} }
else if (OperatingSystem.IsLinux() || OperatingSystem.IsMacOS()) else if (OperatingSystem.IsLinux() || OperatingSystem.IsMacOS())
{ {
@ -144,11 +123,11 @@ namespace Ryujinx.Memory
} }
} }
public static bool Free(IntPtr address, ulong size, bool force4KBMap) public static bool Free(IntPtr address, ulong size)
{ {
if (OperatingSystem.IsWindows()) if (OperatingSystem.IsWindows())
{ {
return MemoryManagementWindows.Free(address, (IntPtr)size, force4KBMap); return MemoryManagementWindows.Free(address, (IntPtr)size);
} }
else if (OperatingSystem.IsLinux() || OperatingSystem.IsMacOS()) else if (OperatingSystem.IsLinux() || OperatingSystem.IsMacOS())
{ {

View file

@ -10,23 +10,19 @@ namespace Ryujinx.Memory
public const int PageSize = 0x1000; public const int PageSize = 0x1000;
private static readonly PlaceholderManager _placeholders = new PlaceholderManager(); private static readonly PlaceholderManager _placeholders = new PlaceholderManager();
private static readonly PlaceholderManager4KB _placeholders4KB = new PlaceholderManager4KB();
public static IntPtr Allocate(IntPtr size) public static IntPtr Allocate(IntPtr size)
{ {
return AllocateInternal(size, AllocationType.Reserve | AllocationType.Commit); return AllocateInternal(size, AllocationType.Reserve | AllocationType.Commit);
} }
public static IntPtr Reserve(IntPtr size, bool viewCompatible, bool force4KBMap) public static IntPtr Reserve(IntPtr size, bool viewCompatible)
{ {
if (viewCompatible) if (viewCompatible)
{ {
IntPtr baseAddress = AllocateInternal2(size, AllocationType.Reserve | AllocationType.ReservePlaceholder); IntPtr baseAddress = AllocateInternal2(size, AllocationType.Reserve | AllocationType.ReservePlaceholder);
if (!force4KBMap) _placeholders.ReserveRange((ulong)baseAddress, (ulong)size);
{
_placeholders.ReserveRange((ulong)baseAddress, (ulong)size);
}
return baseAddress; return baseAddress;
} }
@ -73,49 +69,11 @@ namespace Ryujinx.Memory
_placeholders.MapView(sharedMemory, srcOffset, location, size, owner); _placeholders.MapView(sharedMemory, srcOffset, location, size, owner);
} }
public static void MapView4KB(IntPtr sharedMemory, ulong srcOffset, IntPtr location, IntPtr size)
{
_placeholders4KB.UnmapAndMarkRangeAsMapped(location, size);
ulong uaddress = (ulong)location;
ulong usize = (ulong)size;
IntPtr endLocation = (IntPtr)(uaddress + usize);
while (location != endLocation)
{
WindowsApi.VirtualFree(location, (IntPtr)PageSize, AllocationType.Release | AllocationType.PreservePlaceholder);
var ptr = WindowsApi.MapViewOfFile3(
sharedMemory,
WindowsApi.CurrentProcessHandle,
location,
srcOffset,
(IntPtr)PageSize,
0x4000,
MemoryProtection.ReadWrite,
IntPtr.Zero,
0);
if (ptr == IntPtr.Zero)
{
throw new WindowsApiException("MapViewOfFile3");
}
location += PageSize;
srcOffset += PageSize;
}
}
public static void UnmapView(IntPtr sharedMemory, IntPtr location, IntPtr size, MemoryBlock owner) public static void UnmapView(IntPtr sharedMemory, IntPtr location, IntPtr size, MemoryBlock owner)
{ {
_placeholders.UnmapView(sharedMemory, location, size, owner); _placeholders.UnmapView(sharedMemory, location, size, owner);
} }
public static void UnmapView4KB(IntPtr location, IntPtr size)
{
_placeholders4KB.UnmapView(location, size);
}
public static bool Reprotect(IntPtr address, IntPtr size, MemoryPermission permission, bool forView) public static bool Reprotect(IntPtr address, IntPtr size, MemoryPermission permission, bool forView)
{ {
if (forView) if (forView)
@ -128,34 +86,9 @@ namespace Ryujinx.Memory
} }
} }
public static bool Reprotect4KB(IntPtr address, IntPtr size, MemoryPermission permission, bool forView) public static bool Free(IntPtr address, IntPtr size)
{ {
ulong uaddress = (ulong)address; _placeholders.UnreserveRange((ulong)address, (ulong)size);
ulong usize = (ulong)size;
while (usize > 0)
{
if (!WindowsApi.VirtualProtect((IntPtr)uaddress, (IntPtr)PageSize, WindowsApi.GetProtection(permission), out _))
{
return false;
}
uaddress += PageSize;
usize -= PageSize;
}
return true;
}
public static bool Free(IntPtr address, IntPtr size, bool force4KBMap)
{
if (force4KBMap)
{
_placeholders4KB.UnmapRange(address, size);
}
else
{
_placeholders.UnreserveRange((ulong)address, (ulong)size);
}
return WindowsApi.VirtualFree(address, IntPtr.Zero, AllocationType.Release); return WindowsApi.VirtualFree(address, IntPtr.Zero, AllocationType.Release);
} }
@ -207,10 +140,5 @@ namespace Ryujinx.Memory
throw new ArgumentException("Invalid address.", nameof(address)); throw new ArgumentException("Invalid address.", nameof(address));
} }
} }
public static bool RetryFromAccessViolation()
{
return _placeholders.RetryFromAccessViolation();
}
} }
} }

View file

@ -188,30 +188,6 @@ namespace Ryujinx.Memory.Tracking
return VirtualMemoryEvent(address, 1, write); return VirtualMemoryEvent(address, 1, write);
} }
/// <summary>
/// Signal that a virtual memory event happened at the given location.
/// This is similar VirtualMemoryEvent, but on Windows, it might also return true after a partial unmap.
/// This should only be called from the exception handler.
/// </summary>
/// <param name="address">Virtual address accessed</param>
/// <param name="size">Size of the region affected in bytes</param>
/// <param name="write">Whether the region was written to or read</param>
/// <param name="precise">True if the access is precise, false otherwise</param>
/// <returns>True if the event triggered any tracking regions, false otherwise</returns>
public bool VirtualMemoryEventEh(ulong address, ulong size, bool write, bool precise = false)
{
// Windows has a limitation, it can't do partial unmaps.
// For this reason, we need to unmap the whole range and then remap the sub-ranges.
// When this happens, we might have caused a undesirable access violation from the time that the range was unmapped.
// In this case, try again as the memory might be mapped now.
if (OperatingSystem.IsWindows() && MemoryManagementWindows.RetryFromAccessViolation())
{
return true;
}
return VirtualMemoryEvent(address, size, write, precise);
}
/// <summary> /// <summary>
/// Signal that a virtual memory event happened at the given location. /// Signal that a virtual memory event happened at the given location.
/// This can be flagged as a precise event, which will avoid reprotection and call special handlers if possible. /// This can be flagged as a precise event, which will avoid reprotection and call special handlers if possible.
@ -237,10 +213,12 @@ namespace Ryujinx.Memory.Tracking
if (count == 0 && !precise) if (count == 0 && !precise)
{ {
if (_memoryManager.IsMapped(address)) if (_memoryManager.IsRangeMapped(address, size))
{ {
// TODO: There is currently the possibility that a page can be protected after its virtual region is removed.
// This code handles that case when it happens, but it would be better to find out how this happens.
_memoryManager.TrackingReprotect(address & ~(ulong)(_pageSize - 1), (ulong)_pageSize, MemoryPermission.ReadAndWrite); _memoryManager.TrackingReprotect(address & ~(ulong)(_pageSize - 1), (ulong)_pageSize, MemoryPermission.ReadAndWrite);
return false; // We can't handle this - it's probably a real invalid access. return true; // This memory _should_ be mapped, so we need to try again.
} }
else else
{ {

View file

@ -1,5 +1,7 @@
using Ryujinx.Common.Memory.PartialUnmaps;
using System; using System;
using System.Diagnostics; using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Runtime.Versioning; using System.Runtime.Versioning;
using System.Threading; using System.Threading;
@ -13,13 +15,10 @@ namespace Ryujinx.Memory.WindowsShared
{ {
private const ulong MinimumPageSize = 0x1000; private const ulong MinimumPageSize = 0x1000;
[ThreadStatic]
private static int _threadLocalPartialUnmapsCount;
private readonly IntervalTree<ulong, ulong> _mappings; private readonly IntervalTree<ulong, ulong> _mappings;
private readonly IntervalTree<ulong, MemoryPermission> _protections; private readonly IntervalTree<ulong, MemoryPermission> _protections;
private readonly ReaderWriterLock _partialUnmapLock; private readonly IntPtr _partialUnmapStatePtr;
private int _partialUnmapsCount; private readonly Thread _partialUnmapTrimThread;
/// <summary> /// <summary>
/// Creates a new instance of the Windows memory placeholder manager. /// Creates a new instance of the Windows memory placeholder manager.
@ -28,7 +27,35 @@ namespace Ryujinx.Memory.WindowsShared
{ {
_mappings = new IntervalTree<ulong, ulong>(); _mappings = new IntervalTree<ulong, ulong>();
_protections = new IntervalTree<ulong, MemoryPermission>(); _protections = new IntervalTree<ulong, MemoryPermission>();
_partialUnmapLock = new ReaderWriterLock();
_partialUnmapStatePtr = PartialUnmapState.GlobalState;
_partialUnmapTrimThread = new Thread(TrimThreadLocalMapLoop);
_partialUnmapTrimThread.Name = "CPU.PartialUnmapTrimThread";
_partialUnmapTrimThread.IsBackground = true;
_partialUnmapTrimThread.Start();
}
/// <summary>
/// Gets a reference to the partial unmap state struct.
/// </summary>
/// <returns>A reference to the partial unmap state struct</returns>
private unsafe ref PartialUnmapState GetPartialUnmapState()
{
return ref Unsafe.AsRef<PartialUnmapState>((void*)_partialUnmapStatePtr);
}
/// <summary>
/// Trims inactive threads from the partial unmap state's thread mapping every few seconds.
/// Should be run in a Background thread so that it doesn't stop the program from closing.
/// </summary>
private void TrimThreadLocalMapLoop()
{
while (true)
{
Thread.Sleep(2000);
GetPartialUnmapState().TrimThreads();
}
} }
/// <summary> /// <summary>
@ -98,7 +125,8 @@ namespace Ryujinx.Memory.WindowsShared
/// <param name="owner">Memory block that owns the mapping</param> /// <param name="owner">Memory block that owns the mapping</param>
public void MapView(IntPtr sharedMemory, ulong srcOffset, IntPtr location, IntPtr size, MemoryBlock owner) public void MapView(IntPtr sharedMemory, ulong srcOffset, IntPtr location, IntPtr size, MemoryBlock owner)
{ {
_partialUnmapLock.AcquireReaderLock(Timeout.Infinite); ref var partialUnmapLock = ref GetPartialUnmapState().PartialUnmapLock;
partialUnmapLock.AcquireReaderLock();
try try
{ {
@ -107,7 +135,7 @@ namespace Ryujinx.Memory.WindowsShared
} }
finally finally
{ {
_partialUnmapLock.ReleaseReaderLock(); partialUnmapLock.ReleaseReaderLock();
} }
} }
@ -221,7 +249,8 @@ namespace Ryujinx.Memory.WindowsShared
/// <param name="owner">Memory block that owns the mapping</param> /// <param name="owner">Memory block that owns the mapping</param>
public void UnmapView(IntPtr sharedMemory, IntPtr location, IntPtr size, MemoryBlock owner) public void UnmapView(IntPtr sharedMemory, IntPtr location, IntPtr size, MemoryBlock owner)
{ {
_partialUnmapLock.AcquireReaderLock(Timeout.Infinite); ref var partialUnmapLock = ref GetPartialUnmapState().PartialUnmapLock;
partialUnmapLock.AcquireReaderLock();
try try
{ {
@ -229,7 +258,7 @@ namespace Ryujinx.Memory.WindowsShared
} }
finally finally
{ {
_partialUnmapLock.ReleaseReaderLock(); partialUnmapLock.ReleaseReaderLock();
} }
} }
@ -265,11 +294,6 @@ namespace Ryujinx.Memory.WindowsShared
if (IsMapped(overlap.Value)) if (IsMapped(overlap.Value))
{ {
if (!WindowsApi.UnmapViewOfFile2(WindowsApi.CurrentProcessHandle, (IntPtr)overlap.Start, 2))
{
throw new WindowsApiException("UnmapViewOfFile2");
}
// Tree operations might modify the node start/end values, so save a copy before we modify the tree. // Tree operations might modify the node start/end values, so save a copy before we modify the tree.
ulong overlapStart = overlap.Start; ulong overlapStart = overlap.Start;
ulong overlapEnd = overlap.End; ulong overlapEnd = overlap.End;
@ -291,30 +315,46 @@ namespace Ryujinx.Memory.WindowsShared
// This is necessary because Windows does not support partial view unmaps. // This is necessary because Windows does not support partial view unmaps.
// That is, you can only fully unmap a view that was previously mapped, you can't just unmap a chunck of it. // That is, you can only fully unmap a view that was previously mapped, you can't just unmap a chunck of it.
LockCookie lockCookie = _partialUnmapLock.UpgradeToWriterLock(Timeout.Infinite); ref var partialUnmapState = ref GetPartialUnmapState();
ref var partialUnmapLock = ref partialUnmapState.PartialUnmapLock;
partialUnmapLock.UpgradeToWriterLock();
_partialUnmapsCount++; try
if (overlapStartsBefore)
{ {
ulong remapSize = startAddress - overlapStart; partialUnmapState.PartialUnmapsCount++;
MapViewInternal(sharedMemory, overlapValue, (IntPtr)overlapStart, (IntPtr)remapSize); if (!WindowsApi.UnmapViewOfFile2(WindowsApi.CurrentProcessHandle, (IntPtr)overlapStart, 2))
RestoreRangeProtection(overlapStart, remapSize); {
throw new WindowsApiException("UnmapViewOfFile2");
}
if (overlapStartsBefore)
{
ulong remapSize = startAddress - overlapStart;
MapViewInternal(sharedMemory, overlapValue, (IntPtr)overlapStart, (IntPtr)remapSize);
RestoreRangeProtection(overlapStart, remapSize);
}
if (overlapEndsAfter)
{
ulong overlappedSize = endAddress - overlapStart;
ulong remapBackingOffset = overlapValue + overlappedSize;
ulong remapAddress = overlapStart + overlappedSize;
ulong remapSize = overlapEnd - endAddress;
MapViewInternal(sharedMemory, remapBackingOffset, (IntPtr)remapAddress, (IntPtr)remapSize);
RestoreRangeProtection(remapAddress, remapSize);
}
} }
finally
if (overlapEndsAfter)
{ {
ulong overlappedSize = endAddress - overlapStart; partialUnmapLock.DowngradeFromWriterLock();
ulong remapBackingOffset = overlapValue + overlappedSize;
ulong remapAddress = overlapStart + overlappedSize;
ulong remapSize = overlapEnd - endAddress;
MapViewInternal(sharedMemory, remapBackingOffset, (IntPtr)remapAddress, (IntPtr)remapSize);
RestoreRangeProtection(remapAddress, remapSize);
} }
}
_partialUnmapLock.DowngradeFromWriterLock(ref lockCookie); else if (!WindowsApi.UnmapViewOfFile2(WindowsApi.CurrentProcessHandle, (IntPtr)overlapStart, 2))
{
throw new WindowsApiException("UnmapViewOfFile2");
} }
} }
} }
@ -394,7 +434,8 @@ namespace Ryujinx.Memory.WindowsShared
/// <returns>True if the reprotection was successful, false otherwise</returns> /// <returns>True if the reprotection was successful, false otherwise</returns>
public bool ReprotectView(IntPtr address, IntPtr size, MemoryPermission permission) public bool ReprotectView(IntPtr address, IntPtr size, MemoryPermission permission)
{ {
_partialUnmapLock.AcquireReaderLock(Timeout.Infinite); ref var partialUnmapLock = ref GetPartialUnmapState().PartialUnmapLock;
partialUnmapLock.AcquireReaderLock();
try try
{ {
@ -402,7 +443,7 @@ namespace Ryujinx.Memory.WindowsShared
} }
finally finally
{ {
_partialUnmapLock.ReleaseReaderLock(); partialUnmapLock.ReleaseReaderLock();
} }
} }
@ -659,31 +700,5 @@ namespace Ryujinx.Memory.WindowsShared
ReprotectViewInternal((IntPtr)protAddress, (IntPtr)(protEndAddress - protAddress), protection.Value, true); ReprotectViewInternal((IntPtr)protAddress, (IntPtr)(protEndAddress - protAddress), protection.Value, true);
} }
} }
/// <summary>
/// Checks if an access violation handler should retry execution due to a fault caused by partial unmap.
/// </summary>
/// <remarks>
/// Due to Windows limitations, <see cref="UnmapView"/> might need to unmap more memory than requested.
/// The additional memory that was unmapped is later remapped, however this leaves a time gap where the
/// memory might be accessed but is unmapped. Users of the API must compensate for that by catching the
/// access violation and retrying if it happened between the unmap and remap operation.
/// This method can be used to decide if retrying in such cases is necessary or not.
/// </remarks>
/// <returns>True if execution should be retried, false otherwise</returns>
public bool RetryFromAccessViolation()
{
_partialUnmapLock.AcquireReaderLock(Timeout.Infinite);
bool retry = _threadLocalPartialUnmapsCount != _partialUnmapsCount;
if (retry)
{
_threadLocalPartialUnmapsCount = _partialUnmapsCount;
}
_partialUnmapLock.ReleaseReaderLock();
return retry;
}
} }
} }

View file

@ -1,170 +0,0 @@
using System;
using System.Runtime.Versioning;
namespace Ryujinx.Memory.WindowsShared
{
/// <summary>
/// Windows 4KB memory placeholder manager.
/// </summary>
[SupportedOSPlatform("windows")]
class PlaceholderManager4KB
{
private const int PageSize = MemoryManagementWindows.PageSize;
private readonly IntervalTree<ulong, byte> _mappings;
/// <summary>
/// Creates a new instance of the Windows 4KB memory placeholder manager.
/// </summary>
public PlaceholderManager4KB()
{
_mappings = new IntervalTree<ulong, byte>();
}
/// <summary>
/// Unmaps the specified range of memory and marks it as mapped internally.
/// </summary>
/// <remarks>
/// Since this marks the range as mapped, the expectation is that the range will be mapped after calling this method.
/// </remarks>
/// <param name="location">Memory address to unmap and mark as mapped</param>
/// <param name="size">Size of the range in bytes</param>
public void UnmapAndMarkRangeAsMapped(IntPtr location, IntPtr size)
{
ulong startAddress = (ulong)location;
ulong unmapSize = (ulong)size;
ulong endAddress = startAddress + unmapSize;
var overlaps = Array.Empty<IntervalTreeNode<ulong, byte>>();
int count = 0;
lock (_mappings)
{
count = _mappings.Get(startAddress, endAddress, ref overlaps);
}
for (int index = 0; index < count; index++)
{
var overlap = overlaps[index];
// Tree operations might modify the node start/end values, so save a copy before we modify the tree.
ulong overlapStart = overlap.Start;
ulong overlapEnd = overlap.End;
ulong overlapValue = overlap.Value;
_mappings.Remove(overlap);
ulong unmapStart = Math.Max(overlapStart, startAddress);
ulong unmapEnd = Math.Min(overlapEnd, endAddress);
if (overlapStart < startAddress)
{
startAddress = overlapStart;
}
if (overlapEnd > endAddress)
{
endAddress = overlapEnd;
}
ulong currentAddress = unmapStart;
while (currentAddress < unmapEnd)
{
WindowsApi.UnmapViewOfFile2(WindowsApi.CurrentProcessHandle, (IntPtr)currentAddress, 2);
currentAddress += PageSize;
}
}
_mappings.Add(startAddress, endAddress, 0);
}
/// <summary>
/// Unmaps views at the specified memory range.
/// </summary>
/// <param name="location">Address of the range</param>
/// <param name="size">Size of the range in bytes</param>
public void UnmapView(IntPtr location, IntPtr size)
{
ulong startAddress = (ulong)location;
ulong unmapSize = (ulong)size;
ulong endAddress = startAddress + unmapSize;
var overlaps = Array.Empty<IntervalTreeNode<ulong, byte>>();
int count = 0;
lock (_mappings)
{
count = _mappings.Get(startAddress, endAddress, ref overlaps);
}
for (int index = 0; index < count; index++)
{
var overlap = overlaps[index];
// Tree operations might modify the node start/end values, so save a copy before we modify the tree.
ulong overlapStart = overlap.Start;
ulong overlapEnd = overlap.End;
_mappings.Remove(overlap);
if (overlapStart < startAddress)
{
_mappings.Add(overlapStart, startAddress, 0);
}
if (overlapEnd > endAddress)
{
_mappings.Add(endAddress, overlapEnd, 0);
}
ulong unmapStart = Math.Max(overlapStart, startAddress);
ulong unmapEnd = Math.Min(overlapEnd, endAddress);
ulong currentAddress = unmapStart;
while (currentAddress < unmapEnd)
{
WindowsApi.UnmapViewOfFile2(WindowsApi.CurrentProcessHandle, (IntPtr)currentAddress, 2);
currentAddress += PageSize;
}
}
}
/// <summary>
/// Unmaps mapped memory at a given range.
/// </summary>
/// <param name="location">Address of the range</param>
/// <param name="size">Size of the range in bytes</param>
public void UnmapRange(IntPtr location, IntPtr size)
{
ulong startAddress = (ulong)location;
ulong unmapSize = (ulong)size;
ulong endAddress = startAddress + unmapSize;
var overlaps = Array.Empty<IntervalTreeNode<ulong, byte>>();
int count = 0;
lock (_mappings)
{
count = _mappings.Get(startAddress, endAddress, ref overlaps);
}
for (int index = 0; index < count; index++)
{
var overlap = overlaps[index];
// Tree operations might modify the node start/end values, so save a copy before we modify the tree.
ulong unmapStart = Math.Max(overlap.Start, startAddress);
ulong unmapEnd = Math.Min(overlap.End, endAddress);
_mappings.Remove(overlap);
ulong currentAddress = unmapStart;
while (currentAddress < unmapEnd)
{
WindowsApi.UnmapViewOfFile2(WindowsApi.CurrentProcessHandle, (IntPtr)currentAddress, 2);
currentAddress += PageSize;
}
}
}
}
}

View file

@ -76,6 +76,9 @@ namespace Ryujinx.Memory.WindowsShared
[DllImport("kernel32.dll")] [DllImport("kernel32.dll")]
public static extern uint GetLastError(); public static extern uint GetLastError();
[DllImport("kernel32.dll")]
public static extern int GetCurrentThreadId();
public static MemoryProtection GetProtection(MemoryPermission permission) public static MemoryProtection GetProtection(MemoryPermission permission)
{ {
return permission switch return permission switch

View file

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="utf-8"?>
<RunSettings>
<RunConfiguration>
<EnvironmentVariables>
<COMPlus_EnableAlternateStackCheck>1</COMPlus_EnableAlternateStackCheck>
</EnvironmentVariables>
</RunConfiguration>
</RunSettings>

View file

@ -0,0 +1,53 @@
using ARMeilleure.Memory;
using System;
namespace Ryujinx.Tests.Memory
{
internal class MockMemoryManager : IMemoryManager
{
public int AddressSpaceBits => throw new NotImplementedException();
public IntPtr PageTablePointer => throw new NotImplementedException();
public MemoryManagerType Type => MemoryManagerType.HostMappedUnsafe;
#pragma warning disable CS0067
public event Action<ulong, ulong> UnmapEvent;
#pragma warning restore CS0067
public ref T GetRef<T>(ulong va) where T : unmanaged
{
throw new NotImplementedException();
}
public ReadOnlySpan<byte> GetSpan(ulong va, int size, bool tracked = false)
{
throw new NotImplementedException();
}
public bool IsMapped(ulong va)
{
throw new NotImplementedException();
}
public T Read<T>(ulong va) where T : unmanaged
{
throw new NotImplementedException();
}
public T ReadTracked<T>(ulong va) where T : unmanaged
{
throw new NotImplementedException();
}
public void SignalMemoryTracking(ulong va, ulong size, bool write, bool precise = false)
{
throw new NotImplementedException();
}
public void Write<T>(ulong va, T value) where T : unmanaged
{
throw new NotImplementedException();
}
}
}

View file

@ -0,0 +1,484 @@
using ARMeilleure.Signal;
using ARMeilleure.Translation;
using NUnit.Framework;
using Ryujinx.Common.Memory.PartialUnmaps;
using Ryujinx.Cpu;
using Ryujinx.Cpu.Jit;
using Ryujinx.Memory;
using Ryujinx.Memory.Tests;
using Ryujinx.Memory.Tracking;
using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Threading;
namespace Ryujinx.Tests.Memory
{
[TestFixture]
internal class PartialUnmaps
{
private static Translator _translator;
private (MemoryBlock virt, MemoryBlock mirror, MemoryEhMeilleure exceptionHandler) GetVirtual(ulong asSize)
{
MemoryAllocationFlags asFlags = MemoryAllocationFlags.Reserve | MemoryAllocationFlags.ViewCompatible;
var addressSpace = new MemoryBlock(asSize, asFlags);
var addressSpaceMirror = new MemoryBlock(asSize, asFlags);
var tracking = new MemoryTracking(new MockVirtualMemoryManager(asSize, 0x1000), 0x1000);
var exceptionHandler = new MemoryEhMeilleure(addressSpace, addressSpaceMirror, tracking);
return (addressSpace, addressSpaceMirror, exceptionHandler);
}
private int CountThreads(ref PartialUnmapState state)
{
int count = 0;
ref var ids = ref state.LocalCounts.ThreadIds;
for (int i = 0; i < ids.Length; i++)
{
if (ids[i] != 0)
{
count++;
}
}
return count;
}
private void EnsureTranslator()
{
// Create a translator, as one is needed to register the signal handler or emit methods.
_translator ??= new Translator(new JitMemoryAllocator(), new MockMemoryManager(), true);
}
[Test]
public void PartialUnmap([Values] bool readOnly)
{
if (OperatingSystem.IsMacOS())
{
// Memory aliasing tests fail on CI at the moment.
return;
}
// Set up an address space to test partial unmapping.
// Should register the signal handler to deal with this on Windows.
ulong vaSize = 0x100000;
// The first 0x100000 is mapped to start. It is replaced from the center with the 0x200000 mapping.
var backing = new MemoryBlock(vaSize * 2, MemoryAllocationFlags.Mirrorable);
(MemoryBlock unusedMainMemory, MemoryBlock memory, MemoryEhMeilleure exceptionHandler) = GetVirtual(vaSize * 2);
EnsureTranslator();
ref var state = ref PartialUnmapState.GetRef();
try
{
// Globally reset the struct for handling partial unmap races.
PartialUnmapState.Reset();
bool shouldAccess = true;
bool error = false;
// Create a large mapping.
memory.MapView(backing, 0, 0, vaSize);
if (readOnly)
{
memory.Reprotect(0, vaSize, MemoryPermission.Read);
}
Thread testThread;
if (readOnly)
{
// Write a value to the physical memory, then try to read it repeately from virtual.
// It should not change.
testThread = new Thread(() =>
{
int i = 12345;
backing.Write(vaSize - 0x1000, i);
while (shouldAccess)
{
if (memory.Read<int>(vaSize - 0x1000) != i)
{
error = true;
shouldAccess = false;
}
}
});
}
else
{
// Repeatedly write and check the value on the last page of the mapping on another thread.
testThread = new Thread(() =>
{
int i = 0;
while (shouldAccess)
{
memory.Write(vaSize - 0x1000, i);
if (memory.Read<int>(vaSize - 0x1000) != i)
{
error = true;
shouldAccess = false;
}
i++;
}
});
}
testThread.Start();
// Create a smaller mapping, covering the larger mapping.
// Immediately try to write to the part of the larger mapping that did not change.
// Do this a lot, with the smaller mapping gradually increasing in size. Should not crash, data should not be lost.
ulong pageSize = 0x1000;
int mappingExpandCount = (int)(vaSize / (pageSize * 2)) - 1;
ulong vaCenter = vaSize / 2;
for (int i = 1; i <= mappingExpandCount; i++)
{
ulong start = vaCenter - (pageSize * (ulong)i);
ulong size = pageSize * (ulong)i * 2;
ulong startPa = start + vaSize;
memory.MapView(backing, startPa, start, size);
}
// On Windows, this should put unmap counts on the thread local map.
if (OperatingSystem.IsWindows())
{
// One thread should be present on the thread local map. Trimming should remove it.
Assert.AreEqual(1, CountThreads(ref state));
}
shouldAccess = false;
testThread.Join();
Assert.False(error);
string test = null;
try
{
test.IndexOf('1');
}
catch (NullReferenceException)
{
// This shouldn't freeze.
}
if (OperatingSystem.IsWindows())
{
state.TrimThreads();
Assert.AreEqual(0, CountThreads(ref state));
}
/*
* Use this to test invalid access. Can't put this in the test suite unfortunately as invalid access crashes the test process.
* memory.Reprotect(vaSize - 0x1000, 0x1000, MemoryPermission.None);
* //memory.UnmapView(backing, vaSize - 0x1000, 0x1000);
* memory.Read<int>(vaSize - 0x1000);
*/
}
finally
{
exceptionHandler.Dispose();
unusedMainMemory.Dispose();
memory.Dispose();
backing.Dispose();
}
}
[Test]
public unsafe void PartialUnmapNative()
{
if (OperatingSystem.IsMacOS())
{
// Memory aliasing tests fail on CI at the moment.
return;
}
// Set up an address space to test partial unmapping.
// Should register the signal handler to deal with this on Windows.
ulong vaSize = 0x100000;
// The first 0x100000 is mapped to start. It is replaced from the center with the 0x200000 mapping.
var backing = new MemoryBlock(vaSize * 2, MemoryAllocationFlags.Mirrorable);
(MemoryBlock mainMemory, MemoryBlock unusedMirror, MemoryEhMeilleure exceptionHandler) = GetVirtual(vaSize * 2);
EnsureTranslator();
ref var state = ref PartialUnmapState.GetRef();
// Create some state to be used for managing the native writing loop.
int stateSize = Unsafe.SizeOf<NativeWriteLoopState>();
var statePtr = Marshal.AllocHGlobal(stateSize);
Unsafe.InitBlockUnaligned((void*)statePtr, 0, (uint)stateSize);
ref NativeWriteLoopState writeLoopState = ref Unsafe.AsRef<NativeWriteLoopState>((void*)statePtr);
writeLoopState.Running = 1;
writeLoopState.Error = 0;
try
{
// Globally reset the struct for handling partial unmap races.
PartialUnmapState.Reset();
// Create a large mapping.
mainMemory.MapView(backing, 0, 0, vaSize);
var writeFunc = TestMethods.GenerateDebugNativeWriteLoop();
IntPtr writePtr = mainMemory.GetPointer(vaSize - 0x1000, 4);
Thread testThread = new Thread(() =>
{
writeFunc(statePtr, writePtr);
});
testThread.Start();
// Create a smaller mapping, covering the larger mapping.
// Immediately try to write to the part of the larger mapping that did not change.
// Do this a lot, with the smaller mapping gradually increasing in size. Should not crash, data should not be lost.
ulong pageSize = 0x1000;
int mappingExpandCount = (int)(vaSize / (pageSize * 2)) - 1;
ulong vaCenter = vaSize / 2;
for (int i = 1; i <= mappingExpandCount; i++)
{
ulong start = vaCenter - (pageSize * (ulong)i);
ulong size = pageSize * (ulong)i * 2;
ulong startPa = start + vaSize;
mainMemory.MapView(backing, startPa, start, size);
}
writeLoopState.Running = 0;
testThread.Join();
Assert.False(writeLoopState.Error != 0);
}
finally
{
Marshal.FreeHGlobal(statePtr);
exceptionHandler.Dispose();
mainMemory.Dispose();
unusedMirror.Dispose();
backing.Dispose();
}
}
[Test]
public void ThreadLocalMap()
{
if (!OperatingSystem.IsWindows())
{
// Only test in Windows, as this is only used on Windows and uses Windows APIs for trimming.
return;
}
PartialUnmapState.Reset();
ref var state = ref PartialUnmapState.GetRef();
bool running = true;
var testThread = new Thread(() =>
{
if (!OperatingSystem.IsWindows())
{
// Need this here to avoid a warning.
return;
}
PartialUnmapState.GetRef().RetryFromAccessViolation();
while (running)
{
Thread.Sleep(1);
}
});
testThread.Start();
Thread.Sleep(200);
Assert.AreEqual(1, CountThreads(ref state));
// Trimming should not remove the thread as it's still active.
state.TrimThreads();
Assert.AreEqual(1, CountThreads(ref state));
running = false;
testThread.Join();
// Should trim now that it's inactive.
state.TrimThreads();
Assert.AreEqual(0, CountThreads(ref state));
}
[Test]
public unsafe void ThreadLocalMapNative()
{
if (!OperatingSystem.IsWindows())
{
// Only test in Windows, as this is only used on Windows and uses Windows APIs for trimming.
return;
}
EnsureTranslator();
PartialUnmapState.Reset();
ref var state = ref PartialUnmapState.GetRef();
fixed (void* localMap = &state.LocalCounts)
{
var getOrReserve = TestMethods.GenerateDebugThreadLocalMapGetOrReserve((IntPtr)localMap);
for (int i = 0; i < ThreadLocalMap<int>.MapSize; i++)
{
// Should obtain the index matching the call #.
Assert.AreEqual(i, getOrReserve(i + 1, i));
// Check that this and all previously reserved thread IDs and struct contents are intact.
for (int j = 0; j <= i; j++)
{
Assert.AreEqual(j + 1, state.LocalCounts.ThreadIds[j]);
Assert.AreEqual(j, state.LocalCounts.Structs[j]);
}
}
// Trying to reserve again when the map is full should return -1.
Assert.AreEqual(-1, getOrReserve(200, 0));
for (int i = 0; i < ThreadLocalMap<int>.MapSize; i++)
{
// Should obtain the index matching the call #, as it already exists.
Assert.AreEqual(i, getOrReserve(i + 1, -1));
// The struct should not be reset to -1.
Assert.AreEqual(i, state.LocalCounts.Structs[i]);
}
// Clear one of the ids as if it were freed.
state.LocalCounts.ThreadIds[13] = 0;
// GetOrReserve should now obtain and return 13.
Assert.AreEqual(13, getOrReserve(300, 301));
Assert.AreEqual(300, state.LocalCounts.ThreadIds[13]);
Assert.AreEqual(301, state.LocalCounts.Structs[13]);
}
}
[Test]
public void NativeReaderWriterLock()
{
var rwLock = new NativeReaderWriterLock();
var threads = new List<Thread>();
int value = 0;
bool running = true;
bool error = false;
int readersAllowed = 1;
for (int i = 0; i < 5; i++)
{
var readThread = new Thread(() =>
{
int count = 0;
while (running)
{
rwLock.AcquireReaderLock();
int originalValue = Thread.VolatileRead(ref value);
count++;
// Spin a bit.
for (int i = 0; i < 100; i++)
{
if (Thread.VolatileRead(ref readersAllowed) == 0)
{
error = true;
running = false;
}
}
// Should not change while the lock is held.
if (Thread.VolatileRead(ref value) != originalValue)
{
error = true;
running = false;
}
rwLock.ReleaseReaderLock();
}
});
threads.Add(readThread);
}
for (int i = 0; i < 2; i++)
{
var writeThread = new Thread(() =>
{
int count = 0;
while (running)
{
rwLock.AcquireReaderLock();
rwLock.UpgradeToWriterLock();
Thread.Sleep(2);
count++;
Interlocked.Exchange(ref readersAllowed, 0);
for (int i = 0; i < 10; i++)
{
Interlocked.Increment(ref value);
}
Interlocked.Exchange(ref readersAllowed, 1);
rwLock.DowngradeFromWriterLock();
rwLock.ReleaseReaderLock();
Thread.Sleep(1);
}
});
threads.Add(writeThread);
}
foreach (var thread in threads)
{
thread.Start();
}
Thread.Sleep(1000);
running = false;
foreach (var thread in threads)
{
thread.Join();
}
Assert.False(error);
}
}
}

View file

@ -1,4 +1,4 @@
<Project Sdk="Microsoft.NET.Sdk"> <Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup> <PropertyGroup>
<TargetFramework>net6.0</TargetFramework> <TargetFramework>net6.0</TargetFramework>
@ -9,10 +9,12 @@
<TargetOS Condition="'$([System.Runtime.InteropServices.RuntimeInformation]::IsOSPlatform($([System.Runtime.InteropServices.OSPlatform]::OSX)))' == 'true'">osx</TargetOS> <TargetOS Condition="'$([System.Runtime.InteropServices.RuntimeInformation]::IsOSPlatform($([System.Runtime.InteropServices.OSPlatform]::OSX)))' == 'true'">osx</TargetOS>
<TargetOS Condition="'$([System.Runtime.InteropServices.RuntimeInformation]::IsOSPlatform($([System.Runtime.InteropServices.OSPlatform]::Linux)))' == 'true'">linux</TargetOS> <TargetOS Condition="'$([System.Runtime.InteropServices.RuntimeInformation]::IsOSPlatform($([System.Runtime.InteropServices.OSPlatform]::Linux)))' == 'true'">linux</TargetOS>
<Configurations>Debug;Release</Configurations> <Configurations>Debug;Release</Configurations>
<RunSettingsFilePath>$(MSBuildProjectDirectory)\.runsettings</RunSettingsFilePath>
</PropertyGroup> </PropertyGroup>
<PropertyGroup> <PropertyGroup>
<GenerateAssemblyInfo>false</GenerateAssemblyInfo> <GenerateAssemblyInfo>false</GenerateAssemblyInfo>
<AllowUnsafeBlocks>True</AllowUnsafeBlocks>
</PropertyGroup> </PropertyGroup>
<ItemGroup> <ItemGroup>
@ -25,6 +27,7 @@
<ProjectReference Include="..\Ryujinx.Audio\Ryujinx.Audio.csproj" /> <ProjectReference Include="..\Ryujinx.Audio\Ryujinx.Audio.csproj" />
<ProjectReference Include="..\Ryujinx.Cpu\Ryujinx.Cpu.csproj" /> <ProjectReference Include="..\Ryujinx.Cpu\Ryujinx.Cpu.csproj" />
<ProjectReference Include="..\Ryujinx.HLE\Ryujinx.HLE.csproj" /> <ProjectReference Include="..\Ryujinx.HLE\Ryujinx.HLE.csproj" />
<ProjectReference Include="..\Ryujinx.Memory.Tests\Ryujinx.Memory.Tests.csproj" />
<ProjectReference Include="..\Ryujinx.Memory\Ryujinx.Memory.csproj" /> <ProjectReference Include="..\Ryujinx.Memory\Ryujinx.Memory.csproj" />
<ProjectReference Include="..\Ryujinx.Tests.Unicorn\Ryujinx.Tests.Unicorn.csproj" /> <ProjectReference Include="..\Ryujinx.Tests.Unicorn\Ryujinx.Tests.Unicorn.csproj" />
<ProjectReference Include="..\ARMeilleure\ARMeilleure.csproj" /> <ProjectReference Include="..\ARMeilleure\ARMeilleure.csproj" />