|
4 | 4 |
|
5 | 5 | from setuptools import find_packages, setup
|
6 | 6 |
|
7 |
| -npu_available = False |
8 |
| -try: |
9 |
| - import torch_npu |
10 |
| - |
11 |
| - npu_available = torch_npu.npu.is_available() |
12 |
| -except ImportError: |
13 |
| - pass |
14 |
| - |
15 | 7 | pwd = os.path.dirname(__file__)
|
16 | 8 | version_file = 'lmdeploy/version.py'
|
17 | 9 |
|
18 | 10 |
|
| 11 | +def get_target_device(): |
| 12 | + return os.getenv('LMDEPLOY_TARGET_DEVICE', 'cuda') |
| 13 | + |
| 14 | + |
19 | 15 | def readme():
|
20 | 16 | with open(os.path.join(pwd, 'README.md'), encoding='utf-8') as f:
|
21 | 17 | content = f.read()
|
@@ -154,16 +150,12 @@ def gen_packages_items():
|
154 | 150 | setup_requires=parse_requirements('requirements/build.txt'),
|
155 | 151 | tests_require=parse_requirements('requirements/test.txt'),
|
156 | 152 | install_requires=parse_requirements(
|
157 |
| - 'requirements/runtime_ascend.txt' |
158 |
| - if npu_available else 'requirements/runtime.txt'), |
| 153 | + f'requirements/runtime_{get_target_device()}.txt'), |
159 | 154 | extras_require={
|
160 | 155 | 'all':
|
161 |
| - parse_requirements('requirements_ascend.txt' |
162 |
| - if npu_available else 'requirements.txt'), |
163 |
| - 'lite': |
164 |
| - parse_requirements('requirements/lite.txt'), |
165 |
| - 'serve': |
166 |
| - parse_requirements('requirements/serve.txt') |
| 156 | + parse_requirements(f'requirements_{get_target_device()}.txt'), |
| 157 | + 'lite': parse_requirements('requirements/lite.txt'), |
| 158 | + 'serve': parse_requirements('requirements/serve.txt') |
167 | 159 | },
|
168 | 160 | has_ext_modules=check_ext_modules,
|
169 | 161 | classifiers=[
|
|
0 commit comments