1
0
mirror of https://github.com/kubernetes-sigs/descheduler.git synced 2026-01-28 14:41:10 +01:00

[v0.34.0] bump to kubernetes 1.34 deps

Signed-off-by: Amir Alavi <amiralavi7@gmail.com>
This commit is contained in:
Amir Alavi
2025-09-17 16:55:29 -04:00
parent e9188852ef
commit 1db6b615d1
1266 changed files with 100906 additions and 40660 deletions

View File

@@ -11,15 +11,17 @@ go_library(
"decls.go",
"env.go",
"folding.go",
"io.go",
"inlining.go",
"io.go",
"library.go",
"macro.go",
"optimizer.go",
"options.go",
"program.go",
"prompt.go",
"validator.go",
],
embedsrcs = ["//cel/templates"],
importpath = "github.com/google/cel-go/cel",
visibility = ["//visibility:public"],
deps = [
@@ -29,6 +31,7 @@ go_library(
"//common/ast:go_default_library",
"//common/containers:go_default_library",
"//common/decls:go_default_library",
"//common/env:go_default_library",
"//common/functions:go_default_library",
"//common/operators:go_default_library",
"//common/overloads:go_default_library",
@@ -61,9 +64,10 @@ go_test(
"decls_test.go",
"env_test.go",
"folding_test.go",
"io_test.go",
"inlining_test.go",
"io_test.go",
"optimizer_test.go",
"prompt_test.go",
"validator_test.go",
],
data = [
@@ -72,6 +76,9 @@ go_test(
embed = [
":go_default_library",
],
embedsrcs = [
"//cel/testdata:prompts",
],
deps = [
"//common/operators:go_default_library",
"//common/overloads:go_default_library",
@@ -83,8 +90,8 @@ go_test(
"//test/proto2pb:go_default_library",
"//test/proto3pb:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//encoding/prototext:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//types/known/structpb:go_default_library",
"@org_golang_google_protobuf//types/known/wrapperspb:go_default_library",
],

View File

@@ -142,8 +142,23 @@ func Constant(name string, t *Type, v ref.Val) EnvOption {
// Variable creates an instance of a variable declaration with a variable name and type.
func Variable(name string, t *Type) EnvOption {
return VariableWithDoc(name, t, "")
}
// VariableWithDoc creates an instance of a variable declaration with a variable name, type, and doc string.
func VariableWithDoc(name string, t *Type, doc string) EnvOption {
return func(e *Env) (*Env, error) {
e.variables = append(e.variables, decls.NewVariable(name, t))
e.variables = append(e.variables, decls.NewVariableWithDoc(name, t, doc))
return e, nil
}
}
// VariableDecls configures a set of fully defined cel.VariableDecl instances in the environment.
func VariableDecls(vars ...*decls.VariableDecl) EnvOption {
return func(e *Env) (*Env, error) {
for _, v := range vars {
e.variables = append(e.variables, v)
}
return e, nil
}
}
@@ -183,13 +198,38 @@ func Function(name string, opts ...FunctionOpt) EnvOption {
if err != nil {
return nil, err
}
if existing, found := e.functions[fn.Name()]; found {
fn, err = existing.Merge(fn)
if err != nil {
return nil, err
return FunctionDecls(fn)(e)
}
}
// OverloadSelector selects an overload associated with a given function when it returns true.
//
// Used in combination with the FunctionDecl.Subset method.
type OverloadSelector = decls.OverloadSelector
// IncludeOverloads defines an OverloadSelector which allow-lists a set of overloads by their ids.
func IncludeOverloads(overloadIDs ...string) OverloadSelector {
return decls.IncludeOverloads(overloadIDs...)
}
// ExcludeOverloads defines an OverloadSelector which deny-lists a set of overloads by their ids.
func ExcludeOverloads(overloadIDs ...string) OverloadSelector {
return decls.ExcludeOverloads(overloadIDs...)
}
// FunctionDecls provides one or more fully formed function declarations to be added to the environment.
func FunctionDecls(funcs ...*decls.FunctionDecl) EnvOption {
return func(e *Env) (*Env, error) {
var err error
for _, fn := range funcs {
if existing, found := e.functions[fn.Name()]; found {
fn, err = existing.Merge(fn)
if err != nil {
return nil, err
}
}
e.functions[fn.Name()] = fn
}
e.functions[fn.Name()] = fn
return e, nil
}
}
@@ -197,6 +237,13 @@ func Function(name string, opts ...FunctionOpt) EnvOption {
// FunctionOpt defines a functional option for configuring a function declaration.
type FunctionOpt = decls.FunctionOpt
// FunctionDocs provides a general usage documentation for the function.
//
// Use OverloadExamples to provide example usage instructions for specific overloads.
func FunctionDocs(docs ...string) FunctionOpt {
return decls.FunctionDocs(docs...)
}
// SingletonUnaryBinding creates a singleton function definition to be used for all function overloads.
//
// Note, this approach works well if operand is expected to have a specific trait which it implements,
@@ -270,6 +317,11 @@ func MemberOverload(overloadID string, args []*Type, resultType *Type, opts ...O
// OverloadOpt is a functional option for configuring a function overload.
type OverloadOpt = decls.OverloadOpt
// OverloadExamples configures an example of how to invoke the overload.
func OverloadExamples(docs ...string) OverloadOpt {
return decls.OverloadExamples(docs...)
}
// UnaryBinding provides the implementation of a unary overload. The provided function is protected by a runtime
// type-guard which ensures runtime type agreement between the overload signature and runtime argument types.
func UnaryBinding(binding functions.UnaryOp) OverloadOpt {
@@ -288,6 +340,12 @@ func FunctionBinding(binding functions.FunctionOp) OverloadOpt {
return decls.FunctionBinding(binding)
}
// LateFunctionBinding indicates that the function has a binding which is not known at compile time.
// This is useful for functions which have side-effects or are not deterministically computable.
func LateFunctionBinding() OverloadOpt {
return decls.LateFunctionBinding()
}
// OverloadIsNonStrict enables the function to be called with error and unknown argument values.
//
// Note: do not use this option unless absoluately necessary as it should be an uncommon feature.

View File

@@ -16,6 +16,8 @@ package cel
import (
"errors"
"fmt"
"math"
"sync"
"github.com/google/cel-go/checker"
@@ -24,12 +26,15 @@ import (
celast "github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/containers"
"github.com/google/cel-go/common/decls"
"github.com/google/cel-go/common/env"
"github.com/google/cel-go/common/stdlib"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/interpreter"
"github.com/google/cel-go/parser"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
"google.golang.org/protobuf/reflect/protoreflect"
)
// Source interface representing a user-provided expression.
@@ -127,12 +132,13 @@ type Env struct {
Container *containers.Container
variables []*decls.VariableDecl
functions map[string]*decls.FunctionDecl
macros []parser.Macro
macros []Macro
contextProto protoreflect.MessageDescriptor
adapter types.Adapter
provider types.Provider
features map[int]bool
appliedFeatures map[int]bool
libraries map[string]bool
libraries map[string]SingletonLibrary
validators []ASTValidator
costOptions []checker.CostOption
@@ -151,6 +157,134 @@ type Env struct {
progOpts []ProgramOption
}
// ToConfig produces a YAML-serializable env.Config object from the given environment.
//
// The serialized configuration value is intended to represent a baseline set of config
// options which could be used as input to an EnvOption to configure the majority of the
// environment from a file.
//
// Note: validators, features, flags, and safe-guard settings are not yet supported by
// the serialize method. Since optimizers are a separate construct from the environment
// and the standard expression components (parse, check, evalute), they are also not
// supported by the serialize method.
func (e *Env) ToConfig(name string) (*env.Config, error) {
conf := env.NewConfig(name)
// Container settings
if e.Container != containers.DefaultContainer {
conf.SetContainer(e.Container.Name())
}
for _, typeName := range e.Container.AliasSet() {
conf.AddImports(env.NewImport(typeName))
}
libOverloads := map[string][]string{}
for libName, lib := range e.libraries {
// Track the options which have been configured by a library and
// then diff the library version against the configured function
// to detect incremental overloads or rewrites.
libEnv, _ := NewCustomEnv()
libEnv, _ = Lib(lib)(libEnv)
for fnName, fnDecl := range libEnv.Functions() {
if len(fnDecl.OverloadDecls()) == 0 {
continue
}
overloads, exist := libOverloads[fnName]
if !exist {
overloads = make([]string, 0, len(fnDecl.OverloadDecls()))
}
for _, o := range fnDecl.OverloadDecls() {
overloads = append(overloads, o.ID())
}
libOverloads[fnName] = overloads
}
subsetLib, canSubset := lib.(LibrarySubsetter)
alias := ""
if aliasLib, canAlias := lib.(LibraryAliaser); canAlias {
alias = aliasLib.LibraryAlias()
libName = alias
}
if libName == "stdlib" && canSubset {
conf.SetStdLib(subsetLib.LibrarySubset())
continue
}
version := uint32(math.MaxUint32)
if versionLib, isVersioned := lib.(LibraryVersioner); isVersioned {
version = versionLib.LibraryVersion()
}
conf.AddExtensions(env.NewExtension(libName, version))
}
// If this is a custom environment without the standard env, mark the stdlib as disabled.
if conf.StdLib == nil && !e.HasLibrary("cel.lib.std") {
conf.SetStdLib(env.NewLibrarySubset().SetDisabled(true))
}
// Serialize the variables
vars := make([]*decls.VariableDecl, 0, len(e.Variables()))
stdTypeVars := map[string]*decls.VariableDecl{}
for _, v := range stdlib.Types() {
stdTypeVars[v.Name()] = v
}
for _, v := range e.Variables() {
if _, isStdType := stdTypeVars[v.Name()]; isStdType {
continue
}
vars = append(vars, v)
}
if e.contextProto != nil {
conf.SetContextVariable(env.NewContextVariable(string(e.contextProto.FullName())))
skipVariables := map[string]bool{}
fields := e.contextProto.Fields()
for i := 0; i < fields.Len(); i++ {
field := fields.Get(i)
variable, err := fieldToVariable(field)
if err != nil {
return nil, fmt.Errorf("could not serialize context field variable %q, reason: %w", field.FullName(), err)
}
skipVariables[variable.Name()] = true
}
for _, v := range vars {
if _, found := skipVariables[v.Name()]; !found {
conf.AddVariableDecls(v)
}
}
} else {
conf.AddVariableDecls(vars...)
}
// Serialize functions which are distinct from the ones configured by libraries.
for fnName, fnDecl := range e.Functions() {
if excludedOverloads, found := libOverloads[fnName]; found {
if newDecl := fnDecl.Subset(decls.ExcludeOverloads(excludedOverloads...)); newDecl != nil {
conf.AddFunctionDecls(newDecl)
}
} else {
conf.AddFunctionDecls(fnDecl)
}
}
// Serialize validators
for _, val := range e.Validators() {
// Only add configurable validators to the env.Config as all others are
// expected to be implicitly enabled via extension libraries.
if confVal, ok := val.(ConfigurableASTValidator); ok {
conf.AddValidators(confVal.ToConfig())
}
}
// Serialize features
for featID, enabled := range e.features {
featName, found := featureNameByID(featID)
if !found {
// If the feature isn't named, it isn't intended to be publicly exposed
continue
}
conf.AddFeatures(env.NewFeature(featName, enabled))
}
return conf, nil
}
// NewEnv creates a program environment configured with the standard library of CEL functions and
// macros. The Env value returned can parse and check any CEL program which builds upon the core
// features documented in the CEL specification.
@@ -194,7 +328,7 @@ func NewCustomEnv(opts ...EnvOption) (*Env, error) {
provider: registry,
features: map[int]bool{},
appliedFeatures: map[int]bool{},
libraries: map[string]bool{},
libraries: map[string]SingletonLibrary{},
validators: []ASTValidator{},
progOpts: []ProgramOption{},
costOptions: []checker.CostOption{},
@@ -362,7 +496,7 @@ func (e *Env) Extend(opts ...EnvOption) (*Env, error) {
for k, v := range e.functions {
funcsCopy[k] = v
}
libsCopy := make(map[string]bool, len(e.libraries))
libsCopy := make(map[string]SingletonLibrary, len(e.libraries))
for k, v := range e.libraries {
libsCopy[k] = v
}
@@ -376,6 +510,7 @@ func (e *Env) Extend(opts ...EnvOption) (*Env, error) {
variables: varsCopy,
functions: funcsCopy,
macros: macsCopy,
contextProto: e.contextProto,
progOpts: progOptsCopy,
adapter: adapter,
features: featuresCopy,
@@ -399,8 +534,8 @@ func (e *Env) HasFeature(flag int) bool {
// HasLibrary returns whether a specific SingletonLibrary has been configured in the environment.
func (e *Env) HasLibrary(libName string) bool {
configured, exists := e.libraries[libName]
return exists && configured
_, exists := e.libraries[libName]
return exists
}
// Libraries returns a list of SingletonLibrary that have been configured in the environment.
@@ -418,9 +553,27 @@ func (e *Env) HasFunction(functionName string) bool {
return ok
}
// Functions returns map of Functions, keyed by function name, that have been configured in the environment.
// Functions returns a shallow copy of the Functions, keyed by function name, that have been configured in the environment.
func (e *Env) Functions() map[string]*decls.FunctionDecl {
return e.functions
shallowCopy := make(map[string]*decls.FunctionDecl, len(e.functions))
for nm, fn := range e.functions {
shallowCopy[nm] = fn
}
return shallowCopy
}
// Variables returns a shallow copy of the variables associated with the environment.
func (e *Env) Variables() []*decls.VariableDecl {
shallowCopy := make([]*decls.VariableDecl, len(e.variables))
copy(shallowCopy, e.variables)
return shallowCopy
}
// Macros returns a shallow copy of macros associated with the environment.
func (e *Env) Macros() []Macro {
shallowCopy := make([]Macro, len(e.macros))
copy(shallowCopy, e.macros)
return shallowCopy
}
// HasValidator returns whether a specific ASTValidator has been configured in the environment.
@@ -433,6 +586,11 @@ func (e *Env) HasValidator(name string) bool {
return false
}
// Validators returns the set of ASTValidators configured on the environment.
func (e *Env) Validators() []ASTValidator {
return e.validators[:]
}
// Parse parses the input expression value `txt` to a Ast and/or a set of Issues.
//
// This form of Parse creates a Source value for the input `txt` and forwards to the
@@ -502,31 +660,30 @@ func (e *Env) TypeProvider() ref.TypeProvider {
return &interopLegacyTypeProvider{Provider: e.provider}
}
// UnknownVars returns an interpreter.PartialActivation which marks all variables declared in the
// Env as unknown AttributePattern values.
// UnknownVars returns a PartialActivation which marks all variables declared in the Env as
// unknown AttributePattern values.
//
// Note, the UnknownVars will behave the same as an interpreter.EmptyActivation unless the
// PartialAttributes option is provided as a ProgramOption.
func (e *Env) UnknownVars() interpreter.PartialActivation {
// Note, the UnknownVars will behave the same as an cel.NoVars() unless the PartialAttributes
// option is provided as a ProgramOption.
func (e *Env) UnknownVars() PartialActivation {
act := interpreter.EmptyActivation()
part, _ := PartialVars(act, e.computeUnknownVars(act)...)
return part
}
// PartialVars returns an interpreter.PartialActivation where all variables not in the input variable
// PartialVars returns a PartialActivation where all variables not in the input variable
// set, but which have been configured in the environment, are marked as unknown.
//
// The `vars` value may either be an interpreter.Activation or any valid input to the
// interpreter.NewActivation call.
// The `vars` value may either be an Activation or any valid input to the cel.NewActivation call.
//
// Note, this is equivalent to calling cel.PartialVars and manually configuring the set of unknown
// variables. For more advanced use cases of partial state where portions of an object graph, rather
// than top-level variables, are missing the PartialVars() method may be a more suitable choice.
//
// Note, the PartialVars will behave the same as an interpreter.EmptyActivation unless the
// PartialAttributes option is provided as a ProgramOption.
func (e *Env) PartialVars(vars any) (interpreter.PartialActivation, error) {
act, err := interpreter.NewActivation(vars)
// Note, the PartialVars will behave the same as cel.NoVars() unless the PartialAttributes
// option is provided as a ProgramOption.
func (e *Env) PartialVars(vars any) (PartialActivation, error) {
act, err := NewActivation(vars)
if err != nil {
return nil, err
}
@@ -598,10 +755,15 @@ func (e *Env) configure(opts []EnvOption) (*Env, error) {
}
}
// If the default UTC timezone fix has been enabled, make sure the library is configured
e, err = e.maybeApplyFeature(featureDefaultUTCTimeZone, Lib(timeUTCLibrary{}))
if err != nil {
return nil, err
// If the default UTC timezone has been disabled, configure the legacy overloads
if utcTime, isSet := e.features[featureDefaultUTCTimeZone]; isSet && !utcTime {
if !e.appliedFeatures[featureDefaultUTCTimeZone] {
e.appliedFeatures[featureDefaultUTCTimeZone] = true
e, err = Lib(timeLegacyLibrary{})(e)
if err != nil {
return nil, err
}
}
}
// Configure the parser.
@@ -685,30 +847,9 @@ func (e *Env) getCheckerOrError() (*checker.Env, error) {
return e.chk, e.chkErr
}
// maybeApplyFeature determines whether the feature-guarded option is enabled, and if so applies
// the feature if it has not already been enabled.
func (e *Env) maybeApplyFeature(feature int, option EnvOption) (*Env, error) {
if !e.HasFeature(feature) {
return e, nil
}
_, applied := e.appliedFeatures[feature]
if applied {
return e, nil
}
e, err := option(e)
if err != nil {
return nil, err
}
// record that the feature has been applied since it will generate declarations
// and functions which will be propagated on Extend() calls and which should only
// be registered once.
e.appliedFeatures[feature] = true
return e, nil
}
// computeUnknownVars determines a set of missing variables based on the input activation and the
// environment's configured declaration set.
func (e *Env) computeUnknownVars(vars interpreter.Activation) []*interpreter.AttributePattern {
func (e *Env) computeUnknownVars(vars Activation) []*interpreter.AttributePattern {
var unknownPatterns []*interpreter.AttributePattern
for _, v := range e.variables {
varName := v.Name()

View File

@@ -38,6 +38,23 @@ func MaxConstantFoldIterations(limit int) ConstantFoldingOption {
}
}
// Adds an Activation which provides known values for the folding evaluator
//
// Any values the activation provides will be used by the constant folder and turned into
// literals in the AST.
//
// Defaults to the NoVars() Activation
func FoldKnownValues(knownValues Activation) ConstantFoldingOption {
return func(opt *constantFoldingOptimizer) (*constantFoldingOptimizer, error) {
if knownValues != nil {
opt.knownValues = knownValues
} else {
opt.knownValues = NoVars()
}
return opt, nil
}
}
// NewConstantFoldingOptimizer creates an optimizer which inlines constant scalar an aggregate
// literal values within function calls and select statements with their evaluated result.
func NewConstantFoldingOptimizer(opts ...ConstantFoldingOption) (ASTOptimizer, error) {
@@ -56,6 +73,7 @@ func NewConstantFoldingOptimizer(opts ...ConstantFoldingOption) (ASTOptimizer, e
type constantFoldingOptimizer struct {
maxFoldIterations int
knownValues Activation
}
// Optimize queries the expression graph for scalar and aggregate literal expressions within call and
@@ -68,7 +86,8 @@ func (opt *constantFoldingOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST)
// Walk the list of foldable expression and continue to fold until there are no more folds left.
// All of the fold candidates returned by the constantExprMatcher should succeed unless there's
// a logic bug with the selection of expressions.
foldableExprs := ast.MatchDescendants(root, constantExprMatcher)
constantExprMatcherCapture := func(e ast.NavigableExpr) bool { return opt.constantExprMatcher(ctx, a, e) }
foldableExprs := ast.MatchDescendants(root, constantExprMatcherCapture)
foldCount := 0
for len(foldableExprs) != 0 && foldCount < opt.maxFoldIterations {
for _, fold := range foldableExprs {
@@ -77,21 +96,27 @@ func (opt *constantFoldingOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST)
if fold.Kind() == ast.CallKind && maybePruneBranches(ctx, fold) {
continue
}
// Late-bound function calls cannot be folded.
if fold.Kind() == ast.CallKind && isLateBoundFunctionCall(ctx, a, fold) {
continue
}
// Otherwise, assume all context is needed to evaluate the expression.
err := tryFold(ctx, a, fold)
if err != nil {
err := opt.tryFold(ctx, a, fold)
// Ignore errors for identifiers, since there is no guarantee that the environment
// has a value for them.
if err != nil && fold.Kind() != ast.IdentKind {
ctx.ReportErrorAtID(fold.ID(), "constant-folding evaluation failed: %v", err.Error())
return a
}
}
foldCount++
foldableExprs = ast.MatchDescendants(root, constantExprMatcher)
foldableExprs = ast.MatchDescendants(root, constantExprMatcherCapture)
}
// Once all of the constants have been folded, try to run through the remaining comprehensions
// one last time. In this case, there's no guarantee they'll run, so we only update the
// target comprehension node with the literal value if the evaluation succeeds.
for _, compre := range ast.MatchDescendants(root, ast.KindMatcher(ast.ComprehensionKind)) {
tryFold(ctx, a, compre)
opt.tryFold(ctx, a, compre)
}
// If the output is a list, map, or struct which contains optional entries, then prune it
@@ -121,7 +146,7 @@ func (opt *constantFoldingOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST)
//
// If the evaluation succeeds, the input expr value will be modified to become a literal, otherwise
// the method will return an error.
func tryFold(ctx *OptimizerContext, a *ast.AST, expr ast.Expr) error {
func (opt *constantFoldingOptimizer) tryFold(ctx *OptimizerContext, a *ast.AST, expr ast.Expr) error {
// Assume all context is needed to evaluate the expression.
subAST := &Ast{
impl: ast.NewCheckedAST(ast.NewAST(expr, a.SourceInfo()), a.TypeMap(), a.ReferenceMap()),
@@ -130,7 +155,11 @@ func tryFold(ctx *OptimizerContext, a *ast.AST, expr ast.Expr) error {
if err != nil {
return err
}
out, _, err := prg.Eval(NoVars())
activation := opt.knownValues
if activation == nil {
activation = NoVars()
}
out, _, err := prg.Eval(activation)
if err != nil {
return err
}
@@ -139,6 +168,15 @@ func tryFold(ctx *OptimizerContext, a *ast.AST, expr ast.Expr) error {
return nil
}
func isLateBoundFunctionCall(ctx *OptimizerContext, a *ast.AST, expr ast.Expr) bool {
call := expr.AsCall()
function := ctx.Functions()[call.FunctionName()]
if function == nil {
return false
}
return function.HasLateBinding()
}
// maybePruneBranches inspects the non-strict call expression to determine whether
// a branch can be removed. Evaluation will naturally prune logical and / or calls,
// but conditional will not be pruned cleanly, so this is one small area where the
@@ -455,13 +493,15 @@ func adaptLiteral(ctx *OptimizerContext, val ref.Val) (ast.Expr, error) {
// Only comprehensions which are not nested are included as possible constant folds, and only
// if all variables referenced in the comprehension stack exist are only iteration or
// accumulation variables.
func constantExprMatcher(e ast.NavigableExpr) bool {
func (opt *constantFoldingOptimizer) constantExprMatcher(ctx *OptimizerContext, a *ast.AST, e ast.NavigableExpr) bool {
switch e.Kind() {
case ast.CallKind:
return constantCallMatcher(e)
case ast.SelectKind:
sel := e.AsSelect() // guaranteed to be a navigable value
return constantMatcher(sel.Operand().(ast.NavigableExpr))
case ast.IdentKind:
return opt.knownValues != nil && a.ReferenceMap()[e.ID()] != nil
case ast.ComprehensionKind:
if isNestedComprehension(e) {
return false
@@ -477,6 +517,10 @@ func constantExprMatcher(e ast.NavigableExpr) bool {
if e.Kind() == ast.IdentKind && !vars[e.AsIdent()] {
constantExprs = false
}
// Late-bound function calls cannot be folded.
if e.Kind() == ast.CallKind && isLateBoundFunctionCall(ctx, a, e) {
constantExprs = false
}
})
ast.PreOrderVisit(e, visitor)
return constantExprs

View File

@@ -99,7 +99,13 @@ func AstToParsedExpr(a *Ast) (*exprpb.ParsedExpr, error) {
// Note, the conversion may not be an exact replica of the original expression, but will produce
// a string that is semantically equivalent and whose textual representation is stable.
func AstToString(a *Ast) (string, error) {
return parser.Unparse(a.NativeRep().Expr(), a.NativeRep().SourceInfo())
return ExprToString(a.NativeRep().Expr(), a.NativeRep().SourceInfo())
}
// ExprToString converts an AST Expr node back to a string using macro call tracking metadata from
// source info if any macros are encountered within the expression.
func ExprToString(e ast.Expr, info *ast.SourceInfo) (string, error) {
return parser.Unparse(e, info)
}
// RefValueToValue converts between ref.Val and google.api.expr.v1alpha1.Value.
@@ -120,6 +126,55 @@ func ValueAsAlphaProto(res ref.Val) (*exprpb.Value, error) {
return alpha, err
}
// RefValToExprValue converts between ref.Val and google.api.expr.v1alpha1.ExprValue.
// The result ExprValue is the serialized proto form.
func RefValToExprValue(res ref.Val) (*exprpb.ExprValue, error) {
return ExprValueAsAlphaProto(res)
}
// ExprValueAsAlphaProto converts between ref.Val and google.api.expr.v1alpha1.ExprValue.
// The result ExprValue is the serialized proto form.
func ExprValueAsAlphaProto(res ref.Val) (*exprpb.ExprValue, error) {
canonical, err := ExprValueAsProto(res)
if err != nil {
return nil, err
}
alpha := &exprpb.ExprValue{}
err = convertProto(canonical, alpha)
return alpha, err
}
// ExprValueAsProto converts between ref.Val and cel.expr.ExprValue.
// The result ExprValue is the serialized proto form.
func ExprValueAsProto(res ref.Val) (*celpb.ExprValue, error) {
switch res := res.(type) {
case *types.Unknown:
return &celpb.ExprValue{
Kind: &celpb.ExprValue_Unknown{
Unknown: &celpb.UnknownSet{
Exprs: res.IDs(),
},
}}, nil
case *types.Err:
return &celpb.ExprValue{
Kind: &celpb.ExprValue_Error{
Error: &celpb.ErrorSet{
// Keeping the error code as UNKNOWN since there's no error codes associated with
// Cel-Go runtime errors.
Errors: []*celpb.Status{{Code: 2, Message: res.Error()}},
},
},
}, nil
default:
val, err := ValueAsProto(res)
if err != nil {
return nil, err
}
return &celpb.ExprValue{
Kind: &celpb.ExprValue_Value{Value: val}}, nil
}
}
// ValueAsProto converts between ref.Val and cel.expr.Value.
// The result Value is the serialized proto form. The ref.Val must not be error or unknown.
func ValueAsProto(res ref.Val) (*celpb.Value, error) {

View File

@@ -17,11 +17,11 @@ package cel
import (
"fmt"
"math"
"strconv"
"strings"
"time"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/decls"
"github.com/google/cel-go/common/env"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/stdlib"
@@ -71,6 +71,23 @@ type SingletonLibrary interface {
LibraryName() string
}
// LibraryAliaser generates a simple named alias for the library, for use during environment serialization.
type LibraryAliaser interface {
LibraryAlias() string
}
// LibrarySubsetter provides the subset description associated with the library, nil if not subset.
type LibrarySubsetter interface {
LibrarySubset() *env.LibrarySubset
}
// LibraryVersioner provides a version number for the library.
//
// If not implemented, the library version will be flagged as 'latest' during environment serialization.
type LibraryVersioner interface {
LibraryVersion() uint32
}
// Lib creates an EnvOption out of a Library, allowing libraries to be provided as functional args,
// and to be linked to each other.
func Lib(l Library) EnvOption {
@@ -80,7 +97,7 @@ func Lib(l Library) EnvOption {
if e.HasLibrary(singleton.LibraryName()) {
return e, nil
}
e.libraries[singleton.LibraryName()] = true
e.libraries[singleton.LibraryName()] = singleton
}
var err error
for _, opt := range l.CompileOptions() {
@@ -94,26 +111,79 @@ func Lib(l Library) EnvOption {
}
}
// StdLibOption specifies a functional option for configuring the standard CEL library.
type StdLibOption func(*stdLibrary) *stdLibrary
// StdLibSubset configures the standard library to use a subset of its functions and macros.
//
// Since the StdLib is a singleton library, only the first instance of the StdLib() environment options
// will be configured on the environment which means only the StdLibSubset() initially configured with
// the library will be used.
func StdLibSubset(subset *env.LibrarySubset) StdLibOption {
return func(lib *stdLibrary) *stdLibrary {
lib.subset = subset
return lib
}
}
// StdLib returns an EnvOption for the standard library of CEL functions and macros.
func StdLib() EnvOption {
return Lib(stdLibrary{})
func StdLib(opts ...StdLibOption) EnvOption {
lib := &stdLibrary{}
for _, o := range opts {
lib = o(lib)
}
return Lib(lib)
}
// stdLibrary implements the Library interface and provides functional options for the core CEL
// features documented in the specification.
type stdLibrary struct{}
type stdLibrary struct {
subset *env.LibrarySubset
}
// LibraryName implements the SingletonLibrary interface method.
func (stdLibrary) LibraryName() string {
func (*stdLibrary) LibraryName() string {
return "cel.lib.std"
}
// LibraryAlias returns the simple name of the library.
func (*stdLibrary) LibraryAlias() string {
return "stdlib"
}
// LibrarySubset returns the env.LibrarySubset definition associated with the CEL Library.
func (lib *stdLibrary) LibrarySubset() *env.LibrarySubset {
return lib.subset
}
// CompileOptions returns options for the standard CEL function declarations and macros.
func (stdLibrary) CompileOptions() []EnvOption {
func (lib *stdLibrary) CompileOptions() []EnvOption {
funcs := stdlib.Functions()
macros := StandardMacros
if lib.subset != nil {
subMacros := []Macro{}
for _, m := range macros {
if lib.subset.SubsetMacro(m.Function()) {
subMacros = append(subMacros, m)
}
}
macros = subMacros
subFuncs := []*decls.FunctionDecl{}
for _, fn := range funcs {
if f, include := lib.subset.SubsetFunction(fn); include {
subFuncs = append(subFuncs, f)
}
}
funcs = subFuncs
}
return []EnvOption{
func(e *Env) (*Env, error) {
var err error
for _, fn := range stdlib.Functions() {
if err = lib.subset.Validate(); err != nil {
return nil, err
}
e.variables = append(e.variables, stdlib.Types()...)
for _, fn := range funcs {
existing, found := e.functions[fn.Name()]
if found {
fn, err = existing.Merge(fn)
@@ -125,16 +195,12 @@ func (stdLibrary) CompileOptions() []EnvOption {
}
return e, nil
},
func(e *Env) (*Env, error) {
e.variables = append(e.variables, stdlib.Types()...)
return e, nil
},
Macros(StandardMacros...),
Macros(macros...),
}
}
// ProgramOptions returns function implementations for the standard CEL functions.
func (stdLibrary) ProgramOptions() []ProgramOption {
func (*stdLibrary) ProgramOptions() []ProgramOption {
return []ProgramOption{}
}
@@ -263,7 +329,7 @@ func (stdLibrary) ProgramOptions() []ProgramOption {
// be expressed with `optMap`.
//
// msg.?elements.optFlatMap(e, e[?0]) // return the first element if present.
//
// # First
//
// Introduced in version: 2
@@ -272,7 +338,7 @@ func (stdLibrary) ProgramOptions() []ProgramOption {
// optional.None.
//
// [1, 2, 3].first().value() == 1
//
// # Last
//
// Introduced in version: 2
@@ -283,7 +349,7 @@ func (stdLibrary) ProgramOptions() []ProgramOption {
// [1, 2, 3].last().value() == 3
//
// This is syntactic sugar for msg.elements[msg.elements.size()-1].
//
// # Unwrap / UnwrapOpt
//
// Introduced in version: 2
@@ -293,7 +359,6 @@ func (stdLibrary) ProgramOptions() []ProgramOption {
//
// optional.unwrap([optional.of(42), optional.none()]) == [42]
// [optional.of(42), optional.none()].unwrapOpt() == [42]
func OptionalTypes(opts ...OptionalTypesOption) EnvOption {
lib := &optionalLib{version: math.MaxUint32}
for _, opt := range opts {
@@ -326,10 +391,20 @@ func OptionalTypesVersion(version uint32) OptionalTypesOption {
}
// LibraryName implements the SingletonLibrary interface method.
func (lib *optionalLib) LibraryName() string {
func (*optionalLib) LibraryName() string {
return "cel.lib.optional"
}
// LibraryAlias returns the simple name of the library.
func (*optionalLib) LibraryAlias() string {
return "optional"
}
// LibraryVersion returns the version of the library.
func (lib *optionalLib) LibraryVersion() uint32 {
return lib.version
}
// CompileOptions implements the Library interface method.
func (lib *optionalLib) CompileOptions() []EnvOption {
paramTypeK := TypeParamType("K")
@@ -347,16 +422,29 @@ func (lib *optionalLib) CompileOptions() []EnvOption {
Types(types.OptionalType),
// Configure the optMap and optFlatMap macros.
Macros(ReceiverMacro(optMapMacro, 2, optMap)),
Macros(ReceiverMacro(optMapMacro, 2, optMap,
MacroDocs(`perform computation on the value if present and return the result as an optional`),
MacroExamples(
common.MultilineDescription(
`// sub with the prefix 'dev.cel' or optional.none()`,
`request.auth.tokens.?sub.optMap(id, 'dev.cel.' + id)`),
`optional.none().optMap(i, i * 2) // optional.none()`))),
// Global and member functions for working with optional values.
Function(optionalOfFunc,
FunctionDocs(`create a new optional_type(T) with a value where any value is considered valid`),
Overload("optional_of", []*Type{paramTypeV}, optionalTypeV,
OverloadExamples(`optional.of(1) // optional(1)`),
UnaryBinding(func(value ref.Val) ref.Val {
return types.OptionalOf(value)
}))),
Function(optionalOfNonZeroValueFunc,
FunctionDocs(`create a new optional_type(T) with a value, if the value is not a zero or empty value`),
Overload("optional_ofNonZeroValue", []*Type{paramTypeV}, optionalTypeV,
OverloadExamples(
`optional.ofNonZeroValue(null) // optional.none()`,
`optional.ofNonZeroValue("") // optional.none()`,
`optional.ofNonZeroValue("hello") // optional.of('hello')`),
UnaryBinding(func(value ref.Val) ref.Val {
v, isZeroer := value.(traits.Zeroer)
if !isZeroer || !v.IsZeroValue() {
@@ -365,18 +453,26 @@ func (lib *optionalLib) CompileOptions() []EnvOption {
return types.OptionalNone
}))),
Function(optionalNoneFunc,
FunctionDocs(`singleton value representing an optional without a value`),
Overload("optional_none", []*Type{}, optionalTypeV,
OverloadExamples(`optional.none()`),
FunctionBinding(func(values ...ref.Val) ref.Val {
return types.OptionalNone
}))),
Function(valueFunc,
FunctionDocs(`obtain the value contained by the optional, error if optional.none()`),
MemberOverload("optional_value", []*Type{optionalTypeV}, paramTypeV,
OverloadExamples(
`optional.of(1).value() // 1`,
`optional.none().value() // error`),
UnaryBinding(func(value ref.Val) ref.Val {
opt := value.(*types.Optional)
return opt.GetValue()
}))),
Function(hasValueFunc,
FunctionDocs(`determine whether the optional contains a value`),
MemberOverload("optional_hasValue", []*Type{optionalTypeV}, BoolType,
OverloadExamples(`optional.of({1: 2}).hasValue() // true`),
UnaryBinding(func(value ref.Val) ref.Val {
opt := value.(*types.Optional)
return types.Bool(opt.HasValue())
@@ -385,21 +481,43 @@ func (lib *optionalLib) CompileOptions() []EnvOption {
// Implementation of 'or' and 'orValue' are special-cased to support short-circuiting in the
// evaluation chain.
Function("or",
MemberOverload("optional_or_optional", []*Type{optionalTypeV, optionalTypeV}, optionalTypeV)),
FunctionDocs(`chain optional expressions together, picking the first valued optional expression`),
MemberOverload("optional_or_optional", []*Type{optionalTypeV, optionalTypeV}, optionalTypeV,
OverloadExamples(
`optional.none().or(optional.of(1)) // optional.of(1)`,
common.MultilineDescription(
`// either a value from the first list, a value from the second, or optional.none()`,
`[1, 2, 3][?x].or([3, 4, 5][?y])`)))),
Function("orValue",
MemberOverload("optional_orValue_value", []*Type{optionalTypeV, paramTypeV}, paramTypeV)),
FunctionDocs(`chain optional expressions together picking the first valued optional or the default value`),
MemberOverload("optional_orValue_value", []*Type{optionalTypeV, paramTypeV}, paramTypeV,
OverloadExamples(
common.MultilineDescription(
`// pick the value for the given key if the key exists, otherwise return 'you'`,
`{'hello': 'world', 'goodbye': 'cruel world'}[?greeting].orValue('you')`)))),
// OptSelect is handled specially by the type-checker, so the receiver's field type is used to determine the
// optput type.
Function(operators.OptSelect,
Overload("select_optional_field", []*Type{DynType, StringType}, optionalTypeV)),
FunctionDocs(`if the field is present create an optional of the field value, otherwise return optional.none()`),
Overload("select_optional_field", []*Type{DynType, StringType}, optionalTypeV,
OverloadExamples(
`msg.?field // optional.of(field) if non-empty, otherwise optional.none()`,
`msg.?field.?nested_field // optional.of(nested_field) if both field and nested_field are non-empty.`))),
// OptIndex is handled mostly like any other indexing operation on a list or map, so the type-checker can use
// these signatures to determine type-agreement without any special handling.
Function(operators.OptIndex,
Overload("list_optindex_optional_int", []*Type{listTypeV, IntType}, optionalTypeV),
FunctionDocs(`if the index is present create an optional of the field value, otherwise return optional.none()`),
Overload("list_optindex_optional_int", []*Type{listTypeV, IntType}, optionalTypeV,
OverloadExamples(`[1, 2, 3][?x] // element value if x is in the list size, else optional.none()`)),
Overload("optional_list_optindex_optional_int", []*Type{OptionalType(listTypeV), IntType}, optionalTypeV),
Overload("map_optindex_optional_value", []*Type{mapTypeKV, paramTypeK}, optionalTypeV),
Overload("map_optindex_optional_value", []*Type{mapTypeKV, paramTypeK}, optionalTypeV,
OverloadExamples(
`map_value[?key] // value at the key if present, else optional.none()`,
common.MultilineDescription(
`// map key-value if index is a valid map key, else optional.none()`,
`{0: 2, 2: 4, 6: 8}[?index]`))),
Overload("optional_map_optindex_optional_value", []*Type{OptionalType(mapTypeKV), paramTypeK}, optionalTypeV)),
// Index overloads to accommodate using an optional value as the operand.
@@ -408,45 +526,62 @@ func (lib *optionalLib) CompileOptions() []EnvOption {
Overload("optional_map_index_value", []*Type{OptionalType(mapTypeKV), paramTypeK}, optionalTypeV)),
}
if lib.version >= 1 {
opts = append(opts, Macros(ReceiverMacro(optFlatMapMacro, 2, optFlatMap)))
opts = append(opts, Macros(ReceiverMacro(optFlatMapMacro, 2, optFlatMap,
MacroDocs(`perform computation on the value if present and produce an optional value within the computation`),
MacroExamples(
common.MultilineDescription(
`// m = {'key': {}}`,
`m.?key.optFlatMap(k, k.?subkey) // optional.none()`),
common.MultilineDescription(
`// m = {'key': {'subkey': 'value'}}`,
`m.?key.optFlatMap(k, k.?subkey) // optional.of('value')`),
))))
}
if lib.version >= 2 {
opts = append(opts, Function("last",
FunctionDocs(`return the last value in a list if present, otherwise optional.none()`),
MemberOverload("list_last", []*Type{listTypeV}, optionalTypeV,
OverloadExamples(
`[].last() // optional.none()`,
`[1, 2, 3].last() ? optional.of(3)`),
UnaryBinding(func(v ref.Val) ref.Val {
list := v.(traits.Lister)
sz := list.Size().Value().(int64)
if sz == 0 {
sz := list.Size().(types.Int)
if sz == types.IntZero {
return types.OptionalNone
}
return types.OptionalOf(list.Get(types.Int(sz - 1)))
}),
),
))
opts = append(opts, Function("first",
FunctionDocs(`return the first value in a list if present, otherwise optional.none()`),
MemberOverload("list_first", []*Type{listTypeV}, optionalTypeV,
OverloadExamples(
`[].first() // optional.none()`,
`[1, 2, 3].first() ? optional.of(1)`),
UnaryBinding(func(v ref.Val) ref.Val {
list := v.(traits.Lister)
sz := list.Size().Value().(int64)
if sz == 0 {
sz := list.Size().(types.Int)
if sz == types.IntZero {
return types.OptionalNone
}
return types.OptionalOf(list.Get(types.Int(0)))
}),
),
))
opts = append(opts, Function(optionalUnwrapFunc,
FunctionDocs(`convert a list of optional values to a list containing only value which are not optional.none()`),
Overload("optional_unwrap", []*Type{listOptionalTypeV}, listTypeV,
OverloadExamples(`optional.unwrap([optional.of(1), optional.none()]) // [1]`),
UnaryBinding(optUnwrap))))
opts = append(opts, Function(unwrapOptFunc,
FunctionDocs(`convert a list of optional values to a list containing only value which are not optional.none()`),
MemberOverload("optional_unwrapOpt", []*Type{listOptionalTypeV}, listTypeV,
OverloadExamples(`[optional.of(1), optional.none()].unwrapOpt() // [1]`),
UnaryBinding(optUnwrap))))
}
@@ -460,6 +595,11 @@ func (lib *optionalLib) ProgramOptions() []ProgramOption {
}
}
// Version returns the current version of the library.
func (lib *optionalLib) Version() uint32 {
return lib.version
}
func optMap(meh MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *Error) {
varIdent := args[0]
varName := ""
@@ -633,250 +773,99 @@ func (opt *evalOptionalOrValue) Eval(ctx interpreter.Activation) ref.Val {
return opt.rhs.Eval(ctx)
}
type timeUTCLibrary struct{}
type timeLegacyLibrary struct{}
func (timeUTCLibrary) CompileOptions() []EnvOption {
func (timeLegacyLibrary) CompileOptions() []EnvOption {
return timeOverloadDeclarations
}
func (timeUTCLibrary) ProgramOptions() []ProgramOption {
func (timeLegacyLibrary) ProgramOptions() []ProgramOption {
return []ProgramOption{}
}
// Declarations and functions which enable using UTC on time.Time inputs when the timezone is unspecified
// in the CEL expression.
var (
utcTZ = types.String("UTC")
timeOverloadDeclarations = []EnvOption{
Function(overloads.TimeGetHours,
MemberOverload(overloads.DurationToHours, []*Type{DurationType}, IntType,
UnaryBinding(types.DurationGetHours))),
Function(overloads.TimeGetMinutes,
MemberOverload(overloads.DurationToMinutes, []*Type{DurationType}, IntType,
UnaryBinding(types.DurationGetMinutes))),
Function(overloads.TimeGetSeconds,
MemberOverload(overloads.DurationToSeconds, []*Type{DurationType}, IntType,
UnaryBinding(types.DurationGetSeconds))),
Function(overloads.TimeGetMilliseconds,
MemberOverload(overloads.DurationToMilliseconds, []*Type{DurationType}, IntType,
UnaryBinding(types.DurationGetMilliseconds))),
Function(overloads.TimeGetFullYear,
MemberOverload(overloads.TimestampToYear, []*Type{TimestampType}, IntType,
UnaryBinding(func(ts ref.Val) ref.Val {
return timestampGetFullYear(ts, utcTZ)
t := ts.(types.Timestamp)
return t.Receive(overloads.TimeGetFullYear, overloads.TimestampToYear, []ref.Val{})
}),
),
MemberOverload(overloads.TimestampToYearWithTz, []*Type{TimestampType, StringType}, IntType,
BinaryBinding(timestampGetFullYear),
),
),
Function(overloads.TimeGetMonth,
MemberOverload(overloads.TimestampToMonth, []*Type{TimestampType}, IntType,
UnaryBinding(func(ts ref.Val) ref.Val {
return timestampGetMonth(ts, utcTZ)
t := ts.(types.Timestamp)
return t.Receive(overloads.TimeGetMonth, overloads.TimestampToMonth, []ref.Val{})
}),
),
MemberOverload(overloads.TimestampToMonthWithTz, []*Type{TimestampType, StringType}, IntType,
BinaryBinding(timestampGetMonth),
),
),
Function(overloads.TimeGetDayOfYear,
MemberOverload(overloads.TimestampToDayOfYear, []*Type{TimestampType}, IntType,
UnaryBinding(func(ts ref.Val) ref.Val {
return timestampGetDayOfYear(ts, utcTZ)
}),
),
MemberOverload(overloads.TimestampToDayOfYearWithTz, []*Type{TimestampType, StringType}, IntType,
BinaryBinding(func(ts, tz ref.Val) ref.Val {
return timestampGetDayOfYear(ts, tz)
t := ts.(types.Timestamp)
return t.Receive(overloads.TimeGetDayOfYear, overloads.TimestampToDayOfYear, []ref.Val{})
}),
),
),
Function(overloads.TimeGetDayOfMonth,
MemberOverload(overloads.TimestampToDayOfMonthZeroBased, []*Type{TimestampType}, IntType,
UnaryBinding(func(ts ref.Val) ref.Val {
return timestampGetDayOfMonthZeroBased(ts, utcTZ)
t := ts.(types.Timestamp)
return t.Receive(overloads.TimeGetDayOfMonth, overloads.TimestampToDayOfMonthZeroBased, []ref.Val{})
}),
),
MemberOverload(overloads.TimestampToDayOfMonthZeroBasedWithTz, []*Type{TimestampType, StringType}, IntType,
BinaryBinding(timestampGetDayOfMonthZeroBased),
),
),
Function(overloads.TimeGetDate,
MemberOverload(overloads.TimestampToDayOfMonthOneBased, []*Type{TimestampType}, IntType,
UnaryBinding(func(ts ref.Val) ref.Val {
return timestampGetDayOfMonthOneBased(ts, utcTZ)
t := ts.(types.Timestamp)
return t.Receive(overloads.TimeGetDate, overloads.TimestampToDayOfMonthOneBased, []ref.Val{})
}),
),
MemberOverload(overloads.TimestampToDayOfMonthOneBasedWithTz, []*Type{TimestampType, StringType}, IntType,
BinaryBinding(timestampGetDayOfMonthOneBased),
),
),
Function(overloads.TimeGetDayOfWeek,
MemberOverload(overloads.TimestampToDayOfWeek, []*Type{TimestampType}, IntType,
UnaryBinding(func(ts ref.Val) ref.Val {
return timestampGetDayOfWeek(ts, utcTZ)
t := ts.(types.Timestamp)
return t.Receive(overloads.TimeGetDayOfWeek, overloads.TimestampToDayOfWeek, []ref.Val{})
}),
),
MemberOverload(overloads.TimestampToDayOfWeekWithTz, []*Type{TimestampType, StringType}, IntType,
BinaryBinding(timestampGetDayOfWeek),
),
),
Function(overloads.TimeGetHours,
MemberOverload(overloads.TimestampToHours, []*Type{TimestampType}, IntType,
UnaryBinding(func(ts ref.Val) ref.Val {
return timestampGetHours(ts, utcTZ)
t := ts.(types.Timestamp)
return t.Receive(overloads.TimeGetHours, overloads.TimestampToHours, []ref.Val{})
}),
),
MemberOverload(overloads.TimestampToHoursWithTz, []*Type{TimestampType, StringType}, IntType,
BinaryBinding(timestampGetHours),
),
),
Function(overloads.TimeGetMinutes,
MemberOverload(overloads.TimestampToMinutes, []*Type{TimestampType}, IntType,
UnaryBinding(func(ts ref.Val) ref.Val {
return timestampGetMinutes(ts, utcTZ)
t := ts.(types.Timestamp)
return t.Receive(overloads.TimeGetMinutes, overloads.TimestampToMinutes, []ref.Val{})
}),
),
MemberOverload(overloads.TimestampToMinutesWithTz, []*Type{TimestampType, StringType}, IntType,
BinaryBinding(timestampGetMinutes),
),
),
Function(overloads.TimeGetSeconds,
MemberOverload(overloads.TimestampToSeconds, []*Type{TimestampType}, IntType,
UnaryBinding(func(ts ref.Val) ref.Val {
return timestampGetSeconds(ts, utcTZ)
t := ts.(types.Timestamp)
return t.Receive(overloads.TimeGetSeconds, overloads.TimestampToSeconds, []ref.Val{})
}),
),
MemberOverload(overloads.TimestampToSecondsWithTz, []*Type{TimestampType, StringType}, IntType,
BinaryBinding(timestampGetSeconds),
),
),
Function(overloads.TimeGetMilliseconds,
MemberOverload(overloads.TimestampToMilliseconds, []*Type{TimestampType}, IntType,
UnaryBinding(func(ts ref.Val) ref.Val {
return timestampGetMilliseconds(ts, utcTZ)
t := ts.(types.Timestamp)
return t.Receive(overloads.TimeGetMilliseconds, overloads.TimestampToMilliseconds, []ref.Val{})
}),
),
MemberOverload(overloads.TimestampToMillisecondsWithTz, []*Type{TimestampType, StringType}, IntType,
BinaryBinding(timestampGetMilliseconds),
),
),
}
)
func timestampGetFullYear(ts, tz ref.Val) ref.Val {
t, err := inTimeZone(ts, tz)
if err != nil {
return types.NewErrFromString(err.Error())
}
return types.Int(t.Year())
}
func timestampGetMonth(ts, tz ref.Val) ref.Val {
t, err := inTimeZone(ts, tz)
if err != nil {
return types.NewErrFromString(err.Error())
}
// CEL spec indicates that the month should be 0-based, but the Time value
// for Month() is 1-based.
return types.Int(t.Month() - 1)
}
func timestampGetDayOfYear(ts, tz ref.Val) ref.Val {
t, err := inTimeZone(ts, tz)
if err != nil {
return types.NewErrFromString(err.Error())
}
return types.Int(t.YearDay() - 1)
}
func timestampGetDayOfMonthZeroBased(ts, tz ref.Val) ref.Val {
t, err := inTimeZone(ts, tz)
if err != nil {
return types.NewErrFromString(err.Error())
}
return types.Int(t.Day() - 1)
}
func timestampGetDayOfMonthOneBased(ts, tz ref.Val) ref.Val {
t, err := inTimeZone(ts, tz)
if err != nil {
return types.NewErrFromString(err.Error())
}
return types.Int(t.Day())
}
func timestampGetDayOfWeek(ts, tz ref.Val) ref.Val {
t, err := inTimeZone(ts, tz)
if err != nil {
return types.NewErrFromString(err.Error())
}
return types.Int(t.Weekday())
}
func timestampGetHours(ts, tz ref.Val) ref.Val {
t, err := inTimeZone(ts, tz)
if err != nil {
return types.NewErrFromString(err.Error())
}
return types.Int(t.Hour())
}
func timestampGetMinutes(ts, tz ref.Val) ref.Val {
t, err := inTimeZone(ts, tz)
if err != nil {
return types.NewErrFromString(err.Error())
}
return types.Int(t.Minute())
}
func timestampGetSeconds(ts, tz ref.Val) ref.Val {
t, err := inTimeZone(ts, tz)
if err != nil {
return types.NewErrFromString(err.Error())
}
return types.Int(t.Second())
}
func timestampGetMilliseconds(ts, tz ref.Val) ref.Val {
t, err := inTimeZone(ts, tz)
if err != nil {
return types.NewErrFromString(err.Error())
}
return types.Int(t.Nanosecond() / 1000000)
}
func inTimeZone(ts, tz ref.Val) (time.Time, error) {
t := ts.(types.Timestamp)
val := string(tz.(types.String))
ind := strings.Index(val, ":")
if ind == -1 {
loc, err := time.LoadLocation(val)
if err != nil {
return time.Time{}, err
}
return t.In(loc), nil
}
// If the input is not the name of a timezone (for example, 'US/Central'), it should be a numerical offset from UTC
// in the format ^(+|-)(0[0-9]|1[0-4]):[0-5][0-9]$. The numerical input is parsed in terms of hours and minutes.
hr, err := strconv.Atoi(string(val[0:ind]))
if err != nil {
return time.Time{}, err
}
min, err := strconv.Atoi(string(val[ind+1:]))
if err != nil {
return time.Time{}, err
}
var offset int
if string(val[0]) == "-" {
offset = hr*60 - min
} else {
offset = hr*60 + min
}
secondsEastOfUTC := int((time.Duration(offset) * time.Minute).Seconds())
timezone := time.FixedZone("", secondsEastOfUTC)
return t.In(timezone), nil
}

View File

@@ -142,24 +142,38 @@ type MacroExprHelper interface {
NewError(exprID int64, message string) *Error
}
// MacroOpt defines a functional option for configuring macro behavior.
type MacroOpt = parser.MacroOpt
// MacroDocs configures a list of strings into a multiline description for the macro.
func MacroDocs(docs ...string) MacroOpt {
return parser.MacroDocs(docs...)
}
// MacroExamples configures a list of examples, either as a string or common.MultilineString,
// into an example set to be provided with the macro Documentation() call.
func MacroExamples(examples ...string) MacroOpt {
return parser.MacroExamples(examples...)
}
// GlobalMacro creates a Macro for a global function with the specified arg count.
func GlobalMacro(function string, argCount int, factory MacroFactory) Macro {
return parser.NewGlobalMacro(function, argCount, factory)
func GlobalMacro(function string, argCount int, factory MacroFactory, opts ...MacroOpt) Macro {
return parser.NewGlobalMacro(function, argCount, factory, opts...)
}
// ReceiverMacro creates a Macro for a receiver function matching the specified arg count.
func ReceiverMacro(function string, argCount int, factory MacroFactory) Macro {
return parser.NewReceiverMacro(function, argCount, factory)
func ReceiverMacro(function string, argCount int, factory MacroFactory, opts ...MacroOpt) Macro {
return parser.NewReceiverMacro(function, argCount, factory, opts...)
}
// GlobalVarArgMacro creates a Macro for a global function with a variable arg count.
func GlobalVarArgMacro(function string, factory MacroFactory) Macro {
return parser.NewGlobalVarArgMacro(function, factory)
func GlobalVarArgMacro(function string, factory MacroFactory, opts ...MacroOpt) Macro {
return parser.NewGlobalVarArgMacro(function, factory, opts...)
}
// ReceiverVarArgMacro creates a Macro for a receiver function matching a variable arg count.
func ReceiverVarArgMacro(function string, factory MacroFactory) Macro {
return parser.NewReceiverVarArgMacro(function, factory)
func ReceiverVarArgMacro(function string, factory MacroFactory, opts ...MacroOpt) Macro {
return parser.NewReceiverVarArgMacro(function, factory, opts...)
}
// NewGlobalMacro creates a Macro for a global function with the specified arg count.

View File

@@ -15,6 +15,7 @@
package cel
import (
"errors"
"fmt"
"google.golang.org/protobuf/proto"
@@ -25,6 +26,8 @@ import (
"github.com/google/cel-go/checker"
"github.com/google/cel-go/common/containers"
"github.com/google/cel-go/common/decls"
"github.com/google/cel-go/common/env"
"github.com/google/cel-go/common/functions"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/pb"
@@ -70,6 +73,26 @@ const (
featureIdentEscapeSyntax
)
var featureIDsToNames = map[int]string{
featureEnableMacroCallTracking: "cel.feature.macro_call_tracking",
featureCrossTypeNumericComparisons: "cel.feature.cross_type_numeric_comparisons",
featureIdentEscapeSyntax: "cel.feature.backtick_escape_syntax",
}
func featureNameByID(id int) (string, bool) {
name, found := featureIDsToNames[id]
return name, found
}
func featureIDByName(name string) (int, bool) {
for id, n := range featureIDsToNames {
if n == name {
return id, true
}
}
return 0, false
}
// EnvOption is a functional interface for configuring the environment.
type EnvOption func(e *Env) (*Env, error)
@@ -112,6 +135,8 @@ func CustomTypeProvider(provider any) EnvOption {
// Note: Declarations will by default be appended to the pre-existing declaration set configured
// for the environment. The NewEnv call builds on top of the standard CEL declarations. For a
// purely custom set of declarations use NewCustomEnv.
//
// Deprecated: use FunctionDecls and VariableDecls or FromConfig instead.
func Declarations(decls ...*exprpb.Decl) EnvOption {
declOpts := []EnvOption{}
var err error
@@ -379,7 +404,7 @@ type ProgramOption func(p *prog) (*prog, error)
// InterpretableDecorators can be used to inspect, alter, or replace the Program plan.
func CustomDecorator(dec interpreter.InterpretableDecorator) ProgramOption {
return func(p *prog) (*prog, error) {
p.decorators = append(p.decorators, dec)
p.plannerOptions = append(p.plannerOptions, interpreter.CustomDecorator(dec))
return p, nil
}
}
@@ -401,10 +426,10 @@ func Functions(funcs ...*functions.Overload) ProgramOption {
// variables with the same name provided to the Eval() call. If Globals is used in a Library with
// a Lib EnvOption, vars may shadow variables provided by previously added libraries.
//
// The vars value may either be an `interpreter.Activation` instance or a `map[string]any`.
// The vars value may either be an `cel.Activation` instance or a `map[string]any`.
func Globals(vars any) ProgramOption {
return func(p *prog) (*prog, error) {
defaultVars, err := interpreter.NewActivation(vars)
defaultVars, err := NewActivation(vars)
if err != nil {
return nil, err
}
@@ -426,6 +451,174 @@ func OptimizeRegex(regexOptimizations ...*interpreter.RegexOptimization) Program
}
}
// ConfigOptionFactory declares a signature which accepts a configuration element, e.g. env.Extension
// and optionally produces an EnvOption in response.
//
// If there are multiple ConfigOptionFactory values which could apply to the same configuration node
// the first one that returns an EnvOption and a `true` response will be used, and the config node
// will not be passed along to any other option factory.
//
// Only the *env.Extension type is provided at this time, but validators, optimizers, and other tuning
// parameters may be supported in the future.
type ConfigOptionFactory func(any) (EnvOption, bool)
// FromConfig produces and applies a set of EnvOption values derived from an env.Config object.
//
// For configuration elements which refer to features outside of the `cel` package, an optional set of
// ConfigOptionFactory values may be passed in to support the conversion from static configuration to
// configured cel.Env value.
//
// Note: disabling the standard library will clear the EnvOptions values previously set for the
// environment with the exception of propagating types and adapters over to the new environment.
//
// Note: to support custom types referenced in the configuration file, you must ensure that one of
// the following options appears before the FromConfig option: Types, TypeDescs, or CustomTypeProvider
// as the type provider configured at the time when the config is processed is the one used to derive
// type references from the configuration.
func FromConfig(config *env.Config, optFactories ...ConfigOptionFactory) EnvOption {
return func(e *Env) (*Env, error) {
if err := config.Validate(); err != nil {
return nil, err
}
opts, err := configToEnvOptions(config, e.CELTypeProvider(), optFactories)
if err != nil {
return nil, err
}
for _, o := range opts {
e, err = o(e)
if err != nil {
return nil, err
}
}
return e, nil
}
}
// configToEnvOptions generates a set of EnvOption values (or error) based on a config, a type provider,
// and an optional set of environment options.
func configToEnvOptions(config *env.Config, provider types.Provider, optFactories []ConfigOptionFactory) ([]EnvOption, error) {
envOpts := []EnvOption{}
// Configure the standard lib subset.
if config.StdLib != nil {
envOpts = append(envOpts, func(e *Env) (*Env, error) {
if e.HasLibrary("cel.lib.std") {
return nil, errors.New("invalid subset of stdlib: create a custom env")
}
return e, nil
})
if !config.StdLib.Disabled {
envOpts = append(envOpts, StdLib(StdLibSubset(config.StdLib)))
}
} else {
envOpts = append(envOpts, StdLib())
}
// Configure the container
if config.Container != "" {
envOpts = append(envOpts, Container(config.Container))
}
// Configure abbreviations
for _, imp := range config.Imports {
envOpts = append(envOpts, Abbrevs(imp.Name))
}
// Configure the context variable declaration
if config.ContextVariable != nil {
typeName := config.ContextVariable.TypeName
if _, found := provider.FindStructType(typeName); !found {
return nil, fmt.Errorf("invalid context proto type: %q", typeName)
}
// Attempt to instantiate the proto in order to reflect to its descriptor
msg := provider.NewValue(typeName, map[string]ref.Val{})
pbMsg, ok := msg.Value().(proto.Message)
if !ok {
return nil, fmt.Errorf("unsupported context type: %T", msg.Value())
}
envOpts = append(envOpts, DeclareContextProto(pbMsg.ProtoReflect().Descriptor()))
}
// Configure variables
if len(config.Variables) != 0 {
vars := make([]*decls.VariableDecl, 0, len(config.Variables))
for _, v := range config.Variables {
vDef, err := v.AsCELVariable(provider)
if err != nil {
return nil, err
}
vars = append(vars, vDef)
}
envOpts = append(envOpts, VariableDecls(vars...))
}
// Configure functions
if len(config.Functions) != 0 {
funcs := make([]*decls.FunctionDecl, 0, len(config.Functions))
for _, f := range config.Functions {
fnDef, err := f.AsCELFunction(provider)
if err != nil {
return nil, err
}
funcs = append(funcs, fnDef)
}
envOpts = append(envOpts, FunctionDecls(funcs...))
}
// Configure features
for _, feat := range config.Features {
// Note, if a feature is not found, it is skipped as it is possible the feature
// is not intended to be supported publicly. In the future, a refinement of
// to this strategy to report unrecognized features and validators should probably
// be covered as a standard ConfigOptionFactory
if id, found := featureIDByName(feat.Name); found {
envOpts = append(envOpts, features(id, feat.Enabled))
}
}
// Configure validators
for _, val := range config.Validators {
if fac, found := astValidatorFactories[val.Name]; found {
envOpts = append(envOpts, func(e *Env) (*Env, error) {
validator, err := fac(val)
if err != nil {
return nil, fmt.Errorf("%w", err)
}
return ASTValidators(validator)(e)
})
} else if opt, handled := handleExtendedConfigOption(val, optFactories); handled {
envOpts = append(envOpts, opt)
}
// we don't error when the validator isn't found as it may be part
// of an extension library and enabled implicitly.
}
// Configure extensions
for _, ext := range config.Extensions {
// version number has been validated by the call to `Validate`
ver, _ := ext.VersionNumber()
if ext.Name == "optional" {
envOpts = append(envOpts, OptionalTypes(OptionalTypesVersion(ver)))
} else {
opt, handled := handleExtendedConfigOption(ext, optFactories)
if !handled {
return nil, fmt.Errorf("unrecognized extension: %s", ext.Name)
}
envOpts = append(envOpts, opt)
}
}
return envOpts, nil
}
func handleExtendedConfigOption(conf any, optFactories []ConfigOptionFactory) (EnvOption, bool) {
for _, optFac := range optFactories {
if opt, useOption := optFac(conf); useOption {
return opt, true
}
}
return nil, false
}
// EvalOption indicates an evaluation option that may affect the evaluation behavior or information
// in the output result.
type EvalOption int
@@ -534,7 +727,7 @@ func fieldToCELType(field protoreflect.FieldDescriptor) (*Type, error) {
return nil, fmt.Errorf("field %s type %s not implemented", field.FullName(), field.Kind().String())
}
func fieldToVariable(field protoreflect.FieldDescriptor) (EnvOption, error) {
func fieldToVariable(field protoreflect.FieldDescriptor) (*decls.VariableDecl, error) {
name := string(field.Name())
if field.IsMap() {
mapKey := field.MapKey()
@@ -547,20 +740,20 @@ func fieldToVariable(field protoreflect.FieldDescriptor) (EnvOption, error) {
if err != nil {
return nil, err
}
return Variable(name, MapType(keyType, valueType)), nil
return decls.NewVariable(name, MapType(keyType, valueType)), nil
}
if field.IsList() {
elemType, err := fieldToCELType(field)
if err != nil {
return nil, err
}
return Variable(name, ListType(elemType)), nil
return decls.NewVariable(name, ListType(elemType)), nil
}
celType, err := fieldToCELType(field)
if err != nil {
return nil, err
}
return Variable(name, celType), nil
return decls.NewVariable(name, celType), nil
}
// DeclareContextProto returns an option to extend CEL environment with declarations from the given context proto.
@@ -568,17 +761,25 @@ func fieldToVariable(field protoreflect.FieldDescriptor) (EnvOption, error) {
// https://github.com/google/cel-spec/blob/master/doc/langdef.md#evaluation-environment
func DeclareContextProto(descriptor protoreflect.MessageDescriptor) EnvOption {
return func(e *Env) (*Env, error) {
if e.contextProto != nil {
return nil, fmt.Errorf("context proto already declared as %q, got %q",
e.contextProto.FullName(), descriptor.FullName())
}
e.contextProto = descriptor
fields := descriptor.Fields()
vars := make([]*decls.VariableDecl, 0, fields.Len())
for i := 0; i < fields.Len(); i++ {
field := fields.Get(i)
variable, err := fieldToVariable(field)
if err != nil {
return nil, err
}
e, err = variable(e)
if err != nil {
return nil, err
}
vars = append(vars, variable)
}
var err error
e, err = VariableDecls(vars...)(e)
if err != nil {
return nil, err
}
return Types(dynamicpb.NewMessage(descriptor))(e)
}
@@ -588,7 +789,7 @@ func DeclareContextProto(descriptor protoreflect.MessageDescriptor) EnvOption {
//
// Consider using with `DeclareContextProto` to simplify variable type declarations and publishing when using
// protocol buffers.
func ContextProtoVars(ctx proto.Message) (interpreter.Activation, error) {
func ContextProtoVars(ctx proto.Message) (Activation, error) {
if ctx == nil || !ctx.ProtoReflect().IsValid() {
return interpreter.EmptyActivation(), nil
}
@@ -612,7 +813,7 @@ func ContextProtoVars(ctx proto.Message) (interpreter.Activation, error) {
}
vars[field.TextName()] = fieldVal
}
return interpreter.NewActivation(vars)
return NewActivation(vars)
}
// EnableMacroCallTracking ensures that call expressions which are replaced by macros

View File

@@ -29,7 +29,7 @@ import (
type Program interface {
// Eval returns the result of an evaluation of the Ast and environment against the input vars.
//
// The vars value may either be an `interpreter.Activation` or a `map[string]any`.
// The vars value may either be an `Activation` or a `map[string]any`.
//
// If the `OptTrackState`, `OptTrackCost` or `OptExhaustiveEval` flags are used, the `details` response will
// be non-nil. Given this caveat on `details`, the return state from evaluation will be:
@@ -47,14 +47,39 @@ type Program interface {
// to support cancellation and timeouts. This method must be used in conjunction with the
// InterruptCheckFrequency() option for cancellation interrupts to be impact evaluation.
//
// The vars value may either be an `interpreter.Activation` or `map[string]any`.
// The vars value may either be an `Activation` or `map[string]any`.
//
// The output contract for `ContextEval` is otherwise identical to the `Eval` method.
ContextEval(context.Context, any) (ref.Val, *EvalDetails, error)
}
// Activation used to resolve identifiers by name and references by id.
//
// An Activation is the primary mechanism by which a caller supplies input into a CEL program.
type Activation = interpreter.Activation
// NewActivation returns an activation based on a map-based binding where the map keys are
// expected to be qualified names used with ResolveName calls.
//
// The input `bindings` may either be of type `Activation` or `map[string]any`.
//
// Lazy bindings may be supplied within the map-based input in either of the following forms:
// - func() any
// - func() ref.Val
//
// The output of the lazy binding will overwrite the variable reference in the internal map.
//
// Values which are not represented as ref.Val types on input may be adapted to a ref.Val using
// the types.Adapter configured in the environment.
func NewActivation(bindings any) (Activation, error) {
return interpreter.NewActivation(bindings)
}
// PartialActivation extends the Activation interface with a set of unknown AttributePatterns.
type PartialActivation = interpreter.PartialActivation
// NoVars returns an empty Activation.
func NoVars() interpreter.Activation {
func NoVars() Activation {
return interpreter.EmptyActivation()
}
@@ -64,10 +89,9 @@ func NoVars() interpreter.Activation {
// This method relies on manually configured sets of missing attribute patterns. For a method which
// infers the missing variables from the input and the configured environment, use Env.PartialVars().
//
// The `vars` value may either be an interpreter.Activation or any valid input to the
// interpreter.NewActivation call.
// The `vars` value may either be an Activation or any valid input to the NewActivation call.
func PartialVars(vars any,
unknowns ...*interpreter.AttributePattern) (interpreter.PartialActivation, error) {
unknowns ...*AttributePatternType) (PartialActivation, error) {
return interpreter.NewPartialActivation(vars, unknowns...)
}
@@ -84,12 +108,15 @@ func PartialVars(vars any,
// fully qualified variable name may be `ns.app.a`, `ns.a`, or `a` per the CEL namespace resolution
// rules. Pick the fully qualified variable name that makes sense within the container as the
// AttributePattern `varName` argument.
func AttributePattern(varName string) *AttributePatternType {
return interpreter.NewAttributePattern(varName)
}
// AttributePatternType represents a top-level variable with an optional set of qualifier patterns.
//
// See the interpreter.AttributePattern and interpreter.AttributeQualifierPattern for more info
// about how to create and manipulate AttributePattern values.
func AttributePattern(varName string) *interpreter.AttributePattern {
return interpreter.NewAttributePattern(varName)
}
type AttributePatternType = interpreter.AttributePattern
// EvalDetails holds additional information observed during the Eval() call.
type EvalDetails struct {
@@ -120,37 +147,24 @@ func (ed *EvalDetails) ActualCost() *uint64 {
type prog struct {
*Env
evalOpts EvalOption
defaultVars interpreter.Activation
defaultVars Activation
dispatcher interpreter.Dispatcher
interpreter interpreter.Interpreter
interruptCheckFrequency uint
// Intermediate state used to configure the InterpretableDecorator set provided
// to the initInterpretable call.
decorators []interpreter.InterpretableDecorator
plannerOptions []interpreter.PlannerOption
regexOptimizations []*interpreter.RegexOptimization
// Interpretable configured from an Ast and aggregate decorator set based on program options.
interpretable interpreter.Interpretable
observable *interpreter.ObservableInterpretable
callCostEstimator interpreter.ActualCostEstimator
costOptions []interpreter.CostTrackerOption
costLimit *uint64
}
func (p *prog) clone() *prog {
costOptsCopy := make([]interpreter.CostTrackerOption, len(p.costOptions))
copy(costOptsCopy, p.costOptions)
return &prog{
Env: p.Env,
evalOpts: p.evalOpts,
defaultVars: p.defaultVars,
dispatcher: p.dispatcher,
interpreter: p.interpreter,
interruptCheckFrequency: p.interruptCheckFrequency,
}
}
// newProgram creates a program instance with an environment, an ast, and an optional list of
// ProgramOption values.
//
@@ -162,10 +176,10 @@ func newProgram(e *Env, a *ast.AST, opts []ProgramOption) (Program, error) {
// Ensure the default attribute factory is set after the adapter and provider are
// configured.
p := &prog{
Env: e,
decorators: []interpreter.InterpretableDecorator{},
dispatcher: disp,
costOptions: []interpreter.CostTrackerOption{},
Env: e,
plannerOptions: []interpreter.PlannerOption{},
dispatcher: disp,
costOptions: []interpreter.CostTrackerOption{},
}
// Configure the program via the ProgramOption values.
@@ -203,74 +217,71 @@ func newProgram(e *Env, a *ast.AST, opts []ProgramOption) (Program, error) {
p.interpreter = interp
// Translate the EvalOption flags into InterpretableDecorator instances.
decorators := make([]interpreter.InterpretableDecorator, len(p.decorators))
copy(decorators, p.decorators)
plannerOptions := make([]interpreter.PlannerOption, len(p.plannerOptions))
copy(plannerOptions, p.plannerOptions)
// Enable interrupt checking if there's a non-zero check frequency
if p.interruptCheckFrequency > 0 {
decorators = append(decorators, interpreter.InterruptableEval())
plannerOptions = append(plannerOptions, interpreter.InterruptableEval())
}
// Enable constant folding first.
if p.evalOpts&OptOptimize == OptOptimize {
decorators = append(decorators, interpreter.Optimize())
plannerOptions = append(plannerOptions, interpreter.Optimize())
p.regexOptimizations = append(p.regexOptimizations, interpreter.MatchesRegexOptimization)
}
// Enable regex compilation of constants immediately after folding constants.
if len(p.regexOptimizations) > 0 {
decorators = append(decorators, interpreter.CompileRegexConstants(p.regexOptimizations...))
plannerOptions = append(plannerOptions, interpreter.CompileRegexConstants(p.regexOptimizations...))
}
// Enable exhaustive eval, state tracking and cost tracking last since they require a factory.
if p.evalOpts&(OptExhaustiveEval|OptTrackState|OptTrackCost) != 0 {
factory := func(state interpreter.EvalState, costTracker *interpreter.CostTracker) (Program, error) {
costTracker.Estimator = p.callCostEstimator
costTracker.Limit = p.costLimit
for _, costOpt := range p.costOptions {
err := costOpt(costTracker)
if err != nil {
return nil, err
}
}
// Limit capacity to guarantee a reallocation when calling 'append(decs, ...)' below. This
// prevents the underlying memory from being shared between factory function calls causing
// undesired mutations.
decs := decorators[:len(decorators):len(decorators)]
var observers []interpreter.EvalObserver
if p.evalOpts&(OptExhaustiveEval|OptTrackState) != 0 {
// EvalStateObserver is required for OptExhaustiveEval.
observers = append(observers, interpreter.EvalStateObserver(state))
}
if p.evalOpts&OptTrackCost == OptTrackCost {
observers = append(observers, interpreter.CostObserver(costTracker))
}
// Enable exhaustive eval over a basic observer since it offers a superset of features.
if p.evalOpts&OptExhaustiveEval == OptExhaustiveEval {
decs = append(decs, interpreter.ExhaustiveEval(), interpreter.Observe(observers...))
} else if len(observers) > 0 {
decs = append(decs, interpreter.Observe(observers...))
}
return p.clone().initInterpretable(a, decs)
costOptCount := len(p.costOptions)
if p.costLimit != nil {
costOptCount++
}
costOpts := make([]interpreter.CostTrackerOption, 0, costOptCount)
costOpts = append(costOpts, p.costOptions...)
if p.costLimit != nil {
costOpts = append(costOpts, interpreter.CostTrackerLimit(*p.costLimit))
}
trackerFactory := func() (*interpreter.CostTracker, error) {
return interpreter.NewCostTracker(p.callCostEstimator, costOpts...)
}
var observers []interpreter.PlannerOption
if p.evalOpts&(OptExhaustiveEval|OptTrackState) != 0 {
// EvalStateObserver is required for OptExhaustiveEval.
observers = append(observers, interpreter.EvalStateObserver())
}
if p.evalOpts&OptTrackCost == OptTrackCost {
observers = append(observers, interpreter.CostObserver(interpreter.CostTrackerFactory(trackerFactory)))
}
// Enable exhaustive eval over a basic observer since it offers a superset of features.
if p.evalOpts&OptExhaustiveEval == OptExhaustiveEval {
plannerOptions = append(plannerOptions,
append([]interpreter.PlannerOption{interpreter.ExhaustiveEval()}, observers...)...)
} else if len(observers) > 0 {
plannerOptions = append(plannerOptions, observers...)
}
return newProgGen(factory)
}
return p.initInterpretable(a, decorators)
return p.initInterpretable(a, plannerOptions)
}
func (p *prog) initInterpretable(a *ast.AST, decs []interpreter.InterpretableDecorator) (*prog, error) {
func (p *prog) initInterpretable(a *ast.AST, plannerOptions []interpreter.PlannerOption) (*prog, error) {
// When the AST has been exprAST it contains metadata that can be used to speed up program execution.
interpretable, err := p.interpreter.NewInterpretable(a, decs...)
interpretable, err := p.interpreter.NewInterpretable(a, plannerOptions...)
if err != nil {
return nil, err
}
p.interpretable = interpretable
if oi, ok := interpretable.(*interpreter.ObservableInterpretable); ok {
p.observable = oi
}
return p, nil
}
// Eval implements the Program interface method.
func (p *prog) Eval(input any) (v ref.Val, det *EvalDetails, err error) {
func (p *prog) Eval(input any) (out ref.Val, det *EvalDetails, err error) {
// Configure error recovery for unexpected panics during evaluation. Note, the use of named
// return values makes it possible to modify the error response during the recovery
// function.
@@ -285,9 +296,9 @@ func (p *prog) Eval(input any) (v ref.Val, det *EvalDetails, err error) {
}
}()
// Build a hierarchical activation if there are default vars set.
var vars interpreter.Activation
var vars Activation
switch v := input.(type) {
case interpreter.Activation:
case Activation:
vars = v
case map[string]any:
vars = activationPool.Setup(v)
@@ -298,12 +309,24 @@ func (p *prog) Eval(input any) (v ref.Val, det *EvalDetails, err error) {
if p.defaultVars != nil {
vars = interpreter.NewHierarchicalActivation(p.defaultVars, vars)
}
v = p.interpretable.Eval(vars)
if p.observable != nil {
det = &EvalDetails{}
out = p.observable.ObserveEval(vars, func(observed any) {
switch o := observed.(type) {
case interpreter.EvalState:
det.state = o
case *interpreter.CostTracker:
det.costTracker = o
}
})
} else {
out = p.interpretable.Eval(vars)
}
// The output of an internal Eval may have a value (`v`) that is a types.Err. This step
// translates the CEL value to a Go error response. This interface does not quite match the
// RPC signature which allows for multiple errors to be returned, but should be sufficient.
if types.IsError(v) {
err = v.(*types.Err)
if types.IsError(out) {
err = out.(*types.Err)
}
return
}
@@ -315,9 +338,9 @@ func (p *prog) ContextEval(ctx context.Context, input any) (ref.Val, *EvalDetail
}
// Configure the input, making sure to wrap Activation inputs in the special ctxActivation which
// exposes the #interrupted variable and manages rate-limited checks of the ctx.Done() state.
var vars interpreter.Activation
var vars Activation
switch v := input.(type) {
case interpreter.Activation:
case Activation:
vars = ctxActivationPool.Setup(v, ctx.Done(), p.interruptCheckFrequency)
defer ctxActivationPool.Put(vars)
case map[string]any:
@@ -331,90 +354,8 @@ func (p *prog) ContextEval(ctx context.Context, input any) (ref.Val, *EvalDetail
return p.Eval(vars)
}
// progFactory is a helper alias for marking a program creation factory function.
type progFactory func(interpreter.EvalState, *interpreter.CostTracker) (Program, error)
// progGen holds a reference to a progFactory instance and implements the Program interface.
type progGen struct {
factory progFactory
}
// newProgGen tests the factory object by calling it once and returns a factory-based Program if
// the test is successful.
func newProgGen(factory progFactory) (Program, error) {
// Test the factory to make sure that configuration errors are spotted at config
tracker, err := interpreter.NewCostTracker(nil)
if err != nil {
return nil, err
}
_, err = factory(interpreter.NewEvalState(), tracker)
if err != nil {
return nil, err
}
return &progGen{factory: factory}, nil
}
// Eval implements the Program interface method.
func (gen *progGen) Eval(input any) (ref.Val, *EvalDetails, error) {
// The factory based Eval() differs from the standard evaluation model in that it generates a
// new EvalState instance for each call to ensure that unique evaluations yield unique stateful
// results.
state := interpreter.NewEvalState()
costTracker, err := interpreter.NewCostTracker(nil)
if err != nil {
return nil, nil, err
}
det := &EvalDetails{state: state, costTracker: costTracker}
// Generate a new instance of the interpretable using the factory configured during the call to
// newProgram(). It is incredibly unlikely that the factory call will generate an error given
// the factory test performed within the Program() call.
p, err := gen.factory(state, costTracker)
if err != nil {
return nil, det, err
}
// Evaluate the input, returning the result and the 'state' within EvalDetails.
v, _, err := p.Eval(input)
if err != nil {
return v, det, err
}
return v, det, nil
}
// ContextEval implements the Program interface method.
func (gen *progGen) ContextEval(ctx context.Context, input any) (ref.Val, *EvalDetails, error) {
if ctx == nil {
return nil, nil, fmt.Errorf("context can not be nil")
}
// The factory based Eval() differs from the standard evaluation model in that it generates a
// new EvalState instance for each call to ensure that unique evaluations yield unique stateful
// results.
state := interpreter.NewEvalState()
costTracker, err := interpreter.NewCostTracker(nil)
if err != nil {
return nil, nil, err
}
det := &EvalDetails{state: state, costTracker: costTracker}
// Generate a new instance of the interpretable using the factory configured during the call to
// newProgram(). It is incredibly unlikely that the factory call will generate an error given
// the factory test performed within the Program() call.
p, err := gen.factory(state, costTracker)
if err != nil {
return nil, det, err
}
// Evaluate the input, returning the result and the 'state' within EvalDetails.
v, _, err := p.ContextEval(ctx, input)
if err != nil {
return v, det, err
}
return v, det, nil
}
type ctxEvalActivation struct {
parent interpreter.Activation
parent Activation
interrupt <-chan struct{}
interruptCheckCount uint
interruptCheckFrequency uint
@@ -438,10 +379,15 @@ func (a *ctxEvalActivation) ResolveName(name string) (any, bool) {
return a.parent.ResolveName(name)
}
func (a *ctxEvalActivation) Parent() interpreter.Activation {
func (a *ctxEvalActivation) Parent() Activation {
return a.parent
}
func (a *ctxEvalActivation) AsPartialActivation() (interpreter.PartialActivation, bool) {
pa, ok := a.parent.(interpreter.PartialActivation)
return pa, ok
}
func newCtxEvalActivationPool() *ctxEvalActivationPool {
return &ctxEvalActivationPool{
Pool: sync.Pool{
@@ -457,7 +403,7 @@ type ctxEvalActivationPool struct {
}
// Setup initializes a pooled Activation with the ability check for context.Context cancellation
func (p *ctxEvalActivationPool) Setup(vars interpreter.Activation, done <-chan struct{}, interruptCheckRate uint) *ctxEvalActivation {
func (p *ctxEvalActivationPool) Setup(vars Activation, done <-chan struct{}, interruptCheckRate uint) *ctxEvalActivation {
a := p.Pool.Get().(*ctxEvalActivation)
a.parent = vars
a.interrupt = done
@@ -506,8 +452,8 @@ func (a *evalActivation) ResolveName(name string) (any, bool) {
}
}
// Parent implements the interpreter.Activation interface
func (a *evalActivation) Parent() interpreter.Activation {
// Parent implements the Activation interface
func (a *evalActivation) Parent() Activation {
return nil
}

155
vendor/github.com/google/cel-go/cel/prompt.go generated vendored Normal file
View File

@@ -0,0 +1,155 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cel
import (
_ "embed"
"sort"
"strings"
"text/template"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/overloads"
)
//go:embed templates/authoring.tmpl
var authoringPrompt string
// AuthoringPrompt creates a prompt template from a CEL environment for the purpose of AI-assisted authoring.
func AuthoringPrompt(env *Env) (*Prompt, error) {
funcMap := template.FuncMap{
"split": func(str string) []string { return strings.Split(str, "\n") },
}
tmpl := template.New("cel").Funcs(funcMap)
tmpl, err := tmpl.Parse(authoringPrompt)
if err != nil {
return nil, err
}
return &Prompt{
Persona: defaultPersona,
FormatRules: defaultFormatRules,
GeneralUsage: defaultGeneralUsage,
tmpl: tmpl,
env: env,
}, nil
}
// Prompt represents the core components of an LLM prompt based on a CEL environment.
//
// All fields of the prompt may be overwritten / modified with support for rendering the
// prompt to a human-readable string.
type Prompt struct {
// Persona indicates something about the kind of user making the request
Persona string
// FormatRules indicate how the LLM should generate its output
FormatRules string
// GeneralUsage specifies additional context on how CEL should be used.
GeneralUsage string
// tmpl is the text template base-configuration for rendering text.
tmpl *template.Template
// env reference used to collect variables, functions, and macros available to the prompt.
env *Env
}
type promptInst struct {
*Prompt
Variables []*common.Doc
Macros []*common.Doc
Functions []*common.Doc
UserPrompt string
}
// Render renders the user prompt with the associated context from the prompt template
// for use with LLM generators.
func (p *Prompt) Render(userPrompt string) string {
var buffer strings.Builder
vars := make([]*common.Doc, len(p.env.Variables()))
for i, v := range p.env.Variables() {
vars[i] = v.Documentation()
}
sort.SliceStable(vars, func(i, j int) bool {
return vars[i].Name < vars[j].Name
})
macs := make([]*common.Doc, len(p.env.Macros()))
for i, m := range p.env.Macros() {
macs[i] = m.(common.Documentor).Documentation()
}
funcs := make([]*common.Doc, 0, len(p.env.Functions()))
for _, f := range p.env.Functions() {
if _, hidden := hiddenFunctions[f.Name()]; hidden {
continue
}
funcs = append(funcs, f.Documentation())
}
sort.SliceStable(funcs, func(i, j int) bool {
return funcs[i].Name < funcs[j].Name
})
inst := &promptInst{
Prompt: p,
Variables: vars,
Macros: macs,
Functions: funcs,
UserPrompt: userPrompt}
p.tmpl.Execute(&buffer, inst)
return buffer.String()
}
const (
defaultPersona = `You are a software engineer with expertise in networking and application security
authoring boolean Common Expression Language (CEL) expressions to ensure firewall,
networking, authentication, and data access is only permitted when all conditions
are satisfied.`
defaultFormatRules = `Output your response as a CEL expression.
Write the expression with the comment on the first line and the expression on the
subsequent lines. Format the expression using 80-character line limits commonly
found in C++ or Java code.`
defaultGeneralUsage = `CEL supports Protocol Buffer and JSON types, as well as simple types and aggregate types.
Simple types include bool, bytes, double, int, string, and uint:
* double literals must always include a decimal point: 1.0, 3.5, -2.2
* uint literals must be positive values suffixed with a 'u': 42u
* byte literals are strings prefixed with a 'b': b'1235'
* string literals can use either single quotes or double quotes: 'hello', "world"
* string literals can also be treated as raw strings that do not require any
escaping within the string by using the 'R' prefix: R"""quote: "hi" """
Aggregate types include list and map:
* list literals consist of zero or more values between brackets: "['a', 'b', 'c']"
* map literal consist of colon-separated key-value pairs within braces: "{'key1': 1, 'key2': 2}"
* Only int, uint, string, and bool types are valid map keys.
* Maps containing HTTP headers must always use lower-cased string keys.
Comments start with two-forward slashes followed by text and a newline.`
)
var (
hiddenFunctions = map[string]bool{
overloads.DeprecatedIn: true,
operators.OldIn: true,
operators.OldNotStrictlyFalse: true,
operators.NotStrictlyFalse: true,
}
)

View File

@@ -0,0 +1,56 @@
{{define "variable"}}{{.Name}} is a {{.Type}}
{{- end -}}
{{define "macro" -}}
{{.Name}} macro{{if .Description}} - {{range split .Description}}{{.}} {{end}}
{{end}}
{{range .Children}}{{range split .Description}} {{.}}
{{end}}
{{- end -}}
{{- end -}}
{{define "overload" -}}
{{if .Children}}{{range .Children}}{{range split .Description}} {{.}}
{{end}}
{{- end -}}
{{else}} {{.Signature}}
{{end}}
{{- end -}}
{{define "function" -}}
{{.Name}}{{if .Description}} - {{range split .Description}}{{.}} {{end}}
{{end}}
{{range .Children}}{{template "overload" .}}{{end}}
{{- end -}}
{{.Persona}}
{{.FormatRules}}
{{if or .Variables .Macros .Functions -}}
Only use the following variables, macros, and functions in expressions.
{{if .Variables}}
Variables:
{{range .Variables}}* {{template "variable" .}}
{{end -}}
{{end -}}
{{if .Macros}}
Macros:
{{range .Macros}}* {{template "macro" .}}
{{end -}}
{{end -}}
{{if .Functions}}
Functions:
{{range .Functions}}* {{template "function" .}}
{{end -}}
{{end -}}
{{- end -}}
{{.GeneralUsage}}
{{.UserPrompt}}

View File

@@ -20,11 +20,16 @@ import (
"regexp"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/env"
"github.com/google/cel-go/common/overloads"
)
const (
homogeneousValidatorName = "cel.lib.std.validate.types.homogeneous"
durationValidatorName = "cel.validator.duration"
regexValidatorName = "cel.validator.matches"
timestampValidatorName = "cel.validator.timestamp"
homogeneousValidatorName = "cel.validator.homogeneous_literals"
nestingLimitValidatorName = "cel.validator.comprehension_nesting_limit"
// HomogeneousAggregateLiteralExemptFunctions is the ValidatorConfig key used to configure
// the set of function names which are exempt from homogeneous type checks. The expected type
@@ -36,6 +41,35 @@ const (
HomogeneousAggregateLiteralExemptFunctions = homogeneousValidatorName + ".exempt"
)
var (
astValidatorFactories = map[string]ASTValidatorFactory{
nestingLimitValidatorName: func(val *env.Validator) (ASTValidator, error) {
if limit, found := val.ConfigValue("limit"); found {
if val, isInt := limit.(int); isInt {
return ValidateComprehensionNestingLimit(val), nil
}
return nil, fmt.Errorf("invalid validator: %s unsupported limit type: %v", nestingLimitValidatorName, limit)
}
return nil, fmt.Errorf("invalid validator: %s missing limit", nestingLimitValidatorName)
},
durationValidatorName: func(*env.Validator) (ASTValidator, error) {
return ValidateDurationLiterals(), nil
},
regexValidatorName: func(*env.Validator) (ASTValidator, error) {
return ValidateRegexLiterals(), nil
},
timestampValidatorName: func(*env.Validator) (ASTValidator, error) {
return ValidateTimestampLiterals(), nil
},
homogeneousValidatorName: func(*env.Validator) (ASTValidator, error) {
return ValidateHomogeneousAggregateLiterals(), nil
},
}
)
// ASTValidatorFactory creates an ASTValidator as configured by the input map
type ASTValidatorFactory func(*env.Validator) (ASTValidator, error)
// ASTValidators configures a set of ASTValidator instances into the target environment.
//
// Validators are applied in the order in which the are specified and are treated as singletons.
@@ -70,6 +104,18 @@ type ASTValidator interface {
Validate(*Env, ValidatorConfig, *ast.AST, *Issues)
}
// ConfigurableASTValidator supports conversion of an object to an `env.Validator` instance used for
// YAML serialization.
type ConfigurableASTValidator interface {
// ToConfig converts the internal configuration of an ASTValidator into an env.Validator instance
// which minimally must include the validator name, but may also include a map[string]any config
// object to be serialized to YAML. The string keys represent the configuration parameter name,
// and the any value must mirror the internally supported type associated with the config key.
//
// Note: only primitive CEL types are supported by CEL validators at this time.
ToConfig() *env.Validator
}
// ValidatorConfig provides an accessor method for querying validator configuration state.
type ValidatorConfig interface {
GetOrDefault(name string, value any) any
@@ -196,7 +242,12 @@ type formatValidator struct {
// Name returns the unique name of this function format validator.
func (v formatValidator) Name() string {
return fmt.Sprintf("cel.lib.std.validate.functions.%s", v.funcName)
return fmt.Sprintf("cel.validator.%s", v.funcName)
}
// ToConfig converts the ASTValidator to an env.Validator specifying the validator name.
func (v formatValidator) ToConfig() *env.Validator {
return env.NewValidator(v.Name())
}
// Validate searches the AST for uses of a given function name with a constant argument and performs a check
@@ -242,6 +293,11 @@ func (homogeneousAggregateLiteralValidator) Name() string {
return homogeneousValidatorName
}
// ToConfig converts the ASTValidator to an env.Validator specifying the validator name.
func (v homogeneousAggregateLiteralValidator) ToConfig() *env.Validator {
return env.NewValidator(v.Name())
}
// Validate validates that all lists and map literals have homogeneous types, i.e. don't contain dyn types.
//
// This validator makes an exception for list and map literals which occur at any level of nesting within
@@ -336,10 +392,18 @@ type nestingLimitValidator struct {
limit int
}
// Name returns the name of the nesting limit validator.
func (v nestingLimitValidator) Name() string {
return "cel.lib.std.validate.comprehension_nesting_limit"
return nestingLimitValidatorName
}
// ToConfig converts the ASTValidator to an env.Validator specifying the validator name and the nesting limit
// as an integer value: {"limit": int}
func (v nestingLimitValidator) ToConfig() *env.Validator {
return env.NewValidator(v.Name()).SetConfig(map[string]any{"limit": v.limit})
}
// Validate implements the ASTValidator interface method.
func (v nestingLimitValidator) Validate(e *Env, _ ValidatorConfig, a *ast.AST, iss *Issues) {
root := ast.NavigateAST(a)
comprehensions := ast.MatchDescendants(root, ast.KindMatcher(ast.ComprehensionKind))

View File

@@ -145,6 +145,17 @@ func (c *checker) checkSelect(e ast.Expr) {
func (c *checker) checkOptSelect(e ast.Expr) {
// Collect metadata related to the opt select call packaged by the parser.
call := e.AsCall()
if len(call.Args()) != 2 || call.IsMemberFunction() {
t := ""
if call.IsMemberFunction() {
t = " member call with"
}
c.errors.notAnOptionalFieldSelectionCall(e.ID(), c.location(e),
fmt.Sprintf(
"incorrect signature.%s argument count: %d", t, len(call.Args())))
return
}
operand := call.Args()[0]
field := call.Args()[1]
fieldName, isString := maybeUnwrapString(field)

View File

@@ -545,16 +545,17 @@ func (c *coster) costCall(e ast.Expr) CostEstimate {
if len(overloadIDs) == 0 {
return CostEstimate{}
}
var targetType AstNode
var targetType *AstNode
if call.IsMemberFunction() {
sum = sum.Add(c.cost(call.Target()))
targetType = c.newAstNode(call.Target())
var t AstNode = c.newAstNode(call.Target())
targetType = &t
}
// Pick a cost estimate range that covers all the overload cost estimation ranges
fnCost := CostEstimate{Min: uint64(math.MaxUint64), Max: 0}
var resultSize *SizeEstimate
for _, overload := range overloadIDs {
overloadCost := c.functionCost(e, call.FunctionName(), overload, &targetType, argTypes, argCosts)
overloadCost := c.functionCost(e, call.FunctionName(), overload, targetType, argTypes, argCosts)
fnCost = fnCost.Union(overloadCost.CostEstimate)
if overloadCost.ResultSize != nil {
if resultSize == nil {

View File

@@ -91,6 +91,17 @@ func NewFunction(name string,
Overloads: overloads}}}
}
// NewFunctionWithDoc creates a named function declaration with a description and one or more overloads.
func NewFunctionWithDoc(name, doc string,
overloads ...*exprpb.Decl_FunctionDecl_Overload) *exprpb.Decl {
return &exprpb.Decl{
Name: name,
DeclKind: &exprpb.Decl_Function{
Function: &exprpb.Decl_FunctionDecl{
// Doc: desc,
Overloads: overloads}}}
}
// NewIdent creates a named identifier declaration with an optional literal
// value.
//
@@ -98,28 +109,37 @@ func NewFunction(name string,
//
// Deprecated: Use NewVar or NewConst instead.
func NewIdent(name string, t *exprpb.Type, v *exprpb.Constant) *exprpb.Decl {
return newIdent(name, t, v, "")
}
func newIdent(name string, t *exprpb.Type, v *exprpb.Constant, desc string) *exprpb.Decl {
return &exprpb.Decl{
Name: name,
DeclKind: &exprpb.Decl_Ident{
Ident: &exprpb.Decl_IdentDecl{
Type: t,
Value: v}}}
Value: v,
Doc: desc}}}
}
// NewConst creates a constant identifier with a CEL constant literal value.
func NewConst(name string, t *exprpb.Type, v *exprpb.Constant) *exprpb.Decl {
return NewIdent(name, t, v)
return newIdent(name, t, v, "")
}
// NewVar creates a variable identifier.
func NewVar(name string, t *exprpb.Type) *exprpb.Decl {
return NewIdent(name, t, nil)
return newIdent(name, t, nil, "")
}
// NewVarWithDoc creates a variable identifier with a type and a description string.
func NewVarWithDoc(name string, t *exprpb.Type, desc string) *exprpb.Decl {
return newIdent(name, t, nil, desc)
}
// NewInstanceOverload creates a instance function overload contract.
// First element of argTypes is instance.
func NewInstanceOverload(id string, argTypes []*exprpb.Type,
resultType *exprpb.Type) *exprpb.Decl_FunctionDecl_Overload {
func NewInstanceOverload(id string, argTypes []*exprpb.Type, resultType *exprpb.Type) *exprpb.Decl_FunctionDecl_Overload {
return &exprpb.Decl_FunctionDecl_Overload{
OverloadId: id,
ResultType: resultType,
@@ -154,8 +174,7 @@ func NewObjectType(typeName string) *exprpb.Type {
// NewOverload creates a function overload declaration which contains a unique
// overload id as well as the expected argument and result types. Overloads
// must be aggregated within a Function declaration.
func NewOverload(id string, argTypes []*exprpb.Type,
resultType *exprpb.Type) *exprpb.Decl_FunctionDecl_Overload {
func NewOverload(id string, argTypes []*exprpb.Type, resultType *exprpb.Type) *exprpb.Decl_FunctionDecl_Overload {
return &exprpb.Decl_FunctionDecl_Overload{
OverloadId: id,
ResultType: resultType,
@@ -231,7 +250,5 @@ func NewWrapperType(wrapped *exprpb.Type) *exprpb.Type {
// TODO: return an error
panic("Wrapped type must be a primitive")
}
return &exprpb.Type{
TypeKind: &exprpb.Type_Wrapper{
Wrapper: primitive}}
return &exprpb.Type{TypeKind: &exprpb.Type_Wrapper{Wrapper: primitive}}
}

View File

@@ -45,6 +45,10 @@ func (e *typeErrors) notAComprehensionRange(id int64, l common.Location, t *type
FormatCELType(t))
}
func (e *typeErrors) notAnOptionalFieldSelectionCall(id int64, l common.Location, err string) {
e.errs.ReportErrorAtID(id, l, "unsupported optional field selection: %s", err)
}
func (e *typeErrors) notAnOptionalFieldSelection(id int64, l common.Location, field ast.Expr) {
e.errs.ReportErrorAtID(id, l, "unsupported optional field selection: %v", field)
}

View File

@@ -9,6 +9,7 @@ go_library(
name = "go_default_library",
srcs = [
"cost.go",
"doc.go",
"error.go",
"errors.go",
"location.go",
@@ -25,6 +26,7 @@ go_test(
name = "go_default_test",
size = "small",
srcs = [
"doc_test.go",
"errors_test.go",
"source_test.go",
],

View File

@@ -160,6 +160,13 @@ func MaxID(a *AST) int64 {
return visitor.maxID + 1
}
// Heights computes the heights of all AST expressions and returns a map from expression id to height.
func Heights(a *AST) map[int64]int {
visitor := make(heightVisitor)
PostOrderVisit(a.Expr(), visitor)
return visitor
}
// NewSourceInfo creates a simple SourceInfo object from an input common.Source value.
func NewSourceInfo(src common.Source) *SourceInfo {
var lineOffsets []int32
@@ -455,3 +462,74 @@ func (v *maxIDVisitor) VisitEntryExpr(e EntryExpr) {
v.maxID = e.ID()
}
}
type heightVisitor map[int64]int
// VisitExpr computes the height of a given node as the max height of its children plus one.
//
// Identifiers and literals are treated as having a height of zero.
func (hv heightVisitor) VisitExpr(e Expr) {
// default includes IdentKind, LiteralKind
hv[e.ID()] = 0
switch e.Kind() {
case SelectKind:
hv[e.ID()] = 1 + hv[e.AsSelect().Operand().ID()]
case CallKind:
c := e.AsCall()
height := hv.maxHeight(c.Args()...)
if c.IsMemberFunction() {
tHeight := hv[c.Target().ID()]
if tHeight > height {
height = tHeight
}
}
hv[e.ID()] = 1 + height
case ListKind:
l := e.AsList()
hv[e.ID()] = 1 + hv.maxHeight(l.Elements()...)
case MapKind:
m := e.AsMap()
hv[e.ID()] = 1 + hv.maxEntryHeight(m.Entries()...)
case StructKind:
s := e.AsStruct()
hv[e.ID()] = 1 + hv.maxEntryHeight(s.Fields()...)
case ComprehensionKind:
comp := e.AsComprehension()
hv[e.ID()] = 1 + hv.maxHeight(comp.IterRange(), comp.AccuInit(), comp.LoopCondition(), comp.LoopStep(), comp.Result())
}
}
// VisitEntryExpr computes the max height of a map or struct entry and associates the height with the entry id.
func (hv heightVisitor) VisitEntryExpr(e EntryExpr) {
hv[e.ID()] = 0
switch e.Kind() {
case MapEntryKind:
me := e.AsMapEntry()
hv[e.ID()] = hv.maxHeight(me.Value(), me.Key())
case StructFieldKind:
sf := e.AsStructField()
hv[e.ID()] = hv[sf.Value().ID()]
}
}
func (hv heightVisitor) maxHeight(exprs ...Expr) int {
max := 0
for _, e := range exprs {
h := hv[e.ID()]
if h > max {
max = h
}
}
return max
}
func (hv heightVisitor) maxEntryHeight(entries ...EntryExpr) int {
max := 0
for _, e := range entries {
h := hv[e.ID()]
if h > max {
max = h
}
}
return max
}

View File

@@ -237,8 +237,13 @@ func visit(expr Expr, visitor Visitor, order visitOrder, depth, maxDepth int) {
case StructKind:
s := expr.AsStruct()
for _, f := range s.Fields() {
visitor.VisitEntryExpr(f)
if order == preOrder {
visitor.VisitEntryExpr(f)
}
visit(f.AsStructField().Value(), visitor, order, depth+1, maxDepth)
if order == postOrder {
visitor.VisitEntryExpr(f)
}
}
}
if order == postOrder {

View File

@@ -63,9 +63,9 @@ func (c *Container) Extend(opts ...ContainerOption) (*Container, error) {
}
// Copy the name and aliases of the existing container.
ext := &Container{name: c.Name()}
if len(c.aliasSet()) > 0 {
aliasSet := make(map[string]string, len(c.aliasSet()))
for k, v := range c.aliasSet() {
if len(c.AliasSet()) > 0 {
aliasSet := make(map[string]string, len(c.AliasSet()))
for k, v := range c.AliasSet() {
aliasSet[k] = v
}
ext.aliases = aliasSet
@@ -133,8 +133,8 @@ func (c *Container) ResolveCandidateNames(name string) []string {
return append(candidates, name)
}
// aliasSet returns the alias to fully-qualified name mapping stored in the container.
func (c *Container) aliasSet() map[string]string {
// AliasSet returns the alias to fully-qualified name mapping stored in the container.
func (c *Container) AliasSet() map[string]string {
if c == nil || c.aliases == nil {
return noAliases
}
@@ -160,7 +160,7 @@ func (c *Container) findAlias(name string) (string, bool) {
simple = name[0:dot]
qualifier = name[dot:]
}
alias, found := c.aliasSet()[simple]
alias, found := c.AliasSet()[simple]
if !found {
return "", false
}
@@ -264,7 +264,7 @@ func aliasAs(kind, qualifiedName, alias string) ContainerOption {
return nil, fmt.Errorf("%s must refer to a valid qualified name: %s",
kind, qualifiedName)
}
aliasRef, found := c.aliasSet()[alias]
aliasRef, found := c.AliasSet()[alias]
if found {
return nil, fmt.Errorf(
"%s collides with existing reference: name=%s, %s=%s, existing=%s",

View File

@@ -13,7 +13,9 @@ go_library(
importpath = "github.com/google/cel-go/common/decls",
deps = [
"//checker/decls:go_default_library",
"//common:go_default_library",
"//common/functions:go_default_library",
"//common/operators:go_default_library",
"//common/types:go_default_library",
"//common/types/ref:go_default_library",
"//common/types/traits:go_default_library",

View File

@@ -20,7 +20,9 @@ import (
"strings"
chkdecls "github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/functions"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
@@ -54,6 +56,7 @@ func NewFunction(name string, opts ...FunctionOpt) (*FunctionDecl, error) {
// overload instances.
type FunctionDecl struct {
name string
doc string
// overloads associated with the function name.
overloads map[string]*OverloadDecl
@@ -84,6 +87,26 @@ const (
declarationEnabled
)
// Documentation generates documentation about the Function and its overloads as a common.Doc object.
func (f *FunctionDecl) Documentation() *common.Doc {
if f == nil {
return nil
}
children := make([]*common.Doc, len(f.OverloadDecls()))
for i, o := range f.OverloadDecls() {
var examples []*common.Doc
for _, ex := range o.Examples() {
examples = append(examples, common.NewExampleDoc(ex))
}
od := common.NewOverloadDoc(o.ID(), formatSignature(f.Name(), o), examples...)
children[i] = od
}
return common.NewFunctionDoc(
f.Name(),
f.Description(),
children...)
}
// Name returns the function name in human-readable terms, e.g. 'contains' of 'math.least'
func (f *FunctionDecl) Name() string {
if f == nil {
@@ -92,9 +115,22 @@ func (f *FunctionDecl) Name() string {
return f.name
}
// Description provides an overview of the function's purpose.
//
// Usage examples should be included on specific overloads.
func (f *FunctionDecl) Description() string {
if f == nil {
return ""
}
return f.doc
}
// IsDeclarationDisabled indicates that the function implementation should be added to the dispatcher, but the
// declaration should not be exposed for use in expressions.
func (f *FunctionDecl) IsDeclarationDisabled() bool {
if f == nil {
return true
}
return f.state == declarationDisabled
}
@@ -107,8 +143,8 @@ func (f *FunctionDecl) Merge(other *FunctionDecl) (*FunctionDecl, error) {
if f == other {
return f, nil
}
if f.Name() != other.Name() {
return nil, fmt.Errorf("cannot merge unrelated functions. %s and %s", f.Name(), other.Name())
if f == nil || other == nil || f.Name() != other.Name() {
return nil, fmt.Errorf("cannot merge unrelated functions. %q and %q", f.Name(), other.Name())
}
merged := &FunctionDecl{
name: f.Name(),
@@ -120,12 +156,17 @@ func (f *FunctionDecl) Merge(other *FunctionDecl) (*FunctionDecl, error) {
disableTypeGuards: f.disableTypeGuards && other.disableTypeGuards,
// default to the current functions declaration state.
state: f.state,
doc: f.doc,
}
// If the other state indicates that the declaration should be explicitly enabled or
// disabled, then update the merged state with the most recent value.
if other.state != declarationStateUnset {
merged.state = other.state
}
// Allow for non-empty overrides of documentation
if len(other.doc) != 0 && f.doc != other.doc {
merged.doc = other.doc
}
// baseline copy of the overloads and their ordinals
copy(merged.overloadOrdinals, f.overloadOrdinals)
for oID, o := range f.overloads {
@@ -148,6 +189,70 @@ func (f *FunctionDecl) Merge(other *FunctionDecl) (*FunctionDecl, error) {
return merged, nil
}
// FunctionSubsetter subsets a function declaration or returns nil and false if the function
// subset was empty.
type FunctionSubsetter func(fn *FunctionDecl) (*FunctionDecl, bool)
// OverloadSelector selects an overload associated with a given function when it returns true.
//
// Used in combination with the Subset method.
type OverloadSelector func(overload *OverloadDecl) bool
// IncludeOverloads defines an OverloadSelector which allow-lists a set of overloads by their ids.
func IncludeOverloads(overloadIDs ...string) OverloadSelector {
return func(overload *OverloadDecl) bool {
for _, oID := range overloadIDs {
if overload.id == oID {
return true
}
}
return false
}
}
// ExcludeOverloads defines an OverloadSelector which deny-lists a set of overloads by their ids.
func ExcludeOverloads(overloadIDs ...string) OverloadSelector {
return func(overload *OverloadDecl) bool {
for _, oID := range overloadIDs {
if overload.id == oID {
return false
}
}
return true
}
}
// Subset returns a new function declaration which contains only the overloads with the specified IDs.
// If the subset function contains no overloads, then nil is returned to indicate the function is not
// functional.
func (f *FunctionDecl) Subset(selector OverloadSelector) *FunctionDecl {
if f == nil {
return nil
}
overloads := make(map[string]*OverloadDecl)
overloadOrdinals := make([]string, 0, len(f.overloadOrdinals))
for _, oID := range f.overloadOrdinals {
overload := f.overloads[oID]
if selector(overload) {
overloads[oID] = overload
overloadOrdinals = append(overloadOrdinals, oID)
}
}
if len(overloads) == 0 {
return nil
}
subset := &FunctionDecl{
name: f.Name(),
doc: f.doc,
overloads: overloads,
singleton: f.singleton,
disableTypeGuards: f.disableTypeGuards,
state: f.state,
overloadOrdinals: overloadOrdinals,
}
return subset
}
// AddOverload ensures that the new overload does not collide with an existing overload signature;
// however, if the function signatures are identical, the implementation may be rewritten as its
// difficult to compare functions by object identity.
@@ -155,6 +260,9 @@ func (f *FunctionDecl) AddOverload(overload *OverloadDecl) error {
if f == nil {
return fmt.Errorf("nil function cannot add overload: %s", overload.ID())
}
if overload == nil {
return fmt.Errorf("cannot add nil overload to funciton: %s", f.Name())
}
for oID, o := range f.overloads {
if oID != overload.ID() && o.SignatureOverlaps(overload) {
return fmt.Errorf("overload signature collision in function %s: %s collides with %s", f.Name(), oID, overload.ID())
@@ -165,10 +273,17 @@ func (f *FunctionDecl) AddOverload(overload *OverloadDecl) error {
if overload.hasBinding() {
f.overloads[oID] = overload
}
// Allow redefinition of the doc string.
if len(overload.doc) != 0 && o.doc != overload.doc {
o.doc = overload.doc
}
return nil
}
return fmt.Errorf("overload redefinition in function. %s: %s has multiple definitions", f.Name(), oID)
}
if overload.HasLateBinding() != o.HasLateBinding() {
return fmt.Errorf("overload with late binding cannot be added to function %s: cannot mix late and non-late bindings", f.Name())
}
}
f.overloadOrdinals = append(f.overloadOrdinals, overload.ID())
f.overloads[overload.ID()] = overload
@@ -177,8 +292,9 @@ func (f *FunctionDecl) AddOverload(overload *OverloadDecl) error {
// OverloadDecls returns the overload declarations in the order in which they were declared.
func (f *FunctionDecl) OverloadDecls() []*OverloadDecl {
var emptySet []*OverloadDecl
if f == nil {
return []*OverloadDecl{}
return emptySet
}
overloads := make([]*OverloadDecl, 0, len(f.overloads))
for _, oID := range f.overloadOrdinals {
@@ -187,15 +303,31 @@ func (f *FunctionDecl) OverloadDecls() []*OverloadDecl {
return overloads
}
// HasLateBinding returns true if the function has late bindings. A function cannot mix late bindings with other bindings.
func (f *FunctionDecl) HasLateBinding() bool {
if f == nil {
return false
}
for _, oID := range f.overloadOrdinals {
if f.overloads[oID].HasLateBinding() {
return true
}
}
return false
}
// Bindings produces a set of function bindings, if any are defined.
func (f *FunctionDecl) Bindings() ([]*functions.Overload, error) {
var emptySet []*functions.Overload
if f == nil {
return []*functions.Overload{}, nil
return emptySet, nil
}
overloads := []*functions.Overload{}
nonStrict := false
hasLateBinding := false
for _, oID := range f.overloadOrdinals {
o := f.overloads[oID]
hasLateBinding = hasLateBinding || o.HasLateBinding()
if o.hasBinding() {
overload := &functions.Overload{
Operator: o.ID(),
@@ -213,6 +345,9 @@ func (f *FunctionDecl) Bindings() ([]*functions.Overload, error) {
if len(overloads) != 0 {
return nil, fmt.Errorf("singleton function incompatible with specialized overloads: %s", f.Name())
}
if hasLateBinding {
return nil, fmt.Errorf("singleton function incompatible with late bindings: %s", f.Name())
}
overloads = []*functions.Overload{
{
Operator: f.Name(),
@@ -298,6 +433,14 @@ func MaybeNoSuchOverload(funcName string, args ...ref.Val) ref.Val {
// FunctionOpt defines a functional option for mutating a function declaration.
type FunctionOpt func(*FunctionDecl) (*FunctionDecl, error)
// FunctionDocs configures documentation from a list of strings separated by newlines.
func FunctionDocs(docs ...string) FunctionOpt {
return func(fn *FunctionDecl) (*FunctionDecl, error) {
fn.doc = common.MultilineDescription(docs...)
return fn, nil
}
}
// DisableTypeGuards disables automatically generated function invocation guards on direct overload calls.
// Type guards remain on during dynamic dispatch for parsed-only expressions.
func DisableTypeGuards(value bool) FunctionOpt {
@@ -450,9 +593,13 @@ func newOverloadInternal(overloadID string,
// implementation.
type OverloadDecl struct {
id string
doc string
argTypes []*types.Type
resultType *types.Type
isMemberFunction bool
// hasLateBinding indicates that the function has a binding which is not known at compile time.
// This is useful for functions which have side-effects or are not deterministically computable.
hasLateBinding bool
// nonStrict indicates that the function will accept error and unknown arguments as inputs.
nonStrict bool
// operandTrait indicates whether the member argument should have a specific type-trait.
@@ -469,6 +616,15 @@ type OverloadDecl struct {
functionOp functions.FunctionOp
}
// Examples returns a list of string examples for the overload.
func (o *OverloadDecl) Examples() []string {
var emptySet []string
if o == nil || len(o.doc) == 0 {
return emptySet
}
return common.ParseDescriptions(o.doc)
}
// ID mirrors the overload signature and provides a unique id which may be referenced within the type-checker
// and interpreter to optimize performance.
//
@@ -508,6 +664,14 @@ func (o *OverloadDecl) IsNonStrict() bool {
return o.nonStrict
}
// HasLateBinding returns whether the overload has a binding which is not known at compile time.
func (o *OverloadDecl) HasLateBinding() bool {
if o == nil {
return false
}
return o.hasLateBinding
}
// OperandTrait returns the trait mask of the first operand to the overload call, e.g.
// `traits.Indexer`
func (o *OverloadDecl) OperandTrait() int {
@@ -666,6 +830,14 @@ func matchOperandTrait(trait int, arg ref.Val) bool {
// OverloadOpt is a functional option for configuring a function overload.
type OverloadOpt func(*OverloadDecl) (*OverloadDecl, error)
// OverloadExamples configures example expressions for the overload.
func OverloadExamples(examples ...string) OverloadOpt {
return func(o *OverloadDecl) (*OverloadDecl, error) {
o.doc = common.MultilineDescription(examples...)
return o, nil
}
}
// UnaryBinding provides the implementation of a unary overload. The provided function is protected by a runtime
// type-guard which ensures runtime type agreement between the overload signature and runtime argument types.
func UnaryBinding(binding functions.UnaryOp) OverloadOpt {
@@ -676,6 +848,9 @@ func UnaryBinding(binding functions.UnaryOp) OverloadOpt {
if len(o.ArgTypes()) != 1 {
return nil, fmt.Errorf("unary function bound to non-unary overload: %s", o.ID())
}
if o.hasLateBinding {
return nil, fmt.Errorf("overload already has a late binding: %s", o.ID())
}
o.unaryOp = binding
return o, nil
}
@@ -691,6 +866,9 @@ func BinaryBinding(binding functions.BinaryOp) OverloadOpt {
if len(o.ArgTypes()) != 2 {
return nil, fmt.Errorf("binary function bound to non-binary overload: %s", o.ID())
}
if o.hasLateBinding {
return nil, fmt.Errorf("overload already has a late binding: %s", o.ID())
}
o.binaryOp = binding
return o, nil
}
@@ -703,11 +881,26 @@ func FunctionBinding(binding functions.FunctionOp) OverloadOpt {
if o.hasBinding() {
return nil, fmt.Errorf("overload already has a binding: %s", o.ID())
}
if o.hasLateBinding {
return nil, fmt.Errorf("overload already has a late binding: %s", o.ID())
}
o.functionOp = binding
return o, nil
}
}
// LateFunctionBinding indicates that the function has a binding which is not known at compile time.
// This is useful for functions which have side-effects or are not deterministically computable.
func LateFunctionBinding() OverloadOpt {
return func(o *OverloadDecl) (*OverloadDecl, error) {
if o.hasBinding() {
return nil, fmt.Errorf("overload already has a binding: %s", o.ID())
}
o.hasLateBinding = true
return o, nil
}
}
// OverloadIsNonStrict enables the function to be called with error and unknown argument values.
//
// Note: do not use this option unless absoluately necessary as it should be an uncommon feature.
@@ -737,13 +930,27 @@ func NewVariable(name string, t *types.Type) *VariableDecl {
return &VariableDecl{name: name, varType: t}
}
// NewVariableWithDoc creates a new variable declaration with usage documentation.
func NewVariableWithDoc(name string, t *types.Type, doc string) *VariableDecl {
return &VariableDecl{name: name, varType: t, doc: doc}
}
// VariableDecl defines a variable declaration which may optionally have a constant value.
type VariableDecl struct {
name string
doc string
varType *types.Type
value ref.Val
}
// Documentation returns name, type, and description for the variable.
func (v *VariableDecl) Documentation() *common.Doc {
if v == nil {
return nil
}
return common.NewVariableDoc(v.Name(), describeCELType(v.Type()), v.Description())
}
// Name returns the fully-qualified variable name
func (v *VariableDecl) Name() string {
if v == nil {
@@ -752,6 +959,16 @@ func (v *VariableDecl) Name() string {
return v.name
}
// Description returns the usage documentation for the variable, if set.
//
// Good usage instructions provide information about the valid formats, ranges, sizes for the variable type.
func (v *VariableDecl) Description() string {
if v == nil {
return ""
}
return v.doc
}
// Type returns the types.Type value associated with the variable.
func (v *VariableDecl) Type() *types.Type {
if v == nil {
@@ -793,7 +1010,7 @@ func variableDeclToExprDecl(v *VariableDecl) (*exprpb.Decl, error) {
if err != nil {
return nil, err
}
return chkdecls.NewVar(v.Name(), varType), nil
return chkdecls.NewVarWithDoc(v.Name(), varType, v.doc), nil
}
// FunctionDeclToExprDecl converts a go-native function declaration into a protobuf-typed function declaration.
@@ -838,8 +1055,10 @@ func functionDeclToExprDecl(f *FunctionDecl) (*exprpb.Decl, error) {
overloads[i] = chkdecls.NewParameterizedOverload(oID, argTypes, resultType, params)
}
}
doc := common.MultilineDescription(o.Examples()...)
overloads[i].Doc = doc
}
return chkdecls.NewFunction(f.Name(), overloads...), nil
return chkdecls.NewFunctionWithDoc(f.Name(), f.Description(), overloads...), nil
}
func collectParamNames(paramNames map[string]struct{}, arg *types.Type) {
@@ -851,6 +1070,60 @@ func collectParamNames(paramNames map[string]struct{}, arg *types.Type) {
}
}
func formatSignature(fnName string, o *OverloadDecl) string {
if opName, isOperator := operators.FindReverse(fnName); isOperator {
if opName == "" {
opName = fnName
}
return formatOperator(opName, o)
}
return formatCall(fnName, o)
}
func formatOperator(opName string, o *OverloadDecl) string {
args := o.ArgTypes()
argTypes := make([]string, len(o.ArgTypes()))
for j, a := range args {
argTypes[j] = describeCELType(a)
}
ret := describeCELType(o.ResultType())
switch len(args) {
case 1:
return fmt.Sprintf("%s%s -> %s", opName, argTypes[0], ret)
case 2:
if opName == operators.Index {
return fmt.Sprintf("%s[%s] -> %s", argTypes[0], argTypes[1], ret)
}
return fmt.Sprintf("%s %s %s -> %s", argTypes[0], opName, argTypes[1], ret)
default:
if opName == operators.Conditional {
return fmt.Sprint("bool ? <T> : <T> -> <T>")
}
return formatCall(opName, o)
}
}
func formatCall(funcName string, o *OverloadDecl) string {
args := make([]string, len(o.ArgTypes()))
ret := describeCELType(o.ResultType())
for j, a := range o.ArgTypes() {
args[j] = describeCELType(a)
}
if o.IsMemberFunction() {
target := args[0]
args = args[1:]
return fmt.Sprintf("%s.%s(%s) -> %s", target, funcName, strings.Join(args, ", "), ret)
}
return fmt.Sprintf("%s(%s) -> %s", funcName, strings.Join(args, ", "), ret)
}
func describeCELType(t *types.Type) string {
if t.Kind() == types.TypeKind {
return "type"
}
return t.String()
}
var (
emptyArgs = []*types.Type{}
emptyArgs []*types.Type
)

View File

@@ -15,3 +15,157 @@
// Package common defines types and utilities common to expression parsing,
// checking, and interpretation
package common
import (
"strings"
"unicode"
)
// DocKind indicates the type of documentation element.
type DocKind int
const (
// DocEnv represents environment variable documentation.
DocEnv DocKind = iota + 1
// DocFunction represents function documentation.
DocFunction
// DocOverload represents function overload documentation.
DocOverload
// DocVariable represents variable documentation.
DocVariable
// DocMacro represents macro documentation.
DocMacro
// DocExample represents example documentation.
DocExample
)
// Doc holds the documentation details for a specific program element like
// a variable, function, macro, or example.
type Doc struct {
// Kind specifies the type of documentation element (e.g., Function, Variable).
Kind DocKind
// Name is the identifier of the documented element (e.g., function name, variable name).
Name string
// Type is the data type associated with the element, primarily used for variables.
Type string
// Signature represents the function or overload signature.
Signature string
// Description holds the textual description of the element, potentially spanning multiple lines.
Description string
// Children holds nested documentation elements, such as overloads for a function
// or examples for a function/macro.
Children []*Doc
}
// MultilineDescription combines multiple lines into a newline separated string.
func MultilineDescription(lines ...string) string {
return strings.Join(lines, "\n")
}
// ParseDescription takes a single string containing newline characters and splits
// it into a multiline description. All empty lines will be skipped.
//
// Returns an empty string if the input string is empty.
func ParseDescription(doc string) string {
var lines []string
if len(doc) != 0 {
// Split the input string by newline characters.
for _, line := range strings.Split(doc, "\n") {
l := strings.TrimRightFunc(line, unicode.IsSpace)
if len(l) == 0 {
continue
}
lines = append(lines, l)
}
}
// Return an empty slice if the input is empty.
return MultilineDescription(lines...)
}
// ParseDescriptions splits a documentation string into multiple multi-line description
// sections, using blank lines as delimiters.
func ParseDescriptions(doc string) []string {
var examples []string
if len(doc) != 0 {
lines := strings.Split(doc, "\n")
lineStart := 0
for i, l := range lines {
// Trim trailing whitespace to identify effectively blank lines.
l = strings.TrimRightFunc(l, unicode.IsSpace)
// If a line is blank, it marks the end of the current section.
if len(l) == 0 {
// Start the next section after the blank line.
ex := lines[lineStart:i]
if len(ex) != 0 {
examples = append(examples, MultilineDescription(ex...))
}
lineStart = i + 1
}
}
// Append the last section if it wasn't terminated by a blank line.
if lineStart < len(lines) {
examples = append(examples, MultilineDescription(lines[lineStart:]...))
}
}
return examples
}
// NewVariableDoc creates a new Doc struct specifically for documenting a variable.
func NewVariableDoc(name, celType, description string) *Doc {
return &Doc{
Kind: DocVariable,
Name: name,
Type: celType,
Description: ParseDescription(description),
}
}
// NewFunctionDoc creates a new Doc struct for documenting a function.
func NewFunctionDoc(name, description string, overloads ...*Doc) *Doc {
return &Doc{
Kind: DocFunction,
Name: name,
Description: ParseDescription(description),
Children: overloads,
}
}
// NewOverloadDoc creates a new Doc struct for a function example.
func NewOverloadDoc(id, signature string, examples ...*Doc) *Doc {
return &Doc{
Kind: DocOverload,
Name: id,
Signature: signature,
Children: examples,
}
}
// NewMacroDoc creates a new Doc struct for documenting a macro.
func NewMacroDoc(name, description string, examples ...*Doc) *Doc {
return &Doc{
Kind: DocMacro,
Name: name,
Description: ParseDescription(description),
Children: examples,
}
}
// NewExampleDoc creates a new Doc struct specifically for holding an example.
func NewExampleDoc(ex string) *Doc {
return &Doc{
Kind: DocExample,
Description: ex,
}
}
// Documentor is an interface for types that can provide their own documentation.
type Documentor interface {
// Documentation returns the documentation coded by the DocKind to assist
// with text formatting.
Documentation() *Doc
}

50
vendor/github.com/google/cel-go/common/env/BUILD.bazel generated vendored Normal file
View File

@@ -0,0 +1,50 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
go_library(
name = "go_default_library",
srcs = [
"env.go",
],
importpath = "github.com/google/cel-go/common/env",
deps = [
"//common:go_default_library",
"//common/decls:go_default_library",
"//common/types:go_default_library",
],
)
go_test(
name = "go_default_test",
size = "small",
srcs = [
"env_test.go",
],
data = glob(["testdata/**"]),
embed = [":go_default_library"],
deps = [
"//common/decls:go_default_library",
"//common/operators:go_default_library",
"//common/overloads:go_default_library",
"//common/types:go_default_library",
"@in_gopkg_yaml_v3//:go_default_library",
],
)

887
vendor/github.com/google/cel-go/common/env/env.go generated vendored Normal file
View File

@@ -0,0 +1,887 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package env provides a representation of a CEL environment.
package env
import (
"errors"
"fmt"
"math"
"strconv"
"strings"
"github.com/google/cel-go/common/decls"
"github.com/google/cel-go/common/types"
)
// NewConfig creates an instance of a YAML serializable CEL environment configuration.
func NewConfig(name string) *Config {
return &Config{
Name: name,
}
}
// Config represents a serializable form of the CEL environment configuration.
//
// Note: custom validations, feature flags, and performance tuning parameters are not (yet)
// considered part of the core CEL environment configuration and should be managed separately
// until a common convention for such settings is developed.
type Config struct {
Name string `yaml:"name,omitempty"`
Description string `yaml:"description,omitempty"`
Container string `yaml:"container,omitempty"`
Imports []*Import `yaml:"imports,omitempty"`
StdLib *LibrarySubset `yaml:"stdlib,omitempty"`
Extensions []*Extension `yaml:"extensions,omitempty"`
ContextVariable *ContextVariable `yaml:"context_variable,omitempty"`
Variables []*Variable `yaml:"variables,omitempty"`
Functions []*Function `yaml:"functions,omitempty"`
Validators []*Validator `yaml:"validators,omitempty"`
Features []*Feature `yaml:"features,omitempty"`
}
// Validate validates the whole configuration is well-formed.
func (c *Config) Validate() error {
if c == nil {
return nil
}
var errs []error
for _, imp := range c.Imports {
if err := imp.Validate(); err != nil {
errs = append(errs, err)
}
}
if err := c.StdLib.Validate(); err != nil {
errs = append(errs, err)
}
for _, ext := range c.Extensions {
if err := ext.Validate(); err != nil {
errs = append(errs, err)
}
}
if err := c.ContextVariable.Validate(); err != nil {
errs = append(errs, err)
}
if c.ContextVariable != nil && len(c.Variables) != 0 {
errs = append(errs, errors.New("invalid config: either context variable or variables may be set, but not both"))
}
for _, v := range c.Variables {
if err := v.Validate(); err != nil {
errs = append(errs, err)
}
}
for _, fn := range c.Functions {
if err := fn.Validate(); err != nil {
errs = append(errs, err)
}
}
for _, feat := range c.Features {
if err := feat.Validate(); err != nil {
errs = append(errs, err)
}
}
for _, val := range c.Validators {
if err := val.Validate(); err != nil {
errs = append(errs, err)
}
}
return errors.Join(errs...)
}
// SetContainer configures the container name for this configuration.
func (c *Config) SetContainer(container string) *Config {
c.Container = container
return c
}
// AddVariableDecls adds one or more variables to the config, converting them to serializable values first.
//
// VariableDecl inputs are expected to be well-formed.
func (c *Config) AddVariableDecls(vars ...*decls.VariableDecl) *Config {
convVars := make([]*Variable, len(vars))
for i, v := range vars {
if v == nil {
continue
}
cv := NewVariable(v.Name(), SerializeTypeDesc(v.Type()))
cv.Description = v.Description()
convVars[i] = cv
}
return c.AddVariables(convVars...)
}
// AddVariables adds one or more vairables to the config.
func (c *Config) AddVariables(vars ...*Variable) *Config {
c.Variables = append(c.Variables, vars...)
return c
}
// SetContextVariable configures the ContextVariable for this configuration.
func (c *Config) SetContextVariable(ctx *ContextVariable) *Config {
c.ContextVariable = ctx
return c
}
// AddFunctionDecls adds one or more functions to the config, converting them to serializable values first.
//
// FunctionDecl inputs are expected to be well-formed.
func (c *Config) AddFunctionDecls(funcs ...*decls.FunctionDecl) *Config {
convFuncs := make([]*Function, len(funcs))
for i, fn := range funcs {
if fn == nil {
continue
}
overloads := make([]*Overload, 0, len(fn.OverloadDecls()))
for _, o := range fn.OverloadDecls() {
overloadID := o.ID()
args := make([]*TypeDesc, 0, len(o.ArgTypes()))
for _, a := range o.ArgTypes() {
args = append(args, SerializeTypeDesc(a))
}
ret := SerializeTypeDesc(o.ResultType())
var overload *Overload
if o.IsMemberFunction() {
overload = NewMemberOverload(overloadID, args[0], args[1:], ret)
} else {
overload = NewOverload(overloadID, args, ret)
}
exampleCount := len(o.Examples())
if exampleCount > 0 {
overload.Examples = o.Examples()
}
overloads = append(overloads, overload)
}
cf := NewFunction(fn.Name(), overloads...)
cf.Description = fn.Description()
convFuncs[i] = cf
}
return c.AddFunctions(convFuncs...)
}
// AddFunctions adds one or more functions to the config.
func (c *Config) AddFunctions(funcs ...*Function) *Config {
c.Functions = append(c.Functions, funcs...)
return c
}
// SetStdLib configures the LibrarySubset for the standard library.
func (c *Config) SetStdLib(subset *LibrarySubset) *Config {
c.StdLib = subset
return c
}
// AddImports appends a set of imports to the config.
func (c *Config) AddImports(imps ...*Import) *Config {
c.Imports = append(c.Imports, imps...)
return c
}
// AddExtensions appends a set of extensions to the config.
func (c *Config) AddExtensions(exts ...*Extension) *Config {
c.Extensions = append(c.Extensions, exts...)
return c
}
// AddValidators appends one or more validators to the config.
func (c *Config) AddValidators(vals ...*Validator) *Config {
c.Validators = append(c.Validators, vals...)
return c
}
// AddFeatures appends one or more features to the config.
func (c *Config) AddFeatures(feats ...*Feature) *Config {
c.Features = append(c.Features, feats...)
return c
}
// NewImport returns a serializable import value from the qualified type name.
func NewImport(name string) *Import {
return &Import{Name: name}
}
// Import represents a type name that will be appreviated by its simple name using
// the cel.Abbrevs() option.
type Import struct {
Name string `yaml:"name"`
}
// Validate validates the import configuration is well-formed.
func (imp *Import) Validate() error {
if imp == nil {
return errors.New("invalid import: nil")
}
if imp.Name == "" {
return errors.New("invalid import: missing type name")
}
return nil
}
// NewVariable returns a serializable variable from a name and type definition
func NewVariable(name string, t *TypeDesc) *Variable {
return NewVariableWithDoc(name, t, "")
}
// NewVariableWithDoc returns a serializable variable from a name, type definition, and doc string.
func NewVariableWithDoc(name string, t *TypeDesc, doc string) *Variable {
return &Variable{Name: name, TypeDesc: t, Description: doc}
}
// Variable represents a typed variable declaration which will be published via the
// cel.VariableDecls() option.
type Variable struct {
Name string `yaml:"name"`
Description string `yaml:"description,omitempty"`
// Type represents the type declaration for the variable.
//
// Deprecated: use the embedded *TypeDesc fields directly.
Type *TypeDesc `yaml:"type,omitempty"`
// TypeDesc is an embedded set of fields allowing for the specification of the Variable type.
*TypeDesc `yaml:",inline"`
}
// Validate validates the variable configuration is well-formed.
func (v *Variable) Validate() error {
if v == nil {
return errors.New("invalid variable: nil")
}
if v.Name == "" {
return errors.New("invalid variable: missing variable name")
}
if err := v.GetType().Validate(); err != nil {
return fmt.Errorf("invalid variable %q: %w", v.Name, err)
}
return nil
}
// GetType returns the variable type description.
//
// Note, if both the embedded TypeDesc and the field Type are non-nil, the embedded TypeDesc will
// take precedence.
func (v *Variable) GetType() *TypeDesc {
if v == nil {
return nil
}
if v.TypeDesc != nil {
return v.TypeDesc
}
if v.Type != nil {
return v.Type
}
return nil
}
// AsCELVariable converts the serializable form of the Variable into a CEL environment declaration.
func (v *Variable) AsCELVariable(tp types.Provider) (*decls.VariableDecl, error) {
if err := v.Validate(); err != nil {
return nil, err
}
t, err := v.GetType().AsCELType(tp)
if err != nil {
return nil, fmt.Errorf("invalid variable %q: %w", v.Name, err)
}
return decls.NewVariableWithDoc(v.Name, t, v.Description), nil
}
// NewContextVariable returns a serializable context variable with a specific type name.
func NewContextVariable(typeName string) *ContextVariable {
return &ContextVariable{TypeName: typeName}
}
// ContextVariable represents a structured message whose fields are to be treated as the top-level
// variable identifiers within CEL expressions.
type ContextVariable struct {
// TypeName represents the fully qualified typename of the context variable.
// Currently, only protobuf types are supported.
TypeName string `yaml:"type_name"`
}
// Validate validates the context-variable configuration is well-formed.
func (ctx *ContextVariable) Validate() error {
if ctx == nil {
return nil
}
if ctx.TypeName == "" {
return errors.New("invalid context variable: missing type name")
}
return nil
}
// NewFunction creates a serializable function and overload set.
func NewFunction(name string, overloads ...*Overload) *Function {
return &Function{Name: name, Overloads: overloads}
}
// NewFunctionWithDoc creates a serializable function and overload set.
func NewFunctionWithDoc(name, doc string, overloads ...*Overload) *Function {
return &Function{Name: name, Description: doc, Overloads: overloads}
}
// Function represents the serializable format of a function and its overloads.
type Function struct {
Name string `yaml:"name"`
Description string `yaml:"description,omitempty"`
Overloads []*Overload `yaml:"overloads,omitempty"`
}
// Validate validates the function configuration is well-formed.
func (fn *Function) Validate() error {
if fn == nil {
return errors.New("invalid function: nil")
}
if fn.Name == "" {
return errors.New("invalid function: missing function name")
}
if len(fn.Overloads) == 0 {
return fmt.Errorf("invalid function %q: missing overloads", fn.Name)
}
var errs []error
for _, o := range fn.Overloads {
if err := o.Validate(); err != nil {
errs = append(errs, fmt.Errorf("invalid function %q: %w", fn.Name, err))
}
}
return errors.Join(errs...)
}
// AsCELFunction converts the serializable form of the Function into CEL environment declaration.
func (fn *Function) AsCELFunction(tp types.Provider) (*decls.FunctionDecl, error) {
if err := fn.Validate(); err != nil {
return nil, err
}
opts := make([]decls.FunctionOpt, 0, len(fn.Overloads)+1)
for _, o := range fn.Overloads {
opt, err := o.AsFunctionOption(tp)
opts = append(opts, opt)
if err != nil {
return nil, fmt.Errorf("invalid function %q: %w", fn.Name, err)
}
}
if len(fn.Description) != 0 {
opts = append(opts, decls.FunctionDocs(fn.Description))
}
return decls.NewFunction(fn.Name, opts...)
}
// NewOverload returns a new serializable representation of a global overload.
func NewOverload(id string, args []*TypeDesc, ret *TypeDesc, examples ...string) *Overload {
return &Overload{ID: id, Args: args, Return: ret, Examples: examples}
}
// NewMemberOverload returns a new serializable representation of a member (receiver) overload.
func NewMemberOverload(id string, target *TypeDesc, args []*TypeDesc, ret *TypeDesc, examples ...string) *Overload {
return &Overload{ID: id, Target: target, Args: args, Return: ret, Examples: examples}
}
// Overload represents the serializable format of a function overload.
type Overload struct {
ID string `yaml:"id"`
Examples []string `yaml:"examples,omitempty"`
Target *TypeDesc `yaml:"target,omitempty"`
Args []*TypeDesc `yaml:"args,omitempty"`
Return *TypeDesc `yaml:"return,omitempty"`
}
// Validate validates the overload configuration is well-formed.
func (od *Overload) Validate() error {
if od == nil {
return errors.New("invalid overload: nil")
}
if od.ID == "" {
return errors.New("invalid overload: missing overload id")
}
var errs []error
if od.Target != nil {
if err := od.Target.Validate(); err != nil {
errs = append(errs, fmt.Errorf("invalid overload %q target: %w", od.ID, err))
}
}
for i, arg := range od.Args {
if err := arg.Validate(); err != nil {
errs = append(errs, fmt.Errorf("invalid overload %q arg[%d]: %w", od.ID, i, err))
}
}
if err := od.Return.Validate(); err != nil {
errs = append(errs, fmt.Errorf("invalid overload %q return: %w", od.ID, err))
}
return errors.Join(errs...)
}
// AsFunctionOption converts the serializable form of the Overload into a function declaration option.
func (od *Overload) AsFunctionOption(tp types.Provider) (decls.FunctionOpt, error) {
if err := od.Validate(); err != nil {
return nil, err
}
args := make([]*types.Type, len(od.Args))
var err error
var errs []error
for i, a := range od.Args {
args[i], err = a.AsCELType(tp)
if err != nil {
errs = append(errs, err)
}
}
result, err := od.Return.AsCELType(tp)
if err != nil {
errs = append(errs, err)
}
if od.Target != nil {
t, err := od.Target.AsCELType(tp)
if err != nil {
return nil, errors.Join(append(errs, err)...)
}
args = append([]*types.Type{t}, args...)
return decls.MemberOverload(od.ID, args, result), nil
}
if len(errs) != 0 {
return nil, errors.Join(errs...)
}
return decls.Overload(od.ID, args, result, decls.OverloadExamples(od.Examples...)), nil
}
// NewExtension creates a serializable Extension from a name and version string.
func NewExtension(name string, version uint32) *Extension {
versionString := "latest"
if version < math.MaxUint32 {
versionString = strconv.FormatUint(uint64(version), 10)
}
return &Extension{
Name: name,
Version: versionString,
}
}
// Extension represents a named and optionally versioned extension library configured in the environment.
type Extension struct {
// Name is either the LibraryName() or some short-hand simple identifier which is understood by the config-handler.
Name string `yaml:"name"`
// Version may either be an unsigned long value or the string 'latest'. If empty, the value is treated as '0'.
Version string `yaml:"version,omitempty"`
}
// Validate validates the extension configuration is well-formed.
func (e *Extension) Validate() error {
_, err := e.VersionNumber()
return err
}
// VersionNumber returns the parsed version string, or an error if the version cannot be parsed.
func (e *Extension) VersionNumber() (uint32, error) {
if e == nil {
return 0, fmt.Errorf("invalid extension: nil")
}
if e.Name == "" {
return 0, fmt.Errorf("invalid extension: missing name")
}
if e.Version == "latest" {
return math.MaxUint32, nil
}
if e.Version == "" {
return 0, nil
}
ver, err := strconv.ParseUint(e.Version, 10, 32)
if err != nil {
return 0, fmt.Errorf("invalid extension %q version: %w", e.Name, err)
}
return uint32(ver), nil
}
// NewLibrarySubset returns an empty library subsetting config which permits all library features.
func NewLibrarySubset() *LibrarySubset {
return &LibrarySubset{}
}
// LibrarySubset indicates a subset of the macros and function supported by a subsettable library.
type LibrarySubset struct {
// Disabled indicates whether the library has been disabled, typically only used for
// default-enabled libraries like stdlib.
Disabled bool `yaml:"disabled,omitempty"`
// DisableMacros disables macros for the given library.
DisableMacros bool `yaml:"disable_macros,omitempty"`
// IncludeMacros specifies a set of macro function names to include in the subset.
IncludeMacros []string `yaml:"include_macros,omitempty"`
// ExcludeMacros specifies a set of macro function names to exclude from the subset.
// Note: if IncludeMacros is non-empty, then ExcludeFunctions is ignored.
ExcludeMacros []string `yaml:"exclude_macros,omitempty"`
// IncludeFunctions specifies a set of functions to include in the subset.
//
// Note: the overloads specified in the subset need only specify their ID.
// Note: if IncludeFunctions is non-empty, then ExcludeFunctions is ignored.
IncludeFunctions []*Function `yaml:"include_functions,omitempty"`
// ExcludeFunctions specifies the set of functions to exclude from the subset.
//
// Note: the overloads specified in the subset need only specify their ID.
ExcludeFunctions []*Function `yaml:"exclude_functions,omitempty"`
}
// Validate validates the library configuration is well-formed.
//
// For example, setting both the IncludeMacros and ExcludeMacros together could be confusing
// and create a broken expectation, likewise for IncludeFunctions and ExcludeFunctions.
func (lib *LibrarySubset) Validate() error {
if lib == nil {
return nil
}
var errs []error
if len(lib.IncludeMacros) != 0 && len(lib.ExcludeMacros) != 0 {
errs = append(errs, errors.New("invalid subset: cannot both include and exclude macros"))
}
if len(lib.IncludeFunctions) != 0 && len(lib.ExcludeFunctions) != 0 {
errs = append(errs, errors.New("invalid subset: cannot both include and exclude functions"))
}
return errors.Join(errs...)
}
// SubsetFunction produces a function declaration which matches the supported subset, or nil
// if the function is not supported by the LibrarySubset.
//
// For IncludeFunctions, if the function does not specify a set of overloads to include, the
// whole function definition is included. If overloads are set, then a new function which
// includes only the specified overloads is produced.
//
// For ExcludeFunctions, if the function does not specify a set of overloads to exclude, the
// whole function definition is excluded. If overloads are set, then a new function which
// includes only the permitted overloads is produced.
func (lib *LibrarySubset) SubsetFunction(fn *decls.FunctionDecl) (*decls.FunctionDecl, bool) {
// When lib is null, it should indicate that all values are included in the subset.
if lib == nil {
return fn, true
}
if lib.Disabled {
return nil, false
}
if len(lib.IncludeFunctions) != 0 {
for _, include := range lib.IncludeFunctions {
if include.Name != fn.Name() {
continue
}
if len(include.Overloads) == 0 {
return fn, true
}
overloadIDs := make([]string, len(include.Overloads))
for i, o := range include.Overloads {
overloadIDs[i] = o.ID
}
return fn.Subset(decls.IncludeOverloads(overloadIDs...)), true
}
return nil, false
}
if len(lib.ExcludeFunctions) != 0 {
for _, exclude := range lib.ExcludeFunctions {
if exclude.Name != fn.Name() {
continue
}
if len(exclude.Overloads) == 0 {
return nil, false
}
overloadIDs := make([]string, len(exclude.Overloads))
for i, o := range exclude.Overloads {
overloadIDs[i] = o.ID
}
return fn.Subset(decls.ExcludeOverloads(overloadIDs...)), true
}
return fn, true
}
return fn, true
}
// SubsetMacro indicates whether the macro function should be included in the library subset.
func (lib *LibrarySubset) SubsetMacro(macroFunction string) bool {
// When lib is null, it should indicate that all values are included in the subset.
if lib == nil {
return true
}
if lib.Disabled || lib.DisableMacros {
return false
}
if len(lib.IncludeMacros) != 0 {
for _, name := range lib.IncludeMacros {
if name == macroFunction {
return true
}
}
return false
}
if len(lib.ExcludeMacros) != 0 {
for _, name := range lib.ExcludeMacros {
if name == macroFunction {
return false
}
}
return true
}
return true
}
// SetDisabled disables or enables the library.
func (lib *LibrarySubset) SetDisabled(value bool) *LibrarySubset {
lib.Disabled = value
return lib
}
// SetDisableMacros disables the macros for the library.
func (lib *LibrarySubset) SetDisableMacros(value bool) *LibrarySubset {
lib.DisableMacros = value
return lib
}
// AddIncludedMacros allow-lists one or more macros by function name.
//
// Note, this option will override any excluded macros.
func (lib *LibrarySubset) AddIncludedMacros(macros ...string) *LibrarySubset {
lib.IncludeMacros = append(lib.IncludeMacros, macros...)
return lib
}
// AddExcludedMacros deny-lists one or more macros by function name.
func (lib *LibrarySubset) AddExcludedMacros(macros ...string) *LibrarySubset {
lib.ExcludeMacros = append(lib.ExcludeMacros, macros...)
return lib
}
// AddIncludedFunctions allow-lists one or more functions from the subset.
//
// Note, this option will override any excluded functions.
func (lib *LibrarySubset) AddIncludedFunctions(funcs ...*Function) *LibrarySubset {
lib.IncludeFunctions = append(lib.IncludeFunctions, funcs...)
return lib
}
// AddExcludedFunctions deny-lists one or more functions from the subset.
func (lib *LibrarySubset) AddExcludedFunctions(funcs ...*Function) *LibrarySubset {
lib.ExcludeFunctions = append(lib.ExcludeFunctions, funcs...)
return lib
}
// NewValidator returns a named Validator instance.
func NewValidator(name string) *Validator {
return &Validator{Name: name}
}
// Validator represents a named validator with an optional map-based configuration object.
//
// Note: the map-keys must directly correspond to the internal representation of the original
// validator, and should only use primitive scalar types as values at this time.
type Validator struct {
Name string `yaml:"name"`
Config map[string]any `yaml:"config,omitempty"`
}
// Validate validates the configuration of the validator object.
func (v *Validator) Validate() error {
if v == nil {
return errors.New("invalid validator: nil")
}
if v.Name == "" {
return errors.New("invalid validator: missing name")
}
return nil
}
// SetConfig sets the set of map key-value pairs associated with this validator's configuration.
func (v *Validator) SetConfig(config map[string]any) *Validator {
v.Config = config
return v
}
// ConfigValue retrieves the value associated with the config key name, if one exists.
func (v *Validator) ConfigValue(name string) (any, bool) {
if v == nil {
return nil, false
}
value, found := v.Config[name]
return value, found
}
// NewFeature creates a new feature flag with a boolean enablement flag.
func NewFeature(name string, enabled bool) *Feature {
return &Feature{Name: name, Enabled: enabled}
}
// Feature represents a named boolean feature flag supported by CEL.
type Feature struct {
Name string `yaml:"name"`
Enabled bool `yaml:"enabled"`
}
// Validate validates whether the feature is well-configured.
func (feat *Feature) Validate() error {
if feat == nil {
return errors.New("invalid feature: nil")
}
if feat.Name == "" {
return errors.New("invalid feature: missing name")
}
return nil
}
// NewTypeDesc describes a simple or complex type with parameters.
func NewTypeDesc(typeName string, params ...*TypeDesc) *TypeDesc {
return &TypeDesc{TypeName: typeName, Params: params}
}
// NewTypeParam describe a type-param type.
func NewTypeParam(paramName string) *TypeDesc {
return &TypeDesc{TypeName: paramName, IsTypeParam: true}
}
// TypeDesc represents the serializable format of a CEL *types.Type value.
type TypeDesc struct {
TypeName string `yaml:"type_name"`
Params []*TypeDesc `yaml:"params,omitempty"`
IsTypeParam bool `yaml:"is_type_param,omitempty"`
}
// String implements the strings.Stringer interface method.
func (td *TypeDesc) String() string {
ps := make([]string, len(td.Params))
for i, p := range td.Params {
ps[i] = p.String()
}
typeName := td.TypeName
if len(ps) != 0 {
typeName = fmt.Sprintf("%s(%s)", typeName, strings.Join(ps, ","))
}
return typeName
}
// Validate validates the type configuration is well-formed.
func (td *TypeDesc) Validate() error {
if td == nil {
return errors.New("invalid type: nil")
}
if td.TypeName == "" {
return errors.New("invalid type: missing type name")
}
if td.IsTypeParam && len(td.Params) != 0 {
return errors.New("invalid type: param type cannot have parameters")
}
switch td.TypeName {
case "list":
if len(td.Params) != 1 {
return fmt.Errorf("invalid type: list expects 1 parameter, got %d", len(td.Params))
}
return td.Params[0].Validate()
case "map":
if len(td.Params) != 2 {
return fmt.Errorf("invalid type: map expects 2 parameters, got %d", len(td.Params))
}
if err := td.Params[0].Validate(); err != nil {
return err
}
if err := td.Params[1].Validate(); err != nil {
return err
}
case "optional_type":
if len(td.Params) != 1 {
return fmt.Errorf("invalid type: optional_type expects 1 parameter, got %d", len(td.Params))
}
return td.Params[0].Validate()
default:
}
return nil
}
// AsCELType converts the serializable object to a *types.Type value.
func (td *TypeDesc) AsCELType(tp types.Provider) (*types.Type, error) {
err := td.Validate()
if err != nil {
return nil, err
}
switch td.TypeName {
case "dyn":
return types.DynType, nil
case "map":
kt, err := td.Params[0].AsCELType(tp)
if err != nil {
return nil, err
}
vt, err := td.Params[1].AsCELType(tp)
if err != nil {
return nil, err
}
return types.NewMapType(kt, vt), nil
case "list":
et, err := td.Params[0].AsCELType(tp)
if err != nil {
return nil, err
}
return types.NewListType(et), nil
case "optional_type":
et, err := td.Params[0].AsCELType(tp)
if err != nil {
return nil, err
}
return types.NewOptionalType(et), nil
default:
if td.IsTypeParam {
return types.NewTypeParamType(td.TypeName), nil
}
if msgType, found := tp.FindStructType(td.TypeName); found {
// First parameter is the type name.
return msgType.Parameters()[0], nil
}
t, found := tp.FindIdent(td.TypeName)
if !found {
return nil, fmt.Errorf("undefined type name: %q", td.TypeName)
}
_, ok := t.(*types.Type)
if ok && len(td.Params) == 0 {
return t.(*types.Type), nil
}
params := make([]*types.Type, len(td.Params))
for i, p := range td.Params {
params[i], err = p.AsCELType(tp)
if err != nil {
return nil, err
}
}
return types.NewOpaqueType(td.TypeName, params...), nil
}
}
// SerializeTypeDesc converts a CEL native *types.Type to a serializable TypeDesc.
func SerializeTypeDesc(t *types.Type) *TypeDesc {
typeName := t.TypeName()
if t.Kind() == types.TypeParamKind {
return NewTypeParam(typeName)
}
if t != types.NullType && t.IsAssignableType(types.NullType) {
if wrapperTypeName, found := wrapperTypes[t.Kind()]; found {
return NewTypeDesc(wrapperTypeName)
}
}
var params []*TypeDesc
for _, p := range t.Parameters() {
params = append(params, SerializeTypeDesc(p))
}
return NewTypeDesc(typeName, params...)
}
var wrapperTypes = map[types.Kind]string{
types.BoolKind: "google.protobuf.BoolValue",
types.BytesKind: "google.protobuf.BytesValue",
types.DoubleKind: "google.protobuf.DoubleValue",
types.IntKind: "google.protobuf.Int64Value",
types.StringKind: "google.protobuf.StringValue",
types.UintKind: "google.protobuf.UInt64Value",
}

View File

@@ -12,6 +12,7 @@ go_library(
],
importpath = "github.com/google/cel-go/common/stdlib",
deps = [
"//common:go_default_library",
"//common/decls:go_default_library",
"//common/functions:go_default_library",
"//common/operators:go_default_library",

File diff suppressed because it is too large Load Diff

View File

@@ -18,6 +18,7 @@ go_library(
"int.go",
"iterator.go",
"json_value.go",
"format.go",
"list.go",
"map.go",
"null.go",

View File

@@ -18,6 +18,7 @@ import (
"fmt"
"reflect"
"strconv"
"strings"
"github.com/google/cel-go/common/types/ref"
@@ -128,6 +129,14 @@ func (b Bool) Value() any {
return bool(b)
}
func (b Bool) format(sb *strings.Builder) {
if b {
sb.WriteString("true")
} else {
sb.WriteString("false")
}
}
// IsBool returns whether the input ref.Val or ref.Type is equal to BoolType.
func IsBool(elem ref.Val) bool {
switch v := elem.(type) {

View File

@@ -19,6 +19,7 @@ import (
"encoding/base64"
"fmt"
"reflect"
"strings"
"unicode/utf8"
"github.com/google/cel-go/common/types/ref"
@@ -138,3 +139,17 @@ func (b Bytes) Type() ref.Type {
func (b Bytes) Value() any {
return []byte(b)
}
func (b Bytes) format(sb *strings.Builder) {
fmt.Fprintf(sb, "b\"%s\"", bytesToOctets([]byte(b)))
}
// bytesToOctets converts byte sequences to a string using a three digit octal encoded value
// per byte.
func bytesToOctets(byteVal []byte) string {
var b strings.Builder
for _, c := range byteVal {
fmt.Fprintf(&b, "\\%03o", c)
}
return b.String()
}

View File

@@ -18,6 +18,8 @@ import (
"fmt"
"math"
"reflect"
"strconv"
"strings"
"github.com/google/cel-go/common/types/ref"
@@ -209,3 +211,23 @@ func (d Double) Type() ref.Type {
func (d Double) Value() any {
return float64(d)
}
func (d Double) format(sb *strings.Builder) {
if math.IsNaN(float64(d)) {
sb.WriteString(`double("NaN")`)
return
}
if math.IsInf(float64(d), -1) {
sb.WriteString(`double("-Infinity")`)
return
}
if math.IsInf(float64(d), 1) {
sb.WriteString(`double("Infinity")`)
return
}
s := strconv.FormatFloat(float64(d), 'f', -1, 64)
sb.WriteString(s)
if !strings.ContainsRune(s, '.') {
sb.WriteString(".0")
}
}

View File

@@ -18,6 +18,7 @@ import (
"fmt"
"reflect"
"strconv"
"strings"
"time"
"github.com/google/cel-go/common/overloads"
@@ -185,6 +186,10 @@ func (d Duration) Value() any {
return d.Duration
}
func (d Duration) format(sb *strings.Builder) {
fmt.Fprintf(sb, `duration("%ss")`, strconv.FormatFloat(d.Seconds(), 'f', -1, 64))
}
// DurationGetHours returns the duration in hours.
func DurationGetHours(val ref.Val) ref.Val {
dur, ok := val.(Duration)

42
vendor/github.com/google/cel-go/common/types/format.go generated vendored Normal file
View File

@@ -0,0 +1,42 @@
package types
import (
"fmt"
"strings"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
)
type formattable interface {
format(*strings.Builder)
}
// Format formats the value as a string. The result is only intended for human consumption and ignores errors.
// Do not depend on the output being stable. It may change at any time.
func Format(val ref.Val) string {
var sb strings.Builder
formatTo(&sb, val)
return sb.String()
}
func formatTo(sb *strings.Builder, val ref.Val) {
if fmtable, ok := val.(formattable); ok {
fmtable.format(sb)
return
}
// All of the builtins implement formattable. Try to deal with traits.
if l, ok := val.(traits.Lister); ok {
formatList(l, sb)
return
}
if m, ok := val.(traits.Mapper); ok {
formatMap(m, sb)
return
}
// This could be an error, unknown, opaque or object.
// Unfortunately we have no consistent way of inspecting
// opaque and object. So we just fallback to fmt.Stringer
// and hope it is relavent.
fmt.Fprintf(sb, "%s", val)
}

View File

@@ -19,6 +19,7 @@ import (
"math"
"reflect"
"strconv"
"strings"
"time"
"github.com/google/cel-go/common/types/ref"
@@ -290,6 +291,10 @@ func (i Int) Value() any {
return int64(i)
}
func (i Int) format(sb *strings.Builder) {
sb.WriteString(strconv.FormatInt(int64(i), 10))
}
// isJSONSafe indicates whether the int is safely representable as a floating point value in JSON.
func (i Int) isJSONSafe() bool {
return i >= minIntJSON && i <= maxIntJSON

View File

@@ -299,6 +299,22 @@ func (l *baseList) String() string {
return sb.String()
}
func formatList(l traits.Lister, sb *strings.Builder) {
sb.WriteString("[")
n, _ := l.Size().(Int)
for i := 0; i < int(n); i++ {
formatTo(sb, l.Get(Int(i)))
if i != int(n)-1 {
sb.WriteString(", ")
}
}
sb.WriteString("]")
}
func (l *baseList) format(sb *strings.Builder) {
formatList(l, sb)
}
// mutableList aggregates values into its internal storage. For use with internal CEL variables only.
type mutableList struct {
*baseList

View File

@@ -17,6 +17,7 @@ package types
import (
"fmt"
"reflect"
"sort"
"strings"
"github.com/stoewer/go-strcase"
@@ -318,6 +319,41 @@ func (m *baseMap) String() string {
return sb.String()
}
type baseMapEntry struct {
key string
val string
}
func formatMap(m traits.Mapper, sb *strings.Builder) {
it := m.Iterator()
var ents []baseMapEntry
if s, ok := m.Size().(Int); ok {
ents = make([]baseMapEntry, 0, int(s))
}
for it.HasNext() == True {
k := it.Next()
v, _ := m.Find(k)
ents = append(ents, baseMapEntry{Format(k), Format(v)})
}
sort.SliceStable(ents, func(i, j int) bool {
return ents[i].key < ents[j].key
})
sb.WriteString("{")
for i, ent := range ents {
if i > 0 {
sb.WriteString(", ")
}
sb.WriteString(ent.key)
sb.WriteString(": ")
sb.WriteString(ent.val)
}
sb.WriteString("}")
}
func (m *baseMap) format(sb *strings.Builder) {
formatMap(m, sb)
}
// Type implements the ref.Val interface method.
func (m *baseMap) Type() ref.Type {
return MapType

View File

@@ -17,6 +17,7 @@ package types
import (
"fmt"
"reflect"
"strings"
"google.golang.org/protobuf/proto"
@@ -117,3 +118,7 @@ func (n Null) Type() ref.Type {
func (n Null) Value() any {
return structpb.NullValue_NULL_VALUE
}
func (n Null) format(sb *strings.Builder) {
sb.WriteString("null")
}

View File

@@ -17,9 +17,12 @@ package types
import (
"fmt"
"reflect"
"sort"
"strings"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"github.com/google/cel-go/common/types/pb"
"github.com/google/cel-go/common/types/ref"
@@ -163,3 +166,29 @@ func (o *protoObj) Type() ref.Type {
func (o *protoObj) Value() any {
return o.value
}
type protoObjField struct {
fd protoreflect.FieldDescriptor
v protoreflect.Value
}
func (o *protoObj) format(sb *strings.Builder) {
var fields []protoreflect.FieldDescriptor
o.value.ProtoReflect().Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
fields = append(fields, fd)
return true
})
sort.SliceStable(fields, func(i, j int) bool {
return fields[i].Number() < fields[j].Number()
})
sb.WriteString(o.Type().TypeName())
sb.WriteString("{")
for i, field := range fields {
if i > 0 {
sb.WriteString(", ")
}
sb.WriteString(fmt.Sprintf("%s: ", field.Name()))
formatTo(sb, o.Get(String(field.Name())))
}
sb.WriteString("}")
}

View File

@@ -18,6 +18,7 @@ import (
"errors"
"fmt"
"reflect"
"strings"
"github.com/google/cel-go/common/types/ref"
)
@@ -94,6 +95,16 @@ func (o *Optional) String() string {
return "optional.none()"
}
func (o *Optional) format(sb *strings.Builder) {
if o.HasValue() {
sb.WriteString(`optional.of(`)
formatTo(sb, o.GetValue())
sb.WriteString(`)`)
} else {
sb.WriteString("optional.none()")
}
}
// Type implements the ref.Val interface method.
func (o *Optional) Type() ref.Type {
return OptionalType

View File

@@ -186,6 +186,10 @@ func (s String) Value() any {
return string(s)
}
func (s String) format(sb *strings.Builder) {
sb.WriteString(strconv.Quote(string(s)))
}
// StringContains returns whether the string contains a substring.
func StringContains(s, sub ref.Val) ref.Val {
str, ok := s.(String)

View File

@@ -179,6 +179,10 @@ func (t Timestamp) Value() any {
return t.Time
}
func (t Timestamp) format(sb *strings.Builder) {
fmt.Fprintf(sb, `timestamp("%s")`, t.Time.UTC().Format(time.RFC3339Nano))
}
var (
timestampValueType = reflect.TypeOf(&tpb.Timestamp{})

View File

@@ -164,9 +164,9 @@ var (
traits.SubtractorType,
}
// ListType represents the runtime list type.
ListType = NewListType(nil)
ListType = NewListType(DynType)
// MapType represents the runtime map type.
MapType = NewMapType(nil, nil)
MapType = NewMapType(DynType, DynType)
// NullType represents the type of a null value.
NullType = &Type{
kind: NullTypeKind,
@@ -376,6 +376,10 @@ func (t *Type) TypeName() string {
return t.runtimeTypeName
}
func (t *Type) format(sb *strings.Builder) {
sb.WriteString(t.TypeName())
}
// WithTraits creates a copy of the current Type and sets the trait mask to the traits parameter.
//
// This method should be used with Opaque types where the type acts like a container, e.g. vector.
@@ -395,6 +399,9 @@ func (t *Type) WithTraits(traits int) *Type {
// String returns a human-readable definition of the type name.
func (t *Type) String() string {
if t.Kind() == TypeParamKind {
return fmt.Sprintf("<%s>", t.DeclaredTypeName())
}
if len(t.Parameters()) == 0 {
return t.DeclaredTypeName()
}

View File

@@ -19,6 +19,7 @@ import (
"math"
"reflect"
"strconv"
"strings"
"github.com/google/cel-go/common/types/ref"
@@ -250,6 +251,11 @@ func (i Uint) Value() any {
return uint64(i)
}
func (i Uint) format(sb *strings.Builder) {
sb.WriteString(strconv.FormatUint(uint64(i), 10))
sb.WriteString("u")
}
// isJSONSafe indicates whether the uint is safely representable as a floating point value in JSON.
func (i Uint) isJSONSafe() bool {
return i <= maxIntJSON

View File

@@ -10,12 +10,15 @@ go_library(
"bindings.go",
"comprehensions.go",
"encoders.go",
"extension_option_factory.go",
"formatting.go",
"formatting_v2.go",
"guards.go",
"lists.go",
"math.go",
"native.go",
"protos.go",
"regex.go",
"sets.go",
"strings.go",
],
@@ -24,10 +27,12 @@ go_library(
deps = [
"//cel:go_default_library",
"//checker:go_default_library",
"//common:go_default_library",
"//common/ast:go_default_library",
"//common/decls:go_default_library",
"//common/overloads:go_default_library",
"//common/env:go_default_library",
"//common/operators:go_default_library",
"//common/overloads:go_default_library",
"//common/types:go_default_library",
"//common/types/pb:go_default_library",
"//common/types/ref:go_default_library",
@@ -48,11 +53,15 @@ go_test(
srcs = [
"bindings_test.go",
"comprehensions_test.go",
"encoders_test.go",
"encoders_test.go",
"extension_option_factory_test.go",
"formatting_test.go",
"formatting_v2_test.go",
"lists_test.go",
"math_test.go",
"native_test.go",
"protos_test.go",
"regex_test.go",
"sets_test.go",
"strings_test.go",
],
@@ -62,14 +71,16 @@ go_test(
deps = [
"//cel:go_default_library",
"//checker:go_default_library",
"//common:go_default_library",
"//common/env:go_default_library",
"//common/types:go_default_library",
"//common/types/ref:go_default_library",
"//common/types/traits:go_default_library",
"//test:go_default_library",
"//test/proto2pb:go_default_library",
"//test/proto3pb:go_default_library",
"@org_golang_google_protobuf//encoding/protojson:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//types/known/wrapperspb:go_default_library",
"@org_golang_google_protobuf//encoding/protojson:go_default_library",
],
)

View File

@@ -356,6 +356,23 @@ Examples:
math.isFinite(0.0/0.0) // returns false
math.isFinite(1.2) // returns true
### Math.Sqrt
Introduced at version: 2
Returns the square root of the given input as double
Throws error for negative or non-numeric inputs
math.sqrt(<double>) -> <double>
math.sqrt(<int>) -> <double>
math.sqrt(<uint>) -> <double>
Examples:
math.sqrt(81) // returns 9.0
math.sqrt(985.25) // returns 31.388692231439016
math.sqrt(-15) // returns NaN
## Protos
Protos configure extended macros and functions for proto manipulation.
@@ -395,7 +412,7 @@ zero-based.
### Distinct
**Introduced in version 2**
**Introduced in version 2 (cost support in version 3)**
Returns the distinct elements of a list.
@@ -409,7 +426,7 @@ Examples:
### Flatten
**Introduced in version 1**
**Introduced in version 1 (cost support in version 3)**
Flattens a list recursively.
If an optional depth is provided, the list is flattened to a the specificied level.
@@ -428,7 +445,7 @@ Examples:
### Range
**Introduced in version 2**
**Introduced in version 2 (cost support in version 3)**
Returns a list of integers from 0 to n-1.
@@ -441,7 +458,7 @@ Examples:
### Reverse
**Introduced in version 2**
**Introduced in version 2 (cost support in version 3)**
Returns the elements of a list in reverse order.
@@ -454,6 +471,7 @@ Examples:
### Slice
**Introduced in version 0 (cost support in version 3)**
Returns a new sub-list using the indexes provided.
@@ -466,7 +484,7 @@ Examples:
### Sort
**Introduced in version 2**
**Introduced in version 2 (cost support in version 3)**
Sorts a list with comparable elements. If the element type is not comparable
or the element types are not the same, the function will produce an error.
@@ -483,7 +501,7 @@ Examples:
### SortBy
**Introduced in version 2**
**Introduced in version 2 (cost support in version 3)**
Sorts a list by a key value, i.e., the order is determined by the result of
an expression applied to each element of the list.

View File

@@ -149,7 +149,7 @@ type blockValidationExemption struct{}
// Name returns the name of the validator.
func (blockValidationExemption) Name() string {
return "cel.lib.ext.validate.functions.cel.block"
return "cel.validator.cel_block"
}
// Configure implements the ASTValidatorConfigurer interface and augments the list of functions to skip
@@ -224,7 +224,7 @@ func (b *dynamicBlock) ID() int64 {
}
// Eval implements the Interpretable interface method.
func (b *dynamicBlock) Eval(activation interpreter.Activation) ref.Val {
func (b *dynamicBlock) Eval(activation cel.Activation) ref.Val {
sa := b.slotActivationPool.Get().(*dynamicSlotActivation)
sa.Activation = activation
defer b.clearSlots(sa)
@@ -242,7 +242,7 @@ type slotVal struct {
}
type dynamicSlotActivation struct {
interpreter.Activation
cel.Activation
slotExprs []interpreter.Interpretable
slotCount int
slotVals []*slotVal
@@ -295,13 +295,13 @@ func (b *constantBlock) ID() int64 {
// Eval implements the interpreter.Interpretable interface method, and will proxy @index prefixed variable
// lookups into a set of constant slots determined from the plan step.
func (b *constantBlock) Eval(activation interpreter.Activation) ref.Val {
func (b *constantBlock) Eval(activation cel.Activation) ref.Val {
vars := constantSlotActivation{Activation: activation, slots: b.slots, slotCount: b.slotCount}
return b.expr.Eval(vars)
}
type constantSlotActivation struct {
interpreter.Activation
cel.Activation
slots traits.Lister
slotCount int
}

View File

@@ -0,0 +1,75 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ext
import (
"fmt"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/env"
)
// ExtensionOptionFactory converts an ExtensionConfig value to a CEL environment option.
func ExtensionOptionFactory(configElement any) (cel.EnvOption, bool) {
ext, isExtension := configElement.(*env.Extension)
if !isExtension {
return nil, false
}
fac, found := extFactories[ext.Name]
if !found {
return nil, false
}
// If the version is 'latest', set the version value to the max uint.
ver, err := ext.VersionNumber()
if err != nil {
return func(*cel.Env) (*cel.Env, error) {
return nil, fmt.Errorf("invalid extension version: %s - %s", ext.Name, ext.Version)
}, true
}
return fac(ver), true
}
// extensionFactory accepts a version and produces a CEL environment associated with the versioned extension.
type extensionFactory func(uint32) cel.EnvOption
var extFactories = map[string]extensionFactory{
"bindings": func(version uint32) cel.EnvOption {
return Bindings(BindingsVersion(version))
},
"encoders": func(version uint32) cel.EnvOption {
return Encoders(EncodersVersion(version))
},
"lists": func(version uint32) cel.EnvOption {
return Lists(ListsVersion(version))
},
"math": func(version uint32) cel.EnvOption {
return Math(MathVersion(version))
},
"protos": func(version uint32) cel.EnvOption {
return Protos(ProtosVersion(version))
},
"sets": func(version uint32) cel.EnvOption {
return Sets(SetsVersion(version))
},
"strings": func(version uint32) cel.EnvOption {
return Strings(StringsVersion(version))
},
"two-var-comprehensions": func(version uint32) cel.EnvOption {
return TwoVarComprehensions(TwoVarComprehensionsVersion(version))
},
"regex": func(version uint32) cel.EnvOption {
return Regex(RegexVersion(version))
},
}

View File

@@ -268,14 +268,17 @@ func makeMatcher(locale string) (language.Matcher, error) {
type stringFormatter struct{}
// String implements formatStringInterpolator.String.
func (c *stringFormatter) String(arg ref.Val, locale string) (string, error) {
return FormatString(arg, locale)
}
// Decimal implements formatStringInterpolator.Decimal.
func (c *stringFormatter) Decimal(arg ref.Val, locale string) (string, error) {
return formatDecimal(arg, locale)
}
// Fixed implements formatStringInterpolator.Fixed.
func (c *stringFormatter) Fixed(precision *int) func(ref.Val, string) (string, error) {
if precision == nil {
precision = new(int)
@@ -307,6 +310,7 @@ func (c *stringFormatter) Fixed(precision *int) func(ref.Val, string) (string, e
}
}
// Scientific implements formatStringInterpolator.Scientific.
func (c *stringFormatter) Scientific(precision *int) func(ref.Val, string) (string, error) {
if precision == nil {
precision = new(int)
@@ -337,6 +341,7 @@ func (c *stringFormatter) Scientific(precision *int) func(ref.Val, string) (stri
}
}
// Binary implements formatStringInterpolator.Binary.
func (c *stringFormatter) Binary(arg ref.Val, locale string) (string, error) {
switch arg.Type() {
case types.IntType:
@@ -358,6 +363,7 @@ func (c *stringFormatter) Binary(arg ref.Val, locale string) (string, error) {
}
}
// Hex implements formatStringInterpolator.Hex.
func (c *stringFormatter) Hex(useUpper bool) func(ref.Val, string) (string, error) {
return func(arg ref.Val, locale string) (string, error) {
fmtStr := "%x"
@@ -388,6 +394,7 @@ func (c *stringFormatter) Hex(useUpper bool) func(ref.Val, string) (string, erro
}
}
// Octal implements formatStringInterpolator.Octal.
func (c *stringFormatter) Octal(arg ref.Val, locale string) (string, error) {
switch arg.Type() {
case types.IntType:
@@ -407,7 +414,7 @@ type stringFormatValidator struct{}
// Name returns the name of the validator.
func (stringFormatValidator) Name() string {
return "cel.lib.ext.validate.functions.string.format"
return "cel.validator.string_format"
}
// Configure implements the ASTValidatorConfigurer interface and augments the list of functions to skip
@@ -504,6 +511,7 @@ type stringFormatChecker struct {
ast *ast.AST
}
// String implements formatStringInterpolator.String.
func (c *stringFormatChecker) String(arg ref.Val, locale string) (string, error) {
formatArg := c.args[c.currArgIndex]
valid, badID := c.verifyString(formatArg)
@@ -513,6 +521,7 @@ func (c *stringFormatChecker) String(arg ref.Val, locale string) (string, error)
return "", nil
}
// Decimal implements formatStringInterpolator.Decimal.
func (c *stringFormatChecker) Decimal(arg ref.Val, locale string) (string, error) {
id := c.args[c.currArgIndex].ID()
valid := c.verifyTypeOneOf(id, types.IntType, types.UintType)
@@ -522,6 +531,7 @@ func (c *stringFormatChecker) Decimal(arg ref.Val, locale string) (string, error
return "", nil
}
// Fixed implements formatStringInterpolator.Fixed.
func (c *stringFormatChecker) Fixed(precision *int) func(ref.Val, string) (string, error) {
return func(arg ref.Val, locale string) (string, error) {
id := c.args[c.currArgIndex].ID()
@@ -534,6 +544,7 @@ func (c *stringFormatChecker) Fixed(precision *int) func(ref.Val, string) (strin
}
}
// Scientific implements formatStringInterpolator.Scientific.
func (c *stringFormatChecker) Scientific(precision *int) func(ref.Val, string) (string, error) {
return func(arg ref.Val, locale string) (string, error) {
id := c.args[c.currArgIndex].ID()
@@ -545,6 +556,7 @@ func (c *stringFormatChecker) Scientific(precision *int) func(ref.Val, string) (
}
}
// Binary implements formatStringInterpolator.Binary.
func (c *stringFormatChecker) Binary(arg ref.Val, locale string) (string, error) {
id := c.args[c.currArgIndex].ID()
valid := c.verifyTypeOneOf(id, types.IntType, types.UintType, types.BoolType)
@@ -554,6 +566,7 @@ func (c *stringFormatChecker) Binary(arg ref.Val, locale string) (string, error)
return "", nil
}
// Hex implements formatStringInterpolator.Hex.
func (c *stringFormatChecker) Hex(useUpper bool) func(ref.Val, string) (string, error) {
return func(arg ref.Val, locale string) (string, error) {
id := c.args[c.currArgIndex].ID()
@@ -565,6 +578,7 @@ func (c *stringFormatChecker) Hex(useUpper bool) func(ref.Val, string) (string,
}
}
// Octal implements formatStringInterpolator.Octal.
func (c *stringFormatChecker) Octal(arg ref.Val, locale string) (string, error) {
id := c.args[c.currArgIndex].ID()
valid := c.verifyTypeOneOf(id, types.IntType, types.UintType)
@@ -574,6 +588,7 @@ func (c *stringFormatChecker) Octal(arg ref.Val, locale string) (string, error)
return "", nil
}
// Arg implements formatListArgs.Arg.
func (c *stringFormatChecker) Arg(index int64) (ref.Val, error) {
c.argsRequested++
c.currArgIndex = index
@@ -582,6 +597,7 @@ func (c *stringFormatChecker) Arg(index int64) (ref.Val, error) {
return types.Int(0), nil
}
// Size implements formatListArgs.Size.
func (c *stringFormatChecker) Size() int64 {
return int64(len(c.args))
}
@@ -686,10 +702,12 @@ func newFormatError(id int64, msg string, args ...any) error {
}
}
// Error implements error.
func (e formatError) Error() string {
return e.msg
}
// Is implements errors.Is.
func (e formatError) Is(target error) bool {
return e.msg == target.Error()
}
@@ -699,6 +717,7 @@ type stringArgList struct {
args traits.Lister
}
// Arg implements formatListArgs.Arg.
func (c *stringArgList) Arg(index int64) (ref.Val, error) {
if index >= c.args.Size().Value().(int64) {
return nil, fmt.Errorf("index %d out of range", index)
@@ -706,6 +725,7 @@ func (c *stringArgList) Arg(index int64) (ref.Val, error) {
return c.args.Get(types.Int(index)), nil
}
// Size implements formatListArgs.Size.
func (c *stringArgList) Size() int64 {
return c.args.Size().Value().(int64)
}
@@ -887,14 +907,17 @@ func newParseFormatError(msg string, wrapped error) error {
return parseFormatError{msg: msg, wrapped: wrapped}
}
// Error implements error.
func (e parseFormatError) Error() string {
return fmt.Sprintf("%s: %s", e.msg, e.wrapped.Error())
}
// Is implements errors.Is.
func (e parseFormatError) Is(target error) bool {
return e.Error() == target.Error()
}
// Is implements errors.Unwrap.
func (e parseFormatError) Unwrap() error {
return e.wrapped
}

788
vendor/github.com/google/cel-go/ext/formatting_v2.go generated vendored Normal file
View File

@@ -0,0 +1,788 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ext
import (
"errors"
"fmt"
"math"
"sort"
"strconv"
"strings"
"time"
"unicode"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
)
type clauseImplV2 func(ref.Val) (string, error)
type appendingFormatterV2 struct {
buf []byte
}
type formattedMapEntryV2 struct {
key string
val string
}
func (af *appendingFormatterV2) format(arg ref.Val) error {
switch arg.Type() {
case types.BoolType:
argBool, ok := arg.Value().(bool)
if !ok {
return fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.BoolType)
}
af.buf = strconv.AppendBool(af.buf, argBool)
return nil
case types.IntType:
argInt, ok := arg.Value().(int64)
if !ok {
return fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.IntType)
}
af.buf = strconv.AppendInt(af.buf, argInt, 10)
return nil
case types.UintType:
argUint, ok := arg.Value().(uint64)
if !ok {
return fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.UintType)
}
af.buf = strconv.AppendUint(af.buf, argUint, 10)
return nil
case types.DoubleType:
argDbl, ok := arg.Value().(float64)
if !ok {
return fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.DoubleType)
}
if math.IsNaN(argDbl) {
af.buf = append(af.buf, "NaN"...)
return nil
}
if math.IsInf(argDbl, -1) {
af.buf = append(af.buf, "-Infinity"...)
return nil
}
if math.IsInf(argDbl, 1) {
af.buf = append(af.buf, "Infinity"...)
return nil
}
af.buf = strconv.AppendFloat(af.buf, argDbl, 'f', -1, 64)
return nil
case types.BytesType:
argBytes, ok := arg.Value().([]byte)
if !ok {
return fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.BytesType)
}
af.buf = append(af.buf, argBytes...)
return nil
case types.StringType:
argStr, ok := arg.Value().(string)
if !ok {
return fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.StringType)
}
af.buf = append(af.buf, argStr...)
return nil
case types.DurationType:
argDur, ok := arg.Value().(time.Duration)
if !ok {
return fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.DurationType)
}
af.buf = strconv.AppendFloat(af.buf, argDur.Seconds(), 'f', -1, 64)
af.buf = append(af.buf, "s"...)
return nil
case types.TimestampType:
argTime, ok := arg.Value().(time.Time)
if !ok {
return fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.TimestampType)
}
af.buf = argTime.UTC().AppendFormat(af.buf, time.RFC3339Nano)
return nil
case types.NullType:
af.buf = append(af.buf, "null"...)
return nil
case types.TypeType:
argType, ok := arg.Value().(string)
if !ok {
return fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.TypeType)
}
af.buf = append(af.buf, argType...)
return nil
case types.ListType:
argList, ok := arg.(traits.Lister)
if !ok {
return fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.ListType)
}
argIter := argList.Iterator()
af.buf = append(af.buf, "["...)
if argIter.HasNext() == types.True {
if err := af.format(argIter.Next()); err != nil {
return err
}
for argIter.HasNext() == types.True {
af.buf = append(af.buf, ", "...)
if err := af.format(argIter.Next()); err != nil {
return err
}
}
}
af.buf = append(af.buf, "]"...)
return nil
case types.MapType:
argMap, ok := arg.(traits.Mapper)
if !ok {
return fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.MapType)
}
argIter := argMap.Iterator()
ents := []formattedMapEntryV2{}
for argIter.HasNext() == types.True {
key := argIter.Next()
val, ok := argMap.Find(key)
if !ok {
return fmt.Errorf("key missing from map: '%s'", key)
}
keyStr, err := formatStringV2(key)
if err != nil {
return err
}
valStr, err := formatStringV2(val)
if err != nil {
return err
}
ents = append(ents, formattedMapEntryV2{keyStr, valStr})
}
sort.SliceStable(ents, func(x, y int) bool {
return ents[x].key < ents[y].key
})
af.buf = append(af.buf, "{"...)
for i, e := range ents {
if i > 0 {
af.buf = append(af.buf, ", "...)
}
af.buf = append(af.buf, e.key...)
af.buf = append(af.buf, ": "...)
af.buf = append(af.buf, e.val...)
}
af.buf = append(af.buf, "}"...)
return nil
default:
return stringFormatErrorV2(runtimeID, arg.Type().TypeName())
}
}
func formatStringV2(arg ref.Val) (string, error) {
var fmter appendingFormatterV2
if err := fmter.format(arg); err != nil {
return "", err
}
return string(fmter.buf), nil
}
type stringFormatterV2 struct{}
// String implements formatStringInterpolatorV2.String.
func (c *stringFormatterV2) String(arg ref.Val) (string, error) {
return formatStringV2(arg)
}
// Decimal implements formatStringInterpolatorV2.Decimal.
func (c *stringFormatterV2) Decimal(arg ref.Val) (string, error) {
switch arg.Type() {
case types.IntType:
argInt, ok := arg.Value().(int64)
if !ok {
return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.IntType)
}
return strconv.FormatInt(argInt, 10), nil
case types.UintType:
argUint, ok := arg.Value().(uint64)
if !ok {
return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.UintType)
}
return strconv.FormatUint(argUint, 10), nil
case types.DoubleType:
argDbl, ok := arg.Value().(float64)
if !ok {
return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.DoubleType)
}
if math.IsNaN(argDbl) {
return "NaN", nil
}
if math.IsInf(argDbl, -1) {
return "-Infinity", nil
}
if math.IsInf(argDbl, 1) {
return "Infinity", nil
}
return strconv.FormatFloat(argDbl, 'f', -1, 64), nil
default:
return "", decimalFormatErrorV2(runtimeID, arg.Type().TypeName())
}
}
// Fixed implements formatStringInterpolatorV2.Fixed.
func (c *stringFormatterV2) Fixed(precision int) func(ref.Val) (string, error) {
return func(arg ref.Val) (string, error) {
fmtStr := fmt.Sprintf("%%.%df", precision)
switch arg.Type() {
case types.IntType:
argInt, ok := arg.Value().(int64)
if !ok {
return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.IntType)
}
return fmt.Sprintf(fmtStr, argInt), nil
case types.UintType:
argUint, ok := arg.Value().(uint64)
if !ok {
return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.UintType)
}
return fmt.Sprintf(fmtStr, argUint), nil
case types.DoubleType:
argDbl, ok := arg.Value().(float64)
if !ok {
return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.DoubleType)
}
if math.IsNaN(argDbl) {
return "NaN", nil
}
if math.IsInf(argDbl, -1) {
return "-Infinity", nil
}
if math.IsInf(argDbl, 1) {
return "Infinity", nil
}
return fmt.Sprintf(fmtStr, argDbl), nil
default:
return "", fixedPointFormatErrorV2(runtimeID, arg.Type().TypeName())
}
}
}
// Scientific implements formatStringInterpolatorV2.Scientific.
func (c *stringFormatterV2) Scientific(precision int) func(ref.Val) (string, error) {
return func(arg ref.Val) (string, error) {
fmtStr := fmt.Sprintf("%%1.%de", precision)
switch arg.Type() {
case types.IntType:
argInt, ok := arg.Value().(int64)
if !ok {
return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.IntType)
}
return fmt.Sprintf(fmtStr, argInt), nil
case types.UintType:
argUint, ok := arg.Value().(uint64)
if !ok {
return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.UintType)
}
return fmt.Sprintf(fmtStr, argUint), nil
case types.DoubleType:
argDbl, ok := arg.Value().(float64)
if !ok {
return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.DoubleType)
}
if math.IsNaN(argDbl) {
return "NaN", nil
}
if math.IsInf(argDbl, -1) {
return "-Infinity", nil
}
if math.IsInf(argDbl, 1) {
return "Infinity", nil
}
return fmt.Sprintf(fmtStr, argDbl), nil
default:
return "", scientificFormatErrorV2(runtimeID, arg.Type().TypeName())
}
}
}
// Binary implements formatStringInterpolatorV2.Binary.
func (c *stringFormatterV2) Binary(arg ref.Val) (string, error) {
switch arg.Type() {
case types.BoolType:
argBool, ok := arg.Value().(bool)
if !ok {
return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.BoolType)
}
if argBool {
return "1", nil
}
return "0", nil
case types.IntType:
argInt, ok := arg.Value().(int64)
if !ok {
return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.IntType)
}
return strconv.FormatInt(argInt, 2), nil
case types.UintType:
argUint, ok := arg.Value().(uint64)
if !ok {
return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.UintType)
}
return strconv.FormatUint(argUint, 2), nil
default:
return "", binaryFormatErrorV2(runtimeID, arg.Type().TypeName())
}
}
// Hex implements formatStringInterpolatorV2.Hex.
func (c *stringFormatterV2) Hex(useUpper bool) func(ref.Val) (string, error) {
return func(arg ref.Val) (string, error) {
var fmtStr string
if useUpper {
fmtStr = "%X"
} else {
fmtStr = "%x"
}
switch arg.Type() {
case types.IntType:
argInt, ok := arg.Value().(int64)
if !ok {
return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.IntType)
}
return fmt.Sprintf(fmtStr, argInt), nil
case types.UintType:
argUint, ok := arg.Value().(uint64)
if !ok {
return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.UintType)
}
return fmt.Sprintf(fmtStr, argUint), nil
case types.StringType:
argStr, ok := arg.Value().(string)
if !ok {
return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.StringType)
}
return fmt.Sprintf(fmtStr, argStr), nil
case types.BytesType:
argBytes, ok := arg.Value().([]byte)
if !ok {
return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.BytesType)
}
return fmt.Sprintf(fmtStr, argBytes), nil
default:
return "", hexFormatErrorV2(runtimeID, arg.Type().TypeName())
}
}
}
// Octal implements formatStringInterpolatorV2.Octal.
func (c *stringFormatterV2) Octal(arg ref.Val) (string, error) {
switch arg.Type() {
case types.IntType:
argInt, ok := arg.Value().(int64)
if !ok {
return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.IntType)
}
return strconv.FormatInt(argInt, 8), nil
case types.UintType:
argUint, ok := arg.Value().(uint64)
if !ok {
return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.UintType)
}
return strconv.FormatUint(argUint, 8), nil
default:
return "", octalFormatErrorV2(runtimeID, arg.Type().TypeName())
}
}
// stringFormatValidatorV2 implements the cel.ASTValidator interface allowing for static validation
// of string.format calls.
type stringFormatValidatorV2 struct{}
// Name returns the name of the validator.
func (stringFormatValidatorV2) Name() string {
return "cel.validator.string_format"
}
// Configure implements the ASTValidatorConfigurer interface and augments the list of functions to skip
// during homogeneous aggregate literal type-checks.
func (stringFormatValidatorV2) Configure(config cel.MutableValidatorConfig) error {
functions := config.GetOrDefault(cel.HomogeneousAggregateLiteralExemptFunctions, []string{}).([]string)
functions = append(functions, "format")
return config.Set(cel.HomogeneousAggregateLiteralExemptFunctions, functions)
}
// Validate parses all literal format strings and type checks the format clause against the argument
// at the corresponding ordinal within the list literal argument to the function, if one is specified.
func (stringFormatValidatorV2) Validate(env *cel.Env, _ cel.ValidatorConfig, a *ast.AST, iss *cel.Issues) {
root := ast.NavigateAST(a)
formatCallExprs := ast.MatchDescendants(root, matchConstantFormatStringWithListLiteralArgs(a))
for _, e := range formatCallExprs {
call := e.AsCall()
formatStr := call.Target().AsLiteral().Value().(string)
args := call.Args()[0].AsList().Elements()
formatCheck := &stringFormatCheckerV2{
args: args,
ast: a,
}
// use a placeholder locale, since locale doesn't affect syntax
_, err := parseFormatStringV2(formatStr, formatCheck, formatCheck)
if err != nil {
iss.ReportErrorAtID(getErrorExprID(e.ID(), err), "%v", err)
continue
}
seenArgs := formatCheck.argsRequested
if len(args) > seenArgs {
iss.ReportErrorAtID(e.ID(),
"too many arguments supplied to string.format (expected %d, got %d)", seenArgs, len(args))
}
}
}
// stringFormatCheckerV2 implements the formatStringInterpolater interface
type stringFormatCheckerV2 struct {
args []ast.Expr
argsRequested int
currArgIndex int64
ast *ast.AST
}
// String implements formatStringInterpolatorV2.String.
func (c *stringFormatCheckerV2) String(arg ref.Val) (string, error) {
formatArg := c.args[c.currArgIndex]
valid, badID := c.verifyString(formatArg)
if !valid {
return "", stringFormatErrorV2(badID, c.typeOf(badID).TypeName())
}
return "", nil
}
// Decimal implements formatStringInterpolatorV2.Decimal.
func (c *stringFormatCheckerV2) Decimal(arg ref.Val) (string, error) {
id := c.args[c.currArgIndex].ID()
valid := c.verifyTypeOneOf(id, types.IntType, types.UintType, types.DoubleType)
if !valid {
return "", decimalFormatErrorV2(id, c.typeOf(id).TypeName())
}
return "", nil
}
// Fixed implements formatStringInterpolatorV2.Fixed.
func (c *stringFormatCheckerV2) Fixed(precision int) func(ref.Val) (string, error) {
return func(arg ref.Val) (string, error) {
id := c.args[c.currArgIndex].ID()
valid := c.verifyTypeOneOf(id, types.IntType, types.UintType, types.DoubleType)
if !valid {
return "", fixedPointFormatErrorV2(id, c.typeOf(id).TypeName())
}
return "", nil
}
}
// Scientific implements formatStringInterpolatorV2.Scientific.
func (c *stringFormatCheckerV2) Scientific(precision int) func(ref.Val) (string, error) {
return func(arg ref.Val) (string, error) {
id := c.args[c.currArgIndex].ID()
valid := c.verifyTypeOneOf(id, types.IntType, types.UintType, types.DoubleType)
if !valid {
return "", scientificFormatErrorV2(id, c.typeOf(id).TypeName())
}
return "", nil
}
}
// Binary implements formatStringInterpolatorV2.Binary.
func (c *stringFormatCheckerV2) Binary(arg ref.Val) (string, error) {
id := c.args[c.currArgIndex].ID()
valid := c.verifyTypeOneOf(id, types.BoolType, types.IntType, types.UintType)
if !valid {
return "", binaryFormatErrorV2(id, c.typeOf(id).TypeName())
}
return "", nil
}
// Hex implements formatStringInterpolatorV2.Hex.
func (c *stringFormatCheckerV2) Hex(useUpper bool) func(ref.Val) (string, error) {
return func(arg ref.Val) (string, error) {
id := c.args[c.currArgIndex].ID()
valid := c.verifyTypeOneOf(id, types.IntType, types.UintType, types.StringType, types.BytesType)
if !valid {
return "", hexFormatErrorV2(id, c.typeOf(id).TypeName())
}
return "", nil
}
}
// Octal implements formatStringInterpolatorV2.Octal.
func (c *stringFormatCheckerV2) Octal(arg ref.Val) (string, error) {
id := c.args[c.currArgIndex].ID()
valid := c.verifyTypeOneOf(id, types.IntType, types.UintType)
if !valid {
return "", octalFormatErrorV2(id, c.typeOf(id).TypeName())
}
return "", nil
}
// Arg implements formatListArgs.Arg.
func (c *stringFormatCheckerV2) Arg(index int64) (ref.Val, error) {
c.argsRequested++
c.currArgIndex = index
// return a dummy value - this is immediately passed to back to us
// through one of the FormatCallback functions, so anything will do
return types.Int(0), nil
}
// Size implements formatListArgs.Size.
func (c *stringFormatCheckerV2) Size() int64 {
return int64(len(c.args))
}
func (c *stringFormatCheckerV2) typeOf(id int64) *cel.Type {
return c.ast.GetType(id)
}
func (c *stringFormatCheckerV2) verifyTypeOneOf(id int64, validTypes ...*cel.Type) bool {
t := c.typeOf(id)
if t == cel.DynType {
return true
}
for _, vt := range validTypes {
// Only check runtime type compatibility without delving deeper into parameterized types
if t.Kind() == vt.Kind() {
return true
}
}
return false
}
func (c *stringFormatCheckerV2) verifyString(sub ast.Expr) (bool, int64) {
paramA := cel.TypeParamType("A")
paramB := cel.TypeParamType("B")
subVerified := c.verifyTypeOneOf(sub.ID(),
cel.ListType(paramA), cel.MapType(paramA, paramB),
cel.IntType, cel.UintType, cel.DoubleType, cel.BoolType, cel.StringType,
cel.TimestampType, cel.BytesType, cel.DurationType, cel.TypeType, cel.NullType)
if !subVerified {
return false, sub.ID()
}
switch sub.Kind() {
case ast.ListKind:
for _, e := range sub.AsList().Elements() {
// recursively verify if we're dealing with a list/map
verified, id := c.verifyString(e)
if !verified {
return false, id
}
}
return true, sub.ID()
case ast.MapKind:
for _, e := range sub.AsMap().Entries() {
// recursively verify if we're dealing with a list/map
entry := e.AsMapEntry()
verified, id := c.verifyString(entry.Key())
if !verified {
return false, id
}
verified, id = c.verifyString(entry.Value())
if !verified {
return false, id
}
}
return true, sub.ID()
default:
return true, sub.ID()
}
}
// helper routines for reporting common errors during string formatting static validation and
// runtime execution.
func binaryFormatErrorV2(id int64, badType string) error {
return newFormatError(id, "only ints, uints, and bools can be formatted as binary, was given %s", badType)
}
func decimalFormatErrorV2(id int64, badType string) error {
return newFormatError(id, "decimal clause can only be used on ints, uints, and doubles, was given %s", badType)
}
func fixedPointFormatErrorV2(id int64, badType string) error {
return newFormatError(id, "fixed-point clause can only be used on ints, uints, and doubles, was given %s", badType)
}
func hexFormatErrorV2(id int64, badType string) error {
return newFormatError(id, "only ints, uints, bytes, and strings can be formatted as hex, was given %s", badType)
}
func octalFormatErrorV2(id int64, badType string) error {
return newFormatError(id, "octal clause can only be used on ints and uints, was given %s", badType)
}
func scientificFormatErrorV2(id int64, badType string) error {
return newFormatError(id, "scientific clause can only be used on ints, uints, and doubles, was given %s", badType)
}
func stringFormatErrorV2(id int64, badType string) error {
return newFormatError(id, "string clause can only be used on strings, bools, bytes, ints, doubles, maps, lists, types, durations, and timestamps, was given %s", badType)
}
// formatStringInterpolatorV2 is an interface that allows user-defined behavior
// for formatting clause implementations, as well as argument retrieval.
// Each function is expected to support the appropriate types as laid out in
// the string.format documentation, and to return an error if given an inappropriate type.
type formatStringInterpolatorV2 interface {
// String takes a ref.Val and a string representing the current locale identifier
// and returns the Val formatted as a string, or an error if one occurred.
String(ref.Val) (string, error)
// Decimal takes a ref.Val and a string representing the current locale identifier
// and returns the Val formatted as a decimal integer, or an error if one occurred.
Decimal(ref.Val) (string, error)
// Fixed takes an int pointer representing precision (or nil if none was given) and
// returns a function operating in a similar manner to String and Decimal, taking a
// ref.Val and locale and returning the appropriate string. A closure is returned
// so precision can be set without needing an additional function call/configuration.
Fixed(int) func(ref.Val) (string, error)
// Scientific functions identically to Fixed, except the string returned from the closure
// is expected to be in scientific notation.
Scientific(int) func(ref.Val) (string, error)
// Binary takes a ref.Val and a string representing the current locale identifier
// and returns the Val formatted as a binary integer, or an error if one occurred.
Binary(ref.Val) (string, error)
// Hex takes a boolean that, if true, indicates the hex string output by the returned
// closure should use uppercase letters for A-F.
Hex(bool) func(ref.Val) (string, error)
// Octal takes a ref.Val and a string representing the current locale identifier and
// returns the Val formatted in octal, or an error if one occurred.
Octal(ref.Val) (string, error)
}
// parseFormatString formats a string according to the string.format syntax, taking the clause implementations
// from the provided FormatCallback and the args from the given FormatList.
func parseFormatStringV2(formatStr string, callback formatStringInterpolatorV2, list formatListArgs) (string, error) {
i := 0
argIndex := 0
var builtStr strings.Builder
for i < len(formatStr) {
if formatStr[i] == '%' {
if i+1 < len(formatStr) && formatStr[i+1] == '%' {
err := builtStr.WriteByte('%')
if err != nil {
return "", fmt.Errorf("error writing format string: %w", err)
}
i += 2
continue
} else {
argAny, err := list.Arg(int64(argIndex))
if err != nil {
return "", err
}
if i+1 >= len(formatStr) {
return "", errors.New("unexpected end of string")
}
if int64(argIndex) >= list.Size() {
return "", fmt.Errorf("index %d out of range", argIndex)
}
numRead, val, refErr := parseAndFormatClauseV2(formatStr[i:], argAny, callback, list)
if refErr != nil {
return "", refErr
}
_, err = builtStr.WriteString(val)
if err != nil {
return "", fmt.Errorf("error writing format string: %w", err)
}
i += numRead
argIndex++
}
} else {
err := builtStr.WriteByte(formatStr[i])
if err != nil {
return "", fmt.Errorf("error writing format string: %w", err)
}
i++
}
}
return builtStr.String(), nil
}
// parseAndFormatClause parses the format clause at the start of the given string with val, and returns
// how many characters were consumed and the substituted string form of val, or an error if one occurred.
func parseAndFormatClauseV2(formatStr string, val ref.Val, callback formatStringInterpolatorV2, list formatListArgs) (int, string, error) {
i := 1
read, formatter, err := parseFormattingClauseV2(formatStr[i:], callback)
i += read
if err != nil {
return -1, "", newParseFormatError("could not parse formatting clause", err)
}
valStr, err := formatter(val)
if err != nil {
return -1, "", newParseFormatError("error during formatting", err)
}
return i, valStr, nil
}
func parseFormattingClauseV2(formatStr string, callback formatStringInterpolatorV2) (int, clauseImplV2, error) {
i := 0
read, precision, err := parsePrecisionV2(formatStr[i:])
i += read
if err != nil {
return -1, nil, fmt.Errorf("error while parsing precision: %w", err)
}
r := rune(formatStr[i])
i++
switch r {
case 's':
return i, callback.String, nil
case 'd':
return i, callback.Decimal, nil
case 'f':
return i, callback.Fixed(precision), nil
case 'e':
return i, callback.Scientific(precision), nil
case 'b':
return i, callback.Binary, nil
case 'x', 'X':
return i, callback.Hex(unicode.IsUpper(r)), nil
case 'o':
return i, callback.Octal, nil
default:
return -1, nil, fmt.Errorf("unrecognized formatting clause \"%c\"", r)
}
}
func parsePrecisionV2(formatStr string) (int, int, error) {
i := 0
if formatStr[i] != '.' {
return i, defaultPrecision, nil
}
i++
var buffer strings.Builder
for {
if i >= len(formatStr) {
return -1, -1, errors.New("could not find end of precision specifier")
}
if !isASCIIDigit(rune(formatStr[i])) {
break
}
buffer.WriteByte(formatStr[i])
i++
}
precision, err := strconv.Atoi(buffer.String())
if err != nil {
return -1, -1, fmt.Errorf("error while converting precision to integer: %w", err)
}
if precision < 0 {
return -1, -1, fmt.Errorf("negative precision: %d", precision)
}
return i, precision, nil
}

View File

@@ -20,11 +20,14 @@ import (
"sort"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/checker"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/decls"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
"github.com/google/cel-go/interpreter"
"github.com/google/cel-go/parser"
)
@@ -44,7 +47,7 @@ var comparableTypes = []*cel.Type{
//
// # Distinct
//
// Introduced in version: 2
// Introduced in version: 2 (cost support in version 3)
//
// Returns the distinct elements of a list.
//
@@ -58,7 +61,7 @@ var comparableTypes = []*cel.Type{
//
// # Range
//
// Introduced in version: 2
// Introduced in version: 2 (cost support in version 3)
//
// Returns a list of integers from 0 to n-1.
//
@@ -70,7 +73,7 @@ var comparableTypes = []*cel.Type{
//
// # Reverse
//
// Introduced in version: 2
// Introduced in version: 2 (cost support in version 3)
//
// Returns the elements of a list in reverse order.
//
@@ -82,6 +85,8 @@ var comparableTypes = []*cel.Type{
//
// # Slice
//
// Introduced in version: 0 (cost support in version 3)
//
// Returns a new sub-list using the indexes provided.
//
// <list>.slice(<int>, <int>) -> <list>
@@ -93,12 +98,14 @@ var comparableTypes = []*cel.Type{
//
// # Flatten
//
// Introduced in version: 1 (cost support in version 3)
//
// Flattens a list recursively.
// If an optional depth is provided, the list is flattened to a the specificied level.
// If an optional depth is provided, the list is flattened to a the specified level.
// A negative depth value will result in an error.
//
// <list>.flatten(<list>) -> <list>
// <list>.flatten(<list>, <int>) -> <list>
// <list>.flatten() -> <list>
// <list>.flatten(<int>) -> <list>
//
// Examples:
//
@@ -110,7 +117,7 @@ var comparableTypes = []*cel.Type{
//
// # Sort
//
// Introduced in version: 2
// Introduced in version: 2 (cost support in version 3)
//
// Sorts a list with comparable elements. If the element type is not comparable
// or the element types are not the same, the function will produce an error.
@@ -127,6 +134,8 @@ var comparableTypes = []*cel.Type{
//
// # SortBy
//
// Introduced in version: 2 (cost support in version 3)
//
// Sorts a list by a key value, i.e., the order is determined by the result of
// an expression applied to each element of the list.
// The output of the key expression must be a comparable type, otherwise the
@@ -134,7 +143,7 @@ var comparableTypes = []*cel.Type{
//
// <list(T)>.sortBy(<bindingName>, <keyExpr>) -> <list(T)>
// keyExpr returns a value in {int, uint, double, bool, duration, timestamp, string, bytes}
//
// Examples:
//
// [
@@ -143,7 +152,6 @@ var comparableTypes = []*cel.Type{
// Player { name: "baz", score: 1000 },
// ].sortBy(e, e.score).map(e, e.name)
// == ["bar", "foo", "baz"]
func Lists(options ...ListsOption) cel.EnvOption {
l := &listsLib{version: math.MaxUint32}
for _, o := range options {
@@ -304,9 +312,8 @@ func (lib listsLib) CompileOptions() []cel.EnvOption {
opts = append(opts, cel.Function("lists.range",
cel.Overload("lists_range",
[]*cel.Type{cel.IntType}, cel.ListType(cel.IntType),
cel.FunctionBinding(func(args ...ref.Val) ref.Val {
n := args[0].(types.Int)
result, err := genRange(n)
cel.UnaryBinding(func(n ref.Val) ref.Val {
result, err := genRange(n.(types.Int))
if err != nil {
return types.WrapErr(err)
}
@@ -317,9 +324,8 @@ func (lib listsLib) CompileOptions() []cel.EnvOption {
opts = append(opts, cel.Function("reverse",
cel.MemberOverload("list_reverse",
[]*cel.Type{listType}, listType,
cel.FunctionBinding(func(args ...ref.Val) ref.Val {
list := args[0].(traits.Lister)
result, err := reverseList(list)
cel.UnaryBinding(func(list ref.Val) ref.Val {
result, err := reverseList(list.(traits.Lister))
if err != nil {
return types.WrapErr(err)
}
@@ -340,13 +346,61 @@ func (lib listsLib) CompileOptions() []cel.EnvOption {
),
))
}
if lib.version >= 3 {
estimators := []checker.CostOption{
checker.OverloadCostEstimate("list_slice", estimateListSlice),
checker.OverloadCostEstimate("list_flatten", estimateListFlatten),
checker.OverloadCostEstimate("list_flatten_int", estimateListFlatten),
checker.OverloadCostEstimate("lists_range", estimateListsRange),
checker.OverloadCostEstimate("list_reverse", estimateListReverse),
checker.OverloadCostEstimate("list_distinct", estimateListDistinct),
}
for _, t := range comparableTypes {
estimators = append(estimators,
checker.OverloadCostEstimate(
fmt.Sprintf("list_%s_sort", t.TypeName()),
estimateListSort(t),
),
checker.OverloadCostEstimate(
fmt.Sprintf("list_%s_sortByAssociatedKeys", t.TypeName()),
estimateListSortBy(t),
),
)
}
opts = append(opts, cel.CostEstimatorOptions(estimators...))
}
return opts
}
// ProgramOptions implements the Library interface method.
func (listsLib) ProgramOptions() []cel.ProgramOption {
return []cel.ProgramOption{}
func (lib *listsLib) ProgramOptions() []cel.ProgramOption {
var opts []cel.ProgramOption
if lib.version >= 3 {
// TODO: Add cost trackers for list operations
trackers := []interpreter.CostTrackerOption{
interpreter.OverloadCostTracker("list_slice", trackListOutputSize),
interpreter.OverloadCostTracker("list_flatten", trackListFlatten),
interpreter.OverloadCostTracker("list_flatten_int", trackListFlatten),
interpreter.OverloadCostTracker("lists_range", trackListOutputSize),
interpreter.OverloadCostTracker("list_reverse", trackListOutputSize),
interpreter.OverloadCostTracker("list_distinct", trackListDistinct),
}
for _, t := range comparableTypes {
trackers = append(trackers,
interpreter.OverloadCostTracker(
fmt.Sprintf("list_%s_sort", t.TypeName()),
trackListSort,
),
interpreter.OverloadCostTracker(
fmt.Sprintf("list_%s_sortByAssociatedKeys", t.TypeName()),
trackListSortBy,
),
)
}
opts = append(opts, cel.CostTrackerOptions(trackers...))
}
return opts
}
func genRange(n types.Int) (ref.Val, error) {
@@ -451,20 +505,24 @@ func sortListByAssociatedKeys(list, keys traits.Lister) (ref.Val, error) {
sortedIndices := make([]ref.Val, 0, listLength)
for i := types.IntZero; i < listLength; i++ {
if keys.Get(i).Type() != elem.Type() {
return nil, fmt.Errorf("list elements must have the same type")
}
sortedIndices = append(sortedIndices, i)
}
var err error
sort.Slice(sortedIndices, func(i, j int) bool {
iKey := keys.Get(sortedIndices[i])
jKey := keys.Get(sortedIndices[j])
if iKey.Type() != elem.Type() || jKey.Type() != elem.Type() {
err = fmt.Errorf("list elements must have the same type")
return false
}
return iKey.(traits.Comparer).Compare(jKey) == types.IntNegOne
})
if err != nil {
return nil, err
}
sorted := make([]ref.Val, 0, listLength)
for _, sortedIdx := range sortedIndices {
sorted = append(sorted, list.Get(sortedIdx))
}
@@ -551,3 +609,171 @@ func templatedOverloads(types []*cel.Type, template func(t *cel.Type) cel.Functi
}
return overloads
}
// estimateListSlice computes an O(n) slice operation with a cost factor of 1.
func estimateListSlice(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
if target == nil || len(args) != 2 {
return nil
}
sz := estimateSize(estimator, *target)
start := nodeAsIntValue(args[0], 0)
end := nodeAsIntValue(args[1], sz.Max)
return estimateAllocatingListCall(1, checker.FixedSizeEstimate(end-start))
}
// estimateListsRange computes an O(n) range operation with a cost factor of 1.
func estimateListsRange(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
if target != nil || len(args) != 1 {
return nil
}
return estimateAllocatingListCall(1, checker.FixedSizeEstimate(nodeAsIntValue(args[0], math.MaxUint)))
}
// estimateListReverse computes an O(n) reverse operation with a cost factor of 1.
func estimateListReverse(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
if target == nil || len(args) != 0 {
return nil
}
return estimateAllocatingListCall(1, estimateSize(estimator, *target))
}
// estimateListFlatten computes an O(n) flatten operation with a cost factor proportional to the flatten depth.
func estimateListFlatten(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
if target == nil || len(args) > 1 {
return nil
}
depth := uint64(1)
if len(args) == 1 {
depth = nodeAsIntValue(args[0], math.MaxUint)
}
return estimateAllocatingListCall(float64(depth), estimateSize(estimator, *target))
}
// Compute an O(n^2) with a cost factor of 2, equivalent to sets.contains with a result list
// which can vary in size from 1 element to the original list size.
func estimateListDistinct(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
if target == nil || len(args) != 0 {
return nil
}
sz := estimateSize(estimator, *target)
costFactor := 2.0
return estimateAllocatingListCall(costFactor, sz.Multiply(sz))
}
// estimateListSort computes an O(n^2) sort operation with a cost factor of 2 for the equality
// operations against the elements in the list against themselves which occur during the sort computation.
func estimateListSort(t *types.Type) checker.FunctionEstimator {
return func(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
if target == nil || len(args) != 0 {
return nil
}
return estimateListSortCost(estimator, *target, t)
}
}
// estimateListSortBy computes an O(n^2) sort operation with a cost factor of 2 for the equality
// operations against the sort index list which occur during the sort computation.
func estimateListSortBy(u *types.Type) checker.FunctionEstimator {
return func(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
if target == nil || len(args) != 1 {
return nil
}
// Estimate the size of the list used as the sort index
return estimateListSortCost(estimator, args[0], u)
}
}
// estimateListSortCost estimates an O(n^2) sort operation with a cost factor of 2 for the equality
// operations which occur during the sort computation.
func estimateListSortCost(estimator checker.CostEstimator, node checker.AstNode, elemType *types.Type) *checker.CallEstimate {
sz := estimateSize(estimator, node)
costFactor := 2.0
switch elemType {
case types.StringType, types.BytesType:
costFactor += common.StringTraversalCostFactor
}
return estimateAllocatingListCall(costFactor, sz.Multiply(sz))
}
// estimateAllocatingListCall computes cost as a function of the size of the result list with a
// baseline cost for the call dispatch and the associated list allocation.
func estimateAllocatingListCall(costFactor float64, listSize checker.SizeEstimate) *checker.CallEstimate {
return estimateListCall(costFactor, listSize, true)
}
// estimateListCall computes cost as a function of the size of the target list and whether the
// call allocates memory.
func estimateListCall(costFactor float64, listSize checker.SizeEstimate, allocates bool) *checker.CallEstimate {
cost := listSize.MultiplyByCostFactor(costFactor).Add(callCostEstimate)
if allocates {
cost = cost.Add(checker.FixedCostEstimate(common.ListCreateBaseCost))
}
return &checker.CallEstimate{CostEstimate: cost, ResultSize: &listSize}
}
// trackListOutputSize computes cost as a function of the size of the result list.
func trackListOutputSize(_ []ref.Val, result ref.Val) *uint64 {
return trackAllocatingListCall(1, actualSize(result))
}
// trackListFlatten computes cost as a function of the size of the result list and the depth of
// the flatten operation.
func trackListFlatten(args []ref.Val, _ ref.Val) *uint64 {
depth := 1.0
if len(args) == 2 {
depth = float64(args[1].(types.Int))
}
inputSize := actualSize(args[0])
return trackAllocatingListCall(depth, inputSize)
}
// trackListDistinct computes costs as a worst-case O(n^2) operation over the input list.
func trackListDistinct(args []ref.Val, _ ref.Val) *uint64 {
return trackListSelfCompare(args[0].(traits.Lister))
}
// trackListSort computes costs as a worst-case O(n^2) operation over the input list.
func trackListSort(args []ref.Val, result ref.Val) *uint64 {
return trackListSelfCompare(args[0].(traits.Lister))
}
// trackListSortBy computes costs as a worst-case O(n^2) operation over the sort index list.
func trackListSortBy(args []ref.Val, result ref.Val) *uint64 {
return trackListSelfCompare(args[1].(traits.Lister))
}
// trackListSelfCompare computes costs as a worst-case O(n^2) operation over the input list.
func trackListSelfCompare(l traits.Lister) *uint64 {
sz := actualSize(l)
costFactor := 2.0
if sz == 0 {
return trackAllocatingListCall(costFactor, 0)
}
elem := l.Get(types.IntZero)
if elem.Type() == types.StringType || elem.Type() == types.BytesType {
costFactor += common.StringTraversalCostFactor
}
return trackAllocatingListCall(costFactor, sz*sz)
}
// trackAllocatingListCall computes costs as a function of the size of the result list with a baseline cost
// for the call dispatch and the associated list allocation.
func trackAllocatingListCall(costFactor float64, size uint64) *uint64 {
cost := uint64(float64(size)*costFactor) + callCost + common.ListCreateBaseCost
return &cost
}
func nodeAsIntValue(node checker.AstNode, defaultVal uint64) uint64 {
if node.Expr().Kind() != ast.LiteralKind {
return defaultVal
}
lit := node.Expr().AsLiteral()
if lit.Type() != types.IntType {
return defaultVal
}
val := lit.(types.Int)
if val < types.IntZero {
return 0
}
return uint64(lit.(types.Int))
}

View File

@@ -325,6 +325,23 @@ import (
//
// math.isFinite(0.0/0.0) // returns false
// math.isFinite(1.2) // returns true
//
// # Math.Sqrt
//
// Introduced at version: 2
//
// Returns the square root of the given input as double
// Throws error for negative or non-numeric inputs
//
// math.sqrt(<double>) -> <double>
// math.sqrt(<int>) -> <double>
// math.sqrt(<uint>) -> <double>
//
// Examples:
//
// math.sqrt(81) // returns 9.0
// math.sqrt(985.25) // returns 31.388692231439016
// math.sqrt(-15) // returns NaN
func Math(options ...MathOption) cel.EnvOption {
m := &mathLib{version: math.MaxUint32}
for _, o := range options {
@@ -357,6 +374,9 @@ const (
absFunc = "math.abs"
signFunc = "math.sign"
// SquareRoot function
sqrtFunc = "math.sqrt"
// Bitwise functions
bitAndFunc = "math.bitAnd"
bitOrFunc = "math.bitOr"
@@ -548,6 +568,18 @@ func (lib *mathLib) CompileOptions() []cel.EnvOption {
),
)
}
if lib.version >= 2 {
opts = append(opts,
cel.Function(sqrtFunc,
cel.Overload("math_sqrt_double", []*cel.Type{cel.DoubleType}, cel.DoubleType,
cel.UnaryBinding(sqrt)),
cel.Overload("math_sqrt_int", []*cel.Type{cel.IntType}, cel.DoubleType,
cel.UnaryBinding(sqrt)),
cel.Overload("math_sqrt_uint", []*cel.Type{cel.UintType}, cel.DoubleType,
cel.UnaryBinding(sqrt)),
),
)
}
return opts
}
@@ -691,6 +723,21 @@ func sign(val ref.Val) ref.Val {
}
}
func sqrt(val ref.Val) ref.Val {
switch v := val.(type) {
case types.Double:
return types.Double(math.Sqrt(float64(v)))
case types.Int:
return types.Double(math.Sqrt(float64(v)))
case types.Uint:
return types.Double(math.Sqrt(float64(v)))
default:
return types.NewErr("no such overload: sqrt")
}
}
func bitAndPairInt(first, second ref.Val) ref.Val {
l := first.(types.Int)
r := second.(types.Int)

View File

@@ -81,7 +81,7 @@ var (
// the time that it is invoked.
//
// There is also the possibility to rename the fields of native structs by setting the `cel` tag
// for fields you want to override. In order to enable this feature, pass in the `EnableStructTag`
// for fields you want to override. In order to enable this feature, pass in the `ParseStructTags(true)`
// option. Here is an example to see it in action:
//
// ```go

332
vendor/github.com/google/cel-go/ext/regex.go generated vendored Normal file
View File

@@ -0,0 +1,332 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ext
import (
"errors"
"fmt"
"math"
"regexp"
"strconv"
"strings"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
)
const (
regexReplace = "regex.replace"
regexExtract = "regex.extract"
regexExtractAll = "regex.extractAll"
)
// Regex returns a cel.EnvOption to configure extended functions for regular
// expression operations.
//
// Note: all functions use the 'regex' namespace. If you are
// currently using a variable named 'regex', the functions will likely work as
// intended, however there is some chance for collision.
//
// This library depends on the CEL optional type. Please ensure that the
// cel.OptionalTypes() is enabled when using regex extensions.
//
// # Replace
//
// The `regex.replace` function replaces all non-overlapping substring of a regex
// pattern in the target string with a replacement string. Optionally, you can
// limit the number of replacements by providing a count argument. When the count
// is a negative number, the function acts as replace all. Only numeric (\N)
// capture group references are supported in the replacement string, with
// validation for correctness. Backslashed-escaped digits (\1 to \9) within the
// replacement argument can be used to insert text matching the corresponding
// parenthesized group in the regexp pattern. An error will be thrown for invalid
// regex or replace string.
//
// regex.replace(target: string, pattern: string, replacement: string) -> string
// regex.replace(target: string, pattern: string, replacement: string, count: int) -> string
//
// Examples:
//
// regex.replace('hello world hello', 'hello', 'hi') == 'hi world hi'
// regex.replace('banana', 'a', 'x', 0) == 'banana'
// regex.replace('banana', 'a', 'x', 1) == 'bxnana'
// regex.replace('banana', 'a', 'x', 2) == 'bxnxna'
// regex.replace('banana', 'a', 'x', -12) == 'bxnxnx'
// regex.replace('foo bar', '(fo)o (ba)r', r'\2 \1') == 'ba fo'
// regex.replace('test', '(.)', r'\2') \\ Runtime Error invalid replace string
// regex.replace('foo bar', '(', '$2 $1') \\ Runtime Error invalid regex string
// regex.replace('id=123', r'id=(?P<value>\d+)', r'value: \values') \\ Runtime Error invalid replace string
//
// # Extract
//
// The `regex.extract` function returns the first match of a regex pattern in a
// string. If no match is found, it returns an optional none value. An error will
// be thrown for invalid regex or for multiple capture groups.
//
// regex.extract(target: string, pattern: string) -> optional<string>
//
// Examples:
//
// regex.extract('hello world', 'hello(.*)') == optional.of(' world')
// regex.extract('item-A, item-B', 'item-(\\w+)') == optional.of('A')
// regex.extract('HELLO', 'hello') == optional.empty()
// regex.extract('testuser@testdomain', '(.*)@([^.]*)') // Runtime Error multiple capture group
//
// # Extract All
//
// The `regex.extractAll` function returns a list of all matches of a regex
// pattern in a target string. If no matches are found, it returns an empty list. An error will
// be thrown for invalid regex or for multiple capture groups.
//
// regex.extractAll(target: string, pattern: string) -> list<string>
//
// Examples:
//
// regex.extractAll('id:123, id:456', 'id:\\d+') == ['id:123', 'id:456']
// regex.extractAll('id:123, id:456', 'assa') == []
// regex.extractAll('testuser@testdomain', '(.*)@([^.]*)') // Runtime Error multiple capture group
func Regex(options ...RegexOptions) cel.EnvOption {
s := &regexLib{
version: math.MaxUint32,
}
for _, o := range options {
s = o(s)
}
return cel.Lib(s)
}
// RegexOptions declares a functional operator for configuring regex extension.
type RegexOptions func(*regexLib) *regexLib
// RegexVersion configures the version of the Regex library definitions to use. See [Regex] for supported values.
func RegexVersion(version uint32) RegexOptions {
return func(lib *regexLib) *regexLib {
lib.version = version
return lib
}
}
type regexLib struct {
version uint32
}
// LibraryName implements that SingletonLibrary interface method.
func (r *regexLib) LibraryName() string {
return "cel.lib.ext.regex"
}
// CompileOptions implements the cel.Library interface method.
func (r *regexLib) CompileOptions() []cel.EnvOption {
optionalTypesEnabled := func(env *cel.Env) (*cel.Env, error) {
if !env.HasLibrary("cel.lib.optional") {
return nil, errors.New("regex library requires the optional library")
}
return env, nil
}
opts := []cel.EnvOption{
cel.Function(regexExtract,
cel.Overload("regex_extract_string_string", []*cel.Type{cel.StringType, cel.StringType}, cel.OptionalType(cel.StringType),
cel.BinaryBinding(extract))),
cel.Function(regexExtractAll,
cel.Overload("regex_extractAll_string_string", []*cel.Type{cel.StringType, cel.StringType}, cel.ListType(cel.StringType),
cel.BinaryBinding(extractAll))),
cel.Function(regexReplace,
cel.Overload("regex_replace_string_string_string", []*cel.Type{cel.StringType, cel.StringType, cel.StringType}, cel.StringType,
cel.FunctionBinding(regReplace)),
cel.Overload("regex_replace_string_string_string_int", []*cel.Type{cel.StringType, cel.StringType, cel.StringType, cel.IntType}, cel.StringType,
cel.FunctionBinding((regReplaceN))),
),
cel.EnvOption(optionalTypesEnabled),
}
return opts
}
// ProgramOptions implements the cel.Library interface method
func (r *regexLib) ProgramOptions() []cel.ProgramOption {
return []cel.ProgramOption{}
}
func compileRegex(regexStr string) (*regexp.Regexp, error) {
re, err := regexp.Compile(regexStr)
if err != nil {
return nil, fmt.Errorf("given regex is invalid: %w", err)
}
return re, nil
}
func regReplace(args ...ref.Val) ref.Val {
target := args[0].(types.String)
regexStr := args[1].(types.String)
replaceStr := args[2].(types.String)
return regReplaceN(target, regexStr, replaceStr, types.Int(-1))
}
func regReplaceN(args ...ref.Val) ref.Val {
target := string(args[0].(types.String))
regexStr := string(args[1].(types.String))
replaceStr := string(args[2].(types.String))
replaceCount := int64(args[3].(types.Int))
if replaceCount == 0 {
return types.String(target)
}
if replaceCount > math.MaxInt32 {
return types.NewErr("integer overflow")
}
// If replaceCount is negative, just do a replaceAll.
if replaceCount < 0 {
replaceCount = -1
}
re, err := regexp.Compile(regexStr)
if err != nil {
return types.WrapErr(err)
}
var resultBuilder strings.Builder
var lastIndex int
counter := int64(0)
matches := re.FindAllStringSubmatchIndex(target, -1)
for _, match := range matches {
if replaceCount != -1 && counter >= replaceCount {
break
}
processedReplacement, err := replaceStrValidator(target, re, match, replaceStr)
if err != nil {
return types.WrapErr(err)
}
resultBuilder.WriteString(target[lastIndex:match[0]])
resultBuilder.WriteString(processedReplacement)
lastIndex = match[1]
counter++
}
resultBuilder.WriteString(target[lastIndex:])
return types.String(resultBuilder.String())
}
func replaceStrValidator(target string, re *regexp.Regexp, match []int, replacement string) (string, error) {
groupCount := re.NumSubexp()
var sb strings.Builder
runes := []rune(replacement)
for i := 0; i < len(runes); i++ {
c := runes[i]
if c != '\\' {
sb.WriteRune(c)
continue
}
if i+1 >= len(runes) {
return "", fmt.Errorf("invalid replacement string: '%s' \\ not allowed at end", replacement)
}
i++
nextChar := runes[i]
if nextChar == '\\' {
sb.WriteRune('\\')
continue
}
groupNum, err := strconv.Atoi(string(nextChar))
if err != nil {
return "", fmt.Errorf("invalid replacement string: '%s' \\ must be followed by a digit or \\", replacement)
}
if groupNum > groupCount {
return "", fmt.Errorf("replacement string references group %d but regex has only %d group(s)", groupNum, groupCount)
}
if match[2*groupNum] != -1 {
sb.WriteString(target[match[2*groupNum]:match[2*groupNum+1]])
}
}
return sb.String(), nil
}
func extract(target, regexStr ref.Val) ref.Val {
t := string(target.(types.String))
r := string(regexStr.(types.String))
re, err := compileRegex(r)
if err != nil {
return types.WrapErr(err)
}
if len(re.SubexpNames())-1 > 1 {
return types.WrapErr(fmt.Errorf("regular expression has more than one capturing group: %q", r))
}
matches := re.FindStringSubmatch(t)
if len(matches) == 0 {
return types.OptionalNone
}
// If there is a capturing group, return the first match; otherwise, return the whole match.
if len(matches) > 1 {
capturedGroup := matches[1]
// If optional group is empty, return OptionalNone.
if capturedGroup == "" {
return types.OptionalNone
}
return types.OptionalOf(types.String(capturedGroup))
}
return types.OptionalOf(types.String(matches[0]))
}
func extractAll(target, regexStr ref.Val) ref.Val {
t := string(target.(types.String))
r := string(regexStr.(types.String))
re, err := compileRegex(r)
if err != nil {
return types.WrapErr(err)
}
groupCount := len(re.SubexpNames()) - 1
if groupCount > 1 {
return types.WrapErr(fmt.Errorf("regular expression has more than one capturing group: %q", r))
}
matches := re.FindAllStringSubmatch(t, -1)
result := make([]string, 0, len(matches))
if len(matches) == 0 {
return types.NewStringList(types.DefaultTypeAdapter, result)
}
if groupCount != 1 {
for _, match := range matches {
result = append(result, match[0])
}
return types.NewStringList(types.DefaultTypeAdapter, result)
}
for _, match := range matches {
if match[1] != "" {
result = append(result, match[1])
}
}
return types.NewStringList(types.DefaultTypeAdapter, result)
}

View File

@@ -236,13 +236,13 @@ func setsEquivalent(listA, listB ref.Val) ref.Val {
func estimateSetsCost(costFactor float64) checker.FunctionEstimator {
return func(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
if len(args) == 2 {
arg0Size := estimateSize(estimator, args[0])
arg1Size := estimateSize(estimator, args[1])
costEstimate := arg0Size.Multiply(arg1Size).MultiplyByCostFactor(costFactor).Add(callCostEstimate)
return &checker.CallEstimate{CostEstimate: costEstimate}
if len(args) != 2 {
return nil
}
return nil
arg0Size := estimateSize(estimator, args[0])
arg1Size := estimateSize(estimator, args[1])
costEstimate := arg0Size.Multiply(arg1Size).MultiplyByCostFactor(costFactor).Add(callCostEstimate)
return &checker.CallEstimate{CostEstimate: costEstimate}
}
}
@@ -273,6 +273,6 @@ func actualSize(value ref.Val) uint64 {
}
var (
callCostEstimate = checker.CostEstimate{Min: 1, Max: 1}
callCostEstimate = checker.FixedCostEstimate(1)
callCost = uint64(1)
)

View File

@@ -286,10 +286,15 @@ const (
//
// 'gums'.reverse() // returns 'smug'
// 'John Smith'.reverse() // returns 'htimS nhoJ'
//
// Introduced at version: 4
//
// Formatting updated to adhere to https://github.com/google/cel-spec/blob/master/doc/extensions/strings.md.
//
// <string>.format(<list>) -> <string>
func Strings(options ...StringsOption) cel.EnvOption {
s := &stringLib{
version: math.MaxUint32,
validateFormat: true,
version: math.MaxUint32,
}
for _, o := range options {
s = o(s)
@@ -298,9 +303,8 @@ func Strings(options ...StringsOption) cel.EnvOption {
}
type stringLib struct {
locale string
version uint32
validateFormat bool
locale string
version uint32
}
// LibraryName implements the SingletonLibrary interface method.
@@ -314,6 +318,8 @@ type StringsOption func(*stringLib) *stringLib
// StringsLocale configures the library with the given locale. The locale tag will
// be checked for validity at the time that EnvOptions are configured. If this option
// is not passed, string.format will behave as if en_US was passed as the locale.
//
// If StringsVersion is greater than or equal to 4, this option is ignored.
func StringsLocale(locale string) StringsOption {
return func(sl *stringLib) *stringLib {
sl.locale = locale
@@ -340,10 +346,9 @@ func StringsVersion(version uint32) StringsOption {
// StringsValidateFormatCalls validates type-checked ASTs to ensure that string.format() calls have
// valid formatting clauses and valid argument types for each clause.
//
// Enabled by default.
// Deprecated
func StringsValidateFormatCalls(value bool) StringsOption {
return func(s *stringLib) *stringLib {
s.validateFormat = value
return s
}
}
@@ -351,7 +356,7 @@ func StringsValidateFormatCalls(value bool) StringsOption {
// CompileOptions implements the Library interface method.
func (lib *stringLib) CompileOptions() []cel.EnvOption {
formatLocale := "en_US"
if lib.locale != "" {
if lib.version < 4 && lib.locale != "" {
// ensure locale is properly-formed if set
_, err := language.Parse(lib.locale)
if err != nil {
@@ -466,21 +471,29 @@ func (lib *stringLib) CompileOptions() []cel.EnvOption {
}))),
}
if lib.version >= 1 {
opts = append(opts, cel.Function("format",
cel.MemberOverload("string_format", []*cel.Type{cel.StringType, cel.ListType(cel.DynType)}, cel.StringType,
cel.FunctionBinding(func(args ...ref.Val) ref.Val {
s := string(args[0].(types.String))
formatArgs := args[1].(traits.Lister)
return stringOrError(parseFormatString(s, &stringFormatter{}, &stringArgList{formatArgs}, formatLocale))
}))),
if lib.version >= 4 {
opts = append(opts, cel.Function("format",
cel.MemberOverload("string_format", []*cel.Type{cel.StringType, cel.ListType(cel.DynType)}, cel.StringType,
cel.FunctionBinding(func(args ...ref.Val) ref.Val {
s := string(args[0].(types.String))
formatArgs := args[1].(traits.Lister)
return stringOrError(parseFormatStringV2(s, &stringFormatterV2{}, &stringArgList{formatArgs}))
}))))
} else {
opts = append(opts, cel.Function("format",
cel.MemberOverload("string_format", []*cel.Type{cel.StringType, cel.ListType(cel.DynType)}, cel.StringType,
cel.FunctionBinding(func(args ...ref.Val) ref.Val {
s := string(args[0].(types.String))
formatArgs := args[1].(traits.Lister)
return stringOrError(parseFormatString(s, &stringFormatter{}, &stringArgList{formatArgs}, formatLocale))
}))))
}
opts = append(opts,
cel.Function("strings.quote", cel.Overload("strings_quote", []*cel.Type{cel.StringType}, cel.StringType,
cel.UnaryBinding(func(str ref.Val) ref.Val {
s := str.(types.String)
return stringOrError(quote(string(s)))
}))),
cel.ASTValidators(stringFormatValidator{}))
}))))
}
if lib.version >= 2 {
opts = append(opts,
@@ -529,8 +542,12 @@ func (lib *stringLib) CompileOptions() []cel.EnvOption {
}))),
)
}
if lib.validateFormat {
opts = append(opts, cel.ASTValidators(stringFormatValidator{}))
if lib.version >= 1 {
if lib.version >= 4 {
opts = append(opts, cel.ASTValidators(stringFormatValidatorV2{}))
} else {
opts = append(opts, cel.ASTValidators(stringFormatValidator{}))
}
}
return opts
}
@@ -590,6 +607,10 @@ func lastIndexOf(str, substr string) (int64, error) {
if substr == "" {
return int64(len(runes)), nil
}
if len(str) < len(substr) {
return -1, nil
}
return lastIndexOfOffset(str, substr, int64(len(runes)-1))
}

View File

@@ -158,7 +158,8 @@ type PartialActivation interface {
// partialActivationConverter indicates whether an Activation implementation supports conversion to a PartialActivation
type partialActivationConverter interface {
asPartialActivation() (PartialActivation, bool)
// AsPartialActivation converts the current activation to a PartialActivation
AsPartialActivation() (PartialActivation, bool)
}
// partActivation is the default implementations of the PartialActivation interface.
@@ -172,19 +173,20 @@ func (a *partActivation) UnknownAttributePatterns() []*AttributePattern {
return a.unknowns
}
// asPartialActivation returns the partActivation as a PartialActivation interface.
func (a *partActivation) asPartialActivation() (PartialActivation, bool) {
// AsPartialActivation returns the partActivation as a PartialActivation interface.
func (a *partActivation) AsPartialActivation() (PartialActivation, bool) {
return a, true
}
func asPartialActivation(vars Activation) (PartialActivation, bool) {
// AsPartialActivation walks the activation hierarchy and returns the first PartialActivation, if found.
func AsPartialActivation(vars Activation) (PartialActivation, bool) {
// Only internal activation instances may implement this interface
if pv, ok := vars.(partialActivationConverter); ok {
return pv.asPartialActivation()
return pv.AsPartialActivation()
}
// Since Activations may be hierarchical, test whether a parent converts to a PartialActivation
if vars.Parent() != nil {
return asPartialActivation(vars.Parent())
return AsPartialActivation(vars.Parent())
}
return nil, false
}

View File

@@ -358,7 +358,7 @@ func (m *attributeMatcher) AddQualifier(qual Qualifier) (Attribute, error) {
func (m *attributeMatcher) Resolve(vars Activation) (any, error) {
id := m.NamespacedAttribute.ID()
// Bug in how partial activation is resolved, should search parents as well.
partial, isPartial := asPartialActivation(vars)
partial, isPartial := AsPartialActivation(vars)
if isPartial {
unk, err := m.fac.matchesUnknownPatterns(
partial,

View File

@@ -109,6 +109,47 @@ type InterpretableConstructor interface {
Type() ref.Type
}
// ObservableInterpretable is an Interpretable which supports stateful observation, such as tracing
// or cost-tracking.
type ObservableInterpretable struct {
Interpretable
observers []StatefulObserver
}
// ID implements the Interpretable method to get the expression id associated with the step.
func (oi *ObservableInterpretable) ID() int64 {
return oi.Interpretable.ID()
}
// Eval proxies to the ObserveEval method while invoking a no-op callback to report the observations.
func (oi *ObservableInterpretable) Eval(vars Activation) ref.Val {
return oi.ObserveEval(vars, func(any) {})
}
// ObserveEval evaluates an interpretable and performs per-evaluation state-tracking.
//
// This method is concurrency safe and the expectation is that the observer function will use
// a switch statement to determine the type of the state which has been reported back from the call.
func (oi *ObservableInterpretable) ObserveEval(vars Activation, observer func(any)) ref.Val {
var err error
// Initialize the state needed for the observers to function.
for _, obs := range oi.observers {
vars, err = obs.InitState(vars)
if err != nil {
return types.WrapErr(err)
}
// Provide an initial reference to the state to ensure state is available
// even in cases of interrupting errors generated during evaluation.
observer(obs.GetState(vars))
}
result := oi.Interpretable.Eval(vars)
// Get the state which needs to be reported back as having been observed.
for _, obs := range oi.observers {
observer(obs.GetState(vars))
}
return result
}
// Core Interpretable implementations used during the program planning phase.
type evalTestOnly struct {
@@ -156,9 +197,6 @@ func (q *testOnlyQualifier) Qualify(vars Activation, obj any) (any, error) {
if unk, isUnk := out.(types.Unknown); isUnk {
return unk, nil
}
if opt, isOpt := out.(types.Optional); isOpt {
return opt.HasValue(), nil
}
return present, nil
}
@@ -822,9 +860,9 @@ type evalWatch struct {
}
// Eval implements the Interpretable interface method.
func (e *evalWatch) Eval(ctx Activation) ref.Val {
val := e.Interpretable.Eval(ctx)
e.observer(e.ID(), e.Interpretable, val)
func (e *evalWatch) Eval(vars Activation) ref.Val {
val := e.Interpretable.Eval(vars)
e.observer(vars, e.ID(), e.Interpretable, val)
return val
}
@@ -883,7 +921,7 @@ func (e *evalWatchAttr) AddQualifier(q Qualifier) (Attribute, error) {
// Eval implements the Interpretable interface method.
func (e *evalWatchAttr) Eval(vars Activation) ref.Val {
val := e.InterpretableAttribute.Eval(vars)
e.observer(e.ID(), e.InterpretableAttribute, val)
e.observer(vars, e.ID(), e.InterpretableAttribute, val)
return val
}
@@ -904,7 +942,7 @@ func (e *evalWatchConstQual) Qualify(vars Activation, obj any) (any, error) {
} else {
val = e.adapter.NativeToValue(out)
}
e.observer(e.ID(), e.ConstantQualifier, val)
e.observer(vars, e.ID(), e.ConstantQualifier, val)
return out, err
}
@@ -920,7 +958,7 @@ func (e *evalWatchConstQual) QualifyIfPresent(vars Activation, obj any, presence
val = types.Bool(present)
}
if present || presenceOnly {
e.observer(e.ID(), e.ConstantQualifier, val)
e.observer(vars, e.ID(), e.ConstantQualifier, val)
}
return out, present, err
}
@@ -947,7 +985,7 @@ func (e *evalWatchAttrQual) Qualify(vars Activation, obj any) (any, error) {
} else {
val = e.adapter.NativeToValue(out)
}
e.observer(e.ID(), e.Attribute, val)
e.observer(vars, e.ID(), e.Attribute, val)
return out, err
}
@@ -963,7 +1001,7 @@ func (e *evalWatchAttrQual) QualifyIfPresent(vars Activation, obj any, presenceO
val = types.Bool(present)
}
if present || presenceOnly {
e.observer(e.ID(), e.Attribute, val)
e.observer(vars, e.ID(), e.Attribute, val)
}
return out, present, err
}
@@ -984,7 +1022,7 @@ func (e *evalWatchQual) Qualify(vars Activation, obj any) (any, error) {
} else {
val = e.adapter.NativeToValue(out)
}
e.observer(e.ID(), e.Qualifier, val)
e.observer(vars, e.ID(), e.Qualifier, val)
return out, err
}
@@ -1000,7 +1038,7 @@ func (e *evalWatchQual) QualifyIfPresent(vars Activation, obj any, presenceOnly
val = types.Bool(present)
}
if present || presenceOnly {
e.observer(e.ID(), e.Qualifier, val)
e.observer(vars, e.ID(), e.Qualifier, val)
}
return out, present, err
}
@@ -1014,7 +1052,7 @@ type evalWatchConst struct {
// Eval implements the Interpretable interface method.
func (e *evalWatchConst) Eval(vars Activation) ref.Val {
val := e.Value()
e.observer(e.ID(), e.InterpretableConst, val)
e.observer(vars, e.ID(), e.InterpretableConst, val)
return val
}
@@ -1187,13 +1225,13 @@ func (a *evalAttr) Eval(ctx Activation) ref.Val {
}
// Qualify proxies to the Attribute's Qualify method.
func (a *evalAttr) Qualify(ctx Activation, obj any) (any, error) {
return a.attr.Qualify(ctx, obj)
func (a *evalAttr) Qualify(vars Activation, obj any) (any, error) {
return a.attr.Qualify(vars, obj)
}
// QualifyIfPresent proxies to the Attribute's QualifyIfPresent method.
func (a *evalAttr) QualifyIfPresent(ctx Activation, obj any, presenceOnly bool) (any, bool, error) {
return a.attr.QualifyIfPresent(ctx, obj, presenceOnly)
func (a *evalAttr) QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) {
return a.attr.QualifyIfPresent(vars, obj, presenceOnly)
}
func (a *evalAttr) IsOptional() bool {
@@ -1226,9 +1264,9 @@ func (c *evalWatchConstructor) ID() int64 {
}
// Eval implements the Interpretable Eval function.
func (c *evalWatchConstructor) Eval(ctx Activation) ref.Val {
val := c.constructor.Eval(ctx)
c.observer(c.ID(), c.constructor, val)
func (c *evalWatchConstructor) Eval(vars Activation) ref.Val {
val := c.constructor.Eval(vars)
c.observer(vars, c.ID(), c.constructor, val)
return val
}
@@ -1370,16 +1408,16 @@ func (f *folder) Parent() Activation {
// if they were provided to the input activation, or an empty set if the proxied activation is not partial.
func (f *folder) UnknownAttributePatterns() []*AttributePattern {
if pv, ok := f.activation.(partialActivationConverter); ok {
if partial, isPartial := pv.asPartialActivation(); isPartial {
if partial, isPartial := pv.AsPartialActivation(); isPartial {
return partial.UnknownAttributePatterns()
}
}
return []*AttributePattern{}
}
func (f *folder) asPartialActivation() (PartialActivation, bool) {
func (f *folder) AsPartialActivation() (PartialActivation, bool) {
if pv, ok := f.activation.(partialActivationConverter); ok {
if _, isPartial := pv.asPartialActivation(); isPartial {
if _, isPartial := pv.AsPartialActivation(); isPartial {
return f, true
}
}

View File

@@ -18,36 +18,41 @@
package interpreter
import (
"errors"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/containers"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
)
// PlannerOption configures the program plan options during interpretable setup.
type PlannerOption func(*planner) (*planner, error)
// Interpreter generates a new Interpretable from a checked or unchecked expression.
type Interpreter interface {
// NewInterpretable creates an Interpretable from a checked expression and an
// optional list of InterpretableDecorator values.
NewInterpretable(exprAST *ast.AST, decorators ...InterpretableDecorator) (Interpretable, error)
// optional list of PlannerOption values.
NewInterpretable(exprAST *ast.AST, opts ...PlannerOption) (Interpretable, error)
}
// EvalObserver is a functional interface that accepts an expression id and an observed value.
// The id identifies the expression that was evaluated, the programStep is the Interpretable or Qualifier that
// was evaluated and value is the result of the evaluation.
type EvalObserver func(id int64, programStep any, value ref.Val)
type EvalObserver func(vars Activation, id int64, programStep any, value ref.Val)
// Observe constructs a decorator that calls all the provided observers in order after evaluating each Interpretable
// or Qualifier during program evaluation.
func Observe(observers ...EvalObserver) InterpretableDecorator {
if len(observers) == 1 {
return decObserveEval(observers[0])
}
observeFn := func(id int64, programStep any, val ref.Val) {
for _, observer := range observers {
observer(id, programStep, val)
}
}
return decObserveEval(observeFn)
// StatefulObserver observes evaluation while tracking or utilizing stateful behavior.
type StatefulObserver interface {
// InitState configures stateful metadata on the activation.
InitState(Activation) (Activation, error)
// GetState retrieves the stateful metadata from the activation.
GetState(Activation) any
// Observe passes the activation and relevant evaluation metadata to the observer.
// The observe method is expected to do the equivalent of GetState(vars) in order
// to find the metadata that needs to be updated upon invocation.
Observe(vars Activation, id int64, programStep any, value ref.Val)
}
// EvalCancelledError represents a cancelled program evaluation operation.
@@ -73,24 +78,110 @@ const (
CostLimitExceeded
)
// TODO: Replace all usages of TrackState with EvalStateObserver
// evalStateOption configures the evalStateFactory behavior.
type evalStateOption func(*evalStateFactory) *evalStateFactory
// TrackState decorates each expression node with an observer which records the value
// associated with the given expression id. EvalState must be provided to the decorator.
// This decorator is not thread-safe, and the EvalState must be reset between Eval()
// calls.
// DEPRECATED: Please use EvalStateObserver instead. It composes gracefully with additional observers.
func TrackState(state EvalState) InterpretableDecorator {
return Observe(EvalStateObserver(state))
// EvalStateFactory configures the EvalState generator to be used by the EvalStateObserver.
func EvalStateFactory(factory func() EvalState) evalStateOption {
return func(fac *evalStateFactory) *evalStateFactory {
fac.factory = factory
return fac
}
}
// EvalStateObserver provides an observer which records the value
// associated with the given expression id. EvalState must be provided to the observer.
// This decorator is not thread-safe, and the EvalState must be reset between Eval()
// calls.
func EvalStateObserver(state EvalState) EvalObserver {
return func(id int64, programStep any, val ref.Val) {
state.SetValue(id, val)
// EvalStateObserver provides an observer which records the value associated with the given expression id.
// EvalState must be provided to the observer.
func EvalStateObserver(opts ...evalStateOption) PlannerOption {
et := &evalStateFactory{factory: NewEvalState}
for _, o := range opts {
et = o(et)
}
return func(p *planner) (*planner, error) {
if et.factory == nil {
return nil, errors.New("eval state factory not configured")
}
p.observers = append(p.observers, et)
p.decorators = append(p.decorators, decObserveEval(et.Observe))
return p, nil
}
}
// evalStateConverter identifies an object which is convertible to an EvalState instance.
type evalStateConverter interface {
asEvalState() EvalState
}
// evalStateActivation hides state in the Activation in a manner not accessible to expressions.
type evalStateActivation struct {
vars Activation
state EvalState
}
// ResolveName proxies variable lookups to the backing activation.
func (esa evalStateActivation) ResolveName(name string) (any, bool) {
return esa.vars.ResolveName(name)
}
// Parent proxies parent lookups to the backing activation.
func (esa evalStateActivation) Parent() Activation {
return esa.vars
}
// AsPartialActivation supports conversion to a partial activation in order to detect unknown attributes.
func (esa evalStateActivation) AsPartialActivation() (PartialActivation, bool) {
return AsPartialActivation(esa.vars)
}
// asEvalState implements the evalStateConverter method.
func (esa evalStateActivation) asEvalState() EvalState {
return esa.state
}
// asEvalState walks the Activation hierarchy and returns the first EvalState found, if present.
func asEvalState(vars Activation) (EvalState, bool) {
if conv, ok := vars.(evalStateConverter); ok {
return conv.asEvalState(), true
}
if vars.Parent() != nil {
return asEvalState(vars.Parent())
}
return nil, false
}
// evalStateFactory holds a reference to a factory function that produces an EvalState instance.
type evalStateFactory struct {
factory func() EvalState
}
// InitState produces an EvalState instance and bundles it into the Activation in a way which is
// not visible to expression evaluation.
func (et *evalStateFactory) InitState(vars Activation) (Activation, error) {
state := et.factory()
return evalStateActivation{vars: vars, state: state}, nil
}
// GetState extracts the EvalState from the Activation.
func (et *evalStateFactory) GetState(vars Activation) any {
if state, found := asEvalState(vars); found {
return state
}
return nil
}
// Observe records the evaluation state for a given expression node and program step.
func (et *evalStateFactory) Observe(vars Activation, id int64, programStep any, val ref.Val) {
state, found := asEvalState(vars)
if !found {
return
}
state.SetValue(id, val)
}
// CustomDecorator configures a custom interpretable decorator for the program.
func CustomDecorator(dec InterpretableDecorator) PlannerOption {
return func(p *planner) (*planner, error) {
p.decorators = append(p.decorators, dec)
return p, nil
}
}
@@ -99,11 +190,8 @@ func EvalStateObserver(state EvalState) EvalObserver {
// insight into the evaluation state of the entire expression. EvalState must be
// provided to the decorator. This decorator is not thread-safe, and the EvalState
// must be reset between Eval() calls.
func ExhaustiveEval() InterpretableDecorator {
ex := decDisableShortcircuits()
return func(i Interpretable) (Interpretable, error) {
return ex(i)
}
func ExhaustiveEval() PlannerOption {
return CustomDecorator(decDisableShortcircuits())
}
// InterruptableEval annotates comprehension loops with information that indicates they
@@ -111,14 +199,14 @@ func ExhaustiveEval() InterpretableDecorator {
//
// The custom activation is currently managed higher up in the stack within the 'cel' package
// and should not require any custom support on behalf of callers.
func InterruptableEval() InterpretableDecorator {
return decInterruptFolds()
func InterruptableEval() PlannerOption {
return CustomDecorator(decInterruptFolds())
}
// Optimize will pre-compute operations such as list and map construction and optimize
// call arguments to set membership tests. The set of optimizations will increase over time.
func Optimize() InterpretableDecorator {
return decOptimize()
func Optimize() PlannerOption {
return CustomDecorator(decOptimize())
}
// RegexOptimization provides a way to replace an InterpretableCall for a regex function when the
@@ -142,8 +230,8 @@ type RegexOptimization struct {
// CompileRegexConstants compiles regex pattern string constants at program creation time and reports any regex pattern
// compile errors.
func CompileRegexConstants(regexOptimizations ...*RegexOptimization) InterpretableDecorator {
return decRegexOptimizer(regexOptimizations...)
func CompileRegexConstants(regexOptimizations ...*RegexOptimization) PlannerOption {
return CustomDecorator(decRegexOptimizer(regexOptimizations...))
}
type exprInterpreter struct {
@@ -172,14 +260,14 @@ func NewInterpreter(dispatcher Dispatcher,
// NewIntepretable implements the Interpreter interface method.
func (i *exprInterpreter) NewInterpretable(
checked *ast.AST,
decorators ...InterpretableDecorator) (Interpretable, error) {
p := newPlanner(
i.dispatcher,
i.provider,
i.adapter,
i.attrFactory,
i.container,
checked,
decorators...)
opts ...PlannerOption) (Interpretable, error) {
p := newPlanner(i.dispatcher, i.provider, i.adapter, i.attrFactory, i.container, checked)
var err error
for _, o := range opts {
p, err = o(p)
if err != nil {
return nil, err
}
}
return p.Plan(checked.Expr())
}

View File

@@ -25,12 +25,6 @@ import (
"github.com/google/cel-go/common/types"
)
// interpretablePlanner creates an Interpretable evaluation plan from a proto Expr value.
type interpretablePlanner interface {
// Plan generates an Interpretable value (or error) from the input proto Expr.
Plan(expr ast.Expr) (Interpretable, error)
}
// newPlanner creates an interpretablePlanner which references a Dispatcher, TypeProvider,
// TypeAdapter, Container, and CheckedExpr value. These pieces of data are used to resolve
// functions, types, and namespaced identifiers at plan time rather than at runtime since
@@ -40,8 +34,7 @@ func newPlanner(disp Dispatcher,
adapter types.Adapter,
attrFactory AttributeFactory,
cont *containers.Container,
exprAST *ast.AST,
decorators ...InterpretableDecorator) interpretablePlanner {
exprAST *ast.AST) *planner {
return &planner{
disp: disp,
provider: provider,
@@ -50,7 +43,8 @@ func newPlanner(disp Dispatcher,
container: cont,
refMap: exprAST.ReferenceMap(),
typeMap: exprAST.TypeMap(),
decorators: decorators,
decorators: make([]InterpretableDecorator, 0),
observers: make([]StatefulObserver, 0),
}
}
@@ -64,6 +58,7 @@ type planner struct {
refMap map[int64]*ast.ReferenceInfo
typeMap map[int64]*types.Type
decorators []InterpretableDecorator
observers []StatefulObserver
}
// Plan implements the interpretablePlanner interface. This implementation of the Plan method also
@@ -72,6 +67,17 @@ type planner struct {
// such as state-tracking, expression re-write, and possibly efficient thread-safe memoization of
// repeated expressions.
func (p *planner) Plan(expr ast.Expr) (Interpretable, error) {
i, err := p.plan(expr)
if err != nil {
return nil, err
}
if len(p.observers) == 0 {
return i, nil
}
return &ObservableInterpretable{Interpretable: i, observers: p.observers}, nil
}
func (p *planner) plan(expr ast.Expr) (Interpretable, error) {
switch expr.Kind() {
case ast.CallKind:
return p.decorate(p.planCall(expr))
@@ -161,7 +167,7 @@ func (p *planner) planSelect(expr ast.Expr) (Interpretable, error) {
sel := expr.AsSelect()
// Plan the operand evaluation.
op, err := p.Plan(sel.Operand())
op, err := p.plan(sel.Operand())
if err != nil {
return nil, err
}
@@ -220,14 +226,14 @@ func (p *planner) planCall(expr ast.Expr) (Interpretable, error) {
args := make([]Interpretable, argCount)
if target != nil {
arg, err := p.Plan(target)
arg, err := p.plan(target)
if err != nil {
return nil, err
}
args[0] = arg
}
for i, argExpr := range call.Args() {
arg, err := p.Plan(argExpr)
arg, err := p.plan(argExpr)
if err != nil {
return nil, err
}
@@ -496,7 +502,7 @@ func (p *planner) planCreateList(expr ast.Expr) (Interpretable, error) {
}
elems := make([]Interpretable, len(elements))
for i, elem := range elements {
elemVal, err := p.Plan(elem)
elemVal, err := p.plan(elem)
if err != nil {
return nil, err
}
@@ -521,13 +527,13 @@ func (p *planner) planCreateMap(expr ast.Expr) (Interpretable, error) {
hasOptionals := false
for i, e := range entries {
entry := e.AsMapEntry()
keyVal, err := p.Plan(entry.Key())
keyVal, err := p.plan(entry.Key())
if err != nil {
return nil, err
}
keys[i] = keyVal
valVal, err := p.Plan(entry.Value())
valVal, err := p.plan(entry.Value())
if err != nil {
return nil, err
}
@@ -560,7 +566,7 @@ func (p *planner) planCreateStruct(expr ast.Expr) (Interpretable, error) {
for i, f := range objFields {
field := f.AsStructField()
fields[i] = field.Name()
val, err := p.Plan(field.Value())
val, err := p.plan(field.Value())
if err != nil {
return nil, err
}
@@ -582,23 +588,23 @@ func (p *planner) planCreateStruct(expr ast.Expr) (Interpretable, error) {
// planComprehension generates an Interpretable fold operation.
func (p *planner) planComprehension(expr ast.Expr) (Interpretable, error) {
fold := expr.AsComprehension()
accu, err := p.Plan(fold.AccuInit())
accu, err := p.plan(fold.AccuInit())
if err != nil {
return nil, err
}
iterRange, err := p.Plan(fold.IterRange())
iterRange, err := p.plan(fold.IterRange())
if err != nil {
return nil, err
}
cond, err := p.Plan(fold.LoopCondition())
cond, err := p.plan(fold.LoopCondition())
if err != nil {
return nil, err
}
step, err := p.Plan(fold.LoopStep())
step, err := p.plan(fold.LoopStep())
if err != nil {
return nil, err
}
result, err := p.Plan(fold.Result())
result, err := p.plan(fold.Result())
if err != nil {
return nil, err
}

View File

@@ -15,6 +15,7 @@
package interpreter
import (
"errors"
"math"
"github.com/google/cel-go/common"
@@ -34,78 +35,172 @@ type ActualCostEstimator interface {
CallCost(function, overloadID string, args []ref.Val, result ref.Val) *uint64
}
// CostObserver provides an observer that tracks runtime cost.
func CostObserver(tracker *CostTracker) EvalObserver {
observer := func(id int64, programStep any, val ref.Val) {
switch t := programStep.(type) {
case ConstantQualifier:
// TODO: Push identifiers on to the stack before observing constant qualifiers that apply to them
// and enable the below pop. Once enabled this can case can be collapsed into the Qualifier case.
tracker.cost++
case InterpretableConst:
// zero cost
case InterpretableAttribute:
switch a := t.Attr().(type) {
case *conditionalAttribute:
// Ternary has no direct cost. All cost is from the conditional and the true/false branch expressions.
tracker.stack.drop(a.falsy.ID(), a.truthy.ID(), a.expr.ID())
default:
tracker.stack.drop(t.Attr().ID())
tracker.cost += common.SelectAndIdentCost
}
if !tracker.presenceTestHasCost {
if _, isTestOnly := programStep.(*evalTestOnly); isTestOnly {
tracker.cost -= common.SelectAndIdentCost
}
}
case *evalExhaustiveConditional:
// Ternary has no direct cost. All cost is from the conditional and the true/false branch expressions.
tracker.stack.drop(t.attr.falsy.ID(), t.attr.truthy.ID(), t.attr.expr.ID())
// costTrackPlanOption modifies the cost tracking factory associatied with the CostObserver
type costTrackPlanOption func(*costTrackerFactory) *costTrackerFactory
// While the field names are identical, the boolean operation eval structs do not share an interface and so
// must be handled individually.
case *evalOr:
for _, term := range t.terms {
tracker.stack.drop(term.ID())
}
case *evalAnd:
for _, term := range t.terms {
tracker.stack.drop(term.ID())
}
case *evalExhaustiveOr:
for _, term := range t.terms {
tracker.stack.drop(term.ID())
}
case *evalExhaustiveAnd:
for _, term := range t.terms {
tracker.stack.drop(term.ID())
}
case *evalFold:
tracker.stack.drop(t.iterRange.ID())
case Qualifier:
tracker.cost++
case InterpretableCall:
if argVals, ok := tracker.stack.dropArgs(t.Args()); ok {
tracker.cost += tracker.costCall(t, argVals, val)
}
case InterpretableConstructor:
tracker.stack.dropArgs(t.InitVals())
switch t.Type() {
case types.ListType:
tracker.cost += common.ListCreateBaseCost
case types.MapType:
tracker.cost += common.MapCreateBaseCost
default:
tracker.cost += common.StructCreateBaseCost
// CostTrackerFactory configures the factory method to generate a new cost-tracker per-evaluation.
func CostTrackerFactory(factory func() (*CostTracker, error)) costTrackPlanOption {
return func(fac *costTrackerFactory) *costTrackerFactory {
fac.factory = factory
return fac
}
}
// CostObserver provides an observer that tracks runtime cost.
func CostObserver(opts ...costTrackPlanOption) PlannerOption {
ct := &costTrackerFactory{}
for _, o := range opts {
ct = o(ct)
}
return func(p *planner) (*planner, error) {
if ct.factory == nil {
return nil, errors.New("cost tracker factory not configured")
}
p.observers = append(p.observers, ct)
p.decorators = append(p.decorators, decObserveEval(ct.Observe))
return p, nil
}
}
// costTrackerConverter identifies an object which is convertible to a CostTracker instance.
type costTrackerConverter interface {
asCostTracker() *CostTracker
}
// costTrackActivation hides state in the Activation in a manner not accessible to expressions.
type costTrackActivation struct {
vars Activation
costTracker *CostTracker
}
// ResolveName proxies variable lookups to the backing activation.
func (cta costTrackActivation) ResolveName(name string) (any, bool) {
return cta.vars.ResolveName(name)
}
// Parent proxies parent lookups to the backing activation.
func (cta costTrackActivation) Parent() Activation {
return cta.vars
}
// AsPartialActivation supports conversion to a partial activation in order to detect unknown attributes.
func (cta costTrackActivation) AsPartialActivation() (PartialActivation, bool) {
return AsPartialActivation(cta.vars)
}
// asCostTracker implements the costTrackerConverter method.
func (cta costTrackActivation) asCostTracker() *CostTracker {
return cta.costTracker
}
// asCostTracker walks the Activation hierarchy and returns the first cost tracker found, if present.
func asCostTracker(vars Activation) (*CostTracker, bool) {
if conv, ok := vars.(costTrackerConverter); ok {
return conv.asCostTracker(), true
}
if vars.Parent() != nil {
return asCostTracker(vars.Parent())
}
return nil, false
}
// costTrackerFactory holds a factory for producing new CostTracker instances on each Eval call.
type costTrackerFactory struct {
factory func() (*CostTracker, error)
}
// InitState produces a CostTracker and bundles it into an Activation in a way which is not visible
// to expression evaluation.
func (ct *costTrackerFactory) InitState(vars Activation) (Activation, error) {
tracker, err := ct.factory()
if err != nil {
return nil, err
}
return costTrackActivation{vars: vars, costTracker: tracker}, nil
}
// GetState extracts the CostTracker from the Activation.
func (ct *costTrackerFactory) GetState(vars Activation) any {
if tracker, found := asCostTracker(vars); found {
return tracker
}
return nil
}
// Observe computes the incremental cost of each step and records it into the CostTracker associated
// with the evaluation.
func (ct *costTrackerFactory) Observe(vars Activation, id int64, programStep any, val ref.Val) {
tracker, found := asCostTracker(vars)
if !found {
return
}
switch t := programStep.(type) {
case ConstantQualifier:
// TODO: Push identifiers on to the stack before observing constant qualifiers that apply to them
// and enable the below pop. Once enabled this can case can be collapsed into the Qualifier case.
tracker.cost++
case InterpretableConst:
// zero cost
case InterpretableAttribute:
switch a := t.Attr().(type) {
case *conditionalAttribute:
// Ternary has no direct cost. All cost is from the conditional and the true/false branch expressions.
tracker.stack.drop(a.falsy.ID(), a.truthy.ID(), a.expr.ID())
default:
tracker.stack.drop(t.Attr().ID())
tracker.cost += common.SelectAndIdentCost
}
if !tracker.presenceTestHasCost {
if _, isTestOnly := programStep.(*evalTestOnly); isTestOnly {
tracker.cost -= common.SelectAndIdentCost
}
}
tracker.stack.push(val, id)
case *evalExhaustiveConditional:
// Ternary has no direct cost. All cost is from the conditional and the true/false branch expressions.
tracker.stack.drop(t.attr.falsy.ID(), t.attr.truthy.ID(), t.attr.expr.ID())
if tracker.Limit != nil && tracker.cost > *tracker.Limit {
panic(EvalCancelledError{Cause: CostLimitExceeded, Message: "operation cancelled: actual cost limit exceeded"})
// While the field names are identical, the boolean operation eval structs do not share an interface and so
// must be handled individually.
case *evalOr:
for _, term := range t.terms {
tracker.stack.drop(term.ID())
}
case *evalAnd:
for _, term := range t.terms {
tracker.stack.drop(term.ID())
}
case *evalExhaustiveOr:
for _, term := range t.terms {
tracker.stack.drop(term.ID())
}
case *evalExhaustiveAnd:
for _, term := range t.terms {
tracker.stack.drop(term.ID())
}
case *evalFold:
tracker.stack.drop(t.iterRange.ID())
case Qualifier:
tracker.cost++
case InterpretableCall:
if argVals, ok := tracker.stack.dropArgs(t.Args()); ok {
tracker.cost += tracker.costCall(t, argVals, val)
}
case InterpretableConstructor:
tracker.stack.dropArgs(t.InitVals())
switch t.Type() {
case types.ListType:
tracker.cost += common.ListCreateBaseCost
case types.MapType:
tracker.cost += common.MapCreateBaseCost
default:
tracker.cost += common.StructCreateBaseCost
}
}
return observer
tracker.stack.push(val, id)
if tracker.Limit != nil && tracker.cost > *tracker.Limit {
panic(EvalCancelledError{Cause: CostLimitExceeded, Message: "operation cancelled: actual cost limit exceeded"})
}
}
// CostTrackerOption configures the behavior of CostTracker objects.

View File

@@ -24,38 +24,74 @@ import (
"github.com/google/cel-go/common/types/ref"
)
// MacroOpt defines a functional option for configuring macro behavior.
type MacroOpt func(*macro) *macro
// MacroDocs configures a list of strings into a multiline description for the macro.
func MacroDocs(docs ...string) MacroOpt {
return func(m *macro) *macro {
m.doc = common.MultilineDescription(docs...)
return m
}
}
// MacroExamples configures a list of examples, either as a string or common.MultilineString,
// into an example set to be provided with the macro Documentation() call.
func MacroExamples(examples ...string) MacroOpt {
return func(m *macro) *macro {
m.examples = examples
return m
}
}
// NewGlobalMacro creates a Macro for a global function with the specified arg count.
func NewGlobalMacro(function string, argCount int, expander MacroExpander) Macro {
return &macro{
func NewGlobalMacro(function string, argCount int, expander MacroExpander, opts ...MacroOpt) Macro {
m := &macro{
function: function,
argCount: argCount,
expander: expander}
for _, opt := range opts {
m = opt(m)
}
return m
}
// NewReceiverMacro creates a Macro for a receiver function matching the specified arg count.
func NewReceiverMacro(function string, argCount int, expander MacroExpander) Macro {
return &macro{
func NewReceiverMacro(function string, argCount int, expander MacroExpander, opts ...MacroOpt) Macro {
m := &macro{
function: function,
argCount: argCount,
expander: expander,
receiverStyle: true}
for _, opt := range opts {
m = opt(m)
}
return m
}
// NewGlobalVarArgMacro creates a Macro for a global function with a variable arg count.
func NewGlobalVarArgMacro(function string, expander MacroExpander) Macro {
return &macro{
func NewGlobalVarArgMacro(function string, expander MacroExpander, opts ...MacroOpt) Macro {
m := &macro{
function: function,
expander: expander,
varArgStyle: true}
for _, opt := range opts {
m = opt(m)
}
return m
}
// NewReceiverVarArgMacro creates a Macro for a receiver function matching a variable arg count.
func NewReceiverVarArgMacro(function string, expander MacroExpander) Macro {
return &macro{
func NewReceiverVarArgMacro(function string, expander MacroExpander, opts ...MacroOpt) Macro {
m := &macro{
function: function,
expander: expander,
receiverStyle: true,
varArgStyle: true}
for _, opt := range opts {
m = opt(m)
}
return m
}
// Macro interface for describing the function signature to match and the MacroExpander to apply.
@@ -95,6 +131,8 @@ type macro struct {
varArgStyle bool
argCount int
expander MacroExpander
doc string
examples []string
}
// Function returns the macro's function name (i.e. the function whose syntax it mimics).
@@ -125,6 +163,15 @@ func (m *macro) MacroKey() string {
return makeMacroKey(m.function, m.argCount, m.receiverStyle)
}
// Documentation generates documentation and examples for the macro.
func (m *macro) Documentation() *common.Doc {
examples := make([]*common.Doc, len(m.examples))
for i, ex := range m.examples {
examples[i] = common.NewExampleDoc(ex)
}
return common.NewMacroDoc(m.Function(), m.doc, examples...)
}
func makeMacroKey(name string, args int, receiverStyle bool) string {
return fmt.Sprintf("%s:%d:%v", name, args, receiverStyle)
}
@@ -250,37 +297,139 @@ type ExprHelper interface {
var (
// HasMacro expands "has(m.f)" which tests the presence of a field, avoiding the need to
// specify the field as a string.
HasMacro = NewGlobalMacro(operators.Has, 1, MakeHas)
HasMacro = NewGlobalMacro(operators.Has, 1, MakeHas,
MacroDocs(
`check a protocol buffer message for the presence of a field, or check a map`,
`for the presence of a string key.`,
`Only map accesses using the select notation are supported.`),
MacroExamples(
common.MultilineDescription(
`// true if the 'address' field exists in the 'user' message`,
`has(user.address)`),
common.MultilineDescription(
`// test whether the 'key_name' is set on the map which defines it`,
`has({'key_name': 'value'}.key_name) // true`),
common.MultilineDescription(
`// test whether the 'id' field is set to a non-default value on the Expr{} message literal`,
`has(Expr{}.id) // false`),
))
// AllMacro expands "range.all(var, predicate)" into a comprehension which ensures that all
// elements in the range satisfy the predicate.
AllMacro = NewReceiverMacro(operators.All, 2, MakeAll)
AllMacro = NewReceiverMacro(operators.All, 2, MakeAll,
MacroDocs(`tests whether all elements in the input list or all keys in a map`,
`satisfy the given predicate. The all macro behaves in a manner consistent with`,
`the Logical AND operator including in how it absorbs errors and short-circuits.`),
MacroExamples(
`[1, 2, 3].all(x, x > 0) // true`,
`[1, 2, 0].all(x, x > 0) // false`,
`['apple', 'banana', 'cherry'].all(fruit, fruit.size() > 3) // true`,
`[3.14, 2.71, 1.61].all(num, num < 3.0) // false`,
`{'a': 1, 'b': 2, 'c': 3}.all(key, key != 'b') // false`,
common.MultilineDescription(
`// an empty list or map as the range will result in a trivially true result`,
`[].all(x, x > 0) // true`),
))
// ExistsMacro expands "range.exists(var, predicate)" into a comprehension which ensures that
// some element in the range satisfies the predicate.
ExistsMacro = NewReceiverMacro(operators.Exists, 2, MakeExists)
ExistsMacro = NewReceiverMacro(operators.Exists, 2, MakeExists,
MacroDocs(`tests whether any value in the list or any key in the map`,
`satisfies the predicate expression. The exists macro behaves in a manner`,
`consistent with the Logical OR operator including in how it absorbs errors and`,
`short-circuits.`),
MacroExamples(
`[1, 2, 3].exists(i, i % 2 != 0) // true`,
`[0, -1, 5].exists(num, num < 0) // true`,
`{'x': 'foo', 'y': 'bar'}.exists(key, key.startsWith('z')) // false`,
common.MultilineDescription(
`// an empty list or map as the range will result in a trivially false result`,
`[].exists(i, i > 0) // false`),
common.MultilineDescription(
`// test whether a key name equalling 'iss' exists in the map and the`,
`// value contains the substring 'cel.dev'`,
`// tokens = {'sub': 'me', 'iss': 'https://issuer.cel.dev'}`,
`tokens.exists(k, k == 'iss' && tokens[k].contains('cel.dev'))`),
))
// ExistsOneMacro expands "range.exists_one(var, predicate)", which is true if for exactly one
// element in range the predicate holds.
// Deprecated: Use ExistsOneMacroNew
ExistsOneMacro = NewReceiverMacro(operators.ExistsOne, 2, MakeExistsOne)
ExistsOneMacro = NewReceiverMacro(operators.ExistsOne, 2, MakeExistsOne,
MacroDocs(`tests whether exactly one list element or map key satisfies`,
`the predicate expression. This macro does not short-circuit in order to remain`,
`consistent with logical operators being the only operators which can absorb`,
`errors within CEL.`),
MacroExamples(
`[1, 2, 2].exists_one(i, i < 2) // true`,
`{'a': 'hello', 'aa': 'hellohello'}.exists_one(k, k.startsWith('a')) // false`,
`[1, 2, 3, 4].exists_one(num, num % 2 == 0) // false`,
common.MultilineDescription(
`// ensure exactly one key in the map ends in @acme.co`,
`{'wiley@acme.co': 'coyote', 'aa@milne.co': 'bear'}.exists_one(k, k.endsWith('@acme.co')) // true`),
))
// ExistsOneMacroNew expands "range.existsOne(var, predicate)", which is true if for exactly one
// element in range the predicate holds.
ExistsOneMacroNew = NewReceiverMacro("existsOne", 2, MakeExistsOne)
ExistsOneMacroNew = NewReceiverMacro("existsOne", 2, MakeExistsOne,
MacroDocs(
`tests whether exactly one list element or map key satisfies the predicate`,
`expression. This macro does not short-circuit in order to remain consistent`,
`with logical operators being the only operators which can absorb errors`,
`within CEL.`),
MacroExamples(
`[1, 2, 2].existsOne(i, i < 2) // true`,
`{'a': 'hello', 'aa': 'hellohello'}.existsOne(k, k.startsWith('a')) // false`,
`[1, 2, 3, 4].existsOne(num, num % 2 == 0) // false`,
common.MultilineDescription(
`// ensure exactly one key in the map ends in @acme.co`,
`{'wiley@acme.co': 'coyote', 'aa@milne.co': 'bear'}.existsOne(k, k.endsWith('@acme.co')) // true`),
))
// MapMacro expands "range.map(var, function)" into a comprehension which applies the function
// to each element in the range to produce a new list.
MapMacro = NewReceiverMacro(operators.Map, 2, MakeMap)
MapMacro = NewReceiverMacro(operators.Map, 2, MakeMap,
MacroDocs("the three-argument form of map transforms all elements in the input range."),
MacroExamples(
`[1, 2, 3].map(x, x * 2) // [2, 4, 6]`,
`[5, 10, 15].map(x, x / 5) // [1, 2, 3]`,
`['apple', 'banana'].map(fruit, fruit.upperAscii()) // ['APPLE', 'BANANA']`,
common.MultilineDescription(
`// Combine all map key-value pairs into a list`,
`{'hi': 'you', 'howzit': 'bruv'}.map(k,`,
` k + ":" + {'hi': 'you', 'howzit': 'bruv'}[k]) // ['hi:you', 'howzit:bruv']`),
))
// MapFilterMacro expands "range.map(var, predicate, function)" into a comprehension which
// first filters the elements in the range by the predicate, then applies the transform function
// to produce a new list.
MapFilterMacro = NewReceiverMacro(operators.Map, 3, MakeMap)
MapFilterMacro = NewReceiverMacro(operators.Map, 3, MakeMap,
MacroDocs(`the four-argument form of the map transforms only elements which satisfy`,
`the predicate which is equivalent to chaining the filter and three-argument`,
`map macros together.`),
MacroExamples(
common.MultilineDescription(
`// multiply only numbers divisible two, by 2`,
`[1, 2, 3, 4].map(num, num % 2 == 0, num * 2) // [4, 8]`),
))
// FilterMacro expands "range.filter(var, predicate)" into a comprehension which filters
// elements in the range, producing a new list from the elements that satisfy the predicate.
FilterMacro = NewReceiverMacro(operators.Filter, 2, MakeFilter)
FilterMacro = NewReceiverMacro(operators.Filter, 2, MakeFilter,
MacroDocs(`returns a list containing only the elements from the input list`,
`that satisfy the given predicate`),
MacroExamples(
`[1, 2, 3].filter(x, x > 1) // [2, 3]`,
`['cat', 'dog', 'bird', 'fish'].filter(pet, pet.size() == 3) // ['cat', 'dog']`,
`[{'a': 10, 'b': 5, 'c': 20}].map(m, m.filter(key, m[key] > 10)) // [['c']]`,
common.MultilineDescription(
`// filter a list to select only emails with the @cel.dev suffix`,
`['alice@buf.io', 'tristan@cel.dev'].filter(v, v.endsWith('@cel.dev')) // ['tristan@cel.dev']`),
common.MultilineDescription(
`// filter a map into a list, selecting only the values for keys that start with 'http-auth'`,
`{'http-auth-agent': 'secret', 'user-agent': 'mozilla'}.filter(k,`,
` k.startsWith('http-auth')) // ['secret']`),
))
// AllMacros includes the list of all spec-supported macros.
AllMacros = []Macro{

View File

@@ -15,7 +15,7 @@
package compiler
import (
yaml "gopkg.in/yaml.v3"
yaml "go.yaml.in/yaml/v3"
)
// Context contains state of the compiler as it traverses a document.

View File

@@ -20,9 +20,9 @@ import (
"os/exec"
"strings"
yaml "go.yaml.in/yaml/v3"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
yaml "gopkg.in/yaml.v3"
extensions "github.com/google/gnostic-models/extensions"
)

View File

@@ -20,7 +20,7 @@ import (
"sort"
"strconv"
"gopkg.in/yaml.v3"
yaml "go.yaml.in/yaml/v3"
"github.com/google/gnostic-models/jsonschema"
)

View File

@@ -24,7 +24,7 @@ import (
"strings"
"sync"
yaml "gopkg.in/yaml.v3"
yaml "go.yaml.in/yaml/v3"
)
var verboseReader = false

View File

@@ -16,7 +16,7 @@
// of JSON Schemas.
package jsonschema
import "gopkg.in/yaml.v3"
import "go.yaml.in/yaml/v3"
// The Schema struct models a JSON Schema and, because schemas are
// defined hierarchically, contains many references to itself.

View File

@@ -21,7 +21,7 @@ import (
"io/ioutil"
"strconv"
"gopkg.in/yaml.v3"
yaml "go.yaml.in/yaml/v3"
)
// This is a global map of all known Schemas.

View File

@@ -17,7 +17,7 @@ package jsonschema
import (
"fmt"
"gopkg.in/yaml.v3"
yaml "go.yaml.in/yaml/v3"
)
const indentation = " "

View File

@@ -21,7 +21,7 @@ import (
"regexp"
"strings"
"gopkg.in/yaml.v3"
yaml "go.yaml.in/yaml/v3"
"github.com/google/gnostic-models/compiler"
)
@@ -60,7 +60,7 @@ func NewAdditionalPropertiesItem(in *yaml.Node, context *compiler.Context) (*Add
// since the oneof matched one of its possibilities, discard any matching errors
errors = make([]error, 0)
} else {
message := fmt.Sprintf("contains an invalid AdditionalPropertiesItem")
message := "contains an invalid AdditionalPropertiesItem"
err := compiler.NewError(context, message)
errors = []error{err}
}
@@ -2543,7 +2543,7 @@ func NewNonBodyParameter(in *yaml.Node, context *compiler.Context) (*NonBodyPara
// since the oneof matched one of its possibilities, discard any matching errors
errors = make([]error, 0)
} else {
message := fmt.Sprintf("contains an invalid NonBodyParameter")
message := "contains an invalid NonBodyParameter"
err := compiler.NewError(context, message)
errors = []error{err}
}
@@ -3271,7 +3271,7 @@ func NewParameter(in *yaml.Node, context *compiler.Context) (*Parameter, error)
// since the oneof matched one of its possibilities, discard any matching errors
errors = make([]error, 0)
} else {
message := fmt.Sprintf("contains an invalid Parameter")
message := "contains an invalid Parameter"
err := compiler.NewError(context, message)
errors = []error{err}
}
@@ -3345,7 +3345,7 @@ func NewParametersItem(in *yaml.Node, context *compiler.Context) (*ParametersIte
// since the oneof matched one of its possibilities, discard any matching errors
errors = make([]error, 0)
} else {
message := fmt.Sprintf("contains an invalid ParametersItem")
message := "contains an invalid ParametersItem"
err := compiler.NewError(context, message)
errors = []error{err}
}
@@ -4561,7 +4561,7 @@ func NewResponseValue(in *yaml.Node, context *compiler.Context) (*ResponseValue,
// since the oneof matched one of its possibilities, discard any matching errors
errors = make([]error, 0)
} else {
message := fmt.Sprintf("contains an invalid ResponseValue")
message := "contains an invalid ResponseValue"
err := compiler.NewError(context, message)
errors = []error{err}
}
@@ -5030,7 +5030,7 @@ func NewSchemaItem(in *yaml.Node, context *compiler.Context) (*SchemaItem, error
// since the oneof matched one of its possibilities, discard any matching errors
errors = make([]error, 0)
} else {
message := fmt.Sprintf("contains an invalid SchemaItem")
message := "contains an invalid SchemaItem"
err := compiler.NewError(context, message)
errors = []error{err}
}
@@ -5160,7 +5160,7 @@ func NewSecurityDefinitionsItem(in *yaml.Node, context *compiler.Context) (*Secu
// since the oneof matched one of its possibilities, discard any matching errors
errors = make([]error, 0)
} else {
message := fmt.Sprintf("contains an invalid SecurityDefinitionsItem")
message := "contains an invalid SecurityDefinitionsItem"
err := compiler.NewError(context, message)
errors = []error{err}
}
@@ -6930,7 +6930,7 @@ func (m *BodyParameter) ToRawInfo() *yaml.Node {
// always include this required field.
info.Content = append(info.Content, compiler.NewScalarNodeForString("in"))
info.Content = append(info.Content, compiler.NewScalarNodeForString(m.In))
if m.Required != false {
if m.Required {
info.Content = append(info.Content, compiler.NewScalarNodeForString("required"))
info.Content = append(info.Content, compiler.NewScalarNodeForBool(m.Required))
}
@@ -7149,7 +7149,7 @@ func (m *FileSchema) ToRawInfo() *yaml.Node {
// always include this required field.
info.Content = append(info.Content, compiler.NewScalarNodeForString("type"))
info.Content = append(info.Content, compiler.NewScalarNodeForString(m.Type))
if m.ReadOnly != false {
if m.ReadOnly {
info.Content = append(info.Content, compiler.NewScalarNodeForString("readOnly"))
info.Content = append(info.Content, compiler.NewScalarNodeForBool(m.ReadOnly))
}
@@ -7176,7 +7176,7 @@ func (m *FormDataParameterSubSchema) ToRawInfo() *yaml.Node {
if m == nil {
return info
}
if m.Required != false {
if m.Required {
info.Content = append(info.Content, compiler.NewScalarNodeForString("required"))
info.Content = append(info.Content, compiler.NewScalarNodeForBool(m.Required))
}
@@ -7192,7 +7192,7 @@ func (m *FormDataParameterSubSchema) ToRawInfo() *yaml.Node {
info.Content = append(info.Content, compiler.NewScalarNodeForString("name"))
info.Content = append(info.Content, compiler.NewScalarNodeForString(m.Name))
}
if m.AllowEmptyValue != false {
if m.AllowEmptyValue {
info.Content = append(info.Content, compiler.NewScalarNodeForString("allowEmptyValue"))
info.Content = append(info.Content, compiler.NewScalarNodeForBool(m.AllowEmptyValue))
}
@@ -7220,7 +7220,7 @@ func (m *FormDataParameterSubSchema) ToRawInfo() *yaml.Node {
info.Content = append(info.Content, compiler.NewScalarNodeForString("maximum"))
info.Content = append(info.Content, compiler.NewScalarNodeForFloat(m.Maximum))
}
if m.ExclusiveMaximum != false {
if m.ExclusiveMaximum {
info.Content = append(info.Content, compiler.NewScalarNodeForString("exclusiveMaximum"))
info.Content = append(info.Content, compiler.NewScalarNodeForBool(m.ExclusiveMaximum))
}
@@ -7228,7 +7228,7 @@ func (m *FormDataParameterSubSchema) ToRawInfo() *yaml.Node {
info.Content = append(info.Content, compiler.NewScalarNodeForString("minimum"))
info.Content = append(info.Content, compiler.NewScalarNodeForFloat(m.Minimum))
}
if m.ExclusiveMinimum != false {
if m.ExclusiveMinimum {
info.Content = append(info.Content, compiler.NewScalarNodeForString("exclusiveMinimum"))
info.Content = append(info.Content, compiler.NewScalarNodeForBool(m.ExclusiveMinimum))
}
@@ -7252,7 +7252,7 @@ func (m *FormDataParameterSubSchema) ToRawInfo() *yaml.Node {
info.Content = append(info.Content, compiler.NewScalarNodeForString("minItems"))
info.Content = append(info.Content, compiler.NewScalarNodeForInt(m.MinItems))
}
if m.UniqueItems != false {
if m.UniqueItems {
info.Content = append(info.Content, compiler.NewScalarNodeForString("uniqueItems"))
info.Content = append(info.Content, compiler.NewScalarNodeForBool(m.UniqueItems))
}
@@ -7306,7 +7306,7 @@ func (m *Header) ToRawInfo() *yaml.Node {
info.Content = append(info.Content, compiler.NewScalarNodeForString("maximum"))
info.Content = append(info.Content, compiler.NewScalarNodeForFloat(m.Maximum))
}
if m.ExclusiveMaximum != false {
if m.ExclusiveMaximum {
info.Content = append(info.Content, compiler.NewScalarNodeForString("exclusiveMaximum"))
info.Content = append(info.Content, compiler.NewScalarNodeForBool(m.ExclusiveMaximum))
}
@@ -7314,7 +7314,7 @@ func (m *Header) ToRawInfo() *yaml.Node {
info.Content = append(info.Content, compiler.NewScalarNodeForString("minimum"))
info.Content = append(info.Content, compiler.NewScalarNodeForFloat(m.Minimum))
}
if m.ExclusiveMinimum != false {
if m.ExclusiveMinimum {
info.Content = append(info.Content, compiler.NewScalarNodeForString("exclusiveMinimum"))
info.Content = append(info.Content, compiler.NewScalarNodeForBool(m.ExclusiveMinimum))
}
@@ -7338,7 +7338,7 @@ func (m *Header) ToRawInfo() *yaml.Node {
info.Content = append(info.Content, compiler.NewScalarNodeForString("minItems"))
info.Content = append(info.Content, compiler.NewScalarNodeForInt(m.MinItems))
}
if m.UniqueItems != false {
if m.UniqueItems {
info.Content = append(info.Content, compiler.NewScalarNodeForString("uniqueItems"))
info.Content = append(info.Content, compiler.NewScalarNodeForBool(m.UniqueItems))
}
@@ -7373,7 +7373,7 @@ func (m *HeaderParameterSubSchema) ToRawInfo() *yaml.Node {
if m == nil {
return info
}
if m.Required != false {
if m.Required {
info.Content = append(info.Content, compiler.NewScalarNodeForString("required"))
info.Content = append(info.Content, compiler.NewScalarNodeForBool(m.Required))
}
@@ -7413,7 +7413,7 @@ func (m *HeaderParameterSubSchema) ToRawInfo() *yaml.Node {
info.Content = append(info.Content, compiler.NewScalarNodeForString("maximum"))
info.Content = append(info.Content, compiler.NewScalarNodeForFloat(m.Maximum))
}
if m.ExclusiveMaximum != false {
if m.ExclusiveMaximum {
info.Content = append(info.Content, compiler.NewScalarNodeForString("exclusiveMaximum"))
info.Content = append(info.Content, compiler.NewScalarNodeForBool(m.ExclusiveMaximum))
}
@@ -7421,7 +7421,7 @@ func (m *HeaderParameterSubSchema) ToRawInfo() *yaml.Node {
info.Content = append(info.Content, compiler.NewScalarNodeForString("minimum"))
info.Content = append(info.Content, compiler.NewScalarNodeForFloat(m.Minimum))
}
if m.ExclusiveMinimum != false {
if m.ExclusiveMinimum {
info.Content = append(info.Content, compiler.NewScalarNodeForString("exclusiveMinimum"))
info.Content = append(info.Content, compiler.NewScalarNodeForBool(m.ExclusiveMinimum))
}
@@ -7445,7 +7445,7 @@ func (m *HeaderParameterSubSchema) ToRawInfo() *yaml.Node {
info.Content = append(info.Content, compiler.NewScalarNodeForString("minItems"))
info.Content = append(info.Content, compiler.NewScalarNodeForInt(m.MinItems))
}
if m.UniqueItems != false {
if m.UniqueItems {
info.Content = append(info.Content, compiler.NewScalarNodeForString("uniqueItems"))
info.Content = append(info.Content, compiler.NewScalarNodeForBool(m.UniqueItems))
}
@@ -7940,7 +7940,7 @@ func (m *Operation) ToRawInfo() *yaml.Node {
info.Content = append(info.Content, compiler.NewScalarNodeForString("schemes"))
info.Content = append(info.Content, compiler.NewSequenceNodeForStringArray(m.Schemes))
}
if m.Deprecated != false {
if m.Deprecated {
info.Content = append(info.Content, compiler.NewScalarNodeForString("deprecated"))
info.Content = append(info.Content, compiler.NewScalarNodeForBool(m.Deprecated))
}
@@ -8110,7 +8110,7 @@ func (m *PathParameterSubSchema) ToRawInfo() *yaml.Node {
info.Content = append(info.Content, compiler.NewScalarNodeForString("maximum"))
info.Content = append(info.Content, compiler.NewScalarNodeForFloat(m.Maximum))
}
if m.ExclusiveMaximum != false {
if m.ExclusiveMaximum {
info.Content = append(info.Content, compiler.NewScalarNodeForString("exclusiveMaximum"))
info.Content = append(info.Content, compiler.NewScalarNodeForBool(m.ExclusiveMaximum))
}
@@ -8118,7 +8118,7 @@ func (m *PathParameterSubSchema) ToRawInfo() *yaml.Node {
info.Content = append(info.Content, compiler.NewScalarNodeForString("minimum"))
info.Content = append(info.Content, compiler.NewScalarNodeForFloat(m.Minimum))
}
if m.ExclusiveMinimum != false {
if m.ExclusiveMinimum {
info.Content = append(info.Content, compiler.NewScalarNodeForString("exclusiveMinimum"))
info.Content = append(info.Content, compiler.NewScalarNodeForBool(m.ExclusiveMinimum))
}
@@ -8142,7 +8142,7 @@ func (m *PathParameterSubSchema) ToRawInfo() *yaml.Node {
info.Content = append(info.Content, compiler.NewScalarNodeForString("minItems"))
info.Content = append(info.Content, compiler.NewScalarNodeForInt(m.MinItems))
}
if m.UniqueItems != false {
if m.UniqueItems {
info.Content = append(info.Content, compiler.NewScalarNodeForString("uniqueItems"))
info.Content = append(info.Content, compiler.NewScalarNodeForBool(m.UniqueItems))
}
@@ -8218,7 +8218,7 @@ func (m *PrimitivesItems) ToRawInfo() *yaml.Node {
info.Content = append(info.Content, compiler.NewScalarNodeForString("maximum"))
info.Content = append(info.Content, compiler.NewScalarNodeForFloat(m.Maximum))
}
if m.ExclusiveMaximum != false {
if m.ExclusiveMaximum {
info.Content = append(info.Content, compiler.NewScalarNodeForString("exclusiveMaximum"))
info.Content = append(info.Content, compiler.NewScalarNodeForBool(m.ExclusiveMaximum))
}
@@ -8226,7 +8226,7 @@ func (m *PrimitivesItems) ToRawInfo() *yaml.Node {
info.Content = append(info.Content, compiler.NewScalarNodeForString("minimum"))
info.Content = append(info.Content, compiler.NewScalarNodeForFloat(m.Minimum))
}
if m.ExclusiveMinimum != false {
if m.ExclusiveMinimum {
info.Content = append(info.Content, compiler.NewScalarNodeForString("exclusiveMinimum"))
info.Content = append(info.Content, compiler.NewScalarNodeForBool(m.ExclusiveMinimum))
}
@@ -8250,7 +8250,7 @@ func (m *PrimitivesItems) ToRawInfo() *yaml.Node {
info.Content = append(info.Content, compiler.NewScalarNodeForString("minItems"))
info.Content = append(info.Content, compiler.NewScalarNodeForInt(m.MinItems))
}
if m.UniqueItems != false {
if m.UniqueItems {
info.Content = append(info.Content, compiler.NewScalarNodeForString("uniqueItems"))
info.Content = append(info.Content, compiler.NewScalarNodeForBool(m.UniqueItems))
}
@@ -8296,7 +8296,7 @@ func (m *QueryParameterSubSchema) ToRawInfo() *yaml.Node {
if m == nil {
return info
}
if m.Required != false {
if m.Required {
info.Content = append(info.Content, compiler.NewScalarNodeForString("required"))
info.Content = append(info.Content, compiler.NewScalarNodeForBool(m.Required))
}
@@ -8312,7 +8312,7 @@ func (m *QueryParameterSubSchema) ToRawInfo() *yaml.Node {
info.Content = append(info.Content, compiler.NewScalarNodeForString("name"))
info.Content = append(info.Content, compiler.NewScalarNodeForString(m.Name))
}
if m.AllowEmptyValue != false {
if m.AllowEmptyValue {
info.Content = append(info.Content, compiler.NewScalarNodeForString("allowEmptyValue"))
info.Content = append(info.Content, compiler.NewScalarNodeForBool(m.AllowEmptyValue))
}
@@ -8340,7 +8340,7 @@ func (m *QueryParameterSubSchema) ToRawInfo() *yaml.Node {
info.Content = append(info.Content, compiler.NewScalarNodeForString("maximum"))
info.Content = append(info.Content, compiler.NewScalarNodeForFloat(m.Maximum))
}
if m.ExclusiveMaximum != false {
if m.ExclusiveMaximum {
info.Content = append(info.Content, compiler.NewScalarNodeForString("exclusiveMaximum"))
info.Content = append(info.Content, compiler.NewScalarNodeForBool(m.ExclusiveMaximum))
}
@@ -8348,7 +8348,7 @@ func (m *QueryParameterSubSchema) ToRawInfo() *yaml.Node {
info.Content = append(info.Content, compiler.NewScalarNodeForString("minimum"))
info.Content = append(info.Content, compiler.NewScalarNodeForFloat(m.Minimum))
}
if m.ExclusiveMinimum != false {
if m.ExclusiveMinimum {
info.Content = append(info.Content, compiler.NewScalarNodeForString("exclusiveMinimum"))
info.Content = append(info.Content, compiler.NewScalarNodeForBool(m.ExclusiveMinimum))
}
@@ -8372,7 +8372,7 @@ func (m *QueryParameterSubSchema) ToRawInfo() *yaml.Node {
info.Content = append(info.Content, compiler.NewScalarNodeForString("minItems"))
info.Content = append(info.Content, compiler.NewScalarNodeForInt(m.MinItems))
}
if m.UniqueItems != false {
if m.UniqueItems {
info.Content = append(info.Content, compiler.NewScalarNodeForString("uniqueItems"))
info.Content = append(info.Content, compiler.NewScalarNodeForBool(m.UniqueItems))
}
@@ -8514,7 +8514,7 @@ func (m *Schema) ToRawInfo() *yaml.Node {
info.Content = append(info.Content, compiler.NewScalarNodeForString("maximum"))
info.Content = append(info.Content, compiler.NewScalarNodeForFloat(m.Maximum))
}
if m.ExclusiveMaximum != false {
if m.ExclusiveMaximum {
info.Content = append(info.Content, compiler.NewScalarNodeForString("exclusiveMaximum"))
info.Content = append(info.Content, compiler.NewScalarNodeForBool(m.ExclusiveMaximum))
}
@@ -8522,7 +8522,7 @@ func (m *Schema) ToRawInfo() *yaml.Node {
info.Content = append(info.Content, compiler.NewScalarNodeForString("minimum"))
info.Content = append(info.Content, compiler.NewScalarNodeForFloat(m.Minimum))
}
if m.ExclusiveMinimum != false {
if m.ExclusiveMinimum {
info.Content = append(info.Content, compiler.NewScalarNodeForString("exclusiveMinimum"))
info.Content = append(info.Content, compiler.NewScalarNodeForBool(m.ExclusiveMinimum))
}
@@ -8546,7 +8546,7 @@ func (m *Schema) ToRawInfo() *yaml.Node {
info.Content = append(info.Content, compiler.NewScalarNodeForString("minItems"))
info.Content = append(info.Content, compiler.NewScalarNodeForInt(m.MinItems))
}
if m.UniqueItems != false {
if m.UniqueItems {
info.Content = append(info.Content, compiler.NewScalarNodeForString("uniqueItems"))
info.Content = append(info.Content, compiler.NewScalarNodeForBool(m.UniqueItems))
}
@@ -8610,7 +8610,7 @@ func (m *Schema) ToRawInfo() *yaml.Node {
info.Content = append(info.Content, compiler.NewScalarNodeForString("discriminator"))
info.Content = append(info.Content, compiler.NewScalarNodeForString(m.Discriminator))
}
if m.ReadOnly != false {
if m.ReadOnly {
info.Content = append(info.Content, compiler.NewScalarNodeForString("readOnly"))
info.Content = append(info.Content, compiler.NewScalarNodeForBool(m.ReadOnly))
}
@@ -8796,11 +8796,11 @@ func (m *Xml) ToRawInfo() *yaml.Node {
info.Content = append(info.Content, compiler.NewScalarNodeForString("prefix"))
info.Content = append(info.Content, compiler.NewScalarNodeForString(m.Prefix))
}
if m.Attribute != false {
if m.Attribute {
info.Content = append(info.Content, compiler.NewScalarNodeForString("attribute"))
info.Content = append(info.Content, compiler.NewScalarNodeForBool(m.Attribute))
}
if m.Wrapped != false {
if m.Wrapped {
info.Content = append(info.Content, compiler.NewScalarNodeForString("wrapped"))
info.Content = append(info.Content, compiler.NewScalarNodeForBool(m.Wrapped))
}

View File

@@ -15,7 +15,7 @@
package openapi_v2
import (
"gopkg.in/yaml.v3"
yaml "go.yaml.in/yaml/v3"
"github.com/google/gnostic-models/compiler"
)

View File

@@ -21,7 +21,7 @@ import (
"regexp"
"strings"
"gopkg.in/yaml.v3"
yaml "go.yaml.in/yaml/v3"
"github.com/google/gnostic-models/compiler"
)
@@ -60,7 +60,7 @@ func NewAdditionalPropertiesItem(in *yaml.Node, context *compiler.Context) (*Add
// since the oneof matched one of its possibilities, discard any matching errors
errors = make([]error, 0)
} else {
message := fmt.Sprintf("contains an invalid AdditionalPropertiesItem")
message := "contains an invalid AdditionalPropertiesItem"
err := compiler.NewError(context, message)
errors = []error{err}
}
@@ -113,7 +113,7 @@ func NewAnyOrExpression(in *yaml.Node, context *compiler.Context) (*AnyOrExpress
// since the oneof matched one of its possibilities, discard any matching errors
errors = make([]error, 0)
} else {
message := fmt.Sprintf("contains an invalid AnyOrExpression")
message := "contains an invalid AnyOrExpression"
err := compiler.NewError(context, message)
errors = []error{err}
}
@@ -227,7 +227,7 @@ func NewCallbackOrReference(in *yaml.Node, context *compiler.Context) (*Callback
// since the oneof matched one of its possibilities, discard any matching errors
errors = make([]error, 0)
} else {
message := fmt.Sprintf("contains an invalid CallbackOrReference")
message := "contains an invalid CallbackOrReference"
err := compiler.NewError(context, message)
errors = []error{err}
}
@@ -979,7 +979,7 @@ func NewExampleOrReference(in *yaml.Node, context *compiler.Context) (*ExampleOr
// since the oneof matched one of its possibilities, discard any matching errors
errors = make([]error, 0)
} else {
message := fmt.Sprintf("contains an invalid ExampleOrReference")
message := "contains an invalid ExampleOrReference"
err := compiler.NewError(context, message)
errors = []error{err}
}
@@ -1320,7 +1320,7 @@ func NewHeaderOrReference(in *yaml.Node, context *compiler.Context) (*HeaderOrRe
// since the oneof matched one of its possibilities, discard any matching errors
errors = make([]error, 0)
} else {
message := fmt.Sprintf("contains an invalid HeaderOrReference")
message := "contains an invalid HeaderOrReference"
err := compiler.NewError(context, message)
errors = []error{err}
}
@@ -1713,7 +1713,7 @@ func NewLinkOrReference(in *yaml.Node, context *compiler.Context) (*LinkOrRefere
// since the oneof matched one of its possibilities, discard any matching errors
errors = make([]error, 0)
} else {
message := fmt.Sprintf("contains an invalid LinkOrReference")
message := "contains an invalid LinkOrReference"
err := compiler.NewError(context, message)
errors = []error{err}
}
@@ -3090,7 +3090,7 @@ func NewParameterOrReference(in *yaml.Node, context *compiler.Context) (*Paramet
// since the oneof matched one of its possibilities, discard any matching errors
errors = make([]error, 0)
} else {
message := fmt.Sprintf("contains an invalid ParameterOrReference")
message := "contains an invalid ParameterOrReference"
err := compiler.NewError(context, message)
errors = []error{err}
}
@@ -3606,7 +3606,7 @@ func NewRequestBodyOrReference(in *yaml.Node, context *compiler.Context) (*Reque
// since the oneof matched one of its possibilities, discard any matching errors
errors = make([]error, 0)
} else {
message := fmt.Sprintf("contains an invalid RequestBodyOrReference")
message := "contains an invalid RequestBodyOrReference"
err := compiler.NewError(context, message)
errors = []error{err}
}
@@ -3743,7 +3743,7 @@ func NewResponseOrReference(in *yaml.Node, context *compiler.Context) (*Response
// since the oneof matched one of its possibilities, discard any matching errors
errors = make([]error, 0)
} else {
message := fmt.Sprintf("contains an invalid ResponseOrReference")
message := "contains an invalid ResponseOrReference"
err := compiler.NewError(context, message)
errors = []error{err}
}
@@ -4310,7 +4310,7 @@ func NewSchemaOrReference(in *yaml.Node, context *compiler.Context) (*SchemaOrRe
// since the oneof matched one of its possibilities, discard any matching errors
errors = make([]error, 0)
} else {
message := fmt.Sprintf("contains an invalid SchemaOrReference")
message := "contains an invalid SchemaOrReference"
err := compiler.NewError(context, message)
errors = []error{err}
}
@@ -4543,7 +4543,7 @@ func NewSecuritySchemeOrReference(in *yaml.Node, context *compiler.Context) (*Se
// since the oneof matched one of its possibilities, discard any matching errors
errors = make([]error, 0)
} else {
message := fmt.Sprintf("contains an invalid SecuritySchemeOrReference")
message := "contains an invalid SecuritySchemeOrReference"
err := compiler.NewError(context, message)
errors = []error{err}
}

View File

@@ -15,7 +15,7 @@
package openapi_v3
import (
"gopkg.in/yaml.v3"
yaml "go.yaml.in/yaml/v3"
"github.com/google/gnostic-models/compiler"
)

View File

@@ -1,10 +0,0 @@
language: go
go:
- 1.11.x
- 1.12.x
- 1.13.x
- master
script:
- go test -cover

View File

@@ -1,67 +0,0 @@
# How to contribute #
We'd love to accept your patches and contributions to this project. There are
just a few small guidelines you need to follow.
## Contributor License Agreement ##
Contributions to any Google project must be accompanied by a Contributor
License Agreement. This is not a copyright **assignment**, it simply gives
Google permission to use and redistribute your contributions as part of the
project.
* If you are an individual writing original source code and you're sure you
own the intellectual property, then you'll need to sign an [individual
CLA][].
* If you work for a company that wants to allow you to contribute your work,
then you'll need to sign a [corporate CLA][].
You generally only need to submit a CLA once, so if you've already submitted
one (even if it was for a different project), you probably don't need to do it
again.
[individual CLA]: https://developers.google.com/open-source/cla/individual
[corporate CLA]: https://developers.google.com/open-source/cla/corporate
## Submitting a patch ##
1. It's generally best to start by opening a new issue describing the bug or
feature you're intending to fix. Even if you think it's relatively minor,
it's helpful to know what people are working on. Mention in the initial
issue that you are planning to work on that bug or feature so that it can
be assigned to you.
1. Follow the normal process of [forking][] the project, and setup a new
branch to work in. It's important that each group of changes be done in
separate branches in order to ensure that a pull request only includes the
commits related to that bug or feature.
1. Go makes it very simple to ensure properly formatted code, so always run
`go fmt` on your code before committing it. You should also run
[golint][] over your code. As noted in the [golint readme][], it's not
strictly necessary that your code be completely "lint-free", but this will
help you find common style issues.
1. Any significant changes should almost always be accompanied by tests. The
project already has good test coverage, so look at some of the existing
tests if you're unsure how to go about it. [gocov][] and [gocov-html][]
are invaluable tools for seeing which parts of your code aren't being
exercised by your tests.
1. Do your best to have [well-formed commit messages][] for each change.
This provides consistency throughout the project, and ensures that commit
messages are able to be formatted properly by various git tools.
1. Finally, push the commits to your fork and submit a [pull request][].
[forking]: https://help.github.com/articles/fork-a-repo
[golint]: https://github.com/golang/lint
[golint readme]: https://github.com/golang/lint/blob/master/README
[gocov]: https://github.com/axw/gocov
[gocov-html]: https://github.com/matm/gocov-html
[well-formed commit messages]: http://tbaggery.com/2008/04/19/a-note-about-git-commit-messages.html
[squash]: http://git-scm.com/book/en/Git-Tools-Rewriting-History#Squashing-Commits
[pull request]: https://help.github.com/articles/creating-a-pull-request

View File

@@ -1,202 +0,0 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@@ -1,89 +0,0 @@
gofuzz
======
gofuzz is a library for populating go objects with random values.
[![GoDoc](https://godoc.org/github.com/google/gofuzz?status.svg)](https://godoc.org/github.com/google/gofuzz)
[![Travis](https://travis-ci.org/google/gofuzz.svg?branch=master)](https://travis-ci.org/google/gofuzz)
This is useful for testing:
* Do your project's objects really serialize/unserialize correctly in all cases?
* Is there an incorrectly formatted object that will cause your project to panic?
Import with ```import "github.com/google/gofuzz"```
You can use it on single variables:
```go
f := fuzz.New()
var myInt int
f.Fuzz(&myInt) // myInt gets a random value.
```
You can use it on maps:
```go
f := fuzz.New().NilChance(0).NumElements(1, 1)
var myMap map[ComplexKeyType]string
f.Fuzz(&myMap) // myMap will have exactly one element.
```
Customize the chance of getting a nil pointer:
```go
f := fuzz.New().NilChance(.5)
var fancyStruct struct {
A, B, C, D *string
}
f.Fuzz(&fancyStruct) // About half the pointers should be set.
```
You can even customize the randomization completely if needed:
```go
type MyEnum string
const (
A MyEnum = "A"
B MyEnum = "B"
)
type MyInfo struct {
Type MyEnum
AInfo *string
BInfo *string
}
f := fuzz.New().NilChance(0).Funcs(
func(e *MyInfo, c fuzz.Continue) {
switch c.Intn(2) {
case 0:
e.Type = A
c.Fuzz(&e.AInfo)
case 1:
e.Type = B
c.Fuzz(&e.BInfo)
}
},
)
var myObject MyInfo
f.Fuzz(&myObject) // Type will correspond to whether A or B info is set.
```
See more examples in ```example_test.go```.
You can use this library for easier [go-fuzz](https://github.com/dvyukov/go-fuzz)ing.
go-fuzz provides the user a byte-slice, which should be converted to different inputs
for the tested function. This library can help convert the byte slice. Consider for
example a fuzz test for a the function `mypackage.MyFunc` that takes an int arguments:
```go
// +build gofuzz
package mypackage
import fuzz "github.com/google/gofuzz"
func Fuzz(data []byte) int {
var i int
fuzz.NewFromGoFuzz(data).Fuzz(&i)
MyFunc(i)
return 0
}
```
Happy testing!

View File

@@ -1,81 +0,0 @@
/*
Copyright 2014 Google Inc. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Package bytesource provides a rand.Source64 that is determined by a slice of bytes.
package bytesource
import (
"bytes"
"encoding/binary"
"io"
"math/rand"
)
// ByteSource implements rand.Source64 determined by a slice of bytes. The random numbers are
// generated from each 8 bytes in the slice, until the last bytes are consumed, from which a
// fallback pseudo random source is created in case more random numbers are required.
// It also exposes a `bytes.Reader` API, which lets callers consume the bytes directly.
type ByteSource struct {
*bytes.Reader
fallback rand.Source
}
// New returns a new ByteSource from a given slice of bytes.
func New(input []byte) *ByteSource {
s := &ByteSource{
Reader: bytes.NewReader(input),
fallback: rand.NewSource(0),
}
if len(input) > 0 {
s.fallback = rand.NewSource(int64(s.consumeUint64()))
}
return s
}
func (s *ByteSource) Uint64() uint64 {
// Return from input if it was not exhausted.
if s.Len() > 0 {
return s.consumeUint64()
}
// Input was exhausted, return random number from fallback (in this case fallback should not be
// nil). Try first having a Uint64 output (Should work in current rand implementation),
// otherwise return a conversion of Int63.
if s64, ok := s.fallback.(rand.Source64); ok {
return s64.Uint64()
}
return uint64(s.fallback.Int63())
}
func (s *ByteSource) Int63() int64 {
return int64(s.Uint64() >> 1)
}
func (s *ByteSource) Seed(seed int64) {
s.fallback = rand.NewSource(seed)
s.Reader = bytes.NewReader(nil)
}
// consumeUint64 reads 8 bytes from the input and convert them to a uint64. It assumes that the the
// bytes reader is not empty.
func (s *ByteSource) consumeUint64() uint64 {
var bytes [8]byte
_, err := s.Read(bytes[:])
if err != nil && err != io.EOF {
panic("failed reading source") // Should not happen.
}
return binary.BigEndian.Uint64(bytes[:])
}

View File

@@ -1,18 +0,0 @@
/*
Copyright 2014 Google Inc. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Package fuzz is a library for populating go objects with random values.
package fuzz

View File

@@ -1,605 +0,0 @@
/*
Copyright 2014 Google Inc. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package fuzz
import (
"fmt"
"math/rand"
"reflect"
"regexp"
"time"
"github.com/google/gofuzz/bytesource"
"strings"
)
// fuzzFuncMap is a map from a type to a fuzzFunc that handles that type.
type fuzzFuncMap map[reflect.Type]reflect.Value
// Fuzzer knows how to fill any object with random fields.
type Fuzzer struct {
fuzzFuncs fuzzFuncMap
defaultFuzzFuncs fuzzFuncMap
r *rand.Rand
nilChance float64
minElements int
maxElements int
maxDepth int
skipFieldPatterns []*regexp.Regexp
}
// New returns a new Fuzzer. Customize your Fuzzer further by calling Funcs,
// RandSource, NilChance, or NumElements in any order.
func New() *Fuzzer {
return NewWithSeed(time.Now().UnixNano())
}
func NewWithSeed(seed int64) *Fuzzer {
f := &Fuzzer{
defaultFuzzFuncs: fuzzFuncMap{
reflect.TypeOf(&time.Time{}): reflect.ValueOf(fuzzTime),
},
fuzzFuncs: fuzzFuncMap{},
r: rand.New(rand.NewSource(seed)),
nilChance: .2,
minElements: 1,
maxElements: 10,
maxDepth: 100,
}
return f
}
// NewFromGoFuzz is a helper function that enables using gofuzz (this
// project) with go-fuzz (https://github.com/dvyukov/go-fuzz) for continuous
// fuzzing. Essentially, it enables translating the fuzzing bytes from
// go-fuzz to any Go object using this library.
//
// This implementation promises a constant translation from a given slice of
// bytes to the fuzzed objects. This promise will remain over future
// versions of Go and of this library.
//
// Note: the returned Fuzzer should not be shared between multiple goroutines,
// as its deterministic output will no longer be available.
//
// Example: use go-fuzz to test the function `MyFunc(int)` in the package
// `mypackage`. Add the file: "mypacakge_fuzz.go" with the content:
//
// // +build gofuzz
// package mypacakge
// import fuzz "github.com/google/gofuzz"
// func Fuzz(data []byte) int {
// var i int
// fuzz.NewFromGoFuzz(data).Fuzz(&i)
// MyFunc(i)
// return 0
// }
func NewFromGoFuzz(data []byte) *Fuzzer {
return New().RandSource(bytesource.New(data))
}
// Funcs adds each entry in fuzzFuncs as a custom fuzzing function.
//
// Each entry in fuzzFuncs must be a function taking two parameters.
// The first parameter must be a pointer or map. It is the variable that
// function will fill with random data. The second parameter must be a
// fuzz.Continue, which will provide a source of randomness and a way
// to automatically continue fuzzing smaller pieces of the first parameter.
//
// These functions are called sensibly, e.g., if you wanted custom string
// fuzzing, the function `func(s *string, c fuzz.Continue)` would get
// called and passed the address of strings. Maps and pointers will always
// be made/new'd for you, ignoring the NilChange option. For slices, it
// doesn't make much sense to pre-create them--Fuzzer doesn't know how
// long you want your slice--so take a pointer to a slice, and make it
// yourself. (If you don't want your map/pointer type pre-made, take a
// pointer to it, and make it yourself.) See the examples for a range of
// custom functions.
func (f *Fuzzer) Funcs(fuzzFuncs ...interface{}) *Fuzzer {
for i := range fuzzFuncs {
v := reflect.ValueOf(fuzzFuncs[i])
if v.Kind() != reflect.Func {
panic("Need only funcs!")
}
t := v.Type()
if t.NumIn() != 2 || t.NumOut() != 0 {
panic("Need 2 in and 0 out params!")
}
argT := t.In(0)
switch argT.Kind() {
case reflect.Ptr, reflect.Map:
default:
panic("fuzzFunc must take pointer or map type")
}
if t.In(1) != reflect.TypeOf(Continue{}) {
panic("fuzzFunc's second parameter must be type fuzz.Continue")
}
f.fuzzFuncs[argT] = v
}
return f
}
// RandSource causes f to get values from the given source of randomness.
// Use if you want deterministic fuzzing.
func (f *Fuzzer) RandSource(s rand.Source) *Fuzzer {
f.r = rand.New(s)
return f
}
// NilChance sets the probability of creating a nil pointer, map, or slice to
// 'p'. 'p' should be between 0 (no nils) and 1 (all nils), inclusive.
func (f *Fuzzer) NilChance(p float64) *Fuzzer {
if p < 0 || p > 1 {
panic("p should be between 0 and 1, inclusive.")
}
f.nilChance = p
return f
}
// NumElements sets the minimum and maximum number of elements that will be
// added to a non-nil map or slice.
func (f *Fuzzer) NumElements(atLeast, atMost int) *Fuzzer {
if atLeast > atMost {
panic("atLeast must be <= atMost")
}
if atLeast < 0 {
panic("atLeast must be >= 0")
}
f.minElements = atLeast
f.maxElements = atMost
return f
}
func (f *Fuzzer) genElementCount() int {
if f.minElements == f.maxElements {
return f.minElements
}
return f.minElements + f.r.Intn(f.maxElements-f.minElements+1)
}
func (f *Fuzzer) genShouldFill() bool {
return f.r.Float64() >= f.nilChance
}
// MaxDepth sets the maximum number of recursive fuzz calls that will be made
// before stopping. This includes struct members, pointers, and map and slice
// elements.
func (f *Fuzzer) MaxDepth(d int) *Fuzzer {
f.maxDepth = d
return f
}
// Skip fields which match the supplied pattern. Call this multiple times if needed
// This is useful to skip XXX_ fields generated by protobuf
func (f *Fuzzer) SkipFieldsWithPattern(pattern *regexp.Regexp) *Fuzzer {
f.skipFieldPatterns = append(f.skipFieldPatterns, pattern)
return f
}
// Fuzz recursively fills all of obj's fields with something random. First
// this tries to find a custom fuzz function (see Funcs). If there is no
// custom function this tests whether the object implements fuzz.Interface and,
// if so, calls Fuzz on it to fuzz itself. If that fails, this will see if
// there is a default fuzz function provided by this package. If all of that
// fails, this will generate random values for all primitive fields and then
// recurse for all non-primitives.
//
// This is safe for cyclic or tree-like structs, up to a limit. Use the
// MaxDepth method to adjust how deep you need it to recurse.
//
// obj must be a pointer. Only exported (public) fields can be set (thanks,
// golang :/ ) Intended for tests, so will panic on bad input or unimplemented
// fields.
func (f *Fuzzer) Fuzz(obj interface{}) {
v := reflect.ValueOf(obj)
if v.Kind() != reflect.Ptr {
panic("needed ptr!")
}
v = v.Elem()
f.fuzzWithContext(v, 0)
}
// FuzzNoCustom is just like Fuzz, except that any custom fuzz function for
// obj's type will not be called and obj will not be tested for fuzz.Interface
// conformance. This applies only to obj and not other instances of obj's
// type.
// Not safe for cyclic or tree-like structs!
// obj must be a pointer. Only exported (public) fields can be set (thanks, golang :/ )
// Intended for tests, so will panic on bad input or unimplemented fields.
func (f *Fuzzer) FuzzNoCustom(obj interface{}) {
v := reflect.ValueOf(obj)
if v.Kind() != reflect.Ptr {
panic("needed ptr!")
}
v = v.Elem()
f.fuzzWithContext(v, flagNoCustomFuzz)
}
const (
// Do not try to find a custom fuzz function. Does not apply recursively.
flagNoCustomFuzz uint64 = 1 << iota
)
func (f *Fuzzer) fuzzWithContext(v reflect.Value, flags uint64) {
fc := &fuzzerContext{fuzzer: f}
fc.doFuzz(v, flags)
}
// fuzzerContext carries context about a single fuzzing run, which lets Fuzzer
// be thread-safe.
type fuzzerContext struct {
fuzzer *Fuzzer
curDepth int
}
func (fc *fuzzerContext) doFuzz(v reflect.Value, flags uint64) {
if fc.curDepth >= fc.fuzzer.maxDepth {
return
}
fc.curDepth++
defer func() { fc.curDepth-- }()
if !v.CanSet() {
return
}
if flags&flagNoCustomFuzz == 0 {
// Check for both pointer and non-pointer custom functions.
if v.CanAddr() && fc.tryCustom(v.Addr()) {
return
}
if fc.tryCustom(v) {
return
}
}
if fn, ok := fillFuncMap[v.Kind()]; ok {
fn(v, fc.fuzzer.r)
return
}
switch v.Kind() {
case reflect.Map:
if fc.fuzzer.genShouldFill() {
v.Set(reflect.MakeMap(v.Type()))
n := fc.fuzzer.genElementCount()
for i := 0; i < n; i++ {
key := reflect.New(v.Type().Key()).Elem()
fc.doFuzz(key, 0)
val := reflect.New(v.Type().Elem()).Elem()
fc.doFuzz(val, 0)
v.SetMapIndex(key, val)
}
return
}
v.Set(reflect.Zero(v.Type()))
case reflect.Ptr:
if fc.fuzzer.genShouldFill() {
v.Set(reflect.New(v.Type().Elem()))
fc.doFuzz(v.Elem(), 0)
return
}
v.Set(reflect.Zero(v.Type()))
case reflect.Slice:
if fc.fuzzer.genShouldFill() {
n := fc.fuzzer.genElementCount()
v.Set(reflect.MakeSlice(v.Type(), n, n))
for i := 0; i < n; i++ {
fc.doFuzz(v.Index(i), 0)
}
return
}
v.Set(reflect.Zero(v.Type()))
case reflect.Array:
if fc.fuzzer.genShouldFill() {
n := v.Len()
for i := 0; i < n; i++ {
fc.doFuzz(v.Index(i), 0)
}
return
}
v.Set(reflect.Zero(v.Type()))
case reflect.Struct:
for i := 0; i < v.NumField(); i++ {
skipField := false
fieldName := v.Type().Field(i).Name
for _, pattern := range fc.fuzzer.skipFieldPatterns {
if pattern.MatchString(fieldName) {
skipField = true
break
}
}
if !skipField {
fc.doFuzz(v.Field(i), 0)
}
}
case reflect.Chan:
fallthrough
case reflect.Func:
fallthrough
case reflect.Interface:
fallthrough
default:
panic(fmt.Sprintf("Can't handle %#v", v.Interface()))
}
}
// tryCustom searches for custom handlers, and returns true iff it finds a match
// and successfully randomizes v.
func (fc *fuzzerContext) tryCustom(v reflect.Value) bool {
// First: see if we have a fuzz function for it.
doCustom, ok := fc.fuzzer.fuzzFuncs[v.Type()]
if !ok {
// Second: see if it can fuzz itself.
if v.CanInterface() {
intf := v.Interface()
if fuzzable, ok := intf.(Interface); ok {
fuzzable.Fuzz(Continue{fc: fc, Rand: fc.fuzzer.r})
return true
}
}
// Finally: see if there is a default fuzz function.
doCustom, ok = fc.fuzzer.defaultFuzzFuncs[v.Type()]
if !ok {
return false
}
}
switch v.Kind() {
case reflect.Ptr:
if v.IsNil() {
if !v.CanSet() {
return false
}
v.Set(reflect.New(v.Type().Elem()))
}
case reflect.Map:
if v.IsNil() {
if !v.CanSet() {
return false
}
v.Set(reflect.MakeMap(v.Type()))
}
default:
return false
}
doCustom.Call([]reflect.Value{v, reflect.ValueOf(Continue{
fc: fc,
Rand: fc.fuzzer.r,
})})
return true
}
// Interface represents an object that knows how to fuzz itself. Any time we
// find a type that implements this interface we will delegate the act of
// fuzzing itself.
type Interface interface {
Fuzz(c Continue)
}
// Continue can be passed to custom fuzzing functions to allow them to use
// the correct source of randomness and to continue fuzzing their members.
type Continue struct {
fc *fuzzerContext
// For convenience, Continue implements rand.Rand via embedding.
// Use this for generating any randomness if you want your fuzzing
// to be repeatable for a given seed.
*rand.Rand
}
// Fuzz continues fuzzing obj. obj must be a pointer.
func (c Continue) Fuzz(obj interface{}) {
v := reflect.ValueOf(obj)
if v.Kind() != reflect.Ptr {
panic("needed ptr!")
}
v = v.Elem()
c.fc.doFuzz(v, 0)
}
// FuzzNoCustom continues fuzzing obj, except that any custom fuzz function for
// obj's type will not be called and obj will not be tested for fuzz.Interface
// conformance. This applies only to obj and not other instances of obj's
// type.
func (c Continue) FuzzNoCustom(obj interface{}) {
v := reflect.ValueOf(obj)
if v.Kind() != reflect.Ptr {
panic("needed ptr!")
}
v = v.Elem()
c.fc.doFuzz(v, flagNoCustomFuzz)
}
// RandString makes a random string up to 20 characters long. The returned string
// may include a variety of (valid) UTF-8 encodings.
func (c Continue) RandString() string {
return randString(c.Rand)
}
// RandUint64 makes random 64 bit numbers.
// Weirdly, rand doesn't have a function that gives you 64 random bits.
func (c Continue) RandUint64() uint64 {
return randUint64(c.Rand)
}
// RandBool returns true or false randomly.
func (c Continue) RandBool() bool {
return randBool(c.Rand)
}
func fuzzInt(v reflect.Value, r *rand.Rand) {
v.SetInt(int64(randUint64(r)))
}
func fuzzUint(v reflect.Value, r *rand.Rand) {
v.SetUint(randUint64(r))
}
func fuzzTime(t *time.Time, c Continue) {
var sec, nsec int64
// Allow for about 1000 years of random time values, which keeps things
// like JSON parsing reasonably happy.
sec = c.Rand.Int63n(1000 * 365 * 24 * 60 * 60)
c.Fuzz(&nsec)
*t = time.Unix(sec, nsec)
}
var fillFuncMap = map[reflect.Kind]func(reflect.Value, *rand.Rand){
reflect.Bool: func(v reflect.Value, r *rand.Rand) {
v.SetBool(randBool(r))
},
reflect.Int: fuzzInt,
reflect.Int8: fuzzInt,
reflect.Int16: fuzzInt,
reflect.Int32: fuzzInt,
reflect.Int64: fuzzInt,
reflect.Uint: fuzzUint,
reflect.Uint8: fuzzUint,
reflect.Uint16: fuzzUint,
reflect.Uint32: fuzzUint,
reflect.Uint64: fuzzUint,
reflect.Uintptr: fuzzUint,
reflect.Float32: func(v reflect.Value, r *rand.Rand) {
v.SetFloat(float64(r.Float32()))
},
reflect.Float64: func(v reflect.Value, r *rand.Rand) {
v.SetFloat(r.Float64())
},
reflect.Complex64: func(v reflect.Value, r *rand.Rand) {
v.SetComplex(complex128(complex(r.Float32(), r.Float32())))
},
reflect.Complex128: func(v reflect.Value, r *rand.Rand) {
v.SetComplex(complex(r.Float64(), r.Float64()))
},
reflect.String: func(v reflect.Value, r *rand.Rand) {
v.SetString(randString(r))
},
reflect.UnsafePointer: func(v reflect.Value, r *rand.Rand) {
panic("unimplemented")
},
}
// randBool returns true or false randomly.
func randBool(r *rand.Rand) bool {
return r.Int31()&(1<<30) == 0
}
type int63nPicker interface {
Int63n(int64) int64
}
// UnicodeRange describes a sequential range of unicode characters.
// Last must be numerically greater than First.
type UnicodeRange struct {
First, Last rune
}
// UnicodeRanges describes an arbitrary number of sequential ranges of unicode characters.
// To be useful, each range must have at least one character (First <= Last) and
// there must be at least one range.
type UnicodeRanges []UnicodeRange
// choose returns a random unicode character from the given range, using the
// given randomness source.
func (ur UnicodeRange) choose(r int63nPicker) rune {
count := int64(ur.Last - ur.First + 1)
return ur.First + rune(r.Int63n(count))
}
// CustomStringFuzzFunc constructs a FuzzFunc which produces random strings.
// Each character is selected from the range ur. If there are no characters
// in the range (cr.Last < cr.First), this will panic.
func (ur UnicodeRange) CustomStringFuzzFunc() func(s *string, c Continue) {
ur.check()
return func(s *string, c Continue) {
*s = ur.randString(c.Rand)
}
}
// check is a function that used to check whether the first of ur(UnicodeRange)
// is greater than the last one.
func (ur UnicodeRange) check() {
if ur.Last < ur.First {
panic("The last encoding must be greater than the first one.")
}
}
// randString of UnicodeRange makes a random string up to 20 characters long.
// Each character is selected form ur(UnicodeRange).
func (ur UnicodeRange) randString(r *rand.Rand) string {
n := r.Intn(20)
sb := strings.Builder{}
sb.Grow(n)
for i := 0; i < n; i++ {
sb.WriteRune(ur.choose(r))
}
return sb.String()
}
// defaultUnicodeRanges sets a default unicode range when user do not set
// CustomStringFuzzFunc() but wants fuzz string.
var defaultUnicodeRanges = UnicodeRanges{
{' ', '~'}, // ASCII characters
{'\u00a0', '\u02af'}, // Multi-byte encoded characters
{'\u4e00', '\u9fff'}, // Common CJK (even longer encodings)
}
// CustomStringFuzzFunc constructs a FuzzFunc which produces random strings.
// Each character is selected from one of the ranges of ur(UnicodeRanges).
// Each range has an equal probability of being chosen. If there are no ranges,
// or a selected range has no characters (.Last < .First), this will panic.
// Do not modify any of the ranges in ur after calling this function.
func (ur UnicodeRanges) CustomStringFuzzFunc() func(s *string, c Continue) {
// Check unicode ranges slice is empty.
if len(ur) == 0 {
panic("UnicodeRanges is empty.")
}
// if not empty, each range should be checked.
for i := range ur {
ur[i].check()
}
return func(s *string, c Continue) {
*s = ur.randString(c.Rand)
}
}
// randString of UnicodeRanges makes a random string up to 20 characters long.
// Each character is selected form one of the ranges of ur(UnicodeRanges),
// and each range has an equal probability of being chosen.
func (ur UnicodeRanges) randString(r *rand.Rand) string {
n := r.Intn(20)
sb := strings.Builder{}
sb.Grow(n)
for i := 0; i < n; i++ {
sb.WriteRune(ur[r.Intn(len(ur))].choose(r))
}
return sb.String()
}
// randString makes a random string up to 20 characters long. The returned string
// may include a variety of (valid) UTF-8 encodings.
func randString(r *rand.Rand) string {
return defaultUnicodeRanges.randString(r)
}
// randUint64 makes random 64 bit numbers.
// Weirdly, rand doesn't have a function that gives you 64 random bits.
func randUint64(r *rand.Rand) uint64 {
return uint64(r.Uint32())<<32 | uint64(r.Uint32())
}