#include <windows.h>

typedef int (WINAPI *FnMessageBoxA)(HWND hWnd , LPCSTR lpText, LPCSTR lpCaption, UINT uType);
typedef int (WINAPI *FnMessageBoxW)(HWND hWnd , LPCWSTR lpText, LPCWSTR lpCaption, UINT uType);

FnMessageBoxW        pfnMessageBoxW;
FnMessageBoxA        pfnMessageBoxA;

class A
{
public:
    virtual void fn()
    {
        Sleep(1);
    }

    char data[128];
};

class B
{
public:
    virtual void fn()
    {
        pfnMessageBoxW(NULL, L"Fail", L"ExploitMe", MB_OK);
    }

    char data[128];
};

int ExploitMe()
{
    volatile int         nResult = 0;
    HANDLE               hHeap;
    HANDLE               hFile;
    DWORD                NumberOfBytesRead = 0;
    A                    a;
    B                    b;
    char                 buf[512];
    LPVOID               pHeapBuf;

    hHeap = HeapCreate(0, 0x1000, 0x10000);
    pHeapBuf = HeapAlloc(hHeap, 0, sizeof(buf));
    hFile = CreateFile("exploit.dat", GENERIC_READ, FILE_SHARE_READ, NULL, OPEN_ALWAYS, FILE_ATTRIBUTE_NORMAL, NULL);
    if (INVALID_HANDLE_VALUE != hFile)
    {
        DWORD dwFileSize = GetFileSize(hFile, NULL);
        if (dwFileSize <= sizeof(buf))
        {
            ReadFile(hFile, buf, dwFileSize, &NumberOfBytesRead, NULL);
            memcpy(pHeapBuf, buf, dwFileSize);
            ZeroMemory(buf, sizeof(buf));
            
            HMODULE hMod = LoadLibrary("user32.dll");
            pfnMessageBoxW = (FnMessageBoxW)GetProcAddress(hMod, "MessageBoxW");
            pfnMessageBoxA = (FnMessageBoxA)GetProcAddress(hMod, "MessageBoxA");

            if (dwFileSize <= sizeof(a))
                memcpy(&a.data, pHeapBuf, dwFileSize);
            HeapFree(hHeap, HEAP_NO_SERIALIZE, pHeapBuf);
            ZeroMemory(pHeapBuf, sizeof(a.data));

            if (dwFileSize <= sizeof(b))
                memcpy(&b.data, pHeapBuf, dwFileSize);

            A* pa = &a;
            B* pb = &b;
            pa->fn();
            pb->fn();
        }
        nResult = 1;
    }

    if (hFile)
        CloseHandle(hFile);
    if (pHeapBuf)
        HeapFree(hHeap, HEAP_NO_SERIALIZE, pHeapBuf);
    if (hHeap)
        HeapDestroy(hHeap);

    return nResult;
}

int main(int argc, char* argv[])
{
    if (ExploitMe())
        return 0;
    else
        return -1;
}