[Impeller] Add ability to unregister shaders (flutter/engine#37229)

This commit is contained in:
Brandon DeRosier
2022-11-02 03:40:54 -07:00
committed by GitHub
parent 7abb3d1b23
commit 4da499244f
8 changed files with 120 additions and 20 deletions

View File

@@ -137,4 +137,23 @@ void ShaderLibraryGLES::RegisterFunction(std::string name,
callback(true);
}
// |ShaderLibrary|
void ShaderLibraryGLES::UnregisterFunction(std::string name,
ShaderStage stage) {
ReaderLock lock(functions_mutex_);
const auto key = ShaderKey{name, stage};
auto found = functions_.find(key);
if (found != functions_.end()) {
VALIDATION_LOG << "Library function named " << name
<< " was not found, so it couldn't be unregistered.";
return;
}
functions_.erase(found);
return;
}
} // namespace impeller

View File

@@ -43,6 +43,9 @@ class ShaderLibraryGLES final : public ShaderLibrary {
std::shared_ptr<fml::Mapping> code,
RegistrationCallback callback) override;
// |ShaderLibrary|
void UnregisterFunction(std::string name, ShaderStage stage) override;
FML_DISALLOW_COPY_AND_ASSIGN(ShaderLibraryGLES);
};

View File

@@ -51,6 +51,9 @@ class ShaderLibraryMTL final : public ShaderLibrary {
std::shared_ptr<fml::Mapping> code,
RegistrationCallback callback) override;
// |ShaderLibrary|
void UnregisterFunction(std::string name, ShaderStage stage) override;
id<MTLDevice> GetDevice() const;
void RegisterLibrary(id<MTLLibrary> library);

View File

@@ -54,36 +54,38 @@ std::shared_ptr<const ShaderFunction> ShaderLibraryMTL::GetFunction(
ShaderKey key(name, stage);
if (auto found = functions_.find(key); found != functions_.end()) {
return found->second;
}
id<MTLFunction> function = nil;
{
ReaderLock lock(libraries_mutex_);
if (auto found = functions_.find(key); found != functions_.end()) {
return found->second;
}
for (size_t i = 0, count = [libraries_ count]; i < count; i++) {
function = [libraries_[i] newFunctionWithName:@(name.data())];
if (function) {
break;
}
}
}
if (function == nil) {
return nullptr;
}
if (function == nil) {
return nullptr;
}
if (function.functionType != ToMTLFunctionType(stage)) {
VALIDATION_LOG << "Library function named " << name
<< " was for an unexpected shader stage.";
return nullptr;
}
if (function.functionType != ToMTLFunctionType(stage)) {
VALIDATION_LOG << "Library function named " << name
<< " was for an unexpected shader stage.";
return nullptr;
}
auto func = std::shared_ptr<ShaderFunctionMTL>(new ShaderFunctionMTL(
library_id_, function, {name.data(), name.size()}, stage));
functions_[key] = func;
return func;
auto func = std::shared_ptr<ShaderFunctionMTL>(new ShaderFunctionMTL(
library_id_, function, {name.data(), name.size()}, stage));
functions_[key] = func;
return func;
}
}
id<MTLDevice> ShaderLibraryMTL::GetDevice() const {
@@ -141,6 +143,41 @@ void ShaderLibraryMTL::RegisterFunction(std::string name, // unused
}];
}
// |ShaderLibrary|
void ShaderLibraryMTL::UnregisterFunction(std::string name, ShaderStage stage) {
ReaderLock lock(libraries_mutex_);
// Find the shader library containing this function name and remove it.
bool found_library = false;
for (size_t i = [libraries_ count] - 1; i >= 0; i--) {
id<MTLFunction> function =
[libraries_[i] newFunctionWithName:@(name.data())];
if (function) {
[libraries_ removeObjectAtIndex:i];
found_library = true;
break;
}
}
if (!found_library) {
VALIDATION_LOG << "Library containing function " << name
<< " was not found, so it couldn't be unregistered.";
}
// Remove the shader from the function cache.
ShaderKey key(name, stage);
auto found = functions_.find(key);
if (found == functions_.end()) {
VALIDATION_LOG << "Library function named " << name
<< " was not found, so it couldn't be unregistered.";
return;
}
functions_.erase(found);
}
void ShaderLibraryMTL::RegisterLibrary(id<MTLLibrary> library) {
WriterLock lock(libraries_mutex_);
[libraries_ addObject:library];

View File

@@ -114,6 +114,8 @@ bool ShaderLibraryVK::IsValid() const {
std::shared_ptr<const ShaderFunction> ShaderLibraryVK::GetFunction(
std::string_view name,
ShaderStage stage) {
ReaderLock lock(functions_mutex_);
const auto key = ShaderKey{{name.data(), name.size()}, stage};
auto found = functions_.find(key);
if (found != functions_.end()) {
@@ -122,4 +124,22 @@ std::shared_ptr<const ShaderFunction> ShaderLibraryVK::GetFunction(
return nullptr;
}
// |ShaderLibrary|
void ShaderLibraryVK::UnregisterFunction(std::string name, ShaderStage stage) {
ReaderLock lock(functions_mutex_);
const auto key = ShaderKey{name, stage};
auto found = functions_.find(key);
if (found != functions_.end()) {
VALIDATION_LOG << "Library function named " << name
<< " was not found, so it couldn't be unregistered.";
return;
}
functions_.erase(found);
return;
}
} // namespace impeller

View File

@@ -6,6 +6,7 @@
#include "flutter/fml/macros.h"
#include "impeller/base/comparable.h"
#include "impeller/base/thread.h"
#include "impeller/renderer/backend/vulkan/vk.h"
#include "impeller/renderer/shader_key.h"
#include "impeller/renderer/shader_library.h"
@@ -23,6 +24,7 @@ class ShaderLibraryVK final : public ShaderLibrary {
private:
friend class ContextVK;
const UniqueID library_id_;
mutable RWMutex functions_mutex_;
ShaderFunctionMap functions_;
bool is_valid_ = false;
@@ -34,6 +36,9 @@ class ShaderLibraryVK final : public ShaderLibrary {
std::shared_ptr<const ShaderFunction> GetFunction(std::string_view name,
ShaderStage stage) override;
// |ShaderLibrary|
void UnregisterFunction(std::string name, ShaderStage stage) override;
FML_DISALLOW_COPY_AND_ASSIGN(ShaderLibraryVK);
};

View File

@@ -33,6 +33,8 @@ class ShaderLibrary : public std::enable_shared_from_this<ShaderLibrary> {
std::shared_ptr<fml::Mapping> code,
RegistrationCallback callback);
virtual void UnregisterFunction(std::string name, ShaderStage stage) = 0;
protected:
ShaderLibrary();

View File

@@ -207,9 +207,20 @@ TEST_P(RuntimeStageTest, CanRegisterStage) {
reg.set_value(result);
}));
ASSERT_TRUE(future.get());
auto function =
library->GetFunction(stage.GetEntrypoint(), ShaderStage::kFragment);
ASSERT_NE(function, nullptr);
{
auto function =
library->GetFunction(stage.GetEntrypoint(), ShaderStage::kFragment);
ASSERT_NE(function, nullptr);
}
// Check if unregistering works.
library->UnregisterFunction(stage.GetEntrypoint(), ShaderStage::kFragment);
{
auto function =
library->GetFunction(stage.GetEntrypoint(), ShaderStage::kFragment);
ASSERT_EQ(function, nullptr);
}
}
TEST_P(RuntimeStageTest, CanCreatePipelineFromRuntimeStage) {