Featured image of post python实现遥感图像镶嵌线生成

python实现遥感图像镶嵌线生成

基于分割、最优路径的方法,实现遥感镶嵌线算法

前言

镶嵌线其实就是cv领域的拼缝,做遥感图像拼接的时候为了让拼接处过渡自然一些,就需要拼接处的线呈现蜿蜒曲折的状态,有些需求还需要镶嵌线避开房屋,走地势线等等。镶嵌线的算法,在ENVI中的auto generate seamlines可以很直观的看到生成。但是ENVI中生成的镶嵌线非常直,直接穿过一些明显的曲线纹理,这让拼接的影像留下很明显的拼接痕迹。此次算法的实现,主要是解决这个拼接线直的问题。

算法简介

要点的话主要包括几个部分。一般镶嵌都是做了正射校正之后才进行,正射完成后,影像会发生一定的旋转,产生的黑色区域就是用0值填充的区域,这部分区域是不需要进行处理的,所以我们在处理时需要将有效区域提取;基于影像纹理特征生成路径的方法,主要是利用之前发过的文章中,超像素分割的方法获得;镶嵌线镶嵌的大致想法,其实是先将线生成,继而裁剪出单影像的不规则边,最后按gdal给出的warp方法进行拼接。

使用的包介绍

包名称 版本 包名称 版本 包名称 版本 包名称 版本
os 基础包 cv2 4.6.0.66 networkx 2.8.6 pathlib 基础包
math 基础包 ogr 3.4.3 itertools 基础包 geopandas 0.11.1
shapefile 基础包 osr 3.4.3 shapely 2.3.1 skimage 0.19.3
tarfile 基础包 gdal 3.4.3 shutil 基础包 numpy 1.23.3
time 基础包 gdalconst 3.4.3 topojson 1.5 python 3.10.4

有些包的名字写的不太对,但是大部分都是正确的,主要还是scikit-imagenetworkx以及gdalopencv。有意思的一些地方是,shapely的包安装的时候需要写的是

1
pip install pyshp

全代码

代码很长,下面根据不同的功能,以主函数为思路,进行逐个拆解。

主函数

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86

def main(input1, input2, output):
    time_start=time.time()
    print("start deal")
    tempdir = os.path.join(os.path.dirname(output),"temp")
    try:
        os.mkdir(tempdir)
    except Exception as e:
        pass


    # 过程文件
    out_raster1 = os.path.join(tempdir,"out_raster1.tif")
    out_raster2 = os.path.join(tempdir,"out_raster2.tif")
    outputfile1 = os.path.join(tempdir,"bina_shp1.shp")
    outputfile2 = os.path.join(tempdir,"bina_shp2.shp")
    outShp1 = os.path.join(tempdir,"select_shp1.shp")
    outShp2 = os.path.join(tempdir,"select_shp2.shp")
    outshp_intersect = os.path.join(tempdir,"intersect_shp.shp")
    inter_sim_shp = os.path.join(tempdir,"intersect_simply_shp.shp")
    output_raster = os.path.join(tempdir,"clip_interest_raster.tif")
    resample_raster = os.path.join(tempdir,"clip_resample_raster.tif")
    out_shp_point = os.path.join(tempdir,"start_end_point.shp")
    seg_raster = os.path.join(tempdir,"seg_raster.tif")
    seg_poly = os.path.join(tempdir,"seg_poly_shp.shp")
    seg_line = os.path.join(tempdir,"seg_line_shp.shp")
    seg_line_sim = os.path.join(tempdir,"seg_line_sim.shp")
    seg_line_mer_inter = os.path.join(tempdir,"seg_line_mer_inter_shp.shp")
    intersect_buffer = os.path.join(tempdir,"intersect_buffer.shp")
    seg_line_mer = os.path.join(tempdir,"seg_line_mer.shp")
    # simplify_seg_line = os.path.join(tempdir,"sim_seg_line.shp")
    shortestpath = os.path.join(tempdir,"shortestpath.shp")
    bufferline = os.path.join(tempdir,"cutline_buffer.shp")
    mosaic_mask_clip = os.path.join(tempdir,"mosaic_mask1.shp")
    mosaic_mask_true = os.path.join(tempdir,"mosaic_mask_true.shp")
    mask_temp = os.path.join(tempdir,"mask_temp.shp")
    mask_raster = os.path.join(tempdir,"mask_raster.tif")
    buffer_line = os.path.join(tempdir,"buffer_line.shp")
   
    
    # 函数开始处理部分
    raster_binary(input1,out_raster1)
    raster_binary(input2,out_raster2)

    PolygonizeTheRaster_bina(out_raster1,outputfile1)
    PolygonizeTheRaster_bina(out_raster2,outputfile2)
    print("Binarize done")
    SelectByAttribute(outputfile1, outShp1)
    SelectByAttribute(outputfile2, outShp2)

    get_intersect_shp(outShp1 , outShp2, outshp_intersect)
    simplifyshp(outshp_intersect, inter_sim_shp)
    print("Simplify done")
    clip_raster_from_intersect(input1, inter_sim_shp, output_raster)
    resample_for_seg(output_raster, resample_raster)
    start2 = time.time()
    print("Start segment")
    segementation_img(resample_raster, seg_raster)
    end2 = time.time()
    print('seg time cost',end2-start2,'s')
    print("Segment done")
    tolerance = PolygonizeTheRaster(seg_raster,seg_poly)
    pol2line(seg_poly, seg_line)
    topo_simplify(seg_line, seg_line_sim, tolerance)
    
    merge_all_feature_in_one(seg_line_sim,seg_line_mer)
    get_intersect_shp(seg_line_mer , outshp_intersect, seg_line_mer_inter)
    print("Polygonized done")
    buffer(outshp_intersect, intersect_buffer, -0.00005)
    get_intersect_shp(seg_line_mer_inter , intersect_buffer, buffer_line)

    
    
    get_start_end_points(inter_sim_shp, out_shp_point)
    print("start to find shortest path")
    shortest_path_dijsktra(buffer_line, out_shp_point, shortestpath)
    print("Find shortest path done")
    buffer(shortestpath, bufferline, 0.000000001)
    get_differ_shp(outShp1, bufferline, mosaic_mask_clip)
    explord(mosaic_mask_clip, mosaic_mask_true, mask_temp)
    RasterMosaic(input1, input2, mask_raster, output, mosaic_mask_true)
    print("Mosaic done")
    time_end=time.time()
    print('time cost',time_end-time_start,'s')
    #shutil.rmtree(tempdir) 
    buildpyramid(outputfile)

为了方便调试和检查,这里将过程文件逐一生成,但其实gdal有对应的内存方法即virtul raster可以直接将数组放到内存内,这个后期在工程接近完善的时候会改写为内存方法。

栅格二值化

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
def raster_binary(input_raster,out_raster):
    dataset = gdal.Open(input_raster)
    im_width = dataset.RasterXSize
    im_height = dataset.RasterYSize
    im_geotrans = dataset.GetGeoTransform()
    im_proj = dataset.GetProjection()
    im_data = dataset.ReadAsArray(0, 0, im_width, im_height)
    # re0 = im_data.transpose((2, 1, 0)) 
    ret, border0 = cv2.threshold(im_data[0], 0, 1, cv2.THRESH_BINARY)
    # border0 = border0.transpose((2, 1, 0)) 
    write_img(out_raster, im_proj, im_geotrans, border0) 
    del dataset

这里主要针对多波段栅格,进行了二值化以提取栅格有效范围,使用到的是cv包。

灰色区域的值 黑色区域的值

进行二值化以后,栅格只剩下0和1,以方便栅格转矢量赋值。

二值栅格转矢量

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
def PolygonizeTheRaster_bina(inputfile,outputfile):
    dataset = gdal.Open(inputfile, gdal.GA_ReadOnly)
    srcband=dataset.GetRasterBand(1)
    im_proj = dataset.GetProjection()
    prj = osr.SpatialReference() 
    prj.ImportFromWkt(im_proj)
    drv = ogr.GetDriverByName('ESRI Shapefile')
    dst_ds = drv.CreateDataSource(outputfile)
    dst_layername = 'out'
    dst_layer = dst_ds.CreateLayer(dst_layername, srs=prj)
    dst_fieldname = 'DN'
    fd = ogr.FieldDefn(dst_fieldname, ogr.OFTInteger)
    dst_layer.CreateField(fd)
    dst_field = 0
    gdal.Polygonize(srcband, None, dst_layer, dst_field) 

二值化栅格主要也是用了osr库。这里提一嘴,熟悉gdal的朋友们应该都知道,自某年开始gdal,osr,ogr等地理数据处理包逐渐被osgeo纳入管理,利用开源基金会产生持续发展的力量。利用先前二值化的栅格值赋给polygon,方便后续提取值为1的shp块。

栅格转矢量

根据属性值选取polygon块

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17

def SelectByAttribute(InShp, outShp):
    open_parks = ogr.Open(InShp)
    layer_park = open_parks.GetLayer(0)
    layer_park.SetAttributeFilter("DN = '1'")
    number_park = layer_park.GetFeatureCount()
    driver = ogr.GetDriverByName("ESRI shapefile")
    if os.path.exists(outShp):
        driver.DeleteDataSource(outShp)
    dataset = driver.CreateDataSource(outShp)
    spatialref_new = osr.SpatialReference()
    spatialref_new.ImportFromEPSG(4326)
    new_layer = dataset.CreateLayer(outShp, geom_type= ogr.wkbPolygon, srs=spatialref_new)
    for j in range(0, number_park):
        h = layer_park.GetNextFeature()
        new_layer.CreateFeature(h)
    dataset.Destroy()

基于ogr包,通过关键函数SetAttributeFilter选出DN值等于1的矢量块,创建新的shp进行保存

根据属性值选取polygon块

求镶嵌线处理区域

这个主要是利用有效区域shp求交。通过以上几个函数对两张待镶嵌影像进行处理以后,基于geopandas包求交集,获得两个影像的相交区域。

1
2
3
4
5
6
def get_intersect_shp(out_shp_ring1 , out_shp_ring2, outshp_intersect):
    gdf_left = gpd.read_file(out_shp_ring1)
    gdf_right = gpd.read_file(out_shp_ring2)
    gdf_intersect = gpd.overlay(gdf_left,gdf_right,'intersection',keep_geom_type=True)
    gdf_intersect.to_file(outshp_intersect)
    

当然,在获取交集以后,其实矢量的一部分边界是呈现锯齿状,这是由于栅格转矢量的过程并没有对锯齿进行简化。这里给出简化函数。

1
2
3
4
5
def simplifyshp(input, output): 
  
    gdf_main = gpd.read_file(input)            
    simp = gdf_main.simplify(tolerance=0.001, preserve_topology=True)
    simp.to_file(output)

求镶嵌线处理区域

利用有效区域裁剪输入的原始栅格

1
2
3
4
5
6
7
8
9
def clip_raster_from_intersect(input_main_raster, input_shp, output_raster):
    input_raster=gdal.Open(input_main_raster)
    ds = gdal.Warp(output_raster,
                input_raster,
                format = 'GTiff',
                cutlineDSName = input_shp,    
                cutlineWhere="FIELD = 'whatever'",
                dstNodata = -9999)          
    ds=None  

这里主要是基于warp函数,对原始栅格进行裁剪,获取两张影像的重叠区域并输出为tif,方便后续使用超分辨率进行分割。

超像素分割

由于SLIC或者quickshift处理大分辨率影像时间非常长,所以我这里使用降采样将影像分辨率调低,再利用低分辨率影像进行分割,测试效果速度从500秒提升到了90秒。重采样函数如下:

1
2
3
4
5
6
def resample_for_seg(input_Dir, output_dir):
    dataset = gdal.Open(input_Dir, gdal.GA_ReadOnly)
    ds_trans = dataset.GetGeoTransform()
    res = ds_trans[1]*2
    gdal.Translate(output_dir, input_Dir, xRes=res, yRes=res, resampleAlg="bilinear", format="GTiff")
    

然后就是最主要的分割函数了。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26

def segementation_img(input_raster, output_raster):
  
    dataset = gdal.Open(input_raster)
    im_width = dataset.RasterXSize
    im_height = dataset.RasterYSize
    im_geotrans = dataset.GetGeoTransform()
    im_proj = dataset.GetProjection()
    im_data = dataset.ReadAsArray(0, 0, im_width, im_height) #0, 0, im_width, im_height
    bandnum = dataset.RasterCount
    im_data = im_data.astype(int)
    if bandnum == 1:
        temp, mask_arra = im_data.transpose((1,0))
    
    elif bandnum == 2:
        im_data = im_data[0]
        temp, mask_arra = im_data.transpose((1,0))
    else:
        im_data = im_data[0:3]
        temp = im_data.transpose((2, 1, 0))
        mask_arra = temp[:, : , 0]
    mask = create_mask(mask_arra)
    seg_func = slic(temp, n_segments=2000, compactness=10, mask=mask)  #1
    #seg_func = quickshift(temp, ratio=1.0, kernel_size=5)
    label = seg_func.transpose((1,0))
    write_img(output_raster, im_proj, im_geotrans, label)

这里用的是SLIC分割算法,参数设置基本按常规来,并没有尝试太多,原因是这个就单纯是为了获取网络路径,那其实多复杂其实没什么用,如果过于方正,又不如Voronoi,所以就居中设置。

分割粗结果

分割结果处理

将分割好的栅格,转换为矢量线。这里遇到了难点,因为gdal栅格转矢量只能栅格转面,而不能直接转为线,所以先要先栅格转面再转线:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def PolygonizeTheRaster(inputfile,outputfile):
    dataset = gdal.Open(inputfile, gdal.GA_ReadOnly)
    srcband=dataset.GetRasterBand(1)
    im_proj = dataset.GetProjection()
    im_trans = dataset.GetGeoTransform()
    tolerance = im_trans[1]*5
    prj = osr.SpatialReference() 
    prj.ImportFromWkt(im_proj)
    drv = ogr.GetDriverByName('ESRI Shapefile')
    dst_ds = drv.CreateDataSource(outputfile)
    dst_layername = 'out'
    dst_layer = dst_ds.CreateLayer(dst_layername, srs=prj)
    dst_fieldname = 'DN'
    fd = ogr.FieldDefn(dst_fieldname, ogr.OFTInteger)
    dst_layer.CreateField(fd)
    dst_field = 0
    gdal.Polygonize(srcband, None, dst_layer, dst_field) 
    return tolerance

def pol2line(polyfn, linefn):
    driver = ogr.GetDriverByName('ESRI Shapefile')
    polyds = ogr.Open(polyfn, 0)
    polyLayer = polyds.GetLayer()
    spatialref = polyLayer.GetSpatialRef()
    if os.path.exists(linefn):
        driver.DeleteDataSource(linefn)
    lineds =driver.CreateDataSource(linefn)
    linelayer = lineds.CreateLayer(linefn, srs=spatialref, geom_type=ogr.wkbLineString)
    featuredefn = linelayer.GetLayerDefn()
    for feat in polyLayer:
        geom = feat.GetGeometryRef()
        ring = geom.GetGeometryRef(0)
        outfeature = ogr.Feature(featuredefn)
        outfeature.SetGeometry(ring)
        linelayer.CreateFeature(outfeature)
        outfeature = None

代码也算是常规写法了,将每步操作都输出文件,以方便检查和分析问题。在分割结果线输出以后,还要进行一次简化,为的是去除gdal栅格转矢量的锯齿问题,简化函数这里要用新的一个方法:

1
2
3
4
5
def topo_simplify(input, output, tolerance):
    gdf = gpd.read_file(input)
    topo = tp.Topology(gdf, prequantize=False)
    gdf = topo.toposimplify(tolerance).to_gdf()
    gdf.to_file(output, driver="ESRI Shapefile")

简化未简化

分割后转矢量线的结果,每个块是单独的一个对象,我们需要将所有对象合并到一个对象:

1
2
3
4
5
6
7
def merge_all_feature_in_one(input, output):
    gdf = gpd.read_file(input)
    geom = gdf['geometry']
    new_geom = gpd.tools.collect(geom)
    df = {'id': [0], 'geometry': [new_geom]}
    new_gdf = gpd.GeoDataFrame(df, crs="EPSG:4326")
    new_gdf.to_file(output)

测试发现,合并的结果还是存在一些莫名其妙的散线,有的时候还各种问题,所以为了保险,这里逆向思维,将分割线与分割区域求交,就可以得到整体的线图层,该线图层只包含一个对象:

1
2
3
4
5
def get_intersect_shp(out_shp_ring1 , out_shp_ring2, outshp_intersect):
    gdf_left = gpd.read_file(out_shp_ring1)
    gdf_right = gpd.read_file(out_shp_ring2)
    gdf_intersect = gpd.overlay(gdf_left,gdf_right,'intersection',keep_geom_type=True)
    gdf_intersect.to_file(outshp_intersect)

获取镶嵌线的起点和终点

这里的思路是,以栅格相交区域shp块的对角点为起点和终点,利用geopandas等矢量包轻松获得角点坐标:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25


def get_start_end_points(intersect_shp, out_shp_point):
    inDriver = ogr.GetDriverByName("ESRI Shapefile")
    inDataSource = inDriver.Open(intersect_shp, 0)
    inLayer = inDataSource.GetLayer()
    extent = inLayer.GetExtent()
    elon = abs( extent[0] - extent[1] )
    elat = abs( extent[2] - extent[3] )

    if elat> elon:
        start = geometry.Point(extent[0],extent[2])
        end = geometry.Point(extent[1], extent[3])
    else:
        start = geometry.Point(extent[1], extent[2])
        end = geometry.Point(extent[0], extent[3])
        
    pointshp = gpd.GeoSeries([start, end],               
                                crs='EPSG:4326', 
                                index=['0', '1']
                                )
    pointshp.to_file(out_shp_point,driver='ESRI Shapefile',encoding='utf-8')
    inDataSource = None
    out_shp_point = None
    

起点终点 这里可能有人问了,为什么角点不是在影像边缘。其实实际上影像整体范围是这样的: 只是因为arcmap的忽略背景值设置,使得图像显示背景被忽略了。

最优路径算法

在前段时间发的文章里,其实有这个的介绍,主要是基于networkx包实现,算法有astardijkstrabellman-ford三种可选,这里用的是a星:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62

def shortest_path_dijsktra(input_road_shp, input_point_shp, outpath):
    driver = ogr.GetDriverByName('ESRI Shapefile')
    G = nx.DiGraph()
    r1 = shapefile.Reader(input_road_shp)
    for s in r1.shapes():
        for p1, p2 in pairwise(s.points):
            G.add_edge(tuple(p1), tuple(p2))
    sg = list(G.to_undirected(c) for c in nx.strongly_connected_components(G))[0]
    r2 = shapefile.Reader(input_point_shp)
    start = r2.shape(0).points[0]
    end = r2.shape(1).points[0]
    for n0, n1 in sg.edges():
        dist = haversine(n0, n1)
        sg.edges[n0,n1]["dist"] = dist

    nn_start = None
    nn_end = None
    start_delta = float("inf")
    end_delta = float("inf")
    for n in sg.nodes():
        s_dist = haversine(start, n)
        e_dist = haversine(end, n)
        if s_dist < start_delta:
            nn_start = n
            start_delta = s_dist
        if e_dist < end_delta:
            nn_end = n
            end_delta = e_dist
        nx.shortest_path
    path = nx.astar_path(sg, source=nn_start, target=nn_end, weight="dist") #list , method="bellman-ford"
    multiline = ogr.Geometry(ogr.wkbMultiLineString)
    line = ogr.Geometry(ogr.wkbLineString)
    print(start[0], start[1])
    line.AddPoint(start[0], start[1])
    for point in path:
        print(point[0], point[1])
        line.AddPoint(point[0],point[1])  #  添加点01
    print(end[0])
    print(end[1])
    line.AddPoint(end[0], end[1])

    multiline.AddGeometry(line)
    wkt = multiline.ExportToWkt()
    driver = ogr.GetDriverByName("ESRI Shapefile")
    data_source = driver.CreateDataSource(outpath)
    srs = osr.SpatialReference()
    srs.ImportFromEPSG(4326)
    layer = data_source.CreateLayer("path", srs, ogr.wkbLineString )
    field_name = ogr.FieldDefn("Name", ogr.OFTString)
    field_name.SetWidth(14)
    layer.CreateField(field_name)
    field_name = ogr.FieldDefn("data", ogr.OFTString)
    field_name.SetWidth(14)
    layer.CreateField(field_name)
    feature = ogr.Feature(layer.GetLayerDefn())
    feature.SetField("Name", "path")
    line = ogr.CreateGeometryFromWkt(wkt)
    feature.SetGeometry(line)
    layer.CreateFeature(feature)
    feature = None
    data_source = None

最优路径结果

矢量线裁剪面

没想到这个还是挺有难度的,目前应该很少博主写这个的介绍,这里简单介绍思路。目的是用生成的最优路径polyline去裁剪栅格的范围polygon,首先要将线做缓冲区,缓冲数值需要非常小,然后利用生成的近乎是线的缓冲面polygon栅格范围shp进行求差,就能得到被切分的shp:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
 
def buffer(inShp, out, value):
    """
    :param inShp: 输入的矢量路径
    :param fname: 输出的矢量路径
    :param bdistance: 缓冲区距离
    :return:
    """
    ogr.UseExceptions()
    in_ds = ogr.Open(inShp)
    in_lyr = in_ds.GetLayer()
    driver = ogr.GetDriverByName('ESRI Shapefile')
    if Path(out).exists():
        driver.DeleteDataSource(out)
    out_ds = driver.CreateDataSource(out)
    out_lyr = out_ds.CreateLayer(out, in_lyr.GetSpatialRef(), ogr.wkbPolygon)
    def_feature = out_lyr.GetLayerDefn()
    for feature in in_lyr:
        geometry = feature.GetGeometryRef()
        buffer = geometry.Buffer(value)
        out_feature = ogr.Feature(def_feature)
        out_feature.SetGeometry(buffer)
        out_lyr.CreateFeature(out_feature)
        out_feature = None

def get_differ_shp(input_main,input_minor,output_differ):
    gdf_main = gpd.read_file(input_main)
    gdf_minor = gpd.read_file(input_minor)
    gdf_differ= gpd.overlay(gdf_main,gdf_minor,'difference')
    gdf_differ.to_file(output_differ)
    

虽然前文也写有相关函数,但这里还是给出(水字数)。

处理完以上内容后,其实被切分的两个面,还是处于同一个对象,我们需要把这个整合的对象"炸开",然后根据面积大小进行排序,并选择面积最大区域输出,就可以获取栅格的真正镶嵌线裁剪区。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
def explord(input,output , output_temp):
    gdf_main = gpd.read_file(input)
    explord = gpd.GeoDataFrame.explode(gdf_main)
    explord.to_file(output_temp, driver="ESRI Shapefile")

    output_temp_shp = gpd.read_file(output_temp)
    areas = output_temp_shp.area
    index = areas.sort_values(ascending = False).index.tolist()[0]
    row = output_temp_shp.loc[[index]]
    row.to_file(output, driver="ESRI Shapefile")

镶嵌

获得了镶嵌线裁剪的结果,接下来就是正式镶嵌了:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23

def RasterMosaic(inputfilePath, referencefilefilePath, outputfile, outputfile2, cutline):
  
    inputrasfile1 = gdal.Open(inputfilePath, gdal.GA_ReadOnly) # 第一幅影像
    inputProj1 = inputrasfile1.GetProjection()
    inputrasfile2 = gdal.Open(referencefilefilePath, gdal.GA_ReadOnly) # 第二幅影像
    options=gdal.WarpOptions(srcSRS=inputProj1, 
                             dstSRS=inputProj1,
                             format='GTiff', 
                            resampleAlg=gdalconst.GRA_Bilinear,
                            srcNodata=0, 
                            dstNodata=0, 
                            cutlineDSName=cutline
                            )
    options2=gdal.WarpOptions(srcSRS=inputProj1, 
                            dstSRS=inputProj1,
                            format='GTiff', 
                            resampleAlg=gdalconst.GRA_Bilinear,
                            srcNodata=0, 
                            dstNodata=0
                            )
    gdal.Warp(outputfile,inputrasfile1,options=options)
    gdal.Warp(outputfile2, [inputrasfile2,outputfile], options=options2)

这里我也研究了一下,gdal似乎不能一步到位直接将option写在一个warp里,然后裁剪完直接镶嵌,只能够一步步来了,基本就是warp函数,gdal是真的强大。

这里由于两幅影像质量较好,看不出明显的分界线。。。

后处理工作

主要指金字塔的工作,因为是最后的工作,所以前文的生成过程中并没有都建立金字塔,我只在镶嵌完成的影像中建立金字塔,主要是方便查看,这个代码难度也不大。

1
2
3
4
5
6

def buildpyramid(input):
    Image = gdal.Open(input, 0)  # 0 = read-only, 1 = read-write.
    gdal.SetConfigOption('COMPRESS_OVERVIEW', 'DEFLATE')
    Image.BuildOverviews('NEAREST', [4, 8, 16, 32, 64, 128], gdal.TermProgress_nocb)
    del Image  # close the dataset (Python object and pointers)

总结

这个思路也是磨合了挺久,一开始想的是利用cv的seamless stitch,但是发现会破坏影像真实值,只能采取物理方法生成镶嵌线,任务上要求是要避开房屋,这里还没啥思路,我之前想的是利用超像素分割的mask函数给房屋区域做一个mask,基于NDVI与OTSU阈值法提取建筑区以后作为分割的非选择区域,但是发现实现出来最优路径还是会按照所有区域去走,但是基于arcmap转线是可以正常按照非建筑区走的。这里原因可能是arcgis的栅格转线底层算法和gdal有所差异,但是涉及底层问题我并没有太多能力和精力去投入钻研,所以目前只能尽量按照建筑纹理去分割。

代码上主要还是得改矢量和栅格的文件输出,一直读写其实是很耗时间的,在实际应用中应该改写为virtual dataset,矢量的话应该改写成统一的处理包,但这里也涉及一个问题,虚拟栅格的矢量化如栅格转面的操作如何进行,现有的包都是基于本地栅格文件进行生成,若放到虚拟栅格中是不是会以数组存放,这是一个未知的,需要验证的问题。

geopandas固然是强大,但是其依赖包太多,其中也包括了fiona等,但其实根本来说fiona等包也能实现以上关于矢量的所有算法,这里就有待验证,是不是可以把矢量的技术栈全换成ogr或fiona。

前文有提到的voronoi生成网络路径,其实是我作为SLIC的一个替代,因为维诺图的生成速度理论是远远优于SLIC或quickshift的,但是维诺图生成的三角网过于方正,这会导致部分纹理路线的缺失,如果作为工程应用来说,若加入后期镶嵌线人工编辑也就是半自动镶嵌线生成的话,其实维诺图的思路应该是可取的,时间关系我也没有对维诺图的生成进行研究,后续也需要进行讨论和了解。

附录·总代码

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
import os
import math
import shapefile
import tarfile
import time, shutil
from shapely import geometry
from itertools import tee
import networkx as nx
import cv2
from osgeo import ogr, osr, gdal, gdalconst
import numpy as np
from skimage.segmentation import slic
from skimage import morphology
import geopandas as gpd
from pathlib import Path
import topojson as tp


def create_mask(input_arra):
    mask = morphology.remove_small_holes(morphology.remove_small_objects(input_arra > 0, 500),500)
    mask = morphology.opening(mask, morphology.disk(3))
    return mask



def untar(fname, dirs):
    ''' 解压缩tar文件函数 '''
    t = tarfile.open(fname)
    t.extractall(path=dirs)

def write_img(filename, im_proj, im_geotrans, im_data):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    if len(im_data.shape) >= 3:
        im_height  , im_width, im_bands = im_data.shape
    else:
        im_height, im_width = im_data.shape
        im_bands = 1
    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)

    dataset.SetGeoTransform(im_geotrans)
    dataset.SetProjection(im_proj)
    if im_bands == 1:
        dataset.GetRasterBand(1).WriteArray(im_data)
    else:
        for i in range(im_bands):
            dataset.GetRasterBand(i + 1).WriteArray(im_data[i+1])

    del dataset
    

def merge_mean_color(graph, src, dst):
    graph.nodes[dst]['total color'] += graph.nodes[src]['total color']
    graph.nodes[dst]['pixel count'] += graph.nodes[src]['pixel count']
    graph.nodes[dst]['mean color'] = (graph.nodes[dst]['total color'] /
                                    graph.nodes[dst]['pixel count'])


def segementation_img(input_raster, output_raster):
  
    dataset = gdal.Open(input_raster)
    im_width = dataset.RasterXSize
    im_height = dataset.RasterYSize
    im_geotrans = dataset.GetGeoTransform()
    im_proj = dataset.GetProjection()
    im_data = dataset.ReadAsArray(0, 0, im_width, im_height) #0, 0, im_width, im_height
    bandnum = dataset.RasterCount
    im_data = im_data.astype(int)
    if bandnum == 1:
        temp, mask_arra = im_data.transpose((1,0))
    
    elif bandnum == 2:
        im_data = im_data[0]
        temp, mask_arra = im_data.transpose((1,0))
    else:
        im_data = im_data[0:3]
        temp = im_data.transpose((2, 1, 0))
        mask_arra = temp[:, : , 0]
    mask = create_mask(mask_arra)
    seg_func = slic(temp, n_segments=2000, compactness=10, mask=mask)  #1
    #seg_func = quickshift(temp, ratio=1.0, kernel_size=5)
    label = seg_func.transpose((1,0))
    write_img(output_raster, im_proj, im_geotrans, label)

def PolygonizeTheRaster_bina(inputfile,outputfile):
    dataset = gdal.Open(inputfile, gdal.GA_ReadOnly)
    srcband=dataset.GetRasterBand(1)
    im_proj = dataset.GetProjection()
    prj = osr.SpatialReference() 
    prj.ImportFromWkt(im_proj)
    drv = ogr.GetDriverByName('ESRI Shapefile')
    dst_ds = drv.CreateDataSource(outputfile)
    dst_layername = 'out'
    dst_layer = dst_ds.CreateLayer(dst_layername, srs=prj)
    dst_fieldname = 'DN'
    fd = ogr.FieldDefn(dst_fieldname, ogr.OFTInteger)
    dst_layer.CreateField(fd)
    dst_field = 0
    gdal.Polygonize(srcband, None, dst_layer, dst_field) 
    
    
def SelectByAttribute(InShp, outShp):
    open_parks = ogr.Open(InShp)
    layer_park = open_parks.GetLayer(0)
    layer_park.SetAttributeFilter("DN = '1'")
    number_park = layer_park.GetFeatureCount()
    driver = ogr.GetDriverByName("ESRI shapefile")
    if os.path.exists(outShp):
        driver.DeleteDataSource(outShp)
    dataset = driver.CreateDataSource(outShp)
    spatialref_new = osr.SpatialReference()
    spatialref_new.ImportFromEPSG(4326)
    new_layer = dataset.CreateLayer(outShp, geom_type= ogr.wkbPolygon, srs=spatialref_new)
    for j in range(0, number_park):
        h = layer_park.GetNextFeature()
        new_layer.CreateFeature(h)
    dataset.Destroy()


def raster_binary(input_raster,out_raster):
    dataset = gdal.Open(input_raster)
    im_width = dataset.RasterXSize
    im_height = dataset.RasterYSize
    im_geotrans = dataset.GetGeoTransform()
    im_proj = dataset.GetProjection()
    im_data = dataset.ReadAsArray(0, 0, im_width, im_height)
    # re0 = im_data.transpose((2, 1, 0)) 
    ret, border0 = cv2.threshold(im_data[0], 0, 1, cv2.THRESH_BINARY)
    # border0 = border0.transpose((2, 1, 0)) 
    
    
    write_img(out_raster, im_proj, im_geotrans, border0) 
    del dataset
    

def PolygonizeTheRaster(inputfile,outputfile):
    dataset = gdal.Open(inputfile, gdal.GA_ReadOnly)
    srcband=dataset.GetRasterBand(1)
    im_proj = dataset.GetProjection()
    im_trans = dataset.GetGeoTransform()
    tolerance = im_trans[1]*5
    prj = osr.SpatialReference() 
    prj.ImportFromWkt(im_proj)
    drv = ogr.GetDriverByName('ESRI Shapefile')
    dst_ds = drv.CreateDataSource(outputfile)
    dst_layername = 'out'
    dst_layer = dst_ds.CreateLayer(dst_layername, srs=prj)
    dst_fieldname = 'DN'
    fd = ogr.FieldDefn(dst_fieldname, ogr.OFTInteger)
    dst_layer.CreateField(fd)
    dst_field = 0
    gdal.Polygonize(srcband, None, dst_layer, dst_field) 
    return tolerance
    

def pol2line(polyfn, linefn):
    """
        This function is used to make polygon convert to line
    :param polyfn: the path of input, the shapefile of polygon
    :param linefn: the path of output, the shapefile of line
    :return:
    """
    driver = ogr.GetDriverByName('ESRI Shapefile')
    polyds = ogr.Open(polyfn, 0)
    polyLayer = polyds.GetLayer()
    spatialref = polyLayer.GetSpatialRef()
    if os.path.exists(linefn):
        driver.DeleteDataSource(linefn)
    lineds =driver.CreateDataSource(linefn)
    linelayer = lineds.CreateLayer(linefn, srs=spatialref, geom_type=ogr.wkbLineString)
    featuredefn = linelayer.GetLayerDefn()
    for feat in polyLayer:
        geom = feat.GetGeometryRef()
        ring = geom.GetGeometryRef(0)
        outfeature = ogr.Feature(featuredefn)
        outfeature.SetGeometry(ring)
        linelayer.CreateFeature(outfeature)
        outfeature = None





def waibao(inShapefile, outShapefile):
    inDriver = ogr.GetDriverByName("ESRI Shapefile")
    inDataSource = inDriver.Open(inShapefile, 0)
    inLayer = inDataSource.GetLayer()
    geomcol = ogr.Geometry(ogr.wkbGeometryCollection) 
    for feature in inLayer:
        geomcol.AddGeometry(feature.GetGeometryRef())
    convexhull = geomcol.ConvexHull()
    outDriver = ogr.GetDriverByName("ESRI Shapefile")
    if os.path.exists(outShapefile):
        outDriver.DeleteDataSource(outShapefile)
    outDataSource = outDriver.CreateDataSource(outShapefile)
    outLayer = outDataSource.CreateLayer("test_convexhull", geom_type=ogr.wkbPolygon)
    idField = ogr.FieldDefn("id", ogr.OFTInteger)
    outLayer.CreateField(idField)
    featureDefn = outLayer.GetLayerDefn()
    feature = ogr.Feature(featureDefn)
    feature.SetGeometry(convexhull)
    feature.SetField("id", 1)
    outLayer.CreateFeature(feature)
    feature = None
    inDataSource = None
    outDataSource = None
    


def get_differ_shp(input_main,input_minor,output_differ):
    gdf_main = gpd.read_file(input_main)
    gdf_minor = gpd.read_file(input_minor)
    gdf_differ= gpd.overlay(gdf_main,gdf_minor,'difference')
    gdf_differ.to_file(output_differ)
    


    
    
def get_intersect_shp(out_shp_ring1 , out_shp_ring2, outshp_intersect):
    gdf_left = gpd.read_file(out_shp_ring1)
    gdf_right = gpd.read_file(out_shp_ring2)
    gdf_intersect = gpd.overlay(gdf_left,gdf_right,'intersection',keep_geom_type=True)
    gdf_intersect.to_file(outshp_intersect)
    


def clip_raster_from_intersect(input_main_raster, input_shp, output_raster):
    input_raster=gdal.Open(input_main_raster)
    ds = gdal.Warp(output_raster,
                input_raster,
                format = 'GTiff',
                cutlineDSName = input_shp,    
                cutlineWhere="FIELD = 'whatever'",
                dstNodata = -9999)          
    ds=None           
            


def get_start_end_points(intersect_shp, out_shp_point):
    inDriver = ogr.GetDriverByName("ESRI Shapefile")
    inDataSource = inDriver.Open(intersect_shp, 0)
    inLayer = inDataSource.GetLayer()
    extent = inLayer.GetExtent()
    elon = abs( extent[0] - extent[1] )
    elat = abs( extent[2] - extent[3] )

    if elat> elon:
        start = geometry.Point(extent[0],extent[2])
        end = geometry.Point(extent[1], extent[3])
    else:
        start = geometry.Point(extent[1], extent[2])
        end = geometry.Point(extent[0], extent[3])
        
    pointshp = gpd.GeoSeries([start, end],               
                                crs='EPSG:4326', 
                                index=['0', '1']
                                )
    pointshp.to_file(out_shp_point,driver='ESRI Shapefile',encoding='utf-8')
    inDataSource = None
    out_shp_point = None
    
    
def BetterMedianFilter(src_arr, k = 3, padding = None):
    height, width = src_arr.shape

    if not padding:
        edge = int((k-1)/2)
        if height - 1 - edge <= edge or width - 1 - edge <= edge:
            print("The parameter k is to large.")
            return None
        new_arr = np.zeros((height, width), dtype = "uint16")
        for i in range(height):
            for j in range(width):
                if i <= edge - 1 or i >= height - 1 - edge or j <= edge - 1 or j >= height - edge - 1:
                    new_arr[i, j] = src_arr[i, j]
                else:
                    nm = src_arr[i - edge:i + edge + 1, j - edge:j + edge + 1]
                    max = np.max(nm)
                    min = np.min(nm)
                    if src_arr[i, j] == max or src_arr[i, j] == min:
                        new_arr[i, j] = np.median(nm)
                    else:
                        new_arr[i, j] = src_arr[i, j]
        return new_arr


    


def haversine(n0, n1):
    x1, y1 = n0
    x2, y2 = n1
    x_dist = math.radians(x1 - x2)
    y_dist = math.radians(y1 - y2)
    y1_rad = math.radians(y1)
    y2_rad = math.radians(y2)
    a = math.sin(y_dist/2)**2 + math.sin(x_dist/2)**2 \
    * math.cos(y1_rad) * math.cos(y2_rad)
    c = 2 * math.asin(math.sqrt(a))
    distance = c * 6371
    return distance

def pairwise(iterable):
    """返回可迭代访问的二值元组
s -> (s0,s1), (s1,s2), (s2, s3), ..."""
    a, b = tee(iterable)
    next(b, None)
    return zip(a, b)

def shortest_path_dijsktra(input_road_shp, input_point_shp, outpath):
    driver = ogr.GetDriverByName('ESRI Shapefile')
    G = nx.DiGraph()
    r1 = shapefile.Reader(input_road_shp)
    for s in r1.shapes():
        for p1, p2 in pairwise(s.points):
            G.add_edge(tuple(p1), tuple(p2))
    sg = list(G.to_undirected(c) for c in nx.strongly_connected_components(G))[0]
    r2 = shapefile.Reader(input_point_shp)
    start = r2.shape(0).points[0]
    end = r2.shape(1).points[0]
    for n0, n1 in sg.edges():
        dist = haversine(n0, n1)
        sg.edges[n0,n1]["dist"] = dist

    nn_start = None
    nn_end = None
    start_delta = float("inf")
    end_delta = float("inf")
    for n in sg.nodes():
        s_dist = haversine(start, n)
        e_dist = haversine(end, n)
        if s_dist < start_delta:
            nn_start = n
            start_delta = s_dist
        if e_dist < end_delta:
            nn_end = n
            end_delta = e_dist
        nx.shortest_path
    path = nx.astar_path(sg, source=nn_start, target=nn_end, weight="dist") #list , method="bellman-ford"
    multiline = ogr.Geometry(ogr.wkbMultiLineString)
    line = ogr.Geometry(ogr.wkbLineString)
    print(start[0], start[1])
    line.AddPoint(start[0], start[1])
    for point in path:
        print(point[0], point[1])
        line.AddPoint(point[0],point[1])  #  添加点01
    print(end[0])
    print(end[1])
    line.AddPoint(end[0], end[1])

    multiline.AddGeometry(line)
    wkt = multiline.ExportToWkt()
    driver = ogr.GetDriverByName("ESRI Shapefile")
    data_source = driver.CreateDataSource(outpath)
    srs = osr.SpatialReference()
    srs.ImportFromEPSG(4326)
    layer = data_source.CreateLayer("path", srs, ogr.wkbLineString )
    field_name = ogr.FieldDefn("Name", ogr.OFTString)
    field_name.SetWidth(14)
    layer.CreateField(field_name)
    field_name = ogr.FieldDefn("data", ogr.OFTString)
    field_name.SetWidth(14)
    layer.CreateField(field_name)
    feature = ogr.Feature(layer.GetLayerDefn())
    feature.SetField("Name", "path")
    line = ogr.CreateGeometryFromWkt(wkt)
    feature.SetGeometry(line)
    layer.CreateFeature(feature)
    feature = None
    data_source = None

def simplifyshp(input, output): #容差是分辨率的根号二倍
  
    gdf_main = gpd.read_file(input)            
    simp = gdf_main.simplify(tolerance=0.001, preserve_topology=True)
    simp.to_file(output)


def merge_all_feature_in_one(input, output):
    gdf = gpd.read_file(input)
    geom = gdf['geometry']
    new_geom = gpd.tools.collect(geom)
    df = {'id': [0], 'geometry': [new_geom]}
    new_gdf = gpd.GeoDataFrame(df, crs="EPSG:4326")
    new_gdf.to_file(output)

    



def get_non_building_field_raster(input,output):
    np.seterr(invalid='ignore')
    dataset = gdal.Open(input)
    im_geotrans = dataset.GetGeoTransform()
    im_proj = dataset.GetProjection()
    im_width = dataset.RasterXSize
    im_height = dataset.RasterYSize
    im_data = dataset.ReadAsArray(0, 0, im_width, im_height)
    im_data = im_data.astype(np.float64)
    ndvi = (im_data[3] - im_data[2]) / (im_data[3] + im_data[2])
    slics =np.where( ndvi > -0.1, ndvi * 1, 0)
    write_img(output, im_proj, im_geotrans, slics) 
    del dataset
    

def collect_non_builidng_shp(input, output):
    nonbuilding = gpd.read_file(input)
    selec =  nonbuilding[nonbuilding.DN == 1]
    selec.to_file(output, driver="ESRI Shapefile")

def RasterMosaic(inputfilePath, referencefilefilePath, outputfile, outputfile2, cutline):
  
    inputrasfile1 = gdal.Open(inputfilePath, gdal.GA_ReadOnly) # 第一幅影像
    inputProj1 = inputrasfile1.GetProjection()
    inputrasfile2 = gdal.Open(referencefilefilePath, gdal.GA_ReadOnly) # 第二幅影像
    options=gdal.WarpOptions(srcSRS=inputProj1, 
                             dstSRS=inputProj1,
                             format='GTiff', 
                            resampleAlg=gdalconst.GRA_Bilinear,
                            srcNodata=0, 
                            dstNodata=0, 
                            cutlineDSName=cutline
                            )
    options2=gdal.WarpOptions(srcSRS=inputProj1, 
                            dstSRS=inputProj1,
                            format='GTiff', 
                            resampleAlg=gdalconst.GRA_Bilinear,
                            srcNodata=0, 
                            dstNodata=0
                            )
    gdal.Warp(outputfile,inputrasfile1,options=options)
    gdal.Warp(outputfile2, [inputrasfile2,outputfile], options=options2)


def buildpyramid(input):
    Image = gdal.Open(input, 0)  # 0 = read-only, 1 = read-write.
    gdal.SetConfigOption('COMPRESS_OVERVIEW', 'DEFLATE')
    Image.BuildOverviews('NEAREST', [4, 8, 16, 32, 64, 128], gdal.TermProgress_nocb)
    del Image  # close the dataset (Python object and pointers)

    
def buffer(inShp, out, value):
    """
    :param inShp: 输入的矢量路径
    :param fname: 输出的矢量路径
    :param bdistance: 缓冲区距离
    :return:
    """
    ogr.UseExceptions()
    in_ds = ogr.Open(inShp)
    in_lyr = in_ds.GetLayer()
    driver = ogr.GetDriverByName('ESRI Shapefile')
    if Path(out).exists():
        driver.DeleteDataSource(out)
    out_ds = driver.CreateDataSource(out)
    out_lyr = out_ds.CreateLayer(out, in_lyr.GetSpatialRef(), ogr.wkbPolygon)
    def_feature = out_lyr.GetLayerDefn()
    for feature in in_lyr:
        geometry = feature.GetGeometryRef()
        buffer = geometry.Buffer(value)
        out_feature = ogr.Feature(def_feature)
        out_feature.SetGeometry(buffer)
        out_lyr.CreateFeature(out_feature)
        out_feature = None
    
    
    out_ds.FlushCache()
    del in_ds, out_ds



def explord(input,output , output_temp):
    gdf_main = gpd.read_file(input)
    explord = gpd.GeoDataFrame.explode(gdf_main)
    explord.to_file(output_temp, driver="ESRI Shapefile")

    output_temp_shp = gpd.read_file(output_temp)
    areas = output_temp_shp.area
    index = areas.sort_values(ascending = False).index.tolist()[0]
    row = output_temp_shp.loc[[index]]
    row.to_file(output, driver="ESRI Shapefile")


def resample_for_seg(input_Dir, output_dir):
    dataset = gdal.Open(input_Dir, gdal.GA_ReadOnly)
    ds_trans = dataset.GetGeoTransform()
    res = ds_trans[1]*2
    gdal.Translate(output_dir, input_Dir, xRes=res, yRes=res, resampleAlg="bilinear", format="GTiff")
    
    
def topo_simplify(input, output, tolerance):
    gdf = gpd.read_file(input)
    topo = tp.Topology(gdf, prequantize=False)
    gdf = topo.toposimplify(tolerance).to_gdf()
    gdf.to_file(output, driver="ESRI Shapefile")




def main(input1, input2, output):
    time_start=time.time()
    print("start deal")
    tempdir = os.path.join(os.path.dirname(output),"temp")
    try:
        os.mkdir(tempdir)
    except Exception as e:
        pass
    
    out_raster1 = os.path.join(tempdir,"out_raster1.tif")
    out_raster2 = os.path.join(tempdir,"out_raster2.tif")
    outputfile1 = os.path.join(tempdir,"bina_shp1.shp")
    outputfile2 = os.path.join(tempdir,"bina_shp2.shp")
    outShp1 = os.path.join(tempdir,"select_shp1.shp")
    outShp2 = os.path.join(tempdir,"select_shp2.shp")
    outshp_intersect = os.path.join(tempdir,"intersect_shp.shp")
    inter_sim_shp = os.path.join(tempdir,"intersect_simply_shp.shp")
    output_raster = os.path.join(tempdir,"clip_interest_raster.tif")
    resample_raster = os.path.join(tempdir,"clip_resample_raster.tif")
    out_shp_point = os.path.join(tempdir,"start_end_point.shp")
    seg_raster = os.path.join(tempdir,"seg_raster.tif")
    seg_poly = os.path.join(tempdir,"seg_poly_shp.shp")
    seg_line = os.path.join(tempdir,"seg_line_shp.shp")
    seg_line_sim = os.path.join(tempdir,"seg_line_sim.shp")
    
    
    
    
    
    seg_line_mer_inter = os.path.join(tempdir,"seg_line_mer_inter_shp.shp")
    intersect_buffer = os.path.join(tempdir,"intersect_buffer.shp")
    seg_line_mer = os.path.join(tempdir,"seg_line_mer.shp")
    # simplify_seg_line = os.path.join(tempdir,"sim_seg_line.shp")
    shortestpath = os.path.join(tempdir,"shortestpath.shp")
    bufferline = os.path.join(tempdir,"cutline_buffer.shp")
    mosaic_mask_clip = os.path.join(tempdir,"mosaic_mask1.shp")
    mosaic_mask_true = os.path.join(tempdir,"mosaic_mask_true.shp")
    mask_temp = os.path.join(tempdir,"mask_temp.shp")
    mask_raster = os.path.join(tempdir,"mask_raster.tif")
    buffer_line = os.path.join(tempdir,"buffer_line.shp")
   
    
    
    raster_binary(input1,out_raster1)
    raster_binary(input2,out_raster2)

    PolygonizeTheRaster_bina(out_raster1,outputfile1)
    PolygonizeTheRaster_bina(out_raster2,outputfile2)
    print("Binarize done")
    SelectByAttribute(outputfile1, outShp1)
    SelectByAttribute(outputfile2, outShp2)

    get_intersect_shp(outShp1 , outShp2, outshp_intersect)
    simplifyshp(outshp_intersect, inter_sim_shp)
    print("Simplify done")
    clip_raster_from_intersect(input1, inter_sim_shp, output_raster)
    resample_for_seg(output_raster, resample_raster)
    start2 = time.time()
    print("Start segment")
    segementation_img(resample_raster, seg_raster)
    end2 = time.time()
    print('seg time cost',end2-start2,'s')
    print("Segment done")
    tolerance = PolygonizeTheRaster(seg_raster,seg_poly)
    pol2line(seg_poly, seg_line)
    topo_simplify(seg_line, seg_line_sim, tolerance)
    
    merge_all_feature_in_one(seg_line_sim,seg_line_mer)
    get_intersect_shp(seg_line_mer , outshp_intersect, seg_line_mer_inter)
    print("Polygonized done")
    buffer(outshp_intersect, intersect_buffer, -0.00005)
    get_intersect_shp(seg_line_mer_inter , intersect_buffer, buffer_line)

    
    
    get_start_end_points(inter_sim_shp, out_shp_point)
    print("start to find shortest path")
    shortest_path_dijsktra(buffer_line, out_shp_point, shortestpath)
    print("Find shortest path done")
    buffer(shortestpath, bufferline, 0.000000001)
    get_differ_shp(outShp1, bufferline, mosaic_mask_clip)
    explord(mosaic_mask_clip, mosaic_mask_true, mask_temp)
    RasterMosaic(input1, input2, mask_raster, output, mosaic_mask_true)
    print("Mosaic done")
    time_end=time.time()
    print('time cost',time_end-time_start,'s')
    #shutil.rmtree(tempdir) 
    buildpyramid(outputfile)


if __name__ == "__main__":
        
    input1 = r"D:\Data\testdata\mosaic\orthx\orthx1.tif"
    input2 = r"D:\Data\testdata\mosaic\orthx\orthx2.tif"
    outputfile = r"D:\Data\testdata\mosaic\mosaic.tif"
    main(input1,input2,outputfile)
Author by Jerrychoices
Built with Hugo
主题 StackJimmy 设计

本站访客数人次 总访问量 本文阅读量