generated from pikalab-unibo/python-structure-example
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmain.py
92 lines (69 loc) · 2.55 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from tuprolog.core import *
from tuprolog.core.parsing import *
from tuprolog.core.visitors import AbstractTermVisitor
formatter = TermFormatter.prettyExpressions()
def is_sum(term: Struct):
return term.getArity() == 2 and term.getFunctor() == '+'
def is_mult(term: Struct):
return term.getArity() == 2 and term.getFunctor() == '*'
def is_negative(term: Term):
if isinstance(term, Integer):
return term < Integer.ZERO
if isinstance(term, Real):
return term < Real.ZERO
if isinstance(term, Struct) and is_mult(term):
return any(map(is_negative, term.getArgs()))
return False
def is_zero(term: Term):
if isinstance(term, Integer):
return term == Integer.ZERO
if isinstance(term, Real):
return term == Real.ZERO
if isinstance(term, Struct) and is_mult(term):
return any(map(is_zero, term.getArgs()))
return False
def absolute(term: Term):
if is_negative(term):
if isinstance(term, Numeric):
return numeric(term.getValue().unaryMinus())
return struct(term.getFunctor(), map(absolute, term.getArgs()))
return term
def is_sum(term: Struct):
return term.getArity() == 2 and term.getFunctor() == '+'
def foldr(accumulator, iterable, default=None):
items = list(iterable)
if len(items) == 0:
return default
current = items[-1]
items.pop(-1)
while len(items) > 0:
current = accumulator(items[-1], current)
items.pop(-1)
return current
class Simplifier(AbstractTermVisitor):
def defaultValue(self, term):
return term
def visitStruct(self, term):
args = term.getArgs()
if is_sum(term):
left, right = args
if is_mult(right):
if is_negative(right):
return struct('-', map(absolute, args))
right_left, right_right = right
if is_sum(right) and is_negative(right_left):
return struct(
'-',
left,
struct(right.getFunctor(), absolute(right_left), right_right).accept(self)
)
return struct(term.getFunctor(), [a.accept(self) for a in args])
features = map(var, ["A", 'B', 'C', 'D', 'E', 'F'])
weights = map(real, [2.5, -3.4, -0.09, 0.2, 0.0, -2.0])
x = zip(features, weights)
x = filter(lambda fw: not is_zero(fw[1]), x)
x = map(lambda fw: struct('*', fw[1], fw[0]), x)
x = foldr(lambda a, b: struct('+', a, b), x)
x = struct('is', var('Y'), x)
print(formatter.format(x))
print(formatter.format(x.accept(Simplifier())))