DataX自动生成json配置文件及多线程执行脚本(仅支持mysql-->hive),其他版本自行实现,修改json模版即可
执行Datax任务
datax_run.py
# 指定项目下的所有的ods任务,如果不指定参数,默认执行dw下的prefix过滤后的所有抽取任务
# 使用方式:python3 datax_run.py -p 项目名 -f 过滤json的前缀或多个文件名,拼接
import os
import re
import sys
import json
import time
import argparse
import subprocess
from threading import Thread
# from logger import create_logger
# 必须在datax目录下执行
datax_home = None
datax = f"bin/datax.py"
datax_config_dir = f"job/dw"
logs_dir = "logs"
prefix = ["ods_"]
subfix = [".json"]
thread_num = 8
task_count = 0
def init():
global datax_home
# 检查配置及环境变量是否存在问题
# 必须在datax目录下执行
environ = os.environ
if not environ.keys().__contains__("DATAX_HOME"):
print("未找到环境变量[DATAX_HOME]")
return False
datax_home = environ.get("DATAX_HOME")
if datax_home is None:
print("环境变量[DATAX_HOME]未设置值")
return False
else:
hive)
datax_config_dir = f"job/{hive}"
logs_dir = f"{logs_dir}/{hive}"
if args_filter is not None:
print("过滤条件:", args_filter)
prefix.clear()
for config in args_filter.split(","):
prefix.append(config)
elif hive is not None:
prefix = ["ods_"]
print(f"初始化参数:配置路径--> {datax_config_dir}\nprefix--> {prefix}")
def run_sub_shell(cmd):
try:
# cmd = f"source /etc/profile && cd {datax_home} && " + cmd
# print(cmd)
output = subprocess.check_output(cmd, shell=True, stderr=subprocess.STDOUT, encoding="utf-8", cwd=datax_home)
print(f'Command "{cmd}" executed successfully. Output: \n{output}')
return output
except subprocess.CalledProcessError as e:
print(f'Command "{cmd}" failed with error: \n{e.output}')
exit(-1)
def check_status(log_path):
with open(file=log_path, mode="r", encoding="utf-8") as fp:
if fp.read().__contains__("completed succes print(f"datax_home:{datax_home}")
return True
if not init():
print("初始化失败,正在退出...")
exit(-1)
# log = create_logger("datax_run", f"{datax_home}/logs/")
def extract_filename(file_path):
re_search = re.search(r'([^\\/]+)\.json$', file_path)
if re_search:
return re_search.group(1)
else:
print("未匹配到文件名,请检查...")
return None
def check_args():
parser = argparse.ArgumentParser(description='datax批量任务执行脚本')
# 添加--hive参数
parser.add_argument("--hive", type=str, help="hive数据库")
# 添加-f/--filter参数
parser.add_argument("-f", "--filter", type=str, help="输入过滤条件")
# 解析命令行参数
args = parser.parse_args()
hive = args.hive
args_filter = args.filter
# 输出结果
global prefix, datax_config_dir, logs_dir
if hive is None:
print(f"默认使用配置目录[{datax_config_dir}]")
else:
print("目标Hive库:",sfully"):
print(f"执行datax任务成功:{log_path}")
return True
else:
print(f"执行datax任务失败:{log_path}")
exit(-1)
def init_log():
# 获取今天日期
from datetime import datetime
date_str = datetime.today().date()
# 创建目录
global logs_dir
logs_dir = f"{datax_home}/{logs_dir}/{date_str}"
os.makedirs(logs_dir, exist_ok=True)
print(f"logs dir[{logs_dir}]")
def match_config(x: str, prefix: [], subfix: []):
for pre in prefix:
if x.startswith(pre):
for sub in subfix:
if x.endswith(sub):
return True
return False
def thread_run(config):
config_name = extract_filename(config)
cmd = f"python {datax} {config}"
# cmd = f"python {datax} {config} > {logs_dir}/{config_name}.log"
output = run_sub_shell(cmd)
if output.__contains__("completed successfully"):
task_count -= 1
print(f"同步数据[{config_name}]成功,剩余任务{task_count}...")
else:
print(f"同步数据[{config_name}]失败!")
exit(-1)
def gen_thread_data():
full_paths = []
# 指定配置文件目录
for dirpath, dirnames, filenames in os.walk(f"{datax_home}/{datax_config_dir}"):
configs = filter(lambda x: match_config(x, prefix, subfix), filenames)
full_paths = [dirpath + "/" + x for x in configs]
return full_paths
def future_thread():
from concurrent import futures
thread_data = gen_thread_data()
global task_count
task_count = len(thread_data)
print(f"待执行抽取任务数量:{task_count}")
with futures.ThreadPoolExecutor(max_workers=thread_num) as executor:
for elem in thread_data:
executor.submit(thread_run, elem)
def start():
check_args()
# init_log()
future_thread()
if __name__ == '__main__':
start()
生成Datax配置,创建ODS表
根据配置生成Datax配置json和自动创建hive ods表的脚本build_core.py
import json
import re
import sys
from pathlib import Path
import mysql.connector
from pyhive import hive
import os
import subprocess
datax_home = None
# 初始化检查环境
def init():
global datax_home
# 检查配置及环境变量是否存在问题
# 必须在datax目录下执行
environ = os.environ
if not environ.keys().__contains__("DATAX_HOME"):
print("未找到环境变量[DATAX_HOME]")
return False
datax_home = environ.get("DATAX_HOME")
if datax_home is None:
print("环境变量[DATAX_HOME]未设置值")
return False
else:
print(f"datax_home:{datax_home}")
return True
if not init():
print("初始化失败,正在退出...")
exit(-1)
# 主要内容用于生成datax的mysql到hive的配置文件
# 对于不同的项目或数据库,指定不同的配置文件,生成不同的json
def dynamic_import():
import importlib.util
argv = sys.argv
if len(argv) <= 1:
print("请输出加载的python配置模块名!")
exit(-1)
module_ = argv[1]
try:
print(f"使用__import__导入模块")
module = __import__(module_)
except Exception as e:
print(f"使用__import__导入模块失败")
print(f"使用importlib导入模块")
args = module_.split(os.sep)
if len(args) == 1:
module_name = args[0]
module_path = module_name
elif len(args) > 1:
module_name = args[-1]
module_path = module_
print(f"module_path:{module_path}\nmodule_name:{module_name}")
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
m_keys = module.__dict__
key_list = list(m_keys.keys())
for k in key_list:
if not str(k).startswith("__"):
globals()[k] = m_keys.get(k)
return module
dynamic_import()
global_config = project_config
config_source = source_ds
project_path = global_config['project_path']
hive_host = global_config['hive_host']
hive_port = global_config['hive_port']
hive_db = global_config['hive_db']
use_kb = global_config['enable_kerberos']
use_pt = global_config['enable_partition']
keytab = global_config['key_table']
principal = global_config['principal']
# 加载当前项目数据库及表
def load_db():
if isinstance(config_source, list):
# 多数据源多数据库模式
for source in config_source:
db_tables_ = source["db_tables"]
db_connect = source["connect"]
host_ = db_connect['host']
port_ = db_connect['port']
username_ = db_connect['username']
password_ = db_connect['password']
for db_info in db_tables_:
db_ = db_info["db"]
if dict(db_info).keys().__contains__("project"):
project_ = db_info["project"]
else:
project_ = None
tables_ = db_info["tables"]
query_table(host_, port_, username_, password_, db_, project_, tables_)
else:
print("加载source_ds的config配置出现问题...")
def save_local(save_path, datax_json):
path = Path(f'../job/{save_path}')
if datax_home is not None:
path = Path(f"{datax_home}/job/{save_path}")
elif not Path('../').exists():
path = Path(f"job/{save_path}")
path.parent.mkdir(parents=True, exist_ok=True)
# 覆盖文件写入
path.write_text(datax_json, encoding="utf-8")
def camel_to_snake(field: str):
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', field).lower()
def is_camel(s):
return bool(re.match(r'^[a-z]+([A-Z][a-z]*)*$', s))
def convert_field(field):
table_name = field[0]
field_name = field[1]
field_type = field[2]
field_comment = field[3]
# 是否为驼峰
if is_camel(field_name):
table_name = f"camel_{table_name}"
field_name = camel_to_snake(field_name)
field_comment = f"({field[1]}){field_comment}"
return [table_name, field_name, field_type, field_comment]
def convert_ods_field(field):
field_name = field['field_name']
field_type = field['field_type']
field_hive_type = field['field_hive_type']
field_comment = field['field_comment']
# 是否为驼峰
if is_camel(field_name):
field_name = camel_to_snake(field_name)
field_comment = f"({field['field_name']}){field_comment}"
return {"field_name": field_name, "field_type": field_type, "field_hive_type": field_hive_type,
"field_comment": field_comment}
def build_db(tables: list):
database = {}
for table in tables:
# 查询指定表的所有字段名和类型
table_name = table[0]
field_name = table[1]
field_type = table[2]
field_comment = table[3]
table_fields: list = database.get(table_name)
field_hive_type = hive_type(field_type)
field_one = {"field_name": field_name, "field_type": field_type, "field_hive_type": field_hive_type,
"field_comment": field_comment}
if table_fields is not None:
table_fields.append(field_one)
else:
table_fields = [field_one]
database[table_name] = table_fields
return database
def run_sub_shell(cmd):
try:
# cmd = f"source /etc/profile && cd {datax_home} && " + cmd
# print(cmd)
output = subprocess.check_output(cmd, shell=True, stderr=subprocess.STDOUT, encoding="utf-8", cwd=datax_home)
print(f'Command "{cmd}" executed successfully. Output: \n{output}')
return output
except subprocess.CalledProcessError as e:
print(f'Command "{cmd}" failed with error: \n{e.output}')
exit(-1)
def hive_file_sql(create_db_sql):
# 创建临时文件
tmp_hql = f"{datax_home}/tmp/_hive_sql.hql"
with open(tmp_hql, mode="w", encoding="utf-8") as fp:
fp.write(create_db_sql)
# 执行hive -f
if os.path.exists(tmp_hql):
run_sub_shell(f"hive -f {tmp_hql}")
else:
print(f"{tmp_hql}文件不存在...")
# 删除临时文件
os.remove(tmp_hql)
def query_table(host, port, user, password, db, project_, include_tables):
# 连接 MySQL 数据库
conn = mysql.connector.connect(
host=host,
port=port,
user=user,
password=password,
database=db
)
# 获取游标对象
cursor = conn.cursor()
query_col_sql = f"select table_name,column_name,data_type,column_comment from information_schema.`COLUMNS` where table_schema='{db}' "
if len(include_tables) > 0:
name_str = ",".join([f"'{x}'" for x in include_tables])
table_filter = f' and table_name in({name_str})'
query_col_sql += table_filter
else:
print(f"查询数据库:[{db}]的所有表")
# 查询指定数据库中的所有表名
# print(query_col_sql)
cursor.execute(query_col_sql)
tables = cursor.fetchall()
# 数据库的json
database = build_db(tables)
create_db_sql = f"use {hive_db};"
# 生成各个表的datax配置文件
for table_name in database.keys():
table_fields = database[table_name]
ods_source, ods_table, datax_json = build_datax(host, port, user, password, db, project_, table_name,
table_fields)
# datax和hive的表名全部小写,datax配置文件中的表名使用原始的大小写
save_local(f"{hive_db}/{ods_table}.json", datax_json)
print(f"生成datax配置文件-->{hive_db}/{ods_table}.json")
# 生成建表语句
create_db_sql += build_create_hive(ods_table, table_fields)
print(f"创建hive表-->{hive_db}.{ods_table}")
hive_file_sql(create_db_sql)
# print(create_db_sql)
# 关闭游标和连接
cursor.close()
conn.close()
print(f"自动处理数据库[{db}]的datax配置及hive库ods表完成\n")
def exec_hive_sql(sql_list=["show databases"]):
# 连接到Hive服务器
if use_kb:
conn = hive.Connection(host=hive_host, port=hive_port, database=hive_db, auth='KERBEROS',
kerberos_service_name='hive')
else:
conn = hive.Connection(host=hive_host, port=hive_port, database=hive_db)
# 执行查询
cursor = conn.cursor()
for sql in sql_list:
# print(f"执行sql:\n{sql}\n")
cursor.execute(sql)
# 关闭连接
cursor.close()
conn.close()
def build_create_hive(hive_table, fields):
# 生成建表语句
stored = "orc"
hive_fields = list(map(convert_ods_field, fields))
field_sql = ",\n".join(
map(lambda x: f"\t\t`{x['field_name']}` {x['field_hive_type']} comment '{x['field_comment']}'", hive_fields))
dw_type = hive_table.split("_")[0]
partition_sql = ""
if use_pt:
partition_sql = "partitioned by(pt_day string comment '格式:YYYYMMDD')"
create_sql = f"""
drop table if exists {hive_table};
create external table if not exists {hive_table}
(
{field_sql}
){partition_sql}
ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t'
stored as {stored}
LOCATION '{project_path}/{dw_type}/{hive_table}'
TBLPROPERTIES('orc.compress'='SNAPPY');
"""
# print(create_sql)
return create_sql
def unify_name(database: str):
snake_case = re.sub(r'(?<!^)(?=[A-Z])', '_', database).lower()
return snake_case.replace("_", "").replace("-", "")
def build_datax(host, port, username, password, database, project_, source_table, fields):
if project_ is None or len(project_) == 0:
ods_source = unify_name(database)
else:
ods_source = unify_name(project_)
pt_str = '_pt' if use_pt else ''
ods_table = f"ods_{ods_source}_{source_table}{pt_str}_a".lower()
jdbc_url = f"jdbc:mysql://{host}:{port}/{database}?useSSL=false&useUnicode=true&allowMultiQueries=true&characterEncoding=utf8&characterSetResults=utf8&serverTimezone=Asia/Shanghai"
columns = ",".join([f'"`{field["field_name"]}`"' for field in fields])
hive_fields = list(map(convert_ods_field, fields))
hive_columns = ",".join(
[f'{{"name":"{field["field_name"]}","type":"{field["field_hive_type"]}" }}' for field in hive_fields])
pt_config = '"ptDay": "pt_day",' if use_pt else ""
kerberos_config = f'"haveKerberos": "{use_kb}","kerberosKeytabFilePath": "{keytab}","kerberosPrincipal": "{principal}",' if use_kb else ""
mysql_hive_tpl = '''
{
"job": {
"setting": {
"speed": {
"channel": 3,
"byte":-1
},
"errorLimit": {
"record": 0,
"percentage": 0
}
},
"content": [
{
"reader": {
"name": "mysqlreader",
"parameter": {
"username": "${username}",
"password": "${password}",
"column": [${columns}],
"splitPk": null,
"connection": [
{
"table": ["${sourceTable}"],
"jdbcUrl": ["${jdbcUrl}"]
}
]
}
},
"writer": {
"name": "hdfswriter",
"parameter": {
"defaultFS": "hdfs://master:8020",
"fileType": "orc",${kerberosConfig}${ptConfig}
"path": "${projectPath}/ods/${odsTable}",
"fileName": "${odsTable}",
"column": [${hiveColumns}],
"writeMode": "truncate",
"fieldDelimiter": "\\t",
"compress": "SNAPPY",
"database":"${database}"
}
}
}
]
}
}
'''
var_dict = {"username": username, "password": password, "columns": columns, "sourceTable": source_table,
"jdbcUrl": jdbc_url,
"kerberosConfig": kerberos_config, "ptConfig": pt_config, "projectPath": project_path,
"odsTable": ods_table,
"hiveColumns": hive_columns, "database": hive_db}
for k in var_dict.keys():
mysql_hive_tpl = mysql_hive_tpl.replace('${' + k + '}', var_dict[k])
data = json.loads(mysql_hive_tpl)
data = json.dumps(data, indent=2, ensure_ascii=False).replace("True", "true").replace("False", "false")
return ods_source, ods_table, data
def hive_type(mysql_type):
if mysql_type == "tinyint" or mysql_type == "smallint" or mysql_type == "boolean":
return "smallint"
elif mysql_type == "long" or mysql_type == "int":
return "bigint"
elif mysql_type == "float" or mysql_type == "double" or mysql_type == "decimal":
return "double"
elif mysql_type == "date":
return "date"
elif mysql_type == "timestamp":
return "timestamp"
else:
return "string"
if __name__ == '__main__':
load_db()
配置文件模版config_xx.py
# 目前只支持从mysql到hive
project_name = "project_name"
hive_host = "master"
hive_db = "hive_db"
# hdfs路径
if project_name == 'project_name':
project_path = f"/project/{project_name}/{hive_db}"
else:
project_path = f"/project/{project_name}/warehouse/{hive_db}"
# 主要配置
project_config = {
"project_name": project_name,
"hive_host": hive_host,
"hive_port": 10000,
"project_path": project_path,
# write:hive数据库
"hive_db": hive_db,
"enable_kerberos": True, #是否启用Kerberos
"enable_partition": False, #是否分区表
# 默认不用修改
"default_fs": f"hdfs://{hive_host}:8020",
"principal": "principal",
"key_table": "hdfs.keytab" #如果存在Kerberos的话
}
# 不指定表名,则扫描全库
# 抽取的mysql数据源
source_ds = [
{
"db_tables": [
{
"db": "database",
"tables": ["view_online"],
}
],
"connect": {
"host": "host",
"port": 23306,
"username": "username",
"password": "password",
}
}
]
使用方法:
配置DATAX_HOME,将脚本放在$DATAX_HOME/bin,自行创建job文件夹
生成Datax任务json,同时生成的json文件在$DATAX_HOME/job/{hive_database}下面
python3 build_core.py config/config_dw.py
多线程执行
cd $DATAX_HOME
python3 bin/datax_run.py --hive hive_database
--hive 指定hive数据库,指定要执行的json文件路径在$DATAX_HOME/job/{hive_database}