"""
MPS 兼容性工具
"""
import torch
import logging
import functools

logger = logging.getLogger(__name__)


def mps_safe_operation(fallback_to_cpu=True):
    """
    装饰器：为 MPS 操作提供安全回退机制
    """
    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            try:
                return func(*args, **kwargs)
            except RuntimeError as e:
                error_msg = str(e)
                if (torch.backends.mps.is_available() and 
                    ('INTERNAL ASSERT FAILED' in error_msg or 
                     'Placeholder tensor is empty' in error_msg or
                     'MPS' in error_msg)):
                    
                    if fallback_to_cpu:
                        logger.warning(f"MPS operation failed, falling back to CPU: {func.__name__}")
                        
                        # 将所有张量参数移到 CPU
                        cpu_args = []
                        cpu_kwargs = {}
                        original_devices = []
                        
                        for arg in args:
                            if isinstance(arg, torch.Tensor) and arg.device.type == 'mps':
                                original_devices.append(arg.device)
                                cpu_args.append(arg.cpu())
                            else:
                                original_devices.append(None)
                                cpu_args.append(arg)
                        
                        for key, value in kwargs.items():
                            if isinstance(value, torch.Tensor) and value.device.type == 'mps':
                                cpu_kwargs[key] = value.cpu()
                            else:
                                cpu_kwargs[key] = value
                        
                        # 在 CPU 上执行操作
                        result = func(*cpu_args, **cpu_kwargs)
                        
                        # 将结果移回原设备
                        if isinstance(result, torch.Tensor):
                            mps_device = next((dev for dev in original_devices if dev is not None), None)
                            if mps_device:
                                result = result.to(mps_device)
                        
                        return result
                    else:
                        logger.error(f"MPS operation failed: {func.__name__}: {e}")
                        raise
                else:
                    raise
        return wrapper
    return decorator


def check_mps_tensor_validity(tensor, name="tensor", silent=False):
    """检查 MPS 张量的有效性"""
    if not isinstance(tensor, torch.Tensor):
        return True
    
    if tensor.device.type != 'mps':
        return True
    
    # 检查张量是否为空
    if tensor.numel() == 0:
        if not silent:
            logger.debug(f"Empty MPS tensor detected: {name}")
        return False
    
    # 检查张量形状
    if any(dim == 0 for dim in tensor.shape):
        if not silent:
            logger.debug(f"MPS tensor with zero dimension detected: {name}, shape: {tensor.shape}")
        return False
    
    # 检查张量是否包含 NaN 或 Inf（这个检查比较昂贵，只在必要时进行）
    try:
        if torch.isnan(tensor).any() or torch.isinf(tensor).any():
            if not silent:
                logger.warning(f"MPS tensor contains NaN or Inf: {name}")
            return False
    except Exception:
        # 如果检查 NaN/Inf 失败，跳过这个检查
        pass
    
    return True


def ensure_mps_tensor_contiguous(tensor):
    """确保 MPS 张量是连续的"""
    if isinstance(tensor, torch.Tensor) and tensor.device.type == 'mps':
        if not tensor.is_contiguous():
            return tensor.contiguous()
    return tensor


def safe_mps_grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corners=None):
    """安全的 MPS grid_sample 操作"""
    # 检查输入有效性
    if not check_mps_tensor_validity(input, "input"):
        return torch.zeros_like(input)
    
    if not check_mps_tensor_validity(grid, "grid"):
        return torch.zeros_like(input)
    
    # 确保张量连续
    input = ensure_mps_tensor_contiguous(input)
    grid = ensure_mps_tensor_contiguous(grid)
    
    # MPS 不支持 border padding
    if input.device.type == 'mps' and padding_mode == 'border':
        padding_mode = 'zeros'
    
    try:
        return torch.nn.functional.grid_sample(
            input, grid, mode=mode, padding_mode=padding_mode, align_corners=align_corners
        )
    except RuntimeError as e:
        if input.device.type == 'mps' and ('INTERNAL ASSERT FAILED' in str(e) or 'Placeholder tensor is empty' in str(e)):
            logger.warning(f"MPS grid_sample failed, falling back to CPU: {e}")
            
            # 回退到 CPU
            input_cpu = input.cpu()
            grid_cpu = grid.cpu()
            
            result_cpu = torch.nn.functional.grid_sample(
                input_cpu, grid_cpu, mode=mode, padding_mode=padding_mode, align_corners=align_corners
            )
            
            return result_cpu.to(input.device)
        else:
            raise


def optimize_mps_memory():
    """优化 MPS 内存使用"""
    if torch.backends.mps.is_available():
        try:
            torch.mps.empty_cache()
            logger.debug("MPS cache cleared")
        except Exception as e:
            logger.warning(f"Failed to clear MPS cache: {e}")


def get_mps_info():
    """获取 MPS 信息"""
    info = {
        'available': False,
        'device_count': 0,
        'current_device': None,
    }
    
    if hasattr(torch.backends, 'mps'):
        info['available'] = torch.backends.mps.is_available()
        if info['available']:
            info['device_count'] = 1  # MPS 只有一个设备
            info['current_device'] = 'mps'
    
    return info