首页 前端知识 语义分割数据集png和json相互转化

语义分割数据集png和json相互转化

2024-06-16 01:06:55 前端知识 前端哥 643 637 我要收藏

为了适配语义分割网络训练数据,我使用anylabeling进行图像标注,标注完之后的图像格式应该为png,而且每个点像素值0,1,2,3,4等对应了某一类。
在这里插入图片描述
我们希望标注好的图像格式为png,需要训练的原图像为jpg,二者文件名相同。
下面就是完成该操作。

json转换为png

先导入需要的包

import os
import numpy as np
import json
from tqdm import tqdm
import cv2

首先给出json转换为png的操作,

def json2png(json_root, class_root, name_classes):

    if not os.path.exists(json_root): 
        print("No such folder!")
        return 
    if not os.path.exists(class_root) and class_root != "":
        os.makedirs(class_root)
    
    print("Start conterting annotations...")
    for root,dirs,files in os.walk(json_root):
        for file in files:
            if not file.endswith(".json"):
                continue

            json_file = file

            # print(json_root+json_file)
            json_data = json.load(open(json_root+json_file,"r",encoding="utf-8"))
            
            # 生成空白图片
            W, H = json_data["imageWidth"], json_data["imageHeight"]
            res = np.zeros((H, W)).astype('uint8')
            
            
            for multi in json_data["shapes"]:
                if multi['label'] not in name_classes:
                    continue
                
                # 获取第几个编号
                fillColor = name_classes.index(multi['label'])
                # fillColor = 255
                pts = np.array(multi['points']).astype('int')
                cv2.fillPoly(res,[pts],fillColor)
            
            cv2.imwrite(class_root+get_root_file_name(file)+'.png', res)
            # Image.fromarray(res).save(class_root+get_root_file_name(file)+'.png')

其中 name_classes 表示语义分割的训练类别,比如:

name_classes = ["background","fire"]

而需要注意的是 name_classes 必须要添加背景 “background” .

png转换为json

为了完成png转换为json,这个相对会麻烦一点。

判断连通域的包含操作

采用点是否在多边形内部来判断连通域是否包含于另外一个连通域。因为提取好的连通域只可能存在相离和包含两种关系,不可能存在相交关系。因此只需要判断其中一个点是否包含于另外一个连通域即可判断连通域的包含关系。
采用射线交点法,首先,选择一个点P,该点可以是多边形外部的一个点。从点P向任意方向发射一条射线,与多边形的每条边进行求交。然后统计射线与多边形的边的交点个数。如果交点个数为奇数,则点P在多边形内部;如果交点个数为偶数,则点P在多边形外部。
使用python算法完成如下:

def point_in_polygon(point, polygon):
    x, y = point[0]
    n = len(polygon)
    inside = False

    p1x, p1y = polygon[0][0]
    for i in range(n + 1):
        p2x, p2y = polygon[i % n][0]
        if y > min(p1y, p2y):
            if y <= max(p1y, p2y):
                if x <= max(p1x, p2x):
                    if p1y != p2y:
                        xinters = (y - p1y) * (p2x - p1x) / (p2y - p1y) + p1x
                    if p1x == p2x or x <= xinters:
                        inside = not inside
        p1x, p1y = p2x, p2y

    return inside

连通域提取配对

由于标注图像的某个类别区域可能不是单连通区域,因此采用裁剪法完成
在这里插入图片描述

def get_approx_countors(layer, length_p=0.002):
    img_bin, contours = cv2.findContours(layer, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)

    length_p=0.002

    cir_all = []
    approx_all = []
    cir_in_or_out = []

    for ss in img_bin:
        epsilon = length_p * cv2.arcLength(ss, True)
        approx = cv2.approxPolyDP(ss, epsilon, True)
        approx_all.append(approx.tolist())
        cir_all.append(ss.tolist())
        cir_in_or_out.append(-1)

    # 匹配轮廓
    for i in range(len(cir_all)):
        for j in range(len(cir_all)):
            if i == j:
                continue
            if point_in_polygon(cir_all[i][0], cir_all[j]):
                cir_in_or_out[i] = j
                break
    white_cir = []
    black_cir = []
    black_match = []
    for i in range(len(cir_in_or_out)):
        if cir_in_or_out[i] == -1: #白色
            white_cir.append(i)
        else:# 黑色
            black_cir.append(cir_in_or_out[i])
            black_match.append(i)
    cir_matched = []

    for i in range(len(white_cir)):
        cir_matched.append(approx_all[white_cir[i]])
        
    for i in range(len(black_cir)):
        index_1 = white_cir.index(black_cir[i])
        s1 = cir_matched[index_1]  + approx_all[black_match[i]] + [approx_all[black_match[i]][0]] + [cir_matched[index_1][-1]]
        cir_matched[index_1] = s1
        
    # print(cir_matched)
    for i in range(len(cir_matched)):
        s1 = cir_matched[i] + [cir_matched[i][0]]
        cir_matched[i] = s1

    return cir_matched

这个函数完成的是根据某一个二值化图像提取连通域,并配对好,输出每个连通域的点集。其中length_p表示多边形逼近的程度,该值越小表示采样点越多。

转换为json

批量转化的函数如下:

def png2json(png_root, json_root, name_classes, length_p=0.002):
    if not os.path.exists(png_root): 
        print("No such folder!")
        return 
    if not os.path.exists(json_root) and json_root != "":
        os.makedirs(json_root)

    print("Start conterting annotations...")
    for root,dirs,files in os.walk(png_root):
        for file in tqdm(files):
            if not file.endswith(".png"):
                continue
            
            img = cv2.imread(png_root+file)
            img = cv2.cvtColor(img,cv2.COLOR_RGB2GRAY)
            H, W = img.shape

            countors_all = []
            labels_all = []
            for i in range(1, len(name_classes)):
                layer = np.where(img==i,1,0).astype('uint8')
                # print(layer)
                countors_l = get_approx_countors(layer, length_p=0.002)
                for j in range(len(countors_l)):
                    countors_all.append(countors_l[j])
                    labels_all.append(name_classes[i])

            with open(json_root+get_root_file_name(file)+".json", 'w', encoding='utf-8') as json_f:
                

                json_f.write("{\n")

                json_f.write("  \"version\": \"0.3.3\",\n")
                json_f.write("  \"flags\": {},\n")
                json_f.write("  \"shapes\": [\n")

            
                # 提取好的点集countors_all[i][j][0][(0~1)],第i个连通域第j个点的坐标(x, y)
                        
                for i in range(len(countors_all)):
                    json_f.write("    {\n")
                    json_f.write("      \"label\": \""+labels_all[i]+"\",\n")
                    json_f.write("      \"text\": \"\",\n")
                    json_f.write("      \"points\": [\n")
                    for j in range(len(countors_all[i])):
                        json_f.write("       [\n")
                        json_f.write("        "+str(countors_all[i][j][0][0])+",\n")
                        json_f.write("        "+str(countors_all[i][j][0][1])+"\n")
                        json_f.write("       ]")   
                        if j != len(countors_all[i])-1:
                            json_f.write(",")
                        json_f.write("\n")

                    json_f.write("      ],\n")
                    json_f.write("      \"shape_type\": \"polygon\",\n")
                    json_f.write("      \"flags\": {}\n")  
                    json_f.write("    }")
                    if i != len(countors_all)-1:
                        json_f.write(",")
                    json_f.write("\n")

                json_f.write("  ],\n")
                json_f.write("  \"imagePath\": \""+get_root_file_name(file)+".jpg\",\n")
                json_f.write("  \"imageData\": null,\n")
                json_f.write("  \"imageHeight\": "+str(H)+",\n")
                json_f.write("  \"imageWidth\": "+str(W)+"\n")
                json_f.write("}\n")

总结

全部代码如下:

import os
import numpy as np
import json
# from shutil import copyfile
from tqdm import tqdm
# from xml.etree.ElementTree import parse
# from PIL import Image
import cv2

def get_root_file_name(root1):

    s0, s1 = -1, -1
    for i in range(len(root1)):
        if root1[i] == "/":
            s0 = i
        if root1[i] == ".":
            s1 = i
    return root1[s0+1:s1]  

def point_in_polygon(point, polygon):
    x, y = point[0]
    n = len(polygon)
    inside = False

    p1x, p1y = polygon[0][0]
    for i in range(n + 1):
        p2x, p2y = polygon[i % n][0]
        if y > min(p1y, p2y):
            if y <= max(p1y, p2y):
                if x <= max(p1x, p2x):
                    if p1y != p2y:
                        xinters = (y - p1y) * (p2x - p1x) / (p2y - p1y) + p1x
                    if p1x == p2x or x <= xinters:
                        inside = not inside
        p1x, p1y = p2x, p2y

    return inside

def get_approx_countors(layer, length_p=0.002):
    img_bin, contours = cv2.findContours(layer, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)

    length_p=0.002

    cir_all = []
    approx_all = []
    cir_in_or_out = []

    for ss in img_bin:
        epsilon = length_p * cv2.arcLength(ss, True)
        approx = cv2.approxPolyDP(ss, epsilon, True)
    
        approx_all.append(approx.tolist())
        # print(ss.tolist())
        cir_all.append(ss.tolist())
        cir_in_or_out.append(-1)

    # 匹配轮廓
    for i in range(len(cir_all)):
        for j in range(len(cir_all)):
            if i == j:
                continue
            if point_in_polygon(cir_all[i][0], cir_all[j]):
                cir_in_or_out[i] = j
                break
    # print(cir_in_or_out)
    white_cir = []
    black_cir = []
    black_match = []
    for i in range(len(cir_in_or_out)):
        if cir_in_or_out[i] == -1: #白色
            white_cir.append(i)
        else:# 黑色
            black_cir.append(cir_in_or_out[i])
            black_match.append(i)
    cir_matched = []

    for i in range(len(white_cir)):
        cir_matched.append(approx_all[white_cir[i]])

    # for i in range(len(cir_matched)):
    #     s1 = cir_matched[i] + [cir_matched[i][0]]
    #     cir_matched[i] = s1

    for i in range(len(black_cir)):
        index_1 = white_cir.index(black_cir[i])
        s1 = cir_matched[index_1]  + approx_all[black_match[i]] + [approx_all[black_match[i]][0]] + [cir_matched[index_1][-1]]
        cir_matched[index_1] = s1


    # print(cir_matched)
    for i in range(len(cir_matched)):
        s1 = cir_matched[i] + [cir_matched[i][0]]
        cir_matched[i] = s1

    return cir_matched

def json2png(json_root, class_root, name_classes):

    if not os.path.exists(json_root): 
        print("No such folder!")
        return 
    if not os.path.exists(class_root) and class_root != "":
        os.makedirs(class_root)
    
    print("Start conterting annotations...")
    for root,dirs,files in os.walk(json_root):
        for file in files:
            if not file.endswith(".json"):
                continue

            json_file = file

            # print(json_root+json_file)
            json_data = json.load(open(json_root+json_file,"r",encoding="utf-8"))
            
            # 生成空白图片
            W, H = json_data["imageWidth"], json_data["imageHeight"]
            res = np.zeros((H, W)).astype('uint8')
            
            
            for multi in json_data["shapes"]:
                if multi['label'] not in name_classes:
                    continue
                
                # 获取第几个编号
                fillColor = name_classes.index(multi['label'])
                # fillColor = 255
                pts = np.array(multi['points']).astype('int')
                cv2.fillPoly(res,[pts],fillColor)
            
            cv2.imwrite(class_root+get_root_file_name(file)+'.png', res)
            # Image.fromarray(res).save(class_root+get_root_file_name(file)+'.png')

def png2json(png_root, json_root, name_classes, length_p=0.002):
    if not os.path.exists(png_root): 
        print("No such folder!")
        return 
    if not os.path.exists(json_root) and json_root != "":
        os.makedirs(json_root)

    print("Start conterting annotations...")
    for root,dirs,files in os.walk(png_root):
        for file in tqdm(files):
            if not file.endswith(".png"):
                continue
            
            img = cv2.imread(png_root+file)
            img = cv2.cvtColor(img,cv2.COLOR_RGB2GRAY)
            H, W = img.shape

            countors_all = []
            labels_all = []
            for i in range(1, len(name_classes)):
                layer = np.where(img==i,1,0).astype('uint8')
                # print(layer)
                countors_l = get_approx_countors(layer, length_p=0.002)
                for j in range(len(countors_l)):
                    countors_all.append(countors_l[j])
                    labels_all.append(name_classes[i])

            with open(json_root+get_root_file_name(file)+".json", 'w', encoding='utf-8') as json_f:
                

                json_f.write("{\n")

                json_f.write("  \"version\": \"0.3.3\",\n")
                json_f.write("  \"flags\": {},\n")
                json_f.write("  \"shapes\": [\n")

            
                # 提取好的点集countors_all[i][j][0][(0~1)],第i个连通域第j个点的坐标(x, y)
                        
                for i in range(len(countors_all)):
                    json_f.write("    {\n")
                    json_f.write("      \"label\": \""+labels_all[i]+"\",\n")
                    json_f.write("      \"text\": \"\",\n")
                    json_f.write("      \"points\": [\n")
                    for j in range(len(countors_all[i])):
                        json_f.write("       [\n")
                        json_f.write("        "+str(countors_all[i][j][0][0])+",\n")
                        json_f.write("        "+str(countors_all[i][j][0][1])+"\n")
                        json_f.write("       ]")   
                        if j != len(countors_all[i])-1:
                            json_f.write(",")
                        json_f.write("\n")

                    json_f.write("      ],\n")
                    json_f.write("      \"shape_type\": \"polygon\",\n")
                    json_f.write("      \"flags\": {}\n")  
                    json_f.write("    }")
                    if i != len(countors_all)-1:
                        json_f.write(",")
                    json_f.write("\n")

                json_f.write("  ],\n")
                json_f.write("  \"imagePath\": \""+get_root_file_name(file)+".jpg\",\n")
                json_f.write("  \"imageData\": null,\n")
                json_f.write("  \"imageHeight\": "+str(H)+",\n")
                json_f.write("  \"imageWidth\": "+str(W)+"\n")
                json_f.write("}\n")


    


if __name__ == "__main__":
    name_classes = ["background","fire"]

    json_root = './img/JPEGImages/'
    class_root = './img/SegmentationClass/'
    json2png(json_root, class_root, name_classes)
	
	png_root = './img/SegmentationClass/'
	json_root = './img/JPEGImages/'
	png2json(png_root ,json_root  ,name_classes, length_p=0.001)

    

以火焰数据集标注举例,原图为
在这里插入图片描述
先把图像归一化到0,1,2,3…像素之后用上述函数转化效果如下
在这里插入图片描述
这样就可以用训练好的网络去预测数据集之外的图像,并作适量调整从而添加进数据集完善网络训练!
由于png转换为json只能采用采样的方法,因此这样做必然会导致一定程度的精度降低。如果读者有更好的思路欢迎讨论!

转载请注明出处或者链接地址:https://www.qianduange.cn//article/12281.html
标签
神经网络
评论
发布的文章

JQuery中的load()、$

2024-05-10 08:05:15

大家推荐的文章
会员中心 联系我 留言建议 回顶部
复制成功!