diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/expr.cc | 21 | ||||
-rw-r--r-- | src/expression.h | 2 | ||||
-rw-r--r-- | src/modcontext.cc | 26 | ||||
-rw-r--r-- | src/modcontext.h | 2 | ||||
-rw-r--r-- | src/module.cc | 20 | ||||
-rw-r--r-- | src/module.h | 3 |
6 files changed, 44 insertions, 30 deletions
diff --git a/src/expr.cc b/src/expr.cc index 746c0e3..9eaeeef 100644 --- a/src/expr.cc +++ b/src/expr.cc @@ -31,6 +31,7 @@ #include <sstream> #include <algorithm> #include "stl-utils.h" +#include "printutils.h" #include <boost/bind.hpp> #include <boost/foreach.hpp> @@ -40,7 +41,7 @@ Expression::Expression() Expression::Expression(const std::string &type, Expression *left, Expression *right) - : type(type) + : type(type), recursioncount(0) { this->children.push_back(left); this->children.push_back(right); @@ -61,6 +62,18 @@ Expression::~Expression() std::for_each(this->children.begin(), this->children.end(), del_fun<Expression>()); } +class FuncRecursionGuard +{ +public: + FuncRecursionGuard(const Expression &e) : expr(e) { + expr.recursioncount++; + } + ~FuncRecursionGuard() { expr.recursioncount--; } + bool recursion_detected() const { return (expr.recursioncount > 100); } +private: + const Expression &expr; +}; + Value Expression::evaluate(const Context *context) const { if (this->type == "!") @@ -141,6 +154,12 @@ Value Expression::evaluate(const Context *context) const return Value(); } if (this->type == "F") { + FuncRecursionGuard g(*this); + if (g.recursion_detected()) { + PRINTB("ERROR: Recursion detected calling function '%s'", this->call_funcname); + return Value(); + } + EvalContext c(context, this->call_arguments); return context->evaluate_function(this->call_funcname, &c); } diff --git a/src/expression.h b/src/expression.h index 6c03f52..3629704 100644 --- a/src/expression.h +++ b/src/expression.h @@ -40,6 +40,8 @@ public: Value evaluate(const class Context *context) const; std::string toString() const; + + mutable int recursioncount; }; std::ostream &operator<<(std::ostream &stream, const Expression &expr); diff --git a/src/modcontext.cc b/src/modcontext.cc index 44c2002..3879811 100644 --- a/src/modcontext.cc +++ b/src/modcontext.cc @@ -27,20 +27,6 @@ void ModuleContext::initializeModule(const class Module &module) } } -class RecursionGuard -{ -public: - RecursionGuard(const ModuleContext &c, const std::string &name) : c(c), name(name) { - c.recursioncount[name]++; - } - ~RecursionGuard() { if (--c.recursioncount[name] == 0) c.recursioncount.erase(name); } - bool recursion_detected() const { return (c.recursioncount[name] > 100); } -private: - const ModuleContext &c; - const std::string &name; -}; - - /*! Only used to initialize builtins for the top-level root context */ @@ -81,12 +67,6 @@ const AbstractModule *ModuleContext::findLocalModule(const std::string &name) co Value ModuleContext::evaluate_function(const std::string &name, const EvalContext *evalctx) const { - RecursionGuard g(*this, name); - if (g.recursion_detected()) { - PRINTB("Recursion detected calling function '%s'", name); - return Value(); - } - const AbstractFunction *foundf = findLocalFunction(name); if (foundf) return foundf->evaluate(this, evalctx); @@ -141,12 +121,6 @@ FileContext::FileContext(const class FileModule &module, const Context *parent) Value FileContext::evaluate_function(const std::string &name, const EvalContext *evalctx) const { - RecursionGuard g(*this, name); - if (g.recursion_detected()) { - PRINTB("Recursion detected calling function '%s'", name); - return Value(); - } - const AbstractFunction *foundf = findLocalFunction(name); if (foundf) return foundf->evaluate(this, evalctx); diff --git a/src/modcontext.h b/src/modcontext.h index 4479051..0b3e48c 100644 --- a/src/modcontext.h +++ b/src/modcontext.h @@ -36,8 +36,6 @@ public: #ifdef DEBUG virtual void dump(const class AbstractModule *mod, const ModuleInstantiation *inst); #endif - - mutable boost::unordered_map<std::string, int> recursioncount; }; class FileContext : public ModuleContext diff --git a/src/module.cc b/src/module.cc index 8b84c07..e853457 100644 --- a/src/module.cc +++ b/src/module.cc @@ -135,8 +135,28 @@ Module::~Module() { } +class ModRecursionGuard +{ +public: + ModRecursionGuard(const ModuleInstantiation &inst) : inst(inst) { + inst.recursioncount++; + } + ~ModRecursionGuard() { + inst.recursioncount--; + } + bool recursion_detected() const { return (inst.recursioncount > 100); } +private: + const ModuleInstantiation &inst; +}; + AbstractNode *Module::instantiate(const Context *ctx, const ModuleInstantiation *inst, const EvalContext *evalctx) const { + ModRecursionGuard g(*inst); + if (g.recursion_detected()) { + PRINTB("ERROR: Recursion detected calling module '%s'", inst->name()); + return NULL; + } + ModuleContext c(ctx, evalctx); c.initializeModule(*this); c.set_variable("$children", Value(double(inst->scope.children.size()))); diff --git a/src/module.h b/src/module.h index 8f1ccb7..9f46d37 100644 --- a/src/module.h +++ b/src/module.h @@ -13,7 +13,7 @@ class ModuleInstantiation { public: ModuleInstantiation(const std::string &name = "") - : tag_root(false), tag_highlight(false), tag_background(false), modname(name) { } + : tag_root(false), tag_highlight(false), tag_background(false), recursioncount(0), modname(name) { } virtual ~ModuleInstantiation(); std::string dump(const std::string &indent) const; @@ -35,6 +35,7 @@ public: bool tag_root; bool tag_highlight; bool tag_background; + mutable int recursioncount; protected: std::string modname; std::string modpath; |