| | import torch |
| | import torch.nn as nn |
| | import math |
| |
|
| | try: |
| | from . import helpers as h |
| | except: |
| | import helpers as h |
| |
|
| |
|
| |
|
| | class Const(): |
| | def __init__(self, c): |
| | self.c = c if c is None else float(c) |
| |
|
| | def getVal(self, c = None, **kargs): |
| | return self.c if self.c is not None else c |
| |
|
| | def __str__(self): |
| | return str(self.c) |
| |
|
| | def initConst(x): |
| | return x if isinstance(x, Const) else Const(x) |
| |
|
| | class Lin(Const): |
| | def __init__(self, start, end, steps, initial = 0, quant = False): |
| | self.start = float(start) |
| | self.end = float(end) |
| | self.steps = float(steps) |
| | self.initial = float(initial) |
| | self.quant = quant |
| |
|
| | def getVal(self, time = 0, **kargs): |
| | if self.quant: |
| | time = math.floor(time) |
| | return (self.end - self.start) * max(0,min(1, float(time - self.initial) / self.steps)) + self.start |
| |
|
| | def __str__(self): |
| | return "Lin(%s,%s,%s,%s, quant=%s)".format(str(self.start), str(self.end), str(self.steps), str(self.initial), str(self.quant)) |
| |
|
| | class Until(Const): |
| | def __init__(self, thresh, a, b): |
| | self.a = Const.initConst(a) |
| | self.b = Const.initConst(b) |
| | self.thresh = thresh |
| |
|
| | def getVal(self, *args, time = 0, **kargs): |
| | return self.a.getVal(*args, time = time, **kargs) if time < self.thresh else self.b.getVal(*args, time = time - self.thresh, **kargs) |
| |
|
| | def __str__(self): |
| | return "Until(%s, %s, %s)" % (str(self.thresh), str(self.a), str(self.b)) |
| |
|
| | class Scale(Const): |
| | def __init__(self, c): |
| | self.c = Const.initConst(c) |
| |
|
| | def getVal(self, *args, **kargs): |
| | c = self.c.getVal(*args, **kargs) |
| | if c == 0: |
| | return 0 |
| | assert c >= 0 |
| | assert c < 1 |
| | return c / (1 - c) |
| |
|
| | def __str__(self): |
| | return "Scale(%s)" % str(self.c) |
| |
|
| | def MixLin(*args, **kargs): |
| | return Scale(Lin(*args, **kargs)) |
| |
|
| | class Normal(Const): |
| | def __init__(self, c): |
| | self.c = Const.initConst(c) |
| |
|
| | def getVal(self, *args, shape = [1], **kargs): |
| | c = self.c.getVal(*args, shape = shape, **kargs) |
| | return torch.randn(shape, device = h.device).abs() * c |
| |
|
| | def __str__(self): |
| | return "Normal(%s)" % str(self.c) |
| |
|
| | class Clip(Const): |
| | def __init__(self, c, l, u): |
| | self.c = Const.initConst(c) |
| | self.l = Const.initConst(l) |
| | self.u = Const.initConst(u) |
| |
|
| | def getVal(self, *args, **kargs): |
| | c = self.c.getVal(*args, **kargs) |
| | l = self.l.getVal(*args, **kargs) |
| | u = self.u.getVal(*args, **kargs) |
| | if isinstance(c, float): |
| | return min(max(c,l),u) |
| | else: |
| | return c.clamp(l,u) |
| |
|
| | def __str__(self): |
| | return "Clip(%s, %s, %s)" % (str(self.c), str(self.l), str(self.u)) |
| |
|
| | class Fun(Const): |
| | def __init__(self, foo): |
| | self.foo = foo |
| | def getVal(self, *args, **kargs): |
| | return self.foo(*args, **kargs) |
| | |
| | def __str__(self): |
| | return "Fun(...)" |
| |
|
| | class Complement(Const): |
| | def __init__(self, c): |
| | self.c = Const.initConst(c) |
| |
|
| | def getVal(self, *args, **kargs): |
| | c = self.c.getVal(*args, **kargs) |
| | assert c >= 0 |
| | assert c <= 1 |
| | return 1 - c |
| |
|
| | def __str__(self): |
| | return "Complement(%s)" % str(self.c) |
| |
|