Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCL][RTC] Cache frontend invocation #16823

Merged
merged 20 commits into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions sycl-jit/common/include/Kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,9 @@ struct RTCDevImgInfo {

using RTCBundleInfo = DynArray<RTCDevImgInfo>;

// LLVM's APIs prefer `char *` for byte buffers.
using RTCDeviceCodeIR = DynArray<char>;

} // namespace jit_compiler

#endif // SYCL_FUSION_COMMON_KERNEL_H
1 change: 1 addition & 0 deletions sycl-jit/jit-compiler/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ add_llvm_library(sycl-jit

LINK_COMPONENTS
BitReader
BitWriter
Core
Support
Option
Expand Down
49 changes: 46 additions & 3 deletions sycl-jit/jit-compiler/include/KernelFusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,45 @@ class JITResult {
sycl::detail::string ErrorMessage;
};

class RTCHashResult {
public:
static RTCHashResult success(const char *Hash) {
return RTCHashResult{/*Failed=*/false, Hash};
}

static RTCHashResult failure(const char *PreprocLog) {
return RTCHashResult{/*Failed=*/true, PreprocLog};
}

bool failed() { return Failed; }

const char *getPreprocLog() {
assert(failed() && "No preprocessor log");
return HashOrLog.c_str();
}

const char *getHash() {
assert(!failed() && "No hash");
return HashOrLog.c_str();
}

private:
RTCHashResult(bool Failed, const char *HashOrLog)
: Failed(Failed), HashOrLog(HashOrLog) {}

bool Failed;
sycl::detail::string HashOrLog;
};

class RTCResult {
public:
explicit RTCResult(const char *BuildLog)
: Failed{true}, BundleInfo{}, BuildLog{BuildLog} {}

RTCResult(RTCBundleInfo &&BundleInfo, const char *BuildLog)
: Failed{false}, BundleInfo{std::move(BundleInfo)}, BuildLog{BuildLog} {}
RTCResult(RTCBundleInfo &&BundleInfo, RTCDeviceCodeIR &&DeviceCodeIR,
const char *BuildLog)
: Failed{false}, BundleInfo{std::move(BundleInfo)},
DeviceCodeIR(std::move(DeviceCodeIR)), BuildLog{BuildLog} {}

bool failed() const { return Failed; }

Expand All @@ -73,9 +105,15 @@ class RTCResult {
return BundleInfo;
}

const RTCDeviceCodeIR &getDeviceCodeIR() const {
assert(!failed() && "No device code IR");
return DeviceCodeIR;
}

private:
bool Failed;
RTCBundleInfo BundleInfo;
RTCDeviceCodeIR DeviceCodeIR;
sycl::detail::string BuildLog;
};

Expand All @@ -100,9 +138,14 @@ KF_EXPORT_SYMBOL JITResult materializeSpecConstants(
const char *KernelName, jit_compiler::SYCLKernelBinaryInfo &BinInfo,
View<unsigned char> SpecConstBlob);

KF_EXPORT_SYMBOL RTCHashResult calculateHash(InMemoryFile SourceFile,
View<InMemoryFile> IncludeFiles,
View<const char *> UserArgs);

KF_EXPORT_SYMBOL RTCResult compileSYCL(InMemoryFile SourceFile,
View<InMemoryFile> IncludeFiles,
View<const char *> UserArgs);
View<const char *> UserArgs,
View<char> CachedIR, bool SaveIR);

KF_EXPORT_SYMBOL void destroyBinary(BinaryAddress Address);

Expand Down
1 change: 1 addition & 0 deletions sycl-jit/jit-compiler/ld-version-script.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
/* Export the library entry points */
fuseKernels;
materializeSpecConstants;
calculateHash;
compileSYCL;
destroyBinary;
resetJITConfiguration;
Expand Down
96 changes: 81 additions & 15 deletions sycl-jit/jit-compiler/lib/KernelFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
#include "translation/SPIRVLLVMTranslation.h"

#include <llvm/ADT/StringExtras.h>
#include <llvm/Bitcode/BitcodeReader.h>
#include <llvm/Bitcode/BitcodeWriter.h>
#include <llvm/Support/Error.h>
#include <llvm/Support/MemoryBuffer.h>
#include <llvm/Support/TimeProfiler.h>

#include <clang/Driver/Options.h>
Expand All @@ -31,17 +34,21 @@ using namespace jit_compiler;
using FusedFunction = helper::FusionHelper::FusedFunction;
using FusedFunctionList = std::vector<FusedFunction>;

template <typename ResultType>
static ResultType errorTo(llvm::Error &&Err, const std::string &Msg) {
static std::string formatError(llvm::Error &&Err, const std::string &Msg) {
std::stringstream ErrMsg;
ErrMsg << Msg << "\nDetailed information:\n";
llvm::handleAllErrors(std::move(Err),
[&ErrMsg](const llvm::StringError &StrErr) {
// Cannot throw an exception here if LLVM itself is
// compiled without exception support.
ErrMsg << "\t" << StrErr.getMessage() << "\n";
});
return ResultType{ErrMsg.str().c_str()};
return ErrMsg.str();
}

template <typename ResultType>
static ResultType errorTo(llvm::Error &&Err, const std::string &Msg) {
// Cannot throw an exception here if LLVM itself is compiled without exception
// support.
return ResultType{formatError(std::move(Err), Msg).c_str()};
}

static std::vector<jit_compiler::NDRange>
Expand Down Expand Up @@ -240,10 +247,42 @@ fuseKernels(View<SYCLKernelInfo> KernelInformation, const char *FusedKernelName,
return JITResult{FusedKernelInfo};
}

extern "C" KF_EXPORT_SYMBOL RTCHashResult
calculateHash(InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
View<const char *> UserArgs) {
auto UserArgListOrErr = parseUserArgs(UserArgs);
if (!UserArgListOrErr) {
return RTCHashResult::failure(
formatError(UserArgListOrErr.takeError(),
"Parsing of user arguments failed")
.c_str());
}
llvm::opt::InputArgList UserArgList = std::move(*UserArgListOrErr);

auto Start = std::chrono::high_resolution_clock::now();
auto HashOrError = calculateHash(SourceFile, IncludeFiles, UserArgList);
if (!HashOrError) {
return RTCHashResult::failure(
formatError(HashOrError.takeError(), "Hashing failed").c_str());
}
auto Hash = *HashOrError;
auto Stop = std::chrono::high_resolution_clock::now();

if (UserArgList.hasArg(clang::driver::options::OPT_ftime_trace_EQ)) {
std::chrono::duration<double, std::milli> HashTime = Stop - Start;
llvm::dbgs() << "Hashing of " << SourceFile.Path << " took "
<< int(HashTime.count()) << " ms\n";
}

return RTCHashResult::success(Hash.c_str());
}

extern "C" KF_EXPORT_SYMBOL RTCResult
compileSYCL(InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
View<const char *> UserArgs) {
View<const char *> UserArgs, View<char> CachedIR, bool SaveIR) {
llvm::LLVMContext Context;
std::string BuildLog;
configureDiagnostics(Context, BuildLog);

auto UserArgListOrErr = parseUserArgs(UserArgs);
if (!UserArgListOrErr) {
Expand Down Expand Up @@ -272,16 +311,43 @@ compileSYCL(InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
Verbose);
}

auto ModuleOrErr =
compileDeviceCode(SourceFile, IncludeFiles, UserArgList, BuildLog);
if (!ModuleOrErr) {
return errorTo<RTCResult>(ModuleOrErr.takeError(),
"Device compilation failed");
std::unique_ptr<llvm::Module> Module;

if (CachedIR.size() > 0) {
llvm::StringRef IRStr{CachedIR.begin(), CachedIR.size()};
std::unique_ptr<llvm::MemoryBuffer> IRBuf =
llvm::MemoryBuffer::getMemBuffer(IRStr, /*BufferName=*/"",
/*RequiresNullTerminator=*/false);
auto ModuleOrError = llvm::parseBitcodeFile(*IRBuf, Context);
if (!ModuleOrError) {
// Not a fatal error, we'll just compile the source string normally.
BuildLog.append(formatError(ModuleOrError.takeError(),
"Loading of cached device code failed"));
} else {
Module = std::move(*ModuleOrError);
}
}

std::unique_ptr<llvm::LLVMContext> Context;
std::unique_ptr<llvm::Module> Module = std::move(*ModuleOrErr);
Context.reset(&Module->getContext());
bool FromSource = false;
if (!Module) {
auto ModuleOrErr = compileDeviceCode(SourceFile, IncludeFiles, UserArgList,
BuildLog, Context);
if (!ModuleOrErr) {
return errorTo<RTCResult>(ModuleOrErr.takeError(),
"Device compilation failed");
}

Module = std::move(*ModuleOrErr);
FromSource = true;
}

RTCDeviceCodeIR IR;
if (SaveIR && FromSource) {
std::string BCString;
llvm::raw_string_ostream BCStream{BCString};
llvm::WriteBitcodeToFile(*Module, BCStream);
IR = RTCDeviceCodeIR{BCString.data(), BCString.data() + BCString.size()};
}

if (auto Error = linkDeviceLibraries(*Module, UserArgList, BuildLog)) {
return errorTo<RTCResult>(std::move(Error), "Device linking failed");
Expand Down Expand Up @@ -314,7 +380,7 @@ compileSYCL(InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
}
}

return RTCResult{std::move(BundleInfo), BuildLog.c_str()};
return RTCResult{std::move(BundleInfo), std::move(IR), BuildLog.c_str()};
}

extern "C" KF_EXPORT_SYMBOL void destroyBinary(BinaryAddress Address) {
Expand Down
Loading
Loading