#include "stdafx.h"
#define DECLSPEC __declspec(dllexport)
#include "PdbParser.h"
#undef DECLSPEC
#include "Callback.h"

IDiaDataSource* g_pDiaDataSource = NULL;
IDiaSession* g_pDiaSession = NULL;
IDiaSymbol* g_pGlobalSymbol = NULL;
DWORD g_dwMachineType = CV_CFL_80386;

////////////////////////////////////////////////////////////
// Create an IDiaData source and open a PDB file
//
bool LoadDataFromPdb(
    wchar_t          *wszFilename, 
    IDiaDataSource  **ppSource,
    IDiaSession     **ppSession,
    IDiaSymbol      **ppGlobal){

  HRESULT hr;
  wchar_t wszExt[MAX_PATH];
  wchar_t* wszSearchPath = L"SRV**\\\\symbols\\symbols"; // Alternate path to search for debug data
  DWORD dwMachType = 0;
  
  hr = CoInitialize(NULL);
  // Obtain Access To The Provider
  hr = CoCreateInstance(__uuidof(DiaSource),//CLSID_DiaSource, 
                        NULL, 
                        CLSCTX_INPROC_SERVER, 
                        __uuidof(IDiaDataSource),
                        (void **) ppSource);
  if(hr != S_OK){
    wprintf(L"CoCreateInstance failed - HRESULT = %x\n",hr);
    return false;
  }
  
  _wsplitpath_s(wszFilename,NULL,0,NULL,0,NULL,0,wszExt,MAX_PATH);
  if(!_wcsicmp(wszExt,L".pdb")){
    // Open and prepare a program database (.pdb) file as a debug data source
    hr = (*ppSource)->loadDataFromPdb(wszFilename);
    if(hr != S_OK){
      wprintf(L"loadDataFromPdb failed - HRESULT = %x\n",hr);
      return false;
    }
  }else{
    CCallback callback; // Receives callbacks from the DIA symbol locating procedure,
                        // thus enabling a user interface to report on the progress of 
                        // the location attempt. The client application may optionally 
                        // provide a reference to its own implementation of this 
                        // virtual base class to the IDiaDataSource::loadDataForExe method.
    callback.AddRef();
    // Open and prepare the debug data associated with the .exe/.dll file
    hr = (*ppSource)->loadDataForExe(wszFilename,wszSearchPath,&callback);
    if(hr != S_OK){
      wprintf(L"loadDataForExe failed - HRESULT = %x\n",hr);
      return false;
    }
  }
  // Open a session for querying symbols
  hr = (*ppSource)->openSession(ppSession);
  if(hr != S_OK){
    wprintf(L"openSession failed - HRESULT = %x\n",hr);
    return false;
  }
  // Retrieve a reference to the global scope
  hr = (*ppSession)->get_globalScope(ppGlobal);
  if(hr != S_OK){
    wprintf(L"get_globalScope failed\n");
    return false;
  }
  // Set Machine type for getting correct register names
  if((*ppGlobal)->get_machineType(&dwMachType) == S_OK){
    switch(dwMachType){
      case IMAGE_FILE_MACHINE_I386 : g_dwMachineType = CV_CFL_80386; break;
      case IMAGE_FILE_MACHINE_IA64 : g_dwMachineType = CV_CFL_IA64; break;
      case IMAGE_FILE_MACHINE_AMD64 : g_dwMachineType = CV_CFL_AMD64; break;
    }
  }
  return true;
}

__declspec(dllexport) bool __stdcall LoadDataFromPdb(wchar_t *pdbPath) {
	BSTR g_wszFilename = NULL;
	FILE *pFile;
	if(_wfopen_s(&pFile, pdbPath,L"r") || !pFile){
		// Invalid file name or file does not exist.
		return false;
	}
	fclose(pFile);
	g_wszFilename = pdbPath;
	return LoadDataFromPdb(g_wszFilename,&g_pDiaDataSource,&g_pDiaSession,&g_pGlobalSymbol);
}

BSTR GetSourceFileName(IDiaSourceFile* pSource) {
	BSTR wszSourceName;
	if(pSource->get_fileName(&wszSourceName) == S_OK){
		return wszSourceName;
	} else {
		return NULL;
	}
}

////////////////////////////////////////////////////////////
// Dump all the source file information stored in the PDB
// We have to go through every compiland in order to retrieve
//   all the information otherwise checksums for instance
//   will not be available
// Compilands can have multiple source files with the same 
//   name but different content which produces diffrent 
//   checksums
//
std::map<BSTR, std::vector<BSTR> *> *GetAllSourceFiles(IDiaSession *pSession, IDiaSymbol *pGlobal) {
  IDiaEnumSymbols* pEnumSymbols;
  IDiaSymbol* pCompiland;
  ULONG celt = 0;
  BSTR wszCompName;
  std::map<BSTR, std::vector<BSTR> *> *result = new std::map<BSTR, std::vector<BSTR> *>();
  
  //wprintf(L"\n\n*** SOURCE FILES\n\n");
  // To get the complete source file info we must go through the compiland first
  //  By passing NULL instead all the source file names only will be retrieved
  if(pGlobal->findChildren(SymTagCompiland, NULL, nsNone, &pEnumSymbols) == S_OK){
    while(pEnumSymbols->Next(1, &pCompiland, &celt) == S_OK && celt == 1){
      IDiaEnumSourceFiles* pEnumSourceFiles;
      
      pCompiland->get_name(&wszCompName);
	  std::vector<BSTR> *v = new std::vector<BSTR>();
	  result->insert(std::pair<BSTR, std::vector<BSTR> *>(wszCompName, v));
      //SysFreeString(wszCompName);
      // Retrieve all the source files from the given compiland 
      if(pSession->findFile(pCompiland, NULL, nsNone, &pEnumSourceFiles) == S_OK){
        IDiaSourceFile* pSrcFile;
        
        while(pEnumSourceFiles->Next(1, &pSrcFile, &celt) == S_OK && celt == 1){
          BSTR sourceFile = GetSourceFileName(pSrcFile);
		  if (sourceFile != NULL) {
			  v->push_back(sourceFile);
		  }
          pSrcFile->Release();
        }
        pEnumSourceFiles->Release();
      }
      pCompiland->Release();
    }
    pEnumSymbols->Release();
  }
  return result;
}

__declspec(dllexport) std::map<BSTR, std::vector<BSTR> *> * __stdcall GetAllSourceFiles() {
	return GetAllSourceFiles(g_pDiaSession, g_pGlobalSymbol);
}

BSTR GetSourceFile(IDiaLineNumber &pLine) {
	IDiaSourceFile* pSource;
	if(pLine.get_sourceFile(&pSource) == S_OK) {
		BSTR wszSourceName;
		if(pSource->get_fileName(&wszSourceName) == S_OK) {
			return wszSourceName;
		}
	}
	return L"";
}

////////////////////////////////////////////////////////////
// 
void GetLines(IDiaEnumLineNumbers* pLines, std::map<range<DWORD> *,std::pair<unsigned long, BSTR> *> *mappings) {
  IDiaLineNumber* pLine;
  DWORD celt;
  DWORD dwSrcIdLast = (DWORD)(-1);
  DWORD dwRVA, dwOffset, dwSeg, dwLinenum, dwSrcId, dwLength, dwStart = 0;

  while(pLines->Next(1, &pLine, &celt) == S_OK && celt == 1){
    if(pLine->get_relativeVirtualAddress(&dwRVA) == S_OK &&
       pLine->get_addressSection(&dwSeg) == S_OK &&
       pLine->get_addressOffset(&dwOffset) == S_OK &&
       pLine->get_lineNumber(&dwLinenum) == S_OK &&
       pLine->get_sourceFileId(&dwSrcId) == S_OK &&
       pLine->get_length(&dwLength) == S_OK ) {
		   mappings->insert(std::pair<range<DWORD> *,std::pair<unsigned long, BSTR> *>(
			   new range<DWORD>(dwStart, dwLength), 
			   new std::pair<unsigned long, BSTR>(dwLinenum, GetSourceFile(*pLine))));
		dwStart = (dwStart + dwLength);
	}
	pLine->Release();
  }
};

////////////////////////////////////////////////////////////
// 
void GetLines(IDiaSession* pSession,IDiaSymbol* pFunction,std::map<range<DWORD> *,std::pair<unsigned long,BSTR> *> *mappings) {
  DWORD dwSymTag;
  BSTR wszName;
  ULONGLONG ulLength;
  DWORD dwRVA;
  IDiaEnumLineNumbers* pLines;

  if(pFunction->get_symTag(&dwSymTag) != S_OK || dwSymTag != SymTagFunction){
    wprintf(L"ERROR - PrintLines() dwSymTag != SymTagFunction");
    return;
  }
  if(pFunction->get_name(&wszName) == S_OK){
    //wprintf(L"\n** %s\n\n", wszName);
    SysFreeString(wszName);
  }
  if(pFunction->get_length(&ulLength) != S_OK){
    wprintf(L"ERROR - PrintLines() get_length");
    return;
  }
  if(pFunction->get_relativeVirtualAddress(&dwRVA) == S_OK){
    if(pSession->findLinesByRVA(dwRVA,static_cast<DWORD>(ulLength),&pLines) == S_OK){
      GetLines(pLines, mappings);
      pLines->Release();
    }
  }else{
    DWORD dwSect, dwOffset;
    
    if(pFunction->get_addressSection(&dwSect) == S_OK &&
       pFunction->get_addressOffset(&dwOffset) == S_OK){
      if(pSession->findLinesByAddr(dwSect, dwOffset, static_cast<DWORD>(ulLength), &pLines) == S_OK){
        GetLines(pLines, mappings);
        pLines->Release();
      }
    }
  }
};

bool AreEqual(BSTR A, BSTR B) {
	wchar_t *aPtr = A;
	wchar_t *bPtr = B;
	while (*aPtr == *bPtr) {
		if (*aPtr == '\0') {
			break;
		}
		++aPtr; ++bPtr;
	}
	return (*aPtr == '\0' && *bPtr == '\0');
}

////////////////////////////////////////////////////////////
// Dump the all line numbering information for a specified
//  function symbol name (as a regular expression string)
//
std::map<range<DWORD> *,std::pair<unsigned long,BSTR> *> *GetLines(IDiaSession* pSession, IDiaSymbol* pGlobal, BSTR wszFuncName, BSTR clsName) {
  IDiaEnumSymbols* pEnumSymbols;
  IDiaSymbol* pFunction;
  ULONG celt = 0;
  std::map<range<DWORD> *,std::pair<unsigned long,BSTR> *> *result = new std::map<range<DWORD> *,std::pair<unsigned long,BSTR> *>();
  if(pGlobal->findChildren(SymTagFunction, wszFuncName, nsRegularExpression, &pEnumSymbols) == S_OK){
    while(pEnumSymbols->Next(1, &pFunction, &celt) == S_OK && celt == 1){
	  IDiaSymbol *pParent;
	  if (pFunction->get_lexicalParent(&pParent) == S_OK) {
		  BSTR parentName;
		  if (pParent->get_name(&parentName) == S_OK) {
			const std::wstring wstrParentName(parentName);
			if (AreEqual(clsName, parentName)) {
				GetLines(pSession, pFunction, result);
				break;
			  }
		  }
		  pParent->Release();
	  }
      pFunction->Release();
    }
    pEnumSymbols->Release();
  }
  return result;
};

__declspec(dllexport) std::map<range<DWORD> *,std::pair<unsigned long,BSTR> *> * __stdcall GetLineNumbers(BSTR wszFuncName, BSTR clsName) {
	return ::GetLines(g_pDiaSession, g_pGlobalSymbol, wszFuncName, clsName);
};

__declspec(dllexport) void __stdcall Cleanup(void) {
  if(g_pGlobalSymbol){
    g_pGlobalSymbol->Release();
    g_pGlobalSymbol = NULL;
  }
  if(g_pDiaSession){
    g_pDiaSession->Release();
    g_pDiaSession = NULL;
  }
  CoUninitialize();
}
