# Copyright (c) Felix Petersen.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from algovision.core import *
[docs]class If(AlgoModule):
[docs] def __init__(
self,
condition,
if_true=None,
if_false=None,
epsilon=None,
hard=None,
debug=None,
):
super(If, self).__init__()
self.condition = condition
self.if_true = if_true
self.if_false = if_false
self.epsilon = epsilon
self.hard = hard
self.debug = debug
def __call__(self, state: State) -> State:
p = self.condition(state)
if self.debug:
print(p)
assert len(p.shape) >= 1, p.shape
assert p.shape[0] == state.batch_size, (p.shape, state.batch_size)
state_true = state
state_false = state.clone()
if self.if_true is not None and (p > self.epsilon).any():
if isinstance(self.if_true, AlgoModule):
state_true = self.if_true(state_true)
elif isinstance(self.if_true, list):
for module in self.if_true:
state_true = module(state_true)
else:
assert False, ('The true case has to be either None, an AlgoModule, or a list of AlgoModules; '
'however, it is {}: {}'.format(type(self.if_true), self.if_true))
if self.if_false is not None and ((1 - p) > self.epsilon).any():
if isinstance(self.if_false, AlgoModule):
state_false = self.if_false(state_false)
elif isinstance(self.if_false, list):
for module in self.if_false:
state_false = module(state_false)
else:
assert False, ('The false case has to be either None, an AlgoModule, or a list of AlgoModules; '
'however, it is {}: {}'.format(type(self.if_false), self.if_false))
if not self.hard:
state_true.merge(state_false, 1 - p)
else:
# In case of hard, both cases are still executed as the condition might hold for some elements in the batch.
state_true.merge(state_false, 1 - (p > .5).float())
return state_true
[docs]class While(AlgoModule):
[docs] def __init__(
self,
condition,
*sequence,
max_iter=None,
epsilon=None,
hard=None,
debug=None,
):
super(While, self).__init__()
self.condition = condition
self.sequence = sequence
self.max_iter = max_iter
self.epsilon = epsilon
self.hard = hard
self.debug = debug
def __call__(self, state: State) -> State:
p_after = self.condition(state)
p_before = torch.ones_like(p_after)
i = 0
accumulate_state = state.clone()
accumulate_state.reset()
accumulate_state.add(state, p_before - p_after)
if self.debug:
print('Before WHILE', p_before - p_after)
print('state', state)
print('accumulate_state', accumulate_state)
print('-'*80)
while p_after.max() > self.epsilon and i < self.max_iter:
for elem in self.sequence:
state = elem(state)
p_before = p_after
p_after = p_after * self.condition(state)
accumulate_state.add(state, p_before - p_after)
if self.debug:
print('Inside WHILE', p_before - p_after, accumulate_state)
print('state', state)
print('accumulate_state', accumulate_state)
print('-'*80)
i += 1
accumulate_state.add(state, p_after)
if self.debug:
print('After WHILE', p_before - p_after)
print('state', state)
print('accumulate_state', accumulate_state)
print('-'*80)
return accumulate_state
[docs]class For(AlgoModule):
[docs] def __init__(
self,
var: str,
range_or_list: Union[int, list, iter, str, LambdaType],
*sequence: AlgoModule,
):
super(For, self).__init__()
self.var = var
if isinstance(range_or_list, collections.abc.Iterable) and not isinstance(range_or_list, str):
range_or_list = list(range_or_list)
self.range_or_list = range_or_list
self.sequence = sequence
def __call__(self, state: State) -> State:
if isinstance(self.range_or_list, int):
range_or_list = range(self.range_or_list)
elif isinstance(self.range_or_list, list):
range_or_list = list(self.range_or_list)
elif isinstance(self.range_or_list, LambdaType):
input_args = inspect.getfullargspec(self.range_or_list)[0]
args = [state[k] for k in input_args]
range_or_list = self.range_or_list(*args)
if isinstance(range_or_list, collections.abc.Iterable):
range_or_list = list(range_or_list)
assert type(range_or_list) in [int, list], (
'The return value of the lambda expression has to be one of [int, list, iter] but was type '
'{}. It was supposed to be used as range_or_list in a For loop'.format(type(range_or_list))
)
elif isinstance(self.range_or_list, str):
range_or_list = state[self.range_or_list]
assert isinstance(range_or_list, int) or isinstance(range_or_list, list), (
'The variable {}, which was used for range_or_list should be an int or a list of ints but was neither. '
'It was {} of type {}.'.format(self.range_or_list, range_or_list, type(range_or_list))
)
else:
assert False, (
'Invalid type {} for range_or_list. Supported is Union[int, list, iter, str, LambdaType]. '
'({})'.format(type(self.range_or_list), self.range_or_list)
)
if isinstance(range_or_list, int):
range_or_list = range(range_or_list)
if self.var in state.state:
var_existed = True
assert False, 'The variable used in a For loop ({}) must not exist in the outer scope.'.format(self.var)
else:
var_existed = False
# state.state[self.var] = torch.zeros(state.batch_size, device=state.get_device())
state.state[self.var] = 0
for i in range_or_list:
state[self.var] = i
for elem in self.sequence:
state = elem(state)
if not var_existed:
del state.state[self.var]
return state