Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 110 additions & 6 deletions lisa/features/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import dataclass
from enum import Enum
from functools import partial
from typing import Any, List, Type
from typing import Any, Dict, List, Type

from dataclasses_json import dataclass_json

Expand Down Expand Up @@ -136,17 +136,121 @@ def install_compute_sdk(self, version: str = "") -> None:
raise LisaException(f"{driver} is not a valid value of ComputeSDK")

def get_gpu_count_with_lsvmbus(self) -> int:
"""
Count GPU devices using lsvmbus.
First tries hardcoded list, then groups devices by last segment of device ID.
"""
lsvmbus_tool = self._node.tools[Lsvmbus]

# Get all VMBus devices
vmbus_devices = lsvmbus_tool.get_device_channels()
self._log.debug(f"Found {len(vmbus_devices)} VMBus devices")

# First try the hardcoded list (original approach)
gpu_count = self._get_gpu_count_hardcoded(vmbus_devices)

if gpu_count > 0:
self._log.debug(f"Found {gpu_count} GPU(s) using hardcoded list")
return gpu_count

# If no matches in hardcoded list, group by last segment
self._log.debug("No GPUs found in hardcoded list, trying last-segment grouping")
gpu_count = self._get_gpu_count_by_last_segment(vmbus_devices)

if gpu_count > 0:
self._log.debug(f"Found {gpu_count} GPU(s) using last-segment grouping")
else:
self._log.debug("No GPU devices found in lsvmbus")

return gpu_count

def _get_gpu_count_by_last_segment(self, vmbus_devices: List[Any]) -> int:
"""
Group VMBus devices by last segment of device ID and find GPU group.
GPUs typically share the same last segment (e.g., '423331303142' for GB200).
"""
try:
# Get actual GPU count from nvidia-smi
nvidia_smi = self._node.tools[NvidiaSmi]
# Get GPU count from nvidia-smi without using pre-existing list
actual_gpu_count = nvidia_smi.get_gpu_count_without_list()

if actual_gpu_count == 0:
self._log.debug("nvidia-smi reports 0 GPUs")
return 0

self._log.debug(f"nvidia-smi reports {actual_gpu_count} GPU(s)")

# Group devices by last segment of device ID
last_segment_groups: Dict[str, List[Any]] = {}

for device in vmbus_devices:
device_id = device.device_id
# Device ID format: XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX
# Extract last segment after the last hyphen
id_parts = device_id.split("-")
if len(id_parts) >= 5:
last_segment = id_parts[-1].lower()
if last_segment not in last_segment_groups:
last_segment_groups[last_segment] = []
last_segment_groups[last_segment].append(device)

# Find a group with exactly the GPU count
for last_segment, devices in last_segment_groups.items():
if len(devices) == actual_gpu_count:
# all should be PCI Express pass-through devices
all_pci_passthrough = all(
"PCI Express pass-through" in device.name for device in devices
)

if all_pci_passthrough:
self._log.debug(
f"Found {len(devices)} PCI Express pass-through devices "
f"with last segment '{last_segment}' matching GPU count"
)
# Log the matched devices for debugging
for device in devices:
self._log.debug(f" GPU device: {device.device_id}")
return actual_gpu_count

# If no exact match, log what we found for debugging
self._log.debug(
f"No device group with last segment matches "
f"GPU count {actual_gpu_count}"
)
for last_segment, devices in last_segment_groups.items():
# Only log groups with PCI Express pass-through devices
pci_devices = [
d for d in devices if "PCI Express pass-through" in d.name
]
if pci_devices:
self._log.debug(
f" Last segment '{last_segment}': "
f"{len(pci_devices)} PCI devices"
)

return 0

except Exception as e:
self._log.debug(f"Last-segment grouping failed: {e}")
return 0

def _get_gpu_count_hardcoded(self, vmbus_devices: List[Any]) -> int:
"""
Original method - check against hardcoded list.
"""
lsvmbus_device_count = 0
bridge_device_count = 0

lsvmbus_tool = self._node.tools[Lsvmbus]
device_list = lsvmbus_tool.get_device_channels()
for device in device_list:
for device in vmbus_devices:
for name, id_, bridge_count in NvidiaSmi.gpu_devices:
if id_ in device.device_id:
lsvmbus_device_count += 1
bridge_device_count = bridge_count
self._log.debug(f"GPU device {name} found!")
self._log.debug(
f"GPU device {name} found using hardcoded list! "
f"Device ID: {device.device_id}"
)
break

return lsvmbus_device_count - bridge_device_count
Expand All @@ -156,7 +260,7 @@ def get_gpu_count_with_lspci(self) -> int:

def get_gpu_count_with_vendor_cmd(self) -> int:
nvidiasmi = self._node.tools[NvidiaSmi]
return nvidiasmi.get_gpu_count()
return nvidiasmi.get_gpu_count_without_list()

def get_supported_driver(self) -> List[ComputeSDK]:
raise NotImplementedError()
Expand Down
26 changes: 26 additions & 0 deletions lisa/tools/nvidiasmi.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,29 @@ def get_gpu_count(self) -> int:
device_count += result.stdout.count(gpu_type)

return device_count

def get_gpu_count_without_list(self) -> int:
"""
Get GPU count directly from nvidia-smi output without
using hardcoded device list.
Counts the number of GPU entries in the nvidia-smi -L output.
"""
result = self.run("-L")
if result.exit_code != 0 or (result.exit_code == 0 and result.stdout == ""):
result = self.run("-L", sudo=True)
if result.exit_code != 0 or (result.exit_code == 0 and result.stdout == ""):
raise LisaException(
f"nvidia-smi command exited with exit_code {result.exit_code}"
)
gpu_lines = [
line
for line in result.stdout.splitlines()
if line.strip().startswith("GPU ")
]
gpu_count = len(gpu_lines)

self._log.debug(f"nvidia-smi detected {gpu_count} GPU(s)")
for line in gpu_lines:
self._log.debug(f" {line}")

return gpu_count
Loading