dolphin/Externals/WIL/tests/ComTests.cpp
2020-02-09 19:01:44 +01:00

2718 lines
85 KiB
C++

#include <ocidl.h> // Bring in IObjectWithSite
#include <wil/com.h>
#include <wrl/implements.h>
#include "common.h"
using namespace Microsoft::WRL;
// avoid including #include <shobjidl.h>, it fails to compile in noprivateapis
EXTERN_C const CLSID CLSID_ShellLink;
class DECLSPEC_UUID("00021401-0000-0000-C000-000000000046") ShellLink;
// Uncomment this line to do a more exhaustive test of the concepts covered by this file. By
// default we don't fully compile every combination of tests as this test can substantially impact
// build times with template expansion.
// #define WIL_EXHAUSTIVE_TEST
// Helper objects / functions
class __declspec(uuid("a817e7a2-43fa-11d0-9e44-00aa00b6770a"))
IUnknownFake : public IUnknown
{
public:
STDMETHOD_(ULONG, AddRef)()
{
AddRefCounter++;
return 0;
}
STDMETHOD_(ULONG, Release)()
{
ReleaseCounter++;
return 0;
}
STDMETHOD(QueryInterface)(REFIID riid, _Outptr_result_nullonfailure_ void **ppvObject)
{
if (riid == __uuidof(IUnknown))
{
*ppvObject = this;
return S_OK;
}
*ppvObject = nullptr;
return E_NOINTERFACE;
}
bool ReturnTRUE()
{
return true;
}
static void Clear()
{
AddRefCounter = 0;
ReleaseCounter = 0;
}
static int GetAddRef()
{
int res = AddRefCounter;
AddRefCounter = 0;
return res;
}
static int GetRelease()
{
int res = ReleaseCounter;
ReleaseCounter = 0;
return res;
}
protected:
static int AddRefCounter;
static int ReleaseCounter;
};
int IUnknownFake::AddRefCounter = 0;
int IUnknownFake::ReleaseCounter = 0;
class __declspec(uuid("a817e7a2-43fa-11d0-9e44-00aa00b6770b"))
IUnknownFake2 : public IUnknownFake {};
TEST_CASE("ComTests::Test_Constructors", "[com][com_ptr]")
{
IUnknownFake::Clear();
IUnknownFake helper;
SECTION("Null/default construction")
{
wil::com_ptr_nothrow<IUnknown> ptr; //default constructor
REQUIRE(ptr.get() == nullptr);
wil::com_ptr_nothrow<IUnknown> ptr2(nullptr); //default explicit null constructor
REQUIRE(ptr2.get() == nullptr);
IUnknown* nullPtr = nullptr;
wil::com_ptr_nothrow<IUnknown> ptr3(nullPtr);
REQUIRE(ptr3.get() == nullptr);
}
SECTION("Valid pointer construction")
{
wil::com_ptr_nothrow<IUnknown> ptr(&helper); // explicit
REQUIRE(IUnknownFake::GetAddRef() == 1);
REQUIRE(ptr.get() == &helper);
}
SECTION("Copy construction")
{
wil::com_ptr_nothrow<IUnknown> ptr(&helper);
wil::com_ptr_nothrow<IUnknown> ptrCopy(ptr); // assign the same pointer
REQUIRE(IUnknownFake::GetAddRef() == 2);
REQUIRE(ptrCopy.get() == ptr.get());
IUnknownFake2 helper2;
wil::com_ptr_nothrow<IUnknownFake2> ptr2(&helper2);
wil::com_ptr_nothrow<IUnknownFake> ptrCopy2(ptr2);
REQUIRE(IUnknownFake::GetAddRef() == 2);
REQUIRE(ptrCopy2.get() == &helper2);
}
SECTION("Move construction")
{
IUnknownFake helper3;
wil::com_ptr_nothrow<IUnknownFake> ptr(&helper3);
wil::com_ptr_nothrow<IUnknownFake> ptrMove(reinterpret_cast<wil::com_ptr_nothrow<IUnknownFake>&&>(ptr));
REQUIRE(IUnknownFake::GetAddRef() == 1);
REQUIRE(ptrMove.get() == &helper3);
REQUIRE(ptr.get() == nullptr);
IUnknownFake2 helper4;
wil::com_ptr_nothrow<IUnknownFake2> ptr2(&helper4);
wil::com_ptr_nothrow<IUnknownFake> ptrMove2(reinterpret_cast<wil::com_ptr_nothrow<IUnknownFake2>&&>(ptr2));
REQUIRE(IUnknownFake::GetAddRef() == 1);
REQUIRE(ptrMove2.get() == &helper4);
REQUIRE(ptr2.get() == nullptr);
}
}
TEST_CASE("ComTests::Test_Assign", "[com][com_ptr]")
{
IUnknownFake::Clear();
IUnknownFake helper;
SECTION("Null pointer assignment")
{
wil::com_ptr_nothrow<IUnknownFake> ptr(&helper);
ptr = nullptr;
REQUIRE(ptr.get() == nullptr);
REQUIRE(IUnknownFake::GetRelease() == 1);
}
IUnknownFake::Clear();
IUnknownFake helper2;
SECTION("Different pointer assignment")
{
wil::com_ptr_nothrow<IUnknownFake> ptr(&helper);
wil::com_ptr_nothrow<IUnknownFake> ptr2(&helper2);
ptr = static_cast<const wil::com_ptr_nothrow<IUnknownFake>&>(ptr2);
REQUIRE(ptr.get() == &helper2);
REQUIRE(ptr2.get() == &helper2);
REQUIRE(IUnknownFake::GetRelease() == 1);
REQUIRE(IUnknownFake::GetAddRef() == 3);
}
SECTION("Self assignment")
{
wil::com_ptr_nothrow<IUnknownFake> ptr(&helper);
IUnknownFake::Clear();
ptr = ptr;
REQUIRE(ptr.get() == &helper);
// wil::com_ptr<T> can do self-assignment without blowing up -- and chooses NOT to preserve the this comparison for performance
// as this should be a rare/never operation...
// REQUIRE(IUnknownFake::GetRelease() == 0);
// REQUIRE(IUnknownFake::GetAddRef() == 0);
ptr = std::move(ptr);
REQUIRE(ptr.get() == &helper);
}
IUnknownFake2 helper3;
SECTION("Assign pointer with different interface")
{
wil::com_ptr_nothrow<IUnknownFake> ptr(&helper);
wil::com_ptr_nothrow<IUnknownFake2> ptr2(&helper3);
IUnknownFake::Clear();
ptr = static_cast<const wil::com_ptr_nothrow<IUnknownFake2>&>(ptr2);
REQUIRE(ptr.get() == &helper3);
REQUIRE(ptr2.get() == &helper3);
REQUIRE(IUnknownFake::GetRelease() == 1);
REQUIRE(IUnknownFake::GetAddRef() == 1);
}
SECTION("Move assignment")
{
wil::com_ptr_nothrow<IUnknownFake> ptr(&helper);
wil::com_ptr_nothrow<IUnknownFake> ptr2(&helper2);
IUnknownFake::Clear();
ptr = static_cast<wil::com_ptr_nothrow<IUnknownFake>&&>(ptr2);
REQUIRE(ptr.get() == &helper2);
REQUIRE(ptr2.get() == nullptr);
REQUIRE(IUnknownFake::GetRelease() == 1);
REQUIRE(IUnknownFake::GetAddRef() == 0);
}
SECTION("Move assign with different interface")
{
wil::com_ptr_nothrow<IUnknownFake> ptr(&helper);
wil::com_ptr_nothrow<IUnknownFake2> ptr2(&helper3);
IUnknownFake::Clear();
ptr = static_cast<wil::com_ptr_nothrow<IUnknownFake2>&&>(ptr2);
REQUIRE(ptr.get() == &helper3);
REQUIRE(ptr2.get() == nullptr);
REQUIRE(IUnknownFake::GetRelease() == 1);
REQUIRE(IUnknownFake::GetAddRef() == 0);
}
}
TEST_CASE("ComTests::Test_Operators", "[com][com_ptr]")
{
IUnknownFake::Clear();
IUnknownFake helper;
IUnknownFake helper2;
IUnknownFake2 helper3;
wil::com_ptr_nothrow<IUnknownFake> ptrNULL; //NULL one
wil::com_ptr_nothrow<IUnknownFake> ptrLT(&helper);
wil::com_ptr_nothrow<IUnknownFake> ptrGT(&helper2);
wil::com_ptr_nothrow<IUnknownFake2> ptrDiff(&helper3);
SECTION("equal operator")
{
REQUIRE_FALSE(ptrNULL == ptrLT);
REQUIRE(ptrNULL == ptrNULL);
REQUIRE(ptrLT == ptrLT);
REQUIRE_FALSE(ptrDiff == ptrLT);
REQUIRE_FALSE(ptrLT == ptrGT);
}
SECTION("not equals operator")
{
REQUIRE(ptrNULL != ptrLT);
REQUIRE_FALSE(ptrNULL != ptrNULL);
REQUIRE_FALSE(ptrLT != ptrLT);
REQUIRE(ptrDiff != ptrLT);
REQUIRE(ptrLT != ptrGT);
}
SECTION("less-than operator")
{
REQUIRE_FALSE(ptrNULL < ptrNULL);
REQUIRE(ptrNULL < ptrLT);
REQUIRE(ptrNULL < ptrLT);
if (ptrLT.get() < ptrGT.get())
{
REQUIRE(ptrLT < ptrGT);
}
else
{
REQUIRE(ptrGT < ptrLT);
}
}
}
TEST_CASE("ComTests::Test_Conversion", "[com][com_ptr]")
{
IUnknownFake::Clear();
IUnknownFake helper;
wil::com_ptr_nothrow<IUnknownFake> nullPtr;
wil::com_ptr_nothrow<IUnknownFake> ptr(&helper);
REQUIRE_FALSE(nullPtr);
REQUIRE(ptr);
}
TEST_CASE("ComTests::Test_Address", "[com][com_ptr]")
{
IUnknownFake::Clear();
IUnknownFake helper;
IUnknownFake** pFakePtr;
SECTION("addressof")
{
wil::com_ptr_nothrow<IUnknownFake> ptr(&helper);
IUnknownFake::Clear();
pFakePtr = ptr.addressof();
REQUIRE(IUnknownFake::GetRelease() == 0);
REQUIRE(IUnknownFake::GetAddRef() == 0);
REQUIRE((*pFakePtr) == &helper);
}
SECTION("put")
{
wil::com_ptr_nothrow<IUnknownFake> ptr(&helper);
IUnknownFake::Clear();
pFakePtr = ptr.put();
REQUIRE(IUnknownFake::GetRelease() == 1);
REQUIRE(IUnknownFake::GetAddRef() == 0);
REQUIRE((*pFakePtr) == nullptr);
REQUIRE(ptr == nullptr);
}
SECTION("put_void")
{
wil::com_ptr_nothrow<IUnknownFake> ptr(&helper);
IUnknownFake::Clear();
void** pvFakePtr = ptr.put_void();
REQUIRE(IUnknownFake::GetRelease() == 1);
REQUIRE(IUnknownFake::GetAddRef() == 0);
REQUIRE((*pvFakePtr) == nullptr);
REQUIRE(ptr == nullptr);
}
SECTION("put_unknown")
{
wil::com_ptr_nothrow<IUnknownFake> ptr(&helper);
IUnknownFake::Clear();
IUnknown** puFakePtr = ptr.put_unknown();
REQUIRE(IUnknownFake::GetRelease() == 1);
REQUIRE(IUnknownFake::GetAddRef() == 0);
REQUIRE((*puFakePtr) == nullptr);
REQUIRE(ptr == nullptr);
}
SECTION("Address operator")
{
wil::com_ptr_nothrow<IUnknownFake> ptr(&helper);
IUnknownFake::Clear();
pFakePtr = &ptr;
REQUIRE(IUnknownFake::GetRelease() == 1);
REQUIRE(IUnknownFake::GetAddRef() == 0);
REQUIRE((*pFakePtr) == nullptr);
REQUIRE(ptr == nullptr);
}
}
TEST_CASE("ComTests::Test_Helpers", "[com][com_ptr]")
{
IUnknownFake::Clear();
IUnknownFake helper;
IUnknownFake helper2;
IUnknownFake *ptrHelper;
wil::com_ptr_nothrow<IUnknownFake> ptr(&helper);
SECTION("detach")
{
IUnknownFake::Clear(); //clear addref counter
ptrHelper = ptr.detach();
REQUIRE(ptr.get() == nullptr);
REQUIRE(ptrHelper == &helper);
REQUIRE(IUnknownFake::GetAddRef() == 0);
}
SECTION("attach")
{
ptrHelper = &helper;
wil::com_ptr_nothrow<IUnknownFake> ptr2(&helper2); //have some non null pointer
IUnknownFake::Clear(); //clear addref counter
ptr2.attach(ptrHelper);
REQUIRE(ptr2.get() == ptrHelper);
REQUIRE(IUnknownFake::GetRelease() == 1);
REQUIRE(IUnknownFake::GetAddRef() == 0);
}
SECTION("get")
{
wil::com_ptr_nothrow<IUnknown> ptr2;
REQUIRE(ptr2.get() == nullptr);
IUnknownFake helper3;
wil::com_ptr_nothrow<IUnknownFake> ptr4(&helper3);
REQUIRE(ptr4.get() == &helper3);
}
SECTION("l-value swap")
{
wil::com_ptr_nothrow<IUnknownFake> ptr2(&helper);
wil::com_ptr_nothrow<IUnknownFake> ptr3(&helper2);
ptr2.swap(ptr3);
REQUIRE(ptr2.get() == &helper2);
REQUIRE(ptr3.get() == &helper);
}
SECTION("r-value swap")
{
wil::com_ptr_nothrow<IUnknownFake> ptr2(&helper);
wil::com_ptr_nothrow<IUnknownFake> ptr3(&helper2);
ptr2.swap(wistd::move(ptr3));
REQUIRE(ptr2.get() == &helper2);
REQUIRE(ptr3.get() == &helper);
}
}
TEST_CASE("ComTests::Test_As", "[com][com_ptr]")
{
IUnknownFake::Clear();
IUnknownFake helper;
wil::com_ptr_nothrow<IUnknownFake> ptr(&helper);
SECTION("query by IID")
{
wil::com_ptr_nothrow<IUnknown> ptr2;
// REQUIRE(S_OK == ptr.AsIID(__uuidof(IUnknown), &ptr2));
REQUIRE(S_OK == ptr.query_to(__uuidof(IUnknown), reinterpret_cast<void**>(&ptr2)));
REQUIRE(ptr2 != nullptr);
}
SECTION("query by invalid IID")
{
wil::com_ptr_nothrow<IUnknown> ptr2;
// REQUIRE(S_OK != ptr.AsIID(__uuidof(IDispatch), &ptr2));
REQUIRE(S_OK != ptr.query_to(__uuidof(IDispatch), reinterpret_cast<void**>(&ptr2)));
REQUIRE(ptr2 == nullptr);
}
SECTION("same interface query")
{
// wil::com_ptr optimizes same-type assignment to just call AddRef
IUnknownFake2 helper2;
wil::com_ptr_nothrow<IUnknownFake2> ptr2(&helper2);
wil::com_ptr_nothrow<IUnknownFake2> ptr3;
REQUIRE(S_OK == ptr2.query_to<IUnknownFake2>(&ptr3));
REQUIRE(ptr3 != nullptr);
}
SECTION("base interface query")
{
IUnknownFake2 helper2;
wil::com_ptr_nothrow<IUnknownFake2> ptr2(&helper2);
wil::com_ptr_nothrow<IUnknown> ptr3;
REQUIRE(S_OK == ptr2.query_to<IUnknown>(&ptr3));
REQUIRE(ptr3 != nullptr);
}
}
TEST_CASE("ComTests::Test_CopyTo", "[com][com_ptr]")
{
IUnknownFake::Clear();
IUnknownFake helper;
IUnknownFake2 helper2;
wil::com_ptr_nothrow<IUnknownFake> ptr(&helper);
SECTION("copy by IID")
{
wil::com_ptr_nothrow<IUnknown> ptr2;
REQUIRE(S_OK == ptr.copy_to(__uuidof(IUnknown), reinterpret_cast<void **>(&ptr2)));
REQUIRE(ptr2 != nullptr);
}
SECTION("copy by invalid IID")
{
wil::com_ptr_nothrow<IUnknown> ptr2;
REQUIRE(S_OK != ptr.copy_to(__uuidof(IDispatch), reinterpret_cast<void **>(&ptr2)));
REQUIRE(ptr2 == nullptr);
}
SECTION("same interface copy")
{
wil::com_ptr_nothrow<IUnknownFake2> ptr2(&helper2);
wil::com_ptr_nothrow<IUnknownFake2> ptr3;
REQUIRE(S_OK == ptr2.copy_to(&ptr3));
REQUIRE(ptr3 != nullptr);
}
SECTION("base interface copy")
{
wil::com_ptr_nothrow<IUnknownFake2> ptr2(&helper2);
wil::com_ptr_nothrow<IUnknown> ptr3;
REQUIRE(S_OK == ptr2.copy_to(ptr3.addressof()));
REQUIRE(ptr3 != nullptr);
}
}
// Helper used to verify correctness of IID_PPV_ARGS support
void IID_PPV_ARGS_Test_Helper(REFIID iid, void** pv)
{
__analysis_assume(pv != nullptr);
REQUIRE(pv != nullptr);
REQUIRE(*pv == nullptr);
*pv = reinterpret_cast<void*>(0x01); // Set check value
REQUIRE(iid == __uuidof(IUnknown));
}
TEST_CASE("ComTests::Test_IID_PPV_ARGS", "[com][com_ptr]")
{
wil::com_ptr_nothrow<IUnknown> unk;
IID_PPV_ARGS_Test_Helper(IID_PPV_ARGS(&unk));
//Test if we got the correct check value back
REQUIRE(unk.get() == reinterpret_cast<void*>(0x01));
// Make sure that we will not try to release some garbage
auto avoidWarning = unk.detach();
(void)avoidWarning;
}
// Helps with testing wil::com_ptr<const ExtensionHelper> configuration when the operator -> is used
class ExtensionHelper
{
public:
HRESULT Extend() const
{
return S_OK;
}
STDMETHOD_(ULONG, AddRef)() const
{
return 0;
}
STDMETHOD_(ULONG, Release)() const
{
return 0;
}
};
TEST_CASE("ComTests::Test_ConstPointer", "[com][com_ptr]")
{
IUnknownFake::Clear();
IUnknownFake helper;
const wil::com_ptr_nothrow<IUnknown> spUnk(&helper);
wil::com_ptr_nothrow<IUnknown> spUnkHelper;
wil::com_ptr_nothrow<IInspectable> spInspectable;
REQUIRE(spUnk.get() != nullptr);
REQUIRE(spUnk);
spUnk.addressof();
spUnk.copy_to(spUnkHelper.addressof());
spUnk.copy_to(spInspectable.addressof());
spUnk.copy_to(IID_PPV_ARGS(&spInspectable));
spUnk.query_to(&spUnkHelper);
spUnk.query_to(&spInspectable);
spUnk.query_to(__uuidof(IUnknown), reinterpret_cast<void**>(&spUnkHelper));
const ExtensionHelper extHelper;
wil::com_ptr_nothrow<const ExtensionHelper> spExt(&extHelper);
REQUIRE(spExt->Extend() == S_OK);
}
// Make sure that the pointer can be defined just with forward declaration of the class
TEST_CASE("ComTests::Test_ComPtrWithForwardDeclaration", "[com][com_ptr]")
{
class MyClass;
wil::com_ptr_nothrow<MyClass> spClass;
class MyClass : public IUnknown
{
public:
STDMETHOD_(ULONG, AddRef)()
{
return 0;
}
STDMETHOD_(ULONG, Release)()
{
return 0;
}
};
}
//*****************************************************************************
// various com_ptr tests
//*****************************************************************************
interface __declspec(uuid("ececcc6a-5193-4d14-b38e-ed1460c20a00"))
ITest : public IUnknown
{
STDMETHOD_(void, Test)() = 0;
};
interface __declspec(uuid("ececcc6a-5193-4d14-b38e-ed1460c20a01"))
IDerivedTest : public ITest
{
STDMETHOD_(void, TestDerived)() = 0;
};
interface __declspec(uuid("ececcc6a-5193-4d14-b38e-ed1460c20a02"))
ITestInspectable : public IInspectable
{
STDMETHOD_(void, TestInspctable)() = 0;
};
interface __declspec(uuid("ececcc6a-5193-4d14-b38e-ed1460c20a03"))
IDerivedTestInspectable : public ITestInspectable
{
STDMETHOD_(void, TestInspctableDerived)() = 0;
};
interface __declspec(uuid("ececcc6a-5193-4d14-b38e-ed1460c20a04"))
INever : public IUnknown
{
STDMETHOD_(void, Never)() = 0;
};
interface __declspec(uuid("ececcc6a-5193-4d14-b38e-ed1460c20a05"))
IAlways : public IUnknown
{
STDMETHOD_(void, Always)() = 0;
};
class __declspec(uuid("ececcc6a-5193-4d14-b38e-ed1460c20b00")) // non-implemented to allow QI for the class to be attempted (and fail)
ComObject : witest::AllocatedObject,
public Microsoft::WRL::RuntimeClass<Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::RuntimeClassType::ClassicCom>,
Microsoft::WRL::ChainInterfaces<IDerivedTest, ITest>,
IAlways>{
public:
COM_DECLSPEC_NOTHROW IFACEMETHODIMP_(void) Test() {}
COM_DECLSPEC_NOTHROW IFACEMETHODIMP_(void) TestDerived() {}
COM_DECLSPEC_NOTHROW IFACEMETHODIMP_(void) Always() {}
};
class __declspec(uuid("ececcc6a-5193-4d14-b38e-ed1460c20b01")) // non-implemented to allow QI for the class to be attempted (and fail)
WinRtObject : witest::AllocatedObject,
public Microsoft::WRL::RuntimeClass<Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::RuntimeClassType::WinRtClassicComMix>,
ITest, IDerivedTest, ITestInspectable, IDerivedTestInspectable, IAlways, Microsoft::WRL::FtmBase>
{
public:
COM_DECLSPEC_NOTHROW IFACEMETHODIMP_(void) Test() {}
COM_DECLSPEC_NOTHROW IFACEMETHODIMP_(void) TestDerived() {}
COM_DECLSPEC_NOTHROW IFACEMETHODIMP_(void) TestInspctable() {}
COM_DECLSPEC_NOTHROW IFACEMETHODIMP_(void) TestInspctableDerived() {}
COM_DECLSPEC_NOTHROW IFACEMETHODIMP_(void) Always() {}
};
class NoCom : witest::AllocatedObject
{
public:
ULONG __stdcall AddRef()
{
return m_ref++;
}
ULONG __stdcall Release()
{
auto retVal = (--m_ref);
if (retVal == 0)
{
delete this;
}
return retVal;
}
private:
ULONG m_ref = 1;
};
template <typename T, typename U, typename = wistd::enable_if_t<!wistd::is_same_v<T, U>>>
T* cast_object(U*)
{
FAIL_FAST();
}
template <typename T>
T* cast_object(T* ptr)
{
return ptr;
}
template <typename IFace, typename Object>
static IFace* make_object()
{
auto obj = Microsoft::WRL::Make<Object>();
IFace* result = nullptr;
if (FAILED(obj.Get()->QueryInterface(__uuidof(IFace), reinterpret_cast<void**>(&result))))
{
// The QI only fails when we're asking for a CFoo from a CFoo (equivalent types)... in this
// case just return the original pointer -- the reinterpret_cast is needed as the code is shared
// and the other (nonuniform) cases also compile it (but do not execute it).
result = cast_object<IFace>(obj.Detach());
}
return result;
}
template <>
NoCom* make_object<NoCom, NoCom>()
{
return new NoCom();
}
template <typename Ptr>
void TestSmartPointer(const Ptr& ptr1, const Ptr& ptr2)
{
SECTION("swap (method and global)")
{
auto p1 = ptr1;
auto p2 = ptr2;
p1.swap(p2); // l-value
REQUIRE(((p1 == ptr2) && (p2 == ptr1)));
p1.swap(wistd::move(p2)); // r-value
REQUIRE(((p1 == ptr1) && (p2 == ptr2)));
wil::swap(p1, p2);
REQUIRE(((p1 == ptr2) && (p2 == ptr1)));
}
SECTION("WRL swap (method and global)")
{
auto p1 = ptr1;
Microsoft::WRL::ComPtr<typename Ptr::element_type> p2 = ptr2.get();
p1.swap(p2); // l-value
REQUIRE(((p1 == ptr2) && (p2 == ptr1)));
p1.swap(wistd::move(p2)); // r-value
REQUIRE(((p1 == ptr1) && (p2 == ptr2)));
wil::swap(p1, p2);
REQUIRE(((p1 == ptr2) && (p2 == ptr1)));
wil::swap(p2, p1);
REQUIRE(((p1 == ptr1) && (p2 == ptr2)));
}
SECTION("reset")
{
auto p = ptr1;
p.reset();
REQUIRE_FALSE(p);
p = ptr1;
p.reset(nullptr);
REQUIRE_FALSE(p);
}
SECTION("attach / detach")
{
auto p1 = ptr1;
auto p2 = ptr2;
p1.attach(p2.detach());
REQUIRE(((p1.get() == ptr2.get()) && !p2));
}
SECTION("addressof")
{
auto p1 = ptr1;
auto p2 = ptr2;
p1.addressof(); // Doesn't reset
REQUIRE(p1.get() == ptr1.get());
p1.reset();
*(p1.addressof()) = p2.detach();
REQUIRE(p1.get() == ptr2.get());
}
SECTION("put")
{
auto p1 = ptr1;
auto p2 = ptr2;
p1.put();
REQUIRE_FALSE(p1);
*p1.put() = p2.detach();
REQUIRE(p1.get() == ptr2.get());
}
SECTION("operator&")
{
auto p1 = ptr1;
auto p2 = ptr2;
&p1;
REQUIRE_FALSE(p1);
*(&p1) = p2.detach();
REQUIRE(p1.get() == ptr2.get());
}
SECTION("exercise const methods on the const param (ensure const)")
{
auto address = ptr1.addressof();
REQUIRE(*address == ptr1.get());
(void)static_cast<bool>(ptr1);
ptr1.get();
auto deref = ptr1.operator->();
(void)deref;
if (ptr1)
{
auto& ref = ptr1.operator*();
(void)ref;
}
}
}
template <typename IFace>
static void TestPointerCombination(IFace* p1, IFace* p2)
{
#ifdef WIL_ENABLE_EXCEPTIONS
TestSmartPointer(wil::com_ptr<IFace>(p1), wil::com_ptr<IFace>(p2));
#endif
TestSmartPointer(wil::com_ptr_failfast<IFace>(p1), wil::com_ptr_failfast<IFace>(p2));
TestSmartPointer(wil::com_ptr_nothrow<IFace>(p1), wil::com_ptr_nothrow<IFace>(p2));
}
template <typename IFace, typename Object>
static void TestPointer()
{
auto p1 = make_object<IFace, Object>();
auto p2 = make_object<IFace, Object>();
IFace* nullPtr = nullptr;
TestPointerCombination(p1, p2);
TestPointerCombination(nullPtr, p2);
TestPointerCombination(p1, nullPtr);
TestPointerCombination(nullPtr, nullPtr);
TestPointerCombination(p1, p1); // same object
p1->Release();
p2->Release();
}
TEST_CASE("ComTests::Test_MemberFunctions", "[com][com_ptr]")
{
// avoid overwhelming debug logging, perhaps the COM helpers are over reporting
auto restoreDebugString = wil::g_fResultOutputDebugString;
wil::g_fResultOutputDebugString = false;
TestPointer<NoCom, NoCom>();
TestPointer<ComObject, ComObject>();
TestPointer<IUnknown, ComObject>();
TestPointer<ITest, ComObject>();
TestPointer<IDerivedTest, ComObject>();
TestPointer<IAlways, ComObject>();
TestPointer<WinRtObject, WinRtObject>();
TestPointer<IUnknown, WinRtObject>();
TestPointer<IInspectable, WinRtObject>();
TestPointer<ITest, WinRtObject>();
TestPointer<IDerivedTest, WinRtObject>();
TestPointer<ITestInspectable, WinRtObject>();
TestPointer<IDerivedTestInspectable, WinRtObject>();
TestPointer<IAlways, WinRtObject>();
REQUIRE_FALSE(witest::g_objectCount.Leaked());
wil::g_fResultOutputDebugString = restoreDebugString;
}
template <typename Ptr1, typename Ptr2>
static void TestSmartPointerConversion(const Ptr1& ptr1, const Ptr2& ptr2)
{
const Microsoft::WRL::ComPtr<typename Ptr1::element_type> wrl1 = ptr1.get();
const Microsoft::WRL::ComPtr<typename Ptr1::element_type> wrl2 = ptr2.get();
SECTION("global comparison operators")
{
auto p1 = ptr1.get();
auto p2 = ptr2.get();
// com_ptr to com_ptr
REQUIRE((ptr1 == ptr2) == (p1 == p2));
REQUIRE((ptr1 != ptr2) == (p1 != p2));
REQUIRE((ptr1 < ptr2) == (p1 < p2));
REQUIRE((ptr1 <= ptr2) == (p1 <= p2));
REQUIRE((ptr1 > ptr2) == (p1 > p2));
REQUIRE((ptr1 >= ptr2) == (p1 >= p2));
// com_ptr to ComPtr
REQUIRE((wrl1 == ptr2) == (p1 == p2));
REQUIRE((wrl1 != ptr2) == (p1 != p2));
REQUIRE((wrl1 < ptr2) == (p1 < p2));
REQUIRE((wrl1 <= ptr2) == (p1 <= p2));
REQUIRE((wrl1 > ptr2) == (p1 > p2));
REQUIRE((wrl1 >= ptr2) == (p1 >= p2));
REQUIRE((ptr1 == wrl2) == (p1 == p2));
REQUIRE((ptr1 != wrl2) == (p1 != p2));
REQUIRE((ptr1 < wrl2) == (p1 < p2));
REQUIRE((ptr1 <= wrl2) == (p1 <= p2));
REQUIRE((ptr1 > wrl2) == (p1 > p2));
REQUIRE((ptr1 >= wrl2) == (p1 >= p2));
// com_ptr to raw pointer
REQUIRE((ptr1 == p2) == (p1 == p2));
REQUIRE((ptr1 != p2) == (p1 != p2));
REQUIRE((ptr1 < p2) == (p1 < p2));
REQUIRE((ptr1 <= p2) == (p1 <= p2));
REQUIRE((ptr1 > p2) == (p1 > p2));
REQUIRE((ptr1 >= p2) == (p1 >= p2));
REQUIRE((p1 == ptr2) == (p1 == p2));
REQUIRE((p1 != ptr2) == (p1 != p2));
REQUIRE((p1 < ptr2) == (p1 < p2));
REQUIRE((p1 <= ptr2) == (p1 <= p2));
REQUIRE((p1 > ptr2) == (p1 > p2));
REQUIRE((p1 >= ptr2) == (p1 >= p2));
}
SECTION("construct from raw pointer")
{
Ptr1 p1(ptr2.get());
Ptr1 p2 = ptr2.get();
REQUIRE(((p1 == ptr2) && (p2 == ptr2)));
}
SECTION("construct from com_ptr ref<>")
{
Ptr1 p1(ptr2);
Ptr1 p2 = (ptr2);
REQUIRE(((p1 == ptr2) && (p2 == ptr2)));
}
SECTION("r-value construct from com_ptr ref<>")
{
auto move1 = ptr2;
auto move2 = ptr2;
Ptr1 p1(wistd::move(move1));
Ptr1 p2 = wistd::move(move2);
REQUIRE(((p1 == ptr2) && (p2 == ptr2)));
}
SECTION("assign from raw pointer")
{
Ptr1 p = ptr1;
p = (ptr2.get());
REQUIRE(p == ptr2);
}
SECTION("assign from com_ptr ref<>")
{
Ptr1 p = ptr1;
p = ptr2;
REQUIRE(p == ptr2);
}
SECTION("r-value assign from com_ptr ref<>")
{
Ptr1 p = ptr1;
p = Ptr2(ptr2);
REQUIRE(p == ptr2);
}
SECTION("construct from ComPtr ref<>")
{
Ptr1 p1(wrl2);
Ptr1 p2 = (wrl2);
REQUIRE(((p1 == wrl2) && (p2 == wrl2)));
}
SECTION("r-value construct from ComPtr ref<>")
{
auto move1 = wrl2;
auto move2 = wrl2;
Ptr1 p1(wistd::move(move1));
Ptr1 p2 = wistd::move(move2);
REQUIRE(((p1 == wrl2) && (p2 == wrl2)));
}
SECTION("assign from ComPtr ref<>")
{
Ptr1 p = ptr1;
p = wrl2;
REQUIRE(p == wrl2);
}
SECTION("r-value assign from ComPtr ref<>")
{
Ptr1 p = ptr1;
p = decltype(wrl2)(wrl2);
REQUIRE(p == wrl2);
}
}
template <typename IFace1, typename IFace2>
static void TestPointerConversionCombination(IFace1* p1, IFace2* p2)
{
#ifdef WIL_ENABLE_EXCEPTIONS
TestSmartPointerConversion(wil::com_ptr<IFace1>(p1), wil::com_ptr_nothrow<IFace2>(p2));
#endif
TestSmartPointerConversion(wil::com_ptr_failfast<IFace1>(p1), wil::com_ptr_nothrow<IFace2>(p2));
TestSmartPointerConversion(wil::com_ptr_nothrow<IFace1>(p1), wil::com_ptr_nothrow<IFace2>(p2));
#ifdef WIL_EXHAUSTIVE_TEST
#ifdef WIL_ENABLE_EXCEPTIONS
TestSmartPointerConversion(wil::com_ptr<IFace1>(p1), wil::com_ptr<IFace2>(p2));
TestSmartPointerConversion(wil::com_ptr_failfast<IFace1>(p1), wil::com_ptr<IFace2>(p2));
TestSmartPointerConversion(wil::com_ptr_nothrow<IFace1>(p1), wil::com_ptr<IFace2>(p2));
TestSmartPointerConversion(wil::com_ptr<IFace1>(p1), wil::com_ptr_failfast<IFace2>(p2));
#endif
TestSmartPointerConversion(wil::com_ptr_failfast<IFace1>(p1), wil::com_ptr_failfast<IFace2>(p2));
TestSmartPointerConversion(wil::com_ptr_nothrow<IFace1>(p1), wil::com_ptr_failfast<IFace2>(p2));
#endif
}
template <typename IFace1, typename IFace2, typename Object>
static void TestPointerConversion()
{
auto p1 = make_object<IFace1, Object>();
auto p2 = make_object<IFace2, Object>();
IFace1* nullPtr1 = nullptr;
IFace2* nullPtr2 = nullptr;
TestPointerConversionCombination(p1, p2);
TestPointerConversionCombination(nullPtr1, p2);
TestPointerConversionCombination(p1, nullPtr2);
TestPointerConversionCombination(nullPtr1, nullPtr2);
TestPointerConversionCombination(static_cast<IFace1*>(p2), p2); // same object
p1->Release();
p2->Release();
}
TEST_CASE("ComTests::Test_PointerConversion", "[com][com_ptr]")
{
// avoid overwhelming debug logging, perhaps the COM helpers are over reporting
auto restoreDebugString = wil::g_fResultOutputDebugString;
wil::g_fResultOutputDebugString = false;
TestPointerConversion<NoCom, NoCom, NoCom>();
TestPointerConversion<ComObject, ComObject, ComObject>();
TestPointerConversion<IUnknown, ITest, ComObject>();
TestPointerConversion<IUnknown, IDerivedTest, ComObject>();
TestPointerConversion<ITest, IDerivedTest, ComObject>();
#ifdef WIL_EXHAUSTIVE_TEST
TestPointerConversion<IUnknown, IUnknown, ComObject>();
TestPointerConversion<ITest, ITest, ComObject>();
TestPointerConversion<IDerivedTest, IDerivedTest, ComObject>();
TestPointerConversion<IAlways, IAlways, ComObject>();
TestPointerConversion<IUnknown, IAlways, ComObject>();
TestPointerConversion<WinRtObject, WinRtObject, WinRtObject>();
TestPointerConversion<IUnknown, IUnknown, WinRtObject>();
TestPointerConversion<IUnknown, ITest, WinRtObject>();
TestPointerConversion<IUnknown, IDerivedTest, WinRtObject>();
TestPointerConversion<IUnknown, ITestInspectable, WinRtObject>();
TestPointerConversion<IUnknown, IDerivedTestInspectable, WinRtObject>();
TestPointerConversion<IUnknown, IAlways, WinRtObject>();
TestPointerConversion<IInspectable, IInspectable, WinRtObject>();
TestPointerConversion<IInspectable, ITestInspectable, WinRtObject>();
TestPointerConversion<IInspectable, IDerivedTestInspectable, WinRtObject>();
TestPointerConversion<ITest, ITest, WinRtObject>();
TestPointerConversion<ITest, IDerivedTest, WinRtObject>();
TestPointerConversion<ITestInspectable, ITestInspectable, WinRtObject>();
TestPointerConversion<ITestInspectable, IDerivedTestInspectable, WinRtObject>();
TestPointerConversion<IDerivedTest, IDerivedTest, WinRtObject>();
TestPointerConversion<IDerivedTestInspectable, IDerivedTestInspectable, WinRtObject>();
TestPointerConversion<IAlways, IAlways, WinRtObject>();
#endif
REQUIRE_FALSE(witest::g_objectCount.Leaked());
wil::g_fResultOutputDebugString = restoreDebugString;
}
template <typename TargetIFace, typename Ptr>
void TestGlobalQueryIidPpv(wistd::true_type, const Ptr& source) // interface
{
using DestPtr = wil::com_ptr_nothrow<TargetIFace>;
wil::com_ptr_nothrow<INever> never;
SECTION("com_query_to(iid, ppv)")
{
if (source)
{
#ifdef WIL_ENABLE_EXCEPTIONS
DestPtr dest1;
wil::com_query_to(source, IID_PPV_ARGS(&dest1));
REQUIRE_ERROR(wil::com_query_to(source, IID_PPV_ARGS(&never)));
REQUIRE((dest1 && !never));
#endif
DestPtr dest2, dest3;
wil::com_query_to_failfast(source, IID_PPV_ARGS(&dest2));
REQUIRE_ERROR(wil::com_query_to_failfast(source, IID_PPV_ARGS(&never)));
wil::com_query_to_nothrow(source, IID_PPV_ARGS(&dest3));
REQUIRE_ERROR(wil::com_query_to_nothrow(source, IID_PPV_ARGS(&never)));
REQUIRE((dest2 && dest3 && !never));
}
else
{
#ifdef WIL_ENABLE_EXCEPTIONS
DestPtr dest1;
REQUIRE_CRASH(wil::com_query_to(source, IID_PPV_ARGS(&dest1)));
REQUIRE_CRASH(wil::com_query_to(source, IID_PPV_ARGS(&never)));
#endif
DestPtr dest2, dest3;
REQUIRE_CRASH(wil::com_query_to_failfast(source, IID_PPV_ARGS(&dest2)));
REQUIRE_CRASH(wil::com_query_to_failfast(source, IID_PPV_ARGS(&never)));
REQUIRE_CRASH(wil::com_query_to_nothrow(source, IID_PPV_ARGS(&dest3)));
REQUIRE_CRASH(wil::com_query_to_nothrow(source, IID_PPV_ARGS(&never)));
}
}
SECTION("try_com_query_to(iid, ppv)")
{
if (source)
{
DestPtr dest1;
REQUIRE(wil::try_com_query_to(source, IID_PPV_ARGS(&dest1)));
REQUIRE_FALSE(wil::try_com_query_to(source, IID_PPV_ARGS(&never)));
REQUIRE((dest1 && !never));
}
else
{
DestPtr dest1;
REQUIRE_CRASH(wil::try_com_query_to(source, IID_PPV_ARGS(&dest1)));
REQUIRE_CRASH(wil::try_com_query_to(source, IID_PPV_ARGS(&never)));
}
}
SECTION("com_copy_to(iid, ppv)")
{
if (source)
{
#ifdef WIL_ENABLE_EXCEPTIONS
DestPtr dest1;
wil::com_copy_to(source, IID_PPV_ARGS(&dest1));
REQUIRE_ERROR(wil::com_copy_to(source, IID_PPV_ARGS(&never)));
REQUIRE((dest1 && !never));
#endif
DestPtr dest2, dest3;
wil::com_copy_to_failfast(source, IID_PPV_ARGS(&dest2));
REQUIRE_ERROR(wil::com_copy_to_failfast(source, IID_PPV_ARGS(&never)));
wil::com_copy_to_nothrow(source, IID_PPV_ARGS(&dest3));
REQUIRE_ERROR(wil::com_copy_to_nothrow(source, IID_PPV_ARGS(&never)));
REQUIRE((dest2 && dest3 && !never));
}
else
{
#ifdef WIL_ENABLE_EXCEPTIONS
DestPtr dest1;
wil::com_copy_to(source, IID_PPV_ARGS(&dest1));
wil::com_copy_to(source, IID_PPV_ARGS(&never));
#endif
DestPtr dest2, dest3;
wil::com_copy_to_failfast(source, IID_PPV_ARGS(&dest2));
wil::com_copy_to_failfast(source, IID_PPV_ARGS(&never));
wil::com_copy_to_nothrow(source, IID_PPV_ARGS(&dest3));
wil::com_copy_to_nothrow(source, IID_PPV_ARGS(&never));
}
}
SECTION("try_com_copy_to(iid, ppv)")
{
if (source)
{
DestPtr dest1;
REQUIRE(wil::try_com_copy_to(source, IID_PPV_ARGS(&dest1)));
REQUIRE_FALSE(wil::try_com_copy_to(source, IID_PPV_ARGS(&never)));
REQUIRE((dest1 && !never));
}
else
{
DestPtr dest1;
REQUIRE_FALSE(wil::try_com_copy_to(source, IID_PPV_ARGS(&dest1)));
REQUIRE_FALSE(wil::try_com_copy_to(source, IID_PPV_ARGS(&never)));
}
}
}
template <typename TargetIFace, typename Ptr>
void TestGlobalQueryIidPpv(wistd::false_type, const Ptr&) // class
{
// we can't compile against iid, ppv with a class
}
template <typename TargetIFace, typename Ptr>
static void TestGlobalQuery(const Ptr& source)
{
using DestPtr = wil::com_ptr_nothrow<TargetIFace>;
wil::com_ptr_nothrow<INever> never;
SECTION("com_query")
{
if (source)
{
#ifdef WIL_ENABLE_EXCEPTIONS
REQUIRE(wil::com_query<TargetIFace>(source));
REQUIRE_ERROR(wil::com_query<INever>(source));
#endif
REQUIRE(wil::com_query_failfast<TargetIFace>(source));
REQUIRE_ERROR(wil::com_query_failfast<INever>(source));
}
else
{
#ifdef WIL_ENABLE_EXCEPTIONS
REQUIRE_CRASH(wil::com_query<TargetIFace>(source));
REQUIRE_CRASH(wil::com_query<INever>(source));
#endif
REQUIRE_CRASH(wil::com_query_failfast<TargetIFace>(source));
REQUIRE_CRASH(wil::com_query_failfast<INever>(source));
}
}
SECTION("com_query_to(U**)")
{
if (source)
{
#ifdef WIL_ENABLE_EXCEPTIONS
DestPtr dest1;
wil::com_query_to(source, &dest1);
REQUIRE_ERROR(wil::com_query_to(source, &never));
REQUIRE((dest1 && !never));
#endif
DestPtr dest2, dest3;
wil::com_query_to_failfast(source, &dest2);
REQUIRE_ERROR(wil::com_query_to_failfast(source, &never));
wil::com_query_to_nothrow(source, &dest3);
REQUIRE_ERROR(wil::com_query_to_nothrow(source, &never));
REQUIRE((dest2 && dest3 && !never));
}
else
{
#ifdef WIL_ENABLE_EXCEPTIONS
DestPtr dest1;
REQUIRE_CRASH(wil::com_query_to(source, &dest1));
REQUIRE_CRASH(wil::com_query_to(source, &never));
#endif
DestPtr dest2, dest3;
REQUIRE_CRASH(wil::com_query_to_failfast(source, &dest2));
REQUIRE_CRASH(wil::com_query_to_failfast(source, &never));
REQUIRE_CRASH(wil::com_query_to_nothrow(source, &dest3));
REQUIRE_CRASH(wil::com_query_to_nothrow(source, &never));
}
}
SECTION("try_com_query")
{
if (source)
{
#ifdef WIL_ENABLE_EXCEPTIONS
REQUIRE(wil::try_com_query<TargetIFace>(source));
REQUIRE_FALSE(wil::try_com_query<INever>(source));
#endif
REQUIRE(wil::try_com_query_failfast<TargetIFace>(source));
REQUIRE_FALSE(wil::try_com_query_failfast<INever>(source));
REQUIRE(wil::try_com_query_nothrow<TargetIFace>(source));
REQUIRE_FALSE(wil::try_com_query_nothrow<INever>(source));
}
else
{
#ifdef WIL_ENABLE_EXCEPTIONS
REQUIRE_CRASH(wil::try_com_query<TargetIFace>(source));
REQUIRE_CRASH(wil::try_com_query<INever>(source));
#endif
REQUIRE_CRASH(wil::try_com_query_failfast<TargetIFace>(source));
REQUIRE_CRASH(wil::try_com_query_failfast<INever>(source));
REQUIRE_CRASH(wil::try_com_query_nothrow<TargetIFace>(source));
REQUIRE_CRASH(wil::try_com_query_nothrow<INever>(source));
}
}
SECTION("try_com_query_to(U**)")
{
if (source)
{
DestPtr dest1;
REQUIRE(wil::try_com_query_to(source, &dest1));
REQUIRE_FALSE(wil::try_com_query_to(source, &never));
REQUIRE((dest1 && !never));
}
else
{
DestPtr dest1;
REQUIRE_CRASH(wil::try_com_query_to(source, &dest1));
REQUIRE_CRASH(wil::try_com_query_to(source, &never));
}
}
SECTION("com_copy")
{
if (source)
{
#ifdef WIL_ENABLE_EXCEPTIONS
REQUIRE(wil::com_copy<TargetIFace>(source));
REQUIRE_ERROR(wil::com_copy<INever>(source));
#endif
REQUIRE(wil::com_copy_failfast<TargetIFace>(source));
REQUIRE_ERROR(wil::com_copy_failfast<INever>(source));
}
else
{
#ifdef WIL_ENABLE_EXCEPTIONS
REQUIRE_FALSE(wil::com_copy<TargetIFace>(source));
REQUIRE_FALSE(wil::com_copy<INever>(source));
#endif
REQUIRE_FALSE(wil::com_copy_failfast<TargetIFace>(source));
REQUIRE_FALSE(wil::com_copy_failfast<INever>(source));
}
}
SECTION("com_copy_to(U**)")
{
if (source)
{
#ifdef WIL_ENABLE_EXCEPTIONS
DestPtr dest1;
wil::com_copy_to(source, &dest1);
REQUIRE_ERROR(wil::com_copy_to(source, &never));
REQUIRE((dest1 && !never));
#endif
DestPtr dest2, dest3;
wil::com_copy_to_failfast(source, &dest2);
REQUIRE_ERROR(wil::com_copy_to_failfast(source, &never));
wil::com_copy_to_nothrow(source, &dest3);
REQUIRE_ERROR(wil::com_copy_to_nothrow(source, &never));
REQUIRE((dest2 && dest3 && !never));
}
else
{
#ifdef WIL_ENABLE_EXCEPTIONS
DestPtr dest1;
wil::com_copy_to(source, &dest1);
wil::com_copy_to(source, &never);
#endif
DestPtr dest2, dest3;
wil::com_copy_to_failfast(source, &dest2);
wil::com_copy_to_failfast(source, &never);
wil::com_copy_to_nothrow(source, &dest3);
wil::com_copy_to_nothrow(source, &never);
}
}
SECTION("try_com_copy")
{
if (source)
{
#ifdef WIL_ENABLE_EXCEPTIONS
REQUIRE(wil::try_com_copy<TargetIFace>(source));
REQUIRE_FALSE(wil::try_com_copy<INever>(source));
#endif
REQUIRE(wil::try_com_copy_failfast<TargetIFace>(source));
REQUIRE_FALSE(wil::try_com_copy_failfast<INever>(source));
REQUIRE(wil::try_com_copy_nothrow<TargetIFace>(source));
REQUIRE_FALSE(wil::try_com_copy_nothrow<INever>(source));
}
else
{
#ifdef WIL_ENABLE_EXCEPTIONS
REQUIRE_FALSE(wil::try_com_copy<TargetIFace>(source));
REQUIRE_FALSE(wil::try_com_copy<INever>(source));
#endif
REQUIRE_FALSE(wil::try_com_copy_failfast<TargetIFace>(source));
REQUIRE_FALSE(wil::try_com_copy_failfast<INever>(source));
REQUIRE_FALSE(wil::try_com_copy_nothrow<TargetIFace>(source));
REQUIRE_FALSE(wil::try_com_copy_nothrow<INever>(source));
}
}
SECTION("try_com_copy_to(U**)")
{
if (source)
{
DestPtr dest1;
REQUIRE(wil::try_com_copy_to(source, &dest1));
REQUIRE_FALSE(wil::try_com_copy_to(source, &never));
REQUIRE((dest1 && !never));
}
else
{
DestPtr dest1;
REQUIRE_FALSE(wil::try_com_copy_to(source, &dest1));
REQUIRE_FALSE(wil::try_com_copy_to(source, &never));
}
}
TestGlobalQueryIidPpv<TargetIFace, Ptr>(typename wistd::is_abstract<TargetIFace>::type(), source);
}
// Test fluent query functions for types that support them (exception and fail fast)
template <typename IFace, typename Ptr>
void TestSmartPointerQueryFluent(wistd::true_type, const Ptr& source) // void return (non-error based)
{
SECTION("query")
{
if (source)
{
REQUIRE(source.template query<IFace>());
REQUIRE_ERROR(source.template query<INever>());
}
else
{
REQUIRE_CRASH(source.template query<IFace>());
REQUIRE_CRASH(source.template query<INever>());
}
}
SECTION("copy")
{
if (source)
{
REQUIRE(source.template copy<IFace>());
REQUIRE_ERROR(source.template copy<INever>());
}
else
{
REQUIRE_FALSE(source.template copy<IFace>());
REQUIRE_FALSE(source.template copy<INever>());
}
}
}
// "Test" fluent query functions for error-based types (by doing nothing)
template <typename IFace, typename Ptr>
void TestSmartPointerQueryFluent(wistd::false_type, const Ptr& /*source*/) // error-code based return
{
// error code based code cannot call the fluent error methods
}
// Test iid, ppv queries for types that support them (interfaces yes, classes no)
template <typename IFace, typename Ptr>
void TestSmartPointerQueryIidPpv(wistd::true_type, const Ptr& source) // interface
{
wil::com_ptr_nothrow<INever> never;
using DestPtr = wil::com_ptr_nothrow<IFace>;
SECTION("query_to(iid, ppv)")
{
if (source)
{
DestPtr dest;
source.query_to(IID_PPV_ARGS(&dest));
REQUIRE_ERROR(source.query_to(IID_PPV_ARGS(&never)));
REQUIRE((dest && !never));
}
else
{
DestPtr dest;
REQUIRE_CRASH(source.query_to(IID_PPV_ARGS(&dest)));
REQUIRE_CRASH(source.query_to(IID_PPV_ARGS(&never)));
REQUIRE((!dest && !never));
}
}
SECTION("try_query_to(iid, ppv)")
{
if (source)
{
DestPtr dest;
REQUIRE(source.try_query_to(IID_PPV_ARGS(&dest)));
REQUIRE(!source.try_query_to(IID_PPV_ARGS(&never)));
REQUIRE((dest && !never));
}
else
{
DestPtr dest;
REQUIRE_CRASH(source.try_query_to(IID_PPV_ARGS(&dest)));
REQUIRE_CRASH(source.try_query_to(IID_PPV_ARGS(&never)));
REQUIRE((!dest && !never));
}
}
SECTION("copy_to(iid, ppv)")
{
if (source)
{
DestPtr dest;
source.copy_to(IID_PPV_ARGS(&dest));
REQUIRE_ERROR(source.copy_to(IID_PPV_ARGS(&never)));
REQUIRE((dest && !never));
}
else
{
DestPtr dest;
source.copy_to(IID_PPV_ARGS(&dest));
source.copy_to(IID_PPV_ARGS(&never));
REQUIRE((!dest && !never));
}
}
SECTION("try_copy_to(iid, ppv)")
{
if (source)
{
DestPtr dest;
REQUIRE(source.try_copy_to(IID_PPV_ARGS(&dest)));
REQUIRE(!source.try_copy_to(IID_PPV_ARGS(&never)));
REQUIRE((dest && !never));
}
else
{
DestPtr dest;
REQUIRE(!source.try_copy_to(IID_PPV_ARGS(&dest)));
REQUIRE(!source.try_copy_to(IID_PPV_ARGS(&never)));
REQUIRE((!dest && !never));
}
}
}
// "Test" iid, ppv queries for types that support them for a class (unsupported same (interfaces yes, classes no)
template <typename IFace, typename Ptr>
void TestSmartPointerQueryIidPpv(wistd::false_type, const Ptr& /*source*/) // class
{
// we can't compile against iid, ppv with a class
}
// Test the various query and copy methods against the given source pointer (trying produce the given dest pointer)
template <typename IFace, typename Ptr>
void TestSmartPointerQuery(const Ptr& source)
{
wil::com_ptr_nothrow<INever> never;
using DestPtr = wil::com_ptr_nothrow<IFace>;
SECTION("query_to(U**)")
{
if (source)
{
DestPtr dest;
source.query_to(&dest);
REQUIRE_ERROR(source.query_to(&never));
REQUIRE((dest && !never));
}
else
{
DestPtr dest;
REQUIRE_CRASH(source.query_to(&dest));
REQUIRE_CRASH(source.query_to(&never));
REQUIRE((!dest && !never));
}
}
SECTION("try_query")
{
if (source)
{
REQUIRE(source.template try_query<IFace>());
REQUIRE_FALSE(source.template try_query<INever>());
}
else
{
REQUIRE_CRASH(source.template try_query<IFace>());
REQUIRE_CRASH(source.template try_query<INever>());
}
}
SECTION("try_query_to(U**)")
{
if (source)
{
DestPtr dest;
REQUIRE(source.try_query_to(&dest));
REQUIRE_FALSE(source.try_query_to(&never));
REQUIRE((dest && !never));
}
else
{
DestPtr dest;
REQUIRE_CRASH(source.try_query_to(&dest));
REQUIRE_CRASH(source.try_query_to(&never));
REQUIRE((!dest && !never));
}
}
SECTION("copy_to(U**)")
{
if (source)
{
DestPtr dest;
source.copy_to(&dest);
REQUIRE_ERROR(source.copy_to(&never));
REQUIRE((dest && !never));
}
else
{
DestPtr dest;
source.copy_to(&dest);
source.copy_to(&never);
REQUIRE((!dest && !never));
}
}
SECTION("try_copy")
{
if (source)
{
REQUIRE(source.template try_copy<IFace>());
REQUIRE_FALSE(source.template try_copy<INever>());
}
else
{
REQUIRE_FALSE(source.template try_copy<IFace>());
REQUIRE_FALSE(source.template try_copy<INever>());
}
}
SECTION("try_copy_to(U**)")
{
if (source)
{
DestPtr dest;
REQUIRE(source.try_copy_to(&dest));
REQUIRE_FALSE(source.try_copy_to(&never));
REQUIRE((dest && !never));
}
else
{
DestPtr dest;
REQUIRE_FALSE(source.try_copy_to(&dest));
REQUIRE_FALSE(source.try_copy_to(&never));
REQUIRE((!dest && !never));
}
}
TestSmartPointerQueryFluent<IFace, Ptr>(typename wistd::is_same<void, typename Ptr::result>::type(), source);
TestSmartPointerQueryIidPpv<IFace, Ptr>(typename wistd::is_abstract<IFace>::type(), source);
}
template <typename TargetIFace, typename IFace>
static void TestQueryCombination(IFace* ptr)
{
TestGlobalQuery<TargetIFace>(ptr);
#ifdef WIL_EXHAUSTIVE_TEST
#ifdef WIL_ENABLE_EXCEPTIONS
TestGlobalQuery<TargetIFace>(wil::com_ptr<IFace>(ptr));
#endif
TestGlobalQuery<TargetIFace>(wil::com_ptr_failfast<IFace>(ptr));
#endif
TestGlobalQuery<TargetIFace>(wil::com_ptr_nothrow<IFace>(ptr));
TestGlobalQuery<TargetIFace>(Microsoft::WRL::ComPtr<IFace>(ptr));
#ifdef WIL_ENABLE_EXCEPTIONS
TestSmartPointerQuery<TargetIFace>(wil::com_ptr<IFace>(ptr));
#endif
TestSmartPointerQuery<TargetIFace>(wil::com_ptr_failfast<IFace>(ptr));
TestSmartPointerQuery<TargetIFace>(wil::com_ptr_nothrow<IFace>(ptr));
}
template <typename TargetIFace, typename IFace>
static void TestQuery(IFace* ptr)
{
IFace* nullPtr = nullptr;
TestQueryCombination<TargetIFace>(ptr);
TestQueryCombination<TargetIFace>(nullPtr);
}
template <typename IFace, typename TargetIFace, typename Object>
static void TestQuery()
{
auto ptr = make_object<IFace, Object>();
TestQuery<TargetIFace>(ptr);
ptr->Release();
}
TEST_CASE("ComTests::Test_Query", "[com][com_ptr]")
{
// avoid overwhelming debug logging, perhaps the COM helpers are over reporting
auto restoreDebugString = wil::g_fResultOutputDebugString;
wil::g_fResultOutputDebugString = false;
TestQuery<ComObject, ComObject, ComObject>(); // Same type (no QI)
TestQuery<ComObject, IUnknown, ComObject>(); // Ambiguous base (must QI)
TestQuery<ComObject, ITest, ComObject>(); // Non-ambiguous base (no QI)
// This adds a significant amount of time to the compilation duration, so most tests are disabled by default...
#ifdef WIL_EXHAUSTIVE_TEST
TestQuery<ComObject, IDerivedTest, ComObject>(); // ComObject
TestQuery<ComObject, IAlways, ComObject>();
TestQuery<IUnknown, IUnknown, ComObject>(); // IUnknown
TestQuery<IUnknown, ITest, ComObject>();
TestQuery<IUnknown, IDerivedTest, ComObject>();
TestQuery<IUnknown, IAlways, ComObject>();
TestQuery<ITest, IUnknown, ComObject>(); // ITest
TestQuery<ITest, ITest, ComObject>();
TestQuery<ITest, IDerivedTest, ComObject>();
TestQuery<ITest, IAlways, ComObject>();
TestQuery<IDerivedTest, IUnknown, ComObject>(); // IDerivedTest
TestQuery<IDerivedTest, ITest, ComObject>();
TestQuery<IDerivedTest, IDerivedTest, ComObject>();
TestQuery<IDerivedTest, IAlways, ComObject>();
TestQuery<IAlways, IUnknown, ComObject>(); // IAlways
TestQuery<IAlways, ITest, ComObject>();
TestQuery<IAlways, IDerivedTest, ComObject>();
TestQuery<IAlways, IAlways, ComObject>();
TestQuery<WinRtObject, WinRtObject, WinRtObject>(); // WinRtObject
TestQuery<WinRtObject, IUnknown, WinRtObject>();
TestQuery<WinRtObject, ITest, WinRtObject>();
TestQuery<WinRtObject, IInspectable, WinRtObject>();
TestQuery<WinRtObject, ITestInspectable, WinRtObject>();
TestQuery<WinRtObject, IDerivedTest, WinRtObject>();
TestQuery<WinRtObject, IDerivedTestInspectable, WinRtObject>();
TestQuery<WinRtObject, IAlways, WinRtObject>();
TestQuery<IUnknown, IUnknown, WinRtObject>(); // IUnknown
TestQuery<IUnknown, IInspectable, WinRtObject>();
TestQuery<IUnknown, ITest, WinRtObject>();
TestQuery<IUnknown, IDerivedTest, WinRtObject>();
TestQuery<IUnknown, ITestInspectable, WinRtObject>();
TestQuery<IUnknown, IDerivedTestInspectable, WinRtObject>();
TestQuery<IUnknown, IAlways, WinRtObject>();
TestQuery<IInspectable, IUnknown, WinRtObject>(); // IInspectable
TestQuery<IInspectable, IInspectable, WinRtObject>();
TestQuery<IInspectable, ITest, WinRtObject>();
TestQuery<IInspectable, IDerivedTest, WinRtObject>();
TestQuery<IInspectable, ITestInspectable, WinRtObject>();
TestQuery<IInspectable, IDerivedTestInspectable, WinRtObject>();
TestQuery<IInspectable, IAlways, WinRtObject>();
TestQuery<ITest, IUnknown, WinRtObject>(); // ITest
TestQuery<ITest, IInspectable, WinRtObject>();
TestQuery<ITest, ITest, WinRtObject>();
TestQuery<ITest, IDerivedTest, WinRtObject>();
TestQuery<ITest, ITestInspectable, WinRtObject>();
TestQuery<ITest, IDerivedTestInspectable, WinRtObject>();
TestQuery<ITest, IAlways, WinRtObject>();
TestQuery<IDerivedTest, IUnknown, WinRtObject>(); // IDerivedTest
TestQuery<IDerivedTest, IInspectable, WinRtObject>();
TestQuery<IDerivedTest, ITest, WinRtObject>();
TestQuery<IDerivedTest, IDerivedTest, WinRtObject>();
TestQuery<IDerivedTest, ITestInspectable, WinRtObject>();
TestQuery<IDerivedTest, IDerivedTestInspectable, WinRtObject>();
TestQuery<IDerivedTest, IAlways, WinRtObject>();
TestQuery<ITestInspectable, IUnknown, WinRtObject>(); // ITestInspectable
TestQuery<ITestInspectable, IInspectable, WinRtObject>();
TestQuery<ITestInspectable, ITest, WinRtObject>();
TestQuery<ITestInspectable, IDerivedTest, WinRtObject>();
TestQuery<ITestInspectable, ITestInspectable, WinRtObject>();
TestQuery<ITestInspectable, IDerivedTestInspectable, WinRtObject>();
TestQuery<ITestInspectable, IAlways, WinRtObject>();
TestQuery<IDerivedTestInspectable, IUnknown, WinRtObject>(); // IDerivedTestInspectable
TestQuery<IDerivedTestInspectable, IInspectable, WinRtObject>();
TestQuery<IDerivedTestInspectable, ITest, WinRtObject>();
TestQuery<IDerivedTestInspectable, IDerivedTest, WinRtObject>();
TestQuery<IDerivedTestInspectable, ITestInspectable, WinRtObject>();
TestQuery<IDerivedTestInspectable, IDerivedTestInspectable, WinRtObject>();
TestQuery<IDerivedTestInspectable, IAlways, WinRtObject>();
TestQuery<IAlways, IUnknown, WinRtObject>(); // IAlways
TestQuery<IAlways, IInspectable, WinRtObject>();
TestQuery<IAlways, ITest, WinRtObject>();
TestQuery<IAlways, IDerivedTest, WinRtObject>();
TestQuery<IAlways, ITestInspectable, WinRtObject>();
TestQuery<IAlways, IDerivedTestInspectable, WinRtObject>();
TestQuery<IAlways, IAlways, WinRtObject>();
#endif
REQUIRE_FALSE(witest::g_objectCount.Leaked());
wil::g_fResultOutputDebugString = restoreDebugString;
}
#if (NTDDI_VERSION >= NTDDI_WINBLUE)
template <typename Ptr>
void TestAgile(const Ptr& source)
{
bool source_valid = (source != nullptr);
if (source)
{
#ifdef WIL_ENABLE_EXCEPTIONS
auto agile1 = wil::com_agile_query(source);
REQUIRE(agile1);
#endif
auto agile2 = wil::com_agile_query_failfast(source);
wil::com_agile_ref_nothrow agile3;
REQUIRE_SUCCEEDED(wil::com_agile_query_nothrow(source, &agile3));
REQUIRE((agile2 && agile3));
}
else
{
#ifdef WIL_ENABLE_EXCEPTIONS
REQUIRE_CRASH(wil::com_agile_query(source));
#endif
REQUIRE_CRASH(wil::com_agile_query_failfast(source));
wil::com_agile_ref_nothrow agile3;
REQUIRE_CRASH(wil::com_agile_query_nothrow(source, &agile3));
}
#ifdef WIL_ENABLE_EXCEPTIONS
auto agile1 = wil::com_agile_copy(source);
REQUIRE(static_cast<bool>(agile1) == source_valid);
#endif
auto agile2 = wil::com_agile_copy_failfast(source);
wil::com_agile_ref_nothrow agile3;
REQUIRE_SUCCEEDED(wil::com_agile_copy_nothrow(source, &agile3));
REQUIRE(static_cast<bool>(agile2) == source_valid);
REQUIRE(static_cast<bool>(agile3) == source_valid);
}
template <typename IFace>
void TestAgileCombinations()
{
auto ptr = make_object<IFace, WinRtObject>();
REQUIRE_SUCCEEDED(::CoInitializeEx(nullptr, COINIT_APARTMENTTHREADED));
auto exit = wil::scope_exit([] { ::CoUninitialize(); });
TestAgile(ptr);
TestAgile(wil::com_ptr_nothrow<IFace>(ptr));
TestAgile(Microsoft::WRL::ComPtr<IFace>(ptr));
auto agilePtr = wil::com_agile_query_failfast(ptr);
TestQuery<ITest>(agilePtr.get());
#ifdef WIL_EXHAUSTIVE_TEST
TestQuery<IUnknown>(agilePtr.get());
TestQuery<IInspectable>(agilePtr.get());
TestQuery<IDerivedTest>(agilePtr.get());
TestQuery<ITestInspectable>(agilePtr.get());
TestQuery<IDerivedTestInspectable>(agilePtr.get());
TestQuery<IAlways>(agilePtr.get());
#endif
ptr->Release();
}
TEST_CASE("ComTests::Test_Agile", "[com][com_agile_ref]")
{
// TestAgileCombinations<WinRtObject>();
TestAgileCombinations<IUnknown>();
TestAgileCombinations<IInspectable>();
TestAgileCombinations<ITest>();
#ifdef WIL_EXHAUSTIVE_TEST
TestAgileCombinations<IDerivedTest>();
TestAgileCombinations<ITestInspectable>();
TestAgileCombinations<IDerivedTestInspectable>();
TestAgileCombinations<IAlways>();
#endif
REQUIRE_FALSE(witest::g_objectCount.Leaked());
}
#endif
template <typename Ptr>
void TestWeak(const Ptr& source)
{
bool supports_weak = (source && (wil::try_com_query_nothrow<IInspectable>(source)));
if (supports_weak && source)
{
#ifdef WIL_ENABLE_EXCEPTIONS
auto weak1 = wil::com_weak_query(source);
REQUIRE(weak1);
#endif
auto weak2 = wil::com_weak_query_failfast(source);
wil::com_weak_ref_nothrow weak3;
REQUIRE_SUCCEEDED(wil::com_weak_query_nothrow(source, &weak3));
REQUIRE((weak2 && weak3));
#ifdef WIL_ENABLE_EXCEPTIONS
auto weak1copy = wil::com_weak_copy(source);
REQUIRE(weak1copy);
#endif
auto weak2copy = wil::com_weak_copy_failfast(source);
wil::com_weak_ref_nothrow weak3copy;
REQUIRE_SUCCEEDED(wil::com_weak_copy_nothrow(source, &weak3copy));
REQUIRE((weak2copy && weak3copy));
}
else if (source)
{
#ifdef WIL_ENABLE_EXCEPTIONS
REQUIRE_ERROR(wil::com_weak_query(source));
#endif
REQUIRE_ERROR(wil::com_weak_query_failfast(source));
wil::com_weak_ref_nothrow weak3err;
REQUIRE_ERROR(wil::com_weak_query_nothrow(source, &weak3err));
#ifdef WIL_ENABLE_EXCEPTIONS
REQUIRE_ERROR(wil::com_weak_copy(source));
#endif
REQUIRE_ERROR(wil::com_weak_copy_failfast(source));
wil::com_weak_ref_nothrow weak3;
REQUIRE_ERROR(wil::com_weak_copy_nothrow(source, &weak3));
}
else // !source
{
#ifdef WIL_ENABLE_EXCEPTIONS
REQUIRE_CRASH(wil::com_weak_query(source));
#endif
REQUIRE_CRASH(wil::com_weak_query_failfast(source));
wil::com_weak_ref_nothrow weak3crash;
REQUIRE_CRASH(wil::com_weak_query_nothrow(source, &weak3crash));
#ifdef WIL_ENABLE_EXCEPTIONS
auto weak1 = wil::com_weak_copy(source);
REQUIRE(!weak1);
#endif
auto weak2 = wil::com_weak_copy_failfast(source);
wil::com_weak_ref_nothrow weak3;
REQUIRE_SUCCEEDED(wil::com_weak_copy_nothrow(source, &weak3));
REQUIRE((!weak2 && !weak3));
}
}
template <typename TargetIFace, typename Ptr>
void TestGlobalQueryWithFailedResolve(const Ptr& source)
{
// No need to test the null source and wrong interface query
// since that's covered in the TestGlobalQuery.
using DestPtr = wil::com_ptr_nothrow<TargetIFace>;
SECTION("com_query")
{
#ifdef WIL_ENABLE_EXCEPTIONS
REQUIRE_ERROR(wil::com_query<TargetIFace>(source));
#endif
REQUIRE_ERROR(wil::com_query_failfast<TargetIFace>(source));
}
SECTION("com_query_to(U**)")
{
#ifdef WIL_ENABLE_EXCEPTIONS
DestPtr dest1;
REQUIRE_ERROR(wil::com_query_to(source, &dest1));
REQUIRE(!dest1);
#endif
DestPtr dest2, dest3;
REQUIRE_ERROR(wil::com_query_to_failfast(source, &dest2));
REQUIRE_ERROR(wil::com_query_to_nothrow(source, &dest3));
REQUIRE((!dest2 && !dest3));
}
SECTION("try_com_query")
{
#ifdef WIL_ENABLE_EXCEPTIONS
REQUIRE(!wil::try_com_query<TargetIFace>(source));
#endif
REQUIRE(!wil::try_com_query_failfast<TargetIFace>(source));
REQUIRE(!wil::try_com_query_nothrow<TargetIFace>(source));
}
SECTION("try_com_query_to(U**)")
{
DestPtr dest1;
REQUIRE(!wil::try_com_query_to(source, &dest1));
REQUIRE(!dest1);
}
SECTION("com_copy")
{
#ifdef WIL_ENABLE_EXCEPTIONS
REQUIRE_ERROR(wil::com_copy<TargetIFace>(source));
#endif
REQUIRE_ERROR(wil::com_copy_failfast<TargetIFace>(source));
}
SECTION("com_copy_to(U**)")
{
#ifdef WIL_ENABLE_EXCEPTIONS
DestPtr dest1;
REQUIRE_ERROR(wil::com_copy_to(source, &dest1));
REQUIRE(!dest1);
#endif
DestPtr dest2, dest3;
REQUIRE_ERROR(wil::com_copy_to_failfast(source, &dest2));
REQUIRE_ERROR(wil::com_copy_to_nothrow(source, &dest3));
REQUIRE((!dest2 && !dest3));
}
SECTION("try_com_copy")
{
#ifdef WIL_ENABLE_EXCEPTIONS
REQUIRE(!wil::try_com_copy<TargetIFace>(source));
#endif
REQUIRE(!wil::try_com_copy_failfast<TargetIFace>(source));
REQUIRE(!wil::try_com_copy_nothrow<TargetIFace>(source));
}
SECTION("try_com_copy_to(U**)")
{
DestPtr dest1;
REQUIRE(!wil::try_com_copy_to(source, &dest1));
REQUIRE(!dest1);
}
if (wistd::is_abstract<TargetIFace>::value)
{
SECTION("com_query_to(iid, ppv)")
{
#ifdef WIL_ENABLE_EXCEPTIONS
DestPtr dest1;
REQUIRE_ERROR(wil::com_query_to(source, IID_PPV_ARGS(&dest1)));
REQUIRE(!dest1);
#endif
DestPtr dest2, dest3;
REQUIRE_ERROR(wil::com_query_to_failfast(source, IID_PPV_ARGS(&dest2)));
REQUIRE_ERROR(wil::com_query_to_nothrow(source, IID_PPV_ARGS(&dest3)));
REQUIRE((!dest2 && !dest3));
}
SECTION("try_com_query_to(iid, ppv)")
{
DestPtr dest1;
REQUIRE(!wil::try_com_query_to(source, IID_PPV_ARGS(&dest1)));
REQUIRE(!dest1);
}
SECTION("com_copy_to(iid, ppv)")
{
#ifdef WIL_ENABLE_EXCEPTIONS
DestPtr dest1;
REQUIRE_ERROR(wil::com_copy_to(source, IID_PPV_ARGS(&dest1)));
REQUIRE(!dest1);
#endif
DestPtr dest2, dest3;
REQUIRE_ERROR(wil::com_copy_to_failfast(source, IID_PPV_ARGS(&dest2)));
REQUIRE_ERROR(wil::com_copy_to_nothrow(source, IID_PPV_ARGS(&dest3)));
REQUIRE((!dest2 && !dest3));
}
SECTION("try_com_copy_to(iid, ppv)")
{
DestPtr dest1;
REQUIRE(!wil::try_com_copy_to(source, IID_PPV_ARGS(&dest1)));
REQUIRE(!dest1);
}
}
}
template <typename TargetIFace, typename Ptr>
void TestSmartPointerQueryFluentWithFailedResolve(wistd::false_type, const Ptr& /*source*/)
{
}
template <typename TargetIFace, typename Ptr>
void TestSmartPointerQueryFluentWithFailedResolve(wistd::true_type, const Ptr& source)
{
REQUIRE_ERROR(source.template query<TargetIFace>());
REQUIRE_ERROR(source.template copy<TargetIFace>());
}
template <typename TargetIFace, typename Ptr>
void TestSmartPointerQueryWithFailedResolve(const Ptr source)
{
using DestPtr = wil::com_ptr_nothrow<TargetIFace>;
SECTION("query_to(U**)")
{
DestPtr dest;
REQUIRE_ERROR(source.query_to(&dest));
REQUIRE(!dest);
}
SECTION("try_query")
{
REQUIRE(!source.template try_query<TargetIFace>());
}
SECTION("try_query_to(U**)")
{
DestPtr dest;
REQUIRE(!source.try_query_to(&dest));
REQUIRE(!dest);
}
SECTION("copy_to(U**)")
{
DestPtr dest;
REQUIRE_ERROR(source.copy_to(&dest));
REQUIRE(!dest);
}
SECTION("try_copy")
{
REQUIRE(!source.template try_copy<TargetIFace>());
}
SECTION("try_copy_to(U**)")
{
DestPtr dest;
REQUIRE(!source.try_copy_to(&dest));
REQUIRE(!dest);
}
TestSmartPointerQueryFluentWithFailedResolve<TargetIFace, Ptr>(typename wistd::is_same<void, typename Ptr::result>::type(), source);
if (wistd::is_abstract<TargetIFace>::value)
{
SECTION("query_to(iid, ppv)")
{
DestPtr dest;
REQUIRE_ERROR(source.query_to(IID_PPV_ARGS(&dest)));
REQUIRE(!dest);
}
SECTION("try_query_to(iid, ppv)")
{
DestPtr dest;
REQUIRE(!source.try_query_to(IID_PPV_ARGS(&dest)));
REQUIRE(!dest);
}
SECTION("copy_to(iid, ppv)")
{
DestPtr dest;
REQUIRE_ERROR(source.copy_to(IID_PPV_ARGS(&dest)));
REQUIRE(!dest);
}
SECTION("try_copy_to(iid, ppv)")
{
DestPtr dest;
REQUIRE(!source.try_copy_to(IID_PPV_ARGS(&dest)));
REQUIRE(!dest);
}
}
}
template <typename TargetIFace, typename IFace>
void TestQueryWithFailedResolve(IFace* ptr)
{
TestGlobalQueryWithFailedResolve<TargetIFace>(ptr);
#ifdef WIL_EXHAUSTIVE_TEST
#ifdef WIL_ENABLE_EXCEPTIONS
TestGlobalQueryWithFailedResolve<TargetIFace>(wil::com_ptr<IFace>(ptr));
#endif
TestGlobalQueryWithFailedResolve<TargetIFace>(wil::com_ptr_failfast<IFace>(ptr));
#endif
TestGlobalQueryWithFailedResolve<TargetIFace>(wil::com_ptr_nothrow<IFace>(ptr));
TestGlobalQueryWithFailedResolve<TargetIFace>(Microsoft::WRL::ComPtr<IFace>(ptr));
#ifdef WIL_ENABLE_EXCEPTIONS
TestSmartPointerQueryWithFailedResolve<TargetIFace>(wil::com_ptr<IFace>(ptr));
#endif
TestSmartPointerQueryWithFailedResolve<TargetIFace>(wil::com_ptr_nothrow<IFace>(ptr));
TestSmartPointerQueryWithFailedResolve<TargetIFace>(wil::com_ptr_failfast<IFace>(ptr));
}
template <typename IFace>
void TestWeakCombinations()
{
auto ptr = make_object<IFace, WinRtObject>();
TestWeak(ptr);
TestWeak(wil::com_ptr_nothrow<IFace>(ptr));
TestWeak(Microsoft::WRL::ComPtr<IFace>(ptr));
auto weakPtr = wil::com_weak_query_failfast(ptr);
TestQuery<IUnknown>(weakPtr.get()); // Not IInspectable derived
TestQuery<ITest>(weakPtr.get()); // IInspectable derived
#ifdef WIL_EXHAUSTIVE_TEST
TestQuery<IInspectable>(weakPtr.get());
TestQuery<IDerivedTest>(weakPtr.get());
TestQuery<ITestInspectable>(weakPtr.get());
TestQuery<IDerivedTestInspectable>(weakPtr.get());
TestQuery<IAlways>(weakPtr.get());
#endif
// On the final release of the pointer, the weak reference will no longer resolve
ptr->Release();
TestQueryWithFailedResolve<IUnknown>(weakPtr.get());
TestQueryWithFailedResolve<ITest>(weakPtr.get());
#ifdef WIL_EXHAUSTIVE_TEST
TestQueryWithFailedResolve<IInspectable>(weakPtr.get());
#endif
}
TEST_CASE("ComTests::Test_Weak", "[com][com_weak_ref]")
{
// TestWeakCombinations<WinRtObject>();
TestWeakCombinations<ITest>();
#ifdef WIL_EXHAUSTIVE_TEST
TestWeakCombinations<IUnknown>();
TestWeakCombinations<IInspectable>();
TestWeakCombinations<IDerivedTest>();
TestWeakCombinations<ITestInspectable>();
TestWeakCombinations<IDerivedTestInspectable>();
TestWeakCombinations<IAlways>();
#endif
REQUIRE_FALSE(witest::g_objectCount.Leaked());
}
#if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP)
TEST_CASE("ComTests::VerifyCoCreate", "[com][CoCreateInstance]")
{
auto init = wil::CoInitializeEx_failfast();
// success cases
#ifdef WIL_ENABLE_EXCEPTIONS
auto link1 = wil::CoCreateInstance<ShellLink>();
auto link2 = wil::CoCreateInstance(CLSID_ShellLink);
#endif
auto link3 = wil::CoCreateInstanceFailFast<ShellLink>();
auto link4 = wil::CoCreateInstanceFailFast(CLSID_ShellLink);
auto link5 = wil::CoCreateInstanceNoThrow<ShellLink>();
auto link6 = wil::CoCreateInstanceNoThrow(CLSID_ShellLink);
// failure
#ifdef WIL_ENABLE_EXCEPTIONS
REQUIRE_THROWS((wil::CoCreateInstance<ShellLink, IStream>()));
#endif
// skip this test, assume testing the exception based version is sufficient.
// auto link2 = wil::CoCreateInstanceFailFast<ShellLink, IStream>();
REQUIRE_FALSE(static_cast<bool>(wil::CoCreateInstanceNoThrow<ShellLink, IStream>().get()));
}
TEST_CASE("ComTests::VerifyCoGetClassObject", "[com][CoGetClassObject]")
{
auto init = wil::CoInitializeEx_failfast();
// success cases
#ifdef WIL_ENABLE_EXCEPTIONS
auto linkFactory1 = wil::CoGetClassObject<ShellLink>();
auto linkFactory2 = wil::CoGetClassObject(CLSID_ShellLink);
#endif
auto linkFactory3 = wil::CoGetClassObjectFailFast<ShellLink>();
auto linkFactory4 = wil::CoGetClassObjectFailFast(CLSID_ShellLink);
auto linkFactory5 = wil::CoGetClassObjectNoThrow<ShellLink>();
auto linkFactory6 = wil::CoGetClassObjectNoThrow(CLSID_ShellLink);
// failure
#ifdef WIL_ENABLE_EXCEPTIONS
REQUIRE_THROWS((wil::CoGetClassObject<ShellLink, IStream>()));
#endif
// skip this test, assume testing the exception based version is sufficient.
// auto linkFactory2 = wil::CoGetClassObjectFailFast<ShellLink, IStream>();
REQUIRE_FALSE(static_cast<bool>(wil::CoGetClassObjectNoThrow<ShellLink, IStream>()));
}
#endif
#ifdef __IObjectWithSite_INTERFACE_DEFINED__
TEST_CASE("ComTests::VerifyComSetSiteNullIsMoveOnly", "[com][com_set_site]")
{
wil::unique_set_site_null_call call1;
// intentional compilation errors for copy construction/assignment
// wil::unique_set_site_null_call call2 = call1;
// call2 = call1;
auto siteSetter = wil::com_set_site(nullptr, nullptr);
auto siteSetter2 = std::move(siteSetter); // Move construction
siteSetter2 = std::move(siteSetter); // Move assignment
}
TEST_CASE("ComTests::VerifyComSetSite", "[com][com_set_site]")
{
class ObjectWithSite WrlFinal : public RuntimeClass<RuntimeClassFlags<ClassicCom>, IObjectWithSite>
{
public:
STDMETHODIMP SetSite(IUnknown* val) noexcept override
{
m_site = val;
return S_OK;
}
STDMETHODIMP GetSite(REFIID riid, void** ppv) noexcept override
{
m_site.try_copy_to(riid, ppv);
return S_OK;
}
private:
wil::com_ptr_nothrow<IUnknown> m_site;
};
class ServiceObject WrlFinal : public RuntimeClass<RuntimeClassFlags<ClassicCom>, IServiceProvider>
{
public:
ServiceObject(IServiceProvider* site = nullptr)
{
m_site = site;
}
STDMETHODIMP QueryService(REFIID /*sid*/, REFIID /*riid*/, void** ppv) noexcept override
{
*ppv = nullptr;
return E_NOTIMPL;
}
private:
wil::com_ptr_nothrow<IUnknown> m_site;
};
auto objWithSite = Make<ObjectWithSite>();
auto serviceObj = Make<ServiceObject>();
auto serviceObj2 = Make<ServiceObject>(serviceObj.Get());
{
auto cleanupSite = wil::com_set_site(objWithSite.Get(), serviceObj2.Get());
wil::com_ptr_nothrow<IUnknown> site;
REQUIRE_SUCCEEDED(objWithSite->GetSite(IID_PPV_ARGS(&site)));
REQUIRE(static_cast<bool>(site));
auto siteCount = 0;
wil::for_each_site(objWithSite.Get(), [&](IUnknown* /*site*/)
{
siteCount++;
});
REQUIRE(siteCount == 2);
}
wil::com_ptr_nothrow<IUnknown> site;
REQUIRE_SUCCEEDED(objWithSite->GetSite(IID_PPV_ARGS(&site)));
REQUIRE_FALSE(static_cast<bool>(site));
}
#endif
class FakeStream : public IStream
{
public:
STDMETHOD(QueryInterface)(REFIID riid, PVOID* ppv) override
{
if ((riid == __uuidof(IStream)) ||
(riid == __uuidof(ISequentialStream)) ||
(riid == __uuidof(IUnknown)))
{
*ppv = static_cast<IStream*>(this);
return S_OK;
}
return E_NOTIMPL;
}
STDMETHOD_(ULONG, AddRef)() override
{
return 2;
}
STDMETHOD_(ULONG, Release)() override
{
return 1;
}
unsigned long long Position = 0;
unsigned long long PositionMax = 0;
unsigned long MaxReadSize = 0;
unsigned long MaxWriteSize = 0;
unsigned long long TotalSize = 0;
// ISequentialStream
STDMETHOD(Read)(_Out_writes_bytes_to_(cb, *pcbRead) void *pv, _In_ ULONG cb, _Out_opt_ ULONG *pcbRead) override
{
if (pcbRead)
{
*pcbRead = min(MaxReadSize, cb);
}
ZeroMemory(pv, cb);
return (MaxReadSize <= cb) ? S_OK : S_FALSE;
}
STDMETHOD(Write)(_In_reads_bytes_(cb) const void *, _In_ ULONG cb, _Out_opt_ ULONG *pcbWritten) override
{
if (pcbWritten)
{
*pcbWritten = min(MaxWriteSize, cb);
}
return (MaxWriteSize <= cb) ? S_OK : S_FALSE;
}
// IStream
STDMETHOD(Seek)(LARGE_INTEGER dlibMove, DWORD dwOrigin, _Out_opt_ ULARGE_INTEGER *plibNewPosition)
{
if (dwOrigin == STREAM_SEEK_CUR)
{
if ((dlibMove.QuadPart < 0) && (static_cast<unsigned long long>(-dlibMove.QuadPart) > Position))
{
Position = 0;
}
else
{
Position += dlibMove.QuadPart;
}
}
else if (dwOrigin == STREAM_SEEK_SET)
{
Position = static_cast<unsigned long long>(dlibMove.QuadPart);
}
else if (dwOrigin == STREAM_SEEK_END)
{
if ((dlibMove.QuadPart < 0) && (static_cast<unsigned long long>(-dlibMove.QuadPart) > Position))
{
Position = 0;
}
else
{
Position = PositionMax + dlibMove.QuadPart;
}
}
Position = min(Position, PositionMax);
if (plibNewPosition)
{
plibNewPosition->QuadPart = Position;
}
return S_OK;
}
STDMETHOD(Stat)(__RPC__out STATSTG *pstatstg, DWORD) override
{
*pstatstg = {};
pstatstg->cbSize.QuadPart = TotalSize;
return S_OK;
}
STDMETHOD(Revert)(void) override
{
return E_NOTIMPL;
}
STDMETHOD(SetSize)(ULARGE_INTEGER) override
{
return E_NOTIMPL;
}
STDMETHOD(Clone)(__RPC__deref_out_opt IStream **ppstm) override
{
*ppstm = this;
return S_OK;
}
STDMETHOD(Commit)(DWORD) override
{
return E_NOTIMPL;
}
STDMETHOD(CopyTo)(_In_ IStream *pstm, ULARGE_INTEGER cb, _Out_opt_ ULARGE_INTEGER *pcbRead, _Out_opt_ ULARGE_INTEGER *pcbWritten) override
{
unsigned long didWrite;
unsigned long didRead;
FAIL_FAST_IF(cb.HighPart != 0);
RETURN_IF_FAILED(this->Read(nullptr, cb.LowPart, &didRead));
RETURN_IF_FAILED(pstm->Write(nullptr, didRead, &didWrite));
pcbRead->QuadPart = didRead;
pcbWritten->QuadPart = didWrite;
return S_OK;
}
STDMETHOD(LockRegion)(ULARGE_INTEGER, ULARGE_INTEGER, DWORD) override
{
return E_NOTIMPL;
}
STDMETHOD(UnlockRegion)(ULARGE_INTEGER, ULARGE_INTEGER, DWORD) override
{
return E_NOTIMPL;
}
void SetPosition(unsigned long long position, unsigned long long positionMax)
{
Position = position;
PositionMax = positionMax;
}
void SetPosition(unsigned long long position)
{
return SetPosition(position, position);
}
};
TEST_CASE("StreamTests::ReadPartial", "[com][IStream]")
{
FakeStream stream;
stream.MaxReadSize = 16;
BYTE buffer[32];
ULONG readSize;
// Reading more than what's available is OK
REQUIRE_SUCCEEDED(wil::stream_read_partial_nothrow(&stream, buffer, 32, &readSize));
REQUIRE(stream.MaxReadSize == readSize);
// Reading less than what's available is OK
REQUIRE_SUCCEEDED(wil::stream_read_partial_nothrow(&stream, buffer, 5, &readSize));
REQUIRE(5 == readSize);
#ifdef WIL_ENABLE_EXCEPTIONS
REQUIRE(stream.MaxReadSize == wil::stream_read_partial(&stream, buffer, 32));
REQUIRE(5ULL == wil::stream_read_partial(&stream, buffer, 5));
#endif
}
TEST_CASE("StreamTests::Read", "[com][IStream]")
{
FakeStream stream;
stream.MaxReadSize = 10;
BYTE buffer[32];
// Reading less than available is OK
REQUIRE_SUCCEEDED(wil::stream_read_nothrow(&stream, buffer, 5));
// Reading more is not.
REQUIRE(stream.MaxReadSize < sizeof(buffer));
REQUIRE_FAILED(wil::stream_read_nothrow(&stream, buffer, sizeof(buffer)));
struct Header
{
ULONG Flags;
ULONG Other;
} header;
// Reading a POD when there's not enough fails
stream.MaxReadSize = sizeof(header) - 1;
REQUIRE_FAILED(wil::stream_read_nothrow(&stream, &header));
// Reading a POD when there is is OK (and prove that the read happened)
header.Flags = 1;
header.Other = 2;
stream.MaxReadSize = sizeof(header);
REQUIRE_SUCCEEDED(wil::stream_read_nothrow(&stream, &header));
REQUIRE(0UL == header.Flags);
REQUIRE(0UL == header.Other);
#ifdef WIL_ENABLE_EXCEPTIONS
// Reading less than available is OK
REQUIRE_NOTHROW(wil::stream_read(&stream, buffer, 5));
REQUIRE_THROWS(wil::stream_read(&stream, buffer, sizeof(buffer)));
// Reading a POD when there's not enough fails
stream.MaxReadSize = sizeof(Header) - 1;
REQUIRE_THROWS(wil::stream_read<Header>(&stream));
// Reading a POD when there is is OK (and prove that the read happened)
stream.MaxReadSize = sizeof(Header);
header = wil::stream_read<Header>(&stream);
REQUIRE(0UL == header.Flags);
REQUIRE(0UL == header.Other);
#endif
}
TEST_CASE("StreamTests::Write", "[com][IStream]")
{
FakeStream stream;
BYTE buffer[16] = { 8, 6, 7, 5, 3, 0, 9 };
stream.MaxWriteSize = sizeof(buffer) + 1;
REQUIRE_SUCCEEDED(wil::stream_write_nothrow(&stream, buffer, sizeof(buffer)));
stream.MaxWriteSize = sizeof(buffer) - 1;
REQUIRE_FAILED(wil::stream_write_nothrow(&stream, buffer, sizeof(buffer)));
struct Header
{
ULONG Flags;
ULONG Other;
} header = { 1, 2 };
stream.MaxWriteSize = sizeof(header) + 1;
REQUIRE_SUCCEEDED(wil::stream_write_nothrow(&stream, header));
stream.MaxWriteSize = sizeof(header) - 1;
REQUIRE_FAILED(wil::stream_write_nothrow(&stream, header));
#ifdef WIL_ENABLE_EXCEPTIONS
stream.MaxWriteSize = sizeof(buffer) + 1;
REQUIRE_NOTHROW(wil::stream_write(&stream, buffer, sizeof(buffer)));
stream.MaxWriteSize = sizeof(buffer) - 1;
REQUIRE_THROWS(wil::stream_write(&stream, buffer, sizeof(buffer)));
header = { 1, 2 };
stream.MaxWriteSize = sizeof(header) + 1;
REQUIRE_NOTHROW(wil::stream_write(&stream, header));
stream.MaxWriteSize = sizeof(header) - 1;
REQUIRE_THROWS(wil::stream_write(&stream, header));
#endif
}
TEST_CASE("StreamTests::Size", "[com][IStream]")
{
FakeStream stream;
unsigned long long size;
stream.TotalSize = 150;
REQUIRE_SUCCEEDED(wil::stream_size_nothrow(&stream, &size));
REQUIRE(stream.TotalSize == size);
#ifdef WIL_ENABLE_EXCEPTIONS
REQUIRE(stream.TotalSize == wil::stream_size(&stream));
#endif
}
TEST_CASE("StreamTests::SeekStart", "[com][IStream]")
{
FakeStream stream;
unsigned long long landed;
// Seek within the stream
stream.SetPosition(100, 1000);
REQUIRE_SUCCEEDED(wil::stream_set_position_nothrow(&stream, 10));
REQUIRE(10ULL == stream.Position);
// Seek and get the landing position
REQUIRE_SUCCEEDED(wil::stream_set_position_nothrow(&stream, 11, &landed));
REQUIRE(11ULL == stream.Position);
REQUIRE(11ULL == landed);
// Seek past the end
REQUIRE_SUCCEEDED(wil::stream_set_position_nothrow(&stream, 5000, &landed));
REQUIRE(stream.PositionMax == landed);
// Seek to the start
REQUIRE_SUCCEEDED(wil::stream_reset_nothrow(&stream));
REQUIRE(0ULL == stream.Position);
#ifdef WIL_ENABLE_EXCEPTIONS
// Seek within the stream
stream.SetPosition(100, 1000);
REQUIRE(10ULL == wil::stream_set_position(&stream, 10));
// Seek past the end
REQUIRE(stream.PositionMax == wil::stream_set_position(&stream, 5000));
// Seek to the start
REQUIRE_NOTHROW(wil::stream_reset(&stream));
REQUIRE(0ULL == stream.Position);
#endif
}
TEST_CASE("StreamTests::SeekCur", "[com][IStream]")
{
FakeStream stream;
unsigned long long landed;
stream.SetPosition(100, 5000);
REQUIRE_SUCCEEDED(wil::stream_seek_from_current_position_nothrow(&stream, 10, &landed));
REQUIRE(110ULL == landed);
REQUIRE_SUCCEEDED(wil::stream_seek_from_current_position_nothrow(&stream, -10, &landed));
REQUIRE(100ULL == landed);
REQUIRE_SUCCEEDED(wil::stream_seek_from_current_position_nothrow(&stream, -1000, &landed));
REQUIRE(0ULL == landed);
REQUIRE_SUCCEEDED(wil::stream_seek_from_current_position_nothrow(&stream, 6000, &landed));
REQUIRE(5000ULL == landed);
#ifdef WIL_ENABLE_EXCEPTIONS
stream.SetPosition(100, 5000);
REQUIRE(110ULL == wil::stream_seek_from_current_position(&stream, 10));
REQUIRE(100ULL == wil::stream_seek_from_current_position(&stream, -10));
REQUIRE(0ULL == wil::stream_seek_from_current_position(&stream, -1000));
REQUIRE(5000ULL == wil::stream_seek_from_current_position(&stream, 6000));
#endif
}
TEST_CASE("StreamTests::GetPosition", "[com][IStream]")
{
FakeStream stream;
unsigned long long landed;
stream.SetPosition(50);
REQUIRE_SUCCEEDED(wil::stream_get_position_nothrow(&stream, &landed));
REQUIRE(stream.Position == landed);
#ifdef WIL_ENABLE_EXCEPTIONS
REQUIRE(stream.Position == wil::stream_get_position(&stream));
#endif
}
#ifdef WIL_ENABLE_EXCEPTIONS
TEST_CASE("StreamTests::Saver", "[com][IStream]")
{
FakeStream first;
FakeStream second;
first.SetPosition(200);
{
auto saved = wil::stream_position_saver(&first);
first.SetPosition(250);
}
REQUIRE(200ULL == first.Position);
first.SetPosition(200);
{
auto saved = wil::stream_position_saver(&first);
first.SetPosition(250);
saved.reset();
REQUIRE(200ULL == first.Position);
}
first.SetPosition(200);
{
auto saved = wil::stream_position_saver(&first);
first.SetPosition(250);
saved.dismiss();
}
REQUIRE(250ULL == first.Position);
first.SetPosition(200);
second.SetPosition(250);
{
auto saved = wil::stream_position_saver(&first);
first.SetPosition(210);
saved.reset(&second);
REQUIRE(200ULL == first.Position);
second.SetPosition(300);
saved.reset();
REQUIRE(250ULL == second.Position);
}
}
#endif