diff --git a/lslopt/lslfoldconst.py b/lslopt/lslfoldconst.py index a22e3a5..92a99b2 100644 --- a/lslopt/lslfoldconst.py +++ b/lslopt/lslfoldconst.py @@ -300,14 +300,12 @@ class foldconst(object): parent[index] = child[0]['ch'][0] return - if nt == 'NEG': - # bool(-a) equals bool(a) - parent[index] = child[0] - self.FoldCond(parent, index, ParentIsNegation) - return - - if nt in self.binary_ops and child[0]['t'] == child[1]['t'] == 'integer': - if nt == '!=': + if (child[0]['nt'] == '==' and child[0]['ch'][0]['t'] == 'integer' + and child[0]['ch'][1]['t'] == 'integer' + ): + # We have !(int == int). Replace with int ^ int or with int - 1 + node = parent[index] = child[0] # remove the negation + child = child[0]['ch'] if child[0]['nt'] == 'CONST' and child[0]['value'] == 1 \ or child[1]['nt'] == 'CONST' and child[1]['value'] == 1: # a != 1 -> a - 1 (which FoldTree will transform to ~-a) @@ -319,6 +317,14 @@ class foldconst(object): self.FoldTree(parent, index) return + + if nt == 'NEG': + # bool(-a) equals bool(a) + parent[index] = child[0] + self.FoldCond(parent, index, ParentIsNegation) + return + + if nt in self.binary_ops and child[0]['t'] == child[1]['t'] == 'integer': if nt == '==': if child[0]['nt'] == 'CONST' and -1 <= child[0]['value'] <= 1 \ or child[1]['nt'] == 'CONST' and -1 <= child[1]['value'] <= 1: @@ -702,7 +708,7 @@ class foldconst(object): subexpr['ch'] = [subexpr['ch'][b], subexpr['ch'][a]] parent[index] = subexpr return - if snt == '!=': + if snt == '!=' or snt == '^': subexpr['nt'] = '==' parent[index] = subexpr self.FoldTree(parent, index)