From d300a5a45b73c7d1fdb631a9d93db66c70903486 Mon Sep 17 00:00:00 2001 From: Mary Date: Wed, 12 Jan 2022 17:43:00 +0100 Subject: [PATCH] sfdnsres: Fix serialization issues (#2992) * sfdnsres: Fix serialization issues Fix a crash on Monster Hunter Rise * Address gdkchan's comments * Address gdkchan's comments --- .../Services/Sockets/Sfdnsres/IResolver.cs | 85 +++++------ .../Sockets/Sfdnsres/Types/AddrInfo4.cs | 31 +++- .../Sfdnsres/Types/AddrInfoSerialized.cs | 136 ++++++++++++++++++ .../Types/AddrInfoSerializedHeader.cs | 34 ++++- Ryujinx.HLE/Utilities/StringUtils.cs | 12 ++ 5 files changed, 244 insertions(+), 54 deletions(-) create mode 100644 Ryujinx.HLE/HOS/Services/Sockets/Sfdnsres/Types/AddrInfoSerialized.cs diff --git a/Ryujinx.HLE/HOS/Services/Sockets/Sfdnsres/IResolver.cs b/Ryujinx.HLE/HOS/Services/Sockets/Sfdnsres/IResolver.cs index a07fc518..971c5d65 100644 --- a/Ryujinx.HLE/HOS/Services/Sockets/Sfdnsres/IResolver.cs +++ b/Ryujinx.HLE/HOS/Services/Sockets/Sfdnsres/IResolver.cs @@ -1,4 +1,5 @@ using Ryujinx.Common.Logging; +using Ryujinx.Common.Memory; using Ryujinx.Cpu; using Ryujinx.HLE.HOS.Services.Sockets.Nsd.Manager; using Ryujinx.HLE.HOS.Services.Sockets.Sfdnsres.Proxy; @@ -11,6 +12,7 @@ using System.Linq; using System.Net; using System.Net.Sockets; using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; using System.Text; namespace Ryujinx.HLE.HOS.Services.Sockets.Sfdnsres @@ -268,7 +270,7 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Sfdnsres NetDbError netDbErrorCode = NetDbError.Success; GaiError errno = GaiError.Overflow; - ulong serializedSize = 0; + int serializedSize = 0; if (host.Length <= byte.MaxValue) { @@ -368,7 +370,7 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Sfdnsres NetDbError netDbErrorCode = NetDbError.Success; GaiError errno = GaiError.AddressFamily; - ulong serializedSize = 0; + int serializedSize = 0; if (rawIp.Length == 4) { @@ -400,7 +402,7 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Sfdnsres return ResultCode.Success; } - private static ulong SerializeHostEntries(ServiceCtx context, ulong outputBufferPosition, ulong outputBufferSize, IPHostEntry hostEntry, IEnumerable addresses = null) + private static int SerializeHostEntries(ServiceCtx context, ulong outputBufferPosition, ulong outputBufferSize, IPHostEntry hostEntry, IEnumerable addresses = null) { ulong originalBufferPosition = outputBufferPosition; ulong bufferPosition = originalBufferPosition; @@ -443,7 +445,7 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Sfdnsres } } - return bufferPosition - originalBufferPosition; + return (int)(bufferPosition - originalBufferPosition); } private static ResultCode GetAddrInfoRequestImpl( @@ -470,7 +472,7 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Sfdnsres } // NOTE: We ignore hints for now. - DeserializeAddrInfos(context.Memory, (ulong)context.Request.SendBuff[2].Position, (ulong)context.Request.SendBuff[2].Size); + List hints = DeserializeAddrInfos(context.Memory, context.Request.SendBuff[2].Position, context.Request.SendBuff[2].Size); if (withOptions) { @@ -484,7 +486,7 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Sfdnsres NetDbError netDbErrorCode = NetDbError.Success; GaiError errno = GaiError.AddressFamily; - ulong serializedSize = 0; + int serializedSize = 0; if (host.Length <= byte.MaxValue) { @@ -538,74 +540,73 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Sfdnsres return ResultCode.Success; } - private static void DeserializeAddrInfos(IVirtualMemoryManager memory, ulong address, ulong size) + private static List DeserializeAddrInfos(IVirtualMemoryManager memory, ulong address, ulong size) { - ulong endAddress = address + size; + List result = new List(); - while (address < endAddress) + ReadOnlySpan data = memory.GetSpan(address, (int)size); + + while (!data.IsEmpty) { - AddrInfoSerializedHeader header = memory.Read(address); + AddrInfoSerialized info = AddrInfoSerialized.Read(data, out data); - if (header.Magic != SfdnsresContants.AddrInfoMagic) + if (info == null) { break; } - address += (ulong)Unsafe.SizeOf() + header.AddressLength; - - // ai_canonname - string canonname = MemoryHelper.ReadAsciiString(memory, address); + result.Add(info); } + + return result; } - private static ulong SerializeAddrInfos(ServiceCtx context, ulong responseBufferPosition, ulong responseBufferSize, IPHostEntry hostEntry, int port) + private static int SerializeAddrInfos(ServiceCtx context, ulong responseBufferPosition, ulong responseBufferSize, IPHostEntry hostEntry, int port) { - ulong originalBufferPosition = (ulong)responseBufferPosition; + ulong originalBufferPosition = responseBufferPosition; ulong bufferPosition = originalBufferPosition; - string hostName = hostEntry.HostName + '\0'; + byte[] hostName = Encoding.ASCII.GetBytes(hostEntry.HostName + '\0'); - for (int i = 0; i < hostEntry.AddressList.Length; i++) + using (WritableRegion region = context.Memory.GetWritableRegion(responseBufferPosition, (int)responseBufferSize)) { - IPAddress ip = hostEntry.AddressList[i]; + Span data = region.Memory.Span; - if (ip.AddressFamily != AddressFamily.InterNetwork) + for (int i = 0; i < hostEntry.AddressList.Length; i++) { - continue; + IPAddress ip = hostEntry.AddressList[i]; + + if (ip.AddressFamily != AddressFamily.InterNetwork) + { + continue; + } + + // NOTE: 0 = Any + AddrInfoSerializedHeader header = new AddrInfoSerializedHeader(ip, 0); + AddrInfo4 addr = new AddrInfo4(ip, (short)port); + AddrInfoSerialized info = new AddrInfoSerialized(header, addr, null, hostEntry.HostName); + + data = info.Write(data); } - AddrInfoSerializedHeader header = new AddrInfoSerializedHeader(ip, 0); + uint sentinel = 0; + MemoryMarshal.Write(data, ref sentinel); + data = data[sizeof(uint)..]; - // NOTE: 0 = Any - context.Memory.Write(bufferPosition, header); - bufferPosition += (ulong)Unsafe.SizeOf(); - - // addrinfo_in - context.Memory.Write(bufferPosition, new AddrInfo4(ip, (short)port)); - bufferPosition += header.AddressLength; - - // ai_canonname - context.Memory.Write(bufferPosition, Encoding.ASCII.GetBytes(hostName)); - bufferPosition += (ulong)hostName.Length; + return region.Memory.Span.Length - data.Length; } - - // Termination zero value. - context.Memory.Write(bufferPosition, 0); - bufferPosition += sizeof(int); - - return bufferPosition - originalBufferPosition; } private static void WriteResponse( ServiceCtx context, bool withOptions, - ulong serializedSize, + int serializedSize, GaiError errno, NetDbError netDbErrorCode) { if (withOptions) { - context.ResponseData.Write((int)serializedSize); + context.ResponseData.Write(serializedSize); context.ResponseData.Write((int)errno); context.ResponseData.Write((int)netDbErrorCode); context.ResponseData.Write(0); diff --git a/Ryujinx.HLE/HOS/Services/Sockets/Sfdnsres/Types/AddrInfo4.cs b/Ryujinx.HLE/HOS/Services/Sockets/Sfdnsres/Types/AddrInfo4.cs index 0e1d3aae..e2041d2e 100644 --- a/Ryujinx.HLE/HOS/Services/Sockets/Sfdnsres/Types/AddrInfo4.cs +++ b/Ryujinx.HLE/HOS/Services/Sockets/Sfdnsres/Types/AddrInfo4.cs @@ -2,6 +2,7 @@ using System; using System.Net; using System.Net.Sockets; +using System.Runtime.CompilerServices; using System.Runtime.InteropServices; namespace Ryujinx.HLE.HOS.Services.Sockets.Sfdnsres.Types @@ -16,14 +17,34 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Sfdnsres.Types public AddrInfo4(IPAddress address, short port) { - Length = 0; + Length = (byte)Unsafe.SizeOf>(); Family = (byte)AddressFamily.InterNetwork; Port = port; - Address = default; + Address = new Array4(); - Span outAddress = Address.ToSpan(); - address.TryWriteBytes(outAddress, out _); - outAddress.Reverse(); + address.TryWriteBytes(Address.ToSpan(), out _); + } + + public void ToNetworkOrder() + { + Port = IPAddress.HostToNetworkOrder(Port); + + RawIpv4AddressNetworkEndianSwap(ref Address); + } + + public void ToHostOrder() + { + Port = IPAddress.NetworkToHostOrder(Port); + + RawIpv4AddressNetworkEndianSwap(ref Address); + } + + public static void RawIpv4AddressNetworkEndianSwap(ref Array4 address) + { + if (BitConverter.IsLittleEndian) + { + address.ToSpan().Reverse(); + } } } } \ No newline at end of file diff --git a/Ryujinx.HLE/HOS/Services/Sockets/Sfdnsres/Types/AddrInfoSerialized.cs b/Ryujinx.HLE/HOS/Services/Sockets/Sfdnsres/Types/AddrInfoSerialized.cs new file mode 100644 index 00000000..47012396 --- /dev/null +++ b/Ryujinx.HLE/HOS/Services/Sockets/Sfdnsres/Types/AddrInfoSerialized.cs @@ -0,0 +1,136 @@ +using Ryujinx.Common.Memory; +using Ryujinx.HLE.Utilities; +using System; +using System.Diagnostics; +using System.Net.Sockets; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Text; + +namespace Ryujinx.HLE.HOS.Services.Sockets.Sfdnsres.Types +{ + class AddrInfoSerialized + { + public AddrInfoSerializedHeader Header; + public AddrInfo4? SocketAddress; + public Array4? RawIPv4Address; + public string CanonicalName; + + public AddrInfoSerialized(AddrInfoSerializedHeader header, AddrInfo4? address, Array4? rawIPv4Address, string canonicalName) + { + Header = header; + SocketAddress = address; + RawIPv4Address = rawIPv4Address; + CanonicalName = canonicalName; + } + + public static AddrInfoSerialized Read(ReadOnlySpan buffer, out ReadOnlySpan rest) + { + if (!MemoryMarshal.TryRead(buffer, out AddrInfoSerializedHeader header)) + { + rest = buffer; + + return null; + } + + AddrInfo4? socketAddress = null; + Array4? rawIPv4Address = null; + string canonicalName = null; + + buffer = buffer[Unsafe.SizeOf()..]; + + header.ToHostOrder(); + + if (header.Magic != SfdnsresContants.AddrInfoMagic) + { + rest = buffer; + + return null; + } + + Debug.Assert(header.Magic == SfdnsresContants.AddrInfoMagic); + + if (header.Family == (int)AddressFamily.InterNetwork) + { + socketAddress = MemoryMarshal.Read(buffer); + socketAddress.Value.ToHostOrder(); + + buffer = buffer[Unsafe.SizeOf()..]; + } + // AF_INET6 + else if (header.Family == 28) + { + throw new NotImplementedException(); + } + else + { + // Nintendo hardcode 4 bytes in that case here. + Array4 address = MemoryMarshal.Read>(buffer); + AddrInfo4.RawIpv4AddressNetworkEndianSwap(ref address); + + rawIPv4Address = address; + + buffer = buffer[Unsafe.SizeOf>()..]; + } + + canonicalName = StringUtils.ReadUtf8String(buffer, out int dataRead); + buffer = buffer[dataRead..]; + + rest = buffer; + + return new AddrInfoSerialized(header, socketAddress, rawIPv4Address, canonicalName); + } + + public Span Write(Span buffer) + { + int familly = Header.Family; + + Header.ToNetworkOrder(); + + MemoryMarshal.Write(buffer, ref Header); + + buffer = buffer[Unsafe.SizeOf()..]; + + if (familly == (int)AddressFamily.InterNetwork) + { + AddrInfo4 socketAddress = SocketAddress.Value; + socketAddress.ToNetworkOrder(); + + MemoryMarshal.Write(buffer, ref socketAddress); + + buffer = buffer[Unsafe.SizeOf()..]; + } + // AF_INET6 + else if (familly == 28) + { + throw new NotImplementedException(); + } + else + { + Array4 rawIPv4Address = RawIPv4Address.Value; + AddrInfo4.RawIpv4AddressNetworkEndianSwap(ref rawIPv4Address); + + MemoryMarshal.Write(buffer, ref rawIPv4Address); + + buffer = buffer[Unsafe.SizeOf>()..]; + } + + if (CanonicalName == null) + { + buffer[0] = 0; + + buffer = buffer[1..]; + } + else + { + byte[] canonicalName = Encoding.ASCII.GetBytes(CanonicalName + '\0'); + + canonicalName.CopyTo(buffer); + + buffer = buffer[canonicalName.Length..]; + } + + return buffer; + } + } +} diff --git a/Ryujinx.HLE/HOS/Services/Sockets/Sfdnsres/Types/AddrInfoSerializedHeader.cs b/Ryujinx.HLE/HOS/Services/Sockets/Sfdnsres/Types/AddrInfoSerializedHeader.cs index b6251a45..8e304dfa 100644 --- a/Ryujinx.HLE/HOS/Services/Sockets/Sfdnsres/Types/AddrInfoSerializedHeader.cs +++ b/Ryujinx.HLE/HOS/Services/Sockets/Sfdnsres/Types/AddrInfoSerializedHeader.cs @@ -1,4 +1,4 @@ -using System.Buffers.Binary; +using Ryujinx.Common.Memory; using System.Net; using System.Net.Sockets; using System.Runtime.CompilerServices; @@ -18,11 +18,11 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Sfdnsres.Types public AddrInfoSerializedHeader(IPAddress address, SocketType socketType) { - Magic = (uint)BinaryPrimitives.ReverseEndianness(unchecked((int)SfdnsresContants.AddrInfoMagic)); - Flags = 0; // Big Endian - Family = BinaryPrimitives.ReverseEndianness((int)address.AddressFamily); - SocketType = BinaryPrimitives.ReverseEndianness((int)socketType); - Protocol = 0; // Big Endian + Magic = SfdnsresContants.AddrInfoMagic; + Flags = 0; + Family = (int)address.AddressFamily; + SocketType = (int)socketType; + Protocol = 0; if (address.AddressFamily == AddressFamily.InterNetwork) { @@ -30,8 +30,28 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Sfdnsres.Types } else { - AddressLength = 4; + AddressLength = (uint)Unsafe.SizeOf>(); } } + + public void ToNetworkOrder() + { + Magic = (uint)IPAddress.HostToNetworkOrder((int)Magic); + Flags = IPAddress.HostToNetworkOrder(Flags); + Family = IPAddress.HostToNetworkOrder(Family); + SocketType = IPAddress.HostToNetworkOrder(SocketType); + Protocol = IPAddress.HostToNetworkOrder(Protocol); + AddressLength = (uint)IPAddress.HostToNetworkOrder((int)AddressLength); + } + + public void ToHostOrder() + { + Magic = (uint)IPAddress.NetworkToHostOrder((int)Magic); + Flags = IPAddress.NetworkToHostOrder(Flags); + Family = IPAddress.NetworkToHostOrder(Family); + SocketType = IPAddress.NetworkToHostOrder(SocketType); + Protocol = IPAddress.NetworkToHostOrder(Protocol); + AddressLength = (uint)IPAddress.NetworkToHostOrder((int)AddressLength); + } } } \ No newline at end of file diff --git a/Ryujinx.HLE/Utilities/StringUtils.cs b/Ryujinx.HLE/Utilities/StringUtils.cs index 2b7cbffe..3027139b 100644 --- a/Ryujinx.HLE/Utilities/StringUtils.cs +++ b/Ryujinx.HLE/Utilities/StringUtils.cs @@ -60,6 +60,18 @@ namespace Ryujinx.HLE.Utilities return output; } + public static string ReadUtf8String(ReadOnlySpan data, out int dataRead) + { + dataRead = data.IndexOf((byte)0) + 1; + + if (dataRead <= 1) + { + return string.Empty; + } + + return Encoding.UTF8.GetString(data[..dataRead]); + } + public static string ReadUtf8String(ServiceCtx context, int index = 0) { ulong position = context.Request.PtrBuff[index].Position;