0
0
Fork 0
mirror of https://github.com/GreemDev/Ryujinx.git synced 2025-01-24 12:22:00 +00:00

Use SIMD acceleration for audio upsampler (#4410)

* Use SIMD acceleration for audio upsampler filter kernel for a moderate speedup

* Address formatting. Implement AVX2 fast path for high quality resampling in ResamplerHelper

* now really, are we really getting the benefit of inlining 50+ line methods?

* adding unit tests for resampler + upsampler. The upsampler ones fail for some reason

* Fixing upsampler test. Apparently this algo only works at specific ratios

---------

Co-authored-by: Logan Stromberg <lostromb@microsoft.com>
This commit is contained in:
Logan Stromberg 2023-02-21 02:44:57 -08:00 committed by GitHub
parent fc43aecbbd
commit edfd4d70c0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 279 additions and 84 deletions

View file

@ -1,5 +1,6 @@
using System; using System;
using System.Linq; using System.Linq;
using System.Numerics;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using System.Runtime.Intrinsics; using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.X86; using System.Runtime.Intrinsics.X86;
@ -380,7 +381,6 @@ namespace Ryujinx.Audio.Renderer.Dsp
return _normalCurveLut2F; return _normalCurveLut2F;
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private unsafe static void ResampleDefaultQuality(Span<float> outputBuffer, ReadOnlySpan<short> inputBuffer, float ratio, ref float fraction, int sampleCount, bool needPitch) private unsafe static void ResampleDefaultQuality(Span<float> outputBuffer, ReadOnlySpan<short> inputBuffer, float ratio, ref float fraction, int sampleCount, bool needPitch)
{ {
ReadOnlySpan<float> parameters = GetDefaultParameter(ratio); ReadOnlySpan<float> parameters = GetDefaultParameter(ratio);
@ -394,35 +394,33 @@ namespace Ryujinx.Audio.Renderer.Dsp
if (ratio == 1f) if (ratio == 1f)
{ {
fixed (short* pInput = inputBuffer) fixed (short* pInput = inputBuffer)
fixed (float* pOutput = outputBuffer, pParameters = parameters)
{ {
fixed (float* pOutput = outputBuffer, pParameters = parameters) Vector128<float> parameter = Sse.LoadVector128(pParameters);
for (; i < (sampleCount & ~3); i += 4)
{ {
Vector128<float> parameter = Sse.LoadVector128(pParameters); Vector128<int> intInput0 = Sse41.ConvertToVector128Int32(pInput + (uint)i);
Vector128<int> intInput1 = Sse41.ConvertToVector128Int32(pInput + (uint)i + 1);
Vector128<int> intInput2 = Sse41.ConvertToVector128Int32(pInput + (uint)i + 2);
Vector128<int> intInput3 = Sse41.ConvertToVector128Int32(pInput + (uint)i + 3);
for (; i < (sampleCount & ~3); i += 4) Vector128<float> input0 = Sse2.ConvertToVector128Single(intInput0);
{ Vector128<float> input1 = Sse2.ConvertToVector128Single(intInput1);
Vector128<int> intInput0 = Sse41.ConvertToVector128Int32(pInput + (uint)i); Vector128<float> input2 = Sse2.ConvertToVector128Single(intInput2);
Vector128<int> intInput1 = Sse41.ConvertToVector128Int32(pInput + (uint)i + 1); Vector128<float> input3 = Sse2.ConvertToVector128Single(intInput3);
Vector128<int> intInput2 = Sse41.ConvertToVector128Int32(pInput + (uint)i + 2);
Vector128<int> intInput3 = Sse41.ConvertToVector128Int32(pInput + (uint)i + 3);
Vector128<float> input0 = Sse2.ConvertToVector128Single(intInput0); Vector128<float> mix0 = Sse.Multiply(input0, parameter);
Vector128<float> input1 = Sse2.ConvertToVector128Single(intInput1); Vector128<float> mix1 = Sse.Multiply(input1, parameter);
Vector128<float> input2 = Sse2.ConvertToVector128Single(intInput2); Vector128<float> mix2 = Sse.Multiply(input2, parameter);
Vector128<float> input3 = Sse2.ConvertToVector128Single(intInput3); Vector128<float> mix3 = Sse.Multiply(input3, parameter);
Vector128<float> mix0 = Sse.Multiply(input0, parameter); Vector128<float> mix01 = Sse3.HorizontalAdd(mix0, mix1);
Vector128<float> mix1 = Sse.Multiply(input1, parameter); Vector128<float> mix23 = Sse3.HorizontalAdd(mix2, mix3);
Vector128<float> mix2 = Sse.Multiply(input2, parameter);
Vector128<float> mix3 = Sse.Multiply(input3, parameter);
Vector128<float> mix01 = Sse3.HorizontalAdd(mix0, mix1); Vector128<float> mix0123 = Sse3.HorizontalAdd(mix01, mix23);
Vector128<float> mix23 = Sse3.HorizontalAdd(mix2, mix3);
Vector128<float> mix0123 = Sse3.HorizontalAdd(mix01, mix23); Sse.Store(pOutput + (uint)i, Sse41.RoundToNearestInteger(mix0123));
Sse.Store(pOutput + (uint)i, Sse41.RoundToNearestInteger(mix0123));
}
} }
} }
@ -431,62 +429,60 @@ namespace Ryujinx.Audio.Renderer.Dsp
else else
{ {
fixed (short* pInput = inputBuffer) fixed (short* pInput = inputBuffer)
fixed (float* pOutput = outputBuffer, pParameters = parameters)
{ {
fixed (float* pOutput = outputBuffer, pParameters = parameters) for (; i < (sampleCount & ~3); i += 4)
{ {
for (; i < (sampleCount & ~3); i += 4) uint baseIndex0 = (uint)(fraction * 128) * 4;
{ uint inputIndex0 = (uint)inputBufferIndex;
uint baseIndex0 = (uint)(fraction * 128) * 4;
uint inputIndex0 = (uint)inputBufferIndex;
fraction += ratio; fraction += ratio;
uint baseIndex1 = ((uint)(fraction * 128) & 127) * 4; uint baseIndex1 = ((uint)(fraction * 128) & 127) * 4;
uint inputIndex1 = (uint)inputBufferIndex + (uint)fraction; uint inputIndex1 = (uint)inputBufferIndex + (uint)fraction;
fraction += ratio; fraction += ratio;
uint baseIndex2 = ((uint)(fraction * 128) & 127) * 4; uint baseIndex2 = ((uint)(fraction * 128) & 127) * 4;
uint inputIndex2 = (uint)inputBufferIndex + (uint)fraction; uint inputIndex2 = (uint)inputBufferIndex + (uint)fraction;
fraction += ratio; fraction += ratio;
uint baseIndex3 = ((uint)(fraction * 128) & 127) * 4; uint baseIndex3 = ((uint)(fraction * 128) & 127) * 4;
uint inputIndex3 = (uint)inputBufferIndex + (uint)fraction; uint inputIndex3 = (uint)inputBufferIndex + (uint)fraction;
fraction += ratio; fraction += ratio;
inputBufferIndex += (int)fraction; inputBufferIndex += (int)fraction;
// Only keep lower part (safe as fraction isn't supposed to be negative) // Only keep lower part (safe as fraction isn't supposed to be negative)
fraction -= (int)fraction; fraction -= (int)fraction;
Vector128<float> parameter0 = Sse.LoadVector128(pParameters + baseIndex0); Vector128<float> parameter0 = Sse.LoadVector128(pParameters + baseIndex0);
Vector128<float> parameter1 = Sse.LoadVector128(pParameters + baseIndex1); Vector128<float> parameter1 = Sse.LoadVector128(pParameters + baseIndex1);
Vector128<float> parameter2 = Sse.LoadVector128(pParameters + baseIndex2); Vector128<float> parameter2 = Sse.LoadVector128(pParameters + baseIndex2);
Vector128<float> parameter3 = Sse.LoadVector128(pParameters + baseIndex3); Vector128<float> parameter3 = Sse.LoadVector128(pParameters + baseIndex3);
Vector128<int> intInput0 = Sse41.ConvertToVector128Int32(pInput + inputIndex0); Vector128<int> intInput0 = Sse41.ConvertToVector128Int32(pInput + inputIndex0);
Vector128<int> intInput1 = Sse41.ConvertToVector128Int32(pInput + inputIndex1); Vector128<int> intInput1 = Sse41.ConvertToVector128Int32(pInput + inputIndex1);
Vector128<int> intInput2 = Sse41.ConvertToVector128Int32(pInput + inputIndex2); Vector128<int> intInput2 = Sse41.ConvertToVector128Int32(pInput + inputIndex2);
Vector128<int> intInput3 = Sse41.ConvertToVector128Int32(pInput + inputIndex3); Vector128<int> intInput3 = Sse41.ConvertToVector128Int32(pInput + inputIndex3);
Vector128<float> input0 = Sse2.ConvertToVector128Single(intInput0); Vector128<float> input0 = Sse2.ConvertToVector128Single(intInput0);
Vector128<float> input1 = Sse2.ConvertToVector128Single(intInput1); Vector128<float> input1 = Sse2.ConvertToVector128Single(intInput1);
Vector128<float> input2 = Sse2.ConvertToVector128Single(intInput2); Vector128<float> input2 = Sse2.ConvertToVector128Single(intInput2);
Vector128<float> input3 = Sse2.ConvertToVector128Single(intInput3); Vector128<float> input3 = Sse2.ConvertToVector128Single(intInput3);
Vector128<float> mix0 = Sse.Multiply(input0, parameter0); Vector128<float> mix0 = Sse.Multiply(input0, parameter0);
Vector128<float> mix1 = Sse.Multiply(input1, parameter1); Vector128<float> mix1 = Sse.Multiply(input1, parameter1);
Vector128<float> mix2 = Sse.Multiply(input2, parameter2); Vector128<float> mix2 = Sse.Multiply(input2, parameter2);
Vector128<float> mix3 = Sse.Multiply(input3, parameter3); Vector128<float> mix3 = Sse.Multiply(input3, parameter3);
Vector128<float> mix01 = Sse3.HorizontalAdd(mix0, mix1); Vector128<float> mix01 = Sse3.HorizontalAdd(mix0, mix1);
Vector128<float> mix23 = Sse3.HorizontalAdd(mix2, mix3); Vector128<float> mix23 = Sse3.HorizontalAdd(mix2, mix3);
Vector128<float> mix0123 = Sse3.HorizontalAdd(mix01, mix23); Vector128<float> mix0123 = Sse3.HorizontalAdd(mix01, mix23);
Sse.Store(pOutput + (uint)i, Sse41.RoundToNearestInteger(mix0123)); Sse.Store(pOutput + (uint)i, Sse41.RoundToNearestInteger(mix0123));
}
} }
} }
} }
@ -526,34 +522,59 @@ namespace Ryujinx.Audio.Renderer.Dsp
return _highCurveLut2F; return _highCurveLut2F;
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)] private static unsafe void ResampleHighQuality(Span<float> outputBuffer, ReadOnlySpan<short> inputBuffer, float ratio, ref float fraction, int sampleCount)
private static void ResampleHighQuality(Span<float> outputBuffer, ReadOnlySpan<short> inputBuffer, float ratio, ref float fraction, int sampleCount)
{ {
ReadOnlySpan<float> parameters = GetHighParameter(ratio); ReadOnlySpan<float> parameters = GetHighParameter(ratio);
int inputBufferIndex = 0; int inputBufferIndex = 0;
// TODO: fast path if (Avx2.IsSupported)
for (int i = 0; i < sampleCount; i++)
{ {
int baseIndex = (int)(fraction * 128) * 8; // Fast path; assumes 256-bit vectors for simplicity because the filter is 8 taps
ReadOnlySpan<float> parameter = parameters.Slice(baseIndex, 8); fixed (short* pInput = inputBuffer)
ReadOnlySpan<short> currentInput = inputBuffer.Slice(inputBufferIndex, 8); fixed (float* pParameters = parameters)
{
for (int i = 0; i < sampleCount; i++)
{
int baseIndex = (int)(fraction * 128) * 8;
outputBuffer[i] = (float)Math.Round(currentInput[0] * parameter[0] + Vector256<int> intInput = Avx2.ConvertToVector256Int32(pInput + inputBufferIndex);
currentInput[1] * parameter[1] + Vector256<float> floatInput = Avx.ConvertToVector256Single(intInput);
currentInput[2] * parameter[2] + Vector256<float> parameter = Avx.LoadVector256(pParameters + baseIndex);
currentInput[3] * parameter[3] + Vector256<float> dp = Avx.DotProduct(floatInput, parameter, control: 0xFF);
currentInput[4] * parameter[4] +
currentInput[5] * parameter[5] +
currentInput[6] * parameter[6] +
currentInput[7] * parameter[7]);
fraction += ratio; // avx2 does an 8-element dot product piecewise so we have to sum up 2 intermediate results
inputBufferIndex += (int)MathF.Truncate(fraction); outputBuffer[i] = (float)Math.Round(dp[0] + dp[4]);
fraction -= (int)fraction; fraction += ratio;
inputBufferIndex += (int)MathF.Truncate(fraction);
fraction -= (int)fraction;
}
}
}
else
{
for (int i = 0; i < sampleCount; i++)
{
int baseIndex = (int)(fraction * 128) * 8;
ReadOnlySpan<float> parameter = parameters.Slice(baseIndex, 8);
ReadOnlySpan<short> currentInput = inputBuffer.Slice(inputBufferIndex, 8);
outputBuffer[i] = (float)Math.Round(currentInput[0] * parameter[0] +
currentInput[1] * parameter[1] +
currentInput[2] * parameter[2] +
currentInput[3] * parameter[3] +
currentInput[4] * parameter[4] +
currentInput[5] * parameter[5] +
currentInput[6] * parameter[6] +
currentInput[7] * parameter[7]);
fraction += ratio;
inputBufferIndex += (int)MathF.Truncate(fraction);
fraction -= (int)fraction;
}
} }
} }

View file

@ -2,6 +2,7 @@ using Ryujinx.Audio.Renderer.Server.Upsampler;
using Ryujinx.Common.Memory; using Ryujinx.Common.Memory;
using System; using System;
using System.Diagnostics; using System.Diagnostics;
using System.Numerics;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
namespace Ryujinx.Audio.Renderer.Dsp namespace Ryujinx.Audio.Renderer.Dsp
@ -70,16 +71,32 @@ namespace Ryujinx.Audio.Renderer.Dsp
return; return;
} }
[MethodImpl(MethodImplOptions.AggressiveInlining)]
float DoFilterBank(ref UpsamplerBufferState state, in Array20<float> bank) float DoFilterBank(ref UpsamplerBufferState state, in Array20<float> bank)
{ {
float result = 0.0f; float result = 0.0f;
Debug.Assert(state.History.Length == HistoryLength); Debug.Assert(state.History.Length == HistoryLength);
Debug.Assert(bank.Length == FilterBankLength); Debug.Assert(bank.Length == FilterBankLength);
for (int j = 0; j < FilterBankLength; j++)
int curIdx = 0;
if (Vector.IsHardwareAccelerated)
{ {
result += bank[j] * state.History[j]; // Do SIMD-accelerated block operations where possible.
// Only about a 2x speedup since filter bank length is short
int stopIdx = FilterBankLength - (FilterBankLength % Vector<float>.Count);
while (curIdx < stopIdx)
{
result += Vector.Dot(
new Vector<float>(bank.AsSpan().Slice(curIdx, Vector<float>.Count)),
new Vector<float>(state.History.AsSpan().Slice(curIdx, Vector<float>.Count)));
curIdx += Vector<float>.Count;
}
}
while (curIdx < FilterBankLength)
{
result += bank[curIdx] * state.History[curIdx];
curIdx++;
} }
return result; return result;

View file

@ -0,0 +1,93 @@
using NUnit.Framework;
using Ryujinx.Audio.Renderer.Dsp;
using Ryujinx.Audio.Renderer.Parameter;
using Ryujinx.Audio.Renderer.Server.Upsampler;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading.Tasks;
namespace Ryujinx.Tests.Audio.Renderer.Dsp
{
class ResamplerTests
{
[Test]
[TestCase(VoiceInParameter.SampleRateConversionQuality.Low)]
[TestCase(VoiceInParameter.SampleRateConversionQuality.Default)]
[TestCase(VoiceInParameter.SampleRateConversionQuality.High)]
public void TestResamplerConsistencyUpsampling(VoiceInParameter.SampleRateConversionQuality quality)
{
DoResamplingTest(44100, 48000, quality);
}
[Test]
[TestCase(VoiceInParameter.SampleRateConversionQuality.Low)]
[TestCase(VoiceInParameter.SampleRateConversionQuality.Default)]
[TestCase(VoiceInParameter.SampleRateConversionQuality.High)]
public void TestResamplerConsistencyDownsampling(VoiceInParameter.SampleRateConversionQuality quality)
{
DoResamplingTest(48000, 44100, quality);
}
/// <summary>
/// Generates a 1-second sine wave sample at input rate, resamples it to output rate, and
/// ensures that it resampled at the expected rate with no discontinuities
/// </summary>
/// <param name="inputRate">The input sample rate to test</param>
/// <param name="outputRate">The output sample rate to test</param>
/// <param name="quality">The resampler quality to use</param>
private static void DoResamplingTest(int inputRate, int outputRate, VoiceInParameter.SampleRateConversionQuality quality)
{
float inputSampleRate = (float)inputRate;
float outputSampleRate = (float)outputRate;
int inputSampleCount = inputRate;
int outputSampleCount = outputRate;
short[] inputBuffer = new short[inputSampleCount + 100]; // add some safety buffer at the end
float[] outputBuffer = new float[outputSampleCount + 100];
for (int sample = 0; sample < inputBuffer.Length; sample++)
{
// 440 hz sine wave with amplitude = 0.5f at input sample rate
inputBuffer[sample] = (short)(32767 * MathF.Sin((440 / inputSampleRate) * (float)sample * MathF.PI * 2f) * 0.5f);
}
float fraction = 0;
ResamplerHelper.Resample(
outputBuffer.AsSpan(),
inputBuffer.AsSpan(),
inputSampleRate / outputSampleRate,
ref fraction,
outputSampleCount,
quality,
false);
float[] expectedOutput = new float[outputSampleCount];
float sumDifference = 0;
int delay = quality switch
{
VoiceInParameter.SampleRateConversionQuality.High => 3,
VoiceInParameter.SampleRateConversionQuality.Default => 1,
_ => 0
};
for (int sample = 0; sample < outputSampleCount; sample++)
{
outputBuffer[sample] /= 32767;
// 440 hz sine wave with amplitude = 0.5f at output sample rate
expectedOutput[sample] = MathF.Sin((440 / outputSampleRate) * (float)(sample + delay) * MathF.PI * 2f) * 0.5f;
float thisDelta = Math.Abs(expectedOutput[sample] - outputBuffer[sample]);
// Ensure no discontinuities
Assert.IsTrue(thisDelta < 0.1f);
sumDifference += thisDelta;
}
sumDifference = sumDifference / (float)outputSampleCount;
// Expect the output to be 99% similar to the expected resampled sine wave
Assert.IsTrue(sumDifference < 0.01f);
}
}
}

View file

@ -0,0 +1,64 @@
using NUnit.Framework;
using Ryujinx.Audio.Renderer.Dsp;
using Ryujinx.Audio.Renderer.Parameter;
using Ryujinx.Audio.Renderer.Server.Upsampler;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading.Tasks;
namespace Ryujinx.Tests.Audio.Renderer.Dsp
{
class UpsamplerTests
{
[Test]
public void TestUpsamplerConsistency()
{
UpsamplerBufferState bufferState = new UpsamplerBufferState();
int inputBlockSize = 160;
int numInputSamples = 32000;
int numOutputSamples = 48000;
float inputSampleRate = numInputSamples;
float outputSampleRate = numOutputSamples;
float[] inputBuffer = new float[numInputSamples + 100];
float[] outputBuffer = new float[numOutputSamples + 100];
for (int sample = 0; sample < inputBuffer.Length; sample++)
{
// 440 hz sine wave with amplitude = 0.5f at input sample rate
inputBuffer[sample] = MathF.Sin((440 / inputSampleRate) * (float)sample * MathF.PI * 2f) * 0.5f;
}
int inputIdx = 0;
int outputIdx = 0;
while (inputIdx + inputBlockSize < numInputSamples)
{
int outputBufLength = (int)Math.Round((float)(inputIdx + inputBlockSize) * outputSampleRate / inputSampleRate) - outputIdx;
UpsamplerHelper.Upsample(
outputBuffer.AsSpan(outputIdx),
inputBuffer.AsSpan(inputIdx),
outputBufLength,
inputBlockSize,
ref bufferState);
inputIdx += inputBlockSize;
outputIdx += outputBufLength;
}
float[] expectedOutput = new float[numOutputSamples];
float sumDifference = 0;
for (int sample = 0; sample < numOutputSamples; sample++)
{
// 440 hz sine wave with amplitude = 0.5f at output sample rate with an offset of 15
expectedOutput[sample] = MathF.Sin((440 / outputSampleRate) * (float)(sample - 15) * MathF.PI * 2f) * 0.5f;
sumDifference += Math.Abs(expectedOutput[sample] - outputBuffer[sample]);
}
sumDifference = sumDifference / (float)expectedOutput.Length;
// Expect the output to be 98% similar to the expected resampled sine wave
Assert.IsTrue(sumDifference < 0.02f);
}
}
}