# Copyright (c) Felix Petersen.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from typing import Union, Tuple, List
from types import LambdaType
from abc import abstractmethod
import pprint
import copy
import inspect
import collections.abc
import re
VARIABLE_REGEX = r"^[_a-zA-Z][_a-zA-Z0-9]*$"
[docs]class Output(object):
[docs] def __init__(self, name, shape=None, dtype=None):
self.name = name
assert re.match(VARIABLE_REGEX, name), 'Name {} is invalid, names have to match the following regex: {}'.format(
name, VARIABLE_REGEX
)
self.shape = shape
self.dtype = dtype
[docs] def checks(self, x):
if self.shape is not None:
assert x.shape[1:] == self.shape, 'Shape of output {} does not match predefined shape {}. Note that the ' \
'input is expected to have an additional batch dimension.' \
''.format(x.shape, self.shape)
if self.dtype is not None:
assert x.dtype == self.dtype, 'Data type of output {} does not match predefined {}.' \
''.format(x.dtype, self.dtype)
[docs]class Variable(object):
[docs] def __init__(self, name, initial_value):
self.name = name
assert re.match(VARIABLE_REGEX, name), 'Name {} is invalid, names have to match the following regex: {}'.format(
name, VARIABLE_REGEX
)
self.initial_value = initial_value
[docs] def get_value(self):
if isinstance(self.initial_value, LambdaType):
return self.initial_value
else:
return self.initial_value.clone().detach()
[docs]class VariableInt(object):
[docs] def __init__(self, name, initial_value):
self.name = name
assert re.match(VARIABLE_REGEX, name), 'Name {} is invalid, names have to match the following regex: {}'.format(
name, VARIABLE_REGEX
)
self.initial_value = initial_value
self.checks_and_cast()
[docs] def checks_and_cast(self):
if isinstance(self.initial_value, int):
pass
elif isinstance(self.initial_value, list):
for val in self.initial_value:
assert isinstance(val, int), (
'Integer variable is a hard variable and only supports `int`, `List[int]`, and Iter[int]. '
'For variable {} inserted {} of type {}. The problematic element is {}.'
''.format(self.name, self.initial_value, type(self.initial_value), val)
)
elif isinstance(self.initial_value, collections.abc.Iterable):
self.initial_value = list(self.initial_value)
for val in self.initial_value:
assert isinstance(val, int), (
'Integer variable is a hard variable and only supports `int`, `List[int]`, and Iter[int]. '
'For variable {} inserted {} of type {}. The problematic element is {}.'
''.format(self.name, self.initial_value, type(self.initial_value), val)
)
elif isinstance(self.initial_value, LambdaType):
pass
else:
assert False, (
'Integer variable is a hard variable and only supports `int`, `List[int]`, `LambdaType`, '
'and `Iter[int]`. '
'For variable {} inserted {} of type {}.'
''.format(self.name, self.initial_value, type(self.initial_value))
)
[docs] def get_value(self):
if isinstance(self.initial_value, int) or isinstance(self.initial_value, LambdaType):
return self.initial_value
else:
return list(self.initial_value)
Var = Variable
VarInt = VariableInt
[docs]class State(object):
[docs] def __init__(self, input_names, input_values, variables, variable_ints, batch_size):
assert len(input_names) == len(input_values), 'The number of actual inputs {} does not match the ' \
'predefined number of expected inputs {}.'.format(
len(input_values), len(input_names)
)
self.state = {}
self.batch_size = batch_size
for name, value in zip(input_names, input_values):
assert isinstance(name, Input), name
assert name.name not in self.state, 'Variable / Input with name `{}` is defined for the second time, ' \
'which is not supported. The following variables are already ' \
'defined: {}'.format(name.name, self.state.keys())
name.checks(value)
assert value.shape[0] == batch_size, (
'The 0th dimension of Input `{}` is supposed to be the batch dimension (which was inferred from the '
'first input to be of size {}); however, the shape is {}.'.format(name.name, batch_size, value.shape)
)
self.state[name.name] = value
for variable in variables:
assert isinstance(variable, Variable), variable
assert variable.name not in self.state, (
'Variable / Input with name `{}` is defined for the second time, '
'which is not supported. The following variables are already '
'defined: {}'.format(variable.name, self.state.keys()))
self.state[variable.name] = variable.get_value()
if isinstance(self.state[variable.name], LambdaType):
input_args = inspect.getfullargspec(self.state[variable.name])[0]
args = [self.state[k] for k in input_args]
self.state[variable.name] = self.state[variable.name](*args)
assert isinstance(self.state[variable.name], torch.Tensor), (
'The return value of the lambda expression has to be torch.Tensor but was type '
'{}. It was supposed to be written to variable {}.'.format(
type(self.state[variable.name]), variable.name)
)
self.state[variable.name] = self.state[variable.name].unsqueeze(0).repeat(
batch_size,
*[1]*len(self.state[variable.name].shape)
)
for variable_int in variable_ints:
assert isinstance(variable_int, VariableInt), variable_int
assert variable_int.name not in self.state, (
'Variable / Input with name `{}` is defined for the second time, '
'which is not supported. The following variables are already '
'defined: {}'.format(variable_int.name, self.state.keys()))
self.state[variable_int.name] = variable_int.get_value()
if isinstance(self.state[variable_int.name], LambdaType):
input_args = inspect.getfullargspec(self.state[variable_int.name])[0]
args = [self.state[k] for k in input_args]
self.state[variable_int.name] = self.state[variable_int.name](*args)
assert type(self.state[variable_int.name]) in [int, list], (
'The return value of the lambda expression has to be one of [int, list] but was type '
'{}. It was supposed to be written to variable {}.'.format(
type(self.state[variable_int.name]), variable_int.name)
)
[docs] def return_outputs(self, outputs: List[Output]) -> Union[torch.Tensor, Tuple[torch.Tensor]]:
return_values = []
for output in outputs:
assert output.name in self.state, 'Output with name `{}` is not defined. ' \
'The following variables are already ' \
'defined: {}'.format(output.name, self.state.keys())
value = self.state[output.name]
output.checks(value)
return_values.append(value)
if len(return_values) == 0:
assert False, 'There has to be at least one return value.'
elif len(return_values) == 1:
return return_values[0]
else:
return tuple(return_values)
[docs] def merge(self, to_merge, p):
"""(Internal) Merges two states with where the new tensor is used to the extent of :math:`p` .
For any tensors :math:`t_1, t_2` (`self`, `to_merge`) and probability :math:`p` , the new tensor and probability
:math:`t^\prime` is defined as
.. math:: t^\\prime = t_1 \\cdot (p-1) + t_2 \\cdot p
"""
assert isinstance(to_merge, State), to_merge
assert self.state.keys() == to_merge.state.keys(), (self.state.keys(), to_merge.state.keys())
assert p.shape[0] == self.batch_size, (p.shape, self.batch_size)
for key, value in to_merge.state.items():
if isinstance(value, int) or (isinstance(value, list) and all([isinstance(v, int) for v in value])):
assert value == self.state[key], 'You have probabilistically modified a hard Int discrete value ' \
'({}). {} {}'.format(key, value, self.state[key])
continue
if value.dtype == torch.long:
assert (value == self.state[key]).all(), 'You have probabilistically modified a discrete value ' \
'(LongTensor) ({}).'.format(key)
continue
assert len(p.shape) <= len(value.shape), (
'Error because the probability (which is produced by '
'the Condition) is higher dimensional than the actual '
'values to be interpolated. This is most likely because '
'of too high dimensional inputs to the condition. '
'Usually, the inputs to the condition should have a '
'shape of (B, ) where B is the batch dimension, i.e., '
'specifically not something like (B, 1). The shape of '
'p is {} and the shape of value is {}.'
''.format(p.shape, value.shape)
)
p_0 = 1 - p
p_1 = p
while len(p_0.shape) < len(value.shape):
p_0 = p_0.unsqueeze(-1)
p_1 = p_1.unsqueeze(-1)
assert self.state[key].shape == value.shape, (self.state[key].shape, value.shape)
# If it crashes here, that is most likely because one of the elements in state does not have its batch
# dimension:
# print('merge', p_0.shape, self.state[key].shape, value.shape)
self.state[key] = (self.state[key] * p_0 + value * p_1)
[docs] def probabilistic_update(self, key, value, p):
assert p.shape[0] == self.batch_size, (p.shape, self.batch_size)
assert len(p.shape) == len(self.state[key].shape), (p.shape, self.state[key].shape)
while len(value.shape) < len(self.state[key].shape):
value = value.unsqueeze(-1)
# print('probabilistic_update', p.shape, self.state[key].shape, value.shape)
self.state[key] = p * value + (1-p) * self.state[key]
[docs] def reset(self):
"""(Internal) Resets all tensors to zero.
"""
for key in self.state.keys():
if isinstance(self.state[key], torch.Tensor):
self.state[key] = torch.zeros_like(self.state[key])
else:
assert isinstance(self.state[key], int) or isinstance(self.state[key], LambdaType) or \
(isinstance(self.state[key], list) and all([isinstance(v, int) for v in self.state[key]])), (
'Illegal data type {} found for variable {}.'.format(type(self.state[key]), key)
)
[docs] def add(self, to_add, p):
"""(Internal) Merges two states by adding `to_add` weighted by `p`. Except of VariableInt types; there, the
new value is used.
"""
for key, value in to_add.state.items():
p_0 = p
if isinstance(value, int) or isinstance(value, LambdaType) or \
(isinstance(value, list) and all([isinstance(v, int) for v in value])):
self.state[key] = value
else:
assert self.state[key].shape == value.shape, (self.state[key].shape, value.shape)
assert len(p.shape) <= len(value.shape), (
'Error because the probability (which is produced by '
'the Condition) is higher dimensional than the actual '
'values to be interpolated. This is most likely because '
'of too high dimensional inputs to the condition. '
'Usually, the inputs to the condition should have a '
'shape of (B, ) where B is the batch dimension, i.e., '
'specifically not something like (B, 1). The shape of '
'p is {} and the shape of value is {}.'
''.format(p.shape, value.shape)
)
while len(p_0.shape) < len(value.shape):
p_0 = p_0.unsqueeze(-1)
# print('add', p_0.shape, self.state[key].shape, value.shape)
self.state[key] = self.state[key] + p_0 * value
[docs] def clone(self):
"""Duplicates a :class:`~State` .
Does not duplicate the internal objects, i.e., ``copy.copy()`` instead of ``copy.deepcopy()``.
"""
clone = copy.copy(self)
clone.state = copy.copy(self.state)
return clone
def __str__(self):
d = self.clone().__dict__
return pprint.pformat(d, indent=4)
[docs] def to(self, device):
for obj_key in self.state:
if isinstance(self.state[obj_key], torch.Tensor):
self.state[obj_key] = self.state[obj_key].to(device)
else:
assert False, (obj_key, self.state[obj_key])
def __setitem__(self, key, item):
assert key in self.state, (
'The variable {} is not declared but you attempted to write to it.'.format(key)
)
self.state[key] = item
def __getitem__(self, key):
assert key in self.state, (
'The variable {} does not exist but you attempted to access it.'.format(key)
)
return self.state[key]
[docs] def update(self, new_values: dict):
# """(Internal) Overrides the state with values from a ``dict``."""
for key, value in new_values.items():
if isinstance(self.state[key], torch.Tensor):
if isinstance(value, torch.Tensor):
assert value.shape == self.state[key].shape, (
'A variable ({}) is being updated but the new shape ({}) does not match the original shape '
'({}), which is not legal. '
'This might be because the shape of a tensor is (B, 1) or something similar, i.e., where an '
'unnecessary dimension is in the end.'.format(key, value.shape, self.state[key].shape)
)
self.state[key] = value * torch.ones_like(self.state[key])
else:
assert isinstance(value, type(self.state[key])), (
value, type(value), self.state[key], type(self.state[key])
)
self.state[key] = value
[docs] def get_device(self):
for key, value in self.state.items():
return value.device
[docs]class AlgoModule(object):
[docs] def __init__(self):
self.beta = None
@abstractmethod
def __call__(self, state: State) -> State:
pass
[docs] def set_hyperparameters(
self,
beta,
max_iter,
epsilon,
hard,
debug,
):
kwargs = dict(
beta=beta,
max_iter=max_iter,
epsilon=epsilon,
hard=hard,
debug=debug,
)
for key, val in kwargs.items():
if hasattr(self, key):
if getattr(self, key) is None:
setattr(self, key, val)
for attr_name in dir(self):
if isinstance(getattr(self, attr_name), AlgoModule):
if debug:
print('Setting hyperparameters for {}:'.format(attr_name))
getattr(self, attr_name).set_hyperparameters(**kwargs)
elif isinstance(getattr(self, attr_name), list) or isinstance(getattr(self, attr_name), tuple):
for elem in getattr(self, attr_name):
if debug:
print(getattr(self, attr_name), isinstance(elem, AlgoModule))
if isinstance(elem, AlgoModule):
if debug:
print('Setting hyperparameters for {} {}:'.format(attr_name, elem))
elem.set_hyperparameters(**kwargs)
elif isinstance(getattr(self, attr_name), Condition):
if getattr(self, attr_name).beta is None:
if debug:
print('Setting beta for {} {}:'.format(attr_name, getattr(self, attr_name)))
getattr(self, attr_name).beta = beta
[docs]class Condition(object):
[docs] def __init__(
self,
left,
right,
beta=None,
):
self.left = left
self.right = right
self.beta = beta
[docs] def get_left(self, state):
if type(self.left) == str:
return state[self.left]
elif type(self.left) is LambdaType:
kwargs = dict([(key, state[key]) for key in inspect.getfullargspec(self.left)[0]])
return self.left(**kwargs)
else:
return self.left
[docs] def get_right(self, state):
if type(self.right) == str:
return state[self.right]
elif type(self.right) is LambdaType:
kwargs = dict([(key, state[key]) for key in inspect.getfullargspec(self.right)[0]])
return self.right(**kwargs)
else:
return self.right
@abstractmethod
def __call__(self, state: State) -> torch.Tensor:
pass
def __and__(self, other):
if self.beta is None or other.beta is None:
assert False, 'Warning: and / or currently do not support implicitly setting beta.'
return lambda state: self(state) * other(state)
def __or__(self, other):
if self.beta is None or other.beta is None:
assert False, 'Warning: and / or currently do not support implicitly setting beta.'
def or_(state):
a = self(state)
b = other(state)
return a + b - a * b
return or_
[docs]class Algorithm(torch.nn.Module):
[docs] def __init__(
self,
*sequence,
beta=10.,
max_iter=2**10,
epsilon=1.e-5,
hard=False,
debug=False,
):
super(Algorithm, self).__init__()
self.inputs = []
self.outputs = []
self.variables = []
self.variable_ints = []
self.algo_modules = []
self.beta = beta
self.max_iter = max_iter
self.epsilon = epsilon
self.hard = hard
self.debug = debug
for elem in sequence:
if isinstance(elem, Input):
self.inputs.append(elem)
elif isinstance(elem, Output):
self.outputs.append(elem)
elif isinstance(elem, Variable):
self.variables.append(elem)
elif isinstance(elem, VariableInt):
self.variable_ints.append(elem)
elif isinstance(elem, AlgoModule):
self.algo_modules.append(elem)
else:
raise SyntaxError('You inserted an object of value {} into Algorithm but this is not supported.'
''.format(elem))
assert len(self.inputs) > 0, 'You need to have at least one Input.'
assert len(self.outputs) > 0, 'You need to have at least one Output.'
# assert len(self.algo_modules) > 0, 'You need to have at least one AlgoModule.'
for algo_module in self.algo_modules:
algo_module.set_hyperparameters(
beta=beta,
max_iter=max_iter,
epsilon=epsilon,
hard=hard,
debug=debug,
)
[docs] def forward(self, *inputs: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor]]:
batch_size = inputs[0].shape[0]
state = State(self.inputs, inputs, self.variables, self.variable_ints, batch_size)
if self.debug:
print('Before Algorithm')
print(state)
print('-'*80)
for module in self.algo_modules:
state = module(state)
if self.debug:
print('After Algorithm')
print(state)
print('-'*80)
return state.return_outputs(self.outputs)
if __name__ == '__main__':
a = Algorithm(
Input('values'),
Var('counter', torch.zeros(1)),
VarInt('counter2', 1221),
VarInt('counter3', range(10)),
Output('values'),
)
v = torch.randn(3, 2)
print(v)
print(a(v))