import { is } from 'ramda';
import * as ast from '../ast';
import Parsimmon from 'parsimmon';
import { List } from 'immutable';

const func = ([name, exprs]) => new ast.Func(name, List(exprs));

const binaryOperators = {
  '>': ast.GtOp,
  '>=': ast.GeOp,
  '<': ast.LtOp,
  '<=': ast.LeOp,
  '=': ast.EqOp,
  '!=': ast.NeOp,
  rlike: ast.RLikeOp,
  like: ast.LikeOp,
  notlike: ast.NotLikeOp,
};

const opMap = {
  and: ast.ModelAnd,
  or: ast.ModelOr,
};

const inOps = {
  in: ast.InOp,
  notin: ast.NotInOp,
};

const whitespace = Parsimmon.regexp(/\s*/m);

const token = (parser) => parser.skip(whitespace);
const word = (str) => Parsimmon.string(str).thru(token);

const literal = (str) => Parsimmon.regexp(new RegExp(str, 'i')).thru(token);

const andIdentifier = Parsimmon.regexp(/and/i).desc('and');
const orIdentifier = Parsimmon.regexp(/or/i).desc('or');

const _ = Parsimmon.optWhitespace;

const SQLParser = (
  available_columns = new Set(),
  available_functions = new Set(['lower', 'upper'])
) =>
  Parsimmon.createLanguage({
    expr: (r) =>
      Parsimmon.alt(
        r.case,
        r.andorBin,
        r.bin,
        r.complex.map((side) => r.expr.tryParse(side)),
        r.cast,
        r.string,
        r.number,
        r.inop,
        r.bool,
        r.function,
        r.var
      ).thru((parser) => whitespace.then(parser)),

    identifier: () => token(Parsimmon.regexp(/[a-z_][a-z0-9_]*/i)),

    operator: () => token(Parsimmon.regexp(/(=|!=|rlike|like|<=|>=|>|<)/i)),

    string: () =>
      token(Parsimmon.regexp(/'((?:\\.|.)*?)'/, 1))
        .map((value) => new ast.Const(value))
        .desc('string'),
    number: () =>
      token(Parsimmon.regexp(/-?(0|[1-9][0-9]*)([.][0-9]+)?([eE][+-]?[0-9]+)?/))
        .map(Number)
        .map((value) => new ast.Const(value))
        .desc('number'),
    bool: () =>
      token(Parsimmon.regexp(/(true|false)/i))
        .map((value) => value.toLowerCase() === 'true')
        .map((value) => new ast.Const(value))
        .desc('boolean'),
    inop: (r) =>
      Parsimmon.seq(
        Parsimmon.alt(r.cast, r.number, r.var, r.string, r.function),
        Parsimmon.alt(
          literal('in'),
          Parsimmon.seq(
            literal('not'),
            literal('in'),
          ).map(() => 'notin')
        ),
        r.lparen.then(r.expr.sepBy(r.comma)).skip(r.rparen)
      ).map(([lhs, op, rhs]) => new inOps[op.toLowerCase()](lhs, rhs)),
    function: (r) =>
      Parsimmon.seq(r.identifier, r.lparen.then(r.expr.sepBy(r.comma)).skip(r.rparen))
        .map(func)
        .chain((f) => {
          if (available_functions.has(f.name.toLowerCase())) {
            return Parsimmon.of(f);
          }
          return Parsimmon.fail(`a valid function, but ${f.name}() is unknown`);
        }),
    var: (r) =>
      r.identifier
        .sepBy1(r.dot)
        .map((name) => new ast.VarExpr(name.join('.')))
        .chain((v) => {
          if (available_columns.has(v.name.toLowerCase())) {
            return Parsimmon.of(v);
          }
          return Parsimmon.fail('a valid column name');
        }),

    dot: () => word('.'),
    comma: () => word(','),
    lparen: () => word('('),
    rparen: () => word(')'),

    case: (r) =>
      Parsimmon.seqObj(
        literal('case'),
        ['conditions', r.when.atLeast(1)],
        ['default_', literal('else').then(r.expr).fallback(null)],
        literal('end')
      )
        .map((stmt) => new ast.CaseWhen(List(stmt.conditions), stmt.default_))
        .desc('case when ... end'),
    when: (r) =>
      Parsimmon.seqObj(literal('when'), ['when', r.expr], literal('then'), ['then', r.expr]).map(
        ast.Condition
      ),
    cast: (r) =>
      Parsimmon.seqObj(
        literal('cast'),
        r.lparen,
        ['expr', r.expr],
        literal('as'),
        [
          'cast_type',
          Parsimmon.alt(
            // Adding check to support cast with decimal with decimal points
            // this is the only case where cast type can look like function
            Parsimmon.seq(
              literal('decimal'),
              r.lparen,
              Parsimmon.regexp(/\d+/m),
              r.comma,
              Parsimmon.regexp(/\d+/m),
              r.rparen
            ).map((val) => val.join('')),
            r.identifier
          ),
        ],
        r.rparen
      ).map((stmt) => new ast.Cast(stmt.expr, stmt.cast_type)),
    bin: (r) =>
      Parsimmon.seq(
        Parsimmon.alt(r.cast, r.number, r.var, r.string, r.function),
        Parsimmon.alt(
          r.operator,
          Parsimmon.seq(
            literal('not'),
            literal('like'),
          ).map(() => 'notlike')
        ),
        Parsimmon.alt(r.cast, r.number, r.var, r.string)
      ).map((stmt) => {
        const [lhs, op, rhs] = stmt;
        return new binaryOperators[op.toLowerCase()](lhs, rhs);
      }),

    andorLogic: () => Parsimmon.alt(andIdentifier.trim(_), orIdentifier.trim(_)),
    andorBin: (r) =>
      Parsimmon.seq(
        Parsimmon.alt(r.complex, r.bin, r.inop),
        r.andorLogic,
        Parsimmon.alt(r.andorBin, r.complex, r.bin, r.inop)
      ).map((stmt) => {
        const [lhs, op, rhs] = stmt;
        return new opMap[op.toLowerCase()](
          is(Object, lhs) ? lhs : r.expr.tryParse(lhs),
          is(Object, rhs) ? rhs : r.expr.tryParse(rhs)
        );
      }),
    complex: (r) =>
      r.lparen.chain(() => {
        let stackLen = 1;
        return Parsimmon.takeWhile((char) => {
          if (char === '(') {
            stackLen += 1;
          }
          if (char === ')') {
            stackLen -= 1;
          }
          return !(char === ')' && !stackLen);
        }).skip(r.rparen);
      }),
  }).expr;

const sql2ast = ({ available_columns, available_functions, text }) => {
  const parser = SQLParser(available_columns, available_functions);
  return parser.parse(text);
};

export default sql2ast;
