aboutsummaryrefslogtreecommitdiff
path: root/tools/cru-py/cru/util/_func.py
blob: 5fb49a9a428e890ffa9ab73db7941aa158665ec5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
from collections.abc import Callable, Iterable
from typing import TypeVar, Any, ParamSpec

from ._list import ListOperations, CruList

T = TypeVar("T")
R = TypeVar("R")
R1 = TypeVar("R1")
P = ParamSpec("P")
P1 = ParamSpec("P1")


class _Placeholder:
    pass


PLACEHOLDER = _Placeholder()


class RawFunctions:
    @staticmethod
    def ignore(*_v, **_kwargs) -> None:
        return None

    @staticmethod
    def true_(*_v, **_kwargs) -> True:
        return True

    @staticmethod
    def false_(*_v, **_kwargs) -> False:
        return False

    @staticmethod
    def i_dont_care(r: T, *_v, **_kwargs) -> T:
        return r

    @staticmethod
    def identity(v: T) -> T:
        return v

    @staticmethod
    def equal(a: Any, b: Any) -> bool:
        return a == b

    @staticmethod
    def not_equal(a: Any, b: Any) -> bool:
        return a != b

    @staticmethod
    def not_(v):
        return not v


class MetaFunction:
    @staticmethod
    def bind(f: Callable[P, R], *bind_args, **bind_kwargs) -> Callable[P1, R1]:
        def bound(*args, **kwargs):
            popped = 0
            real_args = []
            for a in bind_args:
                if isinstance(a, _Placeholder):
                    real_args.append(args[popped])
                    popped += 1
                else:
                    real_args.append(a)
            real_args.extend(args[popped:])
            return f(*real_args, **(bind_kwargs | kwargs))

        return bound

    @staticmethod
    def chain(*fs: Callable) -> Callable:
        if len(fs) == 0:
            raise ValueError("At least one function is required!")
        rf = fs[0]
        for f in fs[1:]:
            def n(*args, **kwargs):
                r = rf(*args, **kwargs)
                r = r if isinstance(r, tuple) else (r,)
                return f(*r)

            rf = n
        return rf

    @staticmethod
    def chain_single(f: Callable[P, R], f1: Callable[P1, R1], *bind_args, **bind_kwargs) -> \
            Callable[
                P, R1]:
        return MetaFunction.chain(f, MetaFunction.bind(f1, *bind_args, **bind_kwargs))

    convert_r = chain_single

    @staticmethod
    def neg(f: Callable[P, bool]) -> Callable[P, bool]:
        return MetaFunction.convert_r(f, RawFunctions.not_)


# Advanced Function Wrapper
class CruFunction:
    def __init__(self, f):
        self._f = f

    @property
    def f(self) -> Callable:
        return self._f

    def bind(self, *bind_args, **bind_kwargs) -> "CruFunction":
        self._f = MetaFunction.bind(self._f, *bind_args, **bind_kwargs)
        return self

    def chain(self, *fs: Callable) -> "CruFunction":
        self._f = MetaFunction.chain(self._f, *fs)
        return self

    def chain_single(self, f: Callable[P, R], f1: Callable[P1, R1], *bind_args,
                     **bind_kwargs) -> "CruFunction":
        self._f = MetaFunction.chain_single(self._f, f, f1, *bind_args, **bind_kwargs)
        return self

    def convert_r(self, f: Callable[P, R], f1: Callable[P1, R1], *bind_args,
                  **bind_kwargs) -> "CruFunction":
        self._f = MetaFunction.convert_r(self._f, f, f1, *bind_args, **bind_kwargs)
        return self

    def neg(self) -> "CruFunction":
        self._f = MetaFunction.neg(self._f)
        return self

    def __call__(self, *args, **kwargs):
        return self._f(*args, **kwargs)

    def list_transform(self, l: Iterable[T]) -> CruList[T]:
        return CruList(l).transform(self)

    def list_all(self, l: Iterable[T]) -> bool:
        return ListOperations.all(l, self)

    def list_any(self, l: Iterable[T]) -> bool:
        return ListOperations.any(l, self)

    def list_remove_all_if(self, l: Iterable[T]) -> CruList[T]:
        return CruList(ListOperations.remove_all_if(l, self))


class WrappedFunctions:
    identity = CruFunction(RawFunctions.identity)
    ignore = CruFunction(RawFunctions.ignore)
    equal = CruFunction(RawFunctions.equal)
    not_equal = CruFunction(RawFunctions.not_equal)
    not_ = CruFunction(RawFunctions.not_)