From 3ac2b5527af405e56b7545d0b484e6f481e186d1 Mon Sep 17 00:00:00 2001
From: Mykhailo Bobrovskyi <mikhail.bobrovsky@gmail.com>
Date: Wed, 5 Mar 2025 17:30:32 +0200
Subject: [PATCH 1/3] Kjob storage configuration for run command.

---
 .github/workflows/filestore_tests.yaml        |  4 +++
 .github/workflows/fuse_tests.yaml             |  2 ++
 src/xpk/commands/batch.py                     | 16 +++++-----
 src/xpk/commands/run.py                       | 18 ++++++++++-
 src/xpk/commands/shell.py                     | 18 +++++------
 src/xpk/commands/storage.py                   |  2 +-
 src/xpk/commands/workload.py                  | 17 ++++------
 src/xpk/core/cluster.py                       | 17 +++++++---
 src/xpk/core/config.py                        |  4 ---
 src/xpk/core/kjob.py                          | 32 ++++++++-----------
 src/xpk/core/pathways.py                      |  9 ++----
 src/xpk/core/storage.py                       |  6 ++--
 .../workload_decorators/storage_decorator.py  |  3 +-
 src/xpk/parser/run.py                         | 25 +++++++--------
 14 files changed, 92 insertions(+), 81 deletions(-)

diff --git a/.github/workflows/filestore_tests.yaml b/.github/workflows/filestore_tests.yaml
index ea205a8da..114d1dd90 100644
--- a/.github/workflows/filestore_tests.yaml
+++ b/.github/workflows/filestore_tests.yaml
@@ -124,6 +124,8 @@ jobs:
       run: python3 xpk.py job cancel $READ_JOB_NAME --cluster ${{inputs.cluster-name}} --zone=${{inputs.zone}} | grep "job.batch/$READ_JOB_NAME deleted"
     - name: Delete batch-read.log file
       run: rm batch-read.log
+    - name: Run a run-read job on the cluster
+      run: python3 xpk.py run --cluster ${{inputs.cluster-name}} --zone=${{inputs.zone}} batch-read.sh --timeout 60
     - name: Delete batch-read.sh file
       run: rm batch-read.sh
     - name: Create shell and exit it immediately
@@ -247,6 +249,8 @@ jobs:
       run: python3 xpk.py job cancel $READ_JOB_NAME --cluster ${{inputs.cluster-name}} --zone=${{inputs.zone}} | grep "job.batch/$READ_JOB_NAME deleted"
     - name: Delete batch-read.log file
       run: rm batch-read.log
+    - name: Run a run-read job on the cluster
+      run: python3 xpk.py run --cluster ${{inputs.cluster-name}} --zone=${{inputs.zone}} batch-read.sh --timeout 60
     - name: Delete batch-read.sh file
       run: rm batch-read.sh
     - name: Create shell and exit it immediately
diff --git a/.github/workflows/fuse_tests.yaml b/.github/workflows/fuse_tests.yaml
index 067225e56..91234746f 100644
--- a/.github/workflows/fuse_tests.yaml
+++ b/.github/workflows/fuse_tests.yaml
@@ -121,6 +121,8 @@ jobs:
       run: python3 xpk.py job cancel $READ_JOB_NAME --cluster ${{inputs.cluster-name}} --zone=${{inputs.zone}} | grep "job.batch/$READ_JOB_NAME deleted"
     - name: Delete batch-read.log file
       run: rm batch-read.log
+    - name: Run a run-read job on the cluster
+      run: python3 xpk.py run --cluster ${{inputs.cluster-name}} --zone=${{inputs.zone}} batch-read.sh --timeout 60
     - name: Delete batch-read.sh file
       run: rm batch-read.sh
     - name: Create shell and exit it immediately
diff --git a/src/xpk/commands/batch.py b/src/xpk/commands/batch.py
index 4a3397225..f6ad43dd5 100644
--- a/src/xpk/commands/batch.py
+++ b/src/xpk/commands/batch.py
@@ -16,15 +16,18 @@
 
 from argparse import Namespace
 
-from ..core.cluster import setup_k8s_env, create_k8s_service_account
 from ..core.commands import run_command_for_value
-from ..core.config import GCS_FUSE_ANNOTATION_KEY, GCS_FUSE_ANNOTATION_VALUE, XPK_SA, DEFAULT_NAMESPACE
 from ..core.gcloud_context import add_zone_and_project
 from ..core.kueue import LOCAL_QUEUE_NAME
-from ..core.storage import get_auto_mount_gcsfuse_storages
+from ..core.storage import GCS_FUSE_ANNOTATION_KEY, GCS_FUSE_ANNOTATION_VALUE
 from ..utils.console import xpk_exit, xpk_print
 from .common import set_cluster_command
-from ..core.kjob import AppProfileDefaults, prepare_kjob, Kueue_TAS_annotation
+from ..core.kjob import (
+    AppProfileDefaults,
+    prepare_kjob,
+    Kueue_TAS_annotation,
+    create_service_account_and_get_gcsfuse_storages,
+)
 from .kind import set_local_cluster_command
 import re
 
@@ -54,10 +57,6 @@ def batch(args: Namespace) -> None:
 
 
 def submit_job(args: Namespace) -> None:
-  k8s_api_client = setup_k8s_env(args)
-  create_k8s_service_account(XPK_SA, DEFAULT_NAMESPACE)
-  gcs_fuse_storages = get_auto_mount_gcsfuse_storages(k8s_api_client)
-
   cmd = (
       'kubectl kjob create slurm'
       f' --profile {AppProfileDefaults.NAME.value}'
@@ -66,6 +65,7 @@ def submit_job(args: Namespace) -> None:
       ' --first-node-ip'
   )
 
+  gcs_fuse_storages = create_service_account_and_get_gcsfuse_storages(args)
   if len(gcs_fuse_storages) > 0:
     cmd += (
         ' --pod-template-annotation'
diff --git a/src/xpk/commands/run.py b/src/xpk/commands/run.py
index bf90a25e2..1d719c7b8 100644
--- a/src/xpk/commands/run.py
+++ b/src/xpk/commands/run.py
@@ -19,9 +19,15 @@
 from ..core.commands import run_command_with_full_controls
 from ..core.gcloud_context import add_zone_and_project
 from ..core.kueue import LOCAL_QUEUE_NAME
+from ..core.storage import GCS_FUSE_ANNOTATION_KEY, GCS_FUSE_ANNOTATION_VALUE
 from ..utils.console import xpk_exit, xpk_print
 from .common import set_cluster_command
-from ..core.kjob import AppProfileDefaults, prepare_kjob, Kueue_TAS_annotation
+from ..core.kjob import (
+    AppProfileDefaults,
+    prepare_kjob,
+    Kueue_TAS_annotation,
+    create_service_account_and_get_gcsfuse_storages,
+)
 from .kind import set_local_cluster_command
 
 
@@ -59,6 +65,16 @@ def submit_job(args: Namespace) -> None:
       ' --rm'
   )
 
+  gcs_fuse_storages = create_service_account_and_get_gcsfuse_storages(args)
+  if len(gcs_fuse_storages) > 0:
+    cmd += (
+        ' --pod-template-annotation'
+        f' {GCS_FUSE_ANNOTATION_KEY}={GCS_FUSE_ANNOTATION_VALUE}'
+    )
+
+  if args.timeout:
+    cmd += f' --wait-timeout {args.timeout}s'
+
   if args.ignore_unknown_flags:
     cmd += ' --ignore-unknown-flags'
 
diff --git a/src/xpk/commands/shell.py b/src/xpk/commands/shell.py
index bd448b755..cc05a283e 100644
--- a/src/xpk/commands/shell.py
+++ b/src/xpk/commands/shell.py
@@ -12,14 +12,17 @@
 """
 
 from ..core.commands import run_command_with_full_controls, run_command_for_value, run_command_with_updates
-from ..core.cluster import get_cluster_credentials, add_zone_and_project, setup_k8s_env, create_k8s_service_account
-from ..core.config import GCS_FUSE_ANNOTATION_KEY, GCS_FUSE_ANNOTATION_VALUE, XPK_SA, DEFAULT_NAMESPACE
-from ..core.storage import get_auto_mount_gcsfuse_storages
+from ..core.cluster import get_cluster_credentials, add_zone_and_project
+from ..core.storage import GCS_FUSE_ANNOTATION_KEY, GCS_FUSE_ANNOTATION_VALUE
 from ..utils.console import xpk_exit, xpk_print
 from argparse import Namespace
 
-from ..core.kjob import AppProfileDefaults, prepare_kjob, get_pod_template_interactive_command
-
+from ..core.kjob import (
+    AppProfileDefaults,
+    prepare_kjob,
+    get_pod_template_interactive_command,
+    create_service_account_and_get_gcsfuse_storages,
+)
 
 exit_instructions = 'To exit the shell input "exit".'
 
@@ -81,15 +84,12 @@ def connect_to_new_interactive_shell(args: Namespace) -> int:
   if err_code > 0:
     xpk_exit(err_code)
 
-  k8s_api_client = setup_k8s_env(args)
-  create_k8s_service_account(XPK_SA, DEFAULT_NAMESPACE)
-  gcs_fuse_storages = get_auto_mount_gcsfuse_storages(k8s_api_client)
-
   cmd = (
       'kubectl-kjob create interactive --profile'
       f' {AppProfileDefaults.NAME.value} --pod-running-timeout 180s'
   )
 
+  gcs_fuse_storages = create_service_account_and_get_gcsfuse_storages(args)
   if len(gcs_fuse_storages) > 0:
     cmd += (
         ' --pod-template-annotation'
diff --git a/src/xpk/commands/storage.py b/src/xpk/commands/storage.py
index 5befb7118..a35ffa446 100644
--- a/src/xpk/commands/storage.py
+++ b/src/xpk/commands/storage.py
@@ -26,8 +26,8 @@
     update_cluster_with_gcpfilestore_driver_if_necessary,
     add_zone_and_project,
     get_cluster_network,
+    DEFAULT_NAMESPACE,
 )
-from ..core.config import DEFAULT_NAMESPACE
 from ..core.kjob import (
     KJOB_API_GROUP_NAME,
     KJOB_API_GROUP_VERSION,
diff --git a/src/xpk/commands/workload.py b/src/xpk/commands/workload.py
index ecbecf5f8..d75c8c2df 100644
--- a/src/xpk/commands/workload.py
+++ b/src/xpk/commands/workload.py
@@ -15,20 +15,13 @@
 """
 
 from ..core.cluster import (
-    create_k8s_service_account,
+    create_xpk_k8s_service_account,
     get_cluster_credentials,
     setup_k8s_env,
-)
-from ..core.commands import run_command_with_updates, run_commands
-from ..core.config import (
-    GCS_FUSE_ANNOTATION_KEY,
-    GCS_FUSE_ANNOTATION_VALUE,
-    VERTEX_TENSORBOARD_FEATURE_FLAG,
-    XPK_CURRENT_VERSION,
-    parse_env_config,
     XPK_SA,
-    DEFAULT_NAMESPACE,
 )
+from ..core.commands import run_command_with_updates, run_commands
+from ..core.config import VERTEX_TENSORBOARD_FEATURE_FLAG, XPK_CURRENT_VERSION, parse_env_config
 from ..core.docker_container import (
     get_main_container_docker_image,
     get_user_workload_container,
@@ -68,6 +61,8 @@
     get_storages_to_mount,
     get_storage_volume_mounts_yaml_for_gpu,
     get_storage_volumes_yaml_for_gpu,
+    GCS_FUSE_ANNOTATION_KEY,
+    GCS_FUSE_ANNOTATION_VALUE,
 )
 from ..core.system_characteristics import (
     AcceleratorType,
@@ -491,7 +486,7 @@ def workload_create(args) -> None:
     0 if successful and 1 otherwise.
   """
   k8s_api_client = setup_k8s_env(args)
-  create_k8s_service_account(XPK_SA, DEFAULT_NAMESPACE)
+  create_xpk_k8s_service_account()
 
   workload_exists = check_if_workload_exists(args)
 
diff --git a/src/xpk/core/cluster.py b/src/xpk/core/cluster.py
index 4d14f5b1e..c7ae017b4 100644
--- a/src/xpk/core/cluster.py
+++ b/src/xpk/core/cluster.py
@@ -35,6 +35,9 @@
 INSTALLER_NCC_TCPX = 'https://raw.githubusercontent.com/GoogleCloudPlatform/container-engine-accelerators/master/gpudirect-tcpx/nccl-tcpx-installer.yaml'
 INSTALLER_NCC_TCPXO = 'https://raw.githubusercontent.com/GoogleCloudPlatform/container-engine-accelerators/master/gpudirect-tcpxo/nccl-tcpxo-installer.yaml'
 
+DEFAULT_NAMESPACE = 'default'
+XPK_SA = 'xpk-sa'
+
 
 # TODO(vbarr): Remove this function when jobsets gets enabled by default on
 # GKE clusters.
@@ -232,18 +235,22 @@ def setup_k8s_env(args) -> k8s_client.ApiClient:
   return k8s_client.ApiClient()  # pytype: disable=bad-return-type
 
 
-def create_k8s_service_account(name: str, namespace: str) -> None:
+def create_xpk_k8s_service_account() -> None:
   k8s_core_client = k8s_client.CoreV1Api()
-  sa = k8s_client.V1ServiceAccount(metadata=k8s_client.V1ObjectMeta(name=name))
+  sa = k8s_client.V1ServiceAccount(
+      metadata=k8s_client.V1ObjectMeta(name=XPK_SA)
+  )
 
-  xpk_print(f'Creating a new service account: {name}')
+  xpk_print(f'Creating a new service account: {XPK_SA}')
   try:
     k8s_core_client.create_namespaced_service_account(
-        namespace, sa, pretty=True
+        DEFAULT_NAMESPACE, sa, pretty=True
     )
     xpk_print(f'Created a new service account: {sa} successfully')
   except ApiException:
-    xpk_print(f'Service account: {name} already exists. Skipping its creation')
+    xpk_print(
+        f'Service account: {XPK_SA} already exists. Skipping its creation'
+    )
 
 
 def update_gke_cluster_with_clouddns(args) -> int:
diff --git a/src/xpk/core/config.py b/src/xpk/core/config.py
index e0a9fb8bd..404c0e62f 100644
--- a/src/xpk/core/config.py
+++ b/src/xpk/core/config.py
@@ -30,8 +30,6 @@
 XPK_CONFIG_FILE = os.path.expanduser('~/.config/xpk/config.yaml')
 
 CONFIGS_KEY = 'configs'
-DEFAULT_NAMESPACE = 'default'
-XPK_SA = 'xpk-sa'
 CFG_BUCKET_KEY = 'cluster-state-gcs-bucket'
 CLUSTER_NAME_KEY = 'cluster-name'
 PROJECT_KEY = 'project-id'
@@ -58,8 +56,6 @@
     KJOB_SHELL_WORKING_DIRECTORY,
 ]
 VERTEX_TENSORBOARD_FEATURE_FLAG = XPK_CURRENT_VERSION >= '0.4.0'
-GCS_FUSE_ANNOTATION_KEY = 'gke-gcsfuse/volumes'
-GCS_FUSE_ANNOTATION_VALUE = 'true'
 
 
 yaml = ruamel.yaml.YAML()
diff --git a/src/xpk/core/kjob.py b/src/xpk/core/kjob.py
index 31432ebc5..2d64b3dcf 100644
--- a/src/xpk/core/kjob.py
+++ b/src/xpk/core/kjob.py
@@ -21,9 +21,8 @@
 from kubernetes import client as k8s_client
 from kubernetes.client import ApiClient
 from kubernetes.client.rest import ApiException
-from .cluster import setup_k8s_env
-from .config import XPK_SA, DEFAULT_NAMESPACE
-from .storage import Storage, get_auto_mount_storages, GCS_FUSE_TYPE, GCP_FILESTORE_TYPE
+from .cluster import setup_k8s_env, XPK_SA, DEFAULT_NAMESPACE, create_xpk_k8s_service_account
+from .storage import get_auto_mount_storages, get_auto_mount_gcsfuse_storages
 from ..utils.console import xpk_print, xpk_exit
 from .commands import run_command_for_value, run_kubectl_apply, run_command_with_updates
 from .config import XpkConfig, KJOB_SHELL_IMAGE, KJOB_SHELL_INTERACTIVE_COMMAND, KJOB_SHELL_WORKING_DIRECTORY, KJOB_BATCH_IMAGE, KJOB_BATCH_WORKING_DIRECTORY
@@ -32,11 +31,7 @@
 
 KJOB_API_GROUP_NAME = "kjobctl.x-k8s.io"
 KJOB_API_GROUP_VERSION = "v1alpha1"
-KJOB_API_VOLUME_BUNDLE_KIND = "VolumeBundle"
-KJOB_API_VOLUME_BUNDLE_PLURAL = KJOB_API_VOLUME_BUNDLE_KIND.lower() + "s"
-KJOB_API_VOLUME_BUNDLE_CRD_NAME = (
-    f"{KJOB_API_VOLUME_BUNDLE_PLURAL}.{KJOB_API_GROUP_NAME}"
-)
+KJOB_API_VOLUME_BUNDLE_PLURAL = "volumebundles"
 VOLUME_BUNDLE_TEMPLATE_PATH = "/../templates/volume_bundle.yaml"
 
 
@@ -288,16 +283,10 @@ def prepare_kjob(args: Namespace) -> int:
   system = get_cluster_system_characteristics(args)
 
   k8s_api_client = setup_k8s_env(args)
-  storages: list[Storage] = get_auto_mount_storages(k8s_api_client)
-  gcs_fuse_storages = list(
-      filter(lambda storage: storage.type == GCS_FUSE_TYPE, storages)
-  )
-  gcp_filestore_storages = list(
-      filter(lambda storage: storage.type == GCP_FILESTORE_TYPE, storages)
-  )
+  storages = get_auto_mount_storages(k8s_api_client)
 
   service_account = ""
-  if len(gcs_fuse_storages) > 0 or len(gcp_filestore_storages) > 0:
+  if len(storages) > 0:
     service_account = XPK_SA
 
   job_err_code = create_job_template_instance(args, system, service_account)
@@ -308,8 +297,7 @@ def prepare_kjob(args: Namespace) -> int:
   if pod_err_code > 0:
     return pod_err_code
 
-  all_storages = gcs_fuse_storages + gcp_filestore_storages
-  volume_bundles = [item.name for item in all_storages]
+  volume_bundles = [item.name for item in storages]
 
   return create_app_profile_instance(args, volume_bundles)
 
@@ -387,7 +375,7 @@ def create_volume_bundle_instance(
         body=data,
     )
     xpk_print(
-        f"Created {KJOB_API_VOLUME_BUNDLE_CRD_NAME} object:"
+        f"Created {KJOB_API_VOLUME_BUNDLE_PLURAL}.{KJOB_API_GROUP_NAME} object:"
         f" {data['metadata']['name']}"
     )
   except ApiException as e:
@@ -398,3 +386,9 @@ def create_volume_bundle_instance(
     else:
       xpk_print(f"Encountered error during VolumeBundle creation: {e}")
       xpk_exit(1)
+
+
+def create_service_account_and_get_gcsfuse_storages(args: Namespace):
+  k8s_api_client = setup_k8s_env(args)
+  create_xpk_k8s_service_account()
+  return get_auto_mount_gcsfuse_storages(k8s_api_client)
diff --git a/src/xpk/core/pathways.py b/src/xpk/core/pathways.py
index f969f58aa..4e964f326 100644
--- a/src/xpk/core/pathways.py
+++ b/src/xpk/core/pathways.py
@@ -14,16 +14,13 @@
 limitations under the License.
 """
 
+from .cluster import XPK_SA
 from ..core.docker_container import get_user_workload_container
 from ..core.gcloud_context import zone_to_region
 from ..core.nodepool import get_all_nodepools_programmatic
 from ..utils.console import xpk_exit, xpk_print
-from .config import (
-    GCS_FUSE_ANNOTATION_KEY,
-    GCS_FUSE_ANNOTATION_VALUE,
-    AcceleratorType,
-)
-from .storage import XPK_SA, Storage, get_storage_volumes_yaml
+from .config import AcceleratorType
+from .storage import Storage, get_storage_volumes_yaml, GCS_FUSE_ANNOTATION_KEY, GCS_FUSE_ANNOTATION_VALUE
 from .system_characteristics import SystemCharacteristics
 
 PathwaysExpectedInstancesMap = {
diff --git a/src/xpk/core/storage.py b/src/xpk/core/storage.py
index 419d5b872..ae4598dbe 100644
--- a/src/xpk/core/storage.py
+++ b/src/xpk/core/storage.py
@@ -28,7 +28,7 @@
 from kubernetes.utils import FailToCreateError
 from tabulate import tabulate
 
-from .config import XPK_SA
+from .cluster import XPK_SA
 from ..utils.console import xpk_exit, xpk_print
 
 STORAGE_CRD_PATH = "/../api/storage_crd.yaml"
@@ -36,10 +36,12 @@
 XPK_API_GROUP_NAME = "xpk.x-k8s.io"
 XPK_API_GROUP_VERSION = "v1"
 STORAGE_CRD_KIND = "Storage"
-STORAGE_CRD_PLURAL = STORAGE_CRD_KIND.lower() + "s"
+STORAGE_CRD_PLURAL = "storages"
 STORAGE_CRD_NAME = f"{XPK_API_GROUP_NAME}.{STORAGE_CRD_PLURAL}"
 GCS_FUSE_TYPE = "gcsfuse"
 GCP_FILESTORE_TYPE = "gcpfilestore"
+GCS_FUSE_ANNOTATION_KEY = "gke-gcsfuse/volumes"
+GCS_FUSE_ANNOTATION_VALUE = "true"
 
 
 @dataclass
diff --git a/src/xpk/core/workload_decorators/storage_decorator.py b/src/xpk/core/workload_decorators/storage_decorator.py
index 37a2c02bf..32ded713f 100644
--- a/src/xpk/core/workload_decorators/storage_decorator.py
+++ b/src/xpk/core/workload_decorators/storage_decorator.py
@@ -16,8 +16,7 @@
 
 import yaml
 
-from ..config import GCS_FUSE_ANNOTATION_KEY, GCS_FUSE_ANNOTATION_VALUE
-from ...core.storage import GCS_FUSE_TYPE, get_storage_volumes_yaml_dict
+from ...core.storage import GCS_FUSE_TYPE, get_storage_volumes_yaml_dict, GCS_FUSE_ANNOTATION_KEY, GCS_FUSE_ANNOTATION_VALUE
 
 
 def decorate_jobset(jobset_manifest_str, storages) -> str:
diff --git a/src/xpk/parser/run.py b/src/xpk/parser/run.py
index ac97d025b..f5298afaa 100644
--- a/src/xpk/parser/run.py
+++ b/src/xpk/parser/run.py
@@ -14,10 +14,13 @@
 limitations under the License.
 """
 
-import argparse
-
 from ..commands.run import run
-from .common import add_shared_arguments, add_slurm_arguments
+from .common import (
+    add_shared_arguments,
+    add_slurm_arguments,
+    add_cluster_arguments,
+    add_kind_cluster_arguments,
+)
 
 
 def set_run_parser(run_parser):
@@ -30,19 +33,15 @@ def set_run_parser(run_parser):
 
   run_required_arguments.add_argument('script', help='script with task to run')
   run_optional_arguments.add_argument(
-      '--cluster',
-      type=str,
+      '--timeout',
+      type=int,
       default=None,
-      help='Cluster to which command applies.',
-  )
-  run_optional_arguments.add_argument(
-      '--kind-cluster',
-      type=bool,
-      action=argparse.BooleanOptionalAction,
-      default=False,
-      help='Apply command to a local test cluster.',
+      help='Amount of time to wait for job in seconds',
+      required=False,
   )
 
+  add_cluster_arguments(run_optional_arguments)
+  add_kind_cluster_arguments(run_optional_arguments)
   add_slurm_arguments(run_optional_arguments)
   add_shared_arguments(run_parser)
   run_parser.set_defaults(func=run)

From a2aaee5a9b3857baf7e319f35c35abe2192eb081 Mon Sep 17 00:00:00 2001
From: Mykhailo Bobrovskyi <mikhail.bobrovsky@gmail.com>
Date: Thu, 6 Mar 2025 18:50:26 +0200
Subject: [PATCH 2/3] Use get_storage_annotations.

---
 src/xpk/commands/batch.py | 14 ++++++--------
 src/xpk/commands/run.py   | 14 ++++++--------
 src/xpk/commands/shell.py | 15 ++++++---------
 src/xpk/core/kjob.py      | 10 ++++++----
 4 files changed, 24 insertions(+), 29 deletions(-)

diff --git a/src/xpk/commands/batch.py b/src/xpk/commands/batch.py
index f6ad43dd5..2168b3478 100644
--- a/src/xpk/commands/batch.py
+++ b/src/xpk/commands/batch.py
@@ -16,17 +16,17 @@
 
 from argparse import Namespace
 
+from ..core.cluster import create_xpk_k8s_service_account
 from ..core.commands import run_command_for_value
 from ..core.gcloud_context import add_zone_and_project
 from ..core.kueue import LOCAL_QUEUE_NAME
-from ..core.storage import GCS_FUSE_ANNOTATION_KEY, GCS_FUSE_ANNOTATION_VALUE
 from ..utils.console import xpk_exit, xpk_print
 from .common import set_cluster_command
 from ..core.kjob import (
     AppProfileDefaults,
     prepare_kjob,
     Kueue_TAS_annotation,
-    create_service_account_and_get_gcsfuse_storages,
+    get_gcsfuse_annotation,
 )
 from .kind import set_local_cluster_command
 import re
@@ -52,6 +52,7 @@ def batch(args: Namespace) -> None:
   err_code = prepare_kjob(args)
   if err_code > 0:
     xpk_exit(err_code)
+  create_xpk_k8s_service_account()
 
   submit_job(args)
 
@@ -65,12 +66,9 @@ def submit_job(args: Namespace) -> None:
       ' --first-node-ip'
   )
 
-  gcs_fuse_storages = create_service_account_and_get_gcsfuse_storages(args)
-  if len(gcs_fuse_storages) > 0:
-    cmd += (
-        ' --pod-template-annotation'
-        f' {GCS_FUSE_ANNOTATION_KEY}={GCS_FUSE_ANNOTATION_VALUE}'
-    )
+  gcsfuse_annotation = get_gcsfuse_annotation(args)
+  if gcsfuse_annotation is not None:
+    cmd += f' --pod-template-annotation {gcsfuse_annotation}'
 
   if args.ignore_unknown_flags:
     cmd += ' --ignore-unknown-flags'
diff --git a/src/xpk/commands/run.py b/src/xpk/commands/run.py
index 1d719c7b8..a998e8cde 100644
--- a/src/xpk/commands/run.py
+++ b/src/xpk/commands/run.py
@@ -16,17 +16,17 @@
 
 from argparse import Namespace
 
+from ..core.cluster import create_xpk_k8s_service_account
 from ..core.commands import run_command_with_full_controls
 from ..core.gcloud_context import add_zone_and_project
 from ..core.kueue import LOCAL_QUEUE_NAME
-from ..core.storage import GCS_FUSE_ANNOTATION_KEY, GCS_FUSE_ANNOTATION_VALUE
 from ..utils.console import xpk_exit, xpk_print
 from .common import set_cluster_command
 from ..core.kjob import (
     AppProfileDefaults,
     prepare_kjob,
     Kueue_TAS_annotation,
-    create_service_account_and_get_gcsfuse_storages,
+    get_gcsfuse_annotation,
 )
 from .kind import set_local_cluster_command
 
@@ -51,6 +51,7 @@ def run(args: Namespace) -> None:
   err_code = prepare_kjob(args)
   if err_code > 0:
     xpk_exit(err_code)
+  create_xpk_k8s_service_account()
 
   submit_job(args)
 
@@ -65,12 +66,9 @@ def submit_job(args: Namespace) -> None:
       ' --rm'
   )
 
-  gcs_fuse_storages = create_service_account_and_get_gcsfuse_storages(args)
-  if len(gcs_fuse_storages) > 0:
-    cmd += (
-        ' --pod-template-annotation'
-        f' {GCS_FUSE_ANNOTATION_KEY}={GCS_FUSE_ANNOTATION_VALUE}'
-    )
+  gcsfuse_annotation = get_gcsfuse_annotation(args)
+  if gcsfuse_annotation is not None:
+    cmd += f' --pod-template-annotation {gcsfuse_annotation}'
 
   if args.timeout:
     cmd += f' --wait-timeout {args.timeout}s'
diff --git a/src/xpk/commands/shell.py b/src/xpk/commands/shell.py
index cc05a283e..5b67d28cf 100644
--- a/src/xpk/commands/shell.py
+++ b/src/xpk/commands/shell.py
@@ -12,8 +12,7 @@
 """
 
 from ..core.commands import run_command_with_full_controls, run_command_for_value, run_command_with_updates
-from ..core.cluster import get_cluster_credentials, add_zone_and_project
-from ..core.storage import GCS_FUSE_ANNOTATION_KEY, GCS_FUSE_ANNOTATION_VALUE
+from ..core.cluster import get_cluster_credentials, add_zone_and_project, create_xpk_k8s_service_account
 from ..utils.console import xpk_exit, xpk_print
 from argparse import Namespace
 
@@ -21,7 +20,7 @@
     AppProfileDefaults,
     prepare_kjob,
     get_pod_template_interactive_command,
-    create_service_account_and_get_gcsfuse_storages,
+    get_gcsfuse_annotation,
 )
 
 exit_instructions = 'To exit the shell input "exit".'
@@ -83,18 +82,16 @@ def connect_to_new_interactive_shell(args: Namespace) -> int:
   err_code = prepare_kjob(args)
   if err_code > 0:
     xpk_exit(err_code)
+  create_xpk_k8s_service_account()
 
   cmd = (
       'kubectl-kjob create interactive --profile'
       f' {AppProfileDefaults.NAME.value} --pod-running-timeout 180s'
   )
 
-  gcs_fuse_storages = create_service_account_and_get_gcsfuse_storages(args)
-  if len(gcs_fuse_storages) > 0:
-    cmd += (
-        ' --pod-template-annotation'
-        f' {GCS_FUSE_ANNOTATION_KEY}={GCS_FUSE_ANNOTATION_VALUE}'
-    )
+  gcsfuse_annotation = get_gcsfuse_annotation(args)
+  if gcsfuse_annotation is not None:
+    cmd += f' --pod-template-annotation {gcsfuse_annotation}'
 
   return run_command_with_full_controls(
       command=cmd,
diff --git a/src/xpk/core/kjob.py b/src/xpk/core/kjob.py
index 2d64b3dcf..b17a6d907 100644
--- a/src/xpk/core/kjob.py
+++ b/src/xpk/core/kjob.py
@@ -21,7 +21,7 @@
 from kubernetes import client as k8s_client
 from kubernetes.client import ApiClient
 from kubernetes.client.rest import ApiException
-from .cluster import setup_k8s_env, XPK_SA, DEFAULT_NAMESPACE, create_xpk_k8s_service_account
+from .cluster import setup_k8s_env, XPK_SA, DEFAULT_NAMESPACE
 from .storage import get_auto_mount_storages, get_auto_mount_gcsfuse_storages
 from ..utils.console import xpk_print, xpk_exit
 from .commands import run_command_for_value, run_kubectl_apply, run_command_with_updates
@@ -388,7 +388,9 @@ def create_volume_bundle_instance(
       xpk_exit(1)
 
 
-def create_service_account_and_get_gcsfuse_storages(args: Namespace):
+def get_gcsfuse_annotation(args: Namespace) -> str | None:
   k8s_api_client = setup_k8s_env(args)
-  create_xpk_k8s_service_account()
-  return get_auto_mount_gcsfuse_storages(k8s_api_client)
+  gcsfuse_storages = get_auto_mount_gcsfuse_storages(k8s_api_client)
+  if len(gcsfuse_storages) > 0:
+    return "gke-gcsfuse/volumes=true"
+  return None

From 189b21406e44eb9000f4841a868814c3bb921628 Mon Sep 17 00:00:00 2001
From: Mykhailo Bobrovskyi <mikhail.bobrovsky@gmail.com>
Date: Thu, 6 Mar 2025 19:01:48 +0200
Subject: [PATCH 3/3] Use GCS_FUSE_ANNOTATION.

---
 src/xpk/commands/workload.py                          | 7 ++-----
 src/xpk/core/pathways.py                              | 6 ++----
 src/xpk/core/storage.py                               | 3 +--
 src/xpk/core/workload_decorators/storage_decorator.py | 4 ++--
 4 files changed, 7 insertions(+), 13 deletions(-)

diff --git a/src/xpk/commands/workload.py b/src/xpk/commands/workload.py
index d75c8c2df..c197aadbf 100644
--- a/src/xpk/commands/workload.py
+++ b/src/xpk/commands/workload.py
@@ -61,8 +61,7 @@
     get_storages_to_mount,
     get_storage_volume_mounts_yaml_for_gpu,
     get_storage_volumes_yaml_for_gpu,
-    GCS_FUSE_ANNOTATION_KEY,
-    GCS_FUSE_ANNOTATION_VALUE,
+    GCS_FUSE_ANNOTATION,
 )
 from ..core.system_characteristics import (
     AcceleratorType,
@@ -569,9 +568,7 @@ def workload_create(args) -> None:
   storage_annotations = ''
   service_account = ''
   if len(gcs_fuse_storages) > 0:
-    storage_annotations = (
-        f'{GCS_FUSE_ANNOTATION_KEY}: "{GCS_FUSE_ANNOTATION_VALUE}"'
-    )
+    storage_annotations = GCS_FUSE_ANNOTATION
     service_account = XPK_SA
     xpk_print(f'Detected gcsfuse Storages to add: {gcs_fuse_storages}')
   else:
diff --git a/src/xpk/core/pathways.py b/src/xpk/core/pathways.py
index 4e964f326..395205df6 100644
--- a/src/xpk/core/pathways.py
+++ b/src/xpk/core/pathways.py
@@ -20,7 +20,7 @@
 from ..core.nodepool import get_all_nodepools_programmatic
 from ..utils.console import xpk_exit, xpk_print
 from .config import AcceleratorType
-from .storage import Storage, get_storage_volumes_yaml, GCS_FUSE_ANNOTATION_KEY, GCS_FUSE_ANNOTATION_VALUE
+from .storage import Storage, get_storage_volumes_yaml, GCS_FUSE_ANNOTATION
 from .system_characteristics import SystemCharacteristics
 
 PathwaysExpectedInstancesMap = {
@@ -330,9 +330,7 @@ def get_user_workload_for_pathways(
         storage_volumes=storage_volumes,
         pod_failure_policy=pod_failure_policy,
         service_account=XPK_SA,
-        gcs_fuse_annotation=(
-            f'{GCS_FUSE_ANNOTATION_KEY}: "{GCS_FUSE_ANNOTATION_VALUE}"'
-        ),
+        gcs_fuse_annotation=GCS_FUSE_ANNOTATION,
     )
 
 
diff --git a/src/xpk/core/storage.py b/src/xpk/core/storage.py
index ae4598dbe..5ac916f00 100644
--- a/src/xpk/core/storage.py
+++ b/src/xpk/core/storage.py
@@ -40,8 +40,7 @@
 STORAGE_CRD_NAME = f"{XPK_API_GROUP_NAME}.{STORAGE_CRD_PLURAL}"
 GCS_FUSE_TYPE = "gcsfuse"
 GCP_FILESTORE_TYPE = "gcpfilestore"
-GCS_FUSE_ANNOTATION_KEY = "gke-gcsfuse/volumes"
-GCS_FUSE_ANNOTATION_VALUE = "true"
+GCS_FUSE_ANNOTATION = 'gke-gcsfuse/volumes: "true"'
 
 
 @dataclass
diff --git a/src/xpk/core/workload_decorators/storage_decorator.py b/src/xpk/core/workload_decorators/storage_decorator.py
index 32ded713f..12d60df6a 100644
--- a/src/xpk/core/workload_decorators/storage_decorator.py
+++ b/src/xpk/core/workload_decorators/storage_decorator.py
@@ -16,7 +16,7 @@
 
 import yaml
 
-from ...core.storage import GCS_FUSE_TYPE, get_storage_volumes_yaml_dict, GCS_FUSE_ANNOTATION_KEY, GCS_FUSE_ANNOTATION_VALUE
+from ...core.storage import GCS_FUSE_TYPE, get_storage_volumes_yaml_dict, GCS_FUSE_ANNOTATION
 
 
 def decorate_jobset(jobset_manifest_str, storages) -> str:
@@ -44,7 +44,7 @@ def add_annotations(job_manifest, storages):
   annotations = job_manifest['spec']['template']['metadata']['annotations']
   gcs_present = [storage.type == GCS_FUSE_TYPE for storage in storages]
   if gcs_present:
-    annotations.update({GCS_FUSE_ANNOTATION_KEY: GCS_FUSE_ANNOTATION_VALUE})
+    annotations.update(GCS_FUSE_ANNOTATION)
 
 
 def add_volumes(job_manifest, storage_volumes):