1
+ import operator
1
2
import warnings
2
3
from functools import partial
3
4
from numbers import Number
11
12
from ._magic_ops import concat , pack_dims , expand , rename_dims , stack , unpack_dim , unstack
12
13
from ._shape import Shape , non_batch , merge_shapes , instance , batch , non_instance , shape , channel , spatial , DimFilter , \
13
14
concat_shapes , EMPTY_SHAPE , dual , non_channel , DEBUG_CHECKS , primal , concat_shapes_
14
- from ._tensors import Tensor , TensorStack , Dense , cached , wrap , reshaped_tensor , tensor , backend_for
15
+ from ._tensors import Tensor , TensorStack , Dense , cached , wrap , reshaped_tensor , tensor , backend_for , custom_op2
15
16
from ..backend import choose_backend , NUMPY , Backend , get_precision
16
17
from ..backend ._dtype import DType , INT64
17
18
@@ -367,20 +368,20 @@ def _with_shape_replaced(self, new_shape: Shape):
367
368
def _op1 (self , native_function ):
368
369
return self ._with_values (self ._values ._op1 (native_function ))
369
370
370
- def _op2 (self , other , operator : Callable , native_function : Callable , op_name : str = 'unknown' , op_symbol : str = '?' ) -> 'Tensor' :
371
+ def _op2 (self , other , op : Callable , switch_args : bool ) -> 'Tensor' :
371
372
other_shape = shape (other )
372
373
affects_only_values = self ._dense_shape .isdisjoint (other_shape )
373
374
if affects_only_values :
374
- return self ._with_values (operator (self ._values , other ))
375
+ return self ._with_values (op (self ._values , other ))
375
376
if isinstance (other , CompressedSparseMatrix ):
376
377
other = other .decompress ()
377
378
if isinstance (other , SparseCoordinateTensor ):
378
379
if same_sparsity_pattern (self , other ):
379
- return self ._with_values (operator (self ._values , other ._values ))
380
+ return self ._with_values (op (self ._values , other ._values ))
380
381
else :
381
- if op_name not in [ ' add' , 'radd' , ' sub' , 'rsub' ] :
382
+ if op not in { operator . add , operator . sub } :
382
383
same_sparsity_pattern (self , other ) # debug checkpoint
383
- raise AssertionError (f"Operation '{ op_symbol } ' ( { op_name } ) requires sparse matrices with the same sparsity pattern." )
384
+ raise AssertionError (f"Operation '{ op } ' requires sparse matrices with the same sparsity pattern." )
384
385
all_sparse_dims = sparse_dims (other ) & sparse_dims (self )
385
386
self_indices = pack_dims (self ._indices , instance , instance ('sp_entries' ))
386
387
other_indices = pack_dims (other ._indices , instance , instance ('sp_entries' ))
@@ -389,21 +390,19 @@ def _op2(self, other, operator: Callable, native_function: Callable, op_name: st
389
390
self_indices , self_values = with_sparsified_dim (self_indices , self_values , all_sparse_dims )
390
391
other_indices , other_values = with_sparsified_dim (other_indices , other_values , all_sparse_dims )
391
392
indices = concat ([self_indices , other_indices ], 'sp_entries' )
392
- if op_symbol == '+' :
393
+ if op == operator . add :
393
394
values = concat ([self_values , other_values ], instance (self_values ), expand_values = True )
394
- elif op_name == 'sub' :
395
- values = concat ([self_values , - other_values ], instance (self_values ), expand_values = True )
396
- else : # op_name == 'rsub':
397
- values = concat ([- self_values , other_values ], instance (self_values ), expand_values = True )
395
+ else :
396
+ values = concat ([- self_values , other_values ] if switch_args else [self_values , - other_values ], instance (self_values ), expand_values = True )
398
397
return SparseCoordinateTensor (indices , values , self ._dense_shape & other ._dense_shape , can_contain_double_entries = True , indices_sorted = False , indices_constant = self ._indices_constant )
399
398
else : # other is dense
400
399
if self ._dense_shape in other .shape : # all dims dense -> convert to dense
401
- return dense (self )._op2 (other , operator , native_function , op_name , op_symbol )
400
+ return dense (self )._op2 (other , op , switch_args )
402
401
else : # only some dims dense -> stay sparse
403
402
dense_dims = self ._dense_shape .only (other .shape )
404
403
assert instance (other ).without (self ._dense_shape ).is_empty , f"Instance dims cannot be added to sparse tensors from sparse-dense operations but got { other .shape } for sparse tensor { self .shape } "
405
404
other_values = other [self ._indices .sparse_idx [dense_dims .name_list ]]
406
- values = operator (self ._values , other_values )
405
+ values = custom_op2 (self ._values , other_values , op , switch_args )
407
406
return self ._with_values (values )
408
407
409
408
def _getitem (self , selection : dict ) -> 'Tensor' :
@@ -702,19 +701,19 @@ def __expand__(self, dims: Shape, **kwargs) -> 'Tensor':
702
701
def _op1 (self , native_function ):
703
702
return self ._with_values (self ._values ._op1 (native_function ))
704
703
705
- def _op2 (self , other , operator : Callable , native_function : Callable , op_name : str = 'unknown' , op_symbol : str = '?' ) -> 'Tensor' :
704
+ def _op2 (self , other , op : Callable , switch_args : bool ) -> 'Tensor' :
706
705
other_shape = shape (other )
707
706
affects_only_values = self .sparse_dims .isdisjoint (other_shape ) and non_instance (self ._indices ).isdisjoint (other_shape )
708
707
if affects_only_values :
709
- return self ._with_values (operator (self ._values , other ))
708
+ return self ._with_values (custom_op2 (self ._values , other , op , switch_args ))
710
709
elif isinstance (other , CompressedSparseMatrix ):
711
710
if same_sparsity_pattern (self , other ):
712
- result = operator (self ._values , other ._values )
711
+ result = op (self ._values , other ._values )
713
712
if self ._uncompressed_offset is not None :
714
713
from ._ops import where
715
714
result = where (self ._valid_mask (), result , 0 )
716
715
return self ._with_values (result )
717
- elif op_symbol == '+' :
716
+ elif op == operator . add :
718
717
raise NotImplementedError ("Compressed addition not yet implemented" )
719
718
else :
720
719
# convert to COO, then perform operation
@@ -723,19 +722,19 @@ def _op2(self, other, operator: Callable, native_function: Callable, op_name: st
723
722
from ._ops import gather , boolean_mask , clip , where
724
723
if self ._uncompressed_offset is None :
725
724
other_values = gather (other , self ._indices , self ._uncompressed_dims )
726
- return self ._with_values (operator (self ._values , other_values ))
725
+ return self ._with_values (op (self ._values , other_values ))
727
726
# if bake_slice:
728
727
# baked = self._bake_slice()
729
728
# other_values = gather(other, baked._indices, self._uncompressed_dims)
730
729
# return baked._with_values(operator(baked._values, other_values))
731
730
indices = clip (self ._indices - self ._uncompressed_offset , 0 , self ._uncompressed_dims .volume - 1 )
732
731
other_values = gather (other , indices , self ._uncompressed_dims )
733
- return self ._with_values (where (self ._valid_mask (), operator (self ._values , other_values ), 0 ))
732
+ return self ._with_values (where (self ._valid_mask (), op (self ._values , other_values ), 0 ))
734
733
elif self ._compressed_dims in other_shape and self ._uncompressed_dims .isdisjoint (other_shape ):
735
734
from ._ops import gather , boolean_mask , clip , where
736
735
row_indices , _ = self ._coo_indices ('clamp' )
737
736
other_values = gather (other , row_indices , self ._compressed_dims )
738
- result_values = operator (self ._values , other_values )
737
+ result_values = op (self ._values , other_values )
739
738
if self ._uncompressed_offset is not None :
740
739
result_values = where (self ._valid_mask (), result_values , 0 )
741
740
return self ._with_values (result_values )
@@ -960,16 +959,16 @@ def _with_shape_replaced(self, new_shape: Shape):
960
959
def _op1 (self , native_function ):
961
960
return self ._with_values (self ._values ._op1 (native_function ))
962
961
963
- def _op2 (self , other , operator : Callable , native_function : Callable , op_name : str = 'unknown' , op_symbol : str = '?' ) -> 'Tensor' :
962
+ def _op2 (self , other , op : Callable , switch_args : bool ) -> 'Tensor' :
964
963
other_shape = shape (other )
965
964
affects_only_values = self ._compressed_dims .isdisjoint (other_shape )
966
965
if affects_only_values :
967
- return self ._with_values (operator (self ._values , other ))
966
+ return self ._with_values (op (self ._values , other ))
968
967
elif isinstance (other , (CompressedSparseMatrix , CompactSparseTensor )):
969
968
if same_sparsity_pattern (self , other ):
970
- result = operator (self ._values , other ._values )
969
+ result = op (self ._values , other ._values )
971
970
return self ._with_values (result )
972
- elif op_symbol == '+' :
971
+ elif op == operator . add :
973
972
raise NotImplementedError ("Compressed addition not yet implemented" )
974
973
else :
975
974
# convert to COO, then perform operation
@@ -978,18 +977,18 @@ def _op2(self, other, operator: Callable, native_function: Callable, op_name: st
978
977
from ._ops import gather , boolean_mask , clip , where
979
978
if self ._uncompressed_offset is None :
980
979
other_values = gather (other , self ._indices , self ._uncompressed_dims )
981
- return self ._with_values (operator (self ._values , other_values ))
980
+ return self ._with_values (op (self ._values , other_values ))
982
981
# if bake_slice:
983
982
# baked = self._bake_slice()
984
983
# other_values = gather(other, baked._indices, self._uncompressed_dims)
985
984
# return baked._with_values(operator(baked._values, other_values))
986
985
indices = clip (self ._indices - self ._uncompressed_offset , 0 , self ._uncompressed_dims .volume - 1 )
987
986
other_values = gather (other , indices , self ._uncompressed_dims )
988
- return self ._with_values (where (self ._valid_mask (), operator (self ._values , other_values ), 0 ))
987
+ return self ._with_values (where (self ._valid_mask (), op (self ._values , other_values ), 0 ))
989
988
elif self ._compressed_dims in other_shape and self ._uncompressed_dims .isdisjoint (other_shape ):
990
989
from ._ops import gather , boolean_mask , clip , where
991
990
other_values = gather (other , self ._indices , self ._compressed_dims )
992
- result_values = operator (self ._values , other_values )
991
+ result_values = op (self ._values , other_values )
993
992
return self ._with_values (result_values )
994
993
else :
995
994
raise NotImplementedError
0 commit comments