Skip to content

Commit 41ef7a1

Browse files
author
shengtsui
authoredNov 12, 2024··
VTEN-13-Add-astype-method (#15)
1 parent 35e19dd commit 41ef7a1

File tree

6 files changed

+44
-10
lines changed

6 files changed

+44
-10
lines changed
 

‎README.md

-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ sphinx-build -b html docs/source docs/build/html
2222
```
2323

2424
### Future updates
25-
- Support Solvers (Cholesky, QR etc.)
2625
- Support GPUDirect
2726
- Support more matrix operations
2827
- Support Sparse martix with CuSparse

‎docs/source/api/core/astype.rst

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
vt::astype
2+
=======================
3+
4+
.. toctree::
5+
:maxdepth: 1
6+
7+
.. doxygenfunction:: vt::astype

‎docs/source/api/core/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Core
44
.. toctree::
55
:maxdepth: 1
66

7+
astype
78
broadcast
89
broadcast_to
910
cutensor

‎lib/core/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ cuda_library(
55
visibility = ["//visibility:public"],
66
hdrs = [
77
"assertions.hpp",
8+
"astype.hpp",
89
"broadcast.hpp",
910
"cutensor.hpp",
1011
"iterator.hpp",

‎lib/core/astype.hpp

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#pragma once
2+
3+
#include "lib/core/tensor.hpp"
4+
5+
namespace vt {
6+
7+
/// Forward declaration of the Tensor class.
8+
template <typename T, size_t N>
9+
class Tensor;
10+
11+
/**
12+
* @brief Convert the tensor to a new data type.
13+
*
14+
* @tparam T: Data type of the tensor.
15+
* @tparam U: The data type to cast.
16+
* @tparam N: Number of dimensions of the tensor.
17+
* @param tensor: The tensor object.
18+
* @return Tensor<U, N>: The new tensor object.
19+
*/
20+
template <typename T, typename U, size_t N>
21+
Tensor<U, N> astype(Tensor<T, N> tensor) {
22+
if constexpr (std::is_same_v<T, U>) {
23+
return tensor;
24+
} else {
25+
auto result = Tensor<U, N>(tensor.shape());
26+
thrust::transform(tensor.begin(), tensor.end(), result.begin(), [] __device__(const T& x) { return static_cast<U>(x); });
27+
return result;
28+
}
29+
}
30+
31+
} // namespace vt

‎lib/core/tensor.hpp

+4-9
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <rmm/device_vector.hpp>
1212

1313
#include "lib/core/assertions.hpp"
14+
#include "lib/core/astype.hpp"
1415
#include "lib/core/iterator.hpp"
1516
#include "lib/core/slice.hpp"
1617

@@ -30,7 +31,7 @@ using Shape = std::array<size_t, N>;
3031
template <size_t N>
3132
size_t get_size(const Shape<N>& shape) {
3233
size_t size = 1;
33-
for (size_t i = 0; i < N; ++i) size *= shape[i];
34+
for (int i = 0; i < N; ++i) size *= shape[i];
3435
return size;
3536
}
3637

@@ -417,14 +418,8 @@ class Tensor {
417418
* @return Tensor<U, N>: The new tensor object.
418419
*/
419420
template <typename U>
420-
Tensor<U, N> astype() const {
421-
if constexpr (std::is_same_v<T, U>) {
422-
return *this;
423-
} else {
424-
auto result = Tensor<U, N>(_shape);
425-
thrust::transform(this->begin(), this->end(), result.begin(), [] __device__(const T& x) { return static_cast<U>(x); });
426-
return result;
427-
}
421+
Tensor<U, N> astype() {
422+
return vt::astype<T, U, N>(*this);
428423
}
429424

430425
/**

0 commit comments

Comments
 (0)
Please sign in to comment.