# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn.functional as F
from lada.lib.mps_utils import safe_mps_grid_sample, check_mps_tensor_validity, ensure_mps_tensor_contiguous


def flow_warp(x,
              flow,
              interpolation='bilinear',
              padding_mode='zeros',
              align_corners=True):
    """Warp an image or a feature map with optical flow.

    Args:
        x (Tensor): Tensor with size (n, c, h, w).
        flow (Tensor): Tensor with size (n, h, w, 2). The last dimension is
            a two-channel, denoting the width and height relative offsets.
            Note that the values are not normalized to [-1, 1].
        interpolation (str): Interpolation mode: 'nearest' or 'bilinear'.
            Default: 'bilinear'.
        padding_mode (str): Padding mode: 'zeros' or 'border' or 'reflection'.
            Default: 'zeros'.
        align_corners (bool): Whether align corners. Default: True.

    Returns:
        Tensor: Warped image or feature map.
    """
    if x.size()[-2:] != flow.size()[1:3]:
        raise ValueError(f'The spatial sizes of input ({x.size()[-2:]}) and '
                         f'flow ({flow.size()[1:3]}) are not the same.')
    _, _, h, w = x.size()
    # create mesh grid
    device = flow.device
    # torch.meshgrid has been modified in 1.10.0 (compatibility with previous
    # versions), and will be further modified in 1.12 (Breaking Change)
    if 'indexing' in torch.meshgrid.__code__.co_varnames:
        grid_y, grid_x = torch.meshgrid(
            torch.arange(0, h, device=device, dtype=x.dtype),
            torch.arange(0, w, device=device, dtype=x.dtype),
            indexing='ij')
    else:
        grid_y, grid_x = torch.meshgrid(
            torch.arange(0, h, device=device, dtype=x.dtype),
            torch.arange(0, w, device=device, dtype=x.dtype))
    grid = torch.stack((grid_x, grid_y), 2)  # h, w, 2
    grid.requires_grad_(False)

    grid_flow = grid + flow
    # scale grid_flow to [-1,1]
    grid_flow_x = 2.0 * grid_flow[:, :, :, 0] / max(w - 1, 1) - 1.0
    grid_flow_y = 2.0 * grid_flow[:, :, :, 1] / max(h - 1, 1) - 1.0
    grid_flow = torch.stack((grid_flow_x, grid_flow_y), dim=3)
    # 修复 MPS 设备的类型转换问题
    if x.device.type == 'mps':
        # 检查张量有效性（静默模式，避免过多日志）
        if not check_mps_tensor_validity(x, "input", silent=True):
            # 对于空张量，返回相同形状的零张量
            if x.numel() == 0:
                return torch.zeros_like(x)
            # 对于其他无效张量，尝试创建有效的替代
            return torch.zeros_like(x)
        
        if not check_mps_tensor_validity(grid_flow, "grid_flow", silent=True):
            return torch.zeros_like(x)
        
        grid_flow = grid_flow.to(dtype=x.dtype, device=x.device)
        
        # 确保张量连续
        x = ensure_mps_tensor_contiguous(x)
        grid_flow = ensure_mps_tensor_contiguous(grid_flow)
    else:
        grid_flow = grid_flow.type(x.type())
    
    # 使用安全的 MPS grid_sample
    output = safe_mps_grid_sample(
        x,
        grid_flow,
        mode=interpolation,
        padding_mode=padding_mode,
        align_corners=align_corners)
    return output
