"""
设备管理器 - 智能设备选择和回退
"""
import torch
import logging
import os
import platform
import subprocess
from typing import Optional, Union

logger = logging.getLogger(__name__)

# Minimum RAM (in GB) required for MPS acceleration
# Systems with less RAM experience crashes due to MPS threading issues
# v0.8.6: Lowered from 24GB to 8GB - 16GB M1 confirmed working after timestamp fix
MPS_MIN_RAM_GB = 8


def _get_system_ram_gb() -> float:
    """Get total system RAM in GB. Returns 0 if unable to detect."""
    try:
        if platform.system() == 'Darwin':
            result = subprocess.run(['sysctl', '-n', 'hw.memsize'],
                                  capture_output=True, text=True, timeout=5)
            total_mem_bytes = int(result.stdout.strip())
            return total_mem_bytes / (1024**3)
    except Exception as e:
        logger.warning(f"Could not detect system RAM: {e}")
    return 0


def _is_mps_ram_sufficient() -> bool:
    """Check if system has enough RAM for MPS."""
    ram_gb = _get_system_ram_gb()
    if ram_gb > 0 and ram_gb < MPS_MIN_RAM_GB:
        logger.warning(f"MPS disabled: System has {ram_gb:.1f}GB RAM, minimum {MPS_MIN_RAM_GB}GB required for stability")
        return False
    return True


class DeviceManager:
    """设备管理器"""
    
    def __init__(self):
        self._current_device = None
        self._fallback_device = 'cpu'
        self._mps_failures = 0
        self._max_mps_failures = 3  # 最大 MPS 失败次数
        self._force_cpu = os.environ.get('LADA_FORCE_CPU', '0').lower() in ('1', 'true', 'yes')
    
    def get_best_device(self) -> str:
        """获取最佳可用设备"""
        if self._force_cpu:
            logger.info("Forced to use CPU due to LADA_FORCE_CPU environment variable")
            return 'cpu'
        
        if self._mps_failures >= self._max_mps_failures:
            logger.warning(f"MPS failed {self._mps_failures} times, falling back to CPU")
            return 'cpu'
        
        # 检查 CUDA
        if torch.cuda.is_available():
            return 'cuda:0'
        
        # 检查 MPS (only if system has enough RAM)
        if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
            if not _is_mps_ram_sufficient():
                return 'cpu'
            if self._test_mps_basic():
                return 'mps'
            else:
                logger.warning("MPS basic test failed, using CPU")
                return 'cpu'
        
        return 'cpu'
    
    def _test_mps_basic(self) -> bool:
        """测试基本 MPS 功能"""
        try:
            # 基本张量操作
            x = torch.randn(2, 2, device='mps')
            y = x + 1
            
            # 简单的 grid_sample 测试
            import torch.nn.functional as F
            input_tensor = torch.randn(1, 1, 4, 4, device='mps')
            grid = torch.randn(1, 2, 2, 2, device='mps')
            result = F.grid_sample(input_tensor, grid, align_corners=False, padding_mode='zeros')
            
            return True
        except Exception as e:
            logger.debug(f"MPS basic test failed: {e}")
            return False
    
    def handle_device_error(self, device: str, error: Exception) -> str:
        """处理设备错误，返回回退设备"""
        if device == 'mps':
            self._mps_failures += 1
            logger.warning(f"MPS error (failure #{self._mps_failures}): {error}")
            
            if self._mps_failures >= self._max_mps_failures:
                logger.warning("Too many MPS failures, permanently switching to CPU")
                return 'cpu'
            else:
                logger.info("Temporarily falling back to CPU")
                return 'cpu'
        
        return self._fallback_device
    
    def reset_failure_count(self, device: str):
        """重置失败计数"""
        if device == 'mps':
            self._mps_failures = 0
    
    def is_mps_available_and_stable(self) -> bool:
        """检查 MPS 是否可用且稳定"""
        if self._force_cpu:
            return False

        if self._mps_failures >= self._max_mps_failures:
            return False

        if not (hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()):
            return False

        if not _is_mps_ram_sufficient():
            return False

        return self._test_mps_basic()


# 全局设备管理器实例
device_manager = DeviceManager()


def get_optimal_device() -> str:
    """获取最优设备"""
    return device_manager.get_best_device()


def handle_device_error(device: str, error: Exception) -> str:
    """处理设备错误"""
    return device_manager.handle_device_error(device, error)


def safe_to_device(tensor: torch.Tensor, device: Union[str, torch.device]) -> torch.Tensor:
    """安全地将张量移动到设备"""
    try:
        return tensor.to(device)
    except Exception as e:
        fallback_device = handle_device_error(str(device), e)
        logger.warning(f"Failed to move tensor to {device}, using {fallback_device}: {e}")
        return tensor.to(fallback_device)


def create_tensor_on_device(device: Union[str, torch.device], *args, **kwargs) -> torch.Tensor:
    """在指定设备上安全创建张量"""
    try:
        return torch.tensor(*args, device=device, **kwargs)
    except Exception as e:
        fallback_device = handle_device_error(str(device), e)
        logger.warning(f"Failed to create tensor on {device}, using {fallback_device}: {e}")
        return torch.tensor(*args, device=fallback_device, **kwargs)


def get_device_info() -> dict:
    """获取设备信息"""
    info = {
        'optimal_device': get_optimal_device(),
        'cuda_available': torch.cuda.is_available(),
        'mps_available': hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(),
        'mps_stable': device_manager.is_mps_available_and_stable(),
        'mps_failures': device_manager._mps_failures,
        'force_cpu': device_manager._force_cpu,
    }
    
    if info['cuda_available']:
        info['cuda_device_count'] = torch.cuda.device_count()
        info['cuda_devices'] = [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]
    
    return info