Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit b1319dd

Browse files
committedJan 6, 2025·
Refactor Tensor._op2
1 parent b30a72f commit b1319dd

File tree

6 files changed

+256
-224
lines changed

6 files changed

+256
-224
lines changed
 

‎phiml/backend/_backend.py

+43-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import dataclasses
22
import logging
3+
import operator
34
import sys
45
import warnings
56
from builtins import ValueError
@@ -13,6 +14,7 @@
1314
import numpy as np
1415
from numpy import ndarray
1516

17+
from . import xops
1618
from ._dtype import DType, combine_types, INT32, INT64
1719

1820
TensorType = TypeVar('TensorType')
@@ -194,15 +196,15 @@ def auto_cast(self, *tensors, bool_to_int=False, int_to_float=False) -> list:
194196
return tensors
195197

196198
def auto_cast1(self, tensor):
197-
if isinstance(tensor, (bool, Number)):
199+
if isinstance(tensor, (bool, int, float, complex)):
198200
return tensor
199201
dtype = self.dtype(tensor)
200202
if dtype.kind in {int, bool}:
201203
return tensor
202-
result_type = combine_types(dtype, fp_precision=self.precision)
203-
if result_type.bits == dtype.bits:
204-
return tensor
205-
return self.cast(tensor, result_type)
204+
if dtype.precision != get_precision():
205+
result_type = DType(dtype.kind, precision=get_precision())
206+
return self.cast(tensor, result_type)
207+
return tensor
206208

207209
def __str__(self):
208210
return self.name
@@ -2114,3 +2116,39 @@ def assemble(b: Backend, *args):
21142116
all_values = {f.name: re_values[f.name] if f.name in tensor_fields else getattr(data, f.name) for f in fields}
21152117
return type(data)(**all_values)
21162118
return assemble, tensors
2119+
2120+
2121+
_BACKEND_OPERATORS = {
2122+
operator.eq: Backend.equal,
2123+
operator.ne: Backend.not_equal,
2124+
operator.gt: Backend.greater_than,
2125+
operator.ge: Backend.greater_or_equal,
2126+
operator.add: Backend.add,
2127+
operator.sub: Backend.sub,
2128+
operator.mul: Backend.mul,
2129+
operator.truediv: Backend.div,
2130+
operator.pow: Backend.pow,
2131+
operator.mod: Backend.mod,
2132+
operator.and_: Backend.and_,
2133+
operator.or_: Backend.or_,
2134+
operator.xor: Backend.xor,
2135+
operator.floordiv: Backend.floordiv,
2136+
operator.lshift: Backend.shift_bits_left,
2137+
operator.rshift: Backend.shift_bits_right,
2138+
operator.inv: Backend.invert,
2139+
operator.invert: Backend.invert,
2140+
divmod: divmod,
2141+
abs: Backend.abs,
2142+
xops.save_div: Backend.divide_no_nan,
2143+
xops.gamma_inc_l: Backend.gamma_inc_l,
2144+
xops.gamma_inc_u: Backend.gamma_inc_u,
2145+
xops.arctan2: Backend.arctan2,
2146+
xops.minimum: Backend.minimum,
2147+
xops.maximum: Backend.maximum,
2148+
}
2149+
2150+
def get_operator(op: Callable, backend: Backend):
2151+
fun = _BACKEND_OPERATORS.get(op)
2152+
if fun is not None:
2153+
return getattr(backend, fun.__name__)
2154+
return fun

‎phiml/backend/xops.py

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
"""
2+
Extra operators.
3+
4+
This module is an extension to the built-in operators.
5+
"""
6+
7+
8+
class ExtraOperator(Exception):
9+
pass
10+
11+
12+
def save_div(numerator, denominator):
13+
raise ExtraOperator
14+
15+
16+
def gamma_inc_l(a, x):
17+
raise ExtraOperator
18+
19+
20+
def gamma_inc_u(a, x):
21+
raise ExtraOperator
22+
23+
24+
def arctan2(tan, divide_by):
25+
raise ExtraOperator
26+
27+
28+
def minimum(x1, x2):
29+
raise ExtraOperator
30+
31+
32+
def maximum(x1, x2):
33+
raise ExtraOperator

‎phiml/math/_ops.py

+9-14
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from ..backend import default_backend, choose_backend, Backend, get_precision, convert as b_convert, BACKENDS, NoBackendFound, ComputeDevice, NUMPY
1010
from ..backend._dtype import DType, combine_types, INT32
11+
from phiml.backend import xops
1112
from .magic import PhiTreeNode
1213
from ._magic_ops import expand, pack_dims, unpack_dim, cast, value_attributes, bool_to_int, tree_map, concat, stack, unstack, rename_dims, slice_, all_attributes, squeeze, ipack
1314
from ._shape import (Shape, EMPTY_SHAPE,
@@ -2137,9 +2138,9 @@ def tensor_dot(x, y):
21372138
if is_sparse(x) or is_sparse(y):
21382139
if x_dims.isdisjoint(sparse_dims(x)) and y_dims.isdisjoint(sparse_dims(y)):
21392140
if is_sparse(x):
2140-
return x._op2(y, lambda vx, vy: dot(vx, x_dims, vy, y_dims), None, 'dot', '@')
2141+
return x._op2(y, lambda vx, vy: dot(vx, x_dims, vy, y_dims), False)
21412142
else:
2142-
return y._op2(x, lambda vy, vx: dot(vx, x_dims, vy, y_dims), None, 'dot', '@')
2143+
return y._op2(x, lambda vy, vx: dot(vx, x_dims, vy, y_dims), False)
21432144
else:
21442145
return sparse_dot(x, x_dims, y, y_dims)
21452146
if x._is_tracer:
@@ -2289,11 +2290,10 @@ def incomplete_gamma(a: TensorOrTree, x: TensorOrTree, upper=False, regularized=
22892290
upper: Whether to complete the upper integral (x to infinity) or the lower integral (0 to x).
22902291
regularized: Whether the integral is divided by Γ(a).
22912292
"""
2292-
call = lambda a, x: incomplete_gamma(a, x, upper=upper, regularized=regularized)
22932293
if upper:
2294-
reg = custom_op2(a, x, call, lambda a, x: choose_backend(a, x).gamma_inc_u(a, x), 'gamma_inc_u')
2294+
reg = custom_op2(a, x, xops.gamma_inc_u)
22952295
else:
2296-
reg = custom_op2(a, x, call, lambda a, x: choose_backend(a, x).gamma_inc_l(a, x), 'gamma_inc_l')
2296+
reg = custom_op2(a, x, xops.gamma_inc_l)
22972297
return reg if regularized else reg * exp(log_gamma(a))
22982298

22992299

@@ -2464,7 +2464,7 @@ def arctan(x: TensorOrTree, divide_by=None) -> TensorOrTree:
24642464
return _backend_op1(x, Backend.arctan)
24652465
else:
24662466
divide_by = to_float(divide_by)
2467-
return custom_op2(x, divide_by, arctan, lambda a, b: choose_backend(a, b).arctan2(a, b), 'arctan')
2467+
return custom_op2(x, divide_by, xops.arctan2)
24682468

24692469

24702470
def angle(x: TensorOrTree) -> TensorOrTree:
@@ -2560,12 +2560,7 @@ def cast_same(*values: Tensor) -> Tuple[Tensor]:
25602560

25612561
def safe_div(x: Union[Number, Tensor], y: Union[Number, Tensor]):
25622562
""" Computes *x/y* with the `Tensor`s `x` and `y` but returns 0 where *y=0*. """
2563-
return custom_op2(x, y,
2564-
l_operator=safe_div,
2565-
l_native_function=lambda x_, y_: choose_backend(x_, y_).divide_no_nan(x_, y_),
2566-
r_operator=lambda y_, x_: safe_div(x_, y_),
2567-
r_native_function=lambda y_, x_: choose_backend(x_, y_).divide_no_nan(x_, y_),
2568-
op_name='divide_no_nan')
2563+
return custom_op2(x, y, xops.save_div)
25692564

25702565

25712566
def maximum(x: Union[Tensor, float], y: Union[Tensor, float], allow_none=False):
@@ -2575,7 +2570,7 @@ def maximum(x: Union[Tensor, float], y: Union[Tensor, float], allow_none=False):
25752570
return y
25762571
elif y is None:
25772572
return x
2578-
return custom_op2(x, y, maximum, lambda x_, y_: choose_backend(x_, y_).maximum(x_, y_), op_name='maximum')
2573+
return custom_op2(x, y, xops.maximum)
25792574

25802575

25812576
def minimum(x: Union[Tensor, float], y: Union[Tensor, float], allow_none=False):
@@ -2585,7 +2580,7 @@ def minimum(x: Union[Tensor, float], y: Union[Tensor, float], allow_none=False):
25852580
return y
25862581
elif y is None:
25872582
return x
2588-
return custom_op2(x, y, minimum, lambda x_, y_: choose_backend(x_, y_).minimum(x_, y_), op_name='minimum')
2583+
return custom_op2(x, y, xops.minimum)
25892584

25902585

25912586
def clip(x: Tensor, lower_limit: Union[float, Tensor] = 0, upper_limit: Union[float, Tensor, Shape] = 1):

‎phiml/math/_sparse.py

+26-27
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import operator
12
import warnings
23
from functools import partial
34
from numbers import Number
@@ -11,7 +12,7 @@
1112
from ._magic_ops import concat, pack_dims, expand, rename_dims, stack, unpack_dim, unstack
1213
from ._shape import Shape, non_batch, merge_shapes, instance, batch, non_instance, shape, channel, spatial, DimFilter, \
1314
concat_shapes, EMPTY_SHAPE, dual, non_channel, DEBUG_CHECKS, primal, concat_shapes_
14-
from ._tensors import Tensor, TensorStack, Dense, cached, wrap, reshaped_tensor, tensor, backend_for
15+
from ._tensors import Tensor, TensorStack, Dense, cached, wrap, reshaped_tensor, tensor, backend_for, custom_op2
1516
from ..backend import choose_backend, NUMPY, Backend, get_precision
1617
from ..backend._dtype import DType, INT64
1718

@@ -367,20 +368,20 @@ def _with_shape_replaced(self, new_shape: Shape):
367368
def _op1(self, native_function):
368369
return self._with_values(self._values._op1(native_function))
369370

370-
def _op2(self, other, operator: Callable, native_function: Callable, op_name: str = 'unknown', op_symbol: str = '?') -> 'Tensor':
371+
def _op2(self, other, op: Callable, switch_args: bool) -> 'Tensor':
371372
other_shape = shape(other)
372373
affects_only_values = self._dense_shape.isdisjoint(other_shape)
373374
if affects_only_values:
374-
return self._with_values(operator(self._values, other))
375+
return self._with_values(op(self._values, other))
375376
if isinstance(other, CompressedSparseMatrix):
376377
other = other.decompress()
377378
if isinstance(other, SparseCoordinateTensor):
378379
if same_sparsity_pattern(self, other):
379-
return self._with_values(operator(self._values, other._values))
380+
return self._with_values(op(self._values, other._values))
380381
else:
381-
if op_name not in ['add', 'radd', 'sub', 'rsub']:
382+
if op not in {operator.add, operator.sub}:
382383
same_sparsity_pattern(self, other) # debug checkpoint
383-
raise AssertionError(f"Operation '{op_symbol}' ({op_name}) requires sparse matrices with the same sparsity pattern.")
384+
raise AssertionError(f"Operation '{op}' requires sparse matrices with the same sparsity pattern.")
384385
all_sparse_dims = sparse_dims(other) & sparse_dims(self)
385386
self_indices = pack_dims(self._indices, instance, instance('sp_entries'))
386387
other_indices = pack_dims(other._indices, instance, instance('sp_entries'))
@@ -389,21 +390,19 @@ def _op2(self, other, operator: Callable, native_function: Callable, op_name: st
389390
self_indices, self_values = with_sparsified_dim(self_indices, self_values, all_sparse_dims)
390391
other_indices, other_values = with_sparsified_dim(other_indices, other_values, all_sparse_dims)
391392
indices = concat([self_indices, other_indices], 'sp_entries')
392-
if op_symbol == '+':
393+
if op == operator.add:
393394
values = concat([self_values, other_values], instance(self_values), expand_values=True)
394-
elif op_name == 'sub':
395-
values = concat([self_values, -other_values], instance(self_values), expand_values=True)
396-
else: # op_name == 'rsub':
397-
values = concat([-self_values, other_values], instance(self_values), expand_values=True)
395+
else:
396+
values = concat([-self_values, other_values] if switch_args else [self_values, -other_values], instance(self_values), expand_values=True)
398397
return SparseCoordinateTensor(indices, values, self._dense_shape & other._dense_shape, can_contain_double_entries=True, indices_sorted=False, indices_constant=self._indices_constant)
399398
else: # other is dense
400399
if self._dense_shape in other.shape: # all dims dense -> convert to dense
401-
return dense(self)._op2(other, operator, native_function, op_name, op_symbol)
400+
return dense(self)._op2(other, op, switch_args)
402401
else: # only some dims dense -> stay sparse
403402
dense_dims = self._dense_shape.only(other.shape)
404403
assert instance(other).without(self._dense_shape).is_empty, f"Instance dims cannot be added to sparse tensors from sparse-dense operations but got {other.shape} for sparse tensor {self.shape}"
405404
other_values = other[self._indices.sparse_idx[dense_dims.name_list]]
406-
values = operator(self._values, other_values)
405+
values = custom_op2(self._values, other_values, op, switch_args)
407406
return self._with_values(values)
408407

409408
def _getitem(self, selection: dict) -> 'Tensor':
@@ -702,19 +701,19 @@ def __expand__(self, dims: Shape, **kwargs) -> 'Tensor':
702701
def _op1(self, native_function):
703702
return self._with_values(self._values._op1(native_function))
704703

705-
def _op2(self, other, operator: Callable, native_function: Callable, op_name: str = 'unknown', op_symbol: str = '?') -> 'Tensor':
704+
def _op2(self, other, op: Callable, switch_args: bool) -> 'Tensor':
706705
other_shape = shape(other)
707706
affects_only_values = self.sparse_dims.isdisjoint(other_shape) and non_instance(self._indices).isdisjoint(other_shape)
708707
if affects_only_values:
709-
return self._with_values(operator(self._values, other))
708+
return self._with_values(custom_op2(self._values, other, op, switch_args))
710709
elif isinstance(other, CompressedSparseMatrix):
711710
if same_sparsity_pattern(self, other):
712-
result = operator(self._values, other._values)
711+
result = op(self._values, other._values)
713712
if self._uncompressed_offset is not None:
714713
from ._ops import where
715714
result = where(self._valid_mask(), result, 0)
716715
return self._with_values(result)
717-
elif op_symbol == '+':
716+
elif op == operator.add:
718717
raise NotImplementedError("Compressed addition not yet implemented")
719718
else:
720719
# convert to COO, then perform operation
@@ -723,19 +722,19 @@ def _op2(self, other, operator: Callable, native_function: Callable, op_name: st
723722
from ._ops import gather, boolean_mask, clip, where
724723
if self._uncompressed_offset is None:
725724
other_values = gather(other, self._indices, self._uncompressed_dims)
726-
return self._with_values(operator(self._values, other_values))
725+
return self._with_values(op(self._values, other_values))
727726
# if bake_slice:
728727
# baked = self._bake_slice()
729728
# other_values = gather(other, baked._indices, self._uncompressed_dims)
730729
# return baked._with_values(operator(baked._values, other_values))
731730
indices = clip(self._indices - self._uncompressed_offset, 0, self._uncompressed_dims.volume - 1)
732731
other_values = gather(other, indices, self._uncompressed_dims)
733-
return self._with_values(where(self._valid_mask(), operator(self._values, other_values), 0))
732+
return self._with_values(where(self._valid_mask(), op(self._values, other_values), 0))
734733
elif self._compressed_dims in other_shape and self._uncompressed_dims.isdisjoint(other_shape):
735734
from ._ops import gather, boolean_mask, clip, where
736735
row_indices, _ = self._coo_indices('clamp')
737736
other_values = gather(other, row_indices, self._compressed_dims)
738-
result_values = operator(self._values, other_values)
737+
result_values = op(self._values, other_values)
739738
if self._uncompressed_offset is not None:
740739
result_values = where(self._valid_mask(), result_values, 0)
741740
return self._with_values(result_values)
@@ -960,16 +959,16 @@ def _with_shape_replaced(self, new_shape: Shape):
960959
def _op1(self, native_function):
961960
return self._with_values(self._values._op1(native_function))
962961

963-
def _op2(self, other, operator: Callable, native_function: Callable, op_name: str = 'unknown', op_symbol: str = '?') -> 'Tensor':
962+
def _op2(self, other, op: Callable, switch_args: bool) -> 'Tensor':
964963
other_shape = shape(other)
965964
affects_only_values = self._compressed_dims.isdisjoint(other_shape)
966965
if affects_only_values:
967-
return self._with_values(operator(self._values, other))
966+
return self._with_values(op(self._values, other))
968967
elif isinstance(other, (CompressedSparseMatrix, CompactSparseTensor)):
969968
if same_sparsity_pattern(self, other):
970-
result = operator(self._values, other._values)
969+
result = op(self._values, other._values)
971970
return self._with_values(result)
972-
elif op_symbol == '+':
971+
elif op == operator.add:
973972
raise NotImplementedError("Compressed addition not yet implemented")
974973
else:
975974
# convert to COO, then perform operation
@@ -978,18 +977,18 @@ def _op2(self, other, operator: Callable, native_function: Callable, op_name: st
978977
from ._ops import gather, boolean_mask, clip, where
979978
if self._uncompressed_offset is None:
980979
other_values = gather(other, self._indices, self._uncompressed_dims)
981-
return self._with_values(operator(self._values, other_values))
980+
return self._with_values(op(self._values, other_values))
982981
# if bake_slice:
983982
# baked = self._bake_slice()
984983
# other_values = gather(other, baked._indices, self._uncompressed_dims)
985984
# return baked._with_values(operator(baked._values, other_values))
986985
indices = clip(self._indices - self._uncompressed_offset, 0, self._uncompressed_dims.volume - 1)
987986
other_values = gather(other, indices, self._uncompressed_dims)
988-
return self._with_values(where(self._valid_mask(), operator(self._values, other_values), 0))
987+
return self._with_values(where(self._valid_mask(), op(self._values, other_values), 0))
989988
elif self._compressed_dims in other_shape and self._uncompressed_dims.isdisjoint(other_shape):
990989
from ._ops import gather, boolean_mask, clip, where
991990
other_values = gather(other, self._indices, self._compressed_dims)
992-
result_values = operator(self._values, other_values)
991+
result_values = op(self._values, other_values)
993992
return self._with_values(result_values)
994993
else:
995994
raise NotImplementedError

‎phiml/math/_tensors.py

+106-128
Large diffs are not rendered by default.

‎phiml/math/_trace.py

+39-50
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import operator
12
from collections import namedtuple
23
from typing import Callable, Dict, Set, Tuple, Union, Any, Optional, Sequence, List, Collection
34

@@ -180,18 +181,14 @@ def _op1(self, native_function):
180181
else:
181182
raise NotImplementedError('Only linear operations are supported')
182183

183-
def _op2(self, other: Tensor,
184-
operator: Callable,
185-
native_function: Callable,
186-
op_name: str = 'unknown',
187-
op_symbol: str = '?') -> Tensor:
184+
def _op2(self, other, op: Callable, switch_args: bool) -> 'Tensor':
188185
if is_sparse(other):
189186
return NotImplemented
190187
if isinstance(other, SparseLinTracer):
191-
return to_sparse_tracer(self, other)._op2(other, operator, native_function, op_name, op_symbol)
192-
assert op_symbol in '+-*/', f"Unsupported operation encountered while tracing linear function: {native_function}"
193-
zeros_for_missing_self = op_name not in ['add', 'radd', 'rsub'] # perform `operator` where `self == 0`
194-
zeros_for_missing_other = op_name not in ['add', 'radd', 'sub'] # perform `operator` where `other == 0`
188+
return to_sparse_tracer(self, other)._op2(other, op, switch_args)
189+
assert op in {operator.add, operator.sub, operator.mul, operator.truediv}, f"Unsupported operation encountered while tracing linear function: {op}"
190+
zeros_for_missing_self = op != operator.add and not (op == operator.sub and switch_args) # perform `operator` where `self == 0`
191+
zeros_for_missing_other = op != operator.add and not (op == operator.sub and not switch_args) # perform `operator` where `other == 0`
195192
if isinstance(other, Tensor) and other._is_tracer:
196193
if not isinstance(other, ShiftLinTracer):
197194
raise NotImplementedError
@@ -201,36 +198,36 @@ def _op2(self, other: Tensor,
201198
nz_edge = {}
202199
for dim_shift in self.val.keys():
203200
if dim_shift in other.val:
204-
values[dim_shift] = operator(self.val[dim_shift], other.val[dim_shift])
201+
values[dim_shift] = op(self.val[dim_shift], other.val[dim_shift])
205202
nz_edge[dim_shift] = self._nz_edge[dim_shift] or other._nz_edge[dim_shift]
206203
else:
207204
if zeros_for_missing_other:
208-
values[dim_shift] = operator(self.val[dim_shift], math.zeros_like(self.val[dim_shift]))
205+
values[dim_shift] = op(self.val[dim_shift], math.zeros_like(self.val[dim_shift]))
209206
else:
210207
values[dim_shift] = self.val[dim_shift]
211208
nz_edge[dim_shift] = self._nz_edge[dim_shift]
212209
for dim_shift, other_values in other.val.items():
213210
if dim_shift not in self.val:
214211
if zeros_for_missing_self:
215-
values[dim_shift] = operator(math.zeros_like(other_values), other_values)
212+
values[dim_shift] = op(math.zeros_like(other_values), other_values)
216213
else:
217214
values[dim_shift] = other_values
218215
nz_edge[dim_shift] = other._nz_edge[dim_shift]
219-
bias = operator(self._bias, other._bias)
216+
bias = op(self._bias, other._bias)
220217
return ShiftLinTracer(self._source, values, self._shape, bias, self._renamed, nz_edge)
221218
else:
222219
other = self._tensor(other)
223-
if op_symbol in '*/':
220+
if op in {operator.mul, operator.truediv}:
224221
values = {}
225222
for dim_shift, val in self.val.items():
226-
values[dim_shift] = operator(val, other)
227-
bias = operator(self._bias, other)
223+
values[dim_shift] = op(val, other)
224+
bias = op(self._bias, other)
228225
return ShiftLinTracer(self._source, values, self._shape & other.shape, bias, self._renamed, self._nz_edge)
229-
elif op_symbol in '+-':
230-
bias = operator(self._bias, other)
226+
elif op in {operator.add, operator.sub}:
227+
bias = op(self._bias, other)
231228
return ShiftLinTracer(self._source, self.val, self._shape & other.shape, bias, self._renamed, self._nz_edge)
232229
else:
233-
raise ValueError(f"Unsupported operation encountered while tracing linear function: {native_function}")
230+
raise ValueError(f"Unsupported operation encountered while tracing linear function: {op}")
234231

235232
def _natives(self) -> tuple:
236233
"""
@@ -370,34 +367,30 @@ def _op1(self, native_function):
370367
else:
371368
raise NotImplementedError('Only linear operations are supported')
372369

373-
def _op2(self, other: Tensor,
374-
operator: Callable,
375-
native_function: Callable,
376-
op_name: str = 'unknown',
377-
op_symbol: str = '?') -> Tensor:
378-
assert op_symbol in '+-*/', f"Unsupported operation '{op_symbol}' encountered while tracing linear function: {native_function}"
370+
def _op2(self, other, op: Callable, switch_args: bool) -> 'Tensor':
371+
assert op in {operator.add, operator.sub, operator.mul, operator.truediv}, f"Unsupported operation '{op}' encountered while tracing linear function"
379372
if isinstance(other, ShiftLinTracer):
380373
other = other._to_gather_tracer()
381374
if isinstance(other, GatherLinTracer):
382-
assert op_symbol in '+-', f"Non-linear operation '{op_symbol}' cannot be converted to matrix"
375+
assert op in {operator.add, operator.sub}, f"Non-linear operation '{op}' cannot be converted to matrix"
383376
if not math.always_close(self._selection, other._selection):
384-
return to_sparse_tracer(self, other)._op2(other, operator, native_function, op_name, op_symbol)
385-
diag = operator(self._diag, other._diag)
386-
bias = operator(self._bias, other._bias)
377+
return to_sparse_tracer(self, other)._op2(other, op, switch_args)
378+
diag = op(self._diag, other._diag)
379+
bias = op(self._bias, other._bias)
387380
return GatherLinTracer(self._source, diag, bias, self._shape, self._selection, self._renamed)
388381
if isinstance(other, SparseLinTracer) or is_sparse(other):
389382
return NotImplemented
390383
else:
391384
other = self._tensor(other)
392-
if op_symbol in '*/':
393-
matrix = operator(self._diag, other)
394-
bias = operator(self._bias, other)
385+
if op in {operator.mul, operator.truediv}:
386+
matrix = op(self._diag, other)
387+
bias = op(self._bias, other)
395388
return GatherLinTracer(self._source, matrix, bias, self._shape & other.shape, self._selection, self._renamed)
396-
elif op_symbol in '+-':
397-
bias = operator(self._bias, other)
389+
elif op in {operator.add, operator.sub}:
390+
bias = op(self._bias, other)
398391
return GatherLinTracer(self._source, self._matrix, bias, self._shape & other.shape, self._selection, self._renamed)
399392
else:
400-
raise ValueError(f"Unsupported operation encountered while tracing linear function: {native_function}")
393+
raise ValueError(f"Unsupported operation {op} encountered while tracing linear function")
401394

402395
@property
403396
def _is_tracer(self) -> bool:
@@ -522,35 +515,31 @@ def _op1(self, native_function):
522515
else:
523516
raise NotImplementedError('Only linear operations are supported')
524517

525-
def _op2(self, other,
526-
operator: Callable,
527-
native_function: Callable,
528-
op_name: str = 'unknown',
529-
op_symbol: str = '?') -> 'SparseLinTracer':
518+
def _op2(self, other, op: Callable, switch_args: bool) -> 'Tensor':
530519
other = self._tensor(other)
531-
assert op_symbol in '+-*/', f"Unsupported operation encountered while tracing linear function: {native_function}"
520+
assert op in {operator.add, operator.sub, operator.mul, operator.truediv}, f"Unsupported operation {op} encountered while tracing linear function"
532521
if other._is_tracer and not isinstance(other, SparseLinTracer):
533522
other = to_sparse_tracer(other, self)
534523
if isinstance(other, SparseLinTracer):
535-
assert op_symbol in '+-', f"Non-linear operation '{op_symbol}' cannot be converted to matrix"
536-
bias = operator(self._bias, other._bias)
524+
assert op in {operator.add, operator.sub}, f"Non-linear operation '{op}' cannot be converted to matrix"
525+
bias = op(self._bias, other._bias)
537526
matrix_dims = sparse_dims(self._matrix) & sparse_dims(other._matrix)
538527
self_matrix = expand_matrix(self._matrix, matrix_dims)
539528
other_matrix = expand_matrix(other._matrix, matrix_dims)
540-
matrix = operator(self_matrix, other_matrix) # ToDo if other has no dependence on vector, it would also be in the output
529+
matrix = op(self_matrix, other_matrix) # ToDo if other has no dependence on vector, it would also be in the output
541530
shape = self._shape & other._shape
542531
return SparseLinTracer(self._source, matrix, bias, shape)
543532
else:
544533
# other = self._tensor(other)
545-
if op_symbol in '*/':
546-
matrix = operator(self._matrix, other)
547-
bias = operator(self._bias, other)
534+
if op in {operator.mul, operator.truediv}:
535+
matrix = op(self._matrix, other)
536+
bias = op(self._bias, other)
548537
return SparseLinTracer(self._source, matrix, bias, self._shape & other.shape)
549-
elif op_symbol in '+-':
550-
bias = operator(self._bias, other)
538+
elif op in {operator.add, operator.sub}:
539+
bias = op(self._bias, other)
551540
return SparseLinTracer(self._source, self._matrix, bias, self._shape & other.shape)
552541
else:
553-
raise ValueError(f"Unsupported operation encountered while tracing linear function: {native_function}")
542+
raise ValueError(f"Unsupported operation {op} encountered while tracing linear function")
554543

555544
@property
556545
def _is_tracer(self) -> bool:

0 commit comments

Comments
 (0)
Please sign in to comment.