diff --git a/Ryujinx.HLE/FileSystem/Content/ContentManager.cs b/Ryujinx.HLE/FileSystem/Content/ContentManager.cs index dbc18d85..9203f156 100644 --- a/Ryujinx.HLE/FileSystem/Content/ContentManager.cs +++ b/Ryujinx.HLE/FileSystem/Content/ContentManager.cs @@ -10,6 +10,7 @@ using LibHac.Tools.FsSystem.NcaUtils; using LibHac.Tools.Ncm; using Ryujinx.Common.Logging; using Ryujinx.HLE.Exceptions; +using Ryujinx.HLE.HOS.Services.Ssl; using Ryujinx.HLE.HOS.Services.Time; using Ryujinx.HLE.Utilities; using System; @@ -195,6 +196,7 @@ namespace Ryujinx.HLE.FileSystem.Content if (device != null) { TimeManager.Instance.InitializeTimeZone(device); + BuiltInCertificateManager.Instance.Initialize(device); device.System.SharedFontManager.Initialize(); } } diff --git a/Ryujinx.HLE/HOS/Services/Ssl/BuiltInCertificateManager.cs b/Ryujinx.HLE/HOS/Services/Ssl/BuiltInCertificateManager.cs new file mode 100644 index 00000000..b585224d --- /dev/null +++ b/Ryujinx.HLE/HOS/Services/Ssl/BuiltInCertificateManager.cs @@ -0,0 +1,237 @@ +using LibHac; +using LibHac.Common; +using LibHac.Fs; +using LibHac.Fs.Fsa; +using LibHac.FsSystem; +using LibHac.Tools.FsSystem; +using LibHac.Tools.FsSystem.NcaUtils; +using Ryujinx.Common.Configuration; +using Ryujinx.Common.Logging; +using Ryujinx.HLE.Exceptions; +using Ryujinx.HLE.FileSystem; +using Ryujinx.HLE.FileSystem.Content; +using Ryujinx.HLE.HOS.Services.Ssl.Types; +using System; +using System.Collections.Generic; +using System.IO; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +namespace Ryujinx.HLE.HOS.Services.Ssl +{ + class BuiltInCertificateManager + { + private const long CertStoreTitleId = 0x0100000000000800; + + private readonly string CertStoreTitleMissingErrorMessage = "CertStore system title not found! SSL CA retrieving will not work, provide the system archive to fix this error. (See https://github.com/Ryujinx/Ryujinx/wiki/Ryujinx-Setup-&-Configuration-Guide#initial-setup-continued---installation-of-firmware for more information)"; + + private static BuiltInCertificateManager _instance; + + public static BuiltInCertificateManager Instance + { + get + { + if (_instance == null) + { + _instance = new BuiltInCertificateManager(); + } + + return _instance; + } + } + + private VirtualFileSystem _virtualFileSystem; + private IntegrityCheckLevel _fsIntegrityCheckLevel; + private ContentManager _contentManager; + private bool _initialized; + private Dictionary _certificates; + + private object _lock = new object(); + + private struct CertStoreFileHeader + { + private const uint ValidMagic = 0x546C7373; + +#pragma warning disable CS0649 + public uint Magic; + public uint EntriesCount; +#pragma warning restore CS0649 + + public bool IsValid() + { + return Magic == ValidMagic; + } + } + + private struct CertStoreFileEntry + { +#pragma warning disable CS0649 + public CaCertificateId Id; + public TrustedCertStatus Status; + public uint DataSize; + public uint DataOffset; +#pragma warning restore CS0649 + } + + public class CertStoreEntry + { + public CaCertificateId Id; + public TrustedCertStatus Status; + public byte[] Data; + } + + public string GetCertStoreTitleContentPath() + { + return _contentManager.GetInstalledContentPath(CertStoreTitleId, StorageId.NandSystem, NcaContentType.Data); + } + + public bool HasCertStoreTitle() + { + return !string.IsNullOrEmpty(GetCertStoreTitleContentPath()); + } + + private CertStoreEntry ReadCertStoreEntry(ReadOnlySpan buffer, CertStoreFileEntry entry) + { + string customCertificatePath = System.IO.Path.Join(AppDataManager.BaseDirPath, "system", "ssl", $"{entry.Id}.der"); + + byte[] data; + + if (File.Exists(customCertificatePath)) + { + data = File.ReadAllBytes(customCertificatePath); + } + else + { + data = buffer.Slice((int)entry.DataOffset, (int)entry.DataSize).ToArray(); + } + + return new CertStoreEntry + { + Id = entry.Id, + Status = entry.Status, + Data = data + }; + } + + public void Initialize(Switch device) + { + lock (_lock) + { + _certificates = new Dictionary(); + _initialized = false; + _contentManager = device.System.ContentManager; + _virtualFileSystem = device.FileSystem; + _fsIntegrityCheckLevel = device.System.FsIntegrityCheckLevel; + + if (HasCertStoreTitle()) + { + using LocalStorage ncaFile = new LocalStorage(_virtualFileSystem.SwitchPathToSystemPath(GetCertStoreTitleContentPath()), FileAccess.Read, FileMode.Open); + + Nca nca = new Nca(_virtualFileSystem.KeySet, ncaFile); + + IFileSystem romfs = nca.OpenFileSystem(NcaSectionType.Data, _fsIntegrityCheckLevel); + + using var trustedCertsFileRef = new UniqueRef(); + + Result result = romfs.OpenFile(ref trustedCertsFileRef.Ref(), "/ssl_TrustedCerts.bdf".ToU8Span(), OpenMode.Read); + + if (!result.IsSuccess()) + { + // [1.0.0 - 2.3.0] + if (ResultFs.PathNotFound.Includes(result)) + { + result = romfs.OpenFile(ref trustedCertsFileRef.Ref(), "/ssl_TrustedCerts.tcf".ToU8Span(), OpenMode.Read); + } + + if (result.IsFailure()) + { + Logger.Error?.Print(LogClass.ServiceSsl, CertStoreTitleMissingErrorMessage); + + return; + } + } + + using IFile trustedCertsFile = trustedCertsFileRef.Release(); + + trustedCertsFile.GetSize(out long fileSize).ThrowIfFailure(); + + Span trustedCertsRaw = new byte[fileSize]; + + trustedCertsFile.Read(out _, 0, trustedCertsRaw).ThrowIfFailure(); + + CertStoreFileHeader header = MemoryMarshal.Read(trustedCertsRaw); + + if (!header.IsValid()) + { + Logger.Error?.Print(LogClass.ServiceSsl, "Invalid CertStore data found, skipping!"); + + return; + } + + ReadOnlySpan trustedCertsData = trustedCertsRaw[Unsafe.SizeOf()..]; + ReadOnlySpan trustedCertsEntries = MemoryMarshal.Cast(trustedCertsData)[..(int)header.EntriesCount]; + + foreach (CertStoreFileEntry entry in trustedCertsEntries) + { + _certificates.Add(entry.Id, ReadCertStoreEntry(trustedCertsData, entry)); + } + + _initialized = true; + } + } + } + + public bool TryGetCertificates(ReadOnlySpan ids, out CertStoreEntry[] entries) + { + lock (_lock) + { + if (!_initialized) + { + throw new InvalidSystemResourceException(CertStoreTitleMissingErrorMessage); + } + + bool hasAllCertificates = false; + + foreach (CaCertificateId id in ids) + { + if (id == CaCertificateId.All) + { + hasAllCertificates = true; + + break; + } + } + + if (hasAllCertificates) + { + entries = new CertStoreEntry[_certificates.Count]; + + int i = 0; + + foreach (CertStoreEntry entry in _certificates.Values) + { + entries[i++] = entry; + } + + return true; + } + else + { + entries = new CertStoreEntry[ids.Length]; + + for (int i = 0; i < ids.Length; i++) + { + if (!_certificates.TryGetValue(ids[i], out CertStoreEntry entry)) + { + return false; + } + + entries[i] = entry; + } + + return true; + } + } + } + } +} diff --git a/Ryujinx.HLE/HOS/Services/Ssl/ISslService.cs b/Ryujinx.HLE/HOS/Services/Ssl/ISslService.cs index b7811ec1..90814cf9 100644 --- a/Ryujinx.HLE/HOS/Services/Ssl/ISslService.cs +++ b/Ryujinx.HLE/HOS/Services/Ssl/ISslService.cs @@ -1,6 +1,11 @@ using Ryujinx.Common.Logging; +using Ryujinx.HLE.Exceptions; using Ryujinx.HLE.HOS.Services.Ssl.SslService; using Ryujinx.HLE.HOS.Services.Ssl.Types; +using Ryujinx.Memory; +using System; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; namespace Ryujinx.HLE.HOS.Services.Ssl { @@ -18,13 +23,85 @@ namespace Ryujinx.HLE.HOS.Services.Ssl SslVersion sslVersion = (SslVersion)context.RequestData.ReadUInt32(); ulong pidPlaceholder = context.RequestData.ReadUInt64(); - MakeObject(context, new ISslContext(context)); + MakeObject(context, new ISslContext(context.Request.HandleDesc.PId, sslVersion)); Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { sslVersion }); return ResultCode.Success; } + private uint ComputeCertificateBufferSizeRequired(ReadOnlySpan entries) + { + uint totalSize = 0; + + for (int i = 0; i < entries.Length; i++) + { + totalSize += (uint)Unsafe.SizeOf(); + totalSize += (uint)entries[i].Data.Length; + } + + return totalSize; + } + + [CommandHipc(2)] + // GetCertificates(buffer ids) -> (u32 certificates_count, buffer certificates) + public ResultCode GetCertificates(ServiceCtx context) + { + ReadOnlySpan ids = MemoryMarshal.Cast(context.Memory.GetSpan(context.Request.SendBuff[0].Position, (int)context.Request.SendBuff[0].Size)); + + if (!BuiltInCertificateManager.Instance.TryGetCertificates(ids, out BuiltInCertificateManager.CertStoreEntry[] entries)) + { + throw new InvalidOperationException(); + } + + if (ComputeCertificateBufferSizeRequired(entries) > context.Request.ReceiveBuff[0].Size) + { + return ResultCode.InvalidCertBufSize; + } + + using (WritableRegion region = context.Memory.GetWritableRegion(context.Request.ReceiveBuff[0].Position, (int)context.Request.ReceiveBuff[0].Size)) + { + Span rawData = region.Memory.Span; + Span infos = MemoryMarshal.Cast(rawData)[..entries.Length]; + Span certificatesData = rawData[(Unsafe.SizeOf() * entries.Length)..]; + + for (int i = 0; i < infos.Length; i++) + { + entries[i].Data.CopyTo(certificatesData); + + infos[i] = new BuiltInCertificateInfo + { + Id = entries[i].Id, + Status = entries[i].Status, + CertificateDataSize = (ulong)entries[i].Data.Length, + CertificateDataOffset = (ulong)(rawData.Length - certificatesData.Length) + }; + + certificatesData = certificatesData[entries[i].Data.Length..]; + } + } + + context.ResponseData.Write(entries.Length); + + return ResultCode.Success; + } + + [CommandHipc(3)] + // GetCertificateBufSize(buffer ids) -> u32 buffer_size; + public ResultCode GetCertificateBufSize(ServiceCtx context) + { + ReadOnlySpan ids = MemoryMarshal.Cast(context.Memory.GetSpan(context.Request.SendBuff[0].Position, (int)context.Request.SendBuff[0].Size)); + + if (!BuiltInCertificateManager.Instance.TryGetCertificates(ids, out BuiltInCertificateManager.CertStoreEntry[] entries)) + { + throw new InvalidOperationException(); + } + + context.ResponseData.Write(ComputeCertificateBufferSizeRequired(entries)); + + return ResultCode.Success; + } + [CommandHipc(5)] // SetInterfaceVersion(u32) public ResultCode SetInterfaceVersion(ServiceCtx context) diff --git a/Ryujinx.HLE/HOS/Services/Ssl/ResultCode.cs b/Ryujinx.HLE/HOS/Services/Ssl/ResultCode.cs new file mode 100644 index 00000000..862c79cd --- /dev/null +++ b/Ryujinx.HLE/HOS/Services/Ssl/ResultCode.cs @@ -0,0 +1,20 @@ +namespace Ryujinx.HLE.HOS.Services.Ssl +{ + public enum ResultCode + { + OsModuleId = 123, + ErrorCodeShift = 9, + + Success = 0, + NoSocket = (103 << ErrorCodeShift) | OsModuleId, + InvalidSocket = (106 << ErrorCodeShift) | OsModuleId, + InvalidCertBufSize = (112 << ErrorCodeShift) | OsModuleId, + InvalidOption = (126 << ErrorCodeShift) | OsModuleId, + CertBufferTooSmall = (202 << ErrorCodeShift) | OsModuleId, + AlreadyInUse = (203 << ErrorCodeShift) | OsModuleId, + WouldBlock = (204 << ErrorCodeShift) | OsModuleId, + Timeout = (205 << ErrorCodeShift) | OsModuleId, + ConnectionReset = (209 << ErrorCodeShift) | OsModuleId, + ConnectionAbort = (210 << ErrorCodeShift) | OsModuleId + } +} diff --git a/Ryujinx.HLE/HOS/Services/Ssl/SslService/ISslConnection.cs b/Ryujinx.HLE/HOS/Services/Ssl/SslService/ISslConnection.cs index 24f3d066..fba22f45 100644 --- a/Ryujinx.HLE/HOS/Services/Ssl/SslService/ISslConnection.cs +++ b/Ryujinx.HLE/HOS/Services/Ssl/SslService/ISslConnection.cs @@ -1,41 +1,101 @@ using Ryujinx.Common.Logging; +using Ryujinx.HLE.Exceptions; +using Ryujinx.HLE.HOS.Services.Sockets.Bsd; using Ryujinx.HLE.HOS.Services.Ssl.Types; +using Ryujinx.Memory; +using System; using System.Text; namespace Ryujinx.HLE.HOS.Services.Ssl.SslService { - class ISslConnection : IpcService + class ISslConnection : IpcService, IDisposable { - public ISslConnection() { } + private bool _doNotClockSocket; + private bool _getServerCertChain; + private bool _skipDefaultVerify; + private bool _enableAlpn; + + private SslVersion _sslVersion; + private IoMode _ioMode; + private VerifyOption _verifyOption; + private SessionCacheMode _sessionCacheMode; + private string _hostName; + + private ISslConnectionBase _connection; + private BsdContext _bsdContext; + private readonly long _processId; + + private byte[] _nextAplnProto; + + public ISslConnection(long processId, SslVersion sslVersion) + { + _processId = processId; + _sslVersion = sslVersion; + _ioMode = IoMode.Blocking; + _sessionCacheMode = SessionCacheMode.None; + _verifyOption = VerifyOption.PeerCa | VerifyOption.HostName; + } [CommandHipc(0)] // SetSocketDescriptor(u32) -> u32 public ResultCode SetSocketDescriptor(ServiceCtx context) { - uint socketFd = context.RequestData.ReadUInt32(); - uint duplicateSocketFd = 0; + if (_connection != null) + { + return ResultCode.AlreadyInUse; + } - context.ResponseData.Write(duplicateSocketFd); + _bsdContext = BsdContext.GetContext(_processId); - Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { socketFd }); + if (_bsdContext == null) + { + return ResultCode.InvalidSocket; + } + + int inputFd = context.RequestData.ReadInt32(); + + int internalFd = _bsdContext.DuplicateFileDescriptor(inputFd); + + if (internalFd == -1) + { + return ResultCode.InvalidSocket; + } + + InitializeConnection(internalFd); + + int outputFd = inputFd; + + if (_doNotClockSocket) + { + outputFd = -1; + } + + context.ResponseData.Write(outputFd); return ResultCode.Success; } + private void InitializeConnection(int socketFd) + { + ISocket bsdSocket = _bsdContext.RetrieveSocket(socketFd); + + _connection = new SslManagedSocketConnection(_bsdContext, _sslVersion, socketFd, bsdSocket); + } + [CommandHipc(1)] // SetHostName(buffer) public ResultCode SetHostName(ServiceCtx context) { ulong hostNameDataPosition = context.Request.SendBuff[0].Position; - ulong hostNameDataSize = context.Request.SendBuff[0].Size; + ulong hostNameDataSize = context.Request.SendBuff[0].Size; byte[] hostNameData = new byte[hostNameDataSize]; context.Memory.Read(hostNameDataPosition, hostNameData); - string hostName = Encoding.ASCII.GetString(hostNameData).Trim('\0'); + _hostName = Encoding.ASCII.GetString(hostNameData).Trim('\0'); - Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { hostName }); + Logger.Info?.Print(LogClass.ServiceSsl, _hostName); return ResultCode.Success; } @@ -44,9 +104,9 @@ namespace Ryujinx.HLE.HOS.Services.Ssl.SslService // SetVerifyOption(nn::ssl::sf::VerifyOption) public ResultCode SetVerifyOption(ServiceCtx context) { - VerifyOption verifyOption = (VerifyOption)context.RequestData.ReadUInt32(); + _verifyOption = (VerifyOption)context.RequestData.ReadUInt32(); - Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { verifyOption }); + Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { _verifyOption }); return ResultCode.Success; } @@ -55,9 +115,67 @@ namespace Ryujinx.HLE.HOS.Services.Ssl.SslService // SetIoMode(nn::ssl::sf::IoMode) public ResultCode SetIoMode(ServiceCtx context) { - IoMode ioMode = (IoMode)context.RequestData.ReadUInt32(); + if (_connection == null) + { + return ResultCode.NoSocket; + } - Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { ioMode }); + _ioMode = (IoMode)context.RequestData.ReadUInt32(); + + _connection.Socket.Blocking = _ioMode == IoMode.Blocking; + + Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { _ioMode }); + + return ResultCode.Success; + } + + [CommandHipc(4)] + // GetSocketDescriptor() -> u32 + public ResultCode GetSocketDescriptor(ServiceCtx context) + { + context.ResponseData.Write(_connection.SocketFd); + + return ResultCode.Success; + } + + [CommandHipc(5)] + // GetHostName(buffer) -> u32 + public ResultCode GetHostName(ServiceCtx context) + { + ulong hostNameDataPosition = context.Request.ReceiveBuff[0].Position; + ulong hostNameDataSize = context.Request.ReceiveBuff[0].Size; + + byte[] hostNameData = new byte[hostNameDataSize]; + + Encoding.ASCII.GetBytes(_hostName, hostNameData); + + context.Memory.Write(hostNameDataPosition, hostNameData); + + context.ResponseData.Write((uint)_hostName.Length); + + Logger.Info?.Print(LogClass.ServiceSsl, _hostName); + + return ResultCode.Success; + } + + [CommandHipc(6)] + // GetVerifyOption() -> nn::ssl::sf::VerifyOption + public ResultCode GetVerifyOption(ServiceCtx context) + { + context.ResponseData.Write((uint)_verifyOption); + + Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { _verifyOption }); + + return ResultCode.Success; + } + + [CommandHipc(7)] + // GetIoMode() -> nn::ssl::sf::IoMode + public ResultCode GetIoMode(ServiceCtx context) + { + context.ResponseData.Write((uint)_ioMode); + + Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { _ioMode }); return ResultCode.Success; } @@ -66,30 +184,153 @@ namespace Ryujinx.HLE.HOS.Services.Ssl.SslService // DoHandshake() public ResultCode DoHandshake(ServiceCtx context) { - Logger.Stub?.PrintStub(LogClass.ServiceSsl); + if (_connection == null) + { + return ResultCode.NoSocket; + } + + return _connection.Handshake(_hostName); + } + + [CommandHipc(9)] + // DoHandshakeGetServerCert() -> (u32, u32, buffer) + public ResultCode DoHandshakeGetServerCert(ServiceCtx context) + { + if (_connection == null) + { + return ResultCode.NoSocket; + } + + ResultCode result = _connection.Handshake(_hostName); + + if (result == ResultCode.Success) + { + if (_getServerCertChain) + { + using (WritableRegion region = context.Memory.GetWritableRegion(context.Request.ReceiveBuff[0].Position, (int)context.Request.ReceiveBuff[0].Size)) + { + result = _connection.GetServerCertificate(_hostName, region.Memory.Span, out uint bufferSize, out uint certificateCount); + + context.ResponseData.Write(bufferSize); + context.ResponseData.Write(certificateCount); + } + } + else + { + context.ResponseData.Write(0); + context.ResponseData.Write(0); + } + } + + return result; + } + + [CommandHipc(10)] + // Read() -> (u32, buffer) + public ResultCode Read(ServiceCtx context) + { + if (_connection == null) + { + return ResultCode.NoSocket; + } + + ResultCode result; + + using (WritableRegion region = context.Memory.GetWritableRegion(context.Request.ReceiveBuff[0].Position, (int)context.Request.ReceiveBuff[0].Size)) + { + // TODO: Better error management. + result = _connection.Read(out int readCount, region.Memory); + + if (result == ResultCode.Success) + { + context.ResponseData.Write(readCount); + } + } + + return result; + } + + [CommandHipc(11)] + // Write(buffer) -> s32 + public ResultCode Write(ServiceCtx context) + { + if (_connection == null) + { + return ResultCode.NoSocket; + } + + // We don't dispose as this isn't supposed to be modified + WritableRegion region = context.Memory.GetWritableRegion(context.Request.SendBuff[0].Position, (int)context.Request.SendBuff[0].Size); + + // TODO: Better error management. + ResultCode result = _connection.Write(out int writtenCount, region.Memory); + + if (result == ResultCode.Success) + { + context.ResponseData.Write(writtenCount); + } + + return result; + } + + [CommandHipc(12)] + // Pending() -> s32 + public ResultCode Pending(ServiceCtx context) + { + if (_connection == null) + { + return ResultCode.NoSocket; + } + + context.ResponseData.Write(_connection.Pending()); return ResultCode.Success; } - [CommandHipc(11)] - // Write(buffer) -> u32 - public ResultCode Write(ServiceCtx context) + [CommandHipc(13)] + // Peek() -> (s32, buffer) + public ResultCode Peek(ServiceCtx context) { - ulong inputDataPosition = context.Request.SendBuff[0].Position; - ulong inputDataSize = context.Request.SendBuff[0].Size; + if (_connection == null) + { + return ResultCode.NoSocket; + } - byte[] data = new byte[inputDataSize]; + ResultCode result; - context.Memory.Read(inputDataPosition, data); + using (WritableRegion region = context.Memory.GetWritableRegion(context.Request.ReceiveBuff[0].Position, (int)context.Request.ReceiveBuff[0].Size)) + { + // TODO: Better error management. + result = _connection.Peek(out int peekCount, region.Memory); - // NOTE: Tell the guest everything is transferred. - uint transferredSize = (uint)inputDataSize; + if (result == ResultCode.Success) + { + context.ResponseData.Write(peekCount); + } + } - context.ResponseData.Write(transferredSize); + return result; + } - Logger.Stub?.PrintStub(LogClass.ServiceSsl); + [CommandHipc(14)] + // Poll(nn::ssl::sf::PollEvent poll_event, u32 timeout) -> nn::ssl::sf::PollEvent + public ResultCode Poll(ServiceCtx context) + { + throw new ServiceNotImplementedException(this, context); + } - return ResultCode.Success; + [CommandHipc(15)] + // GetVerifyCertError() + public ResultCode GetVerifyCertError(ServiceCtx context) + { + throw new ServiceNotImplementedException(this, context); + } + + [CommandHipc(16)] + // GetNeededServerCertBufferSize() -> u32 + public ResultCode GetNeededServerCertBufferSize(ServiceCtx context) + { + throw new ServiceNotImplementedException(this, context); } [CommandHipc(17)] @@ -100,19 +341,176 @@ namespace Ryujinx.HLE.HOS.Services.Ssl.SslService Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { sessionCacheMode }); + _sessionCacheMode = sessionCacheMode; + return ResultCode.Success; } + [CommandHipc(18)] + // GetSessionCacheMode() -> nn::ssl::sf::SessionCacheMode + public ResultCode GetSessionCacheMode(ServiceCtx context) + { + throw new ServiceNotImplementedException(this, context); + } + + [CommandHipc(19)] + // FlushSessionCache() + public ResultCode FlushSessionCache(ServiceCtx context) + { + throw new ServiceNotImplementedException(this, context); + } + + [CommandHipc(20)] + // SetRenegotiationMode(nn::ssl::sf::RenegotiationMode) + public ResultCode SetRenegotiationMode(ServiceCtx context) + { + throw new ServiceNotImplementedException(this, context); + } + + [CommandHipc(21)] + // GetRenegotiationMode() -> nn::ssl::sf::RenegotiationMode + public ResultCode GetRenegotiationMode(ServiceCtx context) + { + throw new ServiceNotImplementedException(this, context); + } + [CommandHipc(22)] - // SetOption(b8, nn::ssl::sf::OptionType) + // SetOption(b8 value, nn::ssl::sf::OptionType option) public ResultCode SetOption(ServiceCtx context) { - bool optionEnabled = context.RequestData.ReadBoolean(); - OptionType optionType = (OptionType)context.RequestData.ReadUInt32(); + bool value = context.RequestData.ReadUInt32() != 0; + OptionType option = (OptionType)context.RequestData.ReadUInt32(); - Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { optionType, optionEnabled }); + Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { option, value }); + + return SetOption(option, value); + } + + [CommandHipc(23)] + // GetOption(nn::ssl::sf::OptionType) -> b8 + public ResultCode GetOption(ServiceCtx context) + { + OptionType option = (OptionType)context.RequestData.ReadUInt32(); + + Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { option }); + + ResultCode result = GetOption(option, out bool value); + + if (result == ResultCode.Success) + { + context.ResponseData.Write(value); + } + + return result; + } + + [CommandHipc(24)] + // GetVerifyCertErrors() -> (u32, u32, buffer) + public ResultCode GetVerifyCertErrors(ServiceCtx context) + { + throw new ServiceNotImplementedException(this, context); + } + + [CommandHipc(25)] // 4.0.0+ + // GetCipherInfo(u32) -> buffer + public ResultCode GetCipherInfo(ServiceCtx context) + { + throw new ServiceNotImplementedException(this, context); + } + + [CommandHipc(26)] + // SetNextAlpnProto(buffer) -> u32 + public ResultCode SetNextAlpnProto(ServiceCtx context) + { + ulong inputDataPosition = context.Request.SendBuff[0].Position; + ulong inputDataSize = context.Request.SendBuff[0].Size; + + _nextAplnProto = new byte[inputDataSize]; + + context.Memory.Read(inputDataPosition, _nextAplnProto); + + Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { inputDataSize }); return ResultCode.Success; } + + [CommandHipc(27)] + // GetNextAlpnProto(buffer) -> u32 + public ResultCode GetNextAlpnProto(ServiceCtx context) + { + ulong outputDataPosition = context.Request.ReceiveBuff[0].Position; + ulong outputDataSize = context.Request.ReceiveBuff[0].Size; + + context.Memory.Write(outputDataPosition, _nextAplnProto); + + context.ResponseData.Write(_nextAplnProto.Length); + + Logger.Stub?.PrintStub(LogClass.ServiceSsl, new { outputDataSize }); + + return ResultCode.Success; + } + + private ResultCode SetOption(OptionType option, bool value) + { + switch (option) + { + case OptionType.DoNotCloseSocket: + _doNotClockSocket = value; + break; + + case OptionType.GetServerCertChain: + _getServerCertChain = value; + break; + + case OptionType.SkipDefaultVerify: + _skipDefaultVerify = value; + break; + + case OptionType.EnableAlpn: + _enableAlpn = value; + break; + + default: + Logger.Warning?.Print(LogClass.ServiceSsl, $"Unsupported option {option}"); + return ResultCode.InvalidOption; + } + + return ResultCode.Success; + } + + private ResultCode GetOption(OptionType option, out bool value) + { + switch (option) + { + case OptionType.DoNotCloseSocket: + value = _doNotClockSocket; + break; + + case OptionType.GetServerCertChain: + value = _getServerCertChain; + break; + + case OptionType.SkipDefaultVerify: + value = _skipDefaultVerify; + break; + + case OptionType.EnableAlpn: + value = _enableAlpn; + break; + + default: + Logger.Warning?.Print(LogClass.ServiceSsl, $"Unsupported option {option}"); + + value = false; + return ResultCode.InvalidOption; + } + + return ResultCode.Success; + } + + public void Dispose() + { + _connection?.Dispose(); + } } } \ No newline at end of file diff --git a/Ryujinx.HLE/HOS/Services/Ssl/SslService/ISslConnectionBase.cs b/Ryujinx.HLE/HOS/Services/Ssl/SslService/ISslConnectionBase.cs new file mode 100644 index 00000000..74e5fcda --- /dev/null +++ b/Ryujinx.HLE/HOS/Services/Ssl/SslService/ISslConnectionBase.cs @@ -0,0 +1,25 @@ +using Ryujinx.HLE.HOS.Services.Sockets.Bsd; +using System; +using System.Net.Sockets; + +namespace Ryujinx.HLE.HOS.Services.Ssl.SslService +{ + interface ISslConnectionBase: IDisposable + { + int SocketFd { get; } + + ISocket Socket { get; } + + ResultCode Handshake(string hostName); + + ResultCode GetServerCertificate(string hostname, Span certificates, out uint storageSize, out uint certificateCount); + + ResultCode Write(out int writtenCount, ReadOnlyMemory buffer); + + ResultCode Read(out int readCount, Memory buffer); + + ResultCode Peek(out int peekCount, Memory buffer); + + int Pending(); + } +} diff --git a/Ryujinx.HLE/HOS/Services/Ssl/SslService/ISslContext.cs b/Ryujinx.HLE/HOS/Services/Ssl/SslService/ISslContext.cs index 718af2cb..0b8cb463 100644 --- a/Ryujinx.HLE/HOS/Services/Ssl/SslService/ISslContext.cs +++ b/Ryujinx.HLE/HOS/Services/Ssl/SslService/ISslContext.cs @@ -1,4 +1,5 @@ using Ryujinx.Common.Logging; +using Ryujinx.HLE.HOS.Services.Sockets.Bsd; using Ryujinx.HLE.HOS.Services.Ssl.Types; using System.Text; @@ -8,16 +9,22 @@ namespace Ryujinx.HLE.HOS.Services.Ssl.SslService { private uint _connectionCount; + private readonly long _processId; + private readonly SslVersion _sslVersion; private ulong _serverCertificateId; private ulong _clientCertificateId; - public ISslContext(ServiceCtx context) { } + public ISslContext(long processId, SslVersion sslVersion) + { + _processId = processId; + _sslVersion = sslVersion; + } [CommandHipc(2)] // CreateConnection() -> object public ResultCode CreateConnection(ServiceCtx context) { - MakeObject(context, new ISslConnection()); + MakeObject(context, new ISslConnection(_processId, _sslVersion)); _connectionCount++; diff --git a/Ryujinx.HLE/HOS/Services/Ssl/SslService/SslManagedSocketConnection.cs b/Ryujinx.HLE/HOS/Services/Ssl/SslService/SslManagedSocketConnection.cs new file mode 100644 index 00000000..36c8b51a --- /dev/null +++ b/Ryujinx.HLE/HOS/Services/Ssl/SslService/SslManagedSocketConnection.cs @@ -0,0 +1,247 @@ +using Ryujinx.HLE.HOS.Services.Sockets.Bsd; +using Ryujinx.HLE.HOS.Services.Ssl.Types; +using System; +using System.IO; +using System.Net.Security; +using System.Net.Sockets; +using System.Security.Authentication; + +namespace Ryujinx.HLE.HOS.Services.Ssl.SslService +{ + class SslManagedSocketConnection : ISslConnectionBase + { + public int SocketFd { get; } + + public ISocket Socket { get; } + + private BsdContext _bsdContext; + private SslVersion _sslVersion; + private SslStream _stream; + private bool _isBlockingSocket; + private int _previousReadTimeout; + + public SslManagedSocketConnection(BsdContext bsdContext, SslVersion sslVersion, int socketFd, ISocket socket) + { + _bsdContext = bsdContext; + _sslVersion = sslVersion; + + SocketFd = socketFd; + Socket = socket; + } + + private void StartSslOperation() + { + // Save blocking state + _isBlockingSocket = Socket.Blocking; + + // Force blocking for SslStream + Socket.Blocking = true; + } + + private void EndSslOperation() + { + // Restore blocking state + Socket.Blocking = _isBlockingSocket; + } + + private void StartSslReadOperation() + { + StartSslOperation(); + + if (!_isBlockingSocket) + { + _previousReadTimeout = _stream.ReadTimeout; + + _stream.ReadTimeout = 1; + } + } + + private void EndSslReadOperation() + { + if (!_isBlockingSocket) + { + _stream.ReadTimeout = _previousReadTimeout; + } + + EndSslOperation(); + } + + private static SslProtocols TranslateSslVersion(SslVersion version) + { + switch (version & SslVersion.VersionMask) + { + case SslVersion.Auto: + return SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12 | SslProtocols.Tls13; + case SslVersion.TlsV10: + return SslProtocols.Tls; + case SslVersion.TlsV11: + return SslProtocols.Tls11; + case SslVersion.TlsV12: + return SslProtocols.Tls12; + case SslVersion.TlsV13: + return SslProtocols.Tls13; + default: + throw new NotImplementedException(version.ToString()); + } + } + + public ResultCode Handshake(string hostName) + { + StartSslOperation(); + _stream = new SslStream(new NetworkStream(((ManagedSocket)Socket).Socket, false), false, null, null); + _stream.AuthenticateAsClient(hostName, null, TranslateSslVersion(_sslVersion), false); + EndSslOperation(); + + return ResultCode.Success; + } + + public ResultCode Peek(out int peekCount, Memory buffer) + { + // NOTE: We cannot support that on .NET SSL API. + // As Nintendo's curl implementation detail check if a connection is alive via Peek, we just return that it would block to let it know that it's alive. + peekCount = -1; + + return ResultCode.WouldBlock; + } + + public int Pending() + { + // Unsupported + return 0; + } + + private static bool TryTranslateWinSockError(bool isBlocking, WsaError error, out ResultCode resultCode) + { + switch (error) + { + case WsaError.WSAETIMEDOUT: + resultCode = isBlocking ? ResultCode.Timeout : ResultCode.WouldBlock; + return true; + case WsaError.WSAECONNABORTED: + resultCode = ResultCode.ConnectionAbort; + return true; + case WsaError.WSAECONNRESET: + resultCode = ResultCode.ConnectionReset; + return true; + default: + resultCode = ResultCode.Success; + return false; + } + } + + public ResultCode Read(out int readCount, Memory buffer) + { + if (!Socket.Poll(0, SelectMode.SelectRead)) + { + readCount = -1; + + return ResultCode.WouldBlock; + } + + StartSslReadOperation(); + + try + { + readCount = _stream.Read(buffer.Span); + } + catch (IOException exception) + { + readCount = -1; + + if (exception.InnerException is SocketException socketException) + { + WsaError socketErrorCode = (WsaError)socketException.SocketErrorCode; + + if (TryTranslateWinSockError(_isBlockingSocket, socketErrorCode, out ResultCode result)) + { + return result; + } + else + { + throw socketException; + } + } + else + { + throw exception; + } + } + finally + { + EndSslReadOperation(); + } + + return ResultCode.Success; + } + + public ResultCode Write(out int writtenCount, ReadOnlyMemory buffer) + { + if (!Socket.Poll(0, SelectMode.SelectWrite)) + { + writtenCount = 0; + + return ResultCode.WouldBlock; + } + + StartSslOperation(); + + try + { + _stream.Write(buffer.Span); + } + catch (IOException exception) + { + writtenCount = -1; + + if (exception.InnerException is SocketException socketException) + { + WsaError socketErrorCode = (WsaError)socketException.SocketErrorCode; + + if (TryTranslateWinSockError(_isBlockingSocket, socketErrorCode, out ResultCode result)) + { + return result; + } + else + { + throw socketException; + } + } + else + { + throw exception; + } + } + finally + { + EndSslOperation(); + } + + // .NET API doesn't provide the size written, assume all written. + writtenCount = buffer.Length; + + return ResultCode.Success; + } + + public ResultCode GetServerCertificate(string hostname, Span certificates, out uint storageSize, out uint certificateCount) + { + byte[] rawCertData = _stream.RemoteCertificate.GetRawCertData(); + + storageSize = (uint)rawCertData.Length; + certificateCount = 1; + + if (rawCertData.Length > certificates.Length) + { + return ResultCode.CertBufferTooSmall; + } + + rawCertData.CopyTo(certificates); + + return ResultCode.Success; + } + + public void Dispose() + { + _bsdContext.CloseFileDescriptor(SocketFd); + } + } +} diff --git a/Ryujinx.HLE/HOS/Services/Ssl/Types/BuiltInCertificateInfo.cs b/Ryujinx.HLE/HOS/Services/Ssl/Types/BuiltInCertificateInfo.cs new file mode 100644 index 00000000..313220e1 --- /dev/null +++ b/Ryujinx.HLE/HOS/Services/Ssl/Types/BuiltInCertificateInfo.cs @@ -0,0 +1,10 @@ +namespace Ryujinx.HLE.HOS.Services.Ssl.Types +{ + struct BuiltInCertificateInfo + { + public CaCertificateId Id; + public TrustedCertStatus Status; + public ulong CertificateDataSize; + public ulong CertificateDataOffset; + } +} diff --git a/Ryujinx.HLE/HOS/Services/Ssl/Types/CaCertificateId.cs b/Ryujinx.HLE/HOS/Services/Ssl/Types/CaCertificateId.cs new file mode 100644 index 00000000..5c84579a --- /dev/null +++ b/Ryujinx.HLE/HOS/Services/Ssl/Types/CaCertificateId.cs @@ -0,0 +1,68 @@ +namespace Ryujinx.HLE.HOS.Services.Ssl.Types +{ + enum CaCertificateId : uint + { + // Nintendo CAs + NintendoCAG3 = 1, + NintendoClass2CAG3, + + // External CAs + AmazonRootCA1 = 1000, + StarfieldServicesRootCertificateAuthorityG2, + AddTrustExternalCARoot, + COMODOCertificationAuthority, + UTNDATACorpSGC, + UTNUSERFirstHardware, + BaltimoreCyberTrustRoot, + CybertrustGlobalRoot, + VerizonGlobalRootCA, + DigiCertAssuredIDRootCA, + DigiCertAssuredIDRootG2, + DigiCertGlobalRootCA, + DigiCertGlobalRootG2, + DigiCertHighAssuranceEVRootCA, + EntrustnetCertificationAuthority2048, + EntrustRootCertificationAuthority, + EntrustRootCertificationAuthorityG2, + GeoTrustGlobalCA2, + GeoTrustGlobalCA, + GeoTrustPrimaryCertificationAuthorityG3, + GeoTrustPrimaryCertificationAuthority, + GlobalSignRootCA, + GlobalSignRootCAR2, + GlobalSignRootCAR3, + GoDaddyClass2CertificationAuthority, + GoDaddyRootCertificateAuthorityG2, + StarfieldClass2CertificationAuthority, + StarfieldRootCertificateAuthorityG2, + ThawtePrimaryRootCAG3, + ThawtePrimaryRootCA, + VeriSignClass3PublicPrimaryCertificationAuthorityG3, + VeriSignClass3PublicPrimaryCertificationAuthorityG5, + VeriSignUniversalRootCertificationAuthority, + DSTRootCAX3, + USERTrustRSACertificationAuthority, + ISRGRootX10, + USERTrustECCCertificationAuthority, + COMODORSACertificationAuthority, + COMODOECCCertificationAuthority, + AmazonRootCA2, + AmazonRootCA3, + AmazonRootCA4, + DigiCertAssuredIDRootG3, + DigiCertGlobalRootG3, + DigiCertTrustedRootG4, + EntrustRootCertificationAuthorityEC1, + EntrustRootCertificationAuthorityG4, + GlobalSignECCRootCAR4, + GlobalSignECCRootCAR5, + GlobalSignECCRootCAR6, + GTSRootR1, + GTSRootR2, + GTSRootR3, + GTSRootR4, + SecurityCommunicationRootCA, + + All = uint.MaxValue + } +} diff --git a/Ryujinx.HLE/HOS/Services/Ssl/Types/SslVersion.cs b/Ryujinx.HLE/HOS/Services/Ssl/Types/SslVersion.cs index a8897802..7110fd85 100644 --- a/Ryujinx.HLE/HOS/Services/Ssl/Types/SslVersion.cs +++ b/Ryujinx.HLE/HOS/Services/Ssl/Types/SslVersion.cs @@ -10,6 +10,7 @@ namespace Ryujinx.HLE.HOS.Services.Ssl.Types TlsV11 = 1 << 4, TlsV12 = 1 << 5, TlsV13 = 1 << 6, // 11.0.0+ - Auto2 = 1 << 24 // 11.0.0+ + + VersionMask = 0xFFFFFF } } \ No newline at end of file diff --git a/Ryujinx.HLE/HOS/Services/Ssl/Types/TrustedCertStatus.cs b/Ryujinx.HLE/HOS/Services/Ssl/Types/TrustedCertStatus.cs new file mode 100644 index 00000000..7fd5efd6 --- /dev/null +++ b/Ryujinx.HLE/HOS/Services/Ssl/Types/TrustedCertStatus.cs @@ -0,0 +1,12 @@ +namespace Ryujinx.HLE.HOS.Services.Ssl.Types +{ + enum TrustedCertStatus : uint + { + Removed, + EnabledTrusted, + EnabledNotTrusted, + Revoked, + + Invalid = uint.MaxValue + } +}