#ifndef HALIDE_IR_VISITOR_H
#define HALIDE_IR_VISITOR_H

#include <set>

#include "IR.h"

/** \file
 * Defines the base class for things that recursively walk over the IR
 */

namespace Halide {
namespace Internal {

/** A base class for algorithms that need to recursively walk over the
 * IR. The default implementations just recursively walk over the
 * children. Override the ones you care about.
 */
class IRVisitor {
public:
    IRVisitor() = default;
    virtual ~IRVisitor() = default;

protected:
    // ExprNode<> and StmtNode<> are allowed to call visit (to implement accept())
    template<typename T>
    friend struct ExprNode;

    template<typename T>
    friend struct StmtNode;

    virtual void visit(const IntImm *);
    virtual void visit(const UIntImm *);
    virtual void visit(const FloatImm *);
    virtual void visit(const StringImm *);
    virtual void visit(const Cast *);
    virtual void visit(const Variable *);
    virtual void visit(const Add *);
    virtual void visit(const Sub *);
    virtual void visit(const Mul *);
    virtual void visit(const Div *);
    virtual void visit(const Mod *);
    virtual void visit(const Min *);
    virtual void visit(const Max *);
    virtual void visit(const EQ *);
    virtual void visit(const NE *);
    virtual void visit(const LT *);
    virtual void visit(const LE *);
    virtual void visit(const GT *);
    virtual void visit(const GE *);
    virtual void visit(const And *);
    virtual void visit(const Or *);
    virtual void visit(const Not *);
    virtual void visit(const Select *);
    virtual void visit(const Load *);
    virtual void visit(const Ramp *);
    virtual void visit(const Broadcast *);
    virtual void visit(const Call *);
    virtual void visit(const Let *);
    virtual void visit(const LetStmt *);
    virtual void visit(const AssertStmt *);
    virtual void visit(const ProducerConsumer *);
    virtual void visit(const For *);
    virtual void visit(const Store *);
    virtual void visit(const Provide *);
    virtual void visit(const Allocate *);
    virtual void visit(const Free *);
    virtual void visit(const Realize *);
    virtual void visit(const Block *);
    virtual void visit(const IfThenElse *);
    virtual void visit(const Evaluate *);
    virtual void visit(const Shuffle *);
    virtual void visit(const VectorReduce *);
    virtual void visit(const Prefetch *);
    virtual void visit(const Fork *);
    virtual void visit(const Acquire *);
    virtual void visit(const Atomic *);
};

/** A base class for algorithms that walk recursively over the IR
 * without visiting the same node twice. This is for passes that are
 * capable of interpreting the IR as a DAG instead of a tree. */
class IRGraphVisitor : public IRVisitor {
protected:
    /** By default these methods add the node to the visited set, and
     * return whether or not it was already there. If it wasn't there,
     * it delegates to the appropriate visit method. You can override
     * them if you like. */
    // @{
    virtual void include(const Expr &);
    virtual void include(const Stmt &);
    // @}

private:
    /** The nodes visited so far */
    std::set<IRHandle> visited;

protected:
    /** These methods should call 'include' on the children to only
     * visit them if they haven't been visited already. */
    // @{
    void visit(const IntImm *) override;
    void visit(const UIntImm *) override;
    void visit(const FloatImm *) override;
    void visit(const StringImm *) override;
    void visit(const Cast *) override;
    void visit(const Variable *) override;
    void visit(const Add *) override;
    void visit(const Sub *) override;
    void visit(const Mul *) override;
    void visit(const Div *) override;
    void visit(const Mod *) override;
    void visit(const Min *) override;
    void visit(const Max *) override;
    void visit(const EQ *) override;
    void visit(const NE *) override;
    void visit(const LT *) override;
    void visit(const LE *) override;
    void visit(const GT *) override;
    void visit(const GE *) override;
    void visit(const And *) override;
    void visit(const Or *) override;
    void visit(const Not *) override;
    void visit(const Select *) override;
    void visit(const Load *) override;
    void visit(const Ramp *) override;
    void visit(const Broadcast *) override;
    void visit(const Call *) override;
    void visit(const Let *) override;
    void visit(const LetStmt *) override;
    void visit(const AssertStmt *) override;
    void visit(const ProducerConsumer *) override;
    void visit(const For *) override;
    void visit(const Store *) override;
    void visit(const Provide *) override;
    void visit(const Allocate *) override;
    void visit(const Free *) override;
    void visit(const Realize *) override;
    void visit(const Block *) override;
    void visit(const IfThenElse *) override;
    void visit(const Evaluate *) override;
    void visit(const Shuffle *) override;
    void visit(const VectorReduce *) override;
    void visit(const Prefetch *) override;
    void visit(const Acquire *) override;
    void visit(const Fork *) override;
    void visit(const Atomic *) override;
    // @}
};

/** A visitor/mutator capable of passing arbitrary arguments to the
 * visit methods using CRTP and returning any types from them. All
 * Expr visitors must have the same signature, and all Stmt visitors
 * must have the same signature. Does not have default implementations
 * of the visit methods. */
template<typename T, typename ExprRet, typename StmtRet>
class VariadicVisitor {
private:
    template<typename... Args>
    ExprRet dispatch_expr(const BaseExprNode *node, Args &&...args) {
        if (node == nullptr) {
            return ExprRet{};
        }
        switch (node->node_type) {
        case IRNodeType::IntImm:
            return ((T *)this)->visit((const IntImm *)node, std::forward<Args>(args)...);
        case IRNodeType::UIntImm:
            return ((T *)this)->visit((const UIntImm *)node, std::forward<Args>(args)...);
        case IRNodeType::FloatImm:
            return ((T *)this)->visit((const FloatImm *)node, std::forward<Args>(args)...);
        case IRNodeType::StringImm:
            return ((T *)this)->visit((const StringImm *)node, std::forward<Args>(args)...);
        case IRNodeType::Broadcast:
            return ((T *)this)->visit((const Broadcast *)node, std::forward<Args>(args)...);
        case IRNodeType::Cast:
            return ((T *)this)->visit((const Cast *)node, std::forward<Args>(args)...);
        case IRNodeType::Variable:
            return ((T *)this)->visit((const Variable *)node, std::forward<Args>(args)...);
        case IRNodeType::Add:
            return ((T *)this)->visit((const Add *)node, std::forward<Args>(args)...);
        case IRNodeType::Sub:
            return ((T *)this)->visit((const Sub *)node, std::forward<Args>(args)...);
        case IRNodeType::Mod:
            return ((T *)this)->visit((const Mod *)node, std::forward<Args>(args)...);
        case IRNodeType::Mul:
            return ((T *)this)->visit((const Mul *)node, std::forward<Args>(args)...);
        case IRNodeType::Div:
            return ((T *)this)->visit((const Div *)node, std::forward<Args>(args)...);
        case IRNodeType::Min:
            return ((T *)this)->visit((const Min *)node, std::forward<Args>(args)...);
        case IRNodeType::Max:
            return ((T *)this)->visit((const Max *)node, std::forward<Args>(args)...);
        case IRNodeType::EQ:
            return ((T *)this)->visit((const EQ *)node, std::forward<Args>(args)...);
        case IRNodeType::NE:
            return ((T *)this)->visit((const NE *)node, std::forward<Args>(args)...);
        case IRNodeType::LT:
            return ((T *)this)->visit((const LT *)node, std::forward<Args>(args)...);
        case IRNodeType::LE:
            return ((T *)this)->visit((const LE *)node, std::forward<Args>(args)...);
        case IRNodeType::GT:
            return ((T *)this)->visit((const GT *)node, std::forward<Args>(args)...);
        case IRNodeType::GE:
            return ((T *)this)->visit((const GE *)node, std::forward<Args>(args)...);
        case IRNodeType::And:
            return ((T *)this)->visit((const And *)node, std::forward<Args>(args)...);
        case IRNodeType::Or:
            return ((T *)this)->visit((const Or *)node, std::forward<Args>(args)...);
        case IRNodeType::Not:
            return ((T *)this)->visit((const Not *)node, std::forward<Args>(args)...);
        case IRNodeType::Select:
            return ((T *)this)->visit((const Select *)node, std::forward<Args>(args)...);
        case IRNodeType::Load:
            return ((T *)this)->visit((const Load *)node, std::forward<Args>(args)...);
        case IRNodeType::Ramp:
            return ((T *)this)->visit((const Ramp *)node, std::forward<Args>(args)...);
        case IRNodeType::Call:
            return ((T *)this)->visit((const Call *)node, std::forward<Args>(args)...);
        case IRNodeType::Let:
            return ((T *)this)->visit((const Let *)node, std::forward<Args>(args)...);
        case IRNodeType::Shuffle:
            return ((T *)this)->visit((const Shuffle *)node, std::forward<Args>(args)...);
        case IRNodeType::VectorReduce:
            return ((T *)this)->visit((const VectorReduce *)node, std::forward<Args>(args)...);
            // Explicitly list the Stmt types rather than using a
            // default case so that when new IR nodes are added we
            // don't miss them here.
        case IRNodeType::LetStmt:
        case IRNodeType::AssertStmt:
        case IRNodeType::ProducerConsumer:
        case IRNodeType::For:
        case IRNodeType::Acquire:
        case IRNodeType::Store:
        case IRNodeType::Provide:
        case IRNodeType::Allocate:
        case IRNodeType::Free:
        case IRNodeType::Realize:
        case IRNodeType::Block:
        case IRNodeType::Fork:
        case IRNodeType::IfThenElse:
        case IRNodeType::Evaluate:
        case IRNodeType::Prefetch:
        case IRNodeType::Atomic:
            internal_error << "Unreachable";
        }
        return ExprRet{};
    }

    template<typename... Args>
    StmtRet dispatch_stmt(const BaseStmtNode *node, Args &&...args) {
        if (node == nullptr) {
            return StmtRet{};
        }
        switch (node->node_type) {
        case IRNodeType::IntImm:
        case IRNodeType::UIntImm:
        case IRNodeType::FloatImm:
        case IRNodeType::StringImm:
        case IRNodeType::Broadcast:
        case IRNodeType::Cast:
        case IRNodeType::Variable:
        case IRNodeType::Add:
        case IRNodeType::Sub:
        case IRNodeType::Mod:
        case IRNodeType::Mul:
        case IRNodeType::Div:
        case IRNodeType::Min:
        case IRNodeType::Max:
        case IRNodeType::EQ:
        case IRNodeType::NE:
        case IRNodeType::LT:
        case IRNodeType::LE:
        case IRNodeType::GT:
        case IRNodeType::GE:
        case IRNodeType::And:
        case IRNodeType::Or:
        case IRNodeType::Not:
        case IRNodeType::Select:
        case IRNodeType::Load:
        case IRNodeType::Ramp:
        case IRNodeType::Call:
        case IRNodeType::Let:
        case IRNodeType::Shuffle:
        case IRNodeType::VectorReduce:
            internal_error << "Unreachable";
            break;
        case IRNodeType::LetStmt:
            return ((T *)this)->visit((const LetStmt *)node, std::forward<Args>(args)...);
        case IRNodeType::AssertStmt:
            return ((T *)this)->visit((const AssertStmt *)node, std::forward<Args>(args)...);
        case IRNodeType::ProducerConsumer:
            return ((T *)this)->visit((const ProducerConsumer *)node, std::forward<Args>(args)...);
        case IRNodeType::For:
            return ((T *)this)->visit((const For *)node, std::forward<Args>(args)...);
        case IRNodeType::Acquire:
            return ((T *)this)->visit((const Acquire *)node, std::forward<Args>(args)...);
        case IRNodeType::Store:
            return ((T *)this)->visit((const Store *)node, std::forward<Args>(args)...);
        case IRNodeType::Provide:
            return ((T *)this)->visit((const Provide *)node, std::forward<Args>(args)...);
        case IRNodeType::Allocate:
            return ((T *)this)->visit((const Allocate *)node, std::forward<Args>(args)...);
        case IRNodeType::Free:
            return ((T *)this)->visit((const Free *)node, std::forward<Args>(args)...);
        case IRNodeType::Realize:
            return ((T *)this)->visit((const Realize *)node, std::forward<Args>(args)...);
        case IRNodeType::Block:
            return ((T *)this)->visit((const Block *)node, std::forward<Args>(args)...);
        case IRNodeType::Fork:
            return ((T *)this)->visit((const Fork *)node, std::forward<Args>(args)...);
        case IRNodeType::IfThenElse:
            return ((T *)this)->visit((const IfThenElse *)node, std::forward<Args>(args)...);
        case IRNodeType::Evaluate:
            return ((T *)this)->visit((const Evaluate *)node, std::forward<Args>(args)...);
        case IRNodeType::Prefetch:
            return ((T *)this)->visit((const Prefetch *)node, std::forward<Args>(args)...);
        case IRNodeType::Atomic:
            return ((T *)this)->visit((const Atomic *)node, std::forward<Args>(args)...);
        }
        return StmtRet{};
    }

public:
    template<typename... Args>
    HALIDE_ALWAYS_INLINE StmtRet dispatch(const Stmt &s, Args &&...args) {
        return dispatch_stmt(s.get(), std::forward<Args>(args)...);
    }

    template<typename... Args>
    HALIDE_ALWAYS_INLINE StmtRet dispatch(Stmt &&s, Args &&...args) {
        return dispatch_stmt(s.get(), std::forward<Args>(args)...);
    }

    template<typename... Args>
    HALIDE_ALWAYS_INLINE ExprRet dispatch(const Expr &e, Args &&...args) {
        return dispatch_expr(e.get(), std::forward<Args>(args)...);
    }

    template<typename... Args>
    HALIDE_ALWAYS_INLINE ExprRet dispatch(Expr &&e, Args &&...args) {
        return dispatch_expr(e.get(), std::forward<Args>(args)...);
    }
};

}  // namespace Internal
}  // namespace Halide

#endif
