tensorflow-core-framework-kernel_def_builder.cc 2019-06-19 817 tensorflow-core-framework ```cpp #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/kernel_def.pb_text.h" #include "tensorflow/core/framework/kernel_def.pb.h" namespace tensorflow { KernelDefBuilder::KernelDefBuilder(const char* op_name) { kernel_def_ = new KernelDef; kernel_def_->set_op(op_name); } KernelDefBuilder::~KernelDefBuilder() { DCHECK(kernel_def_ == nullptr) << "Did not call Build()"; } KernelDefBuilder& KernelDefBuilder::Device(const char* device_type) { kernel_def_->set_device_type(device_type); return *this; } template <> KernelDefBuilder& KernelDefBuilder::AttrConstraint( const char* attr_name, gtl::ArraySlice allowed) { auto* constraint = kernel_def_->add_constraint(); constraint->set_name(attr_name); auto* allowed_values = constraint->mutable_allowed_values()->mutable_list(); for (const int64 integer : allowed) { LOG(INFO) << integer; allowed_values->add_i(integer); } return *this; } template <> KernelDefBuilder& KernelDefBuilder::AttrConstraint(const char* attr_name, int64 allowed) { return AttrConstraint( attr_name, gtl::ArraySlice(std::initializer_list({allowed}))); } template <> KernelDefBuilder& KernelDefBuilder::AttrConstraint( const char* attr_name, gtl::ArraySlice allowed) { auto* constraint = kernel_def_->add_constraint(); constraint->set_name(attr_name); auto* allowed_values = constraint->mutable_allowed_values()->mutable_list(); for (const auto& str : allowed) { allowed_values->add_s(str); } return *this; } template <> KernelDefBuilder& KernelDefBuilder::AttrConstraint( const char* attr_name, string allowed) { return AttrConstraint( attr_name, gtl::ArraySlice(std::initializer_list({allowed}))); } template <> KernelDefBuilder& KernelDefBuilder::AttrConstraint( const char* attr_name, gtl::ArraySlice allowed) { auto* constraint = kernel_def_->add_constraint(); constraint->set_name(attr_name); auto* allowed_values = constraint->mutable_allowed_values()->mutable_list(); for (const auto& str : allowed) { allowed_values->add_s(str); } return *this; } template <> KernelDefBuilder& KernelDefBuilder::AttrConstraint( const char* attr_name, const char* allowed) { return AttrConstraint(attr_name, gtl::ArraySlice( std::initializer_list({allowed}))); } template <> KernelDefBuilder& KernelDefBuilder::AttrConstraint(const char* attr_name, bool allowed) { auto* constraint = kernel_def_->add_constraint(); constraint->set_name(attr_name); auto* allowed_values = constraint->mutable_allowed_values()->mutable_list(); allowed_values->add_b(allowed); return *this; } KernelDefBuilder& KernelDefBuilder::TypeConstraint( const char* attr_name, gtl::ArraySlice allowed) { auto* constraint = kernel_def_->add_constraint(); constraint->set_name(attr_name); auto* allowed_values = constraint->mutable_allowed_values()->mutable_list(); for (DataType dt : allowed) { allowed_values->add_type(dt); } return *this; } KernelDefBuilder& KernelDefBuilder::TypeConstraint(const char* attr_name, DataType allowed) { auto* constraint = kernel_def_->add_constraint(); constraint->set_name(attr_name); constraint->mutable_allowed_values()->mutable_list()->add_type(allowed); return *this; } KernelDefBuilder& KernelDefBuilder::HostMemory(const char* arg_name) { kernel_def_->add_host_memory_arg(arg_name); return *this; } KernelDefBuilder& KernelDefBuilder::Label(const char* label) { CHECK_EQ(kernel_def_->label(), "") << "Trying to set a kernel's label a second time: '" << label << "' in: " << ProtoShortDebugString(*kernel_def_); kernel_def_->set_label(label); return *this; } KernelDefBuilder& KernelDefBuilder::Priority(int32 priority) { kernel_def_->set_priority(priority); return *this; } const KernelDef* KernelDefBuilder::Build() { KernelDef* r = kernel_def_; kernel_def_ = nullptr; return r; } } // namespace tensorflow ``` 本文链接: http://www.codeeyes.net/archives/tensorflow-core-framework-kernel_def_builder_cc.html