bsd: implement SendMMsg and RecvMMsg (#3660)
* bsd: implement sendmmsg and recvmmsg * Fix wrong increment of vlen
This commit is contained in:
parent
51bb8707ef
commit
f3835dc78b
6 changed files with 528 additions and 0 deletions
|
@ -886,6 +886,91 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
|
|||
return WriteBsdResult(context, newSockFd, errno);
|
||||
}
|
||||
|
||||
|
||||
[CommandHipc(29)] // 7.0.0+
|
||||
// RecvMMsg(u32 fd, u32 vlen, u32 flags, u32 reserved, nn::socket::TimeVal timeout) -> (i32 ret, u32 bsd_errno, buffer<bytes, 6> message);
|
||||
public ResultCode RecvMMsg(ServiceCtx context)
|
||||
{
|
||||
int socketFd = context.RequestData.ReadInt32();
|
||||
int vlen = context.RequestData.ReadInt32();
|
||||
BsdSocketFlags socketFlags = (BsdSocketFlags)context.RequestData.ReadInt32();
|
||||
uint reserved = context.RequestData.ReadUInt32();
|
||||
TimeVal timeout = context.RequestData.ReadStruct<TimeVal>();
|
||||
|
||||
ulong receivePosition = context.Request.ReceiveBuff[0].Position;
|
||||
ulong receiveLength = context.Request.ReceiveBuff[0].Size;
|
||||
|
||||
WritableRegion receiveRegion = context.Memory.GetWritableRegion(receivePosition, (int)receiveLength);
|
||||
|
||||
LinuxError errno = LinuxError.EBADF;
|
||||
ISocket socket = _context.RetrieveSocket(socketFd);
|
||||
int result = -1;
|
||||
|
||||
if (socket != null)
|
||||
{
|
||||
errno = BsdMMsgHdr.Deserialize(out BsdMMsgHdr message, receiveRegion.Memory.Span, vlen);
|
||||
|
||||
if (errno == LinuxError.SUCCESS)
|
||||
{
|
||||
errno = socket.RecvMMsg(out result, message, socketFlags, timeout);
|
||||
|
||||
if (errno == LinuxError.SUCCESS)
|
||||
{
|
||||
errno = BsdMMsgHdr.Serialize(receiveRegion.Memory.Span, message);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (errno == LinuxError.SUCCESS)
|
||||
{
|
||||
SetResultErrno(socket, result);
|
||||
receiveRegion.Dispose();
|
||||
}
|
||||
|
||||
return WriteBsdResult(context, result, errno);
|
||||
}
|
||||
|
||||
[CommandHipc(30)] // 7.0.0+
|
||||
// SendMMsg(u32 fd, u32 vlen, u32 flags) -> (i32 ret, u32 bsd_errno, buffer<bytes, 6> message);
|
||||
public ResultCode SendMMsg(ServiceCtx context)
|
||||
{
|
||||
int socketFd = context.RequestData.ReadInt32();
|
||||
int vlen = context.RequestData.ReadInt32();
|
||||
BsdSocketFlags socketFlags = (BsdSocketFlags)context.RequestData.ReadInt32();
|
||||
|
||||
ulong receivePosition = context.Request.ReceiveBuff[0].Position;
|
||||
ulong receiveLength = context.Request.ReceiveBuff[0].Size;
|
||||
|
||||
WritableRegion receiveRegion = context.Memory.GetWritableRegion(receivePosition, (int)receiveLength);
|
||||
|
||||
LinuxError errno = LinuxError.EBADF;
|
||||
ISocket socket = _context.RetrieveSocket(socketFd);
|
||||
int result = -1;
|
||||
|
||||
if (socket != null)
|
||||
{
|
||||
errno = BsdMMsgHdr.Deserialize(out BsdMMsgHdr message, receiveRegion.Memory.Span, vlen);
|
||||
|
||||
if (errno == LinuxError.SUCCESS)
|
||||
{
|
||||
errno = socket.SendMMsg(out result, message, socketFlags);
|
||||
|
||||
if (errno == LinuxError.SUCCESS)
|
||||
{
|
||||
errno = BsdMMsgHdr.Serialize(receiveRegion.Memory.Span, message);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (errno == LinuxError.SUCCESS)
|
||||
{
|
||||
SetResultErrno(socket, result);
|
||||
receiveRegion.Dispose();
|
||||
}
|
||||
|
||||
return WriteBsdResult(context, result, errno);
|
||||
}
|
||||
|
||||
[CommandHipc(31)] // 7.0.0+
|
||||
// EventFd(u64 initval, nn::socket::EventFdFlags flags) -> (i32 ret, u32 bsd_errno)
|
||||
public ResultCode EventFd(ServiceCtx context)
|
||||
|
|
|
@ -25,7 +25,12 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
|
|||
|
||||
LinuxError SendTo(out int sendSize, ReadOnlySpan<byte> buffer, int size, BsdSocketFlags flags, IPEndPoint remoteEndPoint);
|
||||
|
||||
LinuxError RecvMMsg(out int vlen, BsdMMsgHdr message, BsdSocketFlags flags, TimeVal timeout);
|
||||
|
||||
LinuxError SendMMsg(out int vlen, BsdMMsgHdr message, BsdSocketFlags flags);
|
||||
|
||||
LinuxError GetSocketOption(BsdSocketOption option, SocketOptionLevel level, Span<byte> optionValue);
|
||||
|
||||
LinuxError SetSocketOption(BsdSocketOption option, SocketOptionLevel level, ReadOnlySpan<byte> optionValue);
|
||||
|
||||
bool Poll(int microSeconds, SelectMode mode);
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
using Ryujinx.Common.Logging;
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Diagnostics;
|
||||
using System.Net;
|
||||
using System.Net.Sockets;
|
||||
using System.Runtime.InteropServices;
|
||||
|
@ -356,5 +358,165 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
|
|||
{
|
||||
return Send(out writeSize, buffer, BsdSocketFlags.None);
|
||||
}
|
||||
|
||||
private bool CanSupportMMsgHdr(BsdMMsgHdr message)
|
||||
{
|
||||
for (int i = 0; i < message.Messages.Length; i++)
|
||||
{
|
||||
if (message.Messages[i].Name != null ||
|
||||
message.Messages[i].Control != null)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
private static IList<ArraySegment<byte>> ConvertMessagesToBuffer(BsdMMsgHdr message)
|
||||
{
|
||||
int segmentCount = 0;
|
||||
int index = 0;
|
||||
|
||||
foreach (BsdMsgHdr msgHeader in message.Messages)
|
||||
{
|
||||
segmentCount += msgHeader.Iov.Length;
|
||||
}
|
||||
|
||||
ArraySegment<byte>[] buffers = new ArraySegment<byte>[segmentCount];
|
||||
|
||||
foreach (BsdMsgHdr msgHeader in message.Messages)
|
||||
{
|
||||
foreach (byte[] iov in msgHeader.Iov)
|
||||
{
|
||||
buffers[index++] = new ArraySegment<byte>(iov);
|
||||
}
|
||||
|
||||
// Clear the length
|
||||
msgHeader.Length = 0;
|
||||
}
|
||||
|
||||
return buffers;
|
||||
}
|
||||
|
||||
private static void UpdateMessages(out int vlen, BsdMMsgHdr message, int transferedSize)
|
||||
{
|
||||
int bytesLeft = transferedSize;
|
||||
int index = 0;
|
||||
|
||||
while (bytesLeft > 0)
|
||||
{
|
||||
// First ensure we haven't finished all buffers
|
||||
if (index >= message.Messages.Length)
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
BsdMsgHdr msgHeader = message.Messages[index];
|
||||
|
||||
int possiblyTransferedBytes = 0;
|
||||
|
||||
foreach (byte[] iov in msgHeader.Iov)
|
||||
{
|
||||
possiblyTransferedBytes += iov.Length;
|
||||
}
|
||||
|
||||
int storedBytes;
|
||||
|
||||
if (bytesLeft > possiblyTransferedBytes)
|
||||
{
|
||||
storedBytes = possiblyTransferedBytes;
|
||||
index++;
|
||||
}
|
||||
else
|
||||
{
|
||||
storedBytes = bytesLeft;
|
||||
}
|
||||
|
||||
msgHeader.Length = (uint)storedBytes;
|
||||
bytesLeft -= storedBytes;
|
||||
}
|
||||
|
||||
Debug.Assert(bytesLeft == 0);
|
||||
|
||||
vlen = index + 1;
|
||||
}
|
||||
|
||||
// TODO: Find a way to support passing the timeout somehow without changing the socket ReceiveTimeout.
|
||||
public LinuxError RecvMMsg(out int vlen, BsdMMsgHdr message, BsdSocketFlags flags, TimeVal timeout)
|
||||
{
|
||||
vlen = 0;
|
||||
|
||||
if (message.Messages.Length == 0)
|
||||
{
|
||||
return LinuxError.SUCCESS;
|
||||
}
|
||||
|
||||
if (!CanSupportMMsgHdr(message))
|
||||
{
|
||||
Logger.Warning?.Print(LogClass.ServiceBsd, $"Unsupported BsdMMsgHdr");
|
||||
|
||||
return LinuxError.EOPNOTSUPP;
|
||||
}
|
||||
|
||||
if (message.Messages.Length == 0)
|
||||
{
|
||||
return LinuxError.SUCCESS;
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
int receiveSize = Socket.Receive(ConvertMessagesToBuffer(message), ConvertBsdSocketFlags(flags), out SocketError socketError);
|
||||
|
||||
if (receiveSize > 0)
|
||||
{
|
||||
UpdateMessages(out vlen, message, receiveSize);
|
||||
}
|
||||
|
||||
return WinSockHelper.ConvertError((WsaError)socketError);
|
||||
}
|
||||
catch (SocketException exception)
|
||||
{
|
||||
return WinSockHelper.ConvertError((WsaError)exception.ErrorCode);
|
||||
}
|
||||
}
|
||||
|
||||
public LinuxError SendMMsg(out int vlen, BsdMMsgHdr message, BsdSocketFlags flags)
|
||||
{
|
||||
vlen = 0;
|
||||
|
||||
if (message.Messages.Length == 0)
|
||||
{
|
||||
return LinuxError.SUCCESS;
|
||||
}
|
||||
|
||||
if (!CanSupportMMsgHdr(message))
|
||||
{
|
||||
Logger.Warning?.Print(LogClass.ServiceBsd, $"Unsupported BsdMMsgHdr");
|
||||
|
||||
return LinuxError.EOPNOTSUPP;
|
||||
}
|
||||
|
||||
if (message.Messages.Length == 0)
|
||||
{
|
||||
return LinuxError.SUCCESS;
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
int sendSize = Socket.Send(ConvertMessagesToBuffer(message), ConvertBsdSocketFlags(flags), out SocketError socketError);
|
||||
|
||||
if (sendSize > 0)
|
||||
{
|
||||
UpdateMessages(out vlen, message, sendSize);
|
||||
}
|
||||
|
||||
return WinSockHelper.ConvertError((WsaError)socketError);
|
||||
}
|
||||
catch (SocketException exception)
|
||||
{
|
||||
return WinSockHelper.ConvertError((WsaError)exception.ErrorCode);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
56
Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdMMsgHdr.cs
Normal file
56
Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdMMsgHdr.cs
Normal file
|
@ -0,0 +1,56 @@
|
|||
using System;
|
||||
|
||||
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
|
||||
{
|
||||
class BsdMMsgHdr
|
||||
{
|
||||
public BsdMsgHdr[] Messages { get; }
|
||||
|
||||
private BsdMMsgHdr(BsdMsgHdr[] messages)
|
||||
{
|
||||
Messages = messages;
|
||||
}
|
||||
|
||||
public static LinuxError Serialize(Span<byte> rawData, BsdMMsgHdr message)
|
||||
{
|
||||
rawData[0] = 0x8;
|
||||
rawData = rawData[1..];
|
||||
|
||||
for (int index = 0; index < message.Messages.Length; index++)
|
||||
{
|
||||
LinuxError res = BsdMsgHdr.Serialize(ref rawData, message.Messages[index]);
|
||||
|
||||
if (res != LinuxError.SUCCESS)
|
||||
{
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
||||
return LinuxError.SUCCESS;
|
||||
}
|
||||
|
||||
public static LinuxError Deserialize(out BsdMMsgHdr message, ReadOnlySpan<byte> rawData, int vlen)
|
||||
{
|
||||
message = null;
|
||||
|
||||
BsdMsgHdr[] messages = new BsdMsgHdr[vlen];
|
||||
|
||||
// Skip "header" byte (Nintendo also ignore it)
|
||||
rawData = rawData[1..];
|
||||
|
||||
for (int index = 0; index < messages.Length; index++)
|
||||
{
|
||||
LinuxError res = BsdMsgHdr.Deserialize(out messages[index], ref rawData);
|
||||
|
||||
if (res != LinuxError.SUCCESS)
|
||||
{
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
||||
message = new BsdMMsgHdr(messages);
|
||||
|
||||
return LinuxError.SUCCESS;
|
||||
}
|
||||
}
|
||||
}
|
212
Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdMsgHdr.cs
Normal file
212
Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/BsdMsgHdr.cs
Normal file
|
@ -0,0 +1,212 @@
|
|||
using System;
|
||||
using System.Runtime.InteropServices;
|
||||
|
||||
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
|
||||
{
|
||||
class BsdMsgHdr
|
||||
{
|
||||
public byte[] Name { get; }
|
||||
public byte[][] Iov { get; }
|
||||
public byte[] Control { get; }
|
||||
public BsdSocketFlags Flags { get; }
|
||||
public uint Length;
|
||||
|
||||
private BsdMsgHdr(byte[] name, byte[][] iov, byte[] control, BsdSocketFlags flags, uint length)
|
||||
{
|
||||
Name = name;
|
||||
Iov = iov;
|
||||
Control = control;
|
||||
Flags = flags;
|
||||
Length = length;
|
||||
}
|
||||
|
||||
public static LinuxError Serialize(ref Span<byte> rawData, BsdMsgHdr message)
|
||||
{
|
||||
int msgNameLength = message.Name == null ? 0 : message.Name.Length;
|
||||
int iovCount = message.Iov == null ? 0 : message.Iov.Length;
|
||||
int controlLength = message.Control == null ? 0 : message.Control.Length;
|
||||
BsdSocketFlags flags = message.Flags;
|
||||
|
||||
if (!MemoryMarshal.TryWrite(rawData, ref msgNameLength))
|
||||
{
|
||||
return LinuxError.EFAULT;
|
||||
}
|
||||
|
||||
rawData = rawData[sizeof(uint)..];
|
||||
|
||||
if (msgNameLength > 0)
|
||||
{
|
||||
if (rawData.Length < msgNameLength)
|
||||
{
|
||||
return LinuxError.EFAULT;
|
||||
}
|
||||
|
||||
message.Name.CopyTo(rawData);
|
||||
rawData = rawData[msgNameLength..];
|
||||
}
|
||||
|
||||
if (!MemoryMarshal.TryWrite(rawData, ref iovCount))
|
||||
{
|
||||
return LinuxError.EFAULT;
|
||||
}
|
||||
|
||||
rawData = rawData[sizeof(uint)..];
|
||||
|
||||
if (iovCount > 0)
|
||||
{
|
||||
for (int index = 0; index < iovCount; index++)
|
||||
{
|
||||
ulong iovLength = (ulong)message.Iov[index].Length;
|
||||
|
||||
if (!MemoryMarshal.TryWrite(rawData, ref iovLength))
|
||||
{
|
||||
return LinuxError.EFAULT;
|
||||
}
|
||||
|
||||
rawData = rawData[sizeof(ulong)..];
|
||||
|
||||
if (iovLength > 0)
|
||||
{
|
||||
if ((ulong)rawData.Length < iovLength)
|
||||
{
|
||||
return LinuxError.EFAULT;
|
||||
}
|
||||
|
||||
message.Iov[index].CopyTo(rawData);
|
||||
rawData = rawData[(int)iovLength..];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!MemoryMarshal.TryWrite(rawData, ref controlLength))
|
||||
{
|
||||
return LinuxError.EFAULT;
|
||||
}
|
||||
|
||||
rawData = rawData[sizeof(uint)..];
|
||||
|
||||
if (controlLength > 0)
|
||||
{
|
||||
if (rawData.Length < controlLength)
|
||||
{
|
||||
return LinuxError.EFAULT;
|
||||
}
|
||||
|
||||
message.Control.CopyTo(rawData);
|
||||
rawData = rawData[controlLength..];
|
||||
}
|
||||
|
||||
if (!MemoryMarshal.TryWrite(rawData, ref flags))
|
||||
{
|
||||
return LinuxError.EFAULT;
|
||||
}
|
||||
|
||||
rawData = rawData[sizeof(BsdSocketFlags)..];
|
||||
|
||||
if (!MemoryMarshal.TryWrite(rawData, ref message.Length))
|
||||
{
|
||||
return LinuxError.EFAULT;
|
||||
}
|
||||
|
||||
rawData = rawData[sizeof(uint)..];
|
||||
|
||||
return LinuxError.SUCCESS;
|
||||
}
|
||||
|
||||
public static LinuxError Deserialize(out BsdMsgHdr message, ref ReadOnlySpan<byte> rawData)
|
||||
{
|
||||
byte[] name = null;
|
||||
byte[][] iov = null;
|
||||
byte[] control = null;
|
||||
|
||||
message = null;
|
||||
|
||||
if (!MemoryMarshal.TryRead(rawData, out uint msgNameLength))
|
||||
{
|
||||
return LinuxError.EFAULT;
|
||||
}
|
||||
|
||||
rawData = rawData[sizeof(uint)..];
|
||||
|
||||
if (msgNameLength > 0)
|
||||
{
|
||||
if (rawData.Length < msgNameLength)
|
||||
{
|
||||
return LinuxError.EFAULT;
|
||||
}
|
||||
|
||||
name = rawData[..(int)msgNameLength].ToArray();
|
||||
rawData = rawData[(int)msgNameLength..];
|
||||
}
|
||||
|
||||
if (!MemoryMarshal.TryRead(rawData, out uint iovCount))
|
||||
{
|
||||
return LinuxError.EFAULT;
|
||||
}
|
||||
|
||||
rawData = rawData[sizeof(uint)..];
|
||||
|
||||
if (iovCount > 0)
|
||||
{
|
||||
iov = new byte[iovCount][];
|
||||
|
||||
for (int index = 0; index < iov.Length; index++)
|
||||
{
|
||||
if (!MemoryMarshal.TryRead(rawData, out ulong iovLength))
|
||||
{
|
||||
return LinuxError.EFAULT;
|
||||
}
|
||||
|
||||
rawData = rawData[sizeof(ulong)..];
|
||||
|
||||
if (iovLength > 0)
|
||||
{
|
||||
if ((ulong)rawData.Length < iovLength)
|
||||
{
|
||||
return LinuxError.EFAULT;
|
||||
}
|
||||
|
||||
iov[index] = rawData[..(int)iovLength].ToArray();
|
||||
rawData = rawData[(int)iovLength..];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!MemoryMarshal.TryRead(rawData, out uint controlLength))
|
||||
{
|
||||
return LinuxError.EFAULT;
|
||||
}
|
||||
|
||||
rawData = rawData[sizeof(uint)..];
|
||||
|
||||
if (controlLength > 0)
|
||||
{
|
||||
if (rawData.Length < controlLength)
|
||||
{
|
||||
return LinuxError.EFAULT;
|
||||
}
|
||||
|
||||
control = rawData[..(int)controlLength].ToArray();
|
||||
rawData = rawData[(int)controlLength..];
|
||||
}
|
||||
|
||||
if (!MemoryMarshal.TryRead(rawData, out BsdSocketFlags flags))
|
||||
{
|
||||
return LinuxError.EFAULT;
|
||||
}
|
||||
|
||||
rawData = rawData[sizeof(BsdSocketFlags)..];
|
||||
|
||||
if (!MemoryMarshal.TryRead(rawData, out uint length))
|
||||
{
|
||||
return LinuxError.EFAULT;
|
||||
}
|
||||
|
||||
rawData = rawData[sizeof(uint)..];
|
||||
|
||||
message = new BsdMsgHdr(name, iov, control, flags, length);
|
||||
|
||||
return LinuxError.SUCCESS;
|
||||
}
|
||||
}
|
||||
}
|
8
Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/TimeVal.cs
Normal file
8
Ryujinx.HLE/HOS/Services/Sockets/Bsd/Types/TimeVal.cs
Normal file
|
@ -0,0 +1,8 @@
|
|||
namespace Ryujinx.HLE.HOS.Services.Sockets.Bsd
|
||||
{
|
||||
public struct TimeVal
|
||||
{
|
||||
public ulong TvSec;
|
||||
public ulong TvUsec;
|
||||
}
|
||||
}
|
Reference in a new issue