#!/usr/bin/env pytest
###############################################################################
#
# Project:  GDAL/OGR Test Suite
# Purpose:  Test shapefile spatial index mechanism (.qix files). This can serve
#           as a test for the functionality of shapelib's shptree.c
# Author:   Even Rouault <even dot rouault at spatialys.com>
#
###############################################################################
# Copyright (c) 2012, Even Rouault <even dot rouault at spatialys.com>
#
# SPDX-License-Identifier: MIT
###############################################################################

import random

from osgeo import ogr

###############################################################################
#


def check_qix_non_overlapping_geoms(lyr):

    geoms = []
    lyr.SetSpatialFilter(None)
    extents = lyr.GetExtent()
    fc_ref = lyr.GetFeatureCount()

    feat = lyr.GetNextFeature()
    while feat is not None:
        geom = feat.GetGeometryRef()
        geoms.append(geom.Clone())
        feat = lyr.GetNextFeature()

    # Test getting each geom 1 by 1
    for geom in geoms:
        bbox = geom.GetEnvelope()
        lyr.SetSpatialFilterRect(bbox[0], bbox[2], bbox[1], bbox[3])
        lyr.ResetReading()
        feat = lyr.GetNextFeature()
        got_geom = feat.GetGeometryRef()
        assert got_geom.Equals(geom) != 0, "expected %s. got %s" % (
            geom.ExportToWkt(),
            got_geom.ExportToWkt(),
        )

    # Get all geoms in a single gulp. We do not use exactly the extent bounds, because
    # there is an optimization in the shapefile driver to skip the spatial index in that
    # case. That trick can only work with non point geometries of course
    lyr.SetSpatialFilterRect(
        extents[0] + 0.001, extents[2] + 0.001, extents[1] - 0.001, extents[3] - 0.001
    )
    lyr.ResetReading()
    fc = lyr.GetFeatureCount()
    assert fc == fc_ref, "expected %d. got %d" % (fc_ref, fc)


###############################################################################


def build_rectangle_from_point(x, y, radius=0.1):
    return ogr.CreateGeometryFromWkt(
        "POLYGON((%f %f,%f %f,%f %f,%f %f,%f %f))"
        % (
            x - radius,
            y - radius,
            x - radius,
            y + radius,
            x + radius,
            y + radius,
            x + radius,
            y - radius,
            x - radius,
            y - radius,
        )
    )


###############################################################################
# Test geoms on a 10x10 grid


def test_ogr_shape_qix_1():

    shape_drv = ogr.GetDriverByName("ESRI Shapefile")
    ds = shape_drv.CreateDataSource("/vsimem/ogr_shape_qix.shp")
    lyr = ds.CreateLayer("ogr_shape_qix")

    for x in range(10):
        for y in range(10):
            feat = ogr.Feature(lyr.GetLayerDefn())
            feat.SetGeometry(build_rectangle_from_point(x, y))
            lyr.CreateFeature(feat)
            feat = None

    ds.ExecuteSQL("CREATE SPATIAL INDEX ON ogr_shape_qix")

    ds = None

    ds = ogr.Open("/vsimem/ogr_shape_qix.shp")
    lyr = ds.GetLayer(0)
    check_qix_non_overlapping_geoms(lyr)

    shape_drv.DeleteDataSource("/vsimem/ogr_shape_qix.shp")


###############################################################################
# Test geoms on a 100x100 grid


def test_ogr_shape_qix_2():

    shape_drv = ogr.GetDriverByName("ESRI Shapefile")
    ds = shape_drv.CreateDataSource("/vsimem/ogr_shape_qix.shp")
    lyr = ds.CreateLayer("ogr_shape_qix")

    for x in range(100):
        for y in range(100):
            feat = ogr.Feature(lyr.GetLayerDefn())
            feat.SetGeometry(build_rectangle_from_point(x, y))
            lyr.CreateFeature(feat)
            feat = None

    ds.ExecuteSQL("CREATE SPATIAL INDEX ON ogr_shape_qix")

    ds = None

    ds = ogr.Open("/vsimem/ogr_shape_qix.shp")
    lyr = ds.GetLayer(0)
    check_qix_non_overlapping_geoms(lyr)

    shape_drv.DeleteDataSource("/vsimem/ogr_shape_qix.shp")


###############################################################################
# Test 2 separated regions of 10x10 geoms


def test_ogr_shape_qix_3():

    shape_drv = ogr.GetDriverByName("ESRI Shapefile")
    ds = shape_drv.CreateDataSource("/vsimem/ogr_shape_qix.shp")
    lyr = ds.CreateLayer("ogr_shape_qix")

    for x in range(10):
        for y in range(10):
            feat = ogr.Feature(lyr.GetLayerDefn())
            feat.SetGeometry(build_rectangle_from_point(x, y))
            lyr.CreateFeature(feat)
            feat = None

    for x in range(10):
        for y in range(10):
            feat = ogr.Feature(lyr.GetLayerDefn())
            feat.SetGeometry(build_rectangle_from_point(x + 1000, y))
            lyr.CreateFeature(feat)
            feat = None

    ds.ExecuteSQL("CREATE SPATIAL INDEX ON ogr_shape_qix")

    ds = None

    ds = ogr.Open("/vsimem/ogr_shape_qix.shp")
    lyr = ds.GetLayer(0)
    check_qix_non_overlapping_geoms(lyr)

    shape_drv.DeleteDataSource("/vsimem/ogr_shape_qix.shp")


###############################################################################
#


def check_qix_random_geoms(lyr):

    geoms = []
    lyr.SetSpatialFilter(None)
    extents = lyr.GetExtent()
    fc_ref = lyr.GetFeatureCount()

    feat = lyr.GetNextFeature()
    while feat is not None:
        geom = feat.GetGeometryRef()
        geoms.append(geom.Clone())
        feat = lyr.GetNextFeature()

    # Test getting each geom 1 by 1
    for geom in geoms:
        bbox = geom.GetEnvelope()
        lyr.SetSpatialFilterRect(bbox[0], bbox[2], bbox[1], bbox[3])
        lyr.ResetReading()
        found_geom = False
        feat = lyr.GetNextFeature()
        while feat is not None and found_geom is False:
            got_geom = feat.GetGeometryRef()
            if got_geom.Equals(geom) == 1:
                found_geom = True
            else:
                feat = lyr.GetNextFeature()
        assert found_geom, "did not find geometry for %s" % (geom.ExportToWkt())

    # Get all geoms in a single gulp. We do not use exactly the extent bounds, because
    # there is an optimization in the shapefile driver to skip the spatial index in that
    # case. That trick can only work with non point geometries of course
    lyr.SetSpatialFilterRect(
        extents[0] + 0.001, extents[2] + 0.001, extents[1] - 0.001, extents[3] - 0.001
    )
    lyr.ResetReading()
    fc = lyr.GetFeatureCount()
    assert fc == fc_ref, "expected %d. got %d" % (fc_ref, fc)


###############################################################################


def build_rectangle(x1, y1, x2, y2):
    return ogr.CreateGeometryFromWkt(
        "POLYGON((%f %f,%f %f,%f %f,%f %f,%f %f))"
        % (x1, y1, x1, y2, x2, y2, x2, y1, x1, y1)
    )


###############################################################################
# Test random geometries


def test_ogr_shape_qix_4():

    shape_drv = ogr.GetDriverByName("ESRI Shapefile")
    ds = shape_drv.CreateDataSource("/vsimem/ogr_shape_qix.shp")
    lyr = ds.CreateLayer("ogr_shape_qix")

    # The 1000,200,10 figures are such that there are
    # a bit of overlapping between the geometries
    for _ in range(1000):
        feat = ogr.Feature(lyr.GetLayerDefn())
        x1 = random.randint(0, 200)
        y1 = random.randint(0, 200)
        x2 = x1 + random.randint(1, 10)
        y2 = y1 + random.randint(1, 10)
        feat.SetGeometry(build_rectangle(x1, y1, x2, y2))
        lyr.CreateFeature(feat)
        feat = None

    # And add statistically non overlapping features
    for _ in range(1000):
        feat = ogr.Feature(lyr.GetLayerDefn())
        x1 = random.randint(0, 10000)
        y1 = random.randint(0, 10000)
        x2 = x1 + random.randint(1, 10)
        y2 = y1 + random.randint(1, 10)
        feat.SetGeometry(build_rectangle(x1, y1, x2, y2))
        lyr.CreateFeature(feat)
        feat = None

    ds.ExecuteSQL("CREATE SPATIAL INDEX ON ogr_shape_qix")

    check_qix_random_geoms(lyr)

    shape_drv.DeleteDataSource("/vsimem/ogr_shape_qix.shp")
