GPU多卡压测脚本(矩阵乘法, Torch-based)

import torch
import torch.multiprocessing as mp
import time
import math

def stress_task(gpu_id):
    try:
        device = torch.device(f"cuda:{gpu_id}")
        torch.cuda.set_device(device) # 确保上下文正确

        free_mem, total_mem = torch.cuda.mem_get_info(device)
        
        # 3. 动态计算矩阵大小
        # 设定占用目标为剩余显存的 65% (留 35% 给 PyTorch内核开销,防止 OOM)
        # 此时需要存储 X, Y 以及结果 Z,共 3 个矩阵。每个 float32 占 4 字节。
        # 公式: (N * N * 4 bytes) * 3 matrices <= free_mem * 0.65
        target_mem = free_mem * 0.65
        matrix_size = int(math.sqrt(target_mem / 12))

        free_gb = free_mem / (1024**3)
        total_gb = total_mem / (1024**3)
        print(f"[GPU {gpu_id}] Total: {total_gb:.2f}GB | Free: {free_gb:.2f}GB")
        print(f"[GPU {gpu_id}] Calculated Matrix Size: {matrix_size}x{matrix_size} (Target utilization: ~85%)")

        print(f"[GPU {gpu_id}] Allocating memory...")
        x = torch.randn(matrix_size, matrix_size, device=device)
        y = torch.randn(matrix_size, matrix_size, device=device)
        
        print(f"[GPU {gpu_id}] Starting loop...")

        while True:
            z = torch.mm(x, y)

    except RuntimeError as e:
        print(f"[GPU {gpu_id}] Error: {e}")
        if "out of memory" in str(e):
            print(f"[GPU {gpu_id}] Auto-size was too aggressive. Try lowering the 0.85 factor in code.")
    except KeyboardInterrupt:
        pass

if __name__ == '__main__':
    if not torch.cuda.is_available():
        print("CUDA is not available!")
        exit()

    num_gpus = torch.cuda.device_count()
    print(f"Found {num_gpus} GPUs. Auto-calculating load for each...")

    processes = []
    
    mp.set_start_method('spawn', force=True)

    print("Starting processes... (Press Ctrl+C to stop)")
    start_time = time.time()

    try:
        for i in range(num_gpus):
            p = mp.Process(target=stress_task, args=(i,))
            p.start()
            processes.append(p)
        
        for p in processes:
            p.join()

    except KeyboardInterrupt:
        print(f"\nStop signal received. Terminating all processes...")
        for p in processes:
            if p.is_alive():
                p.terminate()
        print(f"All stopped. Duration: {time.time() - start_time:.2f} seconds")

发表评论