Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ public STATestMethodAttribute(TestMethodAttribute testMethodAttribute)
: base(testMethodAttribute.DeclaringFilePath, testMethodAttribute.DeclaringLineNumber ?? -1)
=> _testMethodAttribute = testMethodAttribute;

/// <summary>
/// Gets or sets a value indicating whether the attribute will set a <see cref="SynchronizationContext"/> that preserves the same
/// STA thread for async continuations.
/// The default is <see langword="false"/>.
/// </summary>
public bool UseSTASynchronizationContext { get; set; }

/// <summary>
/// The core execution of STA test method, which happens on the STA thread.
/// </summary>
Expand All @@ -38,18 +45,39 @@ protected virtual Task<TestResult[]> ExecuteCoreAsync(ITestMethod testMethod)
=> _testMethodAttribute is null ? base.ExecuteAsync(testMethod) : _testMethodAttribute.ExecuteAsync(testMethod);

/// <inheritdoc />
public override Task<TestResult[]> ExecuteAsync(ITestMethod testMethod)
public override async Task<TestResult[]> ExecuteAsync(ITestMethod testMethod)
{
if (UseSTASynchronizationContext)
{
SynchronizationContext? originalContext = SynchronizationContext.Current;
var syncContext = new SingleThreadedSTASynchronizationContext();
try
{
SynchronizationContext.SetSynchronizationContext(syncContext);

// The yield ensures that we switch to the STA thread created by SingleThreadedSTASynchronizationContext.
await Task.Yield();
TestResult[] testResults = await ExecuteCoreAsync(testMethod).ConfigureAwait(false);
return testResults;
}
finally
{
SynchronizationContext.SetSynchronizationContext(originalContext);
syncContext.Complete();
syncContext.Dispose();
}
}

if (Thread.CurrentThread.GetApartmentState() == ApartmentState.STA)
{
return ExecuteCoreAsync(testMethod);
return await ExecuteCoreAsync(testMethod).ConfigureAwait(false);
}

#if !NETFRAMEWORK
if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
// TODO: Throw?
return ExecuteCoreAsync(testMethod);
return await ExecuteCoreAsync(testMethod).ConfigureAwait(false);
}
#endif

Expand All @@ -61,6 +89,6 @@ public override Task<TestResult[]> ExecuteAsync(ITestMethod testMethod)
t.SetApartmentState(ApartmentState.STA);
t.Start();
t.Join();
return Task.FromResult(results!);
return results!;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

namespace Microsoft.VisualStudio.TestTools.UnitTesting;

internal sealed class SingleThreadedSTASynchronizationContext : SynchronizationContext, IDisposable
{
private readonly BlockingCollection<Action> _queue = [];
private readonly Thread _thread;

public SingleThreadedSTASynchronizationContext()
{
#if !NETFRAMEWORK
if (!OperatingSystem.IsWindows())
{
throw new NotSupportedException("SingleThreadedSTASynchronizationContext is only supported on Windows.");
}
#endif

_thread = new Thread(() =>
{
SetSynchronizationContext(this);
foreach (Action callback in _queue.GetConsumingEnumerable())
{
callback();
}
})
{
IsBackground = true,
};
_thread.SetApartmentState(ApartmentState.STA);
_thread.Start();
}

public override void Post(SendOrPostCallback d, object? state)
=> _queue.Add(() => d(state));

public override void Send(SendOrPostCallback d, object? state)
{
if (Environment.CurrentManagedThreadId == _thread.ManagedThreadId)
{
d(state);
}
else
{
using var done = new ManualResetEventSlim();
_queue.Add(() =>
{
try
{
d(state);
}
finally
{
done.Set();
}
});
done.Wait();
}
}

public void Complete() => _queue.CompleteAdding();

public void Dispose() => _queue.Dispose();
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ Microsoft.VisualStudio.TestTools.UnitTesting.Assert.AssertGenericIsExactInstance
Microsoft.VisualStudio.TestTools.UnitTesting.Assert.AssertGenericIsNotExactInstanceOfTypeInterpolatedStringHandler<TArg>
Microsoft.VisualStudio.TestTools.UnitTesting.Assert.AssertIsExactInstanceOfTypeInterpolatedStringHandler
Microsoft.VisualStudio.TestTools.UnitTesting.Assert.AssertIsNotExactInstanceOfTypeInterpolatedStringHandler
Microsoft.VisualStudio.TestTools.UnitTesting.STATestMethodAttribute.UseSTASynchronizationContext.get -> bool
Microsoft.VisualStudio.TestTools.UnitTesting.STATestMethodAttribute.UseSTASynchronizationContext.set -> void
static Microsoft.VisualStudio.TestTools.UnitTesting.Assert.ContainsSingle(System.Func<object?, bool>! predicate, System.Collections.IEnumerable! collection, string? message = "", string! predicateExpression = "", string! collectionExpression = "") -> object?
static Microsoft.VisualStudio.TestTools.UnitTesting.Assert.IsExactInstanceOfType(object? value, System.Type? expectedType, string? message = "", string! valueExpression = "") -> void
static Microsoft.VisualStudio.TestTools.UnitTesting.Assert.IsExactInstanceOfType(object? value, System.Type? expectedType, ref Microsoft.VisualStudio.TestTools.UnitTesting.Assert.AssertIsExactInstanceOfTypeInterpolatedStringHandler message, string! valueExpression = "") -> void
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

namespace MSTest.SelfRealExamples.UnitTests;

[TestClass]
public class STATestMethodSyncContext
{
[STATestMethod]
[OSCondition(OperatingSystems.Windows)]
public void STAByDefaultDoesNotUseSynchronizationContext()
{
Assert.IsNull(SynchronizationContext.Current);
Assert.AreEqual(ApartmentState.STA, Thread.CurrentThread.GetApartmentState());
}

[STATestMethod(UseSTASynchronizationContext = true)]
[OSCondition(OperatingSystems.Windows)]
public async Task STAWithSynchronizationContextIsCorrect()
{
Assert.IsNotNull(SynchronizationContext.Current);
Assert.AreEqual(ApartmentState.STA, Thread.CurrentThread.GetApartmentState());

await Task.Delay(100);

Assert.AreEqual(ApartmentState.STA, Thread.CurrentThread.GetApartmentState());
}
}
Loading