sample of calc version 2

This commit is contained in:
MITSUNARI Shigeo 2012-02-11 13:47:12 +09:00
parent f73a88b01d
commit ec8b25b8bf

300
sample/calc2.cpp Normal file
View file

@ -0,0 +1,300 @@
/*
@author herumi
@date $Date: 2010/11/17 11:58:06 $
tiny calculator 2
This program generates a function to calc the value of
polynomial given by user in run-time.
use boost::spirit::qi
*/
#ifdef _WIN32
#pragma warning(disable : 4127) // for boost(constant condition)
#pragma warning(disable : 4512) // for boost
#pragma warning(disable : 4819)
#endif
#include <boost/config/warning_disable.hpp>
#include <boost/spirit/include/qi.hpp>
#include <boost/spirit/include/phoenix_core.hpp>
#include <boost/spirit/include/phoenix_container.hpp>
#include <boost/spirit/include/phoenix_bind.hpp>
#include <boost/timer.hpp>
#include <stdio.h>
#include <assert.h>
#include <string>
#include <vector>
#include "xbyak/xbyak.h"
enum Operand {
OpAdd,
OpSub,
OpMul,
OpDiv,
OpNeg,
OpImm,
OpVarX
};
struct Code {
Operand op_;
double val_;
Code(Operand op)
: op_(op)
, val_(0)
{
}
Code(double val)
: op_(OpImm)
, val_(val)
{
}
};
typedef std::vector<Code> CodeSet;
struct Vm {
CodeSet code_;
double operator()(double x) const
{
const size_t maxStack = 16;
double stack[maxStack];
double *p = stack;
CodeSet::const_iterator pc = code_.begin();
while (pc != code_.end()) {
switch (pc->op_) {
case OpVarX:
*p++ = x;
break;
case OpImm:
*p++ = pc->val_;
break;
case OpNeg:
p[-1] = -p[-1];
break;
case OpAdd:
--p;
p[-1] += p[0];
break;
case OpSub:
--p;
p[-1] -= p[0];
break;
case OpMul:
--p;
p[-1] *= p[0];
break;
case OpDiv:
--p;
p[-1] /= p[0];
break;
}
++pc;
assert(p < stack + maxStack);
}
return p[-1];
}
};
class Jit : public Xbyak::CodeGenerator {
private:
enum {
MAX_CONST_NUM = 32
};
MIE_ALIGN(16) double constTbl_[MAX_CONST_NUM];
Xbyak::uint64 negConst_;
size_t constTblPos_;
int regIdx_;
#ifdef XBYAK32
const Xbyak::Reg32e& varTbl_;
#endif
const Xbyak::Reg32e& tbl_;
public:
/*
double jit(double x);
@note 32bit: x : [esp+4], return fp0
64bit: x [rcx](win), xmm0(gcc), return xmm0
*/
Jit()
: negConst_(0x8000000000000000ULL)
, constTblPos_(0)
, regIdx_(-1)
#ifdef XBYAK32
, varTbl_(eax)
, tbl_(edx)
#elif defined(XBYAK64_WIN)
, tbl_(rcx)
#else
, tbl_(rdi)
#endif
{
#ifdef XBYAK32
lea(varTbl_, ptr[esp+4]);
#else
#ifdef XBYAK64_WIN
movaps(ptr [rsp + 8], xm6); // save xm6, xm7
movaps(ptr [rsp + 8 + 16], xm7);
#endif
movaps(xm7, xm0); // save xm0
#endif
mov(tbl_, (size_t)constTbl_);
}
void genPush(double n)
{
if (constTblPos_ >= MAX_CONST_NUM) throw;
constTbl_[constTblPos_] = n;
if (regIdx_ == 7) throw;
movsd(Xbyak::Xmm(++regIdx_), ptr[tbl_ + constTblPos_ * sizeof(double)]);
constTblPos_++;
}
void genVarX()
{
#ifdef XBYAK32
if (regIdx_ == 7) throw;
movsd(Xbyak::Xmm(++regIdx_), ptr[varTbl_]);
#else
if (regIdx_ == 6) throw;
movsd(Xbyak::Xmm(++regIdx_), xm7);
#endif
}
void genAdd()
{
addsd(Xbyak::Xmm(regIdx_ - 1), Xbyak::Xmm(regIdx_)); regIdx_--;
}
void genSub()
{
subsd(Xbyak::Xmm(regIdx_ - 1), Xbyak::Xmm(regIdx_)); regIdx_--;
}
void genMul()
{
mulsd(Xbyak::Xmm(regIdx_ - 1), Xbyak::Xmm(regIdx_)); regIdx_--;
}
void genDiv()
{
divsd(Xbyak::Xmm(regIdx_ - 1), Xbyak::Xmm(regIdx_)); regIdx_--;
}
void genNeg()
{
xorpd(Xbyak::Xmm(regIdx_), ptr [tbl_ + MAX_CONST_NUM * sizeof(double)]);
}
void complete()
{
#ifdef XBYAK32
sub(esp, 8);
movsd(ptr [esp], xm0);
fld(qword [esp]);
add(esp, 8);
#else
#ifdef XBYAK64_WIN
movaps(xm6, ptr [rsp + 8]);
movaps(xm7, ptr [rsp + 8 + 16]);
#endif
#endif
ret();
}
};
template<typename Iterator>
struct Parser : boost::spirit::qi::grammar<Iterator, boost::spirit::ascii::space_type> {
boost::spirit::qi::rule<Iterator, boost::spirit::ascii::space_type> expression, term, factor;
CodeSet& code_;
Parser(CodeSet& code)
: Parser::base_type(expression)
, code_(code)
{
namespace qi = boost::spirit::qi;
using namespace qi::labels;
using boost::phoenix::ref;
using boost::phoenix::push_back;
expression = term >> *(('+' > term[push_back(ref(code_), OpAdd)])
| ('-' > term[push_back(ref(code_), OpSub)]));
term = factor >> *(('*' > factor[push_back(ref(code_), OpMul)])
| ('/' > factor[push_back(ref(code_), OpDiv)]));
factor = qi::double_[push_back(ref(code_), _1)]
| qi::lit('x')[push_back(ref(code_), OpVarX)]
| ('(' > expression > ')')
| ('-' > factor[push_back(ref(code_), OpNeg)])
| ('+' > factor);
}
};
template<typename Iterator>
struct ParserJit : boost::spirit::qi::grammar<Iterator, boost::spirit::ascii::space_type> {
boost::spirit::qi::rule<Iterator, boost::spirit::ascii::space_type> expression, term, factor;
Jit code_;
ParserJit()
: ParserJit::base_type(expression)
{
namespace qi = boost::spirit::qi;
using namespace qi::labels;
using boost::phoenix::ref;
using boost::phoenix::push_back;
using boost::phoenix::bind;
expression = term >> *(('+' > term[bind(&Jit::genAdd, ref(code_))])
| ('-' > term[bind(&Jit::genSub, ref(code_))]));
term = factor >> *(('*' > factor[bind(&Jit::genMul, ref(code_))])
| ('/' > factor[bind(&Jit::genDiv, ref(code_))]));
factor = qi::double_[bind(&Jit::genPush, ref(code_), _1)]
| qi::lit('x')[bind(&Jit::genVarX, ref(code_))]
| ('(' > expression > ')')
| ('-' > factor[bind(&Jit::genNeg, ref(code_))])
| ('+' > factor);
}
};
template<class Func>
void Test(const char *msg, const Func& f)
{
printf("%s:", msg);
boost::timer t;
double sum = 0;
for (double x = 0; x < 1000; x += 0.0001) {
sum += f(x);
}
printf("sum=%f, %fsec\n", sum, t.elapsed());
}
int main(int argc, char *argv[])
{
if (argc < 2) {
fprintf(stderr, "input formula\n");
return 1;
}
const std::string str(argv[1]);
try {
Vm vm;
Parser<std::string::const_iterator> parser(vm.code_);
ParserJit<std::string::const_iterator> parserJit;
const std::string::const_iterator end = str.end();
std::string::const_iterator i = str.begin();
if (!phrase_parse(i, end, parser, boost::spirit::ascii::space) || i != end) {
puts("err 1");
return 1;
}
printf("ret=%f\n", vm(2.3));
i = str.begin();
if (!phrase_parse(i, end, parserJit, boost::spirit::ascii::space) || i != end) {
puts("err 2");
return 1;
}
parserJit.code_.complete();
double (*jit)(double) = (double (*)(double))parserJit.code_.getCode();
Test("VM ", vm);
Test("JIT", jit);
} catch (...) {
fprintf(stderr, "err\n");
}
}