simplifying multilinear polynomials

This commit is contained in:
Thibaud Gasser 2024-01-25 17:26:34 +01:00
parent 751e369a99
commit f6620624df
2 changed files with 66 additions and 52 deletions

View File

@ -4,14 +4,14 @@ from dataclasses import dataclass
class OpKind(Enum): class OpKind(Enum):
LEFT = '<' LEFT = "<"
RIGHT = '>' RIGHT = ">"
INC = '+' INC = "+"
DEC = '-' DEC = "-"
INPT = ',' INPT = ","
OUTP = '.' OUTP = "."
JMPZ = '[' JMPZ = "["
JMPNZ = ']' JMPNZ = "]"
@dataclass @dataclass
@ -33,7 +33,6 @@ def optimize(tokens):
if len(new_tokens) <= 1: if len(new_tokens) <= 1:
return new_tokens return new_tokens
tokens = otpi_cancel_opposing(new_tokens) tokens = otpi_cancel_opposing(new_tokens)
while len(new_tokens := otpi_cancel_opposing(tokens)) != len(tokens): while len(new_tokens := otpi_cancel_opposing(tokens)) != len(tokens):
tokens = new_tokens tokens = new_tokens
@ -49,7 +48,10 @@ def opti_collapse_tokens(tokens):
new_tokens.append(next_token) new_tokens.append(next_token)
continue continue
current_token = new_tokens.pop() current_token = new_tokens.pop()
if current_token.kind == next_token.kind and str(current_token.kind.value) in "<>+-": if (
current_token.kind == next_token.kind
and str(current_token.kind.value) in "<>+-"
):
current_token.operand += 1 current_token.operand += 1
new_tokens.append(current_token) new_tokens.append(current_token)
else: else:
@ -66,8 +68,11 @@ def otpi_cancel_opposing(tokens):
new_tokens.append(next_token) new_tokens.append(next_token)
continue continue
current_token = new_tokens.pop() current_token = new_tokens.pop()
if (current_token.kind.value + next_token.kind.value in ("+-", "-+", "<>", "><", "[]") if (
and current_token.operand == next_token.operand): current_token.kind.value + next_token.kind.value
in ("+-", "-+", "<>", "><", "[]")
and current_token.operand == next_token.operand
):
continue continue
else: else:
new_tokens.append(current_token) new_tokens.append(current_token)
@ -81,14 +86,24 @@ def generate_code(tokens):
# code generation # code generation
for token in tokens: for token in tokens:
match token.kind: match token.kind:
case OpKind.LEFT: out.append(f"{' ' * nest * 2}p -= {token.operand};\n") case OpKind.LEFT:
case OpKind.RIGHT: out.append(f"{' ' * nest * 2}p += {token.operand};\n") out.append(f"{' ' * nest * 2}p -= {token.operand};\n")
case OpKind.INC: out.append(f"{' ' * nest * 2}*p += {token.operand};\n") case OpKind.RIGHT:
case OpKind.DEC: out.append(f"{' ' * nest * 2}*p -= {token.operand};\n") out.append(f"{' ' * nest * 2}p += {token.operand};\n")
case OpKind.INPT: out.append(f"{' ' * nest * 2}*p = getchar();\n") case OpKind.INC:
case OpKind.OUTP: out.append(f"{' ' * nest * 2}putchar(*p);\n") out.append(f"{' ' * nest * 2}*p += {token.operand};\n")
case OpKind.JMPZ: out.append(f"{' ' * nest * 2}" + "if (*p) do {\n");nest += 1 case OpKind.DEC:
case OpKind.JMPNZ: nest -= 1;out.append(f"{' ' * nest * 2}" + "} while (*p);\n") out.append(f"{' ' * nest * 2}*p -= {token.operand};\n")
case OpKind.INPT:
out.append(f"{' ' * nest * 2}*p = getchar();\n")
case OpKind.OUTP:
out.append(f"{' ' * nest * 2}putchar(*p);\n")
case OpKind.JMPZ:
out.append(f"{' ' * nest * 2}" + "if (*p) do {\n")
nest += 1
case OpKind.JMPNZ:
nest -= 1
out.append(f"{' ' * nest * 2}" + "} while (*p);\n")
if nest < 0: if nest < 0:
return "Error!" return "Error!"

View File

@ -3,6 +3,7 @@
import re import re
from dataclasses import dataclass from dataclasses import dataclass
@dataclass @dataclass
class Term: class Term:
coef: int coef: int
@ -18,47 +19,37 @@ class Term:
return f"{a}{b}" return f"{a}{b}"
def __add__(self, other): def __add__(self, other):
print(f"add {self} {other}")
if self.variables == other.variables: if self.variables == other.variables:
r = Term(self.coef + other.coef, self.variables) r = Term(self.coef + other.coef, self.variables)
print(r)
return r return r
else:
print(self.variables, other.variables)
raise ArithmeticError("Variables do not match") raise ArithmeticError("Variables do not match")
def __lt__(self, other): def __lt__(self, other):
return (len(self.variables) < len(other.variables)) and (self.variables < other.variables) if len(self.variables) == len(other.variables):
return self.variables < other.variables
return len(self.variables) < len(other.variables)
def __eq__(self, other): def __eq__(self, other):
return self.coef == other.coef and self.variables == other.variables return self.coef == other.coef and self.variables == other.variables
def parse_term(inp: str) -> Term: def parse_term(inp: str) -> Term:
sign, coef, variables = '', [], [] sign, coef, variables = "", [], []
for n in inp: for n in inp:
if n == '-': if n == "-":
sign = '-' sign = "-"
elif n.isnumeric(): elif n.isnumeric():
coef.append(n) coef.append(n)
elif n.isalpha(): elif n.isalpha():
variables.append(n) variables.append(n)
c = -1 if sign == '-' else 1 c = -1 if sign == "-" else 1
c *= int("".join(coef)) if coef else 1 c *= int("".join(coef)) if coef else 1
return Term(c, sorted(variables)) return Term(c, sorted(variables))
def simplify(poly): def simplify_terms(terms: [Term]) -> [Term]:
print(poly)
# parse
pattern = r"([+-]?\d*\w+)"
terms = sorted(parse_term(t) for t in re.findall(pattern, poly))
print(terms)
# simplify
simplified_terms = [] simplified_terms = []
for term in terms: for term in terms:
if not simplified_terms: if not simplified_terms:
@ -70,14 +61,15 @@ def simplify(poly):
res = previous + term res = previous + term
if res.coef != 0: if res.coef != 0:
simplified_terms.append(res) simplified_terms.append(res)
except ArithmeticError as e: except ArithmeticError: # terms not compatible to be added
print("no match")
simplified_terms.append(previous) simplified_terms.append(previous)
simplified_terms.append(term) simplified_terms.append(term)
return simplified_terms
# build output str
def poly_to_str(terms: [Term]) -> str:
res = "" res = ""
for term in simplified_terms: for term in terms:
s = str(term) s = str(term)
if res and not s.startswith("-"): if res and not s.startswith("-"):
res += "+" res += "+"
@ -85,3 +77,10 @@ def simplify(poly):
else: else:
res += s res += s
return res return res
def simplify(poly: str):
pattern = r"([+-]?\d*\w+)"
terms = sorted(parse_term(t) for t in re.findall(pattern, poly))
simplified_terms = simplify_terms(terms)
return poly_to_str(simplified_terms)